auth.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package middleware
  2. import (
  3. "github.com/gin-contrib/sessions"
  4. "github.com/gin-gonic/gin"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/model"
  8. "strconv"
  9. "strings"
  10. )
  11. func authHelper(c *gin.Context, minRole int) {
  12. session := sessions.Default(c)
  13. username := session.Get("username")
  14. role := session.Get("role")
  15. id := session.Get("id")
  16. status := session.Get("status")
  17. useAccessToken := false
  18. if username == nil {
  19. // Check access token
  20. accessToken := c.Request.Header.Get("Authorization")
  21. if accessToken == "" {
  22. c.JSON(http.StatusUnauthorized, gin.H{
  23. "success": false,
  24. "message": "无权进行此操作,未登录且未提供 access token",
  25. })
  26. c.Abort()
  27. return
  28. }
  29. user := model.ValidateAccessToken(accessToken)
  30. if user != nil && user.Username != "" {
  31. // Token is valid
  32. username = user.Username
  33. role = user.Role
  34. id = user.Id
  35. status = user.Status
  36. useAccessToken = true
  37. } else {
  38. c.JSON(http.StatusOK, gin.H{
  39. "success": false,
  40. "message": "无权进行此操作,access token 无效",
  41. })
  42. c.Abort()
  43. return
  44. }
  45. }
  46. if !useAccessToken {
  47. // get header New-Api-User
  48. apiUserIdStr := c.Request.Header.Get("New-Api-User")
  49. if apiUserIdStr == "" {
  50. c.JSON(http.StatusUnauthorized, gin.H{
  51. "success": false,
  52. "message": "无权进行此操作,请刷新页面或清空缓存后重试",
  53. })
  54. c.Abort()
  55. return
  56. }
  57. apiUserId, err := strconv.Atoi(apiUserIdStr)
  58. if err != nil {
  59. c.JSON(http.StatusUnauthorized, gin.H{
  60. "success": false,
  61. "message": "无权进行此操作,登录信息无效,请重新登录",
  62. })
  63. c.Abort()
  64. return
  65. }
  66. if id != apiUserId {
  67. c.JSON(http.StatusUnauthorized, gin.H{
  68. "success": false,
  69. "message": "无权进行此操作,与登录用户不匹配,请重新登录",
  70. })
  71. c.Abort()
  72. return
  73. }
  74. }
  75. if status.(int) == common.UserStatusDisabled {
  76. c.JSON(http.StatusOK, gin.H{
  77. "success": false,
  78. "message": "用户已被封禁",
  79. })
  80. c.Abort()
  81. return
  82. }
  83. if role.(int) < minRole {
  84. c.JSON(http.StatusOK, gin.H{
  85. "success": false,
  86. "message": "无权进行此操作,权限不足",
  87. })
  88. c.Abort()
  89. return
  90. }
  91. c.Set("username", username)
  92. c.Set("role", role)
  93. c.Set("id", id)
  94. c.Next()
  95. }
  96. func TryUserAuth() func(c *gin.Context) {
  97. return func(c *gin.Context) {
  98. session := sessions.Default(c)
  99. id := session.Get("id")
  100. if id != nil {
  101. c.Set("id", id)
  102. }
  103. c.Next()
  104. }
  105. }
  106. func UserAuth() func(c *gin.Context) {
  107. return func(c *gin.Context) {
  108. authHelper(c, common.RoleCommonUser)
  109. }
  110. }
  111. func AdminAuth() func(c *gin.Context) {
  112. return func(c *gin.Context) {
  113. authHelper(c, common.RoleAdminUser)
  114. }
  115. }
  116. func RootAuth() func(c *gin.Context) {
  117. return func(c *gin.Context) {
  118. authHelper(c, common.RoleRootUser)
  119. }
  120. }
  121. func TokenAuth() func(c *gin.Context) {
  122. return func(c *gin.Context) {
  123. key := c.Request.Header.Get("Authorization")
  124. parts := make([]string, 0)
  125. key = strings.TrimPrefix(key, "Bearer ")
  126. if key == "" || key == "midjourney-proxy" {
  127. key = c.Request.Header.Get("mj-api-secret")
  128. key = strings.TrimPrefix(key, "Bearer ")
  129. key = strings.TrimPrefix(key, "sk-")
  130. parts = strings.Split(key, "-")
  131. key = parts[0]
  132. } else {
  133. key = strings.TrimPrefix(key, "sk-")
  134. parts = strings.Split(key, "-")
  135. key = parts[0]
  136. }
  137. token, err := model.ValidateUserToken(key)
  138. if err != nil {
  139. abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
  140. return
  141. }
  142. userEnabled, err := model.CacheIsUserEnabled(token.UserId)
  143. if err != nil {
  144. abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
  145. return
  146. }
  147. if !userEnabled {
  148. abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
  149. return
  150. }
  151. c.Set("id", token.UserId)
  152. c.Set("token_id", token.Id)
  153. c.Set("token_name", token.Name)
  154. c.Set("token_unlimited_quota", token.UnlimitedQuota)
  155. if !token.UnlimitedQuota {
  156. c.Set("token_quota", token.RemainQuota)
  157. }
  158. if token.ModelLimitsEnabled {
  159. c.Set("token_model_limit_enabled", true)
  160. c.Set("token_model_limit", token.GetModelLimitsMap())
  161. } else {
  162. c.Set("token_model_limit_enabled", false)
  163. }
  164. if len(parts) > 1 {
  165. if model.IsAdmin(token.UserId) {
  166. c.Set("specific_channel_id", parts[1])
  167. } else {
  168. abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
  169. return
  170. }
  171. }
  172. c.Next()
  173. }
  174. }