JustSong 2 лет назад
Родитель
Сommit
852af57c03
12 измененных файлов с 225 добавлено и 70 удалено
  1. 18 0
      common/constants.go
  2. 1 1
      controller/channel.go
  3. 52 0
      controller/relay.go
  4. 20 49
      middleware/auth.go
  5. 68 0
      middleware/distributor.go
  6. 20 3
      model/channel.go
  7. 1 0
      model/main.go
  8. 26 1
      model/token.go
  9. 0 13
      model/user.go
  10. 3 3
      router/api-router.go
  11. 1 0
      router/main.go
  12. 15 0
      router/relay-router.go

+ 18 - 0
common/constants.go

@@ -12,6 +12,8 @@ var SystemName = "One API"
 var ServerAddress = "http://localhost:3000"
 var Footer = ""
 
+var UsingSQLite = false
+
 // Any options with "Secret", "Token" in its key won't be return by GetOptions
 
 var SessionSecret = uuid.New().String()
@@ -84,6 +86,11 @@ const (
 	UserStatusDisabled = 2 // also don't use 0
 )
 
+const (
+	TokenStatusEnabled  = 1 // don't use 0, 0 is the default value!
+	TokenStatusDisabled = 2 // also don't use 0
+)
+
 const (
 	ChannelStatusUnknown  = 0
 	ChannelStatusEnabled  = 1 // don't use 0, 0 is the default value!
@@ -100,3 +107,14 @@ const (
 	ChannelTypeOpenAIMax = 6
 	ChannelTypeOhMyGPT   = 7
 )
+
+var ChannelHosts = []string{
+	"",                            // 0
+	"https://api.openai.com",      // 1
+	"https://openai.api2d.net",    // 2
+	"",                            // 3
+	"https://api.openai-asia.com", // 4
+	"https://api.openai-sb.com",   // 5
+	"https://api.openaimax.com",   // 6
+	"https://api.ohmygpt.com",     // 7
+}

+ 1 - 1
controller/channel.go

@@ -56,7 +56,7 @@ func GetChannel(c *gin.Context) {
 		})
 		return
 	}
-	channel, err := model.GetChannelById(id)
+	channel, err := model.GetChannelById(id, false)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,

+ 52 - 0
controller/relay.go

@@ -0,0 +1,52 @@
+package controller
+
+import (
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+)
+
+func Relay(c *gin.Context) {
+	channelType := c.GetInt("channel")
+	host := common.ChannelHosts[channelType]
+	req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s/%s", host, c.Request.URL.String()), c.Request.Body)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"error": gin.H{
+				"message": err.Error(),
+				"type":    "one_api_error",
+			},
+		})
+		return
+	}
+	req.Header = c.Request.Header.Clone()
+	client := &http.Client{}
+
+	resp, err := client.Do(req)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"error": gin.H{
+				"message": err.Error(),
+				"type":    "one_api_error",
+			},
+		})
+		return
+	}
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	_, err = io.Copy(c.Writer, resp.Body)
+	//body, err := io.ReadAll(resp.Body)
+	//_, err = c.Writer.Write(body)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"error": gin.H{
+				"message": err.Error(),
+				"type":    "one_api_error",
+			},
+		})
+		return
+	}
+}

