auth.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. package middleware
  2. import (
  3. "fmt"
  4. "log"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/constant"
  8. "one-api/model"
  9. "strconv"
  10. "strings"
  11. "github.com/gin-contrib/sessions"
  12. "github.com/gin-gonic/gin"
  13. )
  14. func validUserInfo(username string, role int) bool {
  15. // check username is empty
  16. if strings.TrimSpace(username) == "" {
  17. return false
  18. }
  19. if !common.IsValidateRole(role) {
  20. return false
  21. }
  22. return true
  23. }
  24. func authHelper(c *gin.Context, minRole int) {
  25. session := sessions.Default(c)
  26. username := session.Get("username")
  27. role := session.Get("role")
  28. id := session.Get("id")
  29. status := session.Get("status")
  30. useAccessToken := false
  31. if username == nil {
  32. // Check access token
  33. accessToken := c.Request.Header.Get("Authorization")
  34. if accessToken == "" {
  35. c.JSON(http.StatusUnauthorized, gin.H{
  36. "success": false,
  37. "message": "无权进行此操作,未登录且未提供 access token",
  38. })
  39. c.Abort()
  40. return
  41. }
  42. user := model.ValidateAccessToken(accessToken)
  43. if user != nil && user.Username != "" {
  44. if !validUserInfo(user.Username, user.Role) {
  45. c.JSON(http.StatusOK, gin.H{
  46. "success": false,
  47. "message": "无权进行此操作,用户信息无效",
  48. })
  49. c.Abort()
  50. return
  51. }
  52. // Token is valid
  53. username = user.Username
  54. role = user.Role
  55. id = user.Id
  56. status = user.Status
  57. useAccessToken = true
  58. } else {
  59. c.JSON(http.StatusOK, gin.H{
  60. "success": false,
  61. "message": "无权进行此操作,access token 无效",
  62. })
  63. c.Abort()
  64. return
  65. }
  66. }
  67. // get header New-Api-User
  68. apiUserIdStr := c.Request.Header.Get("New-Api-User")
  69. if apiUserIdStr == "" {
  70. c.JSON(http.StatusUnauthorized, gin.H{
  71. "success": false,
  72. "message": "无权进行此操作,未提供 New-Api-User",
  73. })
  74. c.Abort()
  75. return
  76. }
  77. apiUserId, err := strconv.Atoi(apiUserIdStr)
  78. if err != nil {
  79. c.JSON(http.StatusUnauthorized, gin.H{
  80. "success": false,
  81. "message": "无权进行此操作,New-Api-User 格式错误",
  82. })
  83. c.Abort()
  84. return
  85. }
  86. if id != apiUserId {
  87. c.JSON(http.StatusUnauthorized, gin.H{
  88. "success": false,
  89. "message": "无权进行此操作,New-Api-User 与登录用户不匹配",
  90. })
  91. c.Abort()
  92. return
  93. }
  94. if status.(int) == common.UserStatusDisabled {
  95. c.JSON(http.StatusOK, gin.H{
  96. "success": false,
  97. "message": "用户已被封禁",
  98. })
  99. c.Abort()
  100. return
  101. }
  102. if role.(int) < minRole {
  103. c.JSON(http.StatusOK, gin.H{
  104. "success": false,
  105. "message": "无权进行此操作,权限不足",
  106. })
  107. c.Abort()
  108. return
  109. }
  110. if !validUserInfo(username.(string), role.(int)) {
  111. c.JSON(http.StatusOK, gin.H{
  112. "success": false,
  113. "message": "无权进行此操作,用户信息无效",
  114. })
  115. c.Abort()
  116. return
  117. }
  118. c.Set("username", username)
  119. c.Set("role", role)
  120. c.Set("id", id)
  121. c.Set("group", session.Get("group"))
  122. c.Set("use_access_token", useAccessToken)
  123. //userCache, err := model.GetUserCache(id.(int))
  124. //if err != nil {
  125. // c.JSON(http.StatusOK, gin.H{
  126. // "success": false,
  127. // "message": err.Error(),
  128. // })
  129. // c.Abort()
  130. // return
  131. //}
  132. //userCache.WriteContext(c)
  133. c.Next()
  134. }
  135. func TryUserAuth() func(c *gin.Context) {
  136. return func(c *gin.Context) {
  137. session := sessions.Default(c)
  138. id := session.Get("id")
  139. if id != nil {
  140. c.Set("id", id)
  141. }
  142. c.Next()
  143. }
  144. }
  145. func UserAuth() func(c *gin.Context) {
  146. return func(c *gin.Context) {
  147. authHelper(c, common.RoleCommonUser)
  148. }
  149. }
  150. func AdminAuth() func(c *gin.Context) {
  151. return func(c *gin.Context) {
  152. authHelper(c, common.RoleAdminUser)
  153. }
  154. }
  155. func RootAuth() func(c *gin.Context) {
  156. return func(c *gin.Context) {
  157. authHelper(c, common.RoleRootUser)
  158. }
  159. }
  160. func WssAuth(c *gin.Context) {
  161. }
  162. func TokenAuth() func(c *gin.Context) {
  163. return func(c *gin.Context) {
  164. log.Println("********************", c)
  165. // 先检测是否为ws
  166. if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
  167. // Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
  168. // read sk from Sec-WebSocket-Protocol
  169. key := c.Request.Header.Get("Sec-WebSocket-Protocol")
  170. parts := strings.Split(key, ",")
  171. for _, part := range parts {
  172. part = strings.TrimSpace(part)
  173. if strings.HasPrefix(part, "openai-insecure-api-key") {
  174. key = strings.TrimPrefix(part, "openai-insecure-api-key.")
  175. break
  176. }
  177. }
  178. c.Request.Header.Set("Authorization", "Bearer "+key)
  179. }
  180. // 检查path包含/v1/messages
  181. if strings.Contains(c.Request.URL.Path, "/v1/messages") {
  182. // 从x-api-key中获取key
  183. key := c.Request.Header.Get("x-api-key")
  184. if key != "" {
  185. c.Request.Header.Set("Authorization", "Bearer "+key)
  186. }
  187. }
  188. // gemini api 从query中获取key
  189. if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
  190. skKey := c.Query("key")
  191. if skKey != "" {
  192. c.Request.Header.Set("Authorization", "Bearer "+skKey)
  193. }
  194. // 从x-goog-api-key header中获取key
  195. xGoogKey := c.Request.Header.Get("x-goog-api-key")
  196. if xGoogKey != "" {
  197. c.Request.Header.Set("Authorization", "Bearer "+xGoogKey)
  198. }
  199. }
  200. key := c.Request.Header.Get("Authorization")
  201. parts := make([]string, 0)
  202. key = strings.TrimPrefix(key, "Bearer ")
  203. if key == "" || key == "midjourney-proxy" {
  204. key = c.Request.Header.Get("mj-api-secret")
  205. key = strings.TrimPrefix(key, "Bearer ")
  206. key = strings.TrimPrefix(key, "sk-")
  207. parts = strings.Split(key, "-")
  208. key = parts[0]
  209. } else {
  210. key = strings.TrimPrefix(key, "sk-")
  211. parts = strings.Split(key, "-")
  212. key = parts[0]
  213. }
  214. token, err := model.ValidateUserToken(key)
  215. if token != nil {
  216. id := c.GetInt("id")
  217. if id == 0 {
  218. c.Set("id", token.UserId)
  219. }
  220. }
  221. if err != nil {
  222. abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
  223. return
  224. }
  225. userCache, err := model.GetUserCache(token.UserId)
  226. if err != nil {
  227. abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
  228. return
  229. }
  230. userEnabled := userCache.Status == common.UserStatusEnabled
  231. if !userEnabled {
  232. abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
  233. return
  234. }
  235. userCache.WriteContext(c)
  236. err = SetupContextForToken(c, token, parts...)
  237. if err != nil {
  238. return
  239. }
  240. // 增加Token使用次数
  241. go func() {
  242. if increaseErr := model.IncreaseTokenUsageCount(token.Key); increaseErr != nil {
  243. common.SysError("failed to increase token usage count: " + increaseErr.Error())
  244. }
  245. }()
  246. // 记录Token使用日志(用于频率限制)
  247. go func() {
  248. if recordErr := model.RecordTokenUsage(token.Id); recordErr != nil {
  249. common.SysError("failed to record token usage: " + recordErr.Error())
  250. }
  251. }()
  252. c.Next()
  253. }
  254. }
  255. func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
  256. if token == nil {
  257. return fmt.Errorf("token is nil")
  258. }
  259. c.Set("id", token.UserId)
  260. c.Set("token_id", token.Id)
  261. c.Set("token_key", token.Key)
  262. c.Set("token_name", token.Name)
  263. c.Set("token_unlimited_quota", token.UnlimitedQuota)
  264. if !token.UnlimitedQuota {
  265. c.Set("token_quota", token.RemainQuota)
  266. }
  267. if token.ModelLimitsEnabled {
  268. c.Set("token_model_limit_enabled", true)
  269. c.Set("token_model_limit", token.GetModelLimitsMap())
  270. } else {
  271. c.Set("token_model_limit_enabled", false)
  272. }
  273. c.Set("allow_ips", token.GetIpLimitsMap())
  274. c.Set("token_group", token.Group)
  275. // 设置令牌渠道标签到上下文中
  276. if token.ChannelTag != nil && *token.ChannelTag != "" {
  277. c.Set(string(constant.ContextKeyTokenChannelTag), *token.ChannelTag)
  278. }
  279. if len(parts) > 1 {
  280. if model.IsAdmin(token.UserId) {
  281. c.Set("specific_channel_id", parts[1])
  282. } else {
  283. abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
  284. return fmt.Errorf("普通用户不支持指定渠道")
  285. }
  286. }
  287. return nil
  288. }