Browse Source

feat: retry delay when suing same channel (#155)

zijiren 8 months ago
parent
commit
64002f5841

+ 20 - 9
core/controller/relay-controller.go

@@ -522,6 +522,8 @@ func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayControlle
 	i := 0
 
 	for {
+		lastStatusCode := state.result.Error.StatusCode
+		lastChannelID := state.meta.Channel.ID
 		newChannel, err := getRetryChannel(state)
 		if err == nil {
 			err = prepareRetry(c)
@@ -552,6 +554,11 @@ func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayControlle
 			state.retryTimes-i,
 		)
 
+		// Check if we should delay (using the same channel)
+		if shouldDelay(lastStatusCode, lastChannelID, newChannel.ID) {
+			relayDelay()
+		}
+
 		state.meta = NewMetaByContext(
 			c,
 			newChannel,
@@ -582,10 +589,6 @@ func getRetryChannel(state *retryState) (*model.Channel, error) {
 		if state.lastHasPermissionChannel == nil {
 			return nil, ErrChannelsExhausted
 		}
-		if shouldDelay(state.result.Error.StatusCode) {
-			//nolint:gosec
-			time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
-		}
 		return state.lastHasPermissionChannel, nil
 	}
 
@@ -595,10 +598,6 @@ func getRetryChannel(state *retryState) (*model.Channel, error) {
 			return nil, err
 		}
 		state.exhausted = true
-		if shouldDelay(state.result.Error.StatusCode) {
-			//nolint:gosec
-			time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
-		}
 		return state.lastHasPermissionChannel, nil
 	}
 
@@ -671,11 +670,23 @@ func channelHasPermission(relayErr relaymodel.ErrorWithStatusCode) bool {
 	return !ok
 }
 
-func shouldDelay(statusCode int) bool {
+// shouldDelay checks if we need to add a delay before retrying
+// Only adds delay when retrying with the same channel for rate limiting issues
+func shouldDelay(statusCode int, lastChannelID, newChannelID int) bool {
+	if lastChannelID != newChannelID {
+		return false
+	}
+
+	// Only delay for rate limiting or service unavailable errors
 	return statusCode == http.StatusTooManyRequests ||
 		statusCode == http.StatusServiceUnavailable
 }
 
+func relayDelay() {
+	//nolint:gosec
+	time.Sleep(time.Duration(rand.Float64()*float64(time.Second)) + time.Second)
+}
+
 func RelayNotImplemented(c *gin.Context) {
 	c.JSON(http.StatusNotImplemented, gin.H{
 		"error": &relaymodel.Error{

+ 0 - 4
core/relay/adaptor/gemini/main.go

@@ -614,10 +614,6 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model
 		}
 		data = data[6:]
 
-		if conv.BytesToString(data) == "[DONE]" {
-			break
-		}
-
 		var geminiResponse ChatResponse
 		err := sonic.Unmarshal(data, &geminiResponse)
 		if err != nil {

+ 9 - 1
core/relay/adaptor/xai/error.go

@@ -3,6 +3,7 @@ package xai
 import (
 	"io"
 	"net/http"
+	"strings"
 
 	"github.com/bytedance/sonic"
 	"github.com/labring/aiproxy/core/common/conv"
@@ -28,5 +29,12 @@ func ErrorHandler(resp *http.Response) *model.ErrorWithStatusCode {
 	if err != nil {
 		return openai.ErrorWrapperWithMessage(conv.BytesToString(data), nil, http.StatusInternalServerError)
 	}
-	return openai.ErrorWrapperWithMessage(er.Error, er.Code, resp.StatusCode)
+
+	statusCode := resp.StatusCode
+
+	if strings.Contains(er.Error, "Incorrect API key provided") {
+		statusCode = http.StatusUnauthorized
+	}
+
+	return openai.ErrorWrapperWithMessage(er.Error, er.Code, statusCode)
 }

+ 1 - 1
core/relay/controller/image.go

@@ -51,7 +51,7 @@ func GetImageRequestPrice(c *gin.Context, mc *model.ModelConfig) (model.Price, e
 
 	imageCostPrice, ok := GetImageOutputPrice(mc, imageRequest.Size, imageRequest.Quality)
 	if !ok {
-		return model.Price{}, fmt.Errorf("invalid image size: %s", imageRequest.Size)
+		return model.Price{}, fmt.Errorf("invalid image size `%s` or quality `%s`", imageRequest.Size, imageRequest.Quality)
 	}
 
 	return model.Price{