[email protected] 1 год назад
Родитель
Сommit
bfbbe67fcd

+ 30 - 0
common/str.go

@@ -1,5 +1,13 @@
 package common
 
+import (
+	"bytes"
+	"fmt"
+	goahocorasick "github.com/anknown/ahocorasick"
+	"one-api/constant"
+	"strings"
+)
+
 func SundaySearch(text string, pattern string) bool {
 	// 计算偏移表
 	offset := make(map[rune]int)
@@ -48,3 +56,25 @@ func RemoveDuplicate(s []string) []string {
 	}
 	return result
 }
+
+func InitAc() *goahocorasick.Machine {
+	m := new(goahocorasick.Machine)
+	dict := readRunes()
+	if err := m.Build(dict); err != nil {
+		fmt.Println(err)
+		return nil
+	}
+	return m
+}
+
+func readRunes() [][]rune {
+	var dict [][]rune
+
+	for _, word := range constant.SensitiveWords {
+		word = strings.ToLower(word)
+		l := bytes.TrimSpace([]byte(word))
+		dict = append(dict, bytes.Runes(l))
+	}
+
+	return dict
+}

+ 1 - 1
constant/sensitive.go

@@ -16,7 +16,7 @@ var StreamCacheQueueLength = 0
 // SensitiveWords 敏感词
 // var SensitiveWords []string
 var SensitiveWords = []string{
-	"test",
+	"test_sensitive",
 }
 
 func SensitiveWordsToString() string {

+ 1 - 1
relay/channel/claude/relay-claude.go

@@ -370,7 +370,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
 		}, nil
 	}
 	fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
-	completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
+	completionTokens, err := service.CountTokenText(claudeResponse.Completion, model)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
 	}

+ 1 - 1
relay/channel/gemini/relay-gemini.go

@@ -256,7 +256,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
 		}, nil
 	}
 	fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
-	completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false)
+	completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
 	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,

+ 1 - 1
relay/channel/openai/relay-openai.go

@@ -190,7 +190,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 	if simpleResponse.Usage.TotalTokens == 0 {
 		completionTokens := 0
 		for _, choice := range simpleResponse.Choices {
-			ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false)
+			ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
 			completionTokens += ctkm
 		}
 		simpleResponse.Usage = dto.Usage{

+ 1 - 1
relay/channel/palm/relay-palm.go

@@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 		}, nil
 	}
 	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
-	completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, false)
+	completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model)
 	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,

+ 8 - 2
relay/relay-audio.go

@@ -55,7 +55,13 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 	promptTokens := 0
 	preConsumedTokens := common.PreConsumedQuota
 	if strings.HasPrefix(audioRequest.Model, "tts-1") {
-		promptTokens, err, _ = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
+		if constant.ShouldCheckPromptSensitive() {
+			err = service.CheckSensitiveInput(audioRequest.Input)
+			if err != nil {
+				return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
+			}
+		}
+		promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
 		if err != nil {
 			return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
 		}
@@ -178,7 +184,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 			if strings.HasPrefix(audioRequest.Model, "tts-1") {
 				quota = promptTokens
 			} else {
-				quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, false)
+				quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
 			}
 			quota = int(float64(quota) * ratio)
 			if ratio != 0 && quota <= 0 {

+ 8 - 0
relay/relay-image.go

@@ -10,6 +10,7 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/model"
 	relaycommon "one-api/relay/common"
@@ -47,6 +48,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
 		return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
 	}
 
+	if constant.ShouldCheckPromptSensitive() {
+		err = service.CheckSensitiveInput(imageRequest.Prompt)
+		if err != nil {
+			return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
+		}
+	}
+
 	if strings.Contains(imageRequest.Size, "×") {
 		return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
 	}

+ 33 - 16
relay/relay-text.go

@@ -98,13 +98,17 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	var ratio float64
 	var modelRatio float64
 	//err := service.SensitiveWordsCheck(textRequest)
-	promptTokens, err, sensitiveTrigger := getPromptTokens(textRequest, relayInfo)
 
-	// count messages token error 计算promptTokens错误
-	if err != nil {
-		if sensitiveTrigger {
+	if constant.ShouldCheckPromptSensitive() {
+		err = checkRequestSensitive(textRequest, relayInfo)
+		if err != nil {
 			return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
 		}
+	}
+
+	promptTokens, err := getPromptTokens(textRequest, relayInfo)
+	// count messages token error 计算promptTokens错误
+	if err != nil {
 		return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
 	}
 
@@ -128,7 +132,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
-		return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
+		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 	}
 	adaptor.Init(relayInfo, *textRequest)
 	var requestBody io.Reader
@@ -136,7 +140,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		if isModelMapped {
 			jsonStr, err := json.Marshal(textRequest)
 			if err != nil {
-				return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
+				return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
 			}
 			requestBody = bytes.NewBuffer(jsonStr)
 		} else {
@@ -145,11 +149,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	} else {
 		convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
 		if err != nil {
-			return service.OpenAIErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
+			return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
 		}
 		jsonData, err := json.Marshal(convertedRequest)
 		if err != nil {
-			return service.OpenAIErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
+			return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
 		}
 		requestBody = bytes.NewBuffer(jsonData)
 	}
@@ -182,26 +186,39 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	return nil
 }
 
