assest.go 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. package tiktoken
  2. import (
  3. "embed"
  4. "encoding/base64"
  5. "errors"
  6. "os"
  7. "path"
  8. "strconv"
  9. "strings"
  10. "github.com/labring/aiproxy/core/common/conv"
  11. "github.com/pkoukk/tiktoken-go"
  12. )
  13. //go:embed all:assets
  14. var assets embed.FS
  15. var (
  16. _ tiktoken.BpeLoader = (*embedBpeLoader)(nil)
  17. defaultBpeLoader = tiktoken.NewDefaultBpeLoader()
  18. )
  19. type embedBpeLoader struct{}
  20. func (e *embedBpeLoader) LoadTiktokenBpe(tiktokenBpeFile string) (map[string]int, error) {
  21. embedPath := path.Join("assets", path.Base(tiktokenBpeFile))
  22. contents, err := assets.ReadFile(embedPath)
  23. if err != nil {
  24. if errors.Is(err, os.ErrNotExist) {
  25. return defaultBpeLoader.LoadTiktokenBpe(tiktokenBpeFile)
  26. }
  27. return nil, err
  28. }
  29. bpeRanks := make(map[string]int)
  30. for _, line := range strings.Split(conv.BytesToString(contents), "\n") {
  31. if line == "" {
  32. continue
  33. }
  34. parts := strings.Split(line, " ")
  35. token, err := base64.StdEncoding.DecodeString(parts[0])
  36. if err != nil {
  37. return nil, err
  38. }
  39. rank, err := strconv.Atoi(parts[1])
  40. if err != nil {
  41. return nil, err
  42. }
  43. bpeRanks[string(token)] = rank
  44. }
  45. return bpeRanks, nil
  46. }