|
|
@@ -9,44 +9,53 @@ import (
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
"strconv"
|
|
|
+ "strings"
|
|
|
)
|
|
|
|
|
|
var stopFinishReason = "stop"
|
|
|
|
|
|
+// tokenEncoderMap won't grow after initialization
|
|
|
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
|
|
+var defaultTokenEncoder *tiktoken.Tiktoken
|
|
|
|
|
|
func InitTokenEncoders() {
|
|
|
common.SysLog("initializing token encoders")
|
|
|
- fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
|
|
+ gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
|
|
if err != nil {
|
|
|
- common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error()))
|
|
|
+ common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
|
|
+ }
|
|
|
+ defaultTokenEncoder = gpt35TokenEncoder
|
|
|
+ gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
|
|
+ if err != nil {
|
|
|
+ common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
|
|
}
|
|
|
for model, _ := range common.ModelRatio {
|
|
|
- tokenEncoder, err := tiktoken.EncodingForModel(model)
|
|
|
- if err != nil {
|
|
|
- common.SysError(fmt.Sprintf("using fallback encoder for model %s", model))
|
|
|
- tokenEncoderMap[model] = fallbackTokenEncoder
|
|
|
- continue
|
|
|
+ if strings.HasPrefix(model, "gpt-3.5") {
|
|
|
+ tokenEncoderMap[model] = gpt35TokenEncoder
|
|
|
+ } else if strings.HasPrefix(model, "gpt-4") {
|
|
|
+ tokenEncoderMap[model] = gpt4TokenEncoder
|
|
|
+ } else {
|
|
|
+ tokenEncoderMap[model] = nil
|
|
|
}
|
|
|
- tokenEncoderMap[model] = tokenEncoder
|
|
|
}
|
|
|
common.SysLog("token encoders initialized")
|
|
|
}
|
|
|
|
|
|
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
|
|
- if tokenEncoder, ok := tokenEncoderMap[model]; ok {
|
|
|
+ tokenEncoder, ok := tokenEncoderMap[model]
|
|
|
+ if ok && tokenEncoder != nil {
|
|
|
return tokenEncoder
|
|
|
}
|
|
|
- tokenEncoder, err := tiktoken.EncodingForModel(model)
|
|
|
- if err != nil {
|
|
|
- common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
|
|
- tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
|
|
|
+ if ok {
|
|
|
+ tokenEncoder, err := tiktoken.EncodingForModel(model)
|
|
|
if err != nil {
|
|
|
- common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
|
|
|
+ common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
|
|
+ tokenEncoder = defaultTokenEncoder
|
|
|
}
|
|
|
+ tokenEncoderMap[model] = tokenEncoder
|
|
|
+ return tokenEncoder
|
|
|
}
|
|
|
- tokenEncoderMap[model] = tokenEncoder
|
|
|
- return tokenEncoder
|
|
|
+ return defaultTokenEncoder
|
|
|
}
|
|
|
|
|
|
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|