subscription.go 35 KB


  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "strconv"
  6. "strings"
  7. "sync"
  8. "time"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/pkg/cachex"
  11. "github.com/samber/hot"
  12. "gorm.io/gorm"
  13. )
  14. // Subscription duration units
  15. const (
  16. SubscriptionDurationYear = "year"
  17. SubscriptionDurationMonth = "month"
  18. SubscriptionDurationDay = "day"
  19. SubscriptionDurationHour = "hour"
  20. SubscriptionDurationCustom = "custom"
  21. )
  22. // Subscription quota reset period
  23. const (
  24. SubscriptionResetNever = "never"
  25. SubscriptionResetDaily = "daily"
  26. SubscriptionResetWeekly = "weekly"
  27. SubscriptionResetMonthly = "monthly"
  28. SubscriptionResetCustom = "custom"
  29. )
  30. var (
  31. ErrSubscriptionOrderNotFound = errors.New("subscription order not found")
  32. ErrSubscriptionOrderStatusInvalid = errors.New("subscription order status invalid")
  33. )
  34. const (
  35. subscriptionPlanCacheNamespace = "new-api:subscription_plan:v1"
  36. subscriptionPlanInfoCacheNamespace = "new-api:subscription_plan_info:v1"
  37. )
  38. var (
  39. subscriptionPlanCacheOnce sync.Once
  40. subscriptionPlanInfoCacheOnce sync.Once
  41. subscriptionPlanCache *cachex.HybridCache[SubscriptionPlan]
  42. subscriptionPlanInfoCache *cachex.HybridCache[SubscriptionPlanInfo]
  43. )
  44. func subscriptionPlanCacheTTL() time.Duration {
  45. ttlSeconds := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_CACHE_TTL", 300)
  46. if ttlSeconds <= 0 {
  47. ttlSeconds = 300
  48. }
  49. return time.Duration(ttlSeconds) * time.Second
  50. }
  51. func subscriptionPlanInfoCacheTTL() time.Duration {
  52. ttlSeconds := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_INFO_CACHE_TTL", 120)
  53. if ttlSeconds <= 0 {
  54. ttlSeconds = 120
  55. }
  56. return time.Duration(ttlSeconds) * time.Second
  57. }
  58. func subscriptionPlanCacheCapacity() int {
  59. capacity := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_CACHE_CAP", 5000)
  60. if capacity <= 0 {
  61. capacity = 5000
  62. }
  63. return capacity
  64. }
  65. func subscriptionPlanInfoCacheCapacity() int {
  66. capacity := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_INFO_CACHE_CAP", 10000)
  67. if capacity <= 0 {
  68. capacity = 10000
  69. }
  70. return capacity
  71. }
  72. func getSubscriptionPlanCache() *cachex.HybridCache[SubscriptionPlan] {
  73. subscriptionPlanCacheOnce.Do(func() {
  74. ttl := subscriptionPlanCacheTTL()
  75. subscriptionPlanCache = cachex.NewHybridCache[SubscriptionPlan](cachex.HybridCacheConfig[SubscriptionPlan]{
  76. Namespace: cachex.Namespace(subscriptionPlanCacheNamespace),
  77. Redis: common.RDB,
  78. RedisEnabled: func() bool {
  79. return common.RedisEnabled && common.RDB != nil
  80. },
  81. RedisCodec: cachex.JSONCodec[SubscriptionPlan]{},
  82. Memory: func() *hot.HotCache[string, SubscriptionPlan] {
  83. return hot.NewHotCache[string, SubscriptionPlan](hot.LRU, subscriptionPlanCacheCapacity()).
  84. WithTTL(ttl).
  85. WithJanitor().
  86. Build()
  87. },
  88. })
  89. })
  90. return subscriptionPlanCache
  91. }
  92. func getSubscriptionPlanInfoCache() *cachex.HybridCache[SubscriptionPlanInfo] {
  93. subscriptionPlanInfoCacheOnce.Do(func() {
  94. ttl := subscriptionPlanInfoCacheTTL()
  95. subscriptionPlanInfoCache = cachex.NewHybridCache[SubscriptionPlanInfo](cachex.HybridCacheConfig[SubscriptionPlanInfo]{
  96. Namespace: cachex.Namespace(subscriptionPlanInfoCacheNamespace),
  97. Redis: common.RDB,
  98. RedisEnabled: func() bool {
  99. return common.RedisEnabled && common.RDB != nil
  100. },
  101. RedisCodec: cachex.JSONCodec[SubscriptionPlanInfo]{},
  102. Memory: func() *hot.HotCache[string, SubscriptionPlanInfo] {
  103. return hot.NewHotCache[string, SubscriptionPlanInfo](hot.LRU, subscriptionPlanInfoCacheCapacity()).
  104. WithTTL(ttl).
  105. WithJanitor().
  106. Build()
  107. },
  108. })
  109. })
  110. return subscriptionPlanInfoCache
  111. }
  112. func subscriptionPlanCacheKey(id int) string {
  113. if id <= 0 {
  114. return ""
  115. }
  116. return strconv.Itoa(id)
  117. }
  118. func InvalidateSubscriptionPlanCache(planId int) {
  119. if planId <= 0 {
  120. return
  121. }
  122. cache := getSubscriptionPlanCache()
  123. _, _ = cache.DeleteMany([]string{subscriptionPlanCacheKey(planId)})
  124. infoCache := getSubscriptionPlanInfoCache()
  125. _ = infoCache.Purge()
  126. }
  127. // Subscription plan
  128. type SubscriptionPlan struct {
  129. Id int `json:"id"`
  130. Title string `json:"title" gorm:"type:varchar(128);not null"`
  131. Subtitle string `json:"subtitle" gorm:"type:varchar(255);default:''"`
  132. // Display money amount (follow existing code style: float64 for money)
  133. PriceAmount float64 `json:"price_amount" gorm:"type:decimal(10,6);not null;default:0"`
  134. Currency string `json:"currency" gorm:"type:varchar(8);not null;default:'USD'"`
  135. DurationUnit string `json:"duration_unit" gorm:"type:varchar(16);not null;default:'month'"`
  136. DurationValue int `json:"duration_value" gorm:"type:int;not null;default:1"`
  137. CustomSeconds int64 `json:"custom_seconds" gorm:"type:bigint;not null;default:0"`
  138. Enabled bool `json:"enabled" gorm:"default:true"`
  139. SortOrder int `json:"sort_order" gorm:"type:int;default:0"`
  140. StripePriceId string `json:"stripe_price_id" gorm:"type:varchar(128);default:''"`
  141. CreemProductId string `json:"creem_product_id" gorm:"type:varchar(128);default:''"`
  142. // Max purchases per user (0 = unlimited)
  143. MaxPurchasePerUser int `json:"max_purchase_per_user" gorm:"type:int;default:0"`
  144. // Upgrade user group after purchase (empty = no change)
  145. UpgradeGroup string `json:"upgrade_group" gorm:"type:varchar(64);default:''"`
  146. // Total quota (amount in quota units, 0 = unlimited)
  147. TotalAmount int64 `json:"total_amount" gorm:"type:bigint;not null;default:0"`
  148. // Quota reset period for plan
  149. QuotaResetPeriod string `json:"quota_reset_period" gorm:"type:varchar(16);default:'never'"`
  150. QuotaResetCustomSeconds int64 `json:"quota_reset_custom_seconds" gorm:"type:bigint;default:0"`
  151. CreatedAt int64 `json:"created_at" gorm:"bigint"`
  152. UpdatedAt int64 `json:"updated_at" gorm:"bigint"`
  153. }
  154. func (p *SubscriptionPlan) BeforeCreate(tx *gorm.DB) error {
  155. now := common.GetTimestamp()
  156. p.CreatedAt = now
  157. p.UpdatedAt = now
  158. return nil
  159. }
  160. func (p *SubscriptionPlan) BeforeUpdate(tx *gorm.DB) error {
  161. p.UpdatedAt = common.GetTimestamp()
  162. return nil
  163. }
  164. // Subscription order (payment -> webhook -> create UserSubscription)
  165. type SubscriptionOrder struct {
  166. Id int `json:"id"`
  167. UserId int `json:"user_id" gorm:"index"`
  168. PlanId int `json:"plan_id" gorm:"index"`
  169. Money float64 `json:"money"`
  170. TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
  171. PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
  172. Status string `json:"status"`
  173. CreateTime int64 `json:"create_time"`
  174. CompleteTime int64 `json:"complete_time"`
  175. ProviderPayload string `json:"provider_payload" gorm:"type:text"`
  176. }
  177. func (o *SubscriptionOrder) Insert() error {
  178. if o.CreateTime == 0 {
  179. o.CreateTime = common.GetTimestamp()
  180. }
  181. return DB.Create(o).Error
  182. }
  183. func (o *SubscriptionOrder) Update() error {
  184. return DB.Save(o).Error
  185. }
  186. func GetSubscriptionOrderByTradeNo(tradeNo string) *SubscriptionOrder {
  187. if tradeNo == "" {
  188. return nil
  189. }
  190. var order SubscriptionOrder
  191. if err := DB.Where("trade_no = ?", tradeNo).First(&order).Error; err != nil {
  192. return nil
  193. }
  194. return &order
  195. }
  196. // User subscription instance
  197. type UserSubscription struct {
  198. Id int `json:"id"`
  199. UserId int `json:"user_id" gorm:"index;index:idx_user_sub_active,priority:1"`
  200. PlanId int `json:"plan_id" gorm:"index"`
  201. AmountTotal int64 `json:"amount_total" gorm:"type:bigint;not null;default:0"`
  202. AmountUsed int64 `json:"amount_used" gorm:"type:bigint;not null;default:0"`
  203. StartTime int64 `json:"start_time" gorm:"bigint"`
  204. EndTime int64 `json:"end_time" gorm:"bigint;index;index:idx_user_sub_active,priority:3"`
  205. Status string `json:"status" gorm:"type:varchar(32);index;index:idx_user_sub_active,priority:2"` // active/expired/cancelled
  206. Source string `json:"source" gorm:"type:varchar(32);default:'order'"` // order/admin
  207. LastResetTime int64 `json:"last_reset_time" gorm:"type:bigint;default:0"`
  208. NextResetTime int64 `json:"next_reset_time" gorm:"type:bigint;default:0;index"`
  209. UpgradeGroup string `json:"upgrade_group" gorm:"type:varchar(64);default:''"`
  210. PrevUserGroup string `json:"prev_user_group" gorm:"type:varchar(64);default:''"`
  211. CreatedAt int64 `json:"created_at" gorm:"bigint"`
  212. UpdatedAt int64 `json:"updated_at" gorm:"bigint"`
  213. }
  214. func (s *UserSubscription) BeforeCreate(tx *gorm.DB) error {
  215. now := common.GetTimestamp()
  216. s.CreatedAt = now
  217. s.UpdatedAt = now
  218. return nil
  219. }
  220. func (s *UserSubscription) BeforeUpdate(tx *gorm.DB) error {
  221. s.UpdatedAt = common.GetTimestamp()
  222. return nil
  223. }
  224. type SubscriptionSummary struct {
  225. Subscription *UserSubscription `json:"subscription"`
  226. }
  227. func calcPlanEndTime(start time.Time, plan *SubscriptionPlan) (int64, error) {
  228. if plan == nil {
  229. return 0, errors.New("plan is nil")
  230. }
  231. if plan.DurationValue <= 0 && plan.DurationUnit != SubscriptionDurationCustom {
  232. return 0, errors.New("duration_value must be > 0")
  233. }
  234. switch plan.DurationUnit {
  235. case SubscriptionDurationYear:
  236. return start.AddDate(plan.DurationValue, 0, 0).Unix(), nil
  237. case SubscriptionDurationMonth:
  238. return start.AddDate(0, plan.DurationValue, 0).Unix(), nil
  239. case SubscriptionDurationDay:
  240. return start.Add(time.Duration(plan.DurationValue) * 24 * time.Hour).Unix(), nil
  241. case SubscriptionDurationHour:
  242. return start.Add(time.Duration(plan.DurationValue) * time.Hour).Unix(), nil
  243. case SubscriptionDurationCustom:
  244. if plan.CustomSeconds <= 0 {
  245. return 0, errors.New("custom_seconds must be > 0")
  246. }
  247. return start.Add(time.Duration(plan.CustomSeconds) * time.Second).Unix(), nil
  248. default:
  249. return 0, fmt.Errorf("invalid duration_unit: %s", plan.DurationUnit)
  250. }
  251. }
  252. func NormalizeResetPeriod(period string) string {
  253. switch strings.TrimSpace(period) {
  254. case SubscriptionResetDaily, SubscriptionResetWeekly, SubscriptionResetMonthly, SubscriptionResetCustom:
  255. return strings.TrimSpace(period)
  256. default:
  257. return SubscriptionResetNever
  258. }
  259. }
  260. func calcNextResetTime(base time.Time, plan *SubscriptionPlan, endUnix int64) int64 {
  261. if plan == nil {
  262. return 0
  263. }
  264. period := NormalizeResetPeriod(plan.QuotaResetPeriod)
  265. if period == SubscriptionResetNever {
  266. return 0
  267. }
  268. var next time.Time
  269. switch period {
  270. case SubscriptionResetDaily:
  271. next = time.Date(base.Year(), base.Month(), base.Day(), 0, 0, 0, 0, base.Location()).
  272. AddDate(0, 0, 1)
  273. case SubscriptionResetWeekly:
  274. // Align to next Monday 00:00
  275. weekday := int(base.Weekday()) // Sunday=0
  276. // Convert to Monday=1..Sunday=7
  277. if weekday == 0 {
  278. weekday = 7
  279. }
  280. daysUntil := 8 - weekday
  281. next = time.Date(base.Year(), base.Month(), base.Day(), 0, 0, 0, 0, base.Location()).
  282. AddDate(0, 0, daysUntil)
  283. case SubscriptionResetMonthly:
  284. // Align to first day of next month 00:00
  285. next = time.Date(base.Year(), base.Month(), 1, 0, 0, 0, 0, base.Location()).
  286. AddDate(0, 1, 0)
  287. case SubscriptionResetCustom:
  288. if plan.QuotaResetCustomSeconds <= 0 {
  289. return 0
  290. }
  291. next = base.Add(time.Duration(plan.QuotaResetCustomSeconds) * time.Second)
  292. default:
  293. return 0
  294. }
  295. if endUnix > 0 && next.Unix() > endUnix {
  296. return 0
  297. }
  298. return next.Unix()
  299. }
  300. func GetSubscriptionPlanById(id int) (*SubscriptionPlan, error) {
  301. return getSubscriptionPlanByIdTx(nil, id)
  302. }
  303. func getSubscriptionPlanByIdTx(tx *gorm.DB, id int) (*SubscriptionPlan, error) {
  304. if id <= 0 {
  305. return nil, errors.New("invalid plan id")
  306. }
  307. key := subscriptionPlanCacheKey(id)
  308. if key != "" {
  309. if cached, found, err := getSubscriptionPlanCache().Get(key); err == nil && found {
  310. return &cached, nil
  311. }
  312. }
  313. var plan SubscriptionPlan
  314. query := DB
  315. if tx != nil {
  316. query = tx
  317. }
  318. if err := query.Where("id = ?", id).First(&plan).Error; err != nil {
  319. return nil, err
  320. }
  321. _ = getSubscriptionPlanCache().SetWithTTL(key, plan, subscriptionPlanCacheTTL())
  322. return &plan, nil
  323. }
  324. func CountUserSubscriptionsByPlan(userId int, planId int) (int64, error) {
  325. if userId <= 0 || planId <= 0 {
  326. return 0, errors.New("invalid userId or planId")
  327. }
  328. var count int64
  329. if err := DB.Model(&UserSubscription{}).
  330. Where("user_id = ? AND plan_id = ?", userId, planId).
  331. Count(&count).Error; err != nil {
  332. return 0, err
  333. }
  334. return count, nil
  335. }
  336. func getUserGroupByIdTx(tx *gorm.DB, userId int) (string, error) {
  337. if userId <= 0 {
  338. return "", errors.New("invalid userId")
  339. }
  340. if tx == nil {
  341. tx = DB
  342. }
  343. var group string
  344. if err := tx.Model(&User{}).Where("id = ?", userId).Select(commonGroupCol).Find(&group).Error; err != nil {
  345. return "", err
  346. }
  347. return group, nil
  348. }
  349. func downgradeUserGroupForSubscriptionTx(tx *gorm.DB, sub *UserSubscription, now int64) (string, error) {
  350. if tx == nil || sub == nil {
  351. return "", errors.New("invalid downgrade args")
  352. }
  353. upgradeGroup := strings.TrimSpace(sub.UpgradeGroup)
  354. if upgradeGroup == "" {
  355. return "", nil
  356. }
  357. currentGroup, err := getUserGroupByIdTx(tx, sub.UserId)
  358. if err != nil {
  359. return "", err
  360. }
  361. if currentGroup != upgradeGroup {
  362. return "", nil
  363. }
  364. var activeSub UserSubscription
  365. activeQuery := tx.Where("user_id = ? AND status = ? AND end_time > ? AND id <> ? AND upgrade_group <> ''",
  366. sub.UserId, "active", now, sub.Id).
  367. Order("end_time desc, id desc").
  368. Limit(1).
  369. Find(&activeSub)
  370. if activeQuery.Error == nil && activeQuery.RowsAffected > 0 {
  371. return "", nil
  372. }
  373. prevGroup := strings.TrimSpace(sub.PrevUserGroup)
  374. if prevGroup == "" || prevGroup == currentGroup {
  375. return "", nil
  376. }
  377. if err := tx.Model(&User{}).Where("id = ?", sub.UserId).
  378. Update("group", prevGroup).Error; err != nil {
  379. return "", err
  380. }
  381. return prevGroup, nil
  382. }
  383. func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *SubscriptionPlan, source string) (*UserSubscription, error) {
  384. if tx == nil {
  385. return nil, errors.New("tx is nil")
  386. }
  387. if plan == nil || plan.Id == 0 {
  388. return nil, errors.New("invalid plan")
  389. }
  390. if userId <= 0 {
  391. return nil, errors.New("invalid user id")
  392. }
  393. if plan.MaxPurchasePerUser > 0 {
  394. var count int64
  395. if err := tx.Model(&UserSubscription{}).
  396. Where("user_id = ? AND plan_id = ?", userId, plan.Id).
  397. Count(&count).Error; err != nil {
  398. return nil, err
  399. }
  400. if count >= int64(plan.MaxPurchasePerUser) {
  401. return nil, errors.New("已达到该套餐购买上限")
  402. }
  403. }
  404. nowUnix := GetDBTimestamp()
  405. now := time.Unix(nowUnix, 0)
  406. endUnix, err := calcPlanEndTime(now, plan)
  407. if err != nil {
  408. return nil, err
  409. }
  410. resetBase := now
  411. nextReset := calcNextResetTime(resetBase, plan, endUnix)
  412. lastReset := int64(0)
  413. if nextReset > 0 {
  414. lastReset = now.Unix()
  415. }
  416. upgradeGroup := strings.TrimSpace(plan.UpgradeGroup)
  417. prevGroup := ""
  418. if upgradeGroup != "" {
  419. currentGroup, err := getUserGroupByIdTx(tx, userId)
  420. if err != nil {
  421. return nil, err
  422. }
  423. if currentGroup != upgradeGroup {
  424. prevGroup = currentGroup
  425. if err := tx.Model(&User{}).Where("id = ?", userId).
  426. Update("group", upgradeGroup).Error; err != nil {
  427. return nil, err
  428. }
  429. }
  430. }
  431. sub := &UserSubscription{
  432. UserId: userId,
  433. PlanId: plan.Id,
  434. AmountTotal: plan.TotalAmount,
  435. AmountUsed: 0,
  436. StartTime: now.Unix(),
  437. EndTime: endUnix,
  438. Status: "active",
  439. Source: source,
  440. LastResetTime: lastReset,
  441. NextResetTime: nextReset,
  442. UpgradeGroup: upgradeGroup,
  443. PrevUserGroup: prevGroup,
  444. CreatedAt: common.GetTimestamp(),
  445. UpdatedAt: common.GetTimestamp(),
  446. }
  447. if err := tx.Create(sub).Error; err != nil {
  448. return nil, err
  449. }
  450. return sub, nil
  451. }
  452. // Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan.
  453. func CompleteSubscriptionOrder(tradeNo string, providerPayload string) error {
  454. if tradeNo == "" {
  455. return errors.New("tradeNo is empty")
  456. }
  457. refCol := "`trade_no`"
  458. if common.UsingPostgreSQL {
  459. refCol = `"trade_no"`
  460. }
  461. var logUserId int
  462. var logPlanTitle string
  463. var logMoney float64
  464. var logPaymentMethod string
  465. var upgradeGroup string
  466. err := DB.Transaction(func(tx *gorm.DB) error {
  467. var order SubscriptionOrder
  468. if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
  469. return ErrSubscriptionOrderNotFound
  470. }
  471. if order.Status == common.TopUpStatusSuccess {
  472. return nil
  473. }
  474. if order.Status != common.TopUpStatusPending {
  475. return ErrSubscriptionOrderStatusInvalid
  476. }
  477. plan, err := GetSubscriptionPlanById(order.PlanId)
  478. if err != nil {
  479. return err
  480. }
  481. if !plan.Enabled {
  482. // still allow completion for already purchased orders
  483. }
  484. upgradeGroup = strings.TrimSpace(plan.UpgradeGroup)
  485. _, err = CreateUserSubscriptionFromPlanTx(tx, order.UserId, plan, "order")
  486. if err != nil {
  487. return err
  488. }
  489. if err := upsertSubscriptionTopUpTx(tx, &order); err != nil {
  490. return err
  491. }
  492. order.Status = common.TopUpStatusSuccess
  493. order.CompleteTime = common.GetTimestamp()
  494. if providerPayload != "" {
  495. order.ProviderPayload = providerPayload
  496. }
  497. if err := tx.Save(&order).Error; err != nil {
  498. return err
  499. }
  500. logUserId = order.UserId
  501. logPlanTitle = plan.Title
  502. logMoney = order.Money
  503. logPaymentMethod = order.PaymentMethod
  504. return nil
  505. })
  506. if err != nil {
  507. return err
  508. }
  509. if upgradeGroup != "" && logUserId > 0 {
  510. _ = UpdateUserGroupCache(logUserId, upgradeGroup)
  511. }
  512. if logUserId > 0 {
  513. msg := fmt.Sprintf("订阅购买成功,套餐: %s,支付金额: %.2f,支付方式: %s", logPlanTitle, logMoney, logPaymentMethod)
  514. RecordLog(logUserId, LogTypeTopup, msg)
  515. }
  516. return nil
  517. }
  518. func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error {
  519. if tx == nil || order == nil {
  520. return errors.New("invalid subscription order")
  521. }
  522. now := common.GetTimestamp()
  523. var topup TopUp
  524. if err := tx.Where("trade_no = ?", order.TradeNo).First(&topup).Error; err != nil {
  525. if errors.Is(err, gorm.ErrRecordNotFound) {
  526. topup = TopUp{
  527. UserId: order.UserId,
  528. Amount: 0,
  529. Money: order.Money,
  530. TradeNo: order.TradeNo,
  531. PaymentMethod: order.PaymentMethod,
  532. CreateTime: order.CreateTime,
  533. CompleteTime: now,
  534. Status: common.TopUpStatusSuccess,
  535. }
  536. return tx.Create(&topup).Error
  537. }
  538. return err
  539. }
  540. topup.Money = order.Money
  541. if topup.PaymentMethod == "" {
  542. topup.PaymentMethod = order.PaymentMethod
  543. }
  544. if topup.CreateTime == 0 {
  545. topup.CreateTime = order.CreateTime
  546. }
  547. topup.CompleteTime = now
  548. topup.Status = common.TopUpStatusSuccess
  549. return tx.Save(&topup).Error
  550. }
  551. func ExpireSubscriptionOrder(tradeNo string) error {
  552. if tradeNo == "" {
  553. return errors.New("tradeNo is empty")
  554. }
  555. refCol := "`trade_no`"
  556. if common.UsingPostgreSQL {
  557. refCol = `"trade_no"`
  558. }
  559. return DB.Transaction(func(tx *gorm.DB) error {
  560. var order SubscriptionOrder
  561. if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
  562. return ErrSubscriptionOrderNotFound
  563. }
  564. if order.Status != common.TopUpStatusPending {
  565. return nil
  566. }
  567. order.Status = common.TopUpStatusExpired
  568. order.CompleteTime = common.GetTimestamp()
  569. return tx.Save(&order).Error
  570. })
  571. }
  572. // Admin bind (no payment). Creates a UserSubscription from a plan.
  573. func AdminBindSubscription(userId int, planId int, sourceNote string) (string, error) {
  574. if userId <= 0 || planId <= 0 {
  575. return "", errors.New("invalid userId or planId")
  576. }
  577. plan, err := GetSubscriptionPlanById(planId)
  578. if err != nil {
  579. return "", err
  580. }
  581. err = DB.Transaction(func(tx *gorm.DB) error {
  582. _, err := CreateUserSubscriptionFromPlanTx(tx, userId, plan, "admin")
  583. return err
  584. })
  585. if err != nil {
  586. return "", err
  587. }
  588. if strings.TrimSpace(plan.UpgradeGroup) != "" {
  589. _ = UpdateUserGroupCache(userId, plan.UpgradeGroup)
  590. return fmt.Sprintf("用户分组将升级到 %s", plan.UpgradeGroup), nil
  591. }
  592. return "", nil
  593. }
  594. // GetAllActiveUserSubscriptions returns all active subscriptions for a user.
  595. func GetAllActiveUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
  596. if userId <= 0 {
  597. return nil, errors.New("invalid userId")
  598. }
  599. now := common.GetTimestamp()
  600. var subs []UserSubscription
  601. err := DB.Where("user_id = ? AND status = ? AND end_time > ?", userId, "active", now).
  602. Order("end_time desc, id desc").
  603. Find(&subs).Error
  604. if err != nil {
  605. return nil, err
  606. }
  607. return buildSubscriptionSummaries(subs), nil
  608. }
  609. // HasActiveUserSubscription returns whether the user has any active subscription.
  610. // This is a lightweight existence check to avoid heavy pre-consume transactions.
  611. func HasActiveUserSubscription(userId int) (bool, error) {
  612. if userId <= 0 {
  613. return false, errors.New("invalid userId")
  614. }
  615. now := common.GetTimestamp()
  616. var count int64
  617. if err := DB.Model(&UserSubscription{}).
  618. Where("user_id = ? AND status = ? AND end_time > ?", userId, "active", now).
  619. Count(&count).Error; err != nil {
  620. return false, err
  621. }
  622. return count > 0, nil
  623. }
  624. // GetAllUserSubscriptions returns all subscriptions (active and expired) for a user.
  625. func GetAllUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
  626. if userId <= 0 {
  627. return nil, errors.New("invalid userId")
  628. }
  629. var subs []UserSubscription
  630. err := DB.Where("user_id = ?", userId).
  631. Order("end_time desc, id desc").
  632. Find(&subs).Error
  633. if err != nil {
  634. return nil, err
  635. }
  636. return buildSubscriptionSummaries(subs), nil
  637. }
  638. func buildSubscriptionSummaries(subs []UserSubscription) []SubscriptionSummary {
  639. if len(subs) == 0 {
  640. return []SubscriptionSummary{}
  641. }
  642. result := make([]SubscriptionSummary, 0, len(subs))
  643. for _, sub := range subs {
  644. subCopy := sub
  645. result = append(result, SubscriptionSummary{
  646. Subscription: &subCopy,
  647. })
  648. }
  649. return result
  650. }
  651. // AdminInvalidateUserSubscription marks a user subscription as cancelled and ends it immediately.
  652. func AdminInvalidateUserSubscription(userSubscriptionId int) (string, error) {
  653. if userSubscriptionId <= 0 {
  654. return "", errors.New("invalid userSubscriptionId")
  655. }
  656. now := common.GetTimestamp()
  657. cacheGroup := ""
  658. downgradeGroup := ""
  659. var userId int
  660. err := DB.Transaction(func(tx *gorm.DB) error {
  661. var sub UserSubscription
  662. if err := tx.Set("gorm:query_option", "FOR UPDATE").
  663. Where("id = ?", userSubscriptionId).First(&sub).Error; err != nil {
  664. return err
  665. }
  666. userId = sub.UserId
  667. if err := tx.Model(&sub).Updates(map[string]interface{}{
  668. "status": "cancelled",
  669. "end_time": now,
  670. "updated_at": now,
  671. }).Error; err != nil {
  672. return err
  673. }
  674. target, err := downgradeUserGroupForSubscriptionTx(tx, &sub, now)
  675. if err != nil {
  676. return err
  677. }
  678. if target != "" {
  679. cacheGroup = target
  680. downgradeGroup = target
  681. }
  682. return nil
  683. })
  684. if err != nil {
  685. return "", err
  686. }
  687. if cacheGroup != "" && userId > 0 {
  688. _ = UpdateUserGroupCache(userId, cacheGroup)
  689. }
  690. if downgradeGroup != "" {
  691. return fmt.Sprintf("用户分组将回退到 %s", downgradeGroup), nil
  692. }
  693. return "", nil
  694. }
  695. // AdminDeleteUserSubscription hard-deletes a user subscription.
  696. func AdminDeleteUserSubscription(userSubscriptionId int) (string, error) {
  697. if userSubscriptionId <= 0 {
  698. return "", errors.New("invalid userSubscriptionId")
  699. }
  700. now := common.GetTimestamp()
  701. cacheGroup := ""
  702. downgradeGroup := ""
  703. var userId int
  704. err := DB.Transaction(func(tx *gorm.DB) error {
  705. var sub UserSubscription
  706. if err := tx.Set("gorm:query_option", "FOR UPDATE").
  707. Where("id = ?", userSubscriptionId).First(&sub).Error; err != nil {
  708. return err
  709. }
  710. userId = sub.UserId
  711. target, err := downgradeUserGroupForSubscriptionTx(tx, &sub, now)
  712. if err != nil {
  713. return err
  714. }
  715. if target != "" {
  716. cacheGroup = target
  717. downgradeGroup = target
  718. }
  719. if err := tx.Where("id = ?", userSubscriptionId).Delete(&UserSubscription{}).Error; err != nil {
  720. return err
  721. }
  722. return nil
  723. })
  724. if err != nil {
  725. return "", err
  726. }
  727. if cacheGroup != "" && userId > 0 {
  728. _ = UpdateUserGroupCache(userId, cacheGroup)
  729. }
  730. if downgradeGroup != "" {
  731. return fmt.Sprintf("用户分组将回退到 %s", downgradeGroup), nil
  732. }
  733. return "", nil
  734. }
  735. type SubscriptionPreConsumeResult struct {
  736. UserSubscriptionId int
  737. PreConsumed int64
  738. AmountTotal int64
  739. AmountUsedBefore int64
  740. AmountUsedAfter int64
  741. }
  742. // ExpireDueSubscriptions marks expired subscriptions and handles group downgrade.
  743. func ExpireDueSubscriptions(limit int) (int, error) {
  744. if limit <= 0 {
  745. limit = 200
  746. }
  747. now := GetDBTimestamp()
  748. var subs []UserSubscription
  749. if err := DB.Where("status = ? AND end_time > 0 AND end_time <= ?", "active", now).
  750. Order("end_time asc, id asc").
  751. Limit(limit).
  752. Find(&subs).Error; err != nil {
  753. return 0, err
  754. }
  755. if len(subs) == 0 {
  756. return 0, nil
  757. }
  758. expiredCount := 0
  759. userIds := make(map[int]struct{}, len(subs))
  760. for _, sub := range subs {
  761. if sub.UserId > 0 {
  762. userIds[sub.UserId] = struct{}{}
  763. }
  764. }
  765. for userId := range userIds {
  766. cacheGroup := ""
  767. err := DB.Transaction(func(tx *gorm.DB) error {
  768. res := tx.Model(&UserSubscription{}).
  769. Where("user_id = ? AND status = ? AND end_time > 0 AND end_time <= ?", userId, "active", now).
  770. Updates(map[string]interface{}{
  771. "status": "expired",
  772. "updated_at": common.GetTimestamp(),
  773. })
  774. if res.Error != nil {
  775. return res.Error
  776. }
  777. expiredCount += int(res.RowsAffected)
  778. // If there's an active upgraded subscription, keep current group.
  779. var activeSub UserSubscription
  780. activeQuery := tx.Where("user_id = ? AND status = ? AND end_time > ? AND upgrade_group <> ''",
  781. userId, "active", now).
  782. Order("end_time desc, id desc").
  783. Limit(1).
  784. Find(&activeSub)
  785. if activeQuery.Error == nil && activeQuery.RowsAffected > 0 {
  786. return nil
  787. }
  788. // No active upgraded subscription, downgrade to previous group if needed.
  789. var lastExpired UserSubscription
  790. expiredQuery := tx.Where("user_id = ? AND status = ? AND upgrade_group <> ''",
  791. userId, "expired").
  792. Order("end_time desc, id desc").
  793. Limit(1).
  794. Find(&lastExpired)
  795. if expiredQuery.Error != nil || expiredQuery.RowsAffected == 0 {
  796. return nil
  797. }
  798. upgradeGroup := strings.TrimSpace(lastExpired.UpgradeGroup)
  799. prevGroup := strings.TrimSpace(lastExpired.PrevUserGroup)
  800. if upgradeGroup == "" || prevGroup == "" {
  801. return nil
  802. }
  803. currentGroup, err := getUserGroupByIdTx(tx, userId)
  804. if err != nil {
  805. return err
  806. }
  807. if currentGroup != upgradeGroup || currentGroup == prevGroup {
  808. return nil
  809. }
  810. if err := tx.Model(&User{}).Where("id = ?", userId).
  811. Update("group", prevGroup).Error; err != nil {
  812. return err
  813. }
  814. cacheGroup = prevGroup
  815. return nil
  816. })
  817. if err != nil {
  818. return expiredCount, err
  819. }
  820. if cacheGroup != "" {
  821. _ = UpdateUserGroupCache(userId, cacheGroup)
  822. }
  823. }
  824. return expiredCount, nil
  825. }
  826. // SubscriptionPreConsumeRecord stores idempotent pre-consume operations per request.
  827. type SubscriptionPreConsumeRecord struct {
  828. Id int `json:"id"`
  829. RequestId string `json:"request_id" gorm:"type:varchar(64);uniqueIndex"`
  830. UserId int `json:"user_id" gorm:"index"`
  831. UserSubscriptionId int `json:"user_subscription_id" gorm:"index"`
  832. PreConsumed int64 `json:"pre_consumed" gorm:"type:bigint;not null;default:0"`
  833. Status string `json:"status" gorm:"type:varchar(32);index"` // consumed/refunded
  834. CreatedAt int64 `json:"created_at" gorm:"bigint"`
  835. UpdatedAt int64 `json:"updated_at" gorm:"bigint;index"`
  836. }
  837. func (r *SubscriptionPreConsumeRecord) BeforeCreate(tx *gorm.DB) error {
  838. now := common.GetTimestamp()
  839. r.CreatedAt = now
  840. r.UpdatedAt = now
  841. return nil
  842. }
  843. func (r *SubscriptionPreConsumeRecord) BeforeUpdate(tx *gorm.DB) error {
  844. r.UpdatedAt = common.GetTimestamp()
  845. return nil
  846. }
  847. func maybeResetUserSubscriptionWithPlanTx(tx *gorm.DB, sub *UserSubscription, plan *SubscriptionPlan, now int64) error {
  848. if tx == nil || sub == nil || plan == nil {
  849. return errors.New("invalid reset args")
  850. }
  851. if sub.NextResetTime > 0 && sub.NextResetTime > now {
  852. return nil
  853. }
  854. if NormalizeResetPeriod(plan.QuotaResetPeriod) == SubscriptionResetNever {
  855. return nil
  856. }
  857. baseUnix := sub.LastResetTime
  858. if baseUnix <= 0 {
  859. baseUnix = sub.StartTime
  860. }
  861. base := time.Unix(baseUnix, 0)
  862. next := calcNextResetTime(base, plan, sub.EndTime)
  863. advanced := false
  864. for next > 0 && next <= now {
  865. advanced = true
  866. base = time.Unix(next, 0)
  867. next = calcNextResetTime(base, plan, sub.EndTime)
  868. }
  869. if !advanced {
  870. if sub.NextResetTime == 0 && next > 0 {
  871. sub.NextResetTime = next
  872. sub.LastResetTime = base.Unix()
  873. return tx.Save(sub).Error
  874. }
  875. return nil
  876. }
  877. sub.AmountUsed = 0
  878. sub.LastResetTime = base.Unix()
  879. sub.NextResetTime = next
  880. return tx.Save(sub).Error
  881. }
  882. // PreConsumeUserSubscription pre-consumes from any active subscription total quota.
  883. func PreConsumeUserSubscription(requestId string, userId int, modelName string, quotaType int, amount int64) (*SubscriptionPreConsumeResult, error) {
  884. if userId <= 0 {
  885. return nil, errors.New("invalid userId")
  886. }
  887. if strings.TrimSpace(requestId) == "" {
  888. return nil, errors.New("requestId is empty")
  889. }
  890. if amount <= 0 {
  891. return nil, errors.New("amount must be > 0")
  892. }
  893. now := GetDBTimestamp()
  894. returnValue := &SubscriptionPreConsumeResult{}
  895. err := DB.Transaction(func(tx *gorm.DB) error {
  896. var existing SubscriptionPreConsumeRecord
  897. query := tx.Where("request_id = ?", requestId).Limit(1).Find(&existing)
  898. if query.Error != nil {
  899. return query.Error
  900. }
  901. if query.RowsAffected > 0 {
  902. if existing.Status == "refunded" {
  903. return errors.New("subscription pre-consume already refunded")
  904. }
  905. var sub UserSubscription
  906. if err := tx.Where("id = ?", existing.UserSubscriptionId).First(&sub).Error; err != nil {
  907. return err
  908. }
  909. returnValue.UserSubscriptionId = sub.Id
  910. returnValue.PreConsumed = existing.PreConsumed
  911. returnValue.AmountTotal = sub.AmountTotal
  912. returnValue.AmountUsedBefore = sub.AmountUsed
  913. returnValue.AmountUsedAfter = sub.AmountUsed
  914. return nil
  915. }
  916. var subs []UserSubscription
  917. if err := tx.Set("gorm:query_option", "FOR UPDATE").
  918. Where("user_id = ? AND status = ? AND end_time > ?", userId, "active", now).
  919. Order("end_time asc, id asc").
  920. Find(&subs).Error; err != nil {
  921. return errors.New("no active subscription")
  922. }
  923. if len(subs) == 0 {
  924. return errors.New("no active subscription")
  925. }
  926. for _, candidate := range subs {
  927. sub := candidate
  928. plan, err := getSubscriptionPlanByIdTx(tx, sub.PlanId)
  929. if err != nil {
  930. return err
  931. }
  932. if err := maybeResetUserSubscriptionWithPlanTx(tx, &sub, plan, now); err != nil {
  933. return err
  934. }
  935. usedBefore := sub.AmountUsed
  936. if sub.AmountTotal > 0 {
  937. remain := sub.AmountTotal - usedBefore
  938. if remain < amount {
  939. continue
  940. }
  941. }
  942. record := &SubscriptionPreConsumeRecord{
  943. RequestId: requestId,
  944. UserId: userId,
  945. UserSubscriptionId: sub.Id,
  946. PreConsumed: amount,
  947. Status: "consumed",
  948. }
  949. if err := tx.Create(record).Error; err != nil {
  950. var dup SubscriptionPreConsumeRecord
  951. if err2 := tx.Where("request_id = ?", requestId).First(&dup).Error; err2 == nil {
  952. if dup.Status == "refunded" {
  953. return errors.New("subscription pre-consume already refunded")
  954. }
  955. returnValue.UserSubscriptionId = sub.Id
  956. returnValue.PreConsumed = dup.PreConsumed
  957. returnValue.AmountTotal = sub.AmountTotal
  958. returnValue.AmountUsedBefore = sub.AmountUsed
  959. returnValue.AmountUsedAfter = sub.AmountUsed
  960. return nil
  961. }
  962. return err
  963. }
  964. sub.AmountUsed += amount
  965. if err := tx.Save(&sub).Error; err != nil {
  966. return err
  967. }
  968. returnValue.UserSubscriptionId = sub.Id
  969. returnValue.PreConsumed = amount
  970. returnValue.AmountTotal = sub.AmountTotal
  971. returnValue.AmountUsedBefore = usedBefore
  972. returnValue.AmountUsedAfter = sub.AmountUsed
  973. return nil
  974. }
  975. return fmt.Errorf("subscription quota insufficient, need=%d", amount)
  976. })
  977. if err != nil {
  978. return nil, err
  979. }
  980. return returnValue, nil
  981. }
  982. // RefundSubscriptionPreConsume is idempotent and refunds pre-consumed subscription quota by requestId.
  983. func RefundSubscriptionPreConsume(requestId string) error {
  984. if strings.TrimSpace(requestId) == "" {
  985. return errors.New("requestId is empty")
  986. }
  987. return DB.Transaction(func(tx *gorm.DB) error {
  988. var record SubscriptionPreConsumeRecord
  989. if err := tx.Set("gorm:query_option", "FOR UPDATE").
  990. Where("request_id = ?", requestId).First(&record).Error; err != nil {
  991. return err
  992. }
  993. if record.Status == "refunded" {
  994. return nil
  995. }
  996. if record.PreConsumed <= 0 {
  997. record.Status = "refunded"
  998. return tx.Save(&record).Error
  999. }
  1000. if err := PostConsumeUserSubscriptionDelta(record.UserSubscriptionId, -record.PreConsumed); err != nil {
  1001. return err
  1002. }
  1003. record.Status = "refunded"
  1004. return tx.Save(&record).Error
  1005. })
  1006. }
  1007. // ResetDueSubscriptions resets subscriptions whose next_reset_time has passed.
  1008. func ResetDueSubscriptions(limit int) (int, error) {
  1009. if limit <= 0 {
  1010. limit = 200
  1011. }
  1012. now := GetDBTimestamp()
  1013. var subs []UserSubscription
  1014. if err := DB.Where("next_reset_time > 0 AND next_reset_time <= ? AND status = ?", now, "active").
  1015. Order("next_reset_time asc").
  1016. Limit(limit).
  1017. Find(&subs).Error; err != nil {
  1018. return 0, err
  1019. }
  1020. if len(subs) == 0 {
  1021. return 0, nil
  1022. }
  1023. resetCount := 0
  1024. for _, sub := range subs {
  1025. subCopy := sub
  1026. plan, err := getSubscriptionPlanByIdTx(nil, sub.PlanId)
  1027. if err != nil || plan == nil {
  1028. continue
  1029. }
  1030. err = DB.Transaction(func(tx *gorm.DB) error {
  1031. var locked UserSubscription
  1032. if err := tx.Set("gorm:query_option", "FOR UPDATE").
  1033. Where("id = ? AND next_reset_time > 0 AND next_reset_time <= ?", subCopy.Id, now).
  1034. First(&locked).Error; err != nil {
  1035. return nil
  1036. }
  1037. if err := maybeResetUserSubscriptionWithPlanTx(tx, &locked, plan, now); err != nil {
  1038. return err
  1039. }
  1040. resetCount++
  1041. return nil
  1042. })
  1043. if err != nil {
  1044. return resetCount, err
  1045. }
  1046. }
  1047. return resetCount, nil
  1048. }
  1049. // CleanupSubscriptionPreConsumeRecords removes old idempotency records to keep table small.
  1050. func CleanupSubscriptionPreConsumeRecords(olderThanSeconds int64) (int64, error) {
  1051. if olderThanSeconds <= 0 {
  1052. olderThanSeconds = 7 * 24 * 3600
  1053. }
  1054. cutoff := GetDBTimestamp() - olderThanSeconds
  1055. res := DB.Where("updated_at < ?", cutoff).Delete(&SubscriptionPreConsumeRecord{})
  1056. return res.RowsAffected, res.Error
  1057. }
  1058. type SubscriptionPlanInfo struct {
  1059. PlanId int
  1060. PlanTitle string
  1061. }
  1062. func GetSubscriptionPlanInfoByUserSubscriptionId(userSubscriptionId int) (*SubscriptionPlanInfo, error) {
  1063. if userSubscriptionId <= 0 {
  1064. return nil, errors.New("invalid userSubscriptionId")
  1065. }
  1066. cacheKey := fmt.Sprintf("sub:%d", userSubscriptionId)
  1067. if cached, found, err := getSubscriptionPlanInfoCache().Get(cacheKey); err == nil && found {
  1068. return &cached, nil
  1069. }
  1070. var sub UserSubscription
  1071. if err := DB.Where("id = ?", userSubscriptionId).First(&sub).Error; err != nil {
  1072. return nil, err
  1073. }
  1074. plan, err := getSubscriptionPlanByIdTx(nil, sub.PlanId)
  1075. if err != nil {
  1076. return nil, err
  1077. }
  1078. info := &SubscriptionPlanInfo{
  1079. PlanId: sub.PlanId,
  1080. PlanTitle: plan.Title,
  1081. }
  1082. _ = getSubscriptionPlanInfoCache().SetWithTTL(cacheKey, *info, subscriptionPlanInfoCacheTTL())
  1083. return info, nil
  1084. }
  1085. // Update subscription used amount by delta (positive consume more, negative refund).
  1086. func PostConsumeUserSubscriptionDelta(userSubscriptionId int, delta int64) error {
  1087. if userSubscriptionId <= 0 {
  1088. return errors.New("invalid userSubscriptionId")
  1089. }
  1090. if delta == 0 {
  1091. return nil
  1092. }
  1093. return DB.Transaction(func(tx *gorm.DB) error {
  1094. var sub UserSubscription
  1095. if err := tx.Set("gorm:query_option", "FOR UPDATE").
  1096. Where("id = ?", userSubscriptionId).
  1097. First(&sub).Error; err != nil {
  1098. return err
  1099. }
  1100. newUsed := sub.AmountUsed + delta
  1101. if newUsed < 0 {
  1102. newUsed = 0
  1103. }
  1104. if sub.AmountTotal > 0 && newUsed > sub.AmountTotal {
  1105. return fmt.Errorf("subscription used exceeds total, used=%d total=%d", newUsed, sub.AmountTotal)
  1106. }
  1107. sub.AmountUsed = newUsed
  1108. return tx.Save(&sub).Error
  1109. })
  1110. }