tiktoken.go 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. package tiktoken
  2. import (
  3. "strings"
  4. "sync"
  5. "github.com/pkoukk/tiktoken-go"
  6. log "github.com/sirupsen/logrus"
  7. )
  8. // tokenEncoderMap won't grow after initialization
  9. var (
  10. tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
  11. defaultTokenEncoder *tiktoken.Tiktoken
  12. tokenEncoderLock sync.RWMutex
  13. )
  14. func init() {
  15. tiktoken.SetBpeLoader(&embedBpeLoader{})
  16. gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
  17. if err != nil {
  18. log.Fatal("failed to get gpt-3.5-turbo token encoder: " + err.Error())
  19. }
  20. defaultTokenEncoder = gpt35TokenEncoder
  21. }
  22. func GetTokenEncoder(model string) *tiktoken.Tiktoken {
  23. tokenEncoderLock.RLock()
  24. tokenEncoder, ok := tokenEncoderMap[model]
  25. tokenEncoderLock.RUnlock()
  26. if ok {
  27. return tokenEncoder
  28. }
  29. tokenEncoderLock.Lock()
  30. defer tokenEncoderLock.Unlock()
  31. if tokenEncoder, ok := tokenEncoderMap[model]; ok {
  32. return tokenEncoder
  33. }
  34. log.Info("loading encoding for model " + model)
  35. tokenEncoder, err := tiktoken.EncodingForModel(model)
  36. if err != nil {
  37. if strings.Contains(err.Error(), "no encoding for model") {
  38. log.Warnf("no encoding for model %s, using default encoder", model)
  39. tokenEncoderMap[model] = defaultTokenEncoder
  40. } else {
  41. log.Errorf("failed to get token encoder for model %s: %v", model, err)
  42. }
  43. return defaultTokenEncoder
  44. }
  45. log.Infof("load encoding for model %s success", model)
  46. tokenEncoderMap[model] = tokenEncoder
  47. return tokenEncoder
  48. }