token.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. package controller
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strconv"
  6. "strings"
  7. "github.com/QuantumNous/new-api/common"
  8. "github.com/QuantumNous/new-api/model"
  9. "github.com/gin-gonic/gin"
  10. )
  11. func GetAllTokens(c *gin.Context) {
  12. userId := c.GetInt("id")
  13. pageInfo := common.GetPageQuery(c)
  14. tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
  15. if err != nil {
  16. common.ApiError(c, err)
  17. return
  18. }
  19. total, _ := model.CountUserTokens(userId)
  20. pageInfo.SetTotal(int(total))
  21. pageInfo.SetItems(tokens)
  22. common.ApiSuccess(c, pageInfo)
  23. return
  24. }
  25. func SearchTokens(c *gin.Context) {
  26. userId := c.GetInt("id")
  27. keyword := c.Query("keyword")
  28. token := c.Query("token")
  29. tokens, err := model.SearchUserTokens(userId, keyword, token)
  30. if err != nil {
  31. common.ApiError(c, err)
  32. return
  33. }
  34. c.JSON(http.StatusOK, gin.H{
  35. "success": true,
  36. "message": "",
  37. "data": tokens,
  38. })
  39. return
  40. }
  41. func GetToken(c *gin.Context) {
  42. id, err := strconv.Atoi(c.Param("id"))
  43. userId := c.GetInt("id")
  44. if err != nil {
  45. common.ApiError(c, err)
  46. return
  47. }
  48. token, err := model.GetTokenByIds(id, userId)
  49. if err != nil {
  50. common.ApiError(c, err)
  51. return
  52. }
  53. c.JSON(http.StatusOK, gin.H{
  54. "success": true,
  55. "message": "",
  56. "data": token,
  57. })
  58. return
  59. }
  60. func GetTokenStatus(c *gin.Context) {
  61. tokenId := c.GetInt("token_id")
  62. userId := c.GetInt("id")
  63. token, err := model.GetTokenByIds(tokenId, userId)
  64. if err != nil {
  65. common.ApiError(c, err)
  66. return
  67. }
  68. expiredAt := token.ExpiredTime
  69. if expiredAt == -1 {
  70. expiredAt = 0
  71. }
  72. c.JSON(http.StatusOK, gin.H{
  73. "object": "credit_summary",
  74. "total_granted": token.RemainQuota,
  75. "total_used": 0, // not supported currently
  76. "total_available": token.RemainQuota,
  77. "expires_at": expiredAt * 1000,
  78. })
  79. }
  80. func GetTokenUsage(c *gin.Context) {
  81. authHeader := c.GetHeader("Authorization")
  82. if authHeader == "" {
  83. c.JSON(http.StatusUnauthorized, gin.H{
  84. "success": false,
  85. "message": "No Authorization header",
  86. })
  87. return
  88. }
  89. parts := strings.Split(authHeader, " ")
  90. if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
  91. c.JSON(http.StatusUnauthorized, gin.H{
  92. "success": false,
  93. "message": "Invalid Bearer token",
  94. })
  95. return
  96. }
  97. tokenKey := parts[1]
  98. token, err := model.GetTokenByKey(strings.TrimPrefix(tokenKey, "sk-"), false)
  99. if err != nil {
  100. common.SysError("failed to get token by key: " + err.Error())
  101. c.JSON(http.StatusOK, gin.H{
  102. "success": false,
  103. "message": "获取令牌信息失败,请稍后重试",
  104. })
  105. return
  106. }
  107. expiredAt := token.ExpiredTime
  108. if expiredAt == -1 {
  109. expiredAt = 0
  110. }
  111. c.JSON(http.StatusOK, gin.H{
  112. "code": true,
  113. "message": "ok",
  114. "data": gin.H{
  115. "object": "token_usage",
  116. "name": token.Name,
  117. "total_granted": token.RemainQuota + token.UsedQuota,
  118. "total_used": token.UsedQuota,
  119. "total_available": token.RemainQuota,
  120. "unlimited_quota": token.UnlimitedQuota,
  121. "model_limits": token.GetModelLimitsMap(),
  122. "model_limits_enabled": token.ModelLimitsEnabled,
  123. "expires_at": expiredAt,
  124. },
  125. })
  126. }
  127. func AddToken(c *gin.Context) {
  128. token := model.Token{}
  129. err := c.ShouldBindJSON(&token)
  130. if err != nil {
  131. common.ApiError(c, err)
  132. return
  133. }
  134. if len(token.Name) > 50 {
  135. c.JSON(http.StatusOK, gin.H{
  136. "success": false,
  137. "message": "令牌名称过长",
  138. })
  139. return
  140. }
  141. // 非无限额度时,检查额度值是否超出有效范围
  142. if !token.UnlimitedQuota {
  143. if token.RemainQuota < 0 {
  144. c.JSON(http.StatusOK, gin.H{
  145. "success": false,
  146. "message": "额度值不能为负数",
  147. })
  148. return
  149. }
  150. maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
  151. if token.RemainQuota > maxQuotaValue {
  152. c.JSON(http.StatusOK, gin.H{
  153. "success": false,
  154. "message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
  155. })
  156. return
  157. }
  158. }
  159. key, err := common.GenerateKey()
  160. if err != nil {
  161. c.JSON(http.StatusOK, gin.H{
  162. "success": false,
  163. "message": "生成令牌失败",
  164. })
  165. common.SysLog("failed to generate token key: " + err.Error())
  166. return
  167. }
  168. cleanToken := model.Token{
  169. UserId: c.GetInt("id"),
  170. Name: token.Name,
  171. Key: key,
  172. CreatedTime: common.GetTimestamp(),
  173. AccessedTime: common.GetTimestamp(),
  174. ExpiredTime: token.ExpiredTime,
  175. RemainQuota: token.RemainQuota,
  176. UnlimitedQuota: token.UnlimitedQuota,
  177. ModelLimitsEnabled: token.ModelLimitsEnabled,
  178. ModelLimits: token.ModelLimits,
  179. AllowIps: token.AllowIps,
  180. Group: token.Group,
  181. CrossGroupRetry: token.CrossGroupRetry,
  182. }
  183. err = cleanToken.Insert()
  184. if err != nil {
  185. common.ApiError(c, err)
  186. return
  187. }
  188. c.JSON(http.StatusOK, gin.H{
  189. "success": true,
  190. "message": "",
  191. })
  192. return
  193. }
  194. func DeleteToken(c *gin.Context) {
  195. id, _ := strconv.Atoi(c.Param("id"))
  196. userId := c.GetInt("id")
  197. err := model.DeleteTokenById(id, userId)
  198. if err != nil {
  199. common.ApiError(c, err)
  200. return
  201. }
  202. c.JSON(http.StatusOK, gin.H{
  203. "success": true,
  204. "message": "",
  205. })
  206. return
  207. }
  208. func UpdateToken(c *gin.Context) {
  209. userId := c.GetInt("id")
  210. statusOnly := c.Query("status_only")
  211. token := model.Token{}
  212. err := c.ShouldBindJSON(&token)
  213. if err != nil {
  214. common.ApiError(c, err)
  215. return
  216. }
  217. if len(token.Name) > 50 {
  218. c.JSON(http.StatusOK, gin.H{
  219. "success": false,
  220. "message": "令牌名称过长",
  221. })
  222. return
  223. }
  224. if !token.UnlimitedQuota {
  225. if token.RemainQuota < 0 {
  226. c.JSON(http.StatusOK, gin.H{
  227. "success": false,
  228. "message": "额度值不能为负数",
  229. })
  230. return
  231. }
  232. maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
  233. if token.RemainQuota > maxQuotaValue {
  234. c.JSON(http.StatusOK, gin.H{
  235. "success": false,
  236. "message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
  237. })
  238. return
  239. }
  240. }
  241. cleanToken, err := model.GetTokenByIds(token.Id, userId)
  242. if err != nil {
  243. common.ApiError(c, err)
  244. return
  245. }
  246. if token.Status == common.TokenStatusEnabled {
  247. if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
  248. c.JSON(http.StatusOK, gin.H{
  249. "success": false,
  250. "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
  251. })
  252. return
  253. }
  254. if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
  255. c.JSON(http.StatusOK, gin.H{
  256. "success": false,
  257. "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
  258. })
  259. return
  260. }
  261. }
  262. if statusOnly != "" {
  263. cleanToken.Status = token.Status
  264. } else {
  265. // If you add more fields, please also update token.Update()
  266. cleanToken.Name = token.Name
  267. cleanToken.ExpiredTime = token.ExpiredTime
  268. cleanToken.RemainQuota = token.RemainQuota
  269. cleanToken.UnlimitedQuota = token.UnlimitedQuota
  270. cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
  271. cleanToken.ModelLimits = token.ModelLimits
  272. cleanToken.AllowIps = token.AllowIps
  273. cleanToken.Group = token.Group
  274. cleanToken.CrossGroupRetry = token.CrossGroupRetry
  275. }
  276. err = cleanToken.Update()
  277. if err != nil {
  278. common.ApiError(c, err)
  279. return
  280. }
  281. c.JSON(http.StatusOK, gin.H{
  282. "success": true,
  283. "message": "",
  284. "data": cleanToken,
  285. })
  286. }
  287. type TokenBatch struct {
  288. Ids []int `json:"ids"`
  289. }
  290. func DeleteTokenBatch(c *gin.Context) {
  291. tokenBatch := TokenBatch{}
  292. if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
  293. c.JSON(http.StatusOK, gin.H{
  294. "success": false,
  295. "message": "参数错误",
  296. })
  297. return
  298. }
  299. userId := c.GetInt("id")
  300. count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
  301. if err != nil {
  302. common.ApiError(c, err)
  303. return
  304. }
  305. c.JSON(http.StatusOK, gin.H{
  306. "success": true,
  307. "message": "",
  308. "data": count,
  309. })
  310. }