token.go 15 KB


  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "log"
  6. "one-api/common"
  7. "strings"
  8. "time"
  9. "github.com/bytedance/gopkg/util/gopool"
  10. "gorm.io/gorm"
  11. )
  12. type Token struct {
  13. Id int `json:"id"`
  14. UserId int `json:"user_id" gorm:"index"`
  15. Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
  16. Status int `json:"status" gorm:"default:1"`
  17. Name string `json:"name" gorm:"index" `
  18. CreatedTime int64 `json:"created_time" gorm:"bigint"`
  19. AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
  20. ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
  21. RemainQuota int `json:"remain_quota" gorm:"default:0"`
  22. UnlimitedQuota bool `json:"unlimited_quota"`
  23. ModelLimitsEnabled bool `json:"model_limits_enabled"`
  24. ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
  25. AllowIps *string `json:"allow_ips" gorm:"default:''"`
  26. UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
  27. Group string `json:"group" gorm:"default:''"`
  28. DailyUsageCount int `json:"daily_usage_count" gorm:"default:0"` // 今日使用次数
  29. TotalUsageCount int `json:"total_usage_count" gorm:"default:0"` // 总使用次数
  30. LastUsageDate string `json:"last_usage_date" gorm:"default:''"` // 最后使用日期(YYYY-MM-DD)
  31. RateLimitPerMinute int `json:"rate_limit_per_minute" gorm:"default:0"` // 每分钟访问次数限制,0表示不限制
  32. RateLimitPerDay int `json:"rate_limit_per_day" gorm:"default:0"` // 每日访问次数限制,0表示不限制
  33. LastRateLimitReset int64 `json:"last_rate_limit_reset" gorm:"default:0"` // 最后重置时间戳
  34. DeletedAt gorm.DeletedAt `gorm:"index"`
  35. }
  36. func (token *Token) Clean() {
  37. token.Key = ""
  38. }
  39. func (token *Token) GetIpLimitsMap() map[string]any {
  40. // delete empty spaces
  41. //split with \n
  42. ipLimitsMap := make(map[string]any)
  43. if token.AllowIps == nil {
  44. return ipLimitsMap
  45. }
  46. cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
  47. if cleanIps == "" {
  48. return ipLimitsMap
  49. }
  50. ips := strings.Split(cleanIps, "\n")
  51. for _, ip := range ips {
  52. ip = strings.TrimSpace(ip)
  53. ip = strings.ReplaceAll(ip, ",", "")
  54. if common.IsIP(ip) {
  55. ipLimitsMap[ip] = true
  56. }
  57. }
  58. return ipLimitsMap
  59. }
  60. func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
  61. var tokens []*Token
  62. var err error
  63. err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error
  64. return tokens, err
  65. }
  66. func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
  67. if token != "" {
  68. token = strings.Trim(token, "sk-")
  69. }
  70. err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
  71. return tokens, err
  72. }
  73. func ValidateUserToken(key string) (token *Token, err error) {
  74. log.Println("===========", key)
  75. if key == "" {
  76. return nil, errors.New("未提供令牌")
  77. }
  78. token, err = GetTokenByKey(key, false)
  79. if err == nil {
  80. if token.Status == common.TokenStatusExhausted {
  81. keyPrefix := key[:3]
  82. keySuffix := key[len(key)-3:]
  83. return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
  84. } else if token.Status == common.TokenStatusExpired {
  85. return token, errors.New("该令牌已过期")
  86. }
  87. if token.Status != common.TokenStatusEnabled {
  88. return token, errors.New("该令牌状态不可用")
  89. }
  90. if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
  91. if !common.RedisEnabled {
  92. token.Status = common.TokenStatusExpired
  93. err := token.SelectUpdate()
  94. if err != nil {
  95. common.SysError("failed to update token status" + err.Error())
  96. }
  97. }
  98. return token, errors.New("该令牌已过期")
  99. }
  100. if !token.UnlimitedQuota && token.RemainQuota <= 0 {
  101. if !common.RedisEnabled {
  102. // in this case, we can make sure the token is exhausted
  103. token.Status = common.TokenStatusExhausted
  104. err := token.SelectUpdate()
  105. if err != nil {
  106. common.SysError("failed to update token status" + err.Error())
  107. }
  108. }
  109. keyPrefix := key[:3]
  110. keySuffix := key[len(key)-3:]
  111. return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
  112. }
  113. // 检查访问频率限制
  114. if err := CheckRateLimit(token); err != nil {
  115. return token, err
  116. }
  117. return token, nil
  118. }
  119. return nil, errors.New("无效的令牌")
  120. }
  121. func GetTokenByIds(id int, userId int) (*Token, error) {
  122. if id == 0 || userId == 0 {
  123. return nil, errors.New("id 或 userId 为空!")
  124. }
  125. token := Token{Id: id, UserId: userId}
  126. var err error = nil
  127. err = DB.First(&token, "id = ? and user_id = ?", id, userId).Error
  128. return &token, err
  129. }
  130. func GetTokenById(id int) (*Token, error) {
  131. if id == 0 {
  132. return nil, errors.New("id 为空!")
  133. }
  134. token := Token{Id: id}
  135. var err error = nil
  136. err = DB.First(&token, "id = ?", id).Error
  137. if shouldUpdateRedis(true, err) {
  138. gopool.Go(func() {
  139. if err := cacheSetToken(token); err != nil {
  140. common.SysError("failed to update user status cache: " + err.Error())
  141. }
  142. })
  143. }
  144. return &token, err
  145. }
  146. func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
  147. defer func() {
  148. // Update Redis cache asynchronously on successful DB read
  149. if shouldUpdateRedis(fromDB, err) && token != nil {
  150. gopool.Go(func() {
  151. if err := cacheSetToken(*token); err != nil {
  152. common.SysError("failed to update user status cache: " + err.Error())
  153. }
  154. })
  155. }
  156. }()
  157. if !fromDB && common.RedisEnabled {
  158. // Try Redis first
  159. token, err := cacheGetTokenByKey(key)
  160. if err == nil {
  161. return token, nil
  162. }
  163. // Don't return error - fall through to DB
  164. }
  165. fromDB = true
  166. err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
  167. return token, err
  168. }
  169. func (token *Token) Insert() error {
  170. var err error
  171. err = DB.Create(token).Error
  172. return err
  173. }
  174. // Update Make sure your token's fields is completed, because this will update non-zero values
  175. func (token *Token) Update() (err error) {
  176. defer func() {
  177. if shouldUpdateRedis(true, err) {
  178. gopool.Go(func() {
  179. err := cacheSetToken(*token)
  180. if err != nil {
  181. common.SysError("failed to update token cache: " + err.Error())
  182. }
  183. })
  184. }
  185. }()
  186. err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
  187. "model_limits_enabled", "model_limits", "allow_ips", "group", "daily_usage_count", "total_usage_count", "last_usage_date",
  188. "rate_limit_per_minute", "rate_limit_per_day", "last_rate_limit_reset").Updates(token).Error
  189. return err
  190. }
  191. func (token *Token) SelectUpdate() (err error) {
  192. defer func() {
  193. if shouldUpdateRedis(true, err) {
  194. gopool.Go(func() {
  195. err := cacheSetToken(*token)
  196. if err != nil {
  197. common.SysError("failed to update token cache: " + err.Error())
  198. }
  199. })
  200. }
  201. }()
  202. // This can update zero values
  203. return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
  204. }
  205. func (token *Token) Delete() (err error) {
  206. defer func() {
  207. if shouldUpdateRedis(true, err) {
  208. gopool.Go(func() {
  209. err := cacheDeleteToken(token.Key)
  210. if err != nil {
  211. common.SysError("failed to delete token cache: " + err.Error())
  212. }
  213. })
  214. }
  215. }()
  216. err = DB.Delete(token).Error
  217. return err
  218. }
  219. func (token *Token) IsModelLimitsEnabled() bool {
  220. return token.ModelLimitsEnabled
  221. }
  222. func (token *Token) GetModelLimits() []string {
  223. if token.ModelLimits == "" {
  224. return []string{}
  225. }
  226. return strings.Split(token.ModelLimits, ",")
  227. }
  228. func (token *Token) GetModelLimitsMap() map[string]bool {
  229. limits := token.GetModelLimits()
  230. limitsMap := make(map[string]bool)
  231. for _, limit := range limits {
  232. limitsMap[limit] = true
  233. }
  234. return limitsMap
  235. }
  236. func DisableModelLimits(tokenId int) error {
  237. token, err := GetTokenById(tokenId)
  238. if err != nil {
  239. return err
  240. }
  241. token.ModelLimitsEnabled = false
  242. token.ModelLimits = ""
  243. return token.Update()
  244. }
  245. func DeleteTokenById(id int, userId int) (err error) {
  246. // Why we need userId here? In case user want to delete other's token.
  247. if id == 0 || userId == 0 {
  248. return errors.New("id 或 userId 为空!")
  249. }
  250. token := Token{Id: id, UserId: userId}
  251. err = DB.Where(token).First(&token).Error
  252. if err != nil {
  253. return err
  254. }
  255. return token.Delete()
  256. }
  257. func IncreaseTokenQuota(id int, key string, quota int) (err error) {
  258. if quota < 0 {
  259. return errors.New("quota 不能为负数!")
  260. }
  261. if common.RedisEnabled {
  262. gopool.Go(func() {
  263. err := cacheIncrTokenQuota(key, int64(quota))
  264. if err != nil {
  265. common.SysError("failed to increase token quota: " + err.Error())
  266. }
  267. })
  268. }
  269. if common.BatchUpdateEnabled {
  270. addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
  271. return nil
  272. }
  273. return increaseTokenQuota(id, quota)
  274. }
  275. func increaseTokenQuota(id int, quota int) (err error) {
  276. err = DB.Model(&Token{}).Where("id = ?", id).Updates(
  277. map[string]interface{}{
  278. "remain_quota": gorm.Expr("remain_quota + ?", quota),
  279. "used_quota": gorm.Expr("used_quota - ?", quota),
  280. "accessed_time": common.GetTimestamp(),
  281. },
  282. ).Error
  283. return err
  284. }
  285. func DecreaseTokenQuota(id int, key string, quota int) (err error) {
  286. if quota < 0 {
  287. return errors.New("quota 不能为负数!")
  288. }
  289. if common.RedisEnabled {
  290. gopool.Go(func() {
  291. err := cacheDecrTokenQuota(key, int64(quota))
  292. if err != nil {
  293. common.SysError("failed to decrease token quota: " + err.Error())
  294. }
  295. })
  296. }
  297. if common.BatchUpdateEnabled {
  298. addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
  299. return nil
  300. }
  301. return decreaseTokenQuota(id, quota)
  302. }
  303. func decreaseTokenQuota(id int, quota int) (err error) {
  304. err = DB.Model(&Token{}).Where("id = ?", id).Updates(
  305. map[string]interface{}{
  306. "remain_quota": gorm.Expr("remain_quota - ?", quota),
  307. "used_quota": gorm.Expr("used_quota + ?", quota),
  308. "accessed_time": common.GetTimestamp(),
  309. },
  310. ).Error
  311. return err
  312. }
  313. // CountUserTokens returns total number of tokens for the given user, used for pagination
  314. func CountUserTokens(userId int) (int64, error) {
  315. var total int64
  316. err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
  317. return total, err
  318. }
  319. // BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量
  320. func BatchDeleteTokens(ids []int, userId int) (int, error) {
  321. if len(ids) == 0 {
  322. return 0, errors.New("ids 不能为空!")
  323. }
  324. tx := DB.Begin()
  325. var tokens []Token
  326. if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil {
  327. tx.Rollback()
  328. return 0, err
  329. }
  330. if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil {
  331. tx.Rollback()
  332. return 0, err
  333. }
  334. if err := tx.Commit().Error; err != nil {
  335. return 0, err
  336. }
  337. if common.RedisEnabled {
  338. gopool.Go(func() {
  339. for _, t := range tokens {
  340. _ = cacheDeleteToken(t.Key)
  341. }
  342. })
  343. }
  344. return len(tokens), nil
  345. }
  346. // IncreaseTokenUsageCount 增加Token的使用次数
  347. func IncreaseTokenUsageCount(key string) error {
  348. if key == "" {
  349. return errors.New("key 不能为空")
  350. }
  351. // 获取当前日期
  352. currentDate := common.GetTimeString()[:10] // YYYY-MM-DD格式
  353. // 更新数据库
  354. err := DB.Model(&Token{}).Where(commonKeyCol+" = ?", key).Updates(map[string]interface{}{
  355. "total_usage_count": gorm.Expr("total_usage_count + 1"),
  356. "daily_usage_count": gorm.Expr("CASE WHEN last_usage_date = ? THEN daily_usage_count + 1 ELSE 1 END", currentDate),
  357. "last_usage_date": currentDate,
  358. "accessed_time": common.GetTimestamp(),
  359. }).Error
  360. // 更新缓存
  361. if common.RedisEnabled && err == nil {
  362. gopool.Go(func() {
  363. // 重新缓存Token信息
  364. token, getErr := GetTokenByKey(key, true) // 从DB获取最新数据
  365. if getErr == nil {
  366. _ = cacheSetToken(*token)
  367. }
  368. })
  369. }
  370. return err
  371. }
  372. // CheckRateLimit 检查令牌的访问频率限制
  373. func CheckRateLimit(token *Token) error {
  374. if token == nil {
  375. return errors.New("token不能为空")
  376. }
  377. // 如果没有设置限制,直接返回
  378. if token.RateLimitPerMinute <= 0 && token.RateLimitPerDay <= 0 {
  379. return nil
  380. }
  381. currentTime := time.Now()
  382. currentTimestamp := currentTime.Unix()
  383. // 获取当前时间的分钟级时间戳(用于分钟级限制)
  384. currentMinute := currentTime.Truncate(time.Minute).Unix()
  385. // 获取当天开始的时间戳(用于日级限制)
  386. currentDay := time.Date(currentTime.Year(), currentTime.Month(), currentTime.Day(), 0, 0, 0, 0, currentTime.Location()).Unix()
  387. // 检查是否需要重置计数器
  388. needUpdate := false
  389. originalToken := *token // 保存原始状态
  390. // 检查分钟级限制
  391. if token.RateLimitPerMinute > 0 {
  392. // 如果上次重置时间不在当前分钟内,重置分钟计数器
  393. if token.LastRateLimitReset < currentMinute {
  394. token.LastRateLimitReset = currentTimestamp
  395. needUpdate = true
  396. }
  397. // 计算当前分钟内的使用次数
  398. var minuteCount int64
  399. err := DB.Model(&TokenUsageLog{}).Where("token_id = ? AND created_at >= ?", token.Id, currentMinute).Count(&minuteCount).Error
  400. if err != nil {
  401. common.SysError("检查分钟级使用次数失败: " + err.Error())
  402. return errors.New("系统错误,请稍后再试")
  403. }
  404. if int(minuteCount) >= token.RateLimitPerMinute {
  405. return errors.New("超出分钟限制,请稍后再试")
  406. }
  407. }
  408. // 检查日级限制
  409. if token.RateLimitPerDay > 0 {
  410. // 计算当天内的使用次数
  411. var dayCount int64
  412. err := DB.Model(&TokenUsageLog{}).Where("token_id = ? AND created_at >= ?", token.Id, currentDay).Count(&dayCount).Error
  413. if err != nil {
  414. common.SysError("检查日级使用次数失败: " + err.Error())
  415. return errors.New("系统错误,请稍后再试")
  416. }
  417. if int(dayCount) >= token.RateLimitPerDay {
  418. return errors.New("超出日限制,请稍后再试")
  419. }
  420. }
  421. // 如果需要更新,更新数据库
  422. if needUpdate {
  423. err := DB.Model(token).Select("last_rate_limit_reset").Updates(token).Error
  424. if err != nil {
  425. common.SysError("更新令牌重置时间失败: " + err.Error())
  426. // 恢复原始状态
  427. *token = originalToken
  428. }
  429. }
  430. return nil
  431. }
  432. // TokenUsageLog Token使用日志表(用于频率限制)
  433. type TokenUsageLog struct {
  434. Id int `json:"id" gorm:"primaryKey"`
  435. TokenId int `json:"token_id" gorm:"index:idx_token_created"`
  436. CreatedAt int64 `json:"created_at" gorm:"index:idx_token_created;index:idx_created"`
  437. }
  438. func (TokenUsageLog) TableName() string {
  439. return "token_usage_logs"
  440. }
  441. // RecordTokenUsage 记录令牌使用(用于频率限制)
  442. func RecordTokenUsage(tokenId int) error {
  443. if tokenId <= 0 {
  444. return errors.New("tokenId不能为空")
  445. }
  446. usageLog := TokenUsageLog{
  447. TokenId: tokenId,
  448. CreatedAt: time.Now().Unix(),
  449. }
  450. return DB.Create(&usageLog).Error
  451. }