Browse Source

feat: update token encoder

[email protected] 1 year ago
parent
commit
ecdcb379fe
1 changed files with 17 additions and 7 deletions
  1. 17 7
      service/token_counter.go

+ 17 - 7
service/token_counter.go

@@ -17,6 +17,7 @@ import (
 // tokenEncoderMap won't grow after initialization
 var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 var defaultTokenEncoder *tiktoken.Tiktoken
+var cl200kTokenEncoder *tiktoken.Tiktoken
 
 func InitTokenEncoders() {
 	common.SysLog("initializing token encoders")
@@ -29,7 +30,7 @@ func InitTokenEncoders() {
 	if err != nil {
 		common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
 	}
-	gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o")
+	cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
 	if err != nil {
 		common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
 	}
@@ -38,7 +39,7 @@ func InitTokenEncoders() {
 			tokenEncoderMap[model] = gpt35TokenEncoder
 		} else if strings.HasPrefix(model, "gpt-4") {
 			if strings.HasPrefix(model, "gpt-4o") {
-				tokenEncoderMap[model] = gpt4oTokenEncoder
+				tokenEncoderMap[model] = cl200kTokenEncoder
 			} else {
 				tokenEncoderMap[model] = gpt4TokenEncoder
 			}
@@ -49,21 +50,30 @@ func InitTokenEncoders() {
 	common.SysLog("token encoders initialized")
 }
 
+func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
+	if strings.HasPrefix(model, "gpt-4o") {
+		return cl200kTokenEncoder
+	}
+	return defaultTokenEncoder
+}
+
 func getTokenEncoder(model string) *tiktoken.Tiktoken {
 	tokenEncoder, ok := tokenEncoderMap[model]
 	if ok && tokenEncoder != nil {
 		return tokenEncoder
 	}
+	// 如果ok(即model在tokenEncoderMap中),但是tokenEncoder为nil,说明可能是自定义模型
 	if ok {
 		tokenEncoder, err := tiktoken.EncodingForModel(model)
 		if err != nil {
 			common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
-			tokenEncoder = defaultTokenEncoder
+			tokenEncoder = getModelDefaultTokenEncoder(model)
 		}
 		tokenEncoderMap[model] = tokenEncoder
 		return tokenEncoder
 	}
-	return defaultTokenEncoder
+	// 如果model不在tokenEncoderMap中,直接返回默认的tokenEncoder
+	return getModelDefaultTokenEncoder(model)
 }
 
 func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
@@ -75,13 +85,13 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
 	if model == "glm-4v" {
 		return 1047, nil
 	}
+	if imageUrl.Detail == "low" {
+		return 85, nil
+	}
 	// 同步One API的图片计费逻辑
 	if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
 		imageUrl.Detail = "high"
 	}
-	if imageUrl.Detail == "low" {
-		return 85, nil
-	}
 	var config image.Config
 	var err error
 	var format string