CaIon 1 год назад
Родитель
Сommit
a3de309175
1 измененных файлов с 11 добавлено и 13 удалено
  1. 11 13
      service/token_counter.go

+ 11 - 13
service/token_counter.go

@@ -67,7 +67,11 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
 	return len(tokenEncoder.Encode(text, nil, nil))
 }
 
-func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
+func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
+	// TODO: 非流模式下不计算图片token数量
+	if model == "glm-4v" {
+		return 1047, nil
+	}
 	if imageUrl.Detail == "low" {
 		return 85, nil
 	}
@@ -123,7 +127,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
 
 func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
 	tkm := 0
-	msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive)
+	msgTokens, err, b := CountTokenMessages(request.Messages, model, request.Stream, checkSensitive)
 	if err != nil {
 		return 0, err, b
 	}
@@ -156,7 +160,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, check
 	return tkm, nil, false
 }
 
-func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
+func CountTokenMessages(messages []dto.Message, model string, stream bool, checkSensitive bool) (int, error, bool) {
 	//recover when panic
 	tokenEncoder := getTokenEncoder(model)
 	// Reference:
@@ -193,19 +197,13 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
 					tokenNum += getTokenNum(tokenEncoder, *message.Name)
 				}
 			} else {
-				var err error
 				arrayContent := message.ParseContent()
 				for _, m := range arrayContent {
 					if m.Type == "image_url" {
-						var imageTokenNum int
-						if model == "glm-4v" {
-							imageTokenNum = 1047
-						} else {
-							imageUrl := m.ImageUrl.(dto.MessageImageUrl)
-							imageTokenNum, err = getImageToken(&imageUrl)
-							if err != nil {
-								return 0, err, false
-							}
+						imageUrl := m.ImageUrl.(dto.MessageImageUrl)
+						imageTokenNum, err := getImageToken(&imageUrl, model, stream)
+						if err != nil {
+							return 0, err, false
 						}
 						tokenNum += imageTokenNum
 						log.Printf("image token num: %d", imageTokenNum)