فهرست منبع

feat: 显式指定 error 跳过重试

Xyfacai 5 ماه پیش
والد
کامیت
1f5ef24ecd

+ 5 - 5
controller/playground.go

@@ -28,19 +28,19 @@ func Playground(c *gin.Context) {
 
 	useAccessToken := c.GetBool("use_access_token")
 	if useAccessToken {
-		newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
+		newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
 		return
 	}
 
 	playgroundRequest := &dto.PlayGroundRequest{}
 	err := common.UnmarshalBodyReusable(c, playgroundRequest)
 	if err != nil {
-		newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
+		newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 		return
 	}
 
 	if playgroundRequest.Model == "" {
-		newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
+		newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 		return
 	}
 	c.Set("original_model", playgroundRequest.Model)
@@ -51,7 +51,7 @@ func Playground(c *gin.Context) {
 		group = userGroup
 	} else {
 		if !setting.GroupInUserUsableGroups(group) && group != userGroup {
-			newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
+			newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
 			return
 		}
 		c.Set("group", group)
@@ -62,7 +62,7 @@ func Playground(c *gin.Context) {
 	// Write user context to ensure acceptUnsetRatio is available
 	userCache, err := model.GetUserCache(userId)
 	if err != nil {
-		newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
+		newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
 		return
 	}
 	userCache.WriteContext(c)

+ 4 - 4
controller/relay.go

@@ -127,7 +127,7 @@ func WssRelay(c *gin.Context) {
 	defer ws.Close()
 
 	if err != nil {
-		helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
+		helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
 		return
 	}
 
@@ -258,10 +258,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
 	}
 	channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
 	if err != nil {
-		return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
+		return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
 	}
 	if channel == nil {
-		return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed)
+		return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
 	}
 	newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 	if newAPIError != nil {
@@ -277,7 +277,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
 	if types.IsChannelError(openaiErr) {
 		return true
 	}
-	if types.IsLocalError(openaiErr) {
+	if types.IsSkipRetryError(openaiErr) {
 		return false
 	}
 	if retryTimes <= 0 {

+ 1 - 1
middleware/distributor.go

@@ -247,7 +247,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
 	c.Set("original_model", modelName) // for retry
 	if channel == nil {
-		return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed)
+		return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
 	}
 	common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
 	common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)

+ 1 - 1
model/channel.go

@@ -138,7 +138,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
 
 		channelInfo, err := CacheGetChannelInfo(channel.Id)
 		if err != nil {
-			return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed)
+			return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
 		}
 		//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
 		defer func() {

+ 5 - 5
relay/audio_handler.go

@@ -62,7 +62,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	if err != nil {
 		common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
-		return types.NewError(err, types.ErrorCodeInvalidRequest)
+		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 
 	promptTokens := 0
@@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError)
+		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
 	}
 
 	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -90,18 +90,18 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	adaptor.Init(relayInfo)
 
 	ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}
 
 	resp, err := adaptor.DoRequest(c, relayInfo, ioReader)

+ 9 - 9
relay/claude_handler.go

@@ -40,7 +40,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	// get & validate textRequest 获取并验证文本请求
 	textRequest, err := getAndValidateClaudeRequest(c)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeInvalidRequest)
+		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 
 	if textRequest.Stream {
@@ -49,18 +49,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	err = helper.ModelMappedHelper(c, relayInfo, textRequest)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 
 	promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
 	// count messages token error 计算promptTokens错误
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeCountTokenFailed)
+		return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
 	}
 
 	priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError)
+		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
 	}
 
 	// pre-consume quota 预消耗配额
@@ -77,7 +77,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	adaptor.Init(relayInfo)
 
@@ -111,17 +111,17 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
 		body, err := common.GetRequestBody(c)
 		if err != nil {
-			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
+			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
 		requestBody = bytes.NewBuffer(body)
 	} else {
 		convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 		jsonData, err := common.Marshal(convertedRequest)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 
 		// apply param override
@@ -133,7 +133,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 			}
 			jsonData, err = common.Marshal(reqMap)
 			if err != nil {
-				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
+				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}
 		}
 

+ 7 - 8
relay/embedding_handler.go

@@ -41,17 +41,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	err := common.UnmarshalBodyReusable(c, &embeddingRequest)
 	if err != nil {
 		common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
-		return types.NewError(err, types.ErrorCodeInvalidRequest)
+		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 
 	err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeInvalidRequest)
