token.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "math/rand/v2"
  6. "strings"
  7. "time"
  8. "github.com/labring/aiproxy/core/common"
  9. "github.com/labring/aiproxy/core/common/config"
  10. "github.com/labring/aiproxy/core/common/conv"
  11. log "github.com/sirupsen/logrus"
  12. "gorm.io/gorm"
  13. "gorm.io/gorm/clause"
  14. )
  15. const (
  16. ErrTokenNotFound = "token"
  17. )
  18. const (
  19. TokenStatusEnabled = 1
  20. TokenStatusDisabled = 2
  21. )
  22. type Token struct {
  23. CreatedAt time.Time `json:"created_at"`
  24. ExpiredAt time.Time `json:"expired_at"`
  25. Group *Group `json:"-" gorm:"foreignKey:GroupID"`
  26. Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
  27. Name EmptyNullString `json:"name" gorm:"index;uniqueIndex:idx_group_name;not null"`
  28. GroupID string `json:"group" gorm:"index;uniqueIndex:idx_group_name"`
  29. Subnets []string `json:"subnets" gorm:"serializer:fastjson;type:text"`
  30. Models []string `json:"models" gorm:"serializer:fastjson;type:text"`
  31. Status int `json:"status" gorm:"default:1;index"`
  32. ID int `json:"id" gorm:"primaryKey"`
  33. Quota float64 `json:"quota"`
  34. UsedAmount float64 `json:"used_amount" gorm:"index"`
  35. RequestCount int `json:"request_count" gorm:"index"`
  36. }
  37. func (t *Token) BeforeCreate(_ *gorm.DB) (err error) {
  38. if t.Key == "" || len(t.Key) != 48 {
  39. t.Key = generateKey()
  40. }
  41. return
  42. }
  43. const (
  44. keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
  45. )
  46. func generateKey() string {
  47. key := make([]byte, 48)
  48. for i := range key {
  49. key[i] = keyChars[rand.IntN(len(keyChars))]
  50. }
  51. return conv.BytesToString(key)
  52. }
  53. func getTokenOrder(order string) string {
  54. prefix, suffix, _ := strings.Cut(order, "-")
  55. switch prefix {
  56. case "name", "expired_at", "group", "used_amount", "request_count", "id", "created_at":
  57. switch suffix {
  58. case "asc":
  59. return prefix + " asc"
  60. default:
  61. return prefix + " desc"
  62. }
  63. default:
  64. return "id desc"
  65. }
  66. }
  67. func InsertToken(token *Token, autoCreateGroup, ignoreExist bool) error {
  68. if autoCreateGroup {
  69. group := &Group{
  70. ID: token.GroupID,
  71. }
  72. if err := OnConflictDoNothing().Create(group).Error; err != nil {
  73. return err
  74. }
  75. }
  76. maxTokenNum := config.GetGroupMaxTokenNum()
  77. err := DB.Transaction(func(tx *gorm.DB) error {
  78. if maxTokenNum > 0 {
  79. var count int64
  80. err := tx.Model(&Token{}).Where("group_id = ?", token.GroupID).Count(&count).Error
  81. if err != nil {
  82. return err
  83. }
  84. if count >= maxTokenNum {
  85. return errors.New("group max token num reached")
  86. }
  87. }
  88. if ignoreExist {
  89. return tx.
  90. Where("group_id = ? and name = ?", token.GroupID, token.Name).
  91. FirstOrCreate(token).Error
  92. }
  93. return tx.Create(token).Error
  94. })
  95. if err != nil {
  96. if errors.Is(err, gorm.ErrDuplicatedKey) {
  97. if ignoreExist {
  98. return nil
  99. }
  100. return errors.New("token name already exists in this group")
  101. }
  102. return err
  103. }
  104. return nil
  105. }
  106. func GetTokens(
  107. group string,
  108. page, perPage int,
  109. order string,
  110. status int,
  111. ) (tokens []*Token, total int64, err error) {
  112. tx := DB.Model(&Token{})
  113. if group != "" {
  114. tx = tx.Where("group_id = ?", group)
  115. }
  116. if status != 0 {
  117. tx = tx.Where("status = ?", status)
  118. }
  119. err = tx.Count(&total).Error
  120. if err != nil {
  121. return nil, 0, err
  122. }
  123. if total <= 0 {
  124. return nil, 0, nil
  125. }
  126. limit, offset := toLimitOffset(page, perPage)
  127. err = tx.Order(getTokenOrder(order)).Limit(limit).Offset(offset).Find(&tokens).Error
  128. return tokens, total, err
  129. }
  130. func SearchTokens(
  131. group, keyword string,
  132. page, perPage int,
  133. order string,
  134. status int,
  135. name, key string,
  136. ) (tokens []*Token, total int64, err error) {
  137. tx := DB.Model(&Token{})
  138. if group != "" {
  139. tx = tx.Where("group_id = ?", group)
  140. }
  141. if status != 0 {
  142. tx = tx.Where("status = ?", status)
  143. }
  144. if name != "" {
  145. tx = tx.Where("name = ?", name)
  146. }
  147. if key != "" {
  148. tx = tx.Where("key = ?", key)
  149. }
  150. if keyword != "" {
  151. var (
  152. conditions []string
  153. values []any
  154. )
  155. if group == "" {
  156. if common.UsingPostgreSQL {
  157. conditions = append(conditions, "group_id ILIKE ?")
  158. } else {
  159. conditions = append(conditions, "group_id LIKE ?")
  160. }
  161. values = append(values, "%"+keyword+"%")
  162. }
  163. if name == "" {
  164. if common.UsingPostgreSQL {
  165. conditions = append(conditions, "name ILIKE ?")
  166. } else {
  167. conditions = append(conditions, "name LIKE ?")
  168. }
  169. values = append(values, "%"+keyword+"%")
  170. }
  171. if key == "" {
  172. if common.UsingPostgreSQL {
  173. conditions = append(conditions, "key ILIKE ?")
  174. } else {
  175. conditions = append(conditions, "key LIKE ?")
  176. }
  177. values = append(values, "%"+keyword+"%")
  178. }
  179. if common.UsingPostgreSQL {
  180. conditions = append(conditions, "models ILIKE ?")
  181. } else {
  182. conditions = append(conditions, "models LIKE ?")
  183. }
  184. values = append(values, "%"+keyword+"%")
  185. if len(conditions) > 0 {
  186. tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
  187. }
  188. }
  189. err = tx.Count(&total).Error
  190. if err != nil {
  191. return nil, 0, err
  192. }
  193. if total <= 0 {
  194. return nil, 0, nil
  195. }
  196. limit, offset := toLimitOffset(page, perPage)
  197. err = tx.Order(getTokenOrder(order)).Limit(limit).Offset(offset).Find(&tokens).Error
  198. return tokens, total, err
  199. }
  200. func SearchGroupTokens(
  201. group, keyword string,
  202. page, perPage int,
  203. order string,
  204. status int,
  205. name, key string,
  206. ) (tokens []*Token, total int64, err error) {
  207. if group == "" {
  208. return nil, 0, errors.New("group is empty")
  209. }
  210. tx := DB.Model(&Token{}).
  211. Where("group_id = ?", group)
  212. if name != "" {
  213. tx = tx.Where("name = ?", name)
  214. }
  215. if key != "" {
  216. tx = tx.Where("key = ?", key)
  217. }
  218. if status != 0 {
  219. tx = tx.Where("status = ?", status)
  220. }
  221. if keyword != "" {
  222. var (
  223. conditions []string
  224. values []any
  225. )
  226. if name == "" {
  227. if common.UsingPostgreSQL {
  228. conditions = append(conditions, "name ILIKE ?")
  229. } else {
  230. conditions = append(conditions, "name LIKE ?")
  231. }
  232. values = append(values, "%"+keyword+"%")
  233. }
  234. if key == "" {
  235. if common.UsingPostgreSQL {
  236. conditions = append(conditions, "key ILIKE ?")
  237. } else {
  238. conditions = append(conditions, "key LIKE ?")
  239. }
  240. values = append(values, "%"+keyword+"%")
  241. }
  242. if common.UsingPostgreSQL {
  243. conditions = append(conditions, "models ILIKE ?")
  244. } else {
  245. conditions = append(conditions, "models LIKE ?")
  246. }
  247. values = append(values, "%"+keyword+"%")
  248. if len(conditions) > 0 {
  249. tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
  250. }
  251. }
  252. err = tx.Count(&total).Error
  253. if err != nil {
  254. return nil, 0, err
  255. }
  256. if total <= 0 {
  257. return nil, 0, nil
  258. }
  259. limit, offset := toLimitOffset(page, perPage)
  260. err = tx.Order(getTokenOrder(order)).Limit(limit).Offset(offset).Find(&tokens).Error
  261. return tokens, total, err
  262. }
  263. func GetTokenByKey(key string) (*Token, error) {
  264. if key == "" {
  265. return nil, errors.New("key is empty")
  266. }
  267. var token Token
  268. err := DB.Where("key = ?", key).First(&token).Error
  269. return &token, HandleNotFound(err, ErrTokenNotFound)
  270. }
  271. func ValidateAndGetToken(key string) (token *TokenCache, err error) {
  272. if key == "" {
  273. return nil, errors.New("no token provided")
  274. }
  275. token, err = CacheGetTokenByKey(key)
  276. if err != nil {
  277. if errors.Is(err, gorm.ErrRecordNotFound) {
  278. return nil, errors.New("invalid token")
  279. }
  280. log.Error("get token from cache failed: " + err.Error())
  281. return nil, errors.New("token validation failed")
  282. }
  283. if token.Status == TokenStatusDisabled {
  284. return nil, fmt.Errorf("token (%s[%d]) is disabled", token.Name, token.ID)
  285. }
  286. if !time.Time(token.ExpiredAt).IsZero() && time.Time(token.ExpiredAt).Before(time.Now()) {
  287. return nil, fmt.Errorf("token (%s[%d]) is expired", token.Name, token.ID)
  288. }
  289. if token.Quota > 0 && token.UsedAmount >= token.Quota {
  290. return nil, fmt.Errorf("token (%s[%d]) quota is exhausted", token.Name, token.ID)
  291. }
  292. return token, nil
  293. }
  294. func GetGroupTokenByID(group string, id int) (*Token, error) {
  295. if id == 0 || group == "" {
  296. return nil, errors.New("id or group is empty")
  297. }
  298. token := Token{}
  299. err := DB.
  300. Where("id = ? and group_id = ?", id, group).
  301. First(&token).Error
  302. return &token, HandleNotFound(err, ErrTokenNotFound)
  303. }
  304. func GetTokenByID(id int) (*Token, error) {
  305. if id == 0 {
  306. return nil, errors.New("id is empty")
  307. }
  308. token := Token{ID: id}
  309. err := DB.First(&token, "id = ?", id).Error
  310. return &token, HandleNotFound(err, ErrTokenNotFound)
  311. }
  312. func UpdateTokenStatus(id, status int) (err error) {
  313. token := Token{ID: id}
  314. defer func() {
  315. if err == nil {
  316. if err := CacheUpdateTokenStatus(token.Key, status); err != nil {
  317. log.Error("update token status in cache failed: " + err.Error())
  318. }
  319. }
  320. }()
  321. result := DB.
  322. Model(&token).
  323. Clauses(clause.Returning{
  324. Columns: []clause.Column{
  325. {Name: "key"},
  326. },
  327. }).
  328. Where("id = ?", id).
  329. Updates(
  330. map[string]any{
  331. "status": status,
  332. },
  333. )
  334. return HandleUpdateResult(result, ErrTokenNotFound)
  335. }
  336. func UpdateGroupTokenStatus(group string, id, status int) (err error) {
  337. if id == 0 || group == "" {
  338. return errors.New("id or group is empty")
  339. }
  340. token := Token{}
  341. defer func() {
  342. if err == nil {
  343. if err := CacheUpdateTokenStatus(token.Key, status); err != nil {
  344. log.Error("update token status in cache failed: " + err.Error())
  345. }
  346. }
  347. }()
  348. result := DB.
  349. Model(&token).
  350. Clauses(clause.Returning{
  351. Columns: []clause.Column{
  352. {Name: "key"},
  353. },
  354. }).
  355. Where("id = ? and group_id = ?", id, group).
  356. Updates(
  357. map[string]any{
  358. "status": status,
  359. },
  360. )
  361. return HandleUpdateResult(result, ErrTokenNotFound)
  362. }
  363. func DeleteGroupTokenByID(groupID string, id int) (err error) {
  364. if id == 0 || groupID == "" {
  365. return errors.New("id or group is empty")
  366. }
  367. token := Token{ID: id, GroupID: groupID}
  368. defer func() {
  369. if err == nil {
  370. if err := CacheDeleteToken(token.Key); err != nil {
  371. log.Error("delete token from cache failed: " + err.Error())
  372. }
  373. }
  374. }()
  375. result := DB.
  376. Clauses(clause.Returning{
  377. Columns: []clause.Column{
  378. {Name: "key"},
  379. },
  380. }).
  381. Where(token).
  382. Delete(&token)
  383. return HandleUpdateResult(result, ErrTokenNotFound)
  384. }
  385. func DeleteGroupTokensByIDs(group string, ids []int) (err error) {
  386. if group == "" {
  387. return errors.New("group is empty")
  388. }
  389. if len(ids) == 0 {
  390. return nil
  391. }
  392. tokens := make([]Token, len(ids))
  393. defer func() {
  394. if err == nil {
  395. for _, token := range tokens {
  396. if err := CacheDeleteToken(token.Key); err != nil {
  397. log.Error("delete token from cache failed: " + err.Error())
  398. }
  399. }
  400. }
  401. }()
  402. return DB.Transaction(func(tx *gorm.DB) error {
  403. return tx.
  404. Clauses(clause.Returning{
  405. Columns: []clause.Column{
  406. {Name: "key"},
  407. },
  408. }).
  409. Where("group_id = ?", group).
  410. Where("id IN (?)", ids).
  411. Delete(&tokens).
  412. Error
  413. })
  414. }
  415. func DeleteTokenByID(id int) (err error) {
  416. if id == 0 {
  417. return errors.New("id is empty")
  418. }
  419. token := Token{ID: id}
  420. defer func() {
  421. if err == nil {
  422. if err := CacheDeleteToken(token.Key); err != nil {
  423. log.Error("delete token from cache failed: " + err.Error())
  424. }
  425. }
  426. }()
  427. result := DB.
  428. Clauses(clause.Returning{
  429. Columns: []clause.Column{
  430. {Name: "key"},
  431. },
  432. }).
  433. Where(token).
  434. Delete(&token)
  435. return HandleUpdateResult(result, ErrTokenNotFound)
  436. }
  437. func DeleteTokensByIDs(ids []int) (err error) {
  438. if len(ids) == 0 {
  439. return nil
  440. }
  441. tokens := make([]Token, len(ids))
  442. defer func() {
  443. if err == nil {
  444. for _, token := range tokens {
  445. if err := CacheDeleteToken(token.Key); err != nil {
  446. log.Error("delete token from cache failed: " + err.Error())
  447. }
  448. }
  449. }
  450. }()
  451. return DB.Transaction(func(tx *gorm.DB) error {
  452. return tx.
  453. Clauses(clause.Returning{
  454. Columns: []clause.Column{
  455. {Name: "key"},
  456. },
  457. }).
  458. Where("id IN (?)", ids).
  459. Delete(&tokens).
  460. Error
  461. })
  462. }
  463. func UpdateToken(id int, token *Token) (err error) {
  464. if id == 0 {
  465. return errors.New("id is empty")
  466. }
  467. defer func() {
  468. if err == nil {
  469. if err := CacheDeleteToken(token.Key); err != nil {
  470. log.Error("delete token from cache failed: " + err.Error())
  471. }
  472. }
  473. }()
  474. selects := []string{
  475. "subnets",
  476. "quota",
  477. "models",
  478. "expired_at",
  479. }
  480. if token.Name != "" {
  481. selects = append(selects, "name")
  482. }
  483. if token.Status != 0 {
  484. selects = append(selects, "status")
  485. }
  486. result := DB.
  487. Select(selects).
  488. Where("id = ?", id).
  489. Clauses(clause.Returning{}).
  490. Updates(token)
  491. if result.Error != nil {
  492. if errors.Is(result.Error, gorm.ErrDuplicatedKey) {
  493. return errors.New("token name already exists in this group")
  494. }
  495. }
  496. return HandleUpdateResult(result, ErrTokenNotFound)
  497. }
  498. func UpdateGroupToken(id int, group string, token *Token) (err error) {
  499. if id == 0 || group == "" {
  500. return errors.New("id or group is empty")
  501. }
  502. defer func() {
  503. if err == nil {
  504. if err := CacheDeleteToken(token.Key); err != nil {
  505. log.Error("delete token from cache failed: " + err.Error())
  506. }
  507. }
  508. }()
  509. selects := []string{
  510. "subnets",
  511. "quota",
  512. "models",
  513. "expired_at",
  514. }
  515. if token.Name != "" {
  516. selects = append(selects, "name")
  517. }
  518. if token.Status != 0 {
  519. selects = append(selects, "status")
  520. }
  521. result := DB.
  522. Select(selects).
  523. Where("id = ? and group_id = ?", id, group).
  524. Clauses(clause.Returning{}).
  525. Updates(token)
  526. if result.Error != nil {
  527. if errors.Is(result.Error, gorm.ErrDuplicatedKey) {
  528. return errors.New("token name already exists in this group")
  529. }
  530. }
  531. return HandleUpdateResult(result, ErrTokenNotFound)
  532. }
  533. func UpdateTokenUsedAmount(id int, amount float64, requestCount int) (err error) {
  534. token := &Token{}
  535. defer func() {
  536. if amount > 0 && err == nil && token.Quota > 0 {
  537. if err := CacheUpdateTokenUsedAmountOnlyIncrease(token.Key, token.UsedAmount); err != nil {
  538. log.Error("update token used amount in cache failed: " + err.Error())
  539. }
  540. }
  541. }()
  542. result := DB.
  543. Model(token).
  544. Clauses(clause.Returning{
  545. Columns: []clause.Column{
  546. {Name: "key"},
  547. {Name: "quota"},
  548. {Name: "used_amount"},
  549. },
  550. }).
  551. Where("id = ?", id).
  552. Updates(
  553. map[string]any{
  554. "used_amount": gorm.Expr("used_amount + ?", amount),
  555. "request_count": gorm.Expr("request_count + ?", requestCount),
  556. },
  557. )
  558. return HandleUpdateResult(result, ErrTokenNotFound)
  559. }
  560. func UpdateTokenName(id int, name string) (err error) {
  561. token := &Token{ID: id}
  562. defer func() {
  563. if err == nil {
  564. if err := CacheUpdateTokenName(token.Key, name); err != nil {
  565. log.Error("update token name in cache failed: " + err.Error())
  566. }
  567. }
  568. }()
  569. result := DB.
  570. Model(token).
  571. Clauses(clause.Returning{
  572. Columns: []clause.Column{
  573. {Name: "key"},
  574. },
  575. }).
  576. Where("id = ?", id).
  577. Update("name", name)
  578. if result.Error != nil && errors.Is(result.Error, gorm.ErrDuplicatedKey) {
  579. return errors.New("token name already exists in this group")
  580. }
  581. return HandleUpdateResult(result, ErrTokenNotFound)
  582. }
  583. func UpdateGroupTokenName(group string, id int, name string) (err error) {
  584. token := &Token{ID: id, GroupID: group}
  585. defer func() {
  586. if err == nil {
  587. if err := CacheUpdateTokenName(token.Key, name); err != nil {
  588. log.Error("update token name in cache failed: " + err.Error())
  589. }
  590. }
  591. }()
  592. result := DB.
  593. Model(token).
  594. Clauses(clause.Returning{
  595. Columns: []clause.Column{
  596. {Name: "key"},
  597. },
  598. }).
  599. Where("id = ? and group_id = ?", id, group).
  600. Update("name", name)
  601. if result.Error != nil && errors.Is(result.Error, gorm.ErrDuplicatedKey) {
  602. return errors.New("token name already exists in this group")
  603. }
  604. return HandleUpdateResult(result, ErrTokenNotFound)
  605. }