ソースを参照

feat: 统一错误提示

CaIon 1 年間 前
コミット
a232afe9fd

+ 1 - 1
relay/channel/claude/relay-claude.go

@@ -313,7 +313,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
 		}, nil
 	}
 	fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
-	completionTokens, err := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
+	completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, constant.ShouldCheckCompletionSensitive())
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
 	}

+ 1 - 1
relay/channel/gemini/relay-gemini.go

@@ -257,7 +257,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
 		}, nil
 	}
 	fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
-	completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
+	completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, constant.ShouldCheckCompletionSensitive())
 	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,

+ 1 - 1
relay/channel/openai/relay-openai.go

@@ -154,7 +154,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 		completionTokens := 0
 		for _, choice := range textResponse.Choices {
 			stringContent := string(choice.Message.Content)
-			ctkm, _ := service.CountTokenText(stringContent, model, false)
+			ctkm, _, _ := service.CountTokenText(stringContent, model, false)
 			completionTokens += ctkm
 			if checkSensitive {
 				sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)

+ 1 - 1
relay/channel/palm/relay-palm.go

@@ -157,7 +157,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 		}, nil
 	}
 	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
-	completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
+	completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, constant.ShouldCheckCompletionSensitive())
 	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,

+ 2 - 2
relay/relay-audio.go

@@ -67,7 +67,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 	promptTokens := 0
 	preConsumedTokens := common.PreConsumedQuota
 	if strings.HasPrefix(audioRequest.Model, "tts-1") {
-		promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
+		promptTokens, err, _ = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
 		if err != nil {
 			return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
 		}
@@ -173,7 +173,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 			if strings.HasPrefix(audioRequest.Model, "tts-1") {
 				quota = promptTokens
 			} else {
-				quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
+				quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, constant.ShouldCheckCompletionSensitive())
 			}
 			quota = int(float64(quota) * ratio)
 			if ratio != 0 && quota <= 0 {

+ 11 - 7
relay/relay-text.go

@@ -98,10 +98,13 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	var ratio float64
 	var modelRatio float64
 	//err := service.SensitiveWordsCheck(textRequest)
-	promptTokens, err := getPromptTokens(textRequest, relayInfo)
+	promptTokens, err, sensitiveTrigger := getPromptTokens(textRequest, relayInfo)
 
 	// count messages token error 计算promptTokens错误
 	if err != nil {
+		if sensitiveTrigger {
+			return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
+		}
 		return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
 	}
 
@@ -180,25 +183,26 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	return nil
 }
 
-func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
+func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error, bool) {
 	var promptTokens int
 	var err error
+	var sensitiveTrigger bool
 	checkSensitive := constant.ShouldCheckPromptSensitive()
 	switch info.RelayMode {
 	case relayconstant.RelayModeChatCompletions:
-		promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
+		promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
 	case relayconstant.RelayModeCompletions:
-		promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
+		promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
 	case relayconstant.RelayModeModerations:
-		promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
+		promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
 	case relayconstant.RelayModeEmbeddings:
-		promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
+		promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
 	default:
 		err = errors.New("unknown relay mode")
 		promptTokens = 0
 	}
 	info.PromptTokens = promptTokens
-	return promptTokens, err
+	return promptTokens, err, sensitiveTrigger
 }
 
 // 预扣费并返回用户剩余配额

+ 17 - 11
service/token_counter.go

@@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
 	return tiles*170 + 85, nil
 }
 
-func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) {
+func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
 	//recover when panic
 	tokenEncoder := getTokenEncoder(model)
 	// Reference:
@@ -142,13 +142,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
 			if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
 				var stringContent string
 				if err := json.Unmarshal(message.Content, &stringContent); err != nil {
-					return 0, err
+					return 0, err, false
 				} else {
 					if checkSensitive {
 						contains, words := SensitiveWordContains(stringContent)
 						if contains {
 							err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
-							return 0, err
+							return 0, err, true
 						}
 					}
 					tokenNum += getTokenNum(tokenEncoder, stringContent)
@@ -181,7 +181,7 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
 								imageTokenNum, err = getImageToken(&imageUrl)
 							}
 							if err != nil {
-								return 0, err
+								return 0, err, false
 							}
 						}
 						tokenNum += imageTokenNum
@@ -194,10 +194,10 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
 		}
 	}
 	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
-	return tokenNum, nil
+	return tokenNum, nil, false
 }
 
-func CountTokenInput(input any, model string, check bool) (int, error) {
+func CountTokenInput(input any, model string, check bool) (int, error, bool) {
 	switch v := input.(type) {
 	case string:
 		return CountTokenText(v, model, check)
@@ -208,26 +208,32 @@ func CountTokenInput(input any, model string, check bool) (int, error) {
 		}
 		return CountTokenText(text, model, check)
 	}
-	return 0, errors.New("unsupported input type")
+	return 0, errors.New("unsupported input type"), false
 }
 
-func CountAudioToken(text string, model string, check bool) (int, error) {
+func CountAudioToken(text string, model string, check bool) (int, error, bool) {
 	if strings.HasPrefix(model, "tts") {
-		return utf8.RuneCountInString(text), nil
+		contains, words := SensitiveWordContains(text)
+		if contains {
+			return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true
+		}
+		return utf8.RuneCountInString(text), nil, false
 	} else {
 		return CountTokenText(text, model, check)
 	}
 }
 
 // CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
-func CountTokenText(text string, model string, check bool) (int, error) {
+func CountTokenText(text string, model string, check bool) (int, error, bool) {
 	var err error
+	var trigger bool
 	if check {
 		contains, words := SensitiveWordContains(text)
 		if contains {
 			err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
+			trigger = true
 		}
 	}
 	tokenEncoder := getTokenEncoder(model)
-	return getTokenNum(tokenEncoder, text), err
+	return getTokenNum(tokenEncoder, text), err, trigger
 }

+ 1 - 1
service/usage_helpr.go

@@ -19,7 +19,7 @@ import (
 func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
 	usage := &dto.Usage{}
 	usage.PromptTokens = promptTokens
-	ctkm, err := CountTokenText(responseText, modeName, false)
+	ctkm, err, _ := CountTokenText(responseText, modeName, false)
 	usage.CompletionTokens = ctkm
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	return usage, err