+ 20 - 49
middleware/auth.go

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/model"
+	"strings"
 )
 
 func authHelper(c *gin.Context, minRole int) {
@@ -14,34 +15,13 @@ func authHelper(c *gin.Context, minRole int) {
 	role := session.Get("role")
 	id := session.Get("id")
 	status := session.Get("status")
-	authByToken := false
 	if username == nil {
-		// Check token
-		token := c.Request.Header.Get("Authorization")
-		if token == "" {
-			c.JSON(http.StatusOK, gin.H{
-				"success": false,
-				"message": "无权进行此操作,未登录或 token 无效",
-			})
-			c.Abort()
-			return
-		}
-		user := model.ValidateUserToken(token)
-		if user != nil && user.Username != "" {
-			// Token is valid
-			username = user.Username
-			role = user.Role
-			id = user.Id
-			status = user.Status
-		} else {
-			c.JSON(http.StatusOK, gin.H{
-				"success": false,
-				"message": "无权进行此操作,token 无效",
-			})
-			c.Abort()
-			return
-		}
-		authByToken = true
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "无权进行此操作,未登录",
+		})
+		c.Abort()
+		return
 	}
 	if status.(int) == common.UserStatusDisabled {
 		c.JSON(http.StatusOK, gin.H{
@@ -62,7 +42,6 @@ func authHelper(c *gin.Context, minRole int) {
 	c.Set("username", username)
 	c.Set("role", role)
 	c.Set("id", id)
-	c.Set("authByToken", authByToken)
 	c.Next()
 }
 
@@ -84,33 +63,25 @@ func RootAuth() func(c *gin.Context) {
 	}
 }
 
-// NoTokenAuth You should always use this after normal auth middlewares.
-func NoTokenAuth() func(c *gin.Context) {
+func TokenAuth() func(c *gin.Context) {
 	return func(c *gin.Context) {
-		authByToken := c.GetBool("authByToken")
-		if authByToken {
+		key := c.Request.Header.Get("Authorization")
+		parts := strings.Split(key, "-")
+		key = parts[0]
+		token, err := model.ValidateUserToken(key)
+		if err != nil {
 			c.JSON(http.StatusOK, gin.H{
-				"success": false,
-				"message": "本接口不支持使用 token 进行验证",
+				"error": gin.H{
+					"message": err.Error(),
+					"type":    "one_api_error",
+				},
 			})
 			c.Abort()
 			return
 		}
-		c.Next()
-	}
-}
-
-// TokenOnlyAuth You should always use this after normal auth middlewares.
-func TokenOnlyAuth() func(c *gin.Context) {
-	return func(c *gin.Context) {
-		authByToken := c.GetBool("authByToken")
-		if !authByToken {
-			c.JSON(http.StatusOK, gin.H{
-				"success": false,
-				"message": "本接口仅支持使用 token 进行验证",
-			})
-			c.Abort()
-			return
+		c.Set("id", token.UserId)
+		if len(parts) > 1 {
+			c.Set("channelId", parts[1])
 		}
 		c.Next()
 	}

+ 68 - 0
middleware/distributor.go

@@ -0,0 +1,68 @@
+package middleware
+
+import (
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"net/http"
+	"one-api/common"
+	"one-api/model"
+	"strconv"
+)
+
+func Distribute() func(c *gin.Context) {
+	return func(c *gin.Context) {
+		var channel *model.Channel
+		channelId, ok := c.Get("channelId")
+		if ok {
+			id, err := strconv.Atoi(channelId.(string))
+			if err != nil {
+				c.JSON(http.StatusOK, gin.H{
+					"error": gin.H{
+						"message": "无效的渠道 ID",
+						"type":    "one_api_error",
+					},
+				})
+				c.Abort()
+				return
+			}
+			channel, err = model.GetChannelById(id, true)
+			if err != nil {
+				c.JSON(200, gin.H{
+					"error": gin.H{
+						"message": "无效的渠道 ID",
+						"type":    "one_api_error",
+					},
+				})
+				c.Abort()
+				return
+			}
+			if channel.Status != common.ChannelStatusEnabled {
+				c.JSON(200, gin.H{
+					"error": gin.H{
+						"message": "该渠道已被禁用",
+						"type":    "one_api_error",
+					},
+				})
+				c.Abort()
+				return
+			}
+		} else {
+			// Select a channel for the user
+			var err error
+			channel, err = model.GetRandomChannel()
+			if err != nil {
+				c.JSON(200, gin.H{
+					"error": gin.H{
+						"message": "无可用渠道",
+						"type":    "one_api_error",
+					},
+				})
+				c.Abort()
+				return
+			}
+		}
+		c.Set("channel", channel.Type)
+		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+		c.Next()
+	}
+}

+ 20 - 3
model/channel.go

@@ -2,12 +2,13 @@ package model
 
 import (
 	_ "gorm.io/driver/sqlite"
+	"one-api/common"
 )
 
 type Channel struct {
 	Id           int    `json:"id"`
 	Type         int    `json:"type" gorm:"default:0"`
-	Key          string `json:"key"`
+	Key          string `json:"key" gorm:"not null"`
 	Status       int    `json:"status" gorm:"default:1"`
 	Name         string `json:"name" gorm:"index"`
 	Weight       int    `json:"weight"`
@@ -27,10 +28,26 @@ func SearchChannels(keyword string) (channels []*Channel, err error) {
 	return channels, err
 }
 
-func GetChannelById(id int) (*Channel, error) {
+func GetChannelById(id int, selectAll bool) (*Channel, error) {
 	channel := Channel{Id: id}
 	var err error = nil
-	err = DB.Omit("key").First(&channel, "id = ?", id).Error
+	if selectAll {
+		err = DB.First(&channel, "id = ?", id).Error
+	} else {
+		err = DB.Omit("key").First(&channel, "id = ?", id).Error
+	}
+	return &channel, err
+}
+
+func GetRandomChannel() (*Channel, error) {
+	// TODO: consider weight
+	channel := Channel{}
+	var err error = nil
+	if common.UsingSQLite {
+		err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error
+	} else {
+		err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error
+	}
 	return &channel, err
 }
 

+ 1 - 0
model/main.go

@@ -45,6 +45,7 @@ func InitDB() (err error) {
 		})
 	} else {
 		// Use SQLite
+		common.UsingSQLite = true
 		db, err = gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
 			PrepareStmt: true, // precompile SQL
 		})

+ 26 - 1
model/token.go

@@ -3,12 +3,14 @@ package model
 import (
 	"errors"
 	_ "gorm.io/driver/sqlite"
+	"one-api/common"
+	"strings"
 )
 
 type Token struct {
 	Id           int    `json:"id"`
 	UserId       int    `json:"user_id"`
-	Key          string `json:"key"`
+	Key          string `json:"key" gorm:"uniqueIndex"`
 	Status       int    `json:"status" gorm:"default:1"`
 	Name         string `json:"name" gorm:"index" `
 	CreatedTime  int64  `json:"created_time" gorm:"bigint"`
@@ -27,6 +29,29 @@ func SearchUserTokens(userId int, keyword string) (tokens []*Token, err error) {
 	return tokens, err
 }
 
+func ValidateUserToken(key string) (token *Token, err error) {
+	if key == "" {
+		return nil, errors.New("未提供 token")
+	}
+	key = strings.Replace(key, "Bearer ", "", 1)
+	token = &Token{}
+	err = DB.Where("key = ?", key).First(token).Error
+	if err == nil {
+		if token.Status != common.TokenStatusEnabled {
+			return nil, errors.New("该 token 已被禁用")
+		}
+		go func() {
+			token.AccessedTime = common.GetTimestamp()
+			err := token.Update()
+			if err != nil {
+				common.SysError("更新 token 访问时间失败:" + err.Error())
+			}
+		}()
+		return token, nil
+	}
+	return nil, err
+}
+
 func GetTokenByIds(id int, userId int) (*Token, error) {
 	if id == 0 || userId == 0 {
 		return nil, errors.New("id 或 userId 为空!")

+ 0 - 13
model/user.go

@@ -3,7 +3,6 @@ package model
 import (
 	"errors"
 	"one-api/common"
-	"strings"
 )
 
 // User if you add sensitive fields, don't forget to clean them in setupLogin function.
@@ -149,18 +148,6 @@ func (user *User) FillUserByUsername() error {
 	return nil
 }
 
-func ValidateUserToken(token string) (user *User) {
-	if token == "" {
-		return nil
-	}
-	token = strings.Replace(token, "Bearer ", "", 1)
-	user = &User{}
-	if DB.Where("token = ?", token).First(user).RowsAffected == 1 {
-		return user
-	}
-	return nil
-}
-
 func IsEmailAlreadyTaken(email string) bool {
 	return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1
 }

+ 3 - 3
router/api-router.go

@@ -28,7 +28,7 @@ func SetApiRouter(router *gin.Engine) {
 			userRoute.GET("/logout", controller.Logout)
 
 			selfRoute := userRoute.Group("/")
-			selfRoute.Use(middleware.UserAuth(), middleware.NoTokenAuth())
+			selfRoute.Use(middleware.UserAuth())
 			{
 				selfRoute.GET("/self", controller.GetSelf)
 				selfRoute.PUT("/self", controller.UpdateSelf)
@@ -36,7 +36,7 @@ func SetApiRouter(router *gin.Engine) {
 			}
 
 			adminRoute := userRoute.Group("/")
-			adminRoute.Use(middleware.AdminAuth(), middleware.NoTokenAuth())
+			adminRoute.Use(middleware.AdminAuth())
 			{
 				adminRoute.GET("/", controller.GetAllUsers)
 				adminRoute.GET("/search", controller.SearchUsers)
@@ -48,7 +48,7 @@ func SetApiRouter(router *gin.Engine) {
 			}
 		}
 		optionRoute := apiRouter.Group("/option")
-		optionRoute.Use(middleware.RootAuth(), middleware.NoTokenAuth())
+		optionRoute.Use(middleware.RootAuth())
 		{
 			optionRoute.GET("/", controller.GetOptions)
 			optionRoute.PUT("/", controller.UpdateOption)

+ 1 - 0
router/main.go

@@ -7,5 +7,6 @@ import (
 
 func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 	SetApiRouter(router)
+	SetRelayRouter(router)
 	setWebRouter(router, buildFS, indexPage)
 }

+ 15 - 0
router/relay-router.go

@@ -0,0 +1,15 @@
+package router
+
+import (
+	"github.com/gin-gonic/gin"
+	"one-api/controller"
+	"one-api/middleware"
+)
+
+func SetRelayRouter(router *gin.Engine) {
+	relayRouter := router.Group("/v1")
+	relayRouter.Use(middleware.GlobalAPIRateLimit(), middleware.TokenAuth(), middleware.Distribute())
+	{
+		relayRouter.POST("/chat/completions", controller.Relay)
+	}
+}