CalciumIon 1 год назад
Родитель
Сommit
8a9ff36fbf
2 измененных файлов с 49 добавлено и 27 удалено
  1. 49 26
      controller/relay.go
  2. 0 1
      middleware/distributor.go

+ 49 - 26
controller/relay.go

@@ -39,38 +39,28 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 
 func Relay(c *gin.Context) {
 	relayMode := constant.Path2RelayMode(c.Request.URL.Path)
-	retryTimes := common.RetryTimes
 	requestId := c.GetString(common.RequestIdKey)
-	channelId := c.GetInt("channel_id")
-	channelType := c.GetInt("channel_type")
-	channelName := c.GetString("channel_name")
 	group := c.GetString("group")
 	originalModel := c.GetString("original_model")
-	openaiErr := relayHandler(c, relayMode)
-	c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
-	if openaiErr != nil {
-		go processChannelError(c, channelId, channelType, channelName, openaiErr)
-	} else {
-		retryTimes = 0
-	}
-	for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
-		channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
+	var openaiErr *dto.OpenAIErrorWithStatusCode
+
+	for i := 0; i <= common.RetryTimes; i++ {
+		channel, err := getChannel(c, group, originalModel, i)
 		if err != nil {
-			common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
+			common.LogError(c, fmt.Sprintf("Failed to get channel: %s", err.Error()))
 			break
 		}
-		channelId = channel.Id
-		useChannel := c.GetStringSlice("use_channel")
-		useChannel = append(useChannel, fmt.Sprintf("%d", channel.Id))
-		c.Set("use_channel", useChannel)
-		common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
-		middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 
-		requestBody, err := common.GetRequestBody(c)
-		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
-		openaiErr = relayHandler(c, relayMode)
-		if openaiErr != nil {
-			go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr)
+		openaiErr = relayRequest(c, relayMode, channel)
+
+		if openaiErr == nil {
+			return // 成功处理请求,直接返回
+		}
+
+		go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr)
+
+		if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
+			break
 		}
 	}
 	useChannel := c.GetStringSlice("use_channel")
@@ -90,7 +80,36 @@ func Relay(c *gin.Context) {
 	}
 }
 
-func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
+func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
+	addUsedChannel(c, channel.Id)
+	requestBody, _ := common.GetRequestBody(c)
+	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+	return relayHandler(c, relayMode)
+}
+
+func addUsedChannel(c *gin.Context, channelId int) {
+	useChannel := c.GetStringSlice("use_channel")
+	useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
+	c.Set("use_channel", useChannel)
+}
+
+func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
+	if retryCount == 0 {
+		return &model.Channel{
+			Id:   c.GetInt("channel_id"),
+			Type: c.GetInt("channel_type"),
+			Name: c.GetString("channel_name"),
+		}, nil
+	}
+	channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
+	if err != nil {
+		return nil, err
+	}
+	middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+	return channel, nil
+}
+
+func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
 	if openaiErr == nil {
 		return false
 	}
@@ -114,6 +133,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
 		return true
 	}
 	if openaiErr.StatusCode == http.StatusBadRequest {
+		channelType := c.GetInt("channel_type")
+		if channelType == common.ChannelTypeAnthropic {
+			return true
+		}
 		return false
 	}
 	if openaiErr.StatusCode == 408 {

+ 0 - 1
middleware/distributor.go

@@ -184,7 +184,6 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 	if channel == nil {
 		return
 	}
-	c.Set("channel", channel.Type)
 	c.Set("channel_id", channel.Id)
 	c.Set("channel_name", channel.Name)
 	c.Set("channel_type", channel.Type)