Browse Source

feat: embed tiktoken (#34)

* feat: embed tiktoken

* fix: image build

* fix: image build

* fix: cache no encoding for model

* fix: cache no encoding for model
zijiren 9 months ago
parent
commit
bc92ea5653

+ 4 - 0
.github/workflows/release.yml

@@ -47,6 +47,10 @@ jobs:
     steps:
       - name: Checkout
         uses: actions/checkout@v4
+      
+      - name: Download tiktoken
+        run: |
+          bash common/tiktoken/assest.sh
 
       - name: Setup Go
         uses: actions/setup-go@v5

+ 3 - 1
.gitignore

@@ -1,2 +1,4 @@
 aiproxy.db*
-aiproxy
+aiproxy
+common/tiktoken/assets/*
+!*.gitkeep

+ 4 - 0
Dockerfile

@@ -4,6 +4,10 @@ WORKDIR /aiproxy
 
 COPY ./ ./
 
+RUN apk add --no-cache curl
+
+RUN sh common/tiktoken/assest.sh
+
 RUN go build -trimpath -tags "jsoniter" -ldflags "-s -w" -o aiproxy
 
 FROM alpine:latest

+ 52 - 0
common/tiktoken/assest.go

@@ -0,0 +1,52 @@
+package tiktoken
+
+import (
+	"embed"
+	"encoding/base64"
+	"errors"
+	"os"
+	"path"
+	"strconv"
+	"strings"
+
+	"github.com/labring/aiproxy/common/conv"
+	"github.com/pkoukk/tiktoken-go"
+)
+
+//go:embed all:assets
+var assets embed.FS
+
+var (
+	_                tiktoken.BpeLoader = (*embedBpeLoader)(nil)
+	defaultBpeLoader                    = tiktoken.NewDefaultBpeLoader()
+)
+
+type embedBpeLoader struct{}
+
+func (e *embedBpeLoader) LoadTiktokenBpe(tiktokenBpeFile string) (map[string]int, error) {
+	embedPath := path.Join("assets", path.Base(tiktokenBpeFile))
+	contents, err := assets.ReadFile(embedPath)
+	if err != nil {
+		if errors.Is(err, os.ErrNotExist) {
+			return defaultBpeLoader.LoadTiktokenBpe(tiktokenBpeFile)
+		}
+		return nil, err
+	}
+	bpeRanks := make(map[string]int)
+	for _, line := range strings.Split(conv.BytesToString(contents), "\n") {
+		if line == "" {
+			continue
+		}
+		parts := strings.Split(line, " ")
+		token, err := base64.StdEncoding.DecodeString(parts[0])
+		if err != nil {
+			return nil, err
+		}
+		rank, err := strconv.Atoi(parts[1])
+		if err != nil {
+			return nil, err
+		}
+		bpeRanks[string(token)] = rank
+	}
+	return bpeRanks, nil
+}

+ 20 - 0
common/tiktoken/assest.sh

@@ -0,0 +1,20 @@
+#!/bin/bash
+
+set -ex
+
+ASSETS=$(
+    cat <<EOF
+https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken
+https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken
+https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken
+https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken
+EOF
+)
+
+mkdir -p "$(dirname $0)/assets"
+
+rm -f "$(dirname $0)/assets/*"
+
+for asset in $ASSETS; do
+    curl -L -f -o "$(dirname $0)/assets/$(basename $asset)" "$asset"
+done

+ 0 - 0
common/tiktoken/assets/.gitkeep


+ 54 - 0
common/tiktoken/tiktoken.go

@@ -0,0 +1,54 @@
+package tiktoken
+
+import (
+	"strings"
+	"sync"
+
+	"github.com/pkoukk/tiktoken-go"
+	log "github.com/sirupsen/logrus"
+)
+
+// tokenEncoderMap won't grow after initialization
+var (
+	tokenEncoderMap     = map[string]*tiktoken.Tiktoken{}
+	defaultTokenEncoder *tiktoken.Tiktoken
+	tokenEncoderLock    sync.RWMutex
+)
+
+func init() {
+	tiktoken.SetBpeLoader(&embedBpeLoader{})
+	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
+	if err != nil {
+		log.Fatal("failed to get gpt-3.5-turbo token encoder: " + err.Error())
+	}
+	defaultTokenEncoder = gpt35TokenEncoder
+}
+
+func GetTokenEncoder(model string) *tiktoken.Tiktoken {
+	tokenEncoderLock.RLock()
+	tokenEncoder, ok := tokenEncoderMap[model]
+	tokenEncoderLock.RUnlock()
+	if ok {
+		return tokenEncoder
+	}
+
+	tokenEncoderLock.Lock()
+	defer tokenEncoderLock.Unlock()
+	if tokenEncoder, ok := tokenEncoderMap[model]; ok {
+		return tokenEncoder
+	}
+
+	tokenEncoder, err := tiktoken.EncodingForModel(model)
+	if err != nil {
+		if strings.Contains(err.Error(), "no encoding for model") {
+			log.Warnf("no encoding for model %s, using encoder for gpt-3.5-turbo", model)
+			tokenEncoderMap[model] = defaultTokenEncoder
+		} else {
+			log.Errorf("failed to get token encoder for model %s: %v", model, err)
+		}
+		return defaultTokenEncoder
+	}
+
+	tokenEncoderMap[model] = tokenEncoder
+	return tokenEncoder
+}

+ 0 - 3
main.go

@@ -25,7 +25,6 @@ import (
 	"github.com/labring/aiproxy/controller"
 	"github.com/labring/aiproxy/middleware"
 	"github.com/labring/aiproxy/model"
-	"github.com/labring/aiproxy/relay/adaptor/openai"
 	"github.com/labring/aiproxy/router"
 	log "github.com/sirupsen/logrus"
 )
@@ -41,8 +40,6 @@ func initializeServices() error {
 
 	initializeNotifier()
 
-	openai.InitDefaultTokenEncoder()
-
 	if err := initializeBalance(); err != nil {
 		return err
 	}

+ 3 - 46
relay/adaptor/openai/token.go

@@ -4,59 +4,16 @@ import (
 	"errors"
 	"math"
 	"strings"
-	"sync"
 	"unicode/utf8"
 
 	"github.com/labring/aiproxy/common/config"
 	"github.com/labring/aiproxy/common/image"
+	intertiktoken "github.com/labring/aiproxy/common/tiktoken"
 	"github.com/labring/aiproxy/relay/model"
 	"github.com/pkoukk/tiktoken-go"
 	log "github.com/sirupsen/logrus"
 )
 
-// tokenEncoderMap won't grow after initialization
-var (
-	tokenEncoderMap         = map[string]*tiktoken.Tiktoken{}
-	defaultTokenEncoder     *tiktoken.Tiktoken
-	defaultTokenEncoderOnce sync.Once
-	tokenEncoderLock        sync.RWMutex
-)
-
-func InitDefaultTokenEncoder() {
-	defaultTokenEncoderOnce.Do(func() {
-		gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
-		if err != nil {
-			log.Fatal("failed to get gpt-3.5-turbo token encoder: " + err.Error())
-		}
-		defaultTokenEncoder = gpt35TokenEncoder
-	})
-}
-
-func getTokenEncoder(model string) *tiktoken.Tiktoken {
-	tokenEncoderLock.RLock()
-	tokenEncoder, ok := tokenEncoderMap[model]
-	tokenEncoderLock.RUnlock()
-	if ok {
-		return tokenEncoder
-	}
-
-	InitDefaultTokenEncoder()
-
-	tokenEncoderLock.Lock()
-	defer tokenEncoderLock.Unlock()
-	if tokenEncoder, ok := tokenEncoderMap[model]; ok {
-		return tokenEncoder
-	}
-
-	tokenEncoder, err := tiktoken.EncodingForModel(model)
-	if err != nil {
-		log.Warnf("failed to get token encoder for model %s: %v, using encoder for gpt-3.5-turbo", model, err)
-		tokenEncoder = defaultTokenEncoder
-	}
-	tokenEncoderMap[model] = tokenEncoder
-	return tokenEncoder
-}
-
 func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
 	return len(tokenEncoder.Encode(text, nil, nil))
 }
@@ -65,7 +22,7 @@ func CountTokenMessages(messages []*model.Message, model string) int {
 	if !config.GetBillingEnabled() {
 		return 0
 	}
-	tokenEncoder := getTokenEncoder(model)
+	tokenEncoder := intertiktoken.GetTokenEncoder(model)
 	// Reference:
 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
 	// https://github.com/pkoukk/tiktoken-go/issues/6
@@ -241,5 +198,5 @@ func CountTokenText(text string, model string) int {
 	if strings.HasPrefix(model, "tts") {
 		return utf8.RuneCountInString(text)
 	}
-	return getTokenNum(getTokenEncoder(model), text)
+	return getTokenNum(intertiktoken.GetTokenEncoder(model), text)
 }