Explorar o código

fix: add fuzzyTokenThreshold option (#401)

zijiren hai 2 meses
pai
achega
7f2185a9b4

+ 0 - 5
.github/workflows/release.yml

@@ -79,11 +79,6 @@ jobs:
           name: web
           path: core/public/dist
 
-      - name: Download tiktoken
-        working-directory: core
-        run: |
-          bash scripts/tiktoken.sh
-
       - name: Setup Go
         uses: actions/setup-go@v6
         with:

+ 0 - 4
Dockerfile

@@ -10,16 +10,12 @@ RUN pnpm install && pnpm run build
 
 FROM golang:1.25.2-alpine AS builder
 
-RUN apk add --no-cache curl
-
 WORKDIR /aiproxy/core
 
 COPY ./ /aiproxy
 
 COPY --from=frontend-builder /aiproxy/web/dist/ /aiproxy/core/public/dist/
 
-RUN sh scripts/tiktoken.sh
-
 RUN go install github.com/swaggo/swag/cmd/swag@latest
 
 RUN sh scripts/swag.sh

+ 0 - 1
core/.gitignore

@@ -1,7 +1,6 @@
 aiproxy.db*
 core
 core.exe
-common/tiktoken/assets/*
 /public/dist/*
 !*.gitkeep
 .env.local

+ 15 - 0
core/common/config/config.go

@@ -35,6 +35,12 @@ var (
 	defaultMCPHost atomic.Value
 	publicMCPHost  atomic.Value
 	groupMCPHost   atomic.Value
+
+	// fuzzyTokenThreshold is the text length threshold for fuzzy token calculation.
+	// If text length is below this threshold, precise token counting is used.
+	// If text length is at or above this threshold, approximate counting (length/4) is used.
+	// Set to 0 to always use precise counting (default behavior).
+	fuzzyTokenThreshold atomic.Int64
 )
 
 func init() {
@@ -279,3 +285,12 @@ func SetUsageAlertMinAvgThreshold(threshold int64) {
 	threshold = env.Int64("USAGE_ALERT_MIN_AVG_THRESHOLD", threshold)
 	usageAlertMinAvgThreshold.Store(threshold)
 }
+
+func GetFuzzyTokenThreshold() int64 {
+	return fuzzyTokenThreshold.Load()
+}
+
+func SetFuzzyTokenThreshold(threshold int64) {
+	threshold = env.Int64("FUZZY_TOKEN_THRESHOLD", threshold)
+	fuzzyTokenThreshold.Store(threshold)
+}

+ 0 - 59
core/common/tiktoken/assest.go

@@ -1,59 +0,0 @@
-package tiktoken
-
-import (
-	"embed"
-	"encoding/base64"
-	"errors"
-	"os"
-	"path"
-	"strconv"
-	"strings"
-
-	"github.com/labring/aiproxy/core/common/conv"
-	"github.com/pkoukk/tiktoken-go"
-)
-
-//go:embed all:assets
-var assets embed.FS
-
-var (
-	_                tiktoken.BpeLoader = (*embedBpeLoader)(nil)
-	defaultBpeLoader                    = tiktoken.NewDefaultBpeLoader()
-)
-
-type embedBpeLoader struct{}
-
-func (e *embedBpeLoader) LoadTiktokenBpe(tiktokenBpeFile string) (map[string]int, error) {
-	embedPath := path.Join("assets", path.Base(tiktokenBpeFile))
-
-	contents, err := assets.ReadFile(embedPath)
-	if err != nil {
-		if errors.Is(err, os.ErrNotExist) {
-			return defaultBpeLoader.LoadTiktokenBpe(tiktokenBpeFile)
-		}
-		return nil, err
-	}
-
-	bpeRanks := make(map[string]int)
-	for _, line := range strings.Split(conv.BytesToString(contents), "\n") {
-		if line == "" {
-			continue
-		}
-
-		parts := strings.Split(line, " ")
-
-		token, err := base64.StdEncoding.DecodeString(parts[0])
-		if err != nil {
-			return nil, err
-		}
-
-		rank, err := strconv.Atoi(parts[1])
-		if err != nil {
-			return nil, err
-		}
-
-		bpeRanks[string(token)] = rank
-	}
-
-	return bpeRanks, nil
-}

+ 0 - 0
core/common/tiktoken/assets/.gitkeep


+ 21 - 16
core/common/tiktoken/tiktoken.go

@@ -1,32 +1,30 @@
 package tiktoken
 
 import (
-	"strings"
+	"errors"
 	"sync"
 
-	"github.com/pkoukk/tiktoken-go"
 	log "github.com/sirupsen/logrus"
+	"github.com/tiktoken-go/tokenizer"
 )
 
 // tokenEncoderMap won't grow after initialization
 var (
-	tokenEncoderMap     = map[string]*tiktoken.Tiktoken{}
-	defaultTokenEncoder *tiktoken.Tiktoken
+	tokenEncoderMap     = map[string]tokenizer.Codec{}
+	defaultTokenEncoder tokenizer.Codec
 	tokenEncoderLock    sync.RWMutex
 )
 
 func init() {
-	tiktoken.SetBpeLoader(&embedBpeLoader{})
-
-	gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
+	gpt4oTokenEncoder, err := tokenizer.ForModel(tokenizer.GPT4o)
 	if err != nil {
-		log.Fatal("failed to get gpt-3.5-turbo token encoder: " + err.Error())
+		log.Fatal("failed to get gpt-4o token encoder: " + err.Error())
 	}
 
-	defaultTokenEncoder = gpt35TokenEncoder
+	defaultTokenEncoder = gpt4oTokenEncoder
 }
 
-func GetTokenEncoder(model string) *tiktoken.Tiktoken {
+func GetTokenEncoder(model string) tokenizer.Codec {
 	tokenEncoderLock.RLock()
 
 	tokenEncoder, ok := tokenEncoderMap[model]
@@ -46,19 +44,26 @@ func GetTokenEncoder(model string) *tiktoken.Tiktoken {
 
 	log.Info("loading encoding for model " + model)
 
-	tokenEncoder, err := tiktoken.EncodingForModel(model)
+	// ForModel has built-in prefix matching for model names
+	tokenEncoder, err := tokenizer.ForModel(tokenizer.Model(model))
 	if err != nil {
-		if strings.Contains(err.Error(), "no encoding for model") {
-			log.Warnf("no encoding for model %s, using default encoder", model)
+		if errors.Is(err, tokenizer.ErrModelNotSupported) {
+			log.Warnf("model %s not supported, using default encoder (gpt-4o)", model)
 			tokenEncoderMap[model] = defaultTokenEncoder
-		} else {
-			log.Errorf("failed to get token encoder for model %s: %v", model, err)
+			return defaultTokenEncoder
 		}
 
+		log.Errorf(
+			"failed to get token encoder for model %s: %v, using default encoder",
+			model,
+			err,
+		)
+		tokenEncoderMap[model] = defaultTokenEncoder
+
 		return defaultTokenEncoder
 	}
 
-	log.Infof("load encoding for model %s success", model)
+	log.Infof("loaded encoding for model %s: %s", model, tokenEncoder.GetName())
 
 	tokenEncoderMap[model] = tokenEncoder
 

+ 1 - 1
core/go.mod

@@ -25,7 +25,6 @@ require (
 	github.com/mattn/go-isatty v0.0.20
 	github.com/patrickmn/go-cache v2.1.0+incompatible
 	github.com/pkg/errors v0.9.1
-	github.com/pkoukk/tiktoken-go v0.1.7
 	github.com/redis/go-redis/v9 v9.12.1
 	github.com/shopspring/decimal v1.4.0
 	github.com/sirupsen/logrus v1.9.3
@@ -36,6 +35,7 @@ require (
 	github.com/swaggo/files v1.0.1
 	github.com/swaggo/gin-swagger v1.6.0
 	github.com/swaggo/swag v1.16.6
+	github.com/tiktoken-go/tokenizer v0.7.0
 	golang.org/x/image v0.30.0
 	golang.org/x/sync v0.16.0
 	google.golang.org/api v0.248.0

+ 2 - 2
core/go.sum

@@ -198,8 +198,6 @@ github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0V
 github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
-github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg=
@@ -248,6 +246,8 @@ github.com/swaggo/swag v1.16.6 h1:qBNcx53ZaX+M5dxVyTrgQ0PJ/ACK+NzhwcbieTt+9yI=
 github.com/swaggo/swag v1.16.6/go.mod h1:ngP2etMK5a0P3QBizic5MEwpRmluJZPHjXcMoj4Xesg=
 github.com/temoto/robotstxt v1.1.2 h1:W2pOjSJ6SWvldyEuiFXNxz3xZ8aiWX5LbfDiOFd7Fxg=
 github.com/temoto/robotstxt v1.1.2/go.mod h1:+1AmkuG3IYkh1kv0d2qEB9Le88ehNO0zwOr3ujewlOo=
+github.com/tiktoken-go/tokenizer v0.7.0 h1:VMu6MPT0bXFDHr7UPh9uii7CNItVt3X9K90omxL54vw=
+github.com/tiktoken-go/tokenizer v0.7.0/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
 github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
 github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
 github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA=

+ 12 - 0
core/model/option.go

@@ -123,6 +123,7 @@ func initOptionMap() error {
 		config.GetUsageAlertMinAvgThreshold(),
 		10,
 	)
+	optionMap["FuzzyTokenThreshold"] = strconv.FormatInt(config.GetFuzzyTokenThreshold(), 10)
 
 	optionKeys = make([]string, 0, len(optionMap))
 	for key := range optionMap {
@@ -445,6 +446,17 @@ func updateOption(key, value string, isInit bool) (err error) {
 		}
 
 		config.SetUsageAlertMinAvgThreshold(threshold)
+	case "FuzzyTokenThreshold":
+		threshold, err := strconv.ParseInt(value, 10, 64)
+		if err != nil {
+			return err
+		}
+
+		if threshold < 0 {
+			return errors.New("fuzzy token threshold must be greater than or equal to 0")
+		}
+
+		config.SetFuzzyTokenThreshold(threshold)
 	default:
 		return ErrUnknownOptionKey
 	}

+ 21 - 4
core/relay/adaptor/openai/token.go

@@ -5,22 +5,39 @@ import (
 	"math"
 	"strings"
 
+	"github.com/labring/aiproxy/core/common/config"
 	"github.com/labring/aiproxy/core/common/image"
 	intertiktoken "github.com/labring/aiproxy/core/common/tiktoken"
 	"github.com/labring/aiproxy/core/relay/model"
-	"github.com/pkoukk/tiktoken-go"
 	log "github.com/sirupsen/logrus"
+	"github.com/tiktoken-go/tokenizer"
 )
 
-func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int64 {
-	return int64(len(tokenEncoder.Encode(text, nil, nil)))
+func getTokenNum(tokenEncoder tokenizer.Codec, text string) int64 {
+	// Check fuzzy token threshold
+	threshold := config.GetFuzzyTokenThreshold()
+	textLen := len(text)
+
+	// If threshold is set and text length exceeds it, use fuzzy calculation
+	if threshold > 0 && int64(textLen) >= threshold {
+		return int64(textLen / 4)
+	}
+
+	// Otherwise, use precise token counting
+	count, err := tokenEncoder.Count(text)
+	if err != nil {
+		log.Warnf("failed to count tokens: %v, fallback to length/4", err)
+		// Fallback to rough estimation if counting fails
+		return int64(textLen / 4)
+	}
+
+	return int64(count)
 }
 
 func CountTokenMessages(messages []model.Message, model string) int64 {
 	tokenEncoder := intertiktoken.GetTokenEncoder(model)
 	// Reference:
 	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
-	// https://github.com/pkoukk/tiktoken-go/issues/6
 	//
 	// Every message follows <|start|>{role/name}\n{content}<|end|>\n
 	var (

+ 0 - 18
core/scripts/tiktoken.sh

@@ -1,18 +0,0 @@
-#!/bin/bash
-
-set -ex
-
-ASSETS=$(
-    cat <<EOF
-https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken
-https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken
-https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken
-https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken
-EOF
-)
-
-mkdir -p "$(dirname $0)/../common/tiktoken/assets"
-
-for asset in $ASSETS; do
-    curl -L -f -o "$(dirname $0)/../common/tiktoken/assets/$(basename $asset)" "$asset"
-done

+ 1 - 0
go.work.sum

@@ -78,6 +78,7 @@ github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443/go.mod h1:W+zGtBO5Y1Ig
 github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d h1:U+s90UTSYgptZMwQh2aRr3LuazLJIa+Pg3Kc1ylSYVY=
 github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
 github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
+github.com/dlclark/regexp2cg v0.2.0/go.mod h1:K2c4ctxtSQjzgeMKKgi1rEflZVVJWZWlUUdmtjOp/y8=
 github.com/envoyproxy/go-control-plane v0.13.4 h1:zEqyPVyku6IvWCFwux4x9RxkLOMUL+1vC9xUFv5l2/M=
 github.com/envoyproxy/go-control-plane v0.13.4/go.mod h1:kDfuBlDVsSj2MjrLEtRWtHlsWIFcGyB2RMO44Dc5GZA=
 github.com/envoyproxy/go-control-plane/envoy v1.32.4 h1:jb83lalDRZSpPWW2Z7Mck/8kXZ5CQAFYVjQcdVIr83A=