token.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "log"
  6. "one-api/common"
  7. "strings"
  8. "time"
  9. "github.com/bytedance/gopkg/util/gopool"
  10. "gorm.io/gorm"
  11. )
  12. type Token struct {
  13. Id int `json:"id"`
  14. UserId int `json:"user_id" gorm:"index"`
  15. Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
  16. Status int `json:"status" gorm:"default:1"`
  17. Name string `json:"name" gorm:"index" `
  18. CreatedTime int64 `json:"created_time" gorm:"bigint"`
  19. AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
  20. ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
  21. RemainQuota int `json:"remain_quota" gorm:"default:0"`
  22. UnlimitedQuota bool `json:"unlimited_quota"`
  23. ModelLimitsEnabled bool `json:"model_limits_enabled"`
  24. ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
  25. AllowIps *string `json:"allow_ips" gorm:"default:''"`
  26. UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
  27. Group string `json:"group" gorm:"default:''"`
  28. DailyUsageCount int `json:"daily_usage_count" gorm:"default:0"` // 今日使用次数
  29. TotalUsageCount int `json:"total_usage_count" gorm:"default:0"` // 总使用次数
  30. LastUsageDate string `json:"last_usage_date" gorm:"default:''"` // 最后使用日期(YYYY-MM-DD)
  31. RateLimitPerMinute int `json:"rate_limit_per_minute" gorm:"default:0"` // 每分钟访问次数限制,0表示不限制
  32. RateLimitPerDay int `json:"rate_limit_per_day" gorm:"default:0"` // 每日访问次数限制,0表示不限制
  33. LastRateLimitReset int64 `json:"last_rate_limit_reset" gorm:"default:0"` // 最后重置时间戳
  34. ChannelTag *string `json:"channel_tag" gorm:"default:''"` // 渠道标签限制
  35. TotalUsageLimit *int `json:"total_usage_limit" gorm:"default:null"` // 总使用次数限制,nil表示不限制
  36. DeletedAt gorm.DeletedAt `gorm:"index"`
  37. }
  38. func (token *Token) Clean() {
  39. token.Key = ""
  40. }
  41. func (token *Token) GetIpLimitsMap() map[string]any {
  42. // delete empty spaces
  43. //split with \n
  44. ipLimitsMap := make(map[string]any)
  45. if token.AllowIps == nil {
  46. return ipLimitsMap
  47. }
  48. cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
  49. if cleanIps == "" {
  50. return ipLimitsMap
  51. }
  52. ips := strings.Split(cleanIps, "\n")
  53. for _, ip := range ips {
  54. ip = strings.TrimSpace(ip)
  55. ip = strings.ReplaceAll(ip, ",", "")
  56. if common.IsIP(ip) {
  57. ipLimitsMap[ip] = true
  58. }
  59. }
  60. return ipLimitsMap
  61. }
  62. func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
  63. var tokens []*Token
  64. var err error
  65. err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error
  66. return tokens, err
  67. }
  68. func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
  69. if token != "" {
  70. token = strings.Trim(token, "sk-")
  71. }
  72. err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
  73. return tokens, err
  74. }
  75. func ValidateUserToken(key string) (token *Token, err error) {
  76. log.Println("===========", key)
  77. if key == "" {
  78. return nil, errors.New("未提供令牌")
  79. }
  80. token, err = GetTokenByKey(key, false)
  81. if err == nil {
  82. if token.Status == common.TokenStatusExhausted {
  83. keyPrefix := key[:3]
  84. keySuffix := key[len(key)-3:]
  85. return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
  86. } else if token.Status == common.TokenStatusExpired {
  87. return token, errors.New("该令牌已过期")
  88. }
  89. if token.Status != common.TokenStatusEnabled {
  90. return token, errors.New("该令牌状态不可用")
  91. }
  92. if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
  93. if !common.RedisEnabled {
  94. token.Status = common.TokenStatusExpired
  95. err := token.SelectUpdate()
  96. if err != nil {
  97. common.SysError("failed to update token status" + err.Error())
  98. }
  99. }
  100. return token, errors.New("该令牌已过期")
  101. }
  102. if !token.UnlimitedQuota && token.RemainQuota <= 0 {
  103. if !common.RedisEnabled {
  104. // in this case, we can make sure the token is exhausted
  105. token.Status = common.TokenStatusExhausted
  106. err := token.SelectUpdate()
  107. if err != nil {
  108. common.SysError("failed to update token status" + err.Error())
  109. }
  110. }
  111. keyPrefix := key[:3]
  112. keySuffix := key[len(key)-3:]
  113. return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
  114. }
  115. // 检查总使用次数限制
  116. if token.TotalUsageLimit != nil && *token.TotalUsageLimit > 0 && token.TotalUsageCount >= *token.TotalUsageLimit {
  117. keyPrefix := key[:3]
  118. keySuffix := key[len(key)-3:]
  119. return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌总使用次数已用完,限制次数: %d,已使用次数: %d", keyPrefix, keySuffix, *token.TotalUsageLimit, token.TotalUsageCount))
  120. }
  121. // 检查访问频率限制
  122. if err := CheckRateLimit(token); err != nil {
  123. return token, err
  124. }
  125. return token, nil
  126. }
  127. return nil, errors.New("无效的令牌")
  128. }
  129. func GetTokenByIds(id int, userId int) (*Token, error) {
  130. if id == 0 || userId == 0 {
  131. return nil, errors.New("id 或 userId 为空!")
  132. }
  133. token := Token{Id: id, UserId: userId}
  134. var err error = nil
  135. err = DB.First(&token, "id = ? and user_id = ?", id, userId).Error
  136. return &token, err
  137. }
  138. func GetTokenById(id int) (*Token, error) {
  139. if id == 0 {
  140. return nil, errors.New("id 为空!")
  141. }
  142. token := Token{Id: id}
  143. var err error = nil
  144. err = DB.First(&token, "id = ?", id).Error
  145. if shouldUpdateRedis(true, err) {
  146. gopool.Go(func() {
  147. if err := cacheSetToken(token); err != nil {
  148. common.SysError("failed to update user status cache: " + err.Error())
  149. }
  150. })
  151. }
  152. return &token, err
  153. }
  154. func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
  155. defer func() {
  156. // Update Redis cache asynchronously on successful DB read
  157. if shouldUpdateRedis(fromDB, err) && token != nil {
  158. gopool.Go(func() {
  159. if err := cacheSetToken(*token); err != nil {
  160. common.SysError("failed to update user status cache: " + err.Error())
  161. }
  162. })
  163. }
  164. }()
  165. if !fromDB && common.RedisEnabled {
  166. // Try Redis first
  167. token, err := cacheGetTokenByKey(key)
  168. if err == nil {
  169. return token, nil
  170. }
  171. // Don't return error - fall through to DB
  172. }
  173. fromDB = true
  174. err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
  175. return token, err
  176. }
  177. func (token *Token) Insert() error {
  178. var err error
  179. err = DB.Create(token).Error
  180. return err
  181. }
  182. // Update Make sure your token's fields is completed, because this will update non-zero values
  183. func (token *Token) Update() (err error) {
  184. defer func() {
  185. if shouldUpdateRedis(true, err) {
  186. gopool.Go(func() {
  187. err := cacheSetToken(*token)
  188. if err != nil {
  189. common.SysError("failed to update token cache: " + err.Error())
  190. }
  191. })
  192. }
  193. }()
  194. err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
  195. "model_limits_enabled", "model_limits", "allow_ips", "group", "daily_usage_count", "total_usage_count", "last_usage_date",
  196. "rate_limit_per_minute", "rate_limit_per_day", "last_rate_limit_reset", "channel_tag", "total_usage_limit").Updates(token).Error
  197. return err
  198. }
  199. func (token *Token) SelectUpdate() (err error) {
  200. defer func() {
  201. if shouldUpdateRedis(true, err) {
  202. gopool.Go(func() {
  203. err := cacheSetToken(*token)
  204. if err != nil {
  205. common.SysError("failed to update token cache: " + err.Error())
  206. }
  207. })
  208. }
  209. }()
  210. // This can update zero values
  211. return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
  212. }
  213. func (token *Token) Delete() (err error) {
  214. defer func() {
  215. if shouldUpdateRedis(true, err) {
  216. gopool.Go(func() {
  217. err := cacheDeleteToken(token.Key)
  218. if err != nil {
  219. common.SysError("failed to delete token cache: " + err.Error())
  220. }
  221. })
  222. }
  223. }()
  224. err = DB.Delete(token).Error
  225. return err
  226. }
  227. func (token *Token) IsModelLimitsEnabled() bool {
  228. return token.ModelLimitsEnabled
  229. }
  230. func (token *Token) GetModelLimits() []string {
  231. if token.ModelLimits == "" {
  232. return []string{}
  233. }
  234. return strings.Split(token.ModelLimits, ",")
  235. }
  236. func (token *Token) GetModelLimitsMap() map[string]bool {
  237. limits := token.GetModelLimits()
  238. limitsMap := make(map[string]bool)
  239. for _, limit := range limits {
  240. limitsMap[limit] = true
  241. }
  242. return limitsMap
  243. }
  244. func DisableModelLimits(tokenId int) error {
  245. token, err := GetTokenById(tokenId)
  246. if err != nil {
  247. return err
  248. }
  249. token.ModelLimitsEnabled = false
  250. token.ModelLimits = ""
  251. return token.Update()
  252. }
  253. func DeleteTokenById(id int, userId int) (err error) {
  254. // Why we need userId here? In case user want to delete other's token.
  255. if id == 0 || userId == 0 {
  256. return errors.New("id 或 userId 为空!")
  257. }
  258. token := Token{Id: id, UserId: userId}
  259. err = DB.Where(token).First(&token).Error
  260. if err != nil {
  261. return err
  262. }
  263. return token.Delete()
  264. }
  265. func IncreaseTokenQuota(id int, key string, quota int) (err error) {
  266. if quota < 0 {
  267. return errors.New("quota 不能为负数!")
  268. }
  269. if common.RedisEnabled {
  270. gopool.Go(func() {
  271. err := cacheIncrTokenQuota(key, int64(quota))
  272. if err != nil {
  273. common.SysError("failed to increase token quota: " + err.Error())
  274. }
  275. })
  276. }
  277. if common.BatchUpdateEnabled {
  278. addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
  279. return nil
  280. }
  281. return increaseTokenQuota(id, quota)
  282. }
  283. func increaseTokenQuota(id int, quota int) (err error) {
  284. err = DB.Model(&Token{}).Where("id = ?", id).Updates(
  285. map[string]interface{}{
  286. "remain_quota": gorm.Expr("remain_quota + ?", quota),
  287. "used_quota": gorm.Expr("used_quota - ?", quota),
  288. "accessed_time": common.GetTimestamp(),
  289. },
  290. ).Error
  291. return err
  292. }
  293. func DecreaseTokenQuota(id int, key string, quota int) (err error) {
  294. if quota < 0 {
  295. return errors.New("quota 不能为负数!")
  296. }
  297. if common.RedisEnabled {
  298. gopool.Go(func() {
  299. err := cacheDecrTokenQuota(key, int64(quota))
  300. if err != nil {
  301. common.SysError("failed to decrease token quota: " + err.Error())
  302. }
  303. })
  304. }
  305. if common.BatchUpdateEnabled {
  306. addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
  307. return nil
  308. }
  309. return decreaseTokenQuota(id, quota)
  310. }
  311. func decreaseTokenQuota(id int, quota int) (err error) {
  312. err = DB.Model(&Token{}).Where("id = ?", id).Updates(
  313. map[string]interface{}{
  314. "remain_quota": gorm.Expr("remain_quota - ?", quota),
  315. "used_quota": gorm.Expr("used_quota + ?", quota),
  316. "accessed_time": common.GetTimestamp(),
  317. },
  318. ).Error
  319. return err
  320. }
  321. // CountUserTokens returns total number of tokens for the given user, used for pagination
  322. func CountUserTokens(userId int) (int64, error) {
  323. var total int64
  324. err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
  325. return total, err
  326. }
  327. // BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量
  328. func BatchDeleteTokens(ids []int, userId int) (int, error) {
  329. if len(ids) == 0 {
  330. return 0, errors.New("ids 不能为空!")
  331. }
  332. tx := DB.Begin()
  333. var tokens []Token
  334. if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil {
  335. tx.Rollback()
  336. return 0, err
  337. }
  338. if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil {
  339. tx.Rollback()
  340. return 0, err
  341. }
  342. if err := tx.Commit().Error; err != nil {
  343. return 0, err
  344. }
  345. if common.RedisEnabled {
  346. gopool.Go(func() {
  347. for _, t := range tokens {
  348. _ = cacheDeleteToken(t.Key)
  349. }
  350. })
  351. }
  352. return len(tokens), nil
  353. }
  354. // IncreaseTokenUsageCount 增加Token的使用次数
  355. func IncreaseTokenUsageCount(key string) error {
  356. if key == "" {
  357. return errors.New("key 不能为空")
  358. }
  359. // 获取当前日期
  360. currentDate := common.GetTimeString()[:10] // YYYY-MM-DD格式
  361. // 更新数据库
  362. err := DB.Model(&Token{}).Where(commonKeyCol+" = ?", key).Updates(map[string]interface{}{
  363. "total_usage_count": gorm.Expr("total_usage_count + 1"),
  364. "daily_usage_count": gorm.Expr("CASE WHEN last_usage_date = ? THEN daily_usage_count + 1 ELSE 1 END", currentDate),
  365. "last_usage_date": currentDate,
  366. "accessed_time": common.GetTimestamp(),
  367. }).Error
  368. // 更新缓存
  369. if common.RedisEnabled && err == nil {
  370. gopool.Go(func() {
  371. // 重新缓存Token信息
  372. token, getErr := GetTokenByKey(key, true) // 从DB获取最新数据
  373. if getErr == nil {
  374. _ = cacheSetToken(*token)
  375. }
  376. })
  377. }
  378. return err
  379. }
  380. // CheckRateLimit 检查令牌的访问频率限制
  381. func CheckRateLimit(token *Token) error {
  382. if token == nil {
  383. return errors.New("token不能为空")
  384. }
  385. // 如果没有设置限制,直接返回
  386. if token.RateLimitPerMinute <= 0 && token.RateLimitPerDay <= 0 {
  387. return nil
  388. }
  389. currentTime := time.Now()
  390. currentTimestamp := currentTime.Unix()
  391. // 获取当前时间的分钟级时间戳(用于分钟级限制)
  392. currentMinute := currentTime.Truncate(time.Minute).Unix()
  393. // 获取当天开始的时间戳(用于日级限制)
  394. currentDay := time.Date(currentTime.Year(), currentTime.Month(), currentTime.Day(), 0, 0, 0, 0, currentTime.Location()).Unix()
  395. // 检查是否需要重置计数器
  396. needUpdate := false
  397. originalToken := *token // 保存原始状态
  398. // 检查分钟级限制
  399. if token.RateLimitPerMinute > 0 {
  400. // 如果上次重置时间不在当前分钟内,重置分钟计数器
  401. if token.LastRateLimitReset < currentMinute {
  402. token.LastRateLimitReset = currentTimestamp
  403. needUpdate = true
  404. }
  405. // 计算当前分钟内的使用次数
  406. var minuteCount int64
  407. err := DB.Model(&TokenUsageLog{}).Where("token_id = ? AND created_at >= ?", token.Id, currentMinute).Count(&minuteCount).Error
  408. if err != nil {
  409. common.SysError("检查分钟级使用次数失败: " + err.Error())
  410. return errors.New("系统错误,请稍后再试")
  411. }
  412. if int(minuteCount) >= token.RateLimitPerMinute {
  413. return errors.New("超出分钟限制,请稍后再试")
  414. }
  415. }
  416. // 检查日级限制
  417. if token.RateLimitPerDay > 0 {
  418. // 计算当天内的使用次数
  419. var dayCount int64
  420. err := DB.Model(&TokenUsageLog{}).Where("token_id = ? AND created_at >= ?", token.Id, currentDay).Count(&dayCount).Error
  421. if err != nil {
  422. common.SysError("检查日级使用次数失败: " + err.Error())
  423. return errors.New("系统错误,请稍后再试")
  424. }
  425. if int(dayCount) >= token.RateLimitPerDay {
  426. return errors.New("超出日限制,请稍后再试")
  427. }
  428. }
  429. // 如果需要更新,更新数据库
  430. if needUpdate {
  431. err := DB.Model(token).Select("last_rate_limit_reset").Updates(token).Error
  432. if err != nil {
  433. common.SysError("更新令牌重置时间失败: " + err.Error())
  434. // 恢复原始状态
  435. *token = originalToken
  436. }
  437. }
  438. return nil
  439. }
  440. // TokenUsageLog Token使用日志表(用于频率限制)
  441. type TokenUsageLog struct {
  442. Id int `json:"id" gorm:"primaryKey"`
  443. TokenId int `json:"token_id" gorm:"index:idx_token_created"`
  444. CreatedAt int64 `json:"created_at" gorm:"index:idx_token_created;index:idx_created"`
  445. }
  446. func (TokenUsageLog) TableName() string {
  447. return "token_usage_logs"
  448. }
  449. // RecordTokenUsage 记录令牌使用(用于频率限制)
  450. func RecordTokenUsage(tokenId int) error {
  451. if tokenId <= 0 {
  452. return errors.New("tokenId不能为空")
  453. }
  454. usageLog := TokenUsageLog{
  455. TokenId: tokenId,
  456. CreatedAt: time.Now().Unix(),
  457. }
  458. return DB.Create(&usageLog).Error
  459. }