+		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 
 	err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 
 	promptToken := getEmbeddingPromptToken(*embeddingRequest)
@@ -59,7 +59,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError)
+		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
 	}
 	// pre-consume quota 预消耗配额
 	preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -74,18 +74,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	adaptor.Init(relayInfo)
 
 	convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
-
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}
 	jsonData, err := json.Marshal(convertedRequest)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}
 	requestBody := bytes.NewBuffer(jsonData)
 	statusCodeMappingStr := c.GetString("status_code_mapping")

+ 8 - 8
relay/gemini_handler.go

@@ -109,7 +109,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	req, err := getAndValidateGeminiRequest(c)
 	if err != nil {
 		common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
-		return types.NewError(err, types.ErrorCodeInvalidRequest)
+		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 
 	relayInfo := relaycommon.GenRelayInfoGemini(c)
@@ -121,14 +121,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		sensitiveWords, err := checkGeminiInputSensitive(req)
 		if err != nil {
 			common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
-			return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
+			return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
 		}
 	}
 
 	// model mapped 模型映射
 	err = helper.ModelMappedHelper(c, relayInfo, req)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 
 	if value, exists := c.Get("prompt_tokens"); exists {
@@ -159,7 +159,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError)
+		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
 	}
 
 	// pre consume quota
@@ -175,7 +175,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 
 	adaptor.Init(relayInfo)
@@ -198,13 +198,13 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
 		body, err := common.GetRequestBody(c)
 		if err != nil {
-			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
+			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
 		requestBody = bytes.NewReader(body)
 	} else {
 		jsonData, err := common.Marshal(req)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 
 		// apply param override
@@ -216,7 +216,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 			}
 			jsonData, err = common.Marshal(reqMap)
 			if err != nil {
-				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
+				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}
 		}
 

+ 10 - 10
relay/image_handler.go

@@ -115,17 +115,17 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	imageRequest, err := getAndValidImageRequest(c, relayInfo)
 	if err != nil {
 		common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
-		return types.NewError(err, types.ErrorCodeInvalidRequest)
+		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 
 	err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 
 	priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError)
+		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
 	}
 	var preConsumedQuota int
 	var quota int
@@ -173,16 +173,16 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
 		userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeQueryDataError)
+			return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
 		}
 		if userQuota-quota < 0 {
-			return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota)
+			return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry())
 		}
 	}
 
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	adaptor.Init(relayInfo)
 
@@ -191,20 +191,20 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
 		body, err := common.GetRequestBody(c)
 		if err != nil {
-			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
+			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
 		requestBody = bytes.NewBuffer(body)
 	} else {
 		convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 		if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
 			requestBody = convertedRequest.(io.Reader)
 		} else {
 			jsonData, err := json.Marshal(convertedRequest)
 			if err != nil {
-				return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+				return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 			}
 
 			// apply param override
@@ -216,7 +216,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 				}
 				jsonData, err = common.Marshal(reqMap)
 				if err != nil {
-					return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
+					return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 				}
 			}
 

+ 15 - 17
relay/relay-text.go

@@ -90,9 +90,8 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	// get & validate textRequest 获取并验证文本请求
 	textRequest, err := getAndValidateTextRequest(c, relayInfo)
-
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeInvalidRequest)
+		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 
 	if textRequest.WebSearchOptions != nil {
@@ -103,13 +102,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		words, err := checkRequestSensitive(textRequest, relayInfo)
 		if err != nil {
 			common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
-			return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
+			return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
 		}
 	}
 
 	err = helper.ModelMappedHelper(c, relayInfo, textRequest)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 
 	// 获取 promptTokens,如果上下文中已经存在,则直接使用
@@ -121,14 +120,14 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		promptTokens, err = getPromptTokens(textRequest, relayInfo)
 		// count messages token error 计算promptTokens错误
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeCountTokenFailed)
+			return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
 		}
 		c.Set("prompt_tokens", promptTokens)
 	}
 
 	priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError)
+		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
 	}
 
 	// pre-consume quota 预消耗配额
@@ -165,7 +164,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	adaptor.Init(relayInfo)
 	var requestBody io.Reader
