token.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. package controller
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strconv"
  6. "strings"
  7. "github.com/QuantumNous/new-api/common"
  8. "github.com/QuantumNous/new-api/i18n"
  9. "github.com/QuantumNous/new-api/model"
  10. "github.com/QuantumNous/new-api/setting/operation_setting"
  11. "github.com/gin-gonic/gin"
  12. )
  13. func buildMaskedTokenResponse(token *model.Token) *model.Token {
  14. if token == nil {
  15. return nil
  16. }
  17. maskedToken := *token
  18. maskedToken.Key = token.GetMaskedKey()
  19. return &maskedToken
  20. }
  21. func buildMaskedTokenResponses(tokens []*model.Token) []*model.Token {
  22. maskedTokens := make([]*model.Token, 0, len(tokens))
  23. for _, token := range tokens {
  24. maskedTokens = append(maskedTokens, buildMaskedTokenResponse(token))
  25. }
  26. return maskedTokens
  27. }
  28. func GetAllTokens(c *gin.Context) {
  29. userId := c.GetInt("id")
  30. pageInfo := common.GetPageQuery(c)
  31. tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
  32. if err != nil {
  33. common.ApiError(c, err)
  34. return
  35. }
  36. total, _ := model.CountUserTokens(userId)
  37. pageInfo.SetTotal(int(total))
  38. pageInfo.SetItems(buildMaskedTokenResponses(tokens))
  39. common.ApiSuccess(c, pageInfo)
  40. }
  41. func SearchTokens(c *gin.Context) {
  42. userId := c.GetInt("id")
  43. keyword := c.Query("keyword")
  44. token := c.Query("token")
  45. pageInfo := common.GetPageQuery(c)
  46. tokens, total, err := model.SearchUserTokens(userId, keyword, token, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
  47. if err != nil {
  48. common.ApiError(c, err)
  49. return
  50. }
  51. pageInfo.SetTotal(int(total))
  52. pageInfo.SetItems(buildMaskedTokenResponses(tokens))
  53. common.ApiSuccess(c, pageInfo)
  54. }
  55. func GetToken(c *gin.Context) {
  56. id, err := strconv.Atoi(c.Param("id"))
  57. userId := c.GetInt("id")
  58. if err != nil {
  59. common.ApiError(c, err)
  60. return
  61. }
  62. token, err := model.GetTokenByIds(id, userId)
  63. if err != nil {
  64. common.ApiError(c, err)
  65. return
  66. }
  67. common.ApiSuccess(c, buildMaskedTokenResponse(token))
  68. }
  69. func GetTokenKey(c *gin.Context) {
  70. id, err := strconv.Atoi(c.Param("id"))
  71. userId := c.GetInt("id")
  72. if err != nil {
  73. common.ApiError(c, err)
  74. return
  75. }
  76. token, err := model.GetTokenByIds(id, userId)
  77. if err != nil {
  78. common.ApiError(c, err)
  79. return
  80. }
  81. common.ApiSuccess(c, gin.H{
  82. "key": token.GetFullKey(),
  83. })
  84. }
  85. func GetTokenStatus(c *gin.Context) {
  86. tokenId := c.GetInt("token_id")
  87. userId := c.GetInt("id")
  88. token, err := model.GetTokenByIds(tokenId, userId)
  89. if err != nil {
  90. common.ApiError(c, err)
  91. return
  92. }
  93. expiredAt := token.ExpiredTime
  94. if expiredAt == -1 {
  95. expiredAt = 0
  96. }
  97. c.JSON(http.StatusOK, gin.H{
  98. "object": "credit_summary",
  99. "total_granted": token.RemainQuota,
  100. "total_used": 0, // not supported currently
  101. "total_available": token.RemainQuota,
  102. "expires_at": expiredAt * 1000,
  103. })
  104. }
  105. func GetTokenUsage(c *gin.Context) {
  106. authHeader := c.GetHeader("Authorization")
  107. if authHeader == "" {
  108. c.JSON(http.StatusUnauthorized, gin.H{
  109. "success": false,
  110. "message": "No Authorization header",
  111. })
  112. return
  113. }
  114. parts := strings.Split(authHeader, " ")
  115. if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
  116. c.JSON(http.StatusUnauthorized, gin.H{
  117. "success": false,
  118. "message": "Invalid Bearer token",
  119. })
  120. return
  121. }
  122. tokenKey := parts[1]
  123. token, err := model.GetTokenByKey(strings.TrimPrefix(tokenKey, "sk-"), false)
  124. if err != nil {
  125. common.SysError("failed to get token by key: " + err.Error())
  126. common.ApiErrorI18n(c, i18n.MsgTokenGetInfoFailed)
  127. return
  128. }
  129. expiredAt := token.ExpiredTime
  130. if expiredAt == -1 {
  131. expiredAt = 0
  132. }
  133. c.JSON(http.StatusOK, gin.H{
  134. "code": true,
  135. "message": "ok",
  136. "data": gin.H{
  137. "object": "token_usage",
  138. "name": token.Name,
  139. "total_granted": token.RemainQuota + token.UsedQuota,
  140. "total_used": token.UsedQuota,
  141. "total_available": token.RemainQuota,
  142. "unlimited_quota": token.UnlimitedQuota,
  143. "model_limits": token.GetModelLimitsMap(),
  144. "model_limits_enabled": token.ModelLimitsEnabled,
  145. "expires_at": expiredAt,
  146. },
  147. })
  148. }
  149. func AddToken(c *gin.Context) {
  150. token := model.Token{}
  151. err := c.ShouldBindJSON(&token)
  152. if err != nil {
  153. common.ApiError(c, err)
  154. return
  155. }
  156. if len(token.Name) > 50 {
  157. common.ApiErrorI18n(c, i18n.MsgTokenNameTooLong)
  158. return
  159. }
  160. // 非无限额度时,检查额度值是否超出有效范围
  161. if !token.UnlimitedQuota {
  162. if token.RemainQuota < 0 {
  163. common.ApiErrorI18n(c, i18n.MsgTokenQuotaNegative)
  164. return
  165. }
  166. maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
  167. if token.RemainQuota > maxQuotaValue {
  168. common.ApiErrorI18n(c, i18n.MsgTokenQuotaExceedMax, map[string]any{"Max": maxQuotaValue})
  169. return
  170. }
  171. }
  172. // 检查用户令牌数量是否已达上限
  173. maxTokens := operation_setting.GetMaxUserTokens()
  174. count, err := model.CountUserTokens(c.GetInt("id"))
  175. if err != nil {
  176. common.ApiError(c, err)
  177. return
  178. }
  179. if int(count) >= maxTokens {
  180. c.JSON(http.StatusOK, gin.H{
  181. "success": false,
  182. "message": fmt.Sprintf("已达到最大令牌数量限制 (%d)", maxTokens),
  183. })
  184. return
  185. }
  186. key, err := common.GenerateKey()
  187. if err != nil {
  188. common.ApiErrorI18n(c, i18n.MsgTokenGenerateFailed)
  189. common.SysLog("failed to generate token key: " + err.Error())
  190. return
  191. }
  192. cleanToken := model.Token{
  193. UserId: c.GetInt("id"),
  194. Name: token.Name,
  195. Key: key,
  196. CreatedTime: common.GetTimestamp(),
  197. AccessedTime: common.GetTimestamp(),
  198. ExpiredTime: token.ExpiredTime,
  199. RemainQuota: token.RemainQuota,
  200. UnlimitedQuota: token.UnlimitedQuota,
  201. ModelLimitsEnabled: token.ModelLimitsEnabled,
  202. ModelLimits: token.ModelLimits,
  203. AllowIps: token.AllowIps,
  204. Group: token.Group,
  205. CrossGroupRetry: token.CrossGroupRetry,
  206. }
  207. err = cleanToken.Insert()
  208. if err != nil {
  209. common.ApiError(c, err)
  210. return
  211. }
  212. c.JSON(http.StatusOK, gin.H{
  213. "success": true,
  214. "message": "",
  215. })
  216. }
  217. func DeleteToken(c *gin.Context) {
  218. id, _ := strconv.Atoi(c.Param("id"))
  219. userId := c.GetInt("id")
  220. err := model.DeleteTokenById(id, userId)
  221. if err != nil {
  222. common.ApiError(c, err)
  223. return
  224. }
  225. c.JSON(http.StatusOK, gin.H{
  226. "success": true,
  227. "message": "",
  228. })
  229. }
  230. func UpdateToken(c *gin.Context) {
  231. userId := c.GetInt("id")
  232. statusOnly := c.Query("status_only")
  233. token := model.Token{}
  234. err := c.ShouldBindJSON(&token)
  235. if err != nil {
  236. common.ApiError(c, err)
  237. return
  238. }
  239. if len(token.Name) > 50 {
  240. common.ApiErrorI18n(c, i18n.MsgTokenNameTooLong)
  241. return
  242. }
  243. if !token.UnlimitedQuota {
  244. if token.RemainQuota < 0 {
  245. common.ApiErrorI18n(c, i18n.MsgTokenQuotaNegative)
  246. return
  247. }
  248. maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
  249. if token.RemainQuota > maxQuotaValue {
  250. common.ApiErrorI18n(c, i18n.MsgTokenQuotaExceedMax, map[string]any{"Max": maxQuotaValue})
  251. return
  252. }
  253. }
  254. cleanToken, err := model.GetTokenByIds(token.Id, userId)
  255. if err != nil {
  256. common.ApiError(c, err)
  257. return
  258. }
  259. if token.Status == common.TokenStatusEnabled {
  260. if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
  261. common.ApiErrorI18n(c, i18n.MsgTokenExpiredCannotEnable)
  262. return
  263. }
  264. if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
  265. common.ApiErrorI18n(c, i18n.MsgTokenExhaustedCannotEable)
  266. return
  267. }
  268. }
  269. if statusOnly != "" {
  270. cleanToken.Status = token.Status
  271. } else {
  272. // If you add more fields, please also update token.Update()
  273. cleanToken.Name = token.Name
  274. cleanToken.ExpiredTime = token.ExpiredTime
  275. cleanToken.RemainQuota = token.RemainQuota
  276. cleanToken.UnlimitedQuota = token.UnlimitedQuota
  277. cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
  278. cleanToken.ModelLimits = token.ModelLimits
  279. cleanToken.AllowIps = token.AllowIps
  280. cleanToken.Group = token.Group
  281. cleanToken.CrossGroupRetry = token.CrossGroupRetry
  282. }
  283. err = cleanToken.Update()
  284. if err != nil {
  285. common.ApiError(c, err)
  286. return
  287. }
  288. c.JSON(http.StatusOK, gin.H{
  289. "success": true,
  290. "message": "",
  291. "data": buildMaskedTokenResponse(cleanToken),
  292. })
  293. }
  294. type TokenBatch struct {
  295. Ids []int `json:"ids"`
  296. }
  297. func DeleteTokenBatch(c *gin.Context) {
  298. tokenBatch := TokenBatch{}
  299. if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
  300. common.ApiErrorI18n(c, i18n.MsgInvalidParams)
  301. return
  302. }
  303. userId := c.GetInt("id")
  304. count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
  305. if err != nil {
  306. common.ApiError(c, err)
  307. return
  308. }
  309. c.JSON(http.StatusOK, gin.H{
  310. "success": true,
  311. "message": "",
  312. "data": count,
  313. })
  314. }