|
|
@@ -1,32 +1,30 @@
|
|
|
package tiktoken
|
|
|
|
|
|
import (
|
|
|
- "strings"
|
|
|
+ "errors"
|
|
|
"sync"
|
|
|
|
|
|
- "github.com/pkoukk/tiktoken-go"
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
+ "github.com/tiktoken-go/tokenizer"
|
|
|
)
|
|
|
|
|
|
// tokenEncoderMap won't grow after initialization
|
|
|
var (
|
|
|
- tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
|
|
- defaultTokenEncoder *tiktoken.Tiktoken
|
|
|
+ tokenEncoderMap = map[string]tokenizer.Codec{}
|
|
|
+ defaultTokenEncoder tokenizer.Codec
|
|
|
tokenEncoderLock sync.RWMutex
|
|
|
)
|
|
|
|
|
|
func init() {
|
|
|
- tiktoken.SetBpeLoader(&embedBpeLoader{})
|
|
|
-
|
|
|
- gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
|
|
+ gpt4oTokenEncoder, err := tokenizer.ForModel(tokenizer.GPT4o)
|
|
|
if err != nil {
|
|
|
- log.Fatal("failed to get gpt-3.5-turbo token encoder: " + err.Error())
|
|
|
+ log.Fatal("failed to get gpt-4o token encoder: " + err.Error())
|
|
|
}
|
|
|
|
|
|
- defaultTokenEncoder = gpt35TokenEncoder
|
|
|
+ defaultTokenEncoder = gpt4oTokenEncoder
|
|
|
}
|
|
|
|
|
|
-func GetTokenEncoder(model string) *tiktoken.Tiktoken {
|
|
|
+func GetTokenEncoder(model string) tokenizer.Codec {
|
|
|
tokenEncoderLock.RLock()
|
|
|
|
|
|
tokenEncoder, ok := tokenEncoderMap[model]
|
|
|
@@ -46,19 +44,26 @@ func GetTokenEncoder(model string) *tiktoken.Tiktoken {
|
|
|
|
|
|
log.Info("loading encoding for model " + model)
|
|
|
|
|
|
- tokenEncoder, err := tiktoken.EncodingForModel(model)
|
|
|
+ // ForModel has built-in prefix matching for model names
|
|
|
+ tokenEncoder, err := tokenizer.ForModel(tokenizer.Model(model))
|
|
|
if err != nil {
|
|
|
- if strings.Contains(err.Error(), "no encoding for model") {
|
|
|
- log.Warnf("no encoding for model %s, using default encoder", model)
|
|
|
+ if errors.Is(err, tokenizer.ErrModelNotSupported) {
|
|
|
+ log.Warnf("model %s not supported, using default encoder (gpt-4o)", model)
|
|
|
tokenEncoderMap[model] = defaultTokenEncoder
|
|
|
- } else {
|
|
|
- log.Errorf("failed to get token encoder for model %s: %v", model, err)
|
|
|
+ return defaultTokenEncoder
|
|
|
}
|
|
|
|
|
|
+ log.Errorf(
|
|
|
+ "failed to get token encoder for model %s: %v, using default encoder",
|
|
|
+ model,
|
|
|
+ err,
|
|
|
+ )
|
|
|
+ tokenEncoderMap[model] = defaultTokenEncoder
|
|
|
+
|
|
|
return defaultTokenEncoder
|
|
|
}
|
|
|
|
|
|
- log.Infof("load encoding for model %s success", model)
|
|
|
+ log.Infof("loaded encoding for model %s: %s", model, tokenEncoder.GetName())
|
|
|
|
|
|
tokenEncoderMap[model] = tokenEncoder
|
|
|
|