token.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "math/rand/v2"
  6. "strings"
  7. "time"
  8. "github.com/labring/aiproxy/core/common"
  9. "github.com/labring/aiproxy/core/common/config"
  10. "github.com/labring/aiproxy/core/common/conv"
  11. log "github.com/sirupsen/logrus"
  12. "gorm.io/gorm"
  13. "gorm.io/gorm/clause"
  14. )
  15. const (
  16. ErrTokenNotFound = "token"
  17. )
  18. const (
  19. PeriodTypeDaily = "daily"
  20. PeriodTypeWeekly = "weekly"
  21. PeriodTypeMonthly = "monthly"
  22. )
  23. const (
  24. TokenStatusEnabled = 1
  25. TokenStatusDisabled = 2
  26. )
  27. type Token struct {
  28. CreatedAt time.Time `json:"created_at"`
  29. Group *Group `json:"-" gorm:"foreignKey:GroupID"`
  30. Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
  31. Name EmptyNullString `json:"name" gorm:"size:32;index;uniqueIndex:idx_group_name;not null"`
  32. GroupID string `json:"group" gorm:"size:64;index;uniqueIndex:idx_group_name"`
  33. Subnets []string `json:"subnets" gorm:"serializer:fastjson;type:text"`
  34. Models []string `json:"models" gorm:"serializer:fastjson;type:text"`
  35. Status int `json:"status" gorm:"default:1;index"`
  36. ID int `json:"id" gorm:"primaryKey"`
  37. UsedAmount float64 `json:"used_amount" gorm:"index"`
  38. RequestCount int `json:"request_count" gorm:"index"`
  39. Quota float64 `json:"quota"`
  40. PeriodQuota float64 `json:"period_quota"`
  41. PeriodType EmptyNullString `json:"period_type" gorm:"size:20"` // daily, weekly, monthly, default is monthly
  42. PeriodLastUpdateTime time.Time `json:"period_last_update_time"` // Last time period was reset
  43. PeriodLastUpdateAmount float64 `json:"period_last_update_amount"` // Total usage at last period reset
  44. }
  45. func (t *Token) BeforeCreate(_ *gorm.DB) error {
  46. if t.Key == "" || len(t.Key) != 48 {
  47. t.Key = generateKey()
  48. }
  49. return nil
  50. }
  51. func (t *Token) BeforeSave(_ *gorm.DB) error {
  52. if len(t.Name) > 32 {
  53. return errors.New("token name is too long")
  54. }
  55. return nil
  56. }
  57. // GetEffectiveQuotaStatus returns the effective quota status for token
  58. func (t *Token) GetEffectiveQuotaStatus() (totalExceeded, periodExceeded bool, err error) {
  59. // Check total quota (if set)
  60. if t.Quota > 0 && t.UsedAmount >= t.Quota {
  61. totalExceeded = true
  62. }
  63. if t.PeriodQuota > 0 {
  64. // Check if we need to reset period usage
  65. if needsReset, err := t.NeedsPeriodReset(); err != nil {
  66. return false, false, err
  67. } else if needsReset {
  68. // Period usage should be considered as reset (0) but we don't modify the struct here
  69. // The actual database reset should be handled separately
  70. periodExceeded = false // Consider as reset, so no period limit exceeded
  71. // Trigger async period reset - don't wait for it to complete
  72. go func() {
  73. if err := ResetTokenPeriodUsage(t.ID); err != nil {
  74. log.Error("failed to reset token period usage: " + err.Error())
  75. }
  76. }()
  77. } else {
  78. // Period is still valid, check against current usage
  79. // Calculate period usage: current total - last recorded total at period reset
  80. periodUsage := t.UsedAmount - t.PeriodLastUpdateAmount
  81. if periodUsage >= t.PeriodQuota {
  82. periodExceeded = true
  83. }
  84. }
  85. }
  86. return totalExceeded, periodExceeded, nil
  87. }
  88. // NeedsPeriodReset checks if the period usage should be reset
  89. // Uses PeriodLastUpdateTime to determine when the last period reset occurred
  90. func (t *Token) NeedsPeriodReset() (bool, error) {
  91. // If never been reset, use PeriodStartTime as baseline
  92. baseTime := t.PeriodLastUpdateTime
  93. if baseTime.IsZero() {
  94. return true, nil // Never initialized
  95. }
  96. now := time.Now()
  97. switch t.PeriodType {
  98. case "", PeriodTypeMonthly:
  99. // Check if we've crossed a month boundary since last reset
  100. baseMonth := baseTime.Month()
  101. baseYear := baseTime.Year()
  102. currentMonth := now.Month()
  103. currentYear := now.Year()
  104. if currentYear > baseYear {
  105. return true, nil
  106. }
  107. if currentYear == baseYear && currentMonth > baseMonth {
  108. return true, nil
  109. }
  110. return false, nil
  111. case PeriodTypeWeekly:
  112. // Check if we've passed 7 days since last reset
  113. return now.Sub(baseTime) >= 7*24*time.Hour, nil
  114. case PeriodTypeDaily:
  115. // Check if we've crossed to a new day since last reset
  116. baseDate := baseTime.Truncate(24 * time.Hour)
  117. currentDate := now.Truncate(24 * time.Hour)
  118. return currentDate.After(baseDate), nil
  119. default:
  120. return false, fmt.Errorf("unknown period type: %s", t.PeriodType)
  121. }
  122. }
  123. const (
  124. keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
  125. )
  126. func generateKey() string {
  127. key := make([]byte, 48)
  128. for i := range key {
  129. key[i] = keyChars[rand.IntN(len(keyChars))]
  130. }
  131. return conv.BytesToString(key)
  132. }
  133. func getTokenOrder(order string) string {
  134. prefix, suffix, _ := strings.Cut(order, "-")
  135. switch prefix {
  136. case "name", "expired_at", "group", "used_amount", "request_count", "id", "created_at":
  137. switch suffix {
  138. case "asc":
  139. return prefix + " asc"
  140. default:
  141. return prefix + " desc"
  142. }
  143. default:
  144. return "id desc"
  145. }
  146. }
  147. func InsertToken(token *Token, autoCreateGroup, ignoreExist bool) error {
  148. if autoCreateGroup {
  149. group := &Group{
  150. ID: token.GroupID,
  151. }
  152. if err := OnConflictDoNothing().Create(group).Error; err != nil {
  153. return err
  154. }
  155. }
  156. maxTokenNum := config.GetGroupMaxTokenNum()
  157. err := DB.Transaction(func(tx *gorm.DB) error {
  158. if maxTokenNum > 0 {
  159. var count int64
  160. err := tx.Model(&Token{}).Where("group_id = ?", token.GroupID).Count(&count).Error
  161. if err != nil {
  162. return err
  163. }
  164. if count >= maxTokenNum {
  165. return errors.New("group max token num reached")
  166. }
  167. }
  168. if ignoreExist {
  169. return tx.
  170. Where("group_id = ? and name = ?", token.GroupID, token.Name).
  171. FirstOrCreate(token).Error
  172. }
  173. return tx.Create(token).Error
  174. })
  175. if err != nil {
  176. if errors.Is(err, gorm.ErrDuplicatedKey) {
  177. if ignoreExist {
  178. return nil
  179. }
  180. return errors.New("token name already exists in this group")
  181. }
  182. return err
  183. }
  184. return nil
  185. }
  186. func GetTokens(
  187. group string,
  188. page, perPage int,
  189. order string,
  190. status int,
  191. ) (tokens []*Token, total int64, err error) {
  192. tx := DB.Model(&Token{})
  193. if group != "" {
  194. tx = tx.Where("group_id = ?", group)
  195. }
  196. if status != 0 {
  197. tx = tx.Where("status = ?", status)
  198. }
  199. err = tx.Count(&total).Error
  200. if err != nil {
  201. return nil, 0, err
  202. }
  203. if total <= 0 {
  204. return nil, 0, nil
  205. }
  206. limit, offset := toLimitOffset(page, perPage)
  207. err = tx.Order(getTokenOrder(order)).Limit(limit).Offset(offset).Find(&tokens).Error
  208. return tokens, total, err
  209. }
  210. func SearchTokens(
  211. group, keyword string,
  212. page, perPage int,
  213. order string,
  214. status int,
  215. name, key string,
  216. ) (tokens []*Token, total int64, err error) {
  217. tx := DB.Model(&Token{})
  218. if group != "" {
  219. tx = tx.Where("group_id = ?", group)
  220. }
  221. if status != 0 {
  222. tx = tx.Where("status = ?", status)
  223. }
  224. if name != "" {
  225. tx = tx.Where("name = ?", name)
  226. }
  227. if key != "" {
  228. tx = tx.Where("key = ?", key)
  229. }
  230. if keyword != "" {
  231. var (
  232. conditions []string
  233. values []any
  234. )
  235. if group == "" {
  236. if !common.UsingSQLite {
  237. conditions = append(conditions, "group_id ILIKE ?")
  238. } else {
  239. conditions = append(conditions, "group_id LIKE ?")
  240. }
  241. values = append(values, "%"+keyword+"%")
  242. }
  243. if name == "" {
  244. if !common.UsingSQLite {
  245. conditions = append(conditions, "name ILIKE ?")
  246. } else {
  247. conditions = append(conditions, "name LIKE ?")
  248. }
  249. values = append(values, "%"+keyword+"%")
  250. }
  251. if key == "" {
  252. if !common.UsingSQLite {
  253. conditions = append(conditions, "key ILIKE ?")
  254. } else {
  255. conditions = append(conditions, "key LIKE ?")
  256. }
  257. values = append(values, "%"+keyword+"%")
  258. }
  259. if !common.UsingSQLite {
  260. conditions = append(conditions, "models ILIKE ?")
  261. } else {
  262. conditions = append(conditions, "models LIKE ?")
  263. }
  264. values = append(values, "%"+keyword+"%")
  265. if len(conditions) > 0 {
  266. tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
  267. }
  268. }
  269. err = tx.Count(&total).Error
  270. if err != nil {
  271. return nil, 0, err
  272. }
  273. if total <= 0 {
  274. return nil, 0, nil
  275. }
  276. limit, offset := toLimitOffset(page, perPage)
  277. err = tx.Order(getTokenOrder(order)).Limit(limit).Offset(offset).Find(&tokens).Error
  278. return tokens, total, err
  279. }
  280. func SearchGroupTokens(
  281. group, keyword string,
  282. page, perPage int,
  283. order string,
  284. status int,
  285. name, key string,
  286. ) (tokens []*Token, total int64, err error) {
  287. if group == "" {
  288. return nil, 0, errors.New("group is empty")
  289. }
  290. tx := DB.Model(&Token{}).
  291. Where("group_id = ?", group)
  292. if name != "" {
  293. tx = tx.Where("name = ?", name)
  294. }
  295. if key != "" {
  296. tx = tx.Where("key = ?", key)
  297. }
  298. if status != 0 {
  299. tx = tx.Where("status = ?", status)
  300. }
  301. if keyword != "" {
  302. var (
  303. conditions []string
  304. values []any
  305. )
  306. if name == "" {
  307. if !common.UsingSQLite {
  308. conditions = append(conditions, "name ILIKE ?")
  309. } else {
  310. conditions = append(conditions, "name LIKE ?")
  311. }
  312. values = append(values, "%"+keyword+"%")
  313. }
  314. if key == "" {
  315. if !common.UsingSQLite {
  316. conditions = append(conditions, "key ILIKE ?")
  317. } else {
  318. conditions = append(conditions, "key LIKE ?")
  319. }
  320. values = append(values, "%"+keyword+"%")
  321. }
  322. if !common.UsingSQLite {
  323. conditions = append(conditions, "models ILIKE ?")
  324. } else {
  325. conditions = append(conditions, "models LIKE ?")
  326. }
  327. values = append(values, "%"+keyword+"%")
  328. if len(conditions) > 0 {
  329. tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
  330. }
  331. }
  332. err = tx.Count(&total).Error
  333. if err != nil {
  334. return nil, 0, err
  335. }
  336. if total <= 0 {
  337. return nil, 0, nil
  338. }
  339. limit, offset := toLimitOffset(page, perPage)
  340. err = tx.Order(getTokenOrder(order)).Limit(limit).Offset(offset).Find(&tokens).Error
  341. return tokens, total, err
  342. }
  343. func GetTokenByKey(key string) (*Token, error) {
  344. if key == "" {
  345. return nil, errors.New("key is empty")
  346. }
  347. var token Token
  348. err := DB.Where("key = ?", key).First(&token).Error
  349. return &token, HandleNotFound(err, ErrTokenNotFound)
  350. }
  351. // GetAndValidateToken validates a token and checks quota limits
  352. // This function is safe for concurrent use and handles period resets atomically
  353. func GetAndValidateToken(key string) (token *TokenCache, err error) {
  354. if key == "" {
  355. return nil, errors.New("no token provided")
  356. }
  357. token, err = CacheGetTokenByKey(key)
  358. if err != nil {
  359. if errors.Is(err, gorm.ErrRecordNotFound) {
  360. return nil, errors.New("invalid token")
  361. }
  362. log.Error("get token from cache failed: " + err.Error())
  363. return nil, errors.New("token validation failed")
  364. }
  365. if token.Status == TokenStatusDisabled {
  366. return nil, fmt.Errorf("token (%s[%d]) is disabled", token.Name, token.ID)
  367. }
  368. // Convert TokenCache to Token for quota checking
  369. tokenModel := Token{
  370. ID: token.ID,
  371. Quota: token.Quota,
  372. UsedAmount: token.UsedAmount,
  373. PeriodQuota: token.PeriodQuota,
  374. PeriodType: EmptyNullString(token.PeriodType),
  375. PeriodLastUpdateTime: time.Time(token.PeriodLastUpdateTime),
  376. PeriodLastUpdateAmount: token.PeriodLastUpdateAmount,
  377. }
  378. totalExceeded, periodExceeded, err := tokenModel.GetEffectiveQuotaStatus()
  379. if err != nil {
  380. return nil, fmt.Errorf("token (%s[%d]) quota check failed: %w", token.Name, token.ID, err)
  381. }
  382. if totalExceeded {
  383. return nil, fmt.Errorf("token (%s[%d]) total quota is exhausted", token.Name, token.ID)
  384. }
  385. if periodExceeded {
  386. return nil, fmt.Errorf("token (%s[%d]) period quota is exhausted", token.Name, token.ID)
  387. }
  388. return token, nil
  389. }
  390. func GetGroupTokenByID(group string, id int) (*Token, error) {
  391. if id == 0 || group == "" {
  392. return nil, errors.New("id or group is empty")
  393. }
  394. token := Token{}
  395. err := DB.
  396. Where("id = ? and group_id = ?", id, group).
  397. First(&token).Error
  398. return &token, HandleNotFound(err, ErrTokenNotFound)
  399. }
  400. func GetTokenByID(id int) (*Token, error) {
  401. if id == 0 {
  402. return nil, errors.New("id is empty")
  403. }
  404. token := Token{ID: id}
  405. err := DB.First(&token, "id = ?", id).Error
  406. return &token, HandleNotFound(err, ErrTokenNotFound)
  407. }
  408. func UpdateTokenStatus(id, status int) (err error) {
  409. token := Token{ID: id}
  410. defer func() {
  411. if err == nil {
  412. if err := CacheUpdateTokenStatus(token.Key, status); err != nil {
  413. log.Error("update token status in cache failed: " + err.Error())
  414. }
  415. }
  416. }()
  417. result := DB.
  418. Model(&token).
  419. Clauses(clause.Returning{
  420. Columns: []clause.Column{
  421. {Name: "key"},
  422. },
  423. }).
  424. Where("id = ?", id).
  425. Updates(
  426. map[string]any{
  427. "status": status,
  428. },
  429. )
  430. return HandleUpdateResult(result, ErrTokenNotFound)
  431. }
  432. func UpdateGroupTokenStatus(group string, id, status int) (err error) {
  433. if id == 0 || group == "" {
  434. return errors.New("id or group is empty")
  435. }
  436. token := Token{}
  437. defer func() {
  438. if err == nil {
  439. if err := CacheUpdateTokenStatus(token.Key, status); err != nil {
  440. log.Error("update token status in cache failed: " + err.Error())
  441. }
  442. }
  443. }()
  444. result := DB.
  445. Model(&token).
  446. Clauses(clause.Returning{
  447. Columns: []clause.Column{
  448. {Name: "key"},
  449. },
  450. }).
  451. Where("id = ? and group_id = ?", id, group).
  452. Updates(
  453. map[string]any{
  454. "status": status,
  455. },
  456. )
  457. return HandleUpdateResult(result, ErrTokenNotFound)
  458. }
  459. func DeleteGroupTokenByID(groupID string, id int) (err error) {
  460. if id == 0 || groupID == "" {
  461. return errors.New("id or group is empty")
  462. }
  463. token := Token{ID: id, GroupID: groupID}
  464. defer func() {
  465. if err == nil {
  466. if err := CacheDeleteToken(token.Key); err != nil {
  467. log.Error("delete token from cache failed: " + err.Error())
  468. }
  469. }
  470. }()
  471. result := DB.
  472. Clauses(clause.Returning{
  473. Columns: []clause.Column{
  474. {Name: "key"},
  475. },
  476. }).
  477. Where(token).
  478. Delete(&token)
  479. return HandleUpdateResult(result, ErrTokenNotFound)
  480. }
  481. func DeleteGroupTokensByIDs(group string, ids []int) (err error) {
  482. if group == "" {
  483. return errors.New("group is empty")
  484. }
  485. if len(ids) == 0 {
  486. return nil
  487. }
  488. tokens := make([]Token, len(ids))
  489. defer func() {
  490. if err == nil {
  491. for _, token := range tokens {
  492. if err := CacheDeleteToken(token.Key); err != nil {
  493. log.Error("delete token from cache failed: " + err.Error())
  494. }
  495. }
  496. }
  497. }()
  498. return DB.Transaction(func(tx *gorm.DB) error {
  499. return tx.
  500. Clauses(clause.Returning{
  501. Columns: []clause.Column{
  502. {Name: "key"},
  503. },
  504. }).
  505. Where("group_id = ?", group).
  506. Where("id IN (?)", ids).
  507. Delete(&tokens).
  508. Error
  509. })
  510. }
  511. func DeleteTokenByID(id int) (err error) {
  512. if id == 0 {
  513. return errors.New("id is empty")
  514. }
  515. token := Token{ID: id}
  516. defer func() {
  517. if err == nil {
  518. if err := CacheDeleteToken(token.Key); err != nil {
  519. log.Error("delete token from cache failed: " + err.Error())
  520. }
  521. }
  522. }()
  523. result := DB.
  524. Clauses(clause.Returning{
  525. Columns: []clause.Column{
  526. {Name: "key"},
  527. },
  528. }).
  529. Where(token).
  530. Delete(&token)
  531. return HandleUpdateResult(result, ErrTokenNotFound)
  532. }
  533. func DeleteTokensByIDs(ids []int) (err error) {
  534. if len(ids) == 0 {
  535. return nil
  536. }
  537. tokens := make([]Token, len(ids))
  538. defer func() {
  539. if err == nil {
  540. for _, token := range tokens {
  541. if err := CacheDeleteToken(token.Key); err != nil {
  542. log.Error("delete token from cache failed: " + err.Error())
  543. }
  544. }
  545. }
  546. }()
  547. return DB.Transaction(func(tx *gorm.DB) error {
  548. return tx.
  549. Clauses(clause.Returning{
  550. Columns: []clause.Column{
  551. {Name: "key"},
  552. },
  553. }).
  554. Where("id IN (?)", ids).
  555. Delete(&tokens).
  556. Error
  557. })
  558. }
  559. type UpdateTokenRequest struct {
  560. Name *string `json:"name"`
  561. Subnets *[]string `json:"subnets"`
  562. Models *[]string `json:"models"`
  563. Status int `json:"status"`
  564. // Quota system
  565. Quota *float64 `json:"quota"`
  566. PeriodQuota *float64 `json:"period_quota"`
  567. PeriodType *string `json:"period_type"`
  568. PeriodLastUpdateTime *int64 `json:"period_last_update_time"`
  569. }
  570. func UpdateToken(id int, update UpdateTokenRequest) (token *Token, err error) {
  571. if id == 0 {
  572. return nil, errors.New("id is empty")
  573. }
  574. token = &Token{
  575. ID: id,
  576. Status: update.Status,
  577. }
  578. defer func() {
  579. if err == nil {
  580. if err := CacheDeleteToken(token.Key); err != nil {
  581. log.Error("delete token from cache failed: " + err.Error())
  582. }
  583. }
  584. }()
  585. selects := []string{}
  586. if update.Name != nil && *update.Name != "" {
  587. token.Name = EmptyNullString(*update.Name)
  588. selects = append(selects, "name")
  589. }
  590. if update.Quota != nil {
  591. token.Quota = *update.Quota
  592. selects = append(selects, "quota")
  593. }
  594. if update.PeriodQuota != nil {
  595. token.PeriodQuota = *update.PeriodQuota
  596. selects = append(selects, "period_quota")
  597. }
  598. if update.PeriodType != nil {
  599. token.PeriodType = EmptyNullString(*update.PeriodType)
  600. selects = append(selects, "period_type")
  601. }
  602. if update.PeriodLastUpdateTime != nil {
  603. token.PeriodLastUpdateTime = time.UnixMilli(*update.PeriodLastUpdateTime)
  604. selects = append(selects, "period_last_update_time")
  605. }
  606. if update.Subnets != nil {
  607. token.Subnets = *update.Subnets
  608. selects = append(selects, "subnets")
  609. }
  610. if update.Models != nil {
  611. token.Models = *update.Models
  612. selects = append(selects, "models")
  613. }
  614. if update.Status != 0 {
  615. selects = append(selects, "status")
  616. }
  617. if len(selects) == 0 {
  618. return nil, errors.New("empty update request")
  619. }
  620. result := DB.
  621. Select(selects).
  622. Where("id = ?", id).
  623. Clauses(clause.Returning{}).
  624. Updates(token)
  625. if result.Error != nil {
  626. if errors.Is(result.Error, gorm.ErrDuplicatedKey) {
  627. return nil, errors.New("token name already exists in this group")
  628. }
  629. }
  630. return token, HandleUpdateResult(result, ErrTokenNotFound)
  631. }
  632. func UpdateGroupToken(
  633. id int,
  634. group string,
  635. update UpdateTokenRequest,
  636. ) (token *Token, err error) {
  637. if id == 0 || group == "" {
  638. return nil, errors.New("id or group is empty")
  639. }
  640. token = &Token{
  641. ID: id,
  642. GroupID: group,
  643. Status: update.Status,
  644. }
  645. defer func() {
  646. if err == nil {
  647. if err := CacheDeleteToken(token.Key); err != nil {
  648. log.Error("delete token from cache failed: " + err.Error())
  649. }
  650. }
  651. }()
  652. selects := []string{}
  653. if update.Name != nil && *update.Name != "" {
  654. token.Name = EmptyNullString(*update.Name)
  655. selects = append(selects, "name")
  656. }
  657. if update.Quota != nil {
  658. token.Quota = *update.Quota
  659. selects = append(selects, "quota")
  660. }
  661. if update.PeriodQuota != nil {
  662. token.PeriodQuota = *update.PeriodQuota
  663. selects = append(selects, "period_quota")
  664. }
  665. if update.PeriodType != nil {
  666. token.PeriodType = EmptyNullString(*update.PeriodType)
  667. selects = append(selects, "period_type")
  668. }
  669. if update.PeriodLastUpdateTime != nil {
  670. token.PeriodLastUpdateTime = time.UnixMilli(*update.PeriodLastUpdateTime)
  671. selects = append(selects, "period_last_update_time")
  672. }
  673. if update.Subnets != nil {
  674. token.Subnets = *update.Subnets
  675. selects = append(selects, "subnets")
  676. }
  677. if update.Models != nil {
  678. token.Models = *update.Models
  679. selects = append(selects, "models")
  680. }
  681. if update.Status != 0 {
  682. selects = append(selects, "status")
  683. }
  684. if len(selects) == 0 {
  685. return nil, errors.New("empty update request")
  686. }
  687. result := DB.
  688. Select(selects).
  689. Where("id = ? and group_id = ?", id, group).
  690. Clauses(clause.Returning{}).
  691. Updates(token)
  692. if result.Error != nil {
  693. if errors.Is(result.Error, gorm.ErrDuplicatedKey) {
  694. return nil, errors.New("token name already exists in this group")
  695. }
  696. }
  697. return token, HandleUpdateResult(result, ErrTokenNotFound)
  698. }
  699. func UpdateTokenUsedAmount(id int, amount float64, requestCount int) (err error) {
  700. token := &Token{}
  701. defer func() {
  702. if amount > 0 && err == nil && (token.Quota > 0 || token.PeriodQuota > 0) {
  703. if err := CacheUpdateTokenUsedAmountOnlyIncrease(token.Key, token.UsedAmount); err != nil {
  704. log.Error("update token used amount in cache failed: " + err.Error())
  705. }
  706. }
  707. }()
  708. result := DB.
  709. Model(token).
  710. Clauses(clause.Returning{
  711. Columns: []clause.Column{
  712. {Name: "key"},
  713. {Name: "quota"},
  714. {Name: "used_amount"},
  715. {Name: "period_quota"},
  716. },
  717. }).
  718. Where("id = ?", id).
  719. Updates(
  720. map[string]any{
  721. "used_amount": gorm.Expr("used_amount + ?", amount),
  722. "request_count": gorm.Expr("request_count + ?", requestCount),
  723. },
  724. )
  725. return HandleUpdateResult(result, ErrTokenNotFound)
  726. }
  727. // ResetTokenPeriodUsage resets the period usage for a token with concurrency safety
  728. // This updates PeriodLastUpdateTime and PeriodLastUpdateAmount to current values
  729. func ResetTokenPeriodUsage(id int) error {
  730. token := &Token{}
  731. // Use database transaction with optimistic locking to prevent concurrent resets
  732. err := DB.Transaction(func(tx *gorm.DB) error {
  733. // First, read the current state with FOR UPDATE lock
  734. if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
  735. Where("id = ?", id).
  736. First(token).Error; err != nil {
  737. return err
  738. }
  739. // Check if period still needs reset (another concurrent request might have already reset it)
  740. needsReset, err := token.NeedsPeriodReset()
  741. if err != nil {
  742. return err
  743. }
  744. // If period no longer needs reset, skip the update
  745. if !needsReset {
  746. return nil
  747. }
  748. // Perform the reset with the lock held - update period last update time and amount
  749. result := tx.
  750. Model(token).
  751. Clauses(clause.Returning{
  752. Columns: []clause.Column{
  753. {Name: "key"},
  754. },
  755. }).
  756. Where("id = ?", id).
  757. Updates(
  758. map[string]any{
  759. "period_last_update_time": time.Now(),
  760. "period_last_update_amount": gorm.Expr(
  761. "used_amount",
  762. ), // Set to current total usage
  763. },
  764. )
  765. return HandleUpdateResult(result, ErrTokenNotFound)
  766. })
  767. // Update cache only if database update succeeded
  768. if err == nil && token.Key != "" {
  769. if cacheErr := CacheResetTokenPeriodUsage(token.Key, time.Now(), token.UsedAmount); cacheErr != nil {
  770. log.Error("reset token period usage in cache failed: " + cacheErr.Error())
  771. }
  772. }
  773. return err
  774. }
  775. func UpdateTokenName(id int, name string) (err error) {
  776. token := &Token{ID: id}
  777. defer func() {
  778. if err == nil {
  779. if err := CacheUpdateTokenName(token.Key, name); err != nil {
  780. log.Error("update token name in cache failed: " + err.Error())
  781. }
  782. }
  783. }()
  784. result := DB.
  785. Model(token).
  786. Clauses(clause.Returning{
  787. Columns: []clause.Column{
  788. {Name: "key"},
  789. },
  790. }).
  791. Where("id = ?", id).
  792. Update("name", name)
  793. if result.Error != nil && errors.Is(result.Error, gorm.ErrDuplicatedKey) {
  794. return errors.New("token name already exists in this group")
  795. }
  796. return HandleUpdateResult(result, ErrTokenNotFound)
  797. }
  798. func UpdateGroupTokenName(group string, id int, name string) (err error) {
  799. token := &Token{ID: id, GroupID: group}
  800. defer func() {
  801. if err == nil {
  802. if err := CacheUpdateTokenName(token.Key, name); err != nil {
  803. log.Error("update token name in cache failed: " + err.Error())
  804. }
  805. }
  806. }()
  807. result := DB.
  808. Model(token).
  809. Clauses(clause.Returning{
  810. Columns: []clause.Column{
  811. {Name: "key"},
  812. },
  813. }).
  814. Where("id = ? and group_id = ?", id, group).
  815. Update("name", name)
  816. if result.Error != nil && errors.Is(result.Error, gorm.ErrDuplicatedKey) {
  817. return errors.New("token name already exists in this group")
  818. }
  819. return HandleUpdateResult(result, ErrTokenNotFound)
  820. }