@@ -173,7 +172,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
 		body, err := common.GetRequestBody(c)
 		if err != nil {
-			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
+			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
 		if common.DebugEnabled {
 			println("requestBody: ", string(body))
@@ -182,7 +181,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	} else {
 		convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 
 		if relayInfo.ChannelSetting.SystemPrompt != "" {
@@ -207,7 +206,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 		jsonData, err := common.Marshal(convertedRequest)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 
 		// apply param override
@@ -219,7 +218,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 			}
 			jsonData, err = common.Marshal(reqMap)
 			if err != nil {
-				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
+				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}
 		}
 
@@ -231,7 +230,6 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	var httpResp *http.Response
 	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
-
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 	}
@@ -304,13 +302,13 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom
 func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
 	userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
 	if err != nil {
-		return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError)
+		return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
 	}
 	if userQuota <= 0 {
-		return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
+		return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry())
 	}
 	if userQuota-preConsumedQuota < 0 {
-		return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
+		return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry())
 	}
 	relayInfo.UserQuota = userQuota
 	if userQuota > 100*preConsumedQuota {
@@ -334,11 +332,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 	if preConsumedQuota > 0 {
 		err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
 		if err != nil {
-			return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden)
+			return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry())
 		}
 		err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
 		if err != nil {
-			return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError)
+			return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
 		}
 	}
 	return preConsumedQuota, userQuota, nil

+ 10 - 10
relay/rerank_handler.go

@@ -31,21 +31,21 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
 	err := common.UnmarshalBodyReusable(c, &rerankRequest)
 	if err != nil {
 		common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
-		return types.NewError(err, types.ErrorCodeInvalidRequest)
+		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 
 	relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
 
 	if rerankRequest.Query == "" {
-		return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest)
+		return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 	if len(rerankRequest.Documents) == 0 {
-		return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest)
+		return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 
 	err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 
 	promptToken := getRerankPromptToken(*rerankRequest)
@@ -53,7 +53,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
 
 	priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError)
+		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
 	}
 	// pre-consume quota 预消耗配额
 	preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -68,7 +68,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
 
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	adaptor.Init(relayInfo)
 
@@ -76,17 +76,17 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
 		body, err := common.GetRequestBody(c)
 		if err != nil {
-			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
+			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
 		requestBody = bytes.NewBuffer(body)
 	} else {
 		convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 		jsonData, err := common.Marshal(convertedRequest)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 
 		// apply param override
@@ -98,7 +98,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
 			}
 			jsonData, err = common.Marshal(reqMap)
 			if err != nil {
-				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
+				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}
 		}
 

+ 10 - 10
relay/responses_handler.go

@@ -51,7 +51,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	req, err := getAndValidateResponsesRequest(c)
 	if err != nil {
 		common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
-		return types.NewError(err, types.ErrorCodeInvalidRequest)
+		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 
 	relayInfo := relaycommon.GenRelayInfoResponses(c, req)
@@ -60,13 +60,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		sensitiveWords, err := checkInputSensitive(req, relayInfo)
 		if err != nil {
 			common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
-			return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
+			return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
 		}
 	}
 
 	err = helper.ModelMappedHelper(c, relayInfo, req)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 
 	if value, exists := c.Get("prompt_tokens"); exists {
@@ -79,7 +79,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 	priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError)
+		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
 	}
 	// pre consume quota
 	preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -93,38 +93,38 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	}()
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	adaptor.Init(relayInfo)
 	var requestBody io.Reader
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
 		body, err := common.GetRequestBody(c)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeReadRequestBodyFailed)
+			return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
 		}
 		requestBody = bytes.NewBuffer(body)
 	} else {
 		convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 		jsonData, err := json.Marshal(convertedRequest)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 		// apply param override
 		if len(relayInfo.ParamOverride) > 0 {
 			reqMap := make(map[string]interface{})
 			err = json.Unmarshal(jsonData, &reqMap)
 			if err != nil {
-				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
+				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}
 			for key, value := range relayInfo.ParamOverride {
 				reqMap[key] = value
 			}
 			jsonData, err = json.Marshal(reqMap)
 			if err != nil {
-				return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+				return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 			}
 		}
 

+ 3 - 3
relay/websocket.go

@@ -24,12 +24,12 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr
 
 	err := helper.ModelMappedHelper(c, relayInfo, nil)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 
 	priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError)
