|
|
@@ -67,7 +67,11 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|
|
return len(tokenEncoder.Encode(text, nil, nil))
|
|
|
}
|
|
|
|
|
|
-func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
|
|
|
+func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|
|
|
+ // TODO: 非流模式下不计算图片token数量
|
|
|
+ if model == "glm-4v" {
|
|
|
+ return 1047, nil
|
|
|
+ }
|
|
|
if imageUrl.Detail == "low" {
|
|
|
return 85, nil
|
|
|
}
|
|
|
@@ -123,7 +127,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
|
|
|
|
|
|
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
|
|
|
tkm := 0
|
|
|
- msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive)
|
|
|
+ msgTokens, err, b := CountTokenMessages(request.Messages, model, request.Stream, checkSensitive)
|
|
|
if err != nil {
|
|
|
return 0, err, b
|
|
|
}
|
|
|
@@ -156,7 +160,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
|
|
|
return tkm, nil, false
|
|
|
}
|
|
|
|
|
|
-func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
|
|
|
+func CountTokenMessages(messages []dto.Message, model string, stream bool, checkSensitive bool) (int, error, bool) {
|
|
|
//recover when panic
|
|
|
tokenEncoder := getTokenEncoder(model)
|
|
|
// Reference:
|
|
|
@@ -193,19 +197,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
|
|
|
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
|
|
}
|
|
|
} else {
|
|
|
- var err error
|
|
|
arrayContent := message.ParseContent()
|
|
|
for _, m := range arrayContent {
|
|
|
if m.Type == "image_url" {
|
|
|
- var imageTokenNum int
|
|
|
- if model == "glm-4v" {
|
|
|
- imageTokenNum = 1047
|
|
|
- } else {
|
|
|
- imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
|
|
- imageTokenNum, err = getImageToken(&imageUrl)
|
|
|
- if err != nil {
|
|
|
- return 0, err, false
|
|
|
- }
|
|
|
+ imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
|
|
+ imageTokenNum, err := getImageToken(&imageUrl, model, stream)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err, false
|
|
|
}
|
|
|
tokenNum += imageTokenNum
|
|
|
log.Printf("image token num: %d", imageTokenNum)
|