2
0
Эх сурвалжийг харах

feat: update o1 default token encoder

CalciumIon 1 жил өмнө
parent
commit
d2297d2723
1 өөрчлөгдсөн 12 нэмэгдсэн , 14 устгасан
  1. 12 14
      service/token_counter.go

+ 12 - 14
service/token_counter.go

@@ -19,42 +19,40 @@ import (
 // tokenEncoderMap won't grow after initialization
 var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 var defaultTokenEncoder *tiktoken.Tiktoken
-var cl200kTokenEncoder *tiktoken.Tiktoken
+var o200kTokenEncoder *tiktoken.Tiktoken
 
 func InitTokenEncoders() {
 	common.SysLog("initializing token encoders")
-	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
+	cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
 	if err != nil {
 		common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
 	}
-	defaultTokenEncoder = gpt35TokenEncoder
-	gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
-	if err != nil {
-		common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
-	}
-	cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
+	defaultTokenEncoder = cl100TokenEncoder
+	o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
 	if err != nil {
 		common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
 	}
 	for model, _ := range common.GetDefaultModelRatioMap() {
 		if strings.HasPrefix(model, "gpt-3.5") {
-			tokenEncoderMap[model] = gpt35TokenEncoder
+			tokenEncoderMap[model] = cl100TokenEncoder
 		} else if strings.HasPrefix(model, "gpt-4") {
 			if strings.HasPrefix(model, "gpt-4o") {
-				tokenEncoderMap[model] = cl200kTokenEncoder
+				tokenEncoderMap[model] = o200kTokenEncoder
 			} else {
-				tokenEncoderMap[model] = gpt4TokenEncoder
+				tokenEncoderMap[model] = defaultTokenEncoder
 			}
+		} else if strings.HasPrefix(model, "o1") {
+			tokenEncoderMap[model] = o200kTokenEncoder
 		} else {
-			tokenEncoderMap[model] = nil
+			tokenEncoderMap[model] = defaultTokenEncoder
 		}
 	}
 	common.SysLog("token encoders initialized")
 }
 
 func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
-	if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") {
-		return cl200kTokenEncoder
+	if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") {
+		return o200kTokenEncoder
 	}
 	return defaultTokenEncoder
 }