+		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
 	}
 
 	// pre-consume quota 预消耗配额
@@ -46,7 +46,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr
 
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	adaptor.Init(relayInfo)
 	//var requestBody io.Reader

+ 1 - 1
service/channel.go

@@ -45,7 +45,7 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
 	if types.IsChannelError(err) {
 		return true
 	}
-	if types.IsLocalError(err) {
+	if types.IsSkipRetryError(err) {
 		return false
 	}
 	if err.StatusCode == http.StatusUnauthorized {

+ 40 - 14
types/error.go

@@ -78,6 +78,7 @@ const (
 type NewAPIError struct {
 	Err        error
 	RelayError any
+	skipRetry  bool
 	errorType  ErrorType
 	errorCode  ErrorCode
 	StatusCode int
@@ -170,33 +171,39 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
 	return result
 }
 
-func NewError(err error, errorCode ErrorCode) *NewAPIError {
-	return &NewAPIError{
+type NewAPIErrorOptions func(*NewAPIError)
+
+func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
+	e := &NewAPIError{
 		Err:        err,
 		RelayError: nil,
 		errorType:  ErrorTypeNewAPIError,
 		StatusCode: http.StatusInternalServerError,
 		errorCode:  errorCode,
 	}
+	for _, op := range ops {
+		op(e)
+	}
+	return e
 }
 
-func NewOpenAIError(err error, errorCode ErrorCode, statusCode int) *NewAPIError {
+func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
 	openaiError := OpenAIError{
 		Message: err.Error(),
 		Type:    string(errorCode),
 	}
-	return WithOpenAIError(openaiError, statusCode)
+	return WithOpenAIError(openaiError, statusCode, ops...)
 }
 
-func InitOpenAIError(errorCode ErrorCode, statusCode int) *NewAPIError {
+func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
 	openaiError := OpenAIError{
 		Type: string(errorCode),
 	}
-	return WithOpenAIError(openaiError, statusCode)
+	return WithOpenAIError(openaiError, statusCode, ops...)
 }
 
-func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError {
-	return &NewAPIError{
+func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
+	e := &NewAPIError{
 		Err: err,
 		RelayError: OpenAIError{
 			Message: err.Error(),
@@ -206,9 +213,14 @@ func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *New
 		StatusCode: statusCode,
 		errorCode:  errorCode,
 	}
+	for _, op := range ops {
+		op(e)
+	}
+
+	return e
 }
 
-func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError {
+func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
 	code, ok := openAIError.Code.(string)
 	if !ok {
 		code = fmt.Sprintf("%v", openAIError.Code)
@@ -216,26 +228,34 @@ func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError {
 	if openAIError.Type == "" {
 		openAIError.Type = "upstream_error"
 	}
-	return &NewAPIError{
+	e := &NewAPIError{
 		RelayError: openAIError,
 		errorType:  ErrorTypeOpenAIError,
 		StatusCode: statusCode,
 		Err:        errors.New(openAIError.Message),
 		errorCode:  ErrorCode(code),
 	}
+	for _, op := range ops {
+		op(e)
+	}
+	return e
 }
 
-func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError {
+func WithClaudeError(claudeError ClaudeError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
 	if claudeError.Type == "" {
 		claudeError.Type = "upstream_error"
 	}
-	return &NewAPIError{
+	e := &NewAPIError{
 		RelayError: claudeError,
 		errorType:  ErrorTypeClaudeError,
 		StatusCode: statusCode,
 		Err:        errors.New(claudeError.Message),
 		errorCode:  ErrorCode(claudeError.Type),
 	}
+	for _, op := range ops {
+		op(e)
+	}
+	return e
 }
 
 func IsChannelError(err *NewAPIError) bool {
@@ -245,10 +265,16 @@ func IsChannelError(err *NewAPIError) bool {
 	return strings.HasPrefix(string(err.errorCode), "channel:")
 }
 
-func IsLocalError(err *NewAPIError) bool {
+func IsSkipRetryError(err *NewAPIError) bool {
 	if err == nil {
 		return false
 	}
 
-	return err.errorType == ErrorTypeNewAPIError
+	return err.skipRetry
+}
+
+func ErrOptionWithSkipRetry() NewAPIErrorOptions {
+	return func(e *NewAPIError) {
+		e.skipRetry = true
+	}
 }