token.go 25 KB


  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "math/rand/v2"
  6. "strings"
  7. "time"
  8. "github.com/labring/aiproxy/core/common"
  9. "github.com/labring/aiproxy/core/common/config"
  10. "github.com/labring/aiproxy/core/common/conv"
  11. log "github.com/sirupsen/logrus"
  12. "gorm.io/gorm"
  13. "gorm.io/gorm/clause"
  14. )
  15. const (
  16. ErrTokenNotFound = "token"
  17. )
  18. const (
  19. PeriodTypeDaily = "daily"
  20. PeriodTypeWeekly = "weekly"
  21. PeriodTypeMonthly = "monthly"
  22. )
  23. const (
  24. TokenStatusEnabled = 1
  25. TokenStatusDisabled = 2
  26. )
  27. type Token struct {
  28. CreatedAt time.Time `json:"created_at"`
  29. Group *Group `json:"-" gorm:"foreignKey:GroupID"`
  30. Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
  31. Name EmptyNullString `json:"name" gorm:"size:32;index;uniqueIndex:idx_group_name;not null"`
  32. GroupID string `json:"group" gorm:"size:64;index;uniqueIndex:idx_group_name"`
  33. Subnets []string `json:"subnets" gorm:"serializer:fastjson;type:text"`
  34. Models []string `json:"models" gorm:"serializer:fastjson;type:text"`
  35. Status int `json:"status" gorm:"default:1;index"`
  36. ID int `json:"id" gorm:"primaryKey"`
  37. UsedAmount float64 `json:"used_amount" gorm:"index"`
  38. RequestCount int `json:"request_count" gorm:"index"`
  39. Quota float64 `json:"quota"`
  40. PeriodQuota float64 `json:"period_quota"`
  41. PeriodType EmptyNullString `json:"period_type" gorm:"size:20"` // daily, weekly, monthly, default is monthly
  42. PeriodLastUpdateTime time.Time `json:"period_last_update_time"` // Last time period was reset
  43. PeriodLastUpdateAmount float64 `json:"period_last_update_amount"` // Total usage at last period reset
  44. }
  45. func (t *Token) BeforeCreate(_ *gorm.DB) error {
  46. if t.Key == "" || len(t.Key) != 48 {
  47. t.Key = generateKey()
  48. }
  49. return nil
  50. }
  51. func (t *Token) BeforeSave(_ *gorm.DB) error {
  52. if len(t.Name) > 32 {
  53. return errors.New("token name is too long")
  54. }
  55. return nil
  56. }
  57. // GetEffectiveQuotaStatus returns the effective quota status for token
  58. func (t *Token) GetEffectiveQuotaStatus() (totalExceeded, periodExceeded bool, err error) {
  59. // Check total quota (if set)
  60. if t.Quota > 0 && t.UsedAmount >= t.Quota {
  61. totalExceeded = true
  62. }
  63. if t.PeriodQuota > 0 {
  64. // Check if we need to reset period usage
  65. if needsReset, err := t.NeedsPeriodReset(); err != nil {
  66. return false, false, err
  67. } else if needsReset {
  68. // Period usage should be considered as reset (0) but we don't modify the struct here
  69. // The actual database reset should be handled separately
  70. periodExceeded = false // Consider as reset, so no period limit exceeded
  71. // Trigger async period reset - don't wait for it to complete
  72. go func() {
  73. if err := ResetTokenPeriodUsage(t.ID); err != nil {
  74. log.Error("failed to reset token period usage: " + err.Error())
  75. }
  76. }()
  77. } else {
  78. // Period is still valid, check against current usage
  79. // Calculate period usage: current total - last recorded total at period reset
  80. periodUsage := t.UsedAmount - t.PeriodLastUpdateAmount
  81. if periodUsage >= t.PeriodQuota {
  82. periodExceeded = true
  83. }
  84. }
  85. }
  86. return totalExceeded, periodExceeded, nil
  87. }
  88. // NeedsPeriodReset checks if the period usage should be reset
  89. // Uses PeriodLastUpdateTime to determine when the last period reset occurred
  90. func (t *Token) NeedsPeriodReset() (bool, error) {
  91. // If never been reset, use PeriodStartTime as baseline
  92. baseTime := t.PeriodLastUpdateTime
  93. if baseTime.IsZero() {
  94. return true, nil // Never initialized
  95. }
  96. now := time.Now()
  97. switch t.PeriodType {
  98. case "", PeriodTypeMonthly:
  99. // Check if we've crossed a month boundary since last reset
  100. baseMonth := baseTime.Month()
  101. baseYear := baseTime.Year()
  102. currentMonth := now.Month()
  103. currentYear := now.Year()
  104. if currentYear > baseYear {
  105. return true, nil
  106. }
  107. if currentYear == baseYear && currentMonth > baseMonth {
  108. return true, nil
  109. }
  110. return false, nil
  111. case PeriodTypeWeekly:
  112. // Check if we've passed 7 days since last reset
  113. return now.Sub(baseTime) >= 7*24*time.Hour, nil
  114. case PeriodTypeDaily:
  115. // Check if we've crossed to a new day since last reset
  116. baseDate := baseTime.Truncate(24 * time.Hour)
  117. currentDate := now.Truncate(24 * time.Hour)
  118. return currentDate.After(baseDate), nil
  119. default:
  120. return false, fmt.Errorf("unknown period type: %s", t.PeriodType)
  121. }
  122. }
  123. const (
  124. keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
  125. )
  126. func generateKey() string {
  127. key := make([]byte, 48)
  128. for i := range key {
  129. key[i] = keyChars[rand.IntN(len(keyChars))]
  130. }
  131. return conv.BytesToString(key)
  132. }
  133. func getTokenOrder(order string) string {
  134. prefix, suffix, _ := strings.Cut(order, "-")
  135. switch prefix {
  136. case "name", "expired_at", "group", "used_amount", "request_count", "id", "created_at":
  137. switch suffix {
  138. case "asc":
  139. return prefix + " asc"
  140. default:
  141. return prefix + " desc"
  142. }
  143. default:
  144. return "id desc"
  145. }
  146. }
  147. func InsertToken(token *Token, autoCreateGroup, ignoreExist bool) error {
  148. if autoCreateGroup {
  149. group := &Group{
  150. ID: token.GroupID,
  151. }
  152. if err := OnConflictDoNothing().Create(group).Error; err != nil {
  153. return err
  154. }
  155. }
  156. maxTokenNum := config.GetGroupMaxTokenNum()
  157. err := DB.Transaction(func(tx *gorm.DB) error {
  158. if maxTokenNum > 0 {
  159. var count int64
  160. err := tx.Model(&Token{}).Where("group_id = ?", token.GroupID).Count(&count).Error
  161. if err != nil {
  162. return err
  163. }
  164. if count >= maxTokenNum {
  165. return errors.New("group max token num reached")
  166. }
  167. }
  168. if ignoreExist {
  169. return tx.
  170. Where("group_id = ? and name = ?", token.GroupID, token.Name).
  171. FirstOrCreate(token).Error
  172. }
  173. return tx.Create(token).Error
  174. })
  175. if err != nil {
  176. if errors.Is(err, gorm.ErrDuplicatedKey) {
  177. if ignoreExist {
  178. return nil
  179. }
  180. return errors.New("token name already exists in this group")
  181. }
  182. return err
  183. }
  184. return nil
  185. }
  186. func GetTokens(
  187. group string,
  188. page, perPage int,
  189. order string,
  190. status int,
  191. ) (tokens []*Token, total int64, err error) {
  192. tx := DB.Model(&Token{})
  193. if group != "" {
  194. tx = tx.Where("group_id = ?", group)
  195. }
  196. if status != 0 {
  197. tx = tx.Where("status = ?", status)
  198. }
  199. err = tx.Count(&total).Error
  200. if err != nil {
  201. return nil, 0, err
  202. }
  203. if total <= 0 {
  204. return nil, 0, nil
  205. }
  206. limit, offset := toLimitOffset(page, perPage)
  207. err = tx.Order(getTokenOrder(order)).Limit(limit).Offset(offset).Find(&tokens).Error
  208. return tokens, total, err
  209. }
  210. func SearchTokens(
  211. group, keyword string,
  212. page, perPage int,
  213. order string,
  214. status int,
  215. name, key string,
  216. ) (tokens []*Token, total int64, err error) {
  217. tx := DB.Model(&Token{})
  218. if group != "" {
  219. tx = tx.Where("group_id = ?", group)
  220. }
  221. if status != 0 {
  222. tx = tx.Where("status = ?", status)
  223. }
  224. if name != "" {
  225. tx = tx.Where("name = ?", name)
  226. }
  227. if key != "" {
  228. tx = tx.Where("key = ?", key)
  229. }
  230. if keyword != "" {
  231. var (
  232. conditions []string
  233. values []any
  234. )
  235. if group == "" {
  236. if !common.UsingSQLite {
  237. conditions = append(conditions, "group_id ILIKE ?")
  238. } else {
  239. conditions = append(conditions, "group_id LIKE ?")
  240. }
  241. values = append(values, "%"+keyword+"%")
  242. }
  243. if name == "" {
  244. if !common.UsingSQLite {
  245. conditions = append(conditions, "name ILIKE ?")
  246. } else {
  247. conditions = append(conditions, "name LIKE ?")
  248. }
  249. values = append(values, "%"+keyword+"%")
  250. }
  251. if key == "" {
  252. if !common.UsingSQLite {
  253. conditions = append(conditions, "key ILIKE ?")
  254. } else {
  255. conditions = append(conditions, "key LIKE ?")
  256. }
  257. values = append(values, "%"+keyword+"%")
  258. }
  259. if !common.UsingSQLite {
  260. conditions = append(conditions, "models ILIKE ?")
  261. } else {
  262. conditions = append(conditions, "models LIKE ?")
  263. }
  264. values = append(values, "%"+keyword+"%")
  265. if len(conditions) > 0 {
  266. tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
  267. }
  268. }
  269. err = tx.Count(&total).Error
  270. if err != nil {
  271. return nil, 0, err
  272. }
  273. if total <= 0 {
  274. return nil, 0, nil
  275. }
  276. limit, offset := toLimitOffset(page, perPage)
  277. err = tx.Order(getTokenOrder(order)).Limit(limit).Offset(offset).Find(&tokens).Error
  278. return tokens, total, err
  279. }
  280. func SearchGroupTokens(
  281. group, keyword string,
  282. page, perPage int,
  283. order string,
  284. status int,
  285. name, key string,
  286. ) (tokens []*Token, total int64, err error) {
  287. if group == "" {
  288. return nil, 0, errors.New("group is empty")
  289. }
  290. tx := DB.Model(&Token{}).
  291. Where("group_id = ?", group)
  292. if name != "" {
  293. tx = tx.Where("name = ?", name)
  294. }
  295. if key != "" {
  296. tx = tx.Where("key = ?", key)
  297. }
  298. if status != 0 {
  299. tx = tx.Where("status = ?", status)
  300. }
  301. if keyword != "" {
  302. var (
  303. conditions []string
  304. values []any
  305. )
  306. if name == "" {
  307. if !common.UsingSQLite {
  308. conditions = append(conditions, "name ILIKE ?")
  309. } else {
  310. conditions = append(conditions, "name LIKE ?")
  311. }
  312. values = append(values, "%"+keyword+"%")
  313. }
  314. if key == "" {
  315. if !common.UsingSQLite {
  316. conditions = append(conditions, "key ILIKE ?")
  317. } else {
  318. conditions = append(conditions, "key LIKE ?")
  319. }
  320. values = append(values, "%"+keyword+"%")
  321. }
  322. if !common.UsingSQLite {
  323. conditions = append(conditions, "models ILIKE ?")
  324. } else {
  325. conditions = append(conditions, "models LIKE ?")
  326. }
  327. values = append(values, "%"+keyword+"%")
  328. if len(conditions) > 0 {
  329. tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
  330. }
  331. }
  332. err = tx.Count(&total).Error
  333. if err != nil {
  334. return nil, 0, err
  335. }
  336. if total <= 0 {
  337. return nil, 0, nil
  338. }
  339. limit, offset := toLimitOffset(page, perPage)
  340. err = tx.Order(getTokenOrder(order)).Limit(limit).Offset(offset).Find(&tokens).Error
  341. return tokens, total, err
  342. }
  343. func GetTokenByKey(key string) (*Token, error) {
  344. if key == "" {
  345. return nil, errors.New("key is empty")
  346. }
  347. var token Token
  348. err := DB.Where("key = ?", key).First(&token).Error
  349. return &token, HandleNotFound(err, ErrTokenNotFound)
  350. }
  351. // GetAndValidateToken validates a token and checks quota limits
  352. // This function is safe for concurrent use and handles period resets atomically
  353. func GetAndValidateToken(key string) (token *TokenCache, err error) {
  354. if key == "" {
  355. return nil, errors.New("no token provided")
  356. }
  357. token, err = CacheGetTokenByKey(key)
  358. if err != nil {
  359. if errors.Is(err, gorm.ErrRecordNotFound) {
  360. return nil, errors.New("invalid token")
  361. }
  362. log.Error("get token from cache failed: " + err.Error())
  363. return nil, errors.New("token validation failed")
  364. }
  365. if token.Status == TokenStatusDisabled {
  366. return nil, fmt.Errorf("token (%s[%d]) is disabled", token.Name, token.ID)
  367. }
  368. // Convert TokenCache to Token for quota checking
  369. tokenModel := Token{
  370. ID: token.ID,
  371. Quota: token.Quota,
  372. UsedAmount: token.UsedAmount,
  373. PeriodQuota: token.PeriodQuota,
  374. PeriodType: EmptyNullString(token.PeriodType),
  375. PeriodLastUpdateTime: time.Time(token.PeriodLastUpdateTime),
  376. PeriodLastUpdateAmount: token.PeriodLastUpdateAmount,
  377. }
  378. totalExceeded, periodExceeded, err := tokenModel.GetEffectiveQuotaStatus()
  379. if err != nil {
  380. return nil, fmt.Errorf("token (%s[%d]) quota check failed: %w", token.Name, token.ID, err)
  381. }
  382. if totalExceeded {
  383. return nil, fmt.Errorf("token (%s[%d]) total quota is exhausted", token.Name, token.ID)
  384. }
  385. if periodExceeded {
  386. return nil, fmt.Errorf("token (%s[%d]) period quota is exhausted", token.Name, token.ID)
  387. }
  388. return token, nil
  389. }
  390. func GetGroupTokenByID(group string, id int) (*Token, error) {
  391. if id == 0 || group == "" {
  392. return nil, errors.New("id or group is empty")
  393. }
  394. token := Token{}
  395. err := DB.
  396. Where("id = ? and group_id = ?", id, group).
  397. First(&token).Error
  398. return &token, HandleNotFound(err, ErrTokenNotFound)
  399. }
  400. func GetTokenByID(id int) (*Token, error) {
  401. if id == 0 {
  402. return nil, errors.New("id is empty")
  403. }
  404. token := Token{ID: id}
  405. err := DB.First(&token, "id = ?", id).Error
  406. return &token, HandleNotFound(err, ErrTokenNotFound)
  407. }
  408. func UpdateTokenStatus(id, status int) (err error) {
  409. token := Token{ID: id}
  410. defer func() {
  411. if err == nil {
  412. if err := CacheUpdateTokenStatus(token.Key, status); err != nil {
  413. log.Error("update token status in cache failed: " + err.Error())
  414. }
  415. }
  416. }()
  417. result := DB.
  418. Model(&token).
  419. Clauses(clause.Returning{
  420. Columns: []clause.Column{
  421. {Name: "key"},
  422. },
  423. }).
  424. Where("id = ?", id).
  425. Updates(
  426. map[string]any{
  427. "status": status,
  428. },
  429. )
  430. return HandleUpdateResult(result, ErrTokenNotFound)
  431. }
  432. func UpdateGroupTokenStatus(group string, id, status int) (err error) {
  433. if id == 0 || group == "" {
  434. return errors.New("id or group is empty")
  435. }
  436. token := Token{}
  437. defer func() {
  438. if err == nil {
  439. if err := CacheUpdateTokenStatus(token.Key, status); err != nil {
  440. log.Error("update token status in cache failed: " + err.Error())
  441. }
  442. }
  443. }()
  444. result := DB.
  445. Model(&token).
  446. Clauses(clause.Returning{
  447. Columns: []clause.Column{
  448. {Name: "key"},
  449. },
  450. }).
  451. Where("id = ? and group_id = ?", id, group).
  452. Updates(
  453. map[string]any{
  454. "status": status,
  455. },
  456. )
  457. return HandleUpdateResult(result, ErrTokenNotFound)
  458. }
  459. func DeleteGroupTokenByID(groupID string, id int) (err error) {
  460. if id == 0 || groupID == "" {
  461. return errors.New("id or group is empty")
  462. }
  463. token := Token{ID: id, GroupID: groupID}
  464. defer func() {
  465. if err == nil {
  466. if err := CacheDeleteToken(token.Key); err != nil {
  467. log.Error("delete token from cache failed: " + err.Error())
  468. }
  469. }
  470. }()
  471. result := DB.
  472. Clauses(clause.Returning{
  473. Columns: []clause.Column{
  474. {Name: "key"},
  475. },
  476. }).
  477. Where(token).
  478. Delete(&token)
  479. return HandleUpdateResult(result, ErrTokenNotFound)
  480. }
  481. func DeleteGroupTokensByIDs(group string, ids []int) (err error) {
  482. if group == "" {
  483. return errors.New("group is empty")
  484. }
  485. if len(ids) == 0 {
  486. return nil
  487. }
  488. tokens := make([]Token, len(ids))
  489. defer func() {
  490. if err == nil {
  491. for _, token := range tokens {
  492. if err := CacheDeleteToken(token.Key); err != nil {
  493. log.Error("delete token from cache failed: " + err.Error())
  494. }
  495. }
  496. }
  497. }()
  498. return DB.Transaction(func(tx *gorm.DB) error {
  499. return tx.
  500. Clauses(clause.Returning{
  501. Columns: []clause.Column{
  502. {Name: "key"},
  503. },
  504. }).
  505. Where("group_id = ?", group).
  506. Where("id IN (?)", ids).
  507. Delete(&tokens).
  508. Error
  509. })
  510. }
  511. func DeleteTokenByID(id int) (err error) {
  512. if id == 0 {
  513. return errors.New("id is empty")
  514. }
  515. token := Token{ID: id}
  516. defer func() {
  517. if err == nil {
  518. if err := CacheDeleteToken(token.Key); err != nil {
  519. log.Error("delete token from cache failed: " + err.Error())
  520. }
  521. }
  522. }()
  523. result := DB.
  524. Clauses(clause.Returning{
  525. Columns: []clause.Column{
  526. {Name: "key"},
  527. },
  528. }).
  529. Where(token).
  530. Delete(&token)
  531. return HandleUpdateResult(result, ErrTokenNotFound)
  532. }
  533. func DeleteTokensByIDs(ids []int) (err error) {
  534. if len(ids) == 0 {
  535. return nil
  536. }
  537. tokens := make([]Token, len(ids))
  538. defer func() {
  539. if err == nil {
  540. for _, token := range tokens {
  541. if err := CacheDeleteToken(token.Key); err != nil {
  542. log.Error("delete token from cache failed: " + err.Error())
  543. }
  544. }
  545. }
  546. }()
  547. return DB.Transaction(func(tx *gorm.DB) error {
  548. return tx.
  549. Clauses(clause.Returning{
  550. Columns: []clause.Column{
  551. {Name: "key"},
  552. },
  553. }).
  554. Where("id IN (?)", ids).
  555. Delete(&tokens).
  556. Error
  557. })
  558. }
  559. type UpdateTokenRequest struct {
  560. Name *string `json:"name"`
  561. Subnets *[]string `json:"subnets"`
  562. Models *[]string `json:"models"`
  563. Status int `json:"status"`
  564. // Quota system
  565. Quota *float64 `json:"quota"`
  566. PeriodQuota *float64 `json:"period_quota"`
  567. PeriodType *string `json:"period_type"`
  568. PeriodLastUpdateTime *int64 `json:"period_last_update_time"`
  569. }
  570. func UpdateToken(id int, update UpdateTokenRequest) (token *Token, err error) {
  571. if id == 0 {
  572. return nil, errors.New("id is empty")
  573. }
  574. token = &Token{
  575. ID: id,
  576. Status: update.Status,
  577. }
  578. defer func() {
  579. if err == nil {
  580. if err := CacheDeleteToken(token.Key); err != nil {
  581. log.Error("delete token from cache failed: " + err.Error())
  582. }
  583. }
  584. }()
  585. selects := []string{}
  586. if update.Name != nil && *update.Name != "" {
  587. token.Name = EmptyNullString(*update.Name)
  588. selects = append(selects, "name")
  589. }
  590. if update.Quota != nil {
  591. token.Quota = *update.Quota
  592. selects = append(selects, "quota")
  593. }
  594. if update.PeriodQuota != nil {
  595. token.PeriodQuota = *update.PeriodQuota
  596. selects = append(selects, "period_quota")
  597. }
  598. if update.PeriodType != nil {
  599. token.PeriodType = EmptyNullString(*update.PeriodType)
  600. selects = append(selects, "period_type")
  601. }
  602. if update.PeriodLastUpdateTime != nil {
  603. token.PeriodLastUpdateTime = time.UnixMilli(*update.PeriodLastUpdateTime)
  604. selects = append(selects, "period_last_update_time")
  605. }
  606. if update.Subnets != nil {
  607. token.Subnets = *update.Subnets
  608. selects = append(selects, "subnets")
  609. }
  610. if update.Models != nil {
  611. token.Models = *update.Models
  612. selects = append(selects, "models")
  613. }
  614. if update.Status != 0 {
  615. selects = append(selects, "status")
  616. }
  617. if len(selects) == 0 {
  618. return nil, errors.New("empty update request")
  619. }
  620. result := DB.
  621. Select(selects).
  622. Where("id = ?", id).
  623. Clauses(clause.Returning{}).
  624. Updates(token)
  625. if result.Error != nil {
  626. if errors.Is(result.Error, gorm.ErrDuplicatedKey) {
  627. return nil, errors.New("token name already exists in this group")
  628. }
  629. }
  630. return token, HandleUpdateResult(result, ErrTokenNotFound)
  631. }
  632. func UpdateGroupToken(
  633. id int,
  634. group string,
  635. update UpdateTokenRequest,
  636. ) (token *Token, err error) {
  637. if id == 0 || group == "" {
  638. return nil, errors.New("id or group is empty")
  639. }
  640. token = &Token{
  641. ID: id,
  642. GroupID: group,
  643. Status: update.Status,
  644. }
  645. defer func() {
  646. if err == nil {
  647. if err := CacheDeleteToken(token.Key); err != nil {
  648. log.Error("delete token from cache failed: " + err.Error())
  649. }
  650. }
  651. }()
  652. selects := []string{}
  653. if update.Name != nil && *update.Name != "" {
  654. token.Name = EmptyNullString(*update.Name)
  655. selects = append(selects, "name")
  656. }
  657. if update.Quota != nil {
  658. token.Quota = *update.Quota
  659. selects = append(selects, "quota")
  660. }
  661. if update.PeriodQuota != nil {
  662. token.PeriodQuota = *update.PeriodQuota
  663. selects = append(selects, "period_quota")
  664. }
  665. if update.PeriodType != nil {
  666. token.PeriodType = EmptyNullString(*update.PeriodType)
  667. selects = append(selects, "period_type")
  668. }
  669. if update.PeriodLastUpdateTime != nil {
  670. token.PeriodLastUpdateTime = time.UnixMilli(*update.PeriodLastUpdateTime)
  671. selects = append(selects, "period_last_update_time")
  672. }
  673. if update.Subnets != nil {
  674. token.Subnets = *update.Subnets
  675. selects = append(selects, "subnets")
  676. }
  677. if update.Models != nil {
  678. token.Models = *update.Models
  679. selects = append(selects, "models")
  680. }
  681. if update.Status != 0 {
  682. selects = append(selects, "status")
  683. }
  684. if len(selects) == 0 {
  685. return nil, errors.New("empty update request")
  686. }
  687. result := DB.
  688. Select(selects).
  689. Where("id = ? and group_id = ?", id, group).
  690. Clauses(clause.Returning{}).
  691. Updates(token)
  692. if result.Error != nil {
  693. if errors.Is(result.Error, gorm.ErrDuplicatedKey) {
  694. return nil, errors.New("token name already exists in this group")
  695. }
  696. }
  697. return token, HandleUpdateResult(result, ErrTokenNotFound)
  698. }
  699. func UpdateTokenUsedAmount(id int, amount float64, requestCount int) (err error) {
  700. token := &Token{}
  701. defer func() {
  702. if amount > 0 && err == nil && (token.Quota > 0 || token.PeriodQuota > 0) {
  703. if err := CacheUpdateTokenUsedAmountOnlyIncrease(token.Key, token.UsedAmount); err != nil {
  704. log.Error("update token used amount in cache failed: " + err.Error())
  705. }
  706. }
  707. }()
  708. result := DB.
  709. Model(token).
  710. Clauses(clause.Returning{
  711. Columns: []clause.Column{
  712. {Name: "key"},
  713. {Name: "quota"},
  714. {Name: "used_amount"},
  715. {Name: "period_quota"},
  716. },
  717. }).
  718. Where("id = ?", id).
  719. Updates(
  720. map[string]any{
  721. "used_amount": gorm.Expr("used_amount + ?", amount),
  722. "request_count": gorm.Expr("request_count + ?", requestCount),
  723. },
  724. )
  725. return HandleUpdateResult(result, ErrTokenNotFound)
  726. }
  727. // calculateNextPeriodStartTime calculates the next period start time based on the last update time and period type
  728. // This finds the most recent period boundary by incrementing from lastUpdateTime until we reach the current time
  729. // This maintains period continuity - e.g., if reset was on Jan 15, next periods are Feb 15, Mar 15, etc.
  730. func calculateNextPeriodStartTime(lastUpdateTime time.Time, periodType EmptyNullString) time.Time {
  731. if lastUpdateTime.IsZero() {
  732. // If never initialized, return current time
  733. return time.Now()
  734. }
  735. now := time.Now()
  736. // If we haven't passed the period yet, no reset needed
  737. if !now.After(lastUpdateTime) {
  738. return lastUpdateTime
  739. }
  740. switch periodType {
  741. case "", PeriodTypeMonthly:
  742. // Start from lastUpdateTime and keep adding months until we find the most recent period start
  743. nextPeriod := lastUpdateTime
  744. for {
  745. // Calculate next month period
  746. candidate := time.Date(
  747. nextPeriod.Year(),
  748. nextPeriod.Month()+1,
  749. nextPeriod.Day(),
  750. nextPeriod.Hour(),
  751. nextPeriod.Minute(),
  752. nextPeriod.Second(),
  753. nextPeriod.Nanosecond(),
  754. nextPeriod.Location(),
  755. )
  756. // If candidate is in the future, the current nextPeriod is the one we want
  757. if candidate.After(now) {
  758. return nextPeriod
  759. }
  760. nextPeriod = candidate
  761. }
  762. case PeriodTypeWeekly:
  763. // Calculate how many complete weeks have passed since lastUpdateTime
  764. daysSinceLastUpdate := now.Sub(lastUpdateTime).Hours() / 24
  765. weeksPassed := int(daysSinceLastUpdate / 7)
  766. if weeksPassed == 0 {
  767. // Still in the same week period, no reset needed
  768. return lastUpdateTime
  769. }
  770. // Return the start of the most recent week period
  771. // This is lastUpdateTime + (weeksPassed * 7 days)
  772. return lastUpdateTime.Add(time.Duration(weeksPassed*7*24) * time.Hour)
  773. case PeriodTypeDaily:
  774. // Calculate how many complete days have passed since lastUpdateTime
  775. daysSinceLastUpdate := int(now.Sub(lastUpdateTime).Hours() / 24)
  776. if daysSinceLastUpdate == 0 {
  777. // Still in the same day period, no reset needed
  778. return lastUpdateTime
  779. }
  780. // Return the start of the most recent day period
  781. // This is lastUpdateTime + (daysPassed * 1 day)
  782. return lastUpdateTime.Add(time.Duration(daysSinceLastUpdate*24) * time.Hour)
  783. default:
  784. // Fallback to current time for unknown period types
  785. return now
  786. }
  787. }
  788. // ResetTokenPeriodUsage resets the period usage for a token with concurrency safety
  789. // This updates PeriodLastUpdateTime and PeriodLastUpdateAmount to current values
  790. func ResetTokenPeriodUsage(id int) error {
  791. token := &Token{}
  792. var newPeriodStartTime time.Time
  793. // Use database transaction with optimistic locking to prevent concurrent resets
  794. err := DB.Transaction(func(tx *gorm.DB) error {
  795. // First, read the current state with FOR UPDATE lock
  796. if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
  797. Where("id = ?", id).
  798. First(token).Error; err != nil {
  799. return err
  800. }
  801. // Check if period still needs reset (another concurrent request might have already reset it)
  802. needsReset, err := token.NeedsPeriodReset()
  803. if err != nil {
  804. return err
  805. }
  806. // If period no longer needs reset, skip the update
  807. if !needsReset {
  808. return nil
  809. }
  810. // Calculate the correct next period start time based on period type
  811. newPeriodStartTime = calculateNextPeriodStartTime(
  812. token.PeriodLastUpdateTime,
  813. token.PeriodType,
  814. )
  815. if newPeriodStartTime.IsZero() {
  816. return errors.New("next period start time is zero")
  817. }
  818. // Perform the reset with the lock held - update period last update time and amount
  819. result := tx.
  820. Model(token).
  821. Clauses(clause.Returning{
  822. Columns: []clause.Column{
  823. {Name: "key"},
  824. },
  825. }).
  826. Where("id = ?", id).
  827. Updates(
  828. map[string]any{
  829. "period_last_update_time": newPeriodStartTime,
  830. "period_last_update_amount": gorm.Expr(
  831. "used_amount",
  832. ), // Set to current total usage
  833. },
  834. )
  835. return HandleUpdateResult(result, ErrTokenNotFound)
  836. })
  837. // Update cache only if database update succeeded
  838. if err == nil && token.Key != "" && !newPeriodStartTime.IsZero() {
  839. if cacheErr := CacheResetTokenPeriodUsage(token.Key, newPeriodStartTime, token.UsedAmount); cacheErr != nil {
  840. log.Error("reset token period usage in cache failed: " + cacheErr.Error())
  841. }
  842. }
  843. return err
  844. }
  845. func UpdateTokenName(id int, name string) (err error) {
  846. token := &Token{ID: id}
  847. defer func() {
  848. if err == nil {
  849. if err := CacheUpdateTokenName(token.Key, name); err != nil {
  850. log.Error("update token name in cache failed: " + err.Error())
  851. }
  852. }
  853. }()
  854. result := DB.
  855. Model(token).
  856. Clauses(clause.Returning{
  857. Columns: []clause.Column{
  858. {Name: "key"},
  859. },
  860. }).
  861. Where("id = ?", id).
  862. Update("name", name)
  863. if result.Error != nil && errors.Is(result.Error, gorm.ErrDuplicatedKey) {
  864. return errors.New("token name already exists in this group")
  865. }
  866. return HandleUpdateResult(result, ErrTokenNotFound)
  867. }
  868. func UpdateGroupTokenName(group string, id int, name string) (err error) {
  869. token := &Token{ID: id, GroupID: group}
  870. defer func() {
  871. if err == nil {
  872. if err := CacheUpdateTokenName(token.Key, name); err != nil {
  873. log.Error("update token name in cache failed: " + err.Error())
  874. }
  875. }
  876. }()
  877. result := DB.
  878. Model(token).
  879. Clauses(clause.Returning{
  880. Columns: []clause.Column{
  881. {Name: "key"},
  882. },
  883. }).
  884. Where("id = ? and group_id = ?", id, group).
  885. Update("name", name)
  886. if result.Error != nil && errors.Is(result.Error, gorm.ErrDuplicatedKey) {
  887. return errors.New("token name already exists in this group")
  888. }
  889. return HandleUpdateResult(result, ErrTokenNotFound)
  890. }