|
|
@@ -4,59 +4,16 @@ import (
|
|
|
"errors"
|
|
|
"math"
|
|
|
"strings"
|
|
|
- "sync"
|
|
|
"unicode/utf8"
|
|
|
|
|
|
"github.com/labring/aiproxy/common/config"
|
|
|
"github.com/labring/aiproxy/common/image"
|
|
|
+ intertiktoken "github.com/labring/aiproxy/common/tiktoken"
|
|
|
"github.com/labring/aiproxy/relay/model"
|
|
|
"github.com/pkoukk/tiktoken-go"
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
)
|
|
|
|
|
|
-// tokenEncoderMap won't grow after initialization
|
|
|
-var (
|
|
|
- tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
|
|
- defaultTokenEncoder *tiktoken.Tiktoken
|
|
|
- defaultTokenEncoderOnce sync.Once
|
|
|
- tokenEncoderLock sync.RWMutex
|
|
|
-)
|
|
|
-
|
|
|
-func InitDefaultTokenEncoder() {
|
|
|
- defaultTokenEncoderOnce.Do(func() {
|
|
|
- gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
|
|
- if err != nil {
|
|
|
- log.Fatal("failed to get gpt-3.5-turbo token encoder: " + err.Error())
|
|
|
- }
|
|
|
- defaultTokenEncoder = gpt35TokenEncoder
|
|
|
- })
|
|
|
-}
|
|
|
-
|
|
|
-func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
|
|
- tokenEncoderLock.RLock()
|
|
|
- tokenEncoder, ok := tokenEncoderMap[model]
|
|
|
- tokenEncoderLock.RUnlock()
|
|
|
- if ok {
|
|
|
- return tokenEncoder
|
|
|
- }
|
|
|
-
|
|
|
- InitDefaultTokenEncoder()
|
|
|
-
|
|
|
- tokenEncoderLock.Lock()
|
|
|
- defer tokenEncoderLock.Unlock()
|
|
|
- if tokenEncoder, ok := tokenEncoderMap[model]; ok {
|
|
|
- return tokenEncoder
|
|
|
- }
|
|
|
-
|
|
|
- tokenEncoder, err := tiktoken.EncodingForModel(model)
|
|
|
- if err != nil {
|
|
|
- log.Warnf("failed to get token encoder for model %s: %v, using encoder for gpt-3.5-turbo", model, err)
|
|
|
- tokenEncoder = defaultTokenEncoder
|
|
|
- }
|
|
|
- tokenEncoderMap[model] = tokenEncoder
|
|
|
- return tokenEncoder
|
|
|
-}
|
|
|
-
|
|
|
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|
|
return len(tokenEncoder.Encode(text, nil, nil))
|
|
|
}
|
|
|
@@ -65,7 +22,7 @@ func CountTokenMessages(messages []*model.Message, model string) int {
|
|
|
if !config.GetBillingEnabled() {
|
|
|
return 0
|
|
|
}
|
|
|
- tokenEncoder := getTokenEncoder(model)
|
|
|
+ tokenEncoder := intertiktoken.GetTokenEncoder(model)
|
|
|
// Reference:
|
|
|
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
|
// https://github.com/pkoukk/tiktoken-go/issues/6
|
|
|
@@ -241,5 +198,5 @@ func CountTokenText(text string, model string) int {
|
|
|
if strings.HasPrefix(model, "tts") {
|
|
|
return utf8.RuneCountInString(text)
|
|
|
}
|
|
|
- return getTokenNum(getTokenEncoder(model), text)
|
|
|
+ return getTokenNum(intertiktoken.GetTokenEncoder(model), text)
|
|
|
}
|