subscription.go 34 KB


  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "strconv"
  6. "strings"
  7. "sync"
  8. "time"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/pkg/cachex"
  11. "github.com/samber/hot"
  12. "gorm.io/gorm"
  13. )
  14. // Subscription duration units
  15. const (
  16. SubscriptionDurationYear = "year"
  17. SubscriptionDurationMonth = "month"
  18. SubscriptionDurationDay = "day"
  19. SubscriptionDurationHour = "hour"
  20. SubscriptionDurationCustom = "custom"
  21. )
  22. // Subscription quota reset period
  23. const (
  24. SubscriptionResetNever = "never"
  25. SubscriptionResetDaily = "daily"
  26. SubscriptionResetWeekly = "weekly"
  27. SubscriptionResetMonthly = "monthly"
  28. SubscriptionResetCustom = "custom"
  29. )
  30. var (
  31. ErrSubscriptionOrderNotFound = errors.New("subscription order not found")
  32. ErrSubscriptionOrderStatusInvalid = errors.New("subscription order status invalid")
  33. )
  34. const (
  35. subscriptionPlanCacheNamespace = "new-api:subscription_plan:v1"
  36. subscriptionPlanInfoCacheNamespace = "new-api:subscription_plan_info:v1"
  37. )
  38. var (
  39. subscriptionPlanCacheOnce sync.Once
  40. subscriptionPlanInfoCacheOnce sync.Once
  41. subscriptionPlanCache *cachex.HybridCache[SubscriptionPlan]
  42. subscriptionPlanInfoCache *cachex.HybridCache[SubscriptionPlanInfo]
  43. )
  44. func subscriptionPlanCacheTTL() time.Duration {
  45. ttlSeconds := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_CACHE_TTL", 300)
  46. if ttlSeconds <= 0 {
  47. ttlSeconds = 300
  48. }
  49. return time.Duration(ttlSeconds) * time.Second
  50. }
  51. func subscriptionPlanInfoCacheTTL() time.Duration {
  52. ttlSeconds := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_INFO_CACHE_TTL", 120)
  53. if ttlSeconds <= 0 {
  54. ttlSeconds = 120
  55. }
  56. return time.Duration(ttlSeconds) * time.Second
  57. }
  58. func subscriptionPlanCacheCapacity() int {
  59. capacity := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_CACHE_CAP", 5000)
  60. if capacity <= 0 {
  61. capacity = 5000
  62. }
  63. return capacity
  64. }
  65. func subscriptionPlanInfoCacheCapacity() int {
  66. capacity := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_INFO_CACHE_CAP", 10000)
  67. if capacity <= 0 {
  68. capacity = 10000
  69. }
  70. return capacity
  71. }
  72. func getSubscriptionPlanCache() *cachex.HybridCache[SubscriptionPlan] {
  73. subscriptionPlanCacheOnce.Do(func() {
  74. ttl := subscriptionPlanCacheTTL()
  75. subscriptionPlanCache = cachex.NewHybridCache[SubscriptionPlan](cachex.HybridCacheConfig[SubscriptionPlan]{
  76. Namespace: cachex.Namespace(subscriptionPlanCacheNamespace),
  77. Redis: common.RDB,
  78. RedisEnabled: func() bool {
  79. return common.RedisEnabled && common.RDB != nil
  80. },
  81. RedisCodec: cachex.JSONCodec[SubscriptionPlan]{},
  82. Memory: func() *hot.HotCache[string, SubscriptionPlan] {
  83. return hot.NewHotCache[string, SubscriptionPlan](hot.LRU, subscriptionPlanCacheCapacity()).
  84. WithTTL(ttl).
  85. WithJanitor().
  86. Build()
  87. },
  88. })
  89. })
  90. return subscriptionPlanCache
  91. }
  92. func getSubscriptionPlanInfoCache() *cachex.HybridCache[SubscriptionPlanInfo] {
  93. subscriptionPlanInfoCacheOnce.Do(func() {
  94. ttl := subscriptionPlanInfoCacheTTL()
  95. subscriptionPlanInfoCache = cachex.NewHybridCache[SubscriptionPlanInfo](cachex.HybridCacheConfig[SubscriptionPlanInfo]{
  96. Namespace: cachex.Namespace(subscriptionPlanInfoCacheNamespace),
  97. Redis: common.RDB,
  98. RedisEnabled: func() bool {
  99. return common.RedisEnabled && common.RDB != nil
  100. },
  101. RedisCodec: cachex.JSONCodec[SubscriptionPlanInfo]{},
  102. Memory: func() *hot.HotCache[string, SubscriptionPlanInfo] {
  103. return hot.NewHotCache[string, SubscriptionPlanInfo](hot.LRU, subscriptionPlanInfoCacheCapacity()).
  104. WithTTL(ttl).
  105. WithJanitor().
  106. Build()
  107. },
  108. })
  109. })
  110. return subscriptionPlanInfoCache
  111. }
  112. func subscriptionPlanCacheKey(id int) string {
  113. if id <= 0 {
  114. return ""
  115. }
  116. return strconv.Itoa(id)
  117. }
  118. func InvalidateSubscriptionPlanCache(planId int) {
  119. if planId <= 0 {
  120. return
  121. }
  122. cache := getSubscriptionPlanCache()
  123. _, _ = cache.DeleteMany([]string{subscriptionPlanCacheKey(planId)})
  124. infoCache := getSubscriptionPlanInfoCache()
  125. _ = infoCache.Purge()
  126. }
  127. // Subscription plan
  128. type SubscriptionPlan struct {
  129. Id int `json:"id"`
  130. Title string `json:"title" gorm:"type:varchar(128);not null"`
  131. Subtitle string `json:"subtitle" gorm:"type:varchar(255);default:''"`
  132. // Display money amount (follow existing code style: float64 for money)
  133. PriceAmount float64 `json:"price_amount" gorm:"type:decimal(10,6);not null;default:0"`
  134. Currency string `json:"currency" gorm:"type:varchar(8);not null;default:'USD'"`
  135. DurationUnit string `json:"duration_unit" gorm:"type:varchar(16);not null;default:'month'"`
  136. DurationValue int `json:"duration_value" gorm:"type:int;not null;default:1"`
  137. CustomSeconds int64 `json:"custom_seconds" gorm:"type:bigint;not null;default:0"`
  138. Enabled bool `json:"enabled" gorm:"default:true"`
  139. SortOrder int `json:"sort_order" gorm:"type:int;default:0"`
  140. StripePriceId string `json:"stripe_price_id" gorm:"type:varchar(128);default:''"`
  141. CreemProductId string `json:"creem_product_id" gorm:"type:varchar(128);default:''"`
  142. // Max purchases per user (0 = unlimited)
  143. MaxPurchasePerUser int `json:"max_purchase_per_user" gorm:"type:int;default:0"`
  144. // Upgrade user group after purchase (empty = no change)
  145. UpgradeGroup string `json:"upgrade_group" gorm:"type:varchar(64);default:''"`
  146. // Total quota (amount in quota units, 0 = unlimited)
  147. TotalAmount int64 `json:"total_amount" gorm:"type:bigint;not null;default:0"`
  148. // Quota reset period for plan
  149. QuotaResetPeriod string `json:"quota_reset_period" gorm:"type:varchar(16);default:'never'"`
  150. QuotaResetCustomSeconds int64 `json:"quota_reset_custom_seconds" gorm:"type:bigint;default:0"`
  151. CreatedAt int64 `json:"created_at" gorm:"bigint"`
  152. UpdatedAt int64 `json:"updated_at" gorm:"bigint"`
  153. }
  154. func (p *SubscriptionPlan) BeforeCreate(tx *gorm.DB) error {
  155. now := common.GetTimestamp()
  156. p.CreatedAt = now
  157. p.UpdatedAt = now
  158. return nil
  159. }
  160. func (p *SubscriptionPlan) BeforeUpdate(tx *gorm.DB) error {
  161. p.UpdatedAt = common.GetTimestamp()
  162. return nil
  163. }
  164. // Subscription order (payment -> webhook -> create UserSubscription)
  165. type SubscriptionOrder struct {
  166. Id int `json:"id"`
  167. UserId int `json:"user_id" gorm:"index"`
  168. PlanId int `json:"plan_id" gorm:"index"`
  169. Money float64 `json:"money"`
  170. TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
  171. PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
  172. Status string `json:"status"`
  173. CreateTime int64 `json:"create_time"`
  174. CompleteTime int64 `json:"complete_time"`
  175. ProviderPayload string `json:"provider_payload" gorm:"type:text"`
  176. }
  177. func (o *SubscriptionOrder) Insert() error {
  178. if o.CreateTime == 0 {
  179. o.CreateTime = common.GetTimestamp()
  180. }
  181. return DB.Create(o).Error
  182. }
  183. func (o *SubscriptionOrder) Update() error {
  184. return DB.Save(o).Error
  185. }
  186. func GetSubscriptionOrderByTradeNo(tradeNo string) *SubscriptionOrder {
  187. if tradeNo == "" {
  188. return nil
  189. }
  190. var order SubscriptionOrder
  191. if err := DB.Where("trade_no = ?", tradeNo).First(&order).Error; err != nil {
  192. return nil
  193. }
  194. return &order
  195. }
  196. // User subscription instance
  197. type UserSubscription struct {
  198. Id int `json:"id"`
  199. UserId int `json:"user_id" gorm:"index;index:idx_user_sub_active,priority:1"`
  200. PlanId int `json:"plan_id" gorm:"index"`
  201. AmountTotal int64 `json:"amount_total" gorm:"type:bigint;not null;default:0"`
  202. AmountUsed int64 `json:"amount_used" gorm:"type:bigint;not null;default:0"`
  203. StartTime int64 `json:"start_time" gorm:"bigint"`
  204. EndTime int64 `json:"end_time" gorm:"bigint;index;index:idx_user_sub_active,priority:3"`
  205. Status string `json:"status" gorm:"type:varchar(32);index;index:idx_user_sub_active,priority:2"` // active/expired/cancelled
  206. Source string `json:"source" gorm:"type:varchar(32);default:'order'"` // order/admin
  207. LastResetTime int64 `json:"last_reset_time" gorm:"type:bigint;default:0"`
  208. NextResetTime int64 `json:"next_reset_time" gorm:"type:bigint;default:0;index"`
  209. UpgradeGroup string `json:"upgrade_group" gorm:"type:varchar(64);default:''"`
  210. PrevUserGroup string `json:"prev_user_group" gorm:"type:varchar(64);default:''"`
  211. CreatedAt int64 `json:"created_at" gorm:"bigint"`
  212. UpdatedAt int64 `json:"updated_at" gorm:"bigint"`
  213. }
  214. func (s *UserSubscription) BeforeCreate(tx *gorm.DB) error {
  215. now := common.GetTimestamp()
  216. s.CreatedAt = now
  217. s.UpdatedAt = now
  218. return nil
  219. }
  220. func (s *UserSubscription) BeforeUpdate(tx *gorm.DB) error {
  221. s.UpdatedAt = common.GetTimestamp()
  222. return nil
  223. }
  224. type SubscriptionSummary struct {
  225. Subscription *UserSubscription `json:"subscription"`
  226. }
  227. func calcPlanEndTime(start time.Time, plan *SubscriptionPlan) (int64, error) {
  228. if plan == nil {
  229. return 0, errors.New("plan is nil")
  230. }
  231. if plan.DurationValue <= 0 && plan.DurationUnit != SubscriptionDurationCustom {
  232. return 0, errors.New("duration_value must be > 0")
  233. }
  234. switch plan.DurationUnit {
  235. case SubscriptionDurationYear:
  236. return start.AddDate(plan.DurationValue, 0, 0).Unix(), nil
  237. case SubscriptionDurationMonth:
  238. return start.AddDate(0, plan.DurationValue, 0).Unix(), nil
  239. case SubscriptionDurationDay:
  240. return start.Add(time.Duration(plan.DurationValue) * 24 * time.Hour).Unix(), nil
  241. case SubscriptionDurationHour:
  242. return start.Add(time.Duration(plan.DurationValue) * time.Hour).Unix(), nil
  243. case SubscriptionDurationCustom:
  244. if plan.CustomSeconds <= 0 {
  245. return 0, errors.New("custom_seconds must be > 0")
  246. }
  247. return start.Add(time.Duration(plan.CustomSeconds) * time.Second).Unix(), nil
  248. default:
  249. return 0, fmt.Errorf("invalid duration_unit: %s", plan.DurationUnit)
  250. }
  251. }
  252. func NormalizeResetPeriod(period string) string {
  253. switch strings.TrimSpace(period) {
  254. case SubscriptionResetDaily, SubscriptionResetWeekly, SubscriptionResetMonthly, SubscriptionResetCustom:
  255. return strings.TrimSpace(period)
  256. default:
  257. return SubscriptionResetNever
  258. }
  259. }
  260. func calcNextResetTime(base time.Time, plan *SubscriptionPlan, endUnix int64) int64 {
  261. if plan == nil {
  262. return 0
  263. }
  264. period := NormalizeResetPeriod(plan.QuotaResetPeriod)
  265. if period == SubscriptionResetNever {
  266. return 0
  267. }
  268. var next time.Time
  269. switch period {
  270. case SubscriptionResetDaily:
  271. next = time.Date(base.Year(), base.Month(), base.Day(), 0, 0, 0, 0, base.Location()).
  272. AddDate(0, 0, 1)
  273. case SubscriptionResetWeekly:
  274. // Align to next Monday 00:00
  275. weekday := int(base.Weekday()) // Sunday=0
  276. // Convert to Monday=1..Sunday=7
  277. if weekday == 0 {
  278. weekday = 7
  279. }
  280. daysUntil := 8 - weekday
  281. next = time.Date(base.Year(), base.Month(), base.Day(), 0, 0, 0, 0, base.Location()).
  282. AddDate(0, 0, daysUntil)
  283. case SubscriptionResetMonthly:
  284. // Align to first day of next month 00:00
  285. next = time.Date(base.Year(), base.Month(), 1, 0, 0, 0, 0, base.Location()).
  286. AddDate(0, 1, 0)
  287. case SubscriptionResetCustom:
  288. if plan.QuotaResetCustomSeconds <= 0 {
  289. return 0
  290. }
  291. next = base.Add(time.Duration(plan.QuotaResetCustomSeconds) * time.Second)
  292. default:
  293. return 0
  294. }
  295. if endUnix > 0 && next.Unix() > endUnix {
  296. return 0
  297. }
  298. return next.Unix()
  299. }
  300. func GetSubscriptionPlanById(id int) (*SubscriptionPlan, error) {
  301. return getSubscriptionPlanByIdTx(nil, id)
  302. }
  303. func getSubscriptionPlanByIdTx(tx *gorm.DB, id int) (*SubscriptionPlan, error) {
  304. if id <= 0 {
  305. return nil, errors.New("invalid plan id")
  306. }
  307. key := subscriptionPlanCacheKey(id)
  308. if key != "" {
  309. if cached, found, err := getSubscriptionPlanCache().Get(key); err == nil && found {
  310. return &cached, nil
  311. }
  312. }
  313. var plan SubscriptionPlan
  314. query := DB
  315. if tx != nil {
  316. query = tx
  317. }
  318. if err := query.Where("id = ?", id).First(&plan).Error; err != nil {
  319. return nil, err
  320. }
  321. _ = getSubscriptionPlanCache().SetWithTTL(key, plan, subscriptionPlanCacheTTL())
  322. return &plan, nil
  323. }
  324. func CountUserSubscriptionsByPlan(userId int, planId int) (int64, error) {
  325. if userId <= 0 || planId <= 0 {
  326. return 0, errors.New("invalid userId or planId")
  327. }
  328. var count int64
  329. if err := DB.Model(&UserSubscription{}).
  330. Where("user_id = ? AND plan_id = ?", userId, planId).
  331. Count(&count).Error; err != nil {
  332. return 0, err
  333. }
  334. return count, nil
  335. }
  336. func getUserGroupByIdTx(tx *gorm.DB, userId int) (string, error) {
  337. if userId <= 0 {
  338. return "", errors.New("invalid userId")
  339. }
  340. if tx == nil {
  341. tx = DB
  342. }
  343. var group string
  344. if err := tx.Model(&User{}).Where("id = ?", userId).Select(commonGroupCol).Find(&group).Error; err != nil {
  345. return "", err
  346. }
  347. return group, nil
  348. }
  349. func downgradeUserGroupForSubscriptionTx(tx *gorm.DB, sub *UserSubscription, now int64) (string, error) {
  350. if tx == nil || sub == nil {
  351. return "", errors.New("invalid downgrade args")
  352. }
  353. upgradeGroup := strings.TrimSpace(sub.UpgradeGroup)
  354. if upgradeGroup == "" {
  355. return "", nil
  356. }
  357. currentGroup, err := getUserGroupByIdTx(tx, sub.UserId)
  358. if err != nil {
  359. return "", err
  360. }
  361. if currentGroup != upgradeGroup {
  362. return "", nil
  363. }
  364. var activeSub UserSubscription
  365. activeQuery := tx.Where("user_id = ? AND status = ? AND end_time > ? AND id <> ? AND upgrade_group <> ''",
  366. sub.UserId, "active", now, sub.Id).
  367. Order("end_time desc, id desc").
  368. Limit(1).
  369. Find(&activeSub)
  370. if activeQuery.Error == nil && activeQuery.RowsAffected > 0 {
  371. return "", nil
  372. }
  373. prevGroup := strings.TrimSpace(sub.PrevUserGroup)
  374. if prevGroup == "" || prevGroup == currentGroup {
  375. return "", nil
  376. }
  377. if err := tx.Model(&User{}).Where("id = ?", sub.UserId).
  378. Update("group", prevGroup).Error; err != nil {
  379. return "", err
  380. }
  381. return prevGroup, nil
  382. }
  383. func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *SubscriptionPlan, source string) (*UserSubscription, error) {
  384. if tx == nil {
  385. return nil, errors.New("tx is nil")
  386. }
  387. if plan == nil || plan.Id == 0 {
  388. return nil, errors.New("invalid plan")
  389. }
  390. if userId <= 0 {
  391. return nil, errors.New("invalid user id")
  392. }
  393. if plan.MaxPurchasePerUser > 0 {
  394. var count int64
  395. if err := tx.Model(&UserSubscription{}).
  396. Where("user_id = ? AND plan_id = ?", userId, plan.Id).
  397. Count(&count).Error; err != nil {
  398. return nil, err
  399. }
  400. if count >= int64(plan.MaxPurchasePerUser) {
  401. return nil, errors.New("已达到该套餐购买上限")
  402. }
  403. }
  404. nowUnix := GetDBTimestamp()
  405. now := time.Unix(nowUnix, 0)
  406. endUnix, err := calcPlanEndTime(now, plan)
  407. if err != nil {
  408. return nil, err
  409. }
  410. resetBase := now
  411. nextReset := calcNextResetTime(resetBase, plan, endUnix)
  412. lastReset := int64(0)
  413. if nextReset > 0 {
  414. lastReset = now.Unix()
  415. }
  416. upgradeGroup := strings.TrimSpace(plan.UpgradeGroup)
  417. prevGroup := ""
  418. if upgradeGroup != "" {
  419. currentGroup, err := getUserGroupByIdTx(tx, userId)
  420. if err != nil {
  421. return nil, err
  422. }
  423. if currentGroup != upgradeGroup {
  424. prevGroup = currentGroup
  425. if err := tx.Model(&User{}).Where("id = ?", userId).
  426. Update("group", upgradeGroup).Error; err != nil {
  427. return nil, err
  428. }
  429. }
  430. }
  431. sub := &UserSubscription{
  432. UserId: userId,
  433. PlanId: plan.Id,
  434. AmountTotal: plan.TotalAmount,
  435. AmountUsed: 0,
  436. StartTime: now.Unix(),
  437. EndTime: endUnix,
  438. Status: "active",
  439. Source: source,
  440. LastResetTime: lastReset,
  441. NextResetTime: nextReset,
  442. UpgradeGroup: upgradeGroup,
  443. PrevUserGroup: prevGroup,
  444. CreatedAt: common.GetTimestamp(),
  445. UpdatedAt: common.GetTimestamp(),
  446. }
  447. if err := tx.Create(sub).Error; err != nil {
  448. return nil, err
  449. }
  450. return sub, nil
  451. }
  452. // Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan.
  453. func CompleteSubscriptionOrder(tradeNo string, providerPayload string) error {
  454. if tradeNo == "" {
  455. return errors.New("tradeNo is empty")
  456. }
  457. refCol := "`trade_no`"
  458. if common.UsingPostgreSQL {
  459. refCol = `"trade_no"`
  460. }
  461. var logUserId int
  462. var logPlanTitle string
  463. var logMoney float64
  464. var logPaymentMethod string
  465. var upgradeGroup string
  466. err := DB.Transaction(func(tx *gorm.DB) error {
  467. var order SubscriptionOrder
  468. if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
  469. return ErrSubscriptionOrderNotFound
  470. }
  471. if order.Status == common.TopUpStatusSuccess {
  472. return nil
  473. }
  474. if order.Status != common.TopUpStatusPending {
  475. return ErrSubscriptionOrderStatusInvalid
  476. }
  477. plan, err := GetSubscriptionPlanById(order.PlanId)
  478. if err != nil {
  479. return err
  480. }
  481. if !plan.Enabled {
  482. // still allow completion for already purchased orders
  483. }
  484. upgradeGroup = strings.TrimSpace(plan.UpgradeGroup)
  485. _, err = CreateUserSubscriptionFromPlanTx(tx, order.UserId, plan, "order")
  486. if err != nil {
  487. return err
  488. }
  489. if err := upsertSubscriptionTopUpTx(tx, &order); err != nil {
  490. return err
  491. }
  492. order.Status = common.TopUpStatusSuccess
  493. order.CompleteTime = common.GetTimestamp()
  494. if providerPayload != "" {
  495. order.ProviderPayload = providerPayload
  496. }
  497. if err := tx.Save(&order).Error; err != nil {
  498. return err
  499. }
  500. logUserId = order.UserId
  501. logPlanTitle = plan.Title
  502. logMoney = order.Money
  503. logPaymentMethod = order.PaymentMethod
  504. return nil
  505. })
  506. if err != nil {
  507. return err
  508. }
  509. if upgradeGroup != "" && logUserId > 0 {
  510. _ = UpdateUserGroupCache(logUserId, upgradeGroup)
  511. }
  512. if logUserId > 0 {
  513. msg := fmt.Sprintf("订阅购买成功,套餐: %s,支付金额: %.2f,支付方式: %s", logPlanTitle, logMoney, logPaymentMethod)
  514. RecordLog(logUserId, LogTypeTopup, msg)
  515. }
  516. return nil
  517. }
  518. func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error {
  519. if tx == nil || order == nil {
  520. return errors.New("invalid subscription order")
  521. }
  522. now := common.GetTimestamp()
  523. var topup TopUp
  524. if err := tx.Where("trade_no = ?", order.TradeNo).First(&topup).Error; err != nil {
  525. if errors.Is(err, gorm.ErrRecordNotFound) {
  526. topup = TopUp{
  527. UserId: order.UserId,
  528. Amount: 0,
  529. Money: order.Money,
  530. TradeNo: order.TradeNo,
  531. PaymentMethod: order.PaymentMethod,
  532. CreateTime: order.CreateTime,
  533. CompleteTime: now,
  534. Status: common.TopUpStatusSuccess,
  535. }
  536. return tx.Create(&topup).Error
  537. }
  538. return err
  539. }
  540. topup.Money = order.Money
  541. if topup.PaymentMethod == "" {
  542. topup.PaymentMethod = order.PaymentMethod
  543. }
  544. if topup.CreateTime == 0 {
  545. topup.CreateTime = order.CreateTime
  546. }
  547. topup.CompleteTime = now
  548. topup.Status = common.TopUpStatusSuccess
  549. return tx.Save(&topup).Error
  550. }
  551. func ExpireSubscriptionOrder(tradeNo string) error {
  552. if tradeNo == "" {
  553. return errors.New("tradeNo is empty")
  554. }
  555. refCol := "`trade_no`"
  556. if common.UsingPostgreSQL {
  557. refCol = `"trade_no"`
  558. }
  559. return DB.Transaction(func(tx *gorm.DB) error {
  560. var order SubscriptionOrder
  561. if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
  562. return ErrSubscriptionOrderNotFound
  563. }
  564. if order.Status != common.TopUpStatusPending {
  565. return nil
  566. }
  567. order.Status = common.TopUpStatusExpired
  568. order.CompleteTime = common.GetTimestamp()
  569. return tx.Save(&order).Error
  570. })
  571. }
  572. // Admin bind (no payment). Creates a UserSubscription from a plan.
  573. func AdminBindSubscription(userId int, planId int, sourceNote string) (string, error) {
  574. if userId <= 0 || planId <= 0 {
  575. return "", errors.New("invalid userId or planId")
  576. }
  577. plan, err := GetSubscriptionPlanById(planId)
  578. if err != nil {
  579. return "", err
  580. }
  581. err = DB.Transaction(func(tx *gorm.DB) error {
  582. _, err := CreateUserSubscriptionFromPlanTx(tx, userId, plan, "admin")
  583. return err
  584. })
  585. if err != nil {
  586. return "", err
  587. }
  588. if strings.TrimSpace(plan.UpgradeGroup) != "" {
  589. _ = UpdateUserGroupCache(userId, plan.UpgradeGroup)
  590. return fmt.Sprintf("用户分组将升级到 %s", plan.UpgradeGroup), nil
  591. }
  592. return "", nil
  593. }
  594. // GetAllActiveUserSubscriptions returns all active subscriptions for a user.
  595. func GetAllActiveUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
  596. if userId <= 0 {
  597. return nil, errors.New("invalid userId")
  598. }
  599. now := common.GetTimestamp()
  600. var subs []UserSubscription
  601. err := DB.Where("user_id = ? AND status = ? AND end_time > ?", userId, "active", now).
  602. Order("end_time desc, id desc").
  603. Find(&subs).Error
  604. if err != nil {
  605. return nil, err
  606. }
  607. return buildSubscriptionSummaries(subs), nil
  608. }
  609. // GetAllUserSubscriptions returns all subscriptions (active and expired) for a user.
  610. func GetAllUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
  611. if userId <= 0 {
  612. return nil, errors.New("invalid userId")
  613. }
  614. var subs []UserSubscription
  615. err := DB.Where("user_id = ?", userId).
  616. Order("end_time desc, id desc").
  617. Find(&subs).Error
  618. if err != nil {
  619. return nil, err
  620. }
  621. return buildSubscriptionSummaries(subs), nil
  622. }
  623. func buildSubscriptionSummaries(subs []UserSubscription) []SubscriptionSummary {
  624. if len(subs) == 0 {
  625. return []SubscriptionSummary{}
  626. }
  627. result := make([]SubscriptionSummary, 0, len(subs))
  628. for _, sub := range subs {
  629. subCopy := sub
  630. result = append(result, SubscriptionSummary{
  631. Subscription: &subCopy,
  632. })
  633. }
  634. return result
  635. }
  636. // AdminInvalidateUserSubscription marks a user subscription as cancelled and ends it immediately.
  637. func AdminInvalidateUserSubscription(userSubscriptionId int) (string, error) {
  638. if userSubscriptionId <= 0 {
  639. return "", errors.New("invalid userSubscriptionId")
  640. }
  641. now := common.GetTimestamp()
  642. cacheGroup := ""
  643. downgradeGroup := ""
  644. var userId int
  645. err := DB.Transaction(func(tx *gorm.DB) error {
  646. var sub UserSubscription
  647. if err := tx.Set("gorm:query_option", "FOR UPDATE").
  648. Where("id = ?", userSubscriptionId).First(&sub).Error; err != nil {
  649. return err
  650. }
  651. userId = sub.UserId
  652. if err := tx.Model(&sub).Updates(map[string]interface{}{
  653. "status": "cancelled",
  654. "end_time": now,
  655. "updated_at": now,
  656. }).Error; err != nil {
  657. return err
  658. }
  659. target, err := downgradeUserGroupForSubscriptionTx(tx, &sub, now)
  660. if err != nil {
  661. return err
  662. }
  663. if target != "" {
  664. cacheGroup = target
  665. downgradeGroup = target
  666. }
  667. return nil
  668. })
  669. if err != nil {
  670. return "", err
  671. }
  672. if cacheGroup != "" && userId > 0 {
  673. _ = UpdateUserGroupCache(userId, cacheGroup)
  674. }
  675. if downgradeGroup != "" {
  676. return fmt.Sprintf("用户分组将回退到 %s", downgradeGroup), nil
  677. }
  678. return "", nil
  679. }
  680. // AdminDeleteUserSubscription hard-deletes a user subscription.
  681. func AdminDeleteUserSubscription(userSubscriptionId int) (string, error) {
  682. if userSubscriptionId <= 0 {
  683. return "", errors.New("invalid userSubscriptionId")
  684. }
  685. now := common.GetTimestamp()
  686. cacheGroup := ""
  687. downgradeGroup := ""
  688. var userId int
  689. err := DB.Transaction(func(tx *gorm.DB) error {
  690. var sub UserSubscription
  691. if err := tx.Set("gorm:query_option", "FOR UPDATE").
  692. Where("id = ?", userSubscriptionId).First(&sub).Error; err != nil {
  693. return err
  694. }
  695. userId = sub.UserId
  696. target, err := downgradeUserGroupForSubscriptionTx(tx, &sub, now)
  697. if err != nil {
  698. return err
  699. }
  700. if target != "" {
  701. cacheGroup = target
  702. downgradeGroup = target
  703. }
  704. if err := tx.Where("id = ?", userSubscriptionId).Delete(&UserSubscription{}).Error; err != nil {
  705. return err
  706. }
  707. return nil
  708. })
  709. if err != nil {
  710. return "", err
  711. }
  712. if cacheGroup != "" && userId > 0 {
  713. _ = UpdateUserGroupCache(userId, cacheGroup)
  714. }
  715. if downgradeGroup != "" {
  716. return fmt.Sprintf("用户分组将回退到 %s", downgradeGroup), nil
  717. }
  718. return "", nil
  719. }
  720. type SubscriptionPreConsumeResult struct {
  721. UserSubscriptionId int
  722. PreConsumed int64
  723. AmountTotal int64
  724. AmountUsedBefore int64
  725. AmountUsedAfter int64
  726. }
  727. // ExpireDueSubscriptions marks expired subscriptions and handles group downgrade.
  728. func ExpireDueSubscriptions(limit int) (int, error) {
  729. if limit <= 0 {
  730. limit = 200
  731. }
  732. now := GetDBTimestamp()
  733. var subs []UserSubscription
  734. if err := DB.Where("status = ? AND end_time > 0 AND end_time <= ?", "active", now).
  735. Order("end_time asc, id asc").
  736. Limit(limit).
  737. Find(&subs).Error; err != nil {
  738. return 0, err
  739. }
  740. if len(subs) == 0 {
  741. return 0, nil
  742. }
  743. expiredCount := 0
  744. userIds := make(map[int]struct{}, len(subs))
  745. for _, sub := range subs {
  746. if sub.UserId > 0 {
  747. userIds[sub.UserId] = struct{}{}
  748. }
  749. }
  750. for userId := range userIds {
  751. cacheGroup := ""
  752. err := DB.Transaction(func(tx *gorm.DB) error {
  753. res := tx.Model(&UserSubscription{}).
  754. Where("user_id = ? AND status = ? AND end_time > 0 AND end_time <= ?", userId, "active", now).
  755. Updates(map[string]interface{}{
  756. "status": "expired",
  757. "updated_at": common.GetTimestamp(),
  758. })
  759. if res.Error != nil {
  760. return res.Error
  761. }
  762. expiredCount += int(res.RowsAffected)
  763. // If there's an active upgraded subscription, keep current group.
  764. var activeSub UserSubscription
  765. activeQuery := tx.Where("user_id = ? AND status = ? AND end_time > ? AND upgrade_group <> ''",
  766. userId, "active", now).
  767. Order("end_time desc, id desc").
  768. Limit(1).
  769. Find(&activeSub)
  770. if activeQuery.Error == nil && activeQuery.RowsAffected > 0 {
  771. return nil
  772. }
  773. // No active upgraded subscription, downgrade to previous group if needed.
  774. var lastExpired UserSubscription
  775. expiredQuery := tx.Where("user_id = ? AND status = ? AND upgrade_group <> ''",
  776. userId, "expired").
  777. Order("end_time desc, id desc").
  778. Limit(1).
  779. Find(&lastExpired)
  780. if expiredQuery.Error != nil || expiredQuery.RowsAffected == 0 {
  781. return nil
  782. }
  783. upgradeGroup := strings.TrimSpace(lastExpired.UpgradeGroup)
  784. prevGroup := strings.TrimSpace(lastExpired.PrevUserGroup)
  785. if upgradeGroup == "" || prevGroup == "" {
  786. return nil
  787. }
  788. currentGroup, err := getUserGroupByIdTx(tx, userId)
  789. if err != nil {
  790. return err
  791. }
  792. if currentGroup != upgradeGroup || currentGroup == prevGroup {
  793. return nil
  794. }
  795. if err := tx.Model(&User{}).Where("id = ?", userId).
  796. Update("group", prevGroup).Error; err != nil {
  797. return err
  798. }
  799. cacheGroup = prevGroup
  800. return nil
  801. })
  802. if err != nil {
  803. return expiredCount, err
  804. }
  805. if cacheGroup != "" {
  806. _ = UpdateUserGroupCache(userId, cacheGroup)
  807. }
  808. }
  809. return expiredCount, nil
  810. }
  811. // SubscriptionPreConsumeRecord stores idempotent pre-consume operations per request.
  812. type SubscriptionPreConsumeRecord struct {
  813. Id int `json:"id"`
  814. RequestId string `json:"request_id" gorm:"type:varchar(64);uniqueIndex"`
  815. UserId int `json:"user_id" gorm:"index"`
  816. UserSubscriptionId int `json:"user_subscription_id" gorm:"index"`
  817. PreConsumed int64 `json:"pre_consumed" gorm:"type:bigint;not null;default:0"`
  818. Status string `json:"status" gorm:"type:varchar(32);index"` // consumed/refunded
  819. CreatedAt int64 `json:"created_at" gorm:"bigint"`
  820. UpdatedAt int64 `json:"updated_at" gorm:"bigint;index"`
  821. }
  822. func (r *SubscriptionPreConsumeRecord) BeforeCreate(tx *gorm.DB) error {
  823. now := common.GetTimestamp()
  824. r.CreatedAt = now
  825. r.UpdatedAt = now
  826. return nil
  827. }
  828. func (r *SubscriptionPreConsumeRecord) BeforeUpdate(tx *gorm.DB) error {
  829. r.UpdatedAt = common.GetTimestamp()
  830. return nil
  831. }
  832. func maybeResetUserSubscriptionWithPlanTx(tx *gorm.DB, sub *UserSubscription, plan *SubscriptionPlan, now int64) error {
  833. if tx == nil || sub == nil || plan == nil {
  834. return errors.New("invalid reset args")
  835. }
  836. if sub.NextResetTime > 0 && sub.NextResetTime > now {
  837. return nil
  838. }
  839. if NormalizeResetPeriod(plan.QuotaResetPeriod) == SubscriptionResetNever {
  840. return nil
  841. }
  842. baseUnix := sub.LastResetTime
  843. if baseUnix <= 0 {
  844. baseUnix = sub.StartTime
  845. }
  846. base := time.Unix(baseUnix, 0)
  847. next := calcNextResetTime(base, plan, sub.EndTime)
  848. advanced := false
  849. for next > 0 && next <= now {
  850. advanced = true
  851. base = time.Unix(next, 0)
  852. next = calcNextResetTime(base, plan, sub.EndTime)
  853. }
  854. if !advanced {
  855. if sub.NextResetTime == 0 && next > 0 {
  856. sub.NextResetTime = next
  857. sub.LastResetTime = base.Unix()
  858. return tx.Save(sub).Error
  859. }
  860. return nil
  861. }
  862. sub.AmountUsed = 0
  863. sub.LastResetTime = base.Unix()
  864. sub.NextResetTime = next
  865. return tx.Save(sub).Error
  866. }
  867. // PreConsumeUserSubscription pre-consumes from any active subscription total quota.
  868. func PreConsumeUserSubscription(requestId string, userId int, modelName string, quotaType int, amount int64) (*SubscriptionPreConsumeResult, error) {
  869. if userId <= 0 {
  870. return nil, errors.New("invalid userId")
  871. }
  872. if strings.TrimSpace(requestId) == "" {
  873. return nil, errors.New("requestId is empty")
  874. }
  875. if amount <= 0 {
  876. return nil, errors.New("amount must be > 0")
  877. }
  878. now := GetDBTimestamp()
  879. returnValue := &SubscriptionPreConsumeResult{}
  880. err := DB.Transaction(func(tx *gorm.DB) error {
  881. var existing SubscriptionPreConsumeRecord
  882. query := tx.Where("request_id = ?", requestId).Limit(1).Find(&existing)
  883. if query.Error != nil {
  884. return query.Error
  885. }
  886. if query.RowsAffected > 0 {
  887. if existing.Status == "refunded" {
  888. return errors.New("subscription pre-consume already refunded")
  889. }
  890. var sub UserSubscription
  891. if err := tx.Where("id = ?", existing.UserSubscriptionId).First(&sub).Error; err != nil {
  892. return err
  893. }
  894. returnValue.UserSubscriptionId = sub.Id
  895. returnValue.PreConsumed = existing.PreConsumed
  896. returnValue.AmountTotal = sub.AmountTotal
  897. returnValue.AmountUsedBefore = sub.AmountUsed
  898. returnValue.AmountUsedAfter = sub.AmountUsed
  899. return nil
  900. }
  901. var subs []UserSubscription
  902. if err := tx.Set("gorm:query_option", "FOR UPDATE").
  903. Where("user_id = ? AND status = ? AND end_time > ?", userId, "active", now).
  904. Order("end_time asc, id asc").
  905. Find(&subs).Error; err != nil {
  906. return errors.New("no active subscription")
  907. }
  908. if len(subs) == 0 {
  909. return errors.New("no active subscription")
  910. }
  911. for _, candidate := range subs {
  912. sub := candidate
  913. plan, err := getSubscriptionPlanByIdTx(tx, sub.PlanId)
  914. if err != nil {
  915. return err
  916. }
  917. if err := maybeResetUserSubscriptionWithPlanTx(tx, &sub, plan, now); err != nil {
  918. return err
  919. }
  920. usedBefore := sub.AmountUsed
  921. if sub.AmountTotal > 0 {
  922. remain := sub.AmountTotal - usedBefore
  923. if remain < amount {
  924. continue
  925. }
  926. }
  927. record := &SubscriptionPreConsumeRecord{
  928. RequestId: requestId,
  929. UserId: userId,
  930. UserSubscriptionId: sub.Id,
  931. PreConsumed: amount,
  932. Status: "consumed",
  933. }
  934. if err := tx.Create(record).Error; err != nil {
  935. var dup SubscriptionPreConsumeRecord
  936. if err2 := tx.Where("request_id = ?", requestId).First(&dup).Error; err2 == nil {
  937. if dup.Status == "refunded" {
  938. return errors.New("subscription pre-consume already refunded")
  939. }
  940. returnValue.UserSubscriptionId = sub.Id
  941. returnValue.PreConsumed = dup.PreConsumed
  942. returnValue.AmountTotal = sub.AmountTotal
  943. returnValue.AmountUsedBefore = sub.AmountUsed
  944. returnValue.AmountUsedAfter = sub.AmountUsed
  945. return nil
  946. }
  947. return err
  948. }
  949. sub.AmountUsed += amount
  950. if err := tx.Save(&sub).Error; err != nil {
  951. return err
  952. }
  953. returnValue.UserSubscriptionId = sub.Id
  954. returnValue.PreConsumed = amount
  955. returnValue.AmountTotal = sub.AmountTotal
  956. returnValue.AmountUsedBefore = usedBefore
  957. returnValue.AmountUsedAfter = sub.AmountUsed
  958. return nil
  959. }
  960. return fmt.Errorf("subscription quota insufficient, need=%d", amount)
  961. })
  962. if err != nil {
  963. return nil, err
  964. }
  965. return returnValue, nil
  966. }
  967. // RefundSubscriptionPreConsume is idempotent and refunds pre-consumed subscription quota by requestId.
  968. func RefundSubscriptionPreConsume(requestId string) error {
  969. if strings.TrimSpace(requestId) == "" {
  970. return errors.New("requestId is empty")
  971. }
  972. return DB.Transaction(func(tx *gorm.DB) error {
  973. var record SubscriptionPreConsumeRecord
  974. if err := tx.Set("gorm:query_option", "FOR UPDATE").
  975. Where("request_id = ?", requestId).First(&record).Error; err != nil {
  976. return err
  977. }
  978. if record.Status == "refunded" {
  979. return nil
  980. }
  981. if record.PreConsumed <= 0 {
  982. record.Status = "refunded"
  983. return tx.Save(&record).Error
  984. }
  985. if err := PostConsumeUserSubscriptionDelta(record.UserSubscriptionId, -record.PreConsumed); err != nil {
  986. return err
  987. }
  988. record.Status = "refunded"
  989. return tx.Save(&record).Error
  990. })
  991. }
  992. // ResetDueSubscriptions resets subscriptions whose next_reset_time has passed.
  993. func ResetDueSubscriptions(limit int) (int, error) {
  994. if limit <= 0 {
  995. limit = 200
  996. }
  997. now := GetDBTimestamp()
  998. var subs []UserSubscription
  999. if err := DB.Where("next_reset_time > 0 AND next_reset_time <= ? AND status = ?", now, "active").
  1000. Order("next_reset_time asc").
  1001. Limit(limit).
  1002. Find(&subs).Error; err != nil {
  1003. return 0, err
  1004. }
  1005. if len(subs) == 0 {
  1006. return 0, nil
  1007. }
  1008. resetCount := 0
  1009. for _, sub := range subs {
  1010. subCopy := sub
  1011. plan, err := getSubscriptionPlanByIdTx(nil, sub.PlanId)
  1012. if err != nil || plan == nil {
  1013. continue
  1014. }
  1015. err = DB.Transaction(func(tx *gorm.DB) error {
  1016. var locked UserSubscription
  1017. if err := tx.Set("gorm:query_option", "FOR UPDATE").
  1018. Where("id = ? AND next_reset_time > 0 AND next_reset_time <= ?", subCopy.Id, now).
  1019. First(&locked).Error; err != nil {
  1020. return nil
  1021. }
  1022. if err := maybeResetUserSubscriptionWithPlanTx(tx, &locked, plan, now); err != nil {
  1023. return err
  1024. }
  1025. resetCount++
  1026. return nil
  1027. })
  1028. if err != nil {
  1029. return resetCount, err
  1030. }
  1031. }
  1032. return resetCount, nil
  1033. }
  1034. // CleanupSubscriptionPreConsumeRecords removes old idempotency records to keep table small.
  1035. func CleanupSubscriptionPreConsumeRecords(olderThanSeconds int64) (int64, error) {
  1036. if olderThanSeconds <= 0 {
  1037. olderThanSeconds = 7 * 24 * 3600
  1038. }
  1039. cutoff := GetDBTimestamp() - olderThanSeconds
  1040. res := DB.Where("updated_at < ?", cutoff).Delete(&SubscriptionPreConsumeRecord{})
  1041. return res.RowsAffected, res.Error
  1042. }
  1043. type SubscriptionPlanInfo struct {
  1044. PlanId int
  1045. PlanTitle string
  1046. }
  1047. func GetSubscriptionPlanInfoByUserSubscriptionId(userSubscriptionId int) (*SubscriptionPlanInfo, error) {
  1048. if userSubscriptionId <= 0 {
  1049. return nil, errors.New("invalid userSubscriptionId")
  1050. }
  1051. cacheKey := fmt.Sprintf("sub:%d", userSubscriptionId)
  1052. if cached, found, err := getSubscriptionPlanInfoCache().Get(cacheKey); err == nil && found {
  1053. return &cached, nil
  1054. }
  1055. var sub UserSubscription
  1056. if err := DB.Where("id = ?", userSubscriptionId).First(&sub).Error; err != nil {
  1057. return nil, err
  1058. }
  1059. plan, err := getSubscriptionPlanByIdTx(nil, sub.PlanId)
  1060. if err != nil {
  1061. return nil, err
  1062. }
  1063. info := &SubscriptionPlanInfo{
  1064. PlanId: sub.PlanId,
  1065. PlanTitle: plan.Title,
  1066. }
  1067. _ = getSubscriptionPlanInfoCache().SetWithTTL(cacheKey, *info, subscriptionPlanInfoCacheTTL())
  1068. return info, nil
  1069. }
  1070. // Update subscription used amount by delta (positive consume more, negative refund).
  1071. func PostConsumeUserSubscriptionDelta(userSubscriptionId int, delta int64) error {
  1072. if userSubscriptionId <= 0 {
  1073. return errors.New("invalid userSubscriptionId")
  1074. }
  1075. if delta == 0 {
  1076. return nil
  1077. }
  1078. return DB.Transaction(func(tx *gorm.DB) error {
  1079. var sub UserSubscription
  1080. if err := tx.Set("gorm:query_option", "FOR UPDATE").
  1081. Where("id = ?", userSubscriptionId).
  1082. First(&sub).Error; err != nil {
  1083. return err
  1084. }
  1085. newUsed := sub.AmountUsed + delta
  1086. if newUsed < 0 {
  1087. newUsed = 0
  1088. }
  1089. if sub.AmountTotal > 0 && newUsed > sub.AmountTotal {
  1090. return fmt.Errorf("subscription used exceeds total, used=%d total=%d", newUsed, sub.AmountTotal)
  1091. }
  1092. sub.AmountUsed = newUsed
  1093. return tx.Save(&sub).Error
  1094. })
  1095. }