|
|
@@ -19,42 +19,40 @@ import (
|
|
|
// tokenEncoderMap won't grow after initialization
|
|
|
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
|
|
var defaultTokenEncoder *tiktoken.Tiktoken
|
|
|
-var cl200kTokenEncoder *tiktoken.Tiktoken
|
|
|
+var o200kTokenEncoder *tiktoken.Tiktoken
|
|
|
|
|
|
func InitTokenEncoders() {
|
|
|
common.SysLog("initializing token encoders")
|
|
|
- gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
|
|
+ cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
|
|
|
if err != nil {
|
|
|
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
|
|
}
|
|
|
- defaultTokenEncoder = gpt35TokenEncoder
|
|
|
- gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
|
|
- if err != nil {
|
|
|
- common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
|
|
- }
|
|
|
- cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
|
|
|
+ defaultTokenEncoder = cl100TokenEncoder
|
|
|
+ o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
|
|
|
if err != nil {
|
|
|
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
|
|
|
}
|
|
|
for model, _ := range common.GetDefaultModelRatioMap() {
|
|
|
if strings.HasPrefix(model, "gpt-3.5") {
|
|
|
- tokenEncoderMap[model] = gpt35TokenEncoder
|
|
|
+ tokenEncoderMap[model] = cl100TokenEncoder
|
|
|
} else if strings.HasPrefix(model, "gpt-4") {
|
|
|
if strings.HasPrefix(model, "gpt-4o") {
|
|
|
- tokenEncoderMap[model] = cl200kTokenEncoder
|
|
|
+ tokenEncoderMap[model] = o200kTokenEncoder
|
|
|
} else {
|
|
|
- tokenEncoderMap[model] = gpt4TokenEncoder
|
|
|
+ tokenEncoderMap[model] = defaultTokenEncoder
|
|
|
}
|
|
|
+ } else if strings.HasPrefix(model, "o1") {
|
|
|
+ tokenEncoderMap[model] = o200kTokenEncoder
|
|
|
} else {
|
|
|
- tokenEncoderMap[model] = nil
|
|
|
+ tokenEncoderMap[model] = defaultTokenEncoder
|
|
|
}
|
|
|
}
|
|
|
common.SysLog("token encoders initialized")
|
|
|
}
|
|
|
|
|
|
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
|
|
|
- if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") {
|
|
|
- return cl200kTokenEncoder
|
|
|
+ if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") {
|
|
|
+ return o200kTokenEncoder
|
|
|
}
|
|
|
return defaultTokenEncoder
|
|
|
}
|