token.go 7.0 KB


  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "gorm.io/gorm"
  6. "one-api/common"
  7. "strings"
  8. )
  9. type Token struct {
  10. Id int `json:"id"`
  11. UserId int `json:"user_id"`
  12. Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
  13. Status int `json:"status" gorm:"default:1"`
  14. Name string `json:"name" gorm:"index" `
  15. CreatedTime int64 `json:"created_time" gorm:"bigint"`
  16. AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
  17. ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
  18. RemainQuota int `json:"remain_quota" gorm:"default:0"`
  19. UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
  20. UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
  21. }
  22. func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
  23. var tokens []*Token
  24. var err error
  25. err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error
  26. return tokens, err
  27. }
  28. func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
  29. if token != "" {
  30. token = strings.Trim(token, "sk-")
  31. }
  32. err = DB.Where("user_id = ?", userId).Where("name LIKE ?", keyword+"%").Where("key LIKE ?", token+"%").Find(&tokens).Error
  33. return tokens, err
  34. }
  35. func ValidateUserToken(key string) (token *Token, err error) {
  36. if key == "" {
  37. return nil, errors.New("未提供令牌")
  38. }
  39. token, err = CacheGetTokenByKey(key)
  40. if err == nil {
  41. if token.Status == common.TokenStatusExhausted {
  42. return nil, errors.New("该令牌额度已用尽")
  43. } else if token.Status == common.TokenStatusExpired {
  44. return nil, errors.New("该令牌已过期")
  45. }
  46. if token.Status != common.TokenStatusEnabled {
  47. return nil, errors.New("该令牌状态不可用")
  48. }
  49. if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
  50. if !common.RedisEnabled {
  51. token.Status = common.TokenStatusExpired
  52. err := token.SelectUpdate()
  53. if err != nil {
  54. common.SysError("failed to update token status" + err.Error())
  55. }
  56. }
  57. return nil, errors.New("该令牌已过期")
  58. }
  59. if !token.UnlimitedQuota && token.RemainQuota <= 0 {
  60. if !common.RedisEnabled {
  61. // in this case, we can make sure the token is exhausted
  62. token.Status = common.TokenStatusExhausted
  63. err := token.SelectUpdate()
  64. if err != nil {
  65. common.SysError("failed to update token status" + err.Error())
  66. }
  67. }
  68. return nil, errors.New("该令牌额度已用尽")
  69. }
  70. return token, nil
  71. }
  72. return nil, errors.New("无效的令牌")
  73. }
  74. func GetTokenByIds(id int, userId int) (*Token, error) {
  75. if id == 0 || userId == 0 {
  76. return nil, errors.New("id 或 userId 为空!")
  77. }
  78. token := Token{Id: id, UserId: userId}
  79. var err error = nil
  80. err = DB.First(&token, "id = ? and user_id = ?", id, userId).Error
  81. return &token, err
  82. }
  83. func GetTokenById(id int) (*Token, error) {
  84. if id == 0 {
  85. return nil, errors.New("id 为空!")
  86. }
  87. token := Token{Id: id}
  88. var err error = nil
  89. err = DB.First(&token, "id = ?", id).Error
  90. return &token, err
  91. }
  92. func (token *Token) Insert() error {
  93. var err error
  94. err = DB.Create(token).Error
  95. return err
  96. }
  97. // Update Make sure your token's fields is completed, because this will update non-zero values
  98. func (token *Token) Update() error {
  99. var err error
  100. err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
  101. return err
  102. }
  103. func (token *Token) SelectUpdate() error {
  104. // This can update zero values
  105. return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
  106. }
  107. func (token *Token) Delete() error {
  108. var err error
  109. err = DB.Delete(token).Error
  110. return err
  111. }
  112. func DeleteTokenById(id int, userId int) (err error) {
  113. // Why we need userId here? In case user want to delete other's token.
  114. if id == 0 || userId == 0 {
  115. return errors.New("id 或 userId 为空!")
  116. }
  117. token := Token{Id: id, UserId: userId}
  118. err = DB.Where(token).First(&token).Error
  119. if err != nil {
  120. return err
  121. }
  122. return token.Delete()
  123. }
  124. func IncreaseTokenQuota(id int, quota int) (err error) {
  125. if quota < 0 {
  126. return errors.New("quota 不能为负数!")
  127. }
  128. if common.BatchUpdateEnabled {
  129. addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
  130. return nil
  131. }
  132. return increaseTokenQuota(id, quota)
  133. }
  134. func increaseTokenQuota(id int, quota int) (err error) {
  135. err = DB.Model(&Token{}).Where("id = ?", id).Updates(
  136. map[string]interface{}{
  137. "remain_quota": gorm.Expr("remain_quota + ?", quota),
  138. "used_quota": gorm.Expr("used_quota - ?", quota),
  139. "accessed_time": common.GetTimestamp(),
  140. },
  141. ).Error
  142. return err
  143. }
  144. func DecreaseTokenQuota(id int, quota int) (err error) {
  145. if quota < 0 {
  146. return errors.New("quota 不能为负数!")
  147. }
  148. if common.BatchUpdateEnabled {
  149. addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
  150. return nil
  151. }
  152. return decreaseTokenQuota(id, quota)
  153. }
  154. func decreaseTokenQuota(id int, quota int) (err error) {
  155. err = DB.Model(&Token{}).Where("id = ?", id).Updates(
  156. map[string]interface{}{
  157. "remain_quota": gorm.Expr("remain_quota - ?", quota),
  158. "used_quota": gorm.Expr("used_quota + ?", quota),
  159. "accessed_time": common.GetTimestamp(),
  160. },
  161. ).Error
  162. return err
  163. }
  164. func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
  165. if quota < 0 {
  166. return errors.New("quota 不能为负数!")
  167. }
  168. token, err := GetTokenById(tokenId)
  169. if err != nil {
  170. return err
  171. }
  172. if !token.UnlimitedQuota && token.RemainQuota < quota {
  173. return errors.New("令牌额度不足")
  174. }
  175. userQuota, err := GetUserQuota(token.UserId)
  176. if err != nil {
  177. return err
  178. }
  179. if userQuota < quota {
  180. return errors.New("用户额度不足")
  181. }
  182. quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold
  183. noMoreQuota := userQuota-quota <= 0
  184. if quotaTooLow || noMoreQuota {
  185. go func() {
  186. email, err := GetUserEmail(token.UserId)
  187. if err != nil {
  188. common.SysError("failed to fetch user email: " + err.Error())
  189. }
  190. prompt := "您的额度即将用尽"
  191. if noMoreQuota {
  192. prompt = "您的额度已用尽"
  193. }
  194. if email != "" {
  195. topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
  196. err = common.SendEmail(prompt, email,
  197. fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
  198. if err != nil {
  199. common.SysError("failed to send email" + err.Error())
  200. }
  201. }
  202. }()
  203. }
  204. if !token.UnlimitedQuota {
  205. err = DecreaseTokenQuota(tokenId, quota)
  206. if err != nil {
  207. return err
  208. }
  209. }
  210. err = DecreaseUserQuota(token.UserId, quota)
  211. return err
  212. }
  213. func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
  214. token, err := GetTokenById(tokenId)
  215. if quota > 0 {
  216. err = DecreaseUserQuota(token.UserId, quota)
  217. } else {
  218. err = IncreaseUserQuota(token.UserId, -quota)
  219. }
  220. if err != nil {
  221. return err
  222. }
  223. if !token.UnlimitedQuota {
  224. if quota > 0 {
  225. err = DecreaseTokenQuota(tokenId, quota)
  226. } else {
  227. err = IncreaseTokenQuota(tokenId, -quota)
  228. }
  229. if err != nil {
  230. return err
  231. }
  232. }
  233. return nil
  234. }