tiktoken.go 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. package tiktoken
  2. import (
  3. "errors"
  4. "sync"
  5. log "github.com/sirupsen/logrus"
  6. "github.com/tiktoken-go/tokenizer"
  7. )
  8. // tokenEncoderMap won't grow after initialization
  9. var (
  10. tokenEncoderMap = map[string]tokenizer.Codec{}
  11. defaultTokenEncoder tokenizer.Codec
  12. tokenEncoderLock sync.RWMutex
  13. )
  14. func init() {
  15. gpt4oTokenEncoder, err := tokenizer.ForModel(tokenizer.GPT4o)
  16. if err != nil {
  17. log.Fatal("failed to get gpt-4o token encoder: " + err.Error())
  18. }
  19. defaultTokenEncoder = gpt4oTokenEncoder
  20. }
  21. func GetTokenEncoder(model string) tokenizer.Codec {
  22. tokenEncoderLock.RLock()
  23. tokenEncoder, ok := tokenEncoderMap[model]
  24. tokenEncoderLock.RUnlock()
  25. if ok {
  26. return tokenEncoder
  27. }
  28. tokenEncoderLock.Lock()
  29. defer tokenEncoderLock.Unlock()
  30. if tokenEncoder, ok := tokenEncoderMap[model]; ok {
  31. return tokenEncoder
  32. }
  33. log.Info("loading encoding for model " + model)
  34. // ForModel has built-in prefix matching for model names
  35. tokenEncoder, err := tokenizer.ForModel(tokenizer.Model(model))
  36. if err != nil {
  37. if errors.Is(err, tokenizer.ErrModelNotSupported) {
  38. log.Warnf("model %s not supported, using default encoder (gpt-4o)", model)
  39. tokenEncoderMap[model] = defaultTokenEncoder
  40. return defaultTokenEncoder
  41. }
  42. log.Errorf(
  43. "failed to get token encoder for model %s: %v, using default encoder",
  44. model,
  45. err,
  46. )
  47. tokenEncoderMap[model] = defaultTokenEncoder
  48. return defaultTokenEncoder
  49. }
  50. log.Infof("loaded encoding for model %s: %s", model, tokenEncoder.GetName())
  51. tokenEncoderMap[model] = tokenEncoder
  52. return tokenEncoder
  53. }