-func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error, bool) {
+func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
 	var promptTokens int
 	var err error
-	var sensitiveTrigger bool
-	checkSensitive := constant.ShouldCheckPromptSensitive()
 	switch info.RelayMode {
 	case relayconstant.RelayModeChatCompletions:
-		promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive)
+		promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
 	case relayconstant.RelayModeCompletions:
-		promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
+		promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
 	case relayconstant.RelayModeModerations:
-		promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
+		promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
 	case relayconstant.RelayModeEmbeddings:
-		promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
+		promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
 	default:
 		err = errors.New("unknown relay mode")
 		promptTokens = 0
 	}
 	info.PromptTokens = promptTokens
-	return promptTokens, err, sensitiveTrigger
+	return promptTokens, err
+}
+
+func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
+	var err error
+	switch info.RelayMode {
+	case relayconstant.RelayModeChatCompletions:
+		err = service.CheckSensitiveMessages(textRequest.Messages)
+	case relayconstant.RelayModeCompletions:
+		err = service.CheckSensitiveInput(textRequest.Prompt)
+	case relayconstant.RelayModeModerations:
+		err = service.CheckSensitiveInput(textRequest.Input)
+	case relayconstant.RelayModeEmbeddings:
+		err = service.CheckSensitiveInput(textRequest.Input)
+	}
+	return err
 }
 
 // 预扣费并返回用户剩余配额

+ 51 - 26
service/sensitive.go

@@ -1,13 +1,60 @@
 package service
 
 import (
-	"bytes"
+	"errors"
 	"fmt"
-	"github.com/anknown/ahocorasick"
+	"one-api/common"
 	"one-api/constant"
+	"one-api/dto"
 	"strings"
 )
 
+func CheckSensitiveMessages(messages []dto.Message) error {
+	for _, message := range messages {
+		if len(message.Content) > 0 {
+			if message.IsStringContent() {
+				stringContent := message.StringContent()
+				if ok, words := SensitiveWordContains(stringContent); ok {
+					return errors.New("sensitive words: " + strings.Join(words, ","))
+				}
+			}
+		} else {
+			arrayContent := message.ParseContent()
+			for _, m := range arrayContent {
+				if m.Type == "image_url" {
+					// TODO: check image url
+				} else {
+					if ok, words := SensitiveWordContains(m.Text); ok {
+						return errors.New("sensitive words: " + strings.Join(words, ","))
+					}
+				}
+			}
+		}
+	}
+	return nil
+}
+
+func CheckSensitiveText(text string) error {
+	if ok, words := SensitiveWordContains(text); ok {
+		return errors.New("sensitive words: " + strings.Join(words, ","))
+	}
+	return nil
+}
+
+func CheckSensitiveInput(input any) error {
+	switch v := input.(type) {
+	case string:
+		return CheckSensitiveText(v)
+	case []string:
+		text := ""
+		for _, s := range v {
+			text += s
+		}
+		return CheckSensitiveText(text)
+	}
+	return CheckSensitiveText(fmt.Sprintf("%v", input))
+}
+
 // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
 func SensitiveWordContains(text string) (bool, []string) {
 	if len(constant.SensitiveWords) == 0 {
@@ -15,7 +62,7 @@ func SensitiveWordContains(text string) (bool, []string) {
 	}
 	checkText := strings.ToLower(text)
 	// 构建一个AC自动机
-	m := initAc()
+	m := common.InitAc()
 	hits := m.MultiPatternSearch([]rune(checkText), false)
 	if len(hits) > 0 {
 		words := make([]string, 0)
@@ -33,7 +80,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
 		return false, nil, text
 	}
 	checkText := strings.ToLower(text)
-	m := initAc()
+	m := common.InitAc()
 	hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
 	if len(hits) > 0 {
 		words := make([]string, 0)
@@ -47,25 +94,3 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
 	}
 	return false, nil, text
 }
-
-func initAc() *goahocorasick.Machine {
-	m := new(goahocorasick.Machine)
-	dict := readRunes()
-	if err := m.Build(dict); err != nil {
-		fmt.Println(err)
-		return nil
-	}
-	return m
-}
-
-func readRunes() [][]rune {
-	var dict [][]rune
-
-	for _, word := range constant.SensitiveWords {
-		word = strings.ToLower(word)
-		l := bytes.TrimSpace([]byte(word))
-		dict = append(dict, bytes.Runes(l))
-	}
-
-	return dict
-}

+ 22 - 41
service/token_counter.go

@@ -125,11 +125,11 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
 	return tiles*170 + 85, nil
 }
 
