| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- package middleware
- import (
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
- "net/http"
- "one-api/common"
- "one-api/model"
- "strings"
- )
- func authHelper(c *gin.Context, minRole int) {
- session := sessions.Default(c)
- username := session.Get("username")
- role := session.Get("role")
- id := session.Get("id")
- status := session.Get("status")
- if username == nil {
- // Check access token
- accessToken := c.Request.Header.Get("Authorization")
- if accessToken == "" {
- c.JSON(http.StatusUnauthorized, gin.H{
- "success": false,
- "message": "无权进行此操作,未登录且未提供 access token",
- })
- c.Abort()
- return
- }
- user := model.ValidateAccessToken(accessToken)
- if user != nil && user.Username != "" {
- // Token is valid
- username = user.Username
- role = user.Role
- id = user.Id
- status = user.Status
- } else {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无权进行此操作,access token 无效",
- })
- c.Abort()
- return
- }
- }
- if status.(int) == common.UserStatusDisabled {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "用户已被封禁",
- })
- c.Abort()
- return
- }
- if role.(int) < minRole {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "无权进行此操作,权限不足",
- })
- c.Abort()
- return
- }
- c.Set("username", username)
- c.Set("role", role)
- c.Set("id", id)
- c.Next()
- }
- func UserAuth() func(c *gin.Context) {
- return func(c *gin.Context) {
- authHelper(c, common.RoleCommonUser)
- }
- }
- func AdminAuth() func(c *gin.Context) {
- return func(c *gin.Context) {
- authHelper(c, common.RoleAdminUser)
- }
- }
- func RootAuth() func(c *gin.Context) {
- return func(c *gin.Context) {
- authHelper(c, common.RoleRootUser)
- }
- }
- func TokenAuth() func(c *gin.Context) {
- return func(c *gin.Context) {
- key := c.Request.Header.Get("Authorization")
- key = strings.TrimPrefix(key, "Bearer ")
- key = strings.TrimPrefix(key, "sk-")
- parts := strings.Split(key, "-")
- key = parts[0]
- token, err := model.ValidateUserToken(key)
- if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{
- "error": gin.H{
- "message": err.Error(),
- "type": "one_api_error",
- },
- })
- c.Abort()
- return
- }
- if !model.CacheIsUserEnabled(token.UserId) {
- c.JSON(http.StatusForbidden, gin.H{
- "error": gin.H{
- "message": "用户已被封禁",
- "type": "one_api_error",
- },
- })
- c.Abort()
- return
- }
- c.Set("id", token.UserId)
- c.Set("token_id", token.Id)
- c.Set("token_name", token.Name)
- requestURL := c.Request.URL.String()
- consumeQuota := true
- if strings.HasPrefix(requestURL, "/v1/models") {
- consumeQuota = false
- }
- c.Set("consume_quota", consumeQuota)
- if len(parts) > 1 {
- if model.IsAdmin(token.UserId) {
- c.Set("channelId", parts[1])
- } else {
- c.JSON(http.StatusForbidden, gin.H{
- "error": gin.H{
- "message": "普通用户不支持指定渠道",
- "type": "one_api_error",
- },
- })
- c.Abort()
- return
- }
- }
- c.Next()
- }
- }
|