|
|
@@ -116,7 +116,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
|
|
|
return tiles*170 + 85, nil
|
|
|
}
|
|
|
|
|
|
-func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error) {
|
|
|
+func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
|
|
|
//recover when panic
|
|
|
tokenEncoder := getTokenEncoder(model)
|
|
|
// Reference:
|
|
|
@@ -142,13 +142,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
|
|
|
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
|
|
|
var stringContent string
|
|
|
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
|
|
|
- return 0, err
|
|
|
+ return 0, err, false
|
|
|
} else {
|
|
|
if checkSensitive {
|
|
|
contains, words := SensitiveWordContains(stringContent)
|
|
|
if contains {
|
|
|
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
|
|
- return 0, err
|
|
|
+ return 0, err, true
|
|
|
}
|
|
|
}
|
|
|
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
|
|
@@ -181,7 +181,7 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
|
|
|
imageTokenNum, err = getImageToken(&imageUrl)
|
|
|
}
|
|
|
if err != nil {
|
|
|
- return 0, err
|
|
|
+ return 0, err, false
|
|
|
}
|
|
|
}
|
|
|
tokenNum += imageTokenNum
|
|
|
@@ -194,10 +194,10 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
|
|
|
}
|
|
|
}
|
|
|
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
|
|
- return tokenNum, nil
|
|
|
+ return tokenNum, nil, false
|
|
|
}
|
|
|
|
|
|
-func CountTokenInput(input any, model string, check bool) (int, error) {
|
|
|
+func CountTokenInput(input any, model string, check bool) (int, error, bool) {
|
|
|
switch v := input.(type) {
|
|
|
case string:
|
|
|
return CountTokenText(v, model, check)
|
|
|
@@ -208,26 +208,32 @@ func CountTokenInput(input any, model string, check bool) (int, error) {
|
|
|
}
|
|
|
return CountTokenText(text, model, check)
|
|
|
}
|
|
|
- return 0, errors.New("unsupported input type")
|
|
|
+ return 0, errors.New("unsupported input type"), false
|
|
|
}
|
|
|
|
|
|
-func CountAudioToken(text string, model string, check bool) (int, error) {
|
|
|
+func CountAudioToken(text string, model string, check bool) (int, error, bool) {
|
|
|
if strings.HasPrefix(model, "tts") {
|
|
|
- return utf8.RuneCountInString(text), nil
|
|
|
+ contains, words := SensitiveWordContains(text)
|
|
|
+ if contains {
|
|
|
+ return utf8.RuneCountInString(text), fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ",")), true
|
|
|
+ }
|
|
|
+ return utf8.RuneCountInString(text), nil, false
|
|
|
} else {
|
|
|
return CountTokenText(text, model, check)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
|
|
-func CountTokenText(text string, model string, check bool) (int, error) {
|
|
|
+func CountTokenText(text string, model string, check bool) (int, error, bool) {
|
|
|
var err error
|
|
|
+ var trigger bool
|
|
|
if check {
|
|
|
contains, words := SensitiveWordContains(text)
|
|
|
if contains {
|
|
|
err = fmt.Errorf("input contains sensitive words: [%s]", strings.Join(words, ","))
|
|
|
+ trigger = true
|
|
|
}
|
|
|
}
|
|
|
tokenEncoder := getTokenEncoder(model)
|
|
|
- return getTokenNum(tokenEncoder, text), err
|
|
|
+ return getTokenNum(tokenEncoder, text), err, trigger
|
|
|
}
|