CaIon 4 місяців тому
батько
коміт
29ec328f46
2 змінених файлів з 20 додано та 27 видалено
  1. 3 27
      controller/playground.go
  2. 17 0
      middleware/distributor.go

+ 3 - 27
controller/playground.go

@@ -5,10 +5,8 @@ import (
 	"fmt"
 	"one-api/common"
 	"one-api/constant"
-	"one-api/dto"
 	"one-api/middleware"
 	"one-api/model"
-	"one-api/setting"
 	"one-api/types"
 	"time"
 
@@ -32,30 +30,8 @@ func Playground(c *gin.Context) {
 		return
 	}
 
-	playgroundRequest := &dto.PlayGroundRequest{}
-	err := common.UnmarshalBodyReusable(c, playgroundRequest)
-	if err != nil {
-		newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
-		return
-	}
-
-	if playgroundRequest.Model == "" {
-		newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
-		return
-	}
-	c.Set("original_model", playgroundRequest.Model)
-	group := playgroundRequest.Group
-	userGroup := c.GetString("group")
-
-	if group == "" {
-		group = userGroup
-	} else {
-		if !setting.GroupInUserUsableGroups(group) && group != userGroup {
-			newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
-			return
-		}
-		c.Set("group", group)
-	}
+	group := c.GetString("group")
+	modelName := c.GetString("original_model")
 
 	userId := c.GetInt("id")
 
@@ -73,7 +49,7 @@ func Playground(c *gin.Context) {
 		Group:  group,
 	}
 	_ = middleware.SetupContextForToken(c, tempToken)
-	_, newAPIError = getChannel(c, group, playgroundRequest.Model, 0)
+	_, newAPIError = getChannel(c, group, modelName, 0)
 	if newAPIError != nil {
 		return
 	}

+ 17 - 0
middleware/distributor.go

@@ -10,6 +10,7 @@ import (
 	"one-api/model"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
+	"one-api/setting"
 	"one-api/setting/ratio_setting"
 	"one-api/types"
 	"strconv"
@@ -78,6 +79,22 @@ func Distribute() func(c *gin.Context) {
 				}
 				var selectGroup string
 				userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
+				// check path is /pg/chat/completions
+				if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
+					playgroundRequest := &dto.PlayGroundRequest{}
+					err = common.UnmarshalBodyReusable(c, playgroundRequest)
+					if err != nil {
+						abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
+						return
+					}
+					if playgroundRequest.Group != "" {
+						if !setting.GroupInUserUsableGroups(playgroundRequest.Group) && playgroundRequest.Group != userGroup {
+							abortWithOpenAiMessage(c, http.StatusForbidden, "无权访问该分组")
+							return
+						}
+						userGroup = playgroundRequest.Group
+					}
+				}
 				channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
 				if err != nil {
 					showGroup := userGroup