|
|
@@ -125,11 +125,11 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
|
|
return tiles*170 + 85, nil
|
|
|
}
|
|
|
|
|
|
-func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
|
|
|
+func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {
|
|
|
tkm := 0
|
|
|
- msgTokens, err, b := CountTokenMessages(request.Messages, model, request.Stream, checkSensitive)
|
|
|
+ msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream)
|
|
|
if err != nil {
|
|
|
- return 0, err, b
|
|
|
+ return 0, err
|
|
|
}
|
|
|
tkm += msgTokens
|
|
|
if request.Tools != nil {
|
|
|
@@ -137,7 +137,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
|
|
|
var openaiTools []dto.OpenAITools
|
|
|
err := json.Unmarshal(toolsData, &openaiTools)
|
|
|
if err != nil {
|
|
|
- return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false
|
|
|
+ return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
|
|
|
}
|
|
|
countStr := ""
|
|
|
for _, tool := range openaiTools {
|
|
|
@@ -149,18 +149,18 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
|
|
|
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
|
|
}
|
|
|
}
|
|
|
- toolTokens, err, _ := CountTokenInput(countStr, model, false)
|
|
|
+ toolTokens, err := CountTokenInput(countStr, model)
|
|
|
if err != nil {
|
|
|
- return 0, err, false
|
|
|
+ return 0, err
|
|
|
}
|
|
|
tkm += 8
|
|
|
tkm += toolTokens
|
|
|
}
|
|
|
|
|
|
- return tkm, nil, false
|
|
|
+ return tkm, nil
|
|
|
}
|
|
|
|
|
|
-func CountTokenMessages(messages []dto.Message, model string, stream bool, checkSensitive bool) (int, error, bool) {
|
|
|
+func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
|
|
//recover when panic
|
|
|
tokenEncoder := getTokenEncoder(model)
|
|
|
// Reference:
|
|
|
@@ -184,13 +184,6 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
|
|
|
if len(message.Content) > 0 {
|
|
|
if message.IsStringContent() {
|
|
|
stringContent := message.StringContent()
|
|
|
- if checkSensitive {
|
|
|
- contains, words := SensitiveWordContains(stringContent)
|
|
|
- if contains {
|
|
|
- err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
|
|
- return 0, err, true
|
|
|
- }
|
|
|
- }
|
|
|
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
|
|
if message.Name != nil {
|
|
|
tokenNum += tokensPerName
|
|
|
@@ -203,7 +196,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
|
|
|
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
|
|
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
|
|
|
if err != nil {
|
|
|
- return 0, err, false
|
|
|
+ return 0, err
|
|
|
}
|
|
|
tokenNum += imageTokenNum
|
|
|
log.Printf("image token num: %d", imageTokenNum)
|
|
|
@@ -215,33 +208,33 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool, check
|
|
|
}
|
|
|
}
|
|
|
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
|
|
- return tokenNum, nil, false
|
|
|
+ return tokenNum, nil
|
|
|
}
|
|
|
|
|
|
-func CountTokenInput(input any, model string, check bool) (int, error, bool) {
|
|
|
+func CountTokenInput(input any, model string) (int, error) {
|
|
|
switch v := input.(type) {
|
|
|
case string:
|
|
|
- return CountTokenText(v, model, check)
|
|
|
+ return CountTokenText(v, model)
|
|
|
case []string:
|
|
|
text := ""
|
|
|
for _, s := range v {
|
|
|
text += s
|
|
|
}
|
|
|
- return CountTokenText(text, model, check)
|
|
|
+ return CountTokenText(text, model)
|
|
|
}
|
|
|
- return CountTokenInput(fmt.Sprintf("%v", input), model, check)
|
|
|
+ return CountTokenInput(fmt.Sprintf("%v", input), model)
|
|
|
}
|
|
|
|
|
|
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
|
|
tokens := 0
|
|
|
for _, message := range messages {
|
|
|
- tkm, _, _ := CountTokenInput(message.Delta.GetContentString(), model, false)
|
|
|
+ tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
|
|
|
tokens += tkm
|
|
|
if message.Delta.ToolCalls != nil {
|
|
|
for _, tool := range message.Delta.ToolCalls {
|
|
|
- tkm, _, _ := CountTokenInput(tool.Function.Name, model, false)
|
|
|
+ tkm, _ := CountTokenInput(tool.Function.Name, model)
|
|
|
tokens += tkm
|
|
|
- tkm, _, _ = CountTokenInput(tool.Function.Arguments, model, false)
|
|
|
+ tkm, _ = CountTokenInput(tool.Function.Arguments, model)
|
|
|
tokens += tkm
|
|
|
}
|
|
|
}
|
|
|
@@ -249,29 +242,17 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
|
|
return tokens
|
|
|
}
|
|
|
|
|
|
-func CountAudioToken(text string, model string, check bool) (int, error, bool) {
|
|
|
+func CountAudioToken(text string, model string) (int, error) {
|
|
|
if strings.HasPrefix(model, "tts") {
|
|
|
- contains, words := SensitiveWordContains(text)
|
|
|
- if contains {
|
|
|
- return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true
|
|
|
- }
|
|
|
- return utf8.RuneCountInString(text), nil, false
|
|
|
+ return utf8.RuneCountInString(text), nil
|
|
|
} else {
|
|
|
- return CountTokenText(text, model, check)
|
|
|
+ return CountTokenText(text, model)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
|
|
-func CountTokenText(text string, model string, check bool) (int, error, bool) {
|
|
|
+func CountTokenText(text string, model string) (int, error) {
|
|
|
var err error
|
|
|
- var trigger bool
|
|
|
- if check {
|
|
|
- contains, words := SensitiveWordContains(text)
|
|
|
- if contains {
|
|
|
- err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
|
|
|
- trigger = true
|
|
|
- }
|
|
|
- }
|
|
|
tokenEncoder := getTokenEncoder(model)
|
|
|
- return getTokenNum(tokenEncoder, text), err, trigger
|
|
|
+ return getTokenNum(tokenEncoder, text), err
|
|
|
}
|