소스 검색

feat: update token encoder

[email protected] 1 년 전
부모
커밋
ecdcb379fe
1개의 변경된 파일17개의 추가작업 그리고 7개의 파일을 삭제
  1. 17 7
      service/token_counter.go

+ 17 - 7
service/token_counter.go

@@ -17,6 +17,7 @@ import (
 // tokenEncoderMap won't grow after initialization
 // tokenEncoderMap won't grow after initialization
 var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 var defaultTokenEncoder *tiktoken.Tiktoken
 var defaultTokenEncoder *tiktoken.Tiktoken
+var cl200kTokenEncoder *tiktoken.Tiktoken
 
 
 func InitTokenEncoders() {
 func InitTokenEncoders() {
 	common.SysLog("initializing token encoders")
 	common.SysLog("initializing token encoders")
@@ -29,7 +30,7 @@ func InitTokenEncoders() {
 	if err != nil {
 	if err != nil {
 		common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
 		common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
 	}
 	}
-	gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o")
+	cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
 	if err != nil {
 	if err != nil {
 		common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
 		common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
 	}
 	}
@@ -38,7 +39,7 @@ func InitTokenEncoders() {
 			tokenEncoderMap[model] = gpt35TokenEncoder
 			tokenEncoderMap[model] = gpt35TokenEncoder
 		} else if strings.HasPrefix(model, "gpt-4") {
 		} else if strings.HasPrefix(model, "gpt-4") {
 			if strings.HasPrefix(model, "gpt-4o") {
 			if strings.HasPrefix(model, "gpt-4o") {
-				tokenEncoderMap[model] = gpt4oTokenEncoder
+				tokenEncoderMap[model] = cl200kTokenEncoder
 			} else {
 			} else {
 				tokenEncoderMap[model] = gpt4TokenEncoder
 				tokenEncoderMap[model] = gpt4TokenEncoder
 			}
 			}
@@ -49,21 +50,30 @@ func InitTokenEncoders() {
 	common.SysLog("token encoders initialized")
 	common.SysLog("token encoders initialized")
 }
 }
 
 
+func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
+	if strings.HasPrefix(model, "gpt-4o") {
+		return cl200kTokenEncoder
+	}
+	return defaultTokenEncoder
+}
+
 func getTokenEncoder(model string) *tiktoken.Tiktoken {
 func getTokenEncoder(model string) *tiktoken.Tiktoken {
 	tokenEncoder, ok := tokenEncoderMap[model]
 	tokenEncoder, ok := tokenEncoderMap[model]
 	if ok && tokenEncoder != nil {
 	if ok && tokenEncoder != nil {
 		return tokenEncoder
 		return tokenEncoder
 	}
 	}
+	// 如果ok(即model在tokenEncoderMap中),但是tokenEncoder为nil,说明可能是自定义模型
 	if ok {
 	if ok {
 		tokenEncoder, err := tiktoken.EncodingForModel(model)
 		tokenEncoder, err := tiktoken.EncodingForModel(model)
 		if err != nil {
 		if err != nil {
 			common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
 			common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
-			tokenEncoder = defaultTokenEncoder
+			tokenEncoder = getModelDefaultTokenEncoder(model)
 		}
 		}
 		tokenEncoderMap[model] = tokenEncoder
 		tokenEncoderMap[model] = tokenEncoder
 		return tokenEncoder
 		return tokenEncoder
 	}
 	}
-	return defaultTokenEncoder
+	// 如果model不在tokenEncoderMap中,直接返回默认的tokenEncoder
+	return getModelDefaultTokenEncoder(model)
 }
 }
 
 
 func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
 func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
@@ -75,13 +85,13 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
 	if model == "glm-4v" {
 	if model == "glm-4v" {
 		return 1047, nil
 		return 1047, nil
 	}
 	}
+	if imageUrl.Detail == "low" {
+		return 85, nil
+	}
 	// 同步One API的图片计费逻辑
 	// 同步One API的图片计费逻辑
 	if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
 	if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
 		imageUrl.Detail = "high"
 		imageUrl.Detail = "high"
 	}
 	}
-	if imageUrl.Detail == "low" {
-		return 85, nil
-	}
 	var config image.Config
 	var config image.Config
 	var err error
 	var err error
 	var format string
 	var format string