Browse Source

refactor(relay): update channel retrieval to use RelayInfo structure

CaIon 3 weeks ago
parent
commit
ce6fb95f96
4 changed files with 29 additions and 21 deletions
  1. 9 5
      controller/playground.go
  2. 13 11
      controller/relay.go
  3. 2 0
      relay/common/relay_info.go
  4. 5 5
      service/channel_select.go

+ 9 - 5
controller/playground.go

@@ -9,6 +9,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/middleware"
 	"github.com/QuantumNous/new-api/model"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/types"
 
 	"github.com/gin-gonic/gin"
@@ -31,8 +32,11 @@ func Playground(c *gin.Context) {
 		return
 	}
 
-	group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
-	modelName := c.GetString("original_model")
+	relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatOpenAI, nil, nil)
+	if err != nil {
+		newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+		return
+	}
 
 	userId := c.GetInt("id")
 
@@ -46,11 +50,11 @@ func Playground(c *gin.Context) {
 
 	tempToken := &model.Token{
 		UserId: userId,
-		Name:   fmt.Sprintf("playground-%s", group),
-		Group:  group,
+		Name:   fmt.Sprintf("playground-%s", relayInfo.UsingGroup),
+		Group:  relayInfo.UsingGroup,
 	}
 	_ = middleware.SetupContextForToken(c, tempToken)
-	_, newAPIError = getChannel(c, group, modelName, 0)
+	_, newAPIError = getChannel(c, relayInfo, 0)
 	if newAPIError != nil {
 		return
 	}

+ 13 - 11
controller/relay.go

@@ -64,8 +64,8 @@ func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewA
 func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 
 	requestId := c.GetString(common.RequestIdKey)
-	group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
-	originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
+	//group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
+	//originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
 
 	var (
 		newAPIError *types.NewAPIError
@@ -158,7 +158,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 	}()
 
 	for i := 0; i <= common.RetryTimes; i++ {
-		channel, err := getChannel(c, group, originalModel, i)
+		channel, err := getChannel(c, relayInfo, i)
 		if err != nil {
 			logger.LogError(c, err.Error())
 			newAPIError = err
@@ -211,7 +211,7 @@ func addUsedChannel(c *gin.Context, channelId int) {
 	c.Set("use_channel", useChannel)
 }
 
-func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
+func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryCount int) (*model.Channel, *types.NewAPIError) {
 	if retryCount == 0 {
 		autoBan := c.GetBool("auto_ban")
 		autoBanInt := 1
@@ -225,14 +225,18 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
 			AutoBan: &autoBanInt,
 		}, nil
 	}
-	channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
+	channel, selectGroup, err := service.CacheGetRandomSatisfiedChannel(c, info.TokenGroup, info.OriginModelName, retryCount)
+
+	info.PriceData.GroupRatioInfo = helper.HandleGroupRatio(c, info)
+
 	if err != nil {
-		return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
+		return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, info.OriginModelName, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
 	}
 	if channel == nil {
-		return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
+		return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, info.OriginModelName), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
 	}
-	newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+
+	newAPIError := middleware.SetupContextForSelectedChannel(c, channel, info.OriginModelName)
 	if newAPIError != nil {
 		return nil, newAPIError
 	}
@@ -392,8 +396,6 @@ func RelayNotFound(c *gin.Context) {
 func RelayTask(c *gin.Context) {
 	retryTimes := common.RetryTimes
 	channelId := c.GetInt("channel_id")
-	group := c.GetString("group")
-	originalModel := c.GetString("original_model")
 	c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
 	relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
 	if err != nil {
@@ -404,7 +406,7 @@ func RelayTask(c *gin.Context) {
 		retryTimes = 0
 	}
 	for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
-		channel, newAPIError := getChannel(c, group, originalModel, i)
+		channel, newAPIError := getChannel(c, relayInfo, i)
 		if newAPIError != nil {
 			logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
 			taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)

+ 2 - 0
relay/common/relay_info.go

@@ -81,6 +81,7 @@ type TokenCountMeta struct {
 type RelayInfo struct {
 	TokenId           int
 	TokenKey          string
+	TokenGroup        string
 	UserId            int
 	UsingGroup        string // 使用的分组
 	UserGroup         string // 用户所在分组
@@ -400,6 +401,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
 		TokenId:        common.GetContextKeyInt(c, constant.ContextKeyTokenId),
 		TokenKey:       common.GetContextKeyString(c, constant.ContextKeyTokenKey),
 		TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
+		TokenGroup:     common.GetContextKeyString(c, constant.ContextKeyTokenGroup),
 
 		isFirstResponse: true,
 		RelayMode:       relayconstant.Path2RelayMode(c.Request.URL.Path),

+ 5 - 5
service/channel_select.go

@@ -12,12 +12,12 @@ import (
 )
 
 // CacheGetRandomSatisfiedChannel tries to get a random channel that satisfies the requirements.
-func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName string, retry int) (*model.Channel, string, error) {
+func CacheGetRandomSatisfiedChannel(c *gin.Context, tokenGroup string, modelName string, retry int) (*model.Channel, string, error) {
 	var channel *model.Channel
 	var err error
-	selectGroup := group
+	selectGroup := tokenGroup
 	userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
-	if group == "auto" {
+	if tokenGroup == "auto" {
 		if len(setting.GetAutoGroups()) == 0 {
 			return nil, selectGroup, errors.New("auto groups is not enabled")
 		}
@@ -49,9 +49,9 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, modelName stri
 			}
 		}
 	} else {
-		channel, err = model.GetRandomSatisfiedChannel(group, modelName, retry)
+		channel, err = model.GetRandomSatisfiedChannel(tokenGroup, modelName, retry)
 		if err != nil {
-			return nil, group, err
+			return nil, tokenGroup, err
 		}
 	}
 	return channel, selectGroup, nil