-func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
+func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {
 	tkm := 0
-	msgTokens, err, b := CountTokenMessages(request.Messages, model, request.Stream, checkSensitive)
+	msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream)
 	if err != nil {
-		return 0, err, b
+		return 0, err
 	}
 	tkm += msgTokens
 	if request.Tools != nil {
@@ -137,7 +137,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
 		var openaiTools []dto.OpenAITools
 		err := json.Unmarshal(toolsData, &openaiTools)
 		if err != nil {
-			return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false
+			return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
 		}
 		countStr := ""
 		for _, tool := range openaiTools {
@@ -149,18 +149,18 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
 				countStr += fmt.Sprintf("%v", tool.Function.Parameters)
 			}
 		}
-		toolTokens, err, _ := CountTokenInput(countStr, model, false)
+		toolTokens, err := CountTokenInput(countStr, model)
 		if err != nil {
-			return 0, err, false
+			return 0, err
 		}
 		tkm += 8
 		tkm += toolTokens
 	}
 
-	return tkm, nil, false
+	return tkm, nil
 }
 
-func CountTokenMessages(messages []dto.Message, model string, stream bool, checkSensitive bool) (int, error, bool) {
+func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
 	//recover when panic
 	tokenEncoder := getTokenEncoder(model)
 	// Reference:
@@ -184,13 +184,6 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
 		if len(message.Content) > 0 {
 			if message.IsStringContent() {
 				stringContent := message.StringContent()
-				if checkSensitive {
-					contains, words := SensitiveWordContains(stringContent)
-					if contains {
-						err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
-						return 0, err, true
-					}
-				}
 				tokenNum += getTokenNum(tokenEncoder, stringContent)
 				if message.Name != nil {
 					tokenNum += tokensPerName
@@ -203,7 +196,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
 						imageUrl := m.ImageUrl.(dto.MessageImageUrl)
 						imageTokenNum, err := getImageToken(&imageUrl, model, stream)
 						if err != nil {
-							return 0, err, false
+							return 0, err
 						}
 						tokenNum += imageTokenNum
 						log.Printf("image token num: %d", imageTokenNum)
@@ -215,33 +208,33 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
 		}
 	}
 	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
-	return tokenNum, nil, false
+	return tokenNum, nil
 }
 
-func CountTokenInput(input any, model string, check bool) (int, error, bool) {
+func CountTokenInput(input any, model string) (int, error) {
 	switch v := input.(type) {
 	case string:
-		return CountTokenText(v, model, check)
+		return CountTokenText(v, model)
 	case []string:
 		text := ""
 		for _, s := range v {
 			text += s
 		}
-		return CountTokenText(text, model, check)
+		return CountTokenText(text, model)
 	}
-	return CountTokenInput(fmt.Sprintf("%v", input), model, check)
+	return CountTokenInput(fmt.Sprintf("%v", input), model)
 }
 
 func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
 	tokens := 0
 	for _, message := range messages {
-		tkm, _, _ := CountTokenInput(message.Delta.GetContentString(), model, false)
+		tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
 		tokens += tkm
 		if message.Delta.ToolCalls != nil {
 			for _, tool := range message.Delta.ToolCalls {
-				tkm, _, _ := CountTokenInput(tool.Function.Name, model, false)
+				tkm, _ := CountTokenInput(tool.Function.Name, model)
 				tokens += tkm
-				tkm, _, _ = CountTokenInput(tool.Function.Arguments, model, false)
+				tkm, _ = CountTokenInput(tool.Function.Arguments, model)
 				tokens += tkm
 			}
 		}
@@ -249,29 +242,17 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
 	return tokens
 }
 
-func CountAudioToken(text string, model string, check bool) (int, error, bool) {
+func CountAudioToken(text string, model string) (int, error) {
 	if strings.HasPrefix(model, "tts") {
-		contains, words := SensitiveWordContains(text)
-		if contains {
-			return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true
-		}
-		return utf8.RuneCountInString(text), nil, false
+		return utf8.RuneCountInString(text), nil
 	} else {
-		return CountTokenText(text, model, check)
+		return CountTokenText(text, model)
 	}
 }
 
 // CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
-func CountTokenText(text string, model string, check bool) (int, error, bool) {
+func CountTokenText(text string, model string) (int, error) {
 	var err error
-	var trigger bool
-	if check {
-		contains, words := SensitiveWordContains(text)
-		if contains {
-			err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
-			trigger = true
-		}
-	}
 	tokenEncoder := getTokenEncoder(model)
-	return getTokenNum(tokenEncoder, text), err, trigger
+	return getTokenNum(tokenEncoder, text), err
 }

+ 1 - 1
service/usage_helpr.go

@@ -19,7 +19,7 @@ import (
 func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
 	usage := &dto.Usage{}
 	usage.PromptTokens = promptTokens
-	ctkm, err, _ := CountTokenText(responseText, modeName, false)
+	ctkm, err := CountTokenText(responseText, modeName)
 	usage.CompletionTokens = ctkm
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	return usage, err