|
|
@@ -1,7 +1,6 @@
|
|
|
package service
|
|
|
|
|
|
import (
|
|
|
- "encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"image"
|
|
|
@@ -12,7 +11,6 @@ import (
|
|
|
"math"
|
|
|
"path/filepath"
|
|
|
"strings"
|
|
|
- "sync"
|
|
|
"unicode/utf8"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
|
@@ -23,64 +21,8 @@ import (
|
|
|
"github.com/QuantumNous/new-api/types"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
- "github.com/tiktoken-go/tokenizer"
|
|
|
- "github.com/tiktoken-go/tokenizer/codec"
|
|
|
)
|
|
|
|
|
|
-// tokenEncoderMap won't grow after initialization
|
|
|
-var defaultTokenEncoder tokenizer.Codec
|
|
|
-
|
|
|
-// tokenEncoderMap is used to store token encoders for different models
|
|
|
-var tokenEncoderMap = make(map[string]tokenizer.Codec)
|
|
|
-
|
|
|
-// tokenEncoderMutex protects tokenEncoderMap for concurrent access
|
|
|
-var tokenEncoderMutex sync.RWMutex
|
|
|
-
|
|
|
-func InitTokenEncoders() {
|
|
|
- common.SysLog("initializing token encoders")
|
|
|
- defaultTokenEncoder = codec.NewCl100kBase()
|
|
|
- common.SysLog("token encoders initialized")
|
|
|
-}
|
|
|
-
|
|
|
-func getTokenEncoder(model string) tokenizer.Codec {
|
|
|
- // First, try to get the encoder from cache with read lock
|
|
|
- tokenEncoderMutex.RLock()
|
|
|
- if encoder, exists := tokenEncoderMap[model]; exists {
|
|
|
- tokenEncoderMutex.RUnlock()
|
|
|
- return encoder
|
|
|
- }
|
|
|
- tokenEncoderMutex.RUnlock()
|
|
|
-
|
|
|
- // If not in cache, create new encoder with write lock
|
|
|
- tokenEncoderMutex.Lock()
|
|
|
- defer tokenEncoderMutex.Unlock()
|
|
|
-
|
|
|
- // Double-check if another goroutine already created the encoder
|
|
|
- if encoder, exists := tokenEncoderMap[model]; exists {
|
|
|
- return encoder
|
|
|
- }
|
|
|
-
|
|
|
- // Create new encoder
|
|
|
- modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
|
|
|
- if err != nil {
|
|
|
- // Cache the default encoder for this model to avoid repeated failures
|
|
|
- tokenEncoderMap[model] = defaultTokenEncoder
|
|
|
- return defaultTokenEncoder
|
|
|
- }
|
|
|
-
|
|
|
- // Cache the new encoder
|
|
|
- tokenEncoderMap[model] = modelCodec
|
|
|
- return modelCodec
|
|
|
-}
|
|
|
-
|
|
|
-func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
|
|
|
- if text == "" {
|
|
|
- return 0
|
|
|
- }
|
|
|
- tkm, _ := tokenEncoder.Count(text)
|
|
|
- return tkm
|
|
|
-}
|
|
|
-
|
|
|
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
|
|
if fileMeta == nil {
|
|
|
return 0, fmt.Errorf("image_url_is_nil")
|
|
|
@@ -257,7 +199,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
|
|
return tiles*tileTokens + baseTokens, nil
|
|
|
}
|
|
|
|
|
|
-func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
|
|
+func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
|
|
// 是否统计token
|
|
|
if !constant.CountToken {
|
|
|
return 0, nil
|
|
|
@@ -375,14 +317,14 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
|
|
for i, file := range meta.Files {
|
|
|
switch file.FileType {
|
|
|
case types.FileTypeImage:
|
|
|
- if info.RelayFormat == types.RelayFormatGemini {
|
|
|
- tkm += 520 // gemini per input image tokens
|
|
|
- } else {
|
|
|
+ if common.IsOpenAITextModel(info.UpstreamModelName) {
|
|
|
token, err := getImageToken(file, model, info.IsStream)
|
|
|
if err != nil {
|
|
|
return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
|
|
|
}
|
|
|
tkm += token
|
|
|
+ } else {
|
|
|
+ tkm += 520
|
|
|
}
|
|
|
case types.FileTypeAudio:
|
|
|
tkm += 256
|
|
|
@@ -399,111 +341,6 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
|
|
return tkm, nil
|
|
|
}
|
|
|
|
|
|
-func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
|
|
|
- tkm := 0
|
|
|
-
|
|
|
- // Count tokens in messages
|
|
|
- msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream)
|
|
|
- if err != nil {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
- tkm += msgTokens
|
|
|
-
|
|
|
- // Count tokens in system message
|
|
|
- if request.System != "" {
|
|
|
- systemTokens := CountTokenInput(request.System, model)
|
|
|
- tkm += systemTokens
|
|
|
- }
|
|
|
-
|
|
|
- if request.Tools != nil {
|
|
|
- // check is array
|
|
|
- if tools, ok := request.Tools.([]any); ok {
|
|
|
- if len(tools) > 0 {
|
|
|
- parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools)
|
|
|
- if err1 != nil {
|
|
|
- return 0, fmt.Errorf("tools: Input should be a valid list: %v", err)
|
|
|
- }
|
|
|
- toolTokens, err2 := CountTokenClaudeTools(parsedTools, model)
|
|
|
- if err2 != nil {
|
|
|
- return 0, fmt.Errorf("tools: %v", err)
|
|
|
- }
|
|
|
- tkm += toolTokens
|
|
|
- }
|
|
|
- } else {
|
|
|
- return 0, errors.New("tools: Input should be a valid list")
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return tkm, nil
|
|
|
-}
|
|
|
-
|
|
|
-func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) {
|
|
|
- tokenEncoder := getTokenEncoder(model)
|
|
|
- tokenNum := 0
|
|
|
-
|
|
|
- for _, message := range messages {
|
|
|
- // Count tokens for role
|
|
|
- tokenNum += getTokenNum(tokenEncoder, message.Role)
|
|
|
- if message.IsStringContent() {
|
|
|
- tokenNum += getTokenNum(tokenEncoder, message.GetStringContent())
|
|
|
- } else {
|
|
|
- content, err := message.ParseContent()
|
|
|
- if err != nil {
|
|
|
- return 0, err
|
|
|
- }
|
|
|
- for _, mediaMessage := range content {
|
|
|
- switch mediaMessage.Type {
|
|
|
- case "text":
|
|
|
- tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText())
|
|
|
- case "image":
|
|
|
- //imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream)
|
|
|
- //if err != nil {
|
|
|
- // return 0, err
|
|
|
- //}
|
|
|
- tokenNum += 1000
|
|
|
- case "tool_use":
|
|
|
- if mediaMessage.Input != nil {
|
|
|
- tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
|
|
|
- inputJSON, _ := json.Marshal(mediaMessage.Input)
|
|
|
- tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
|
|
|
- }
|
|
|
- case "tool_result":
|
|
|
- if mediaMessage.Content != nil {
|
|
|
- contentJSON, _ := json.Marshal(mediaMessage.Content)
|
|
|
- tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Add a constant for message formatting (this may need adjustment based on Claude's exact formatting)
|
|
|
- tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting
|
|
|
-
|
|
|
- return tokenNum, nil
|
|
|
-}
|
|
|
-
|
|
|
-func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) {
|
|
|
- tokenEncoder := getTokenEncoder(model)
|
|
|
- tokenNum := 0
|
|
|
-
|
|
|
- for _, tool := range tools {
|
|
|
- tokenNum += getTokenNum(tokenEncoder, tool.Name)
|
|
|
- tokenNum += getTokenNum(tokenEncoder, tool.Description)
|
|
|
-
|
|
|
- schemaJSON, err := json.Marshal(tool.InputSchema)
|
|
|
- if err != nil {
|
|
|
- return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error()))
|
|
|
- }
|
|
|
- tokenNum += getTokenNum(tokenEncoder, string(schemaJSON))
|
|
|
- }
|
|
|
-
|
|
|
- // Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting)
|
|
|
- tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting
|
|
|
-
|
|
|
- return tokenNum, nil
|
|
|
-}
|
|
|
-
|
|
|
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
|
|
|
audioToken := 0
|
|
|
textToken := 0
|
|
|
@@ -578,31 +415,6 @@ func CountTokenInput(input any, model string) int {
|
|
|
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
|
|
}
|
|
|
|
|
|
-func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
|
|
- tokens := 0
|
|
|
- for _, message := range messages {
|
|
|
- tkm := CountTokenInput(message.Delta.GetContentString(), model)
|
|
|
- tokens += tkm
|
|
|
- if message.Delta.ToolCalls != nil {
|
|
|
- for _, tool := range message.Delta.ToolCalls {
|
|
|
- tkm := CountTokenInput(tool.Function.Name, model)
|
|
|
- tokens += tkm
|
|
|
- tkm = CountTokenInput(tool.Function.Arguments, model)
|
|
|
- tokens += tkm
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- return tokens
|
|
|
-}
|
|
|
-
|
|
|
-func CountTTSToken(text string, model string) int {
|
|
|
- if strings.HasPrefix(model, "tts") {
|
|
|
- return utf8.RuneCountInString(text)
|
|
|
- } else {
|
|
|
- return CountTextToken(text, model)
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
|
|
|
if audioBase64 == "" {
|
|
|
return 0, nil
|
|
|
@@ -625,17 +437,16 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
|
|
|
return int(duration / 60 * 200 / 0.24), nil
|
|
|
}
|
|
|
|
|
|
-//func CountAudioToken(sec float64, audioType string) {
|
|
|
-// if audioType == "input" {
|
|
|
-//
|
|
|
-// }
|
|
|
-//}
|
|
|
-
|
|
|
-// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
|
|
+// CountTextToken 统计文本的token数量,仅OpenAI模型使用tokenizer,其余模型使用估算
|
|
|
func CountTextToken(text string, model string) int {
|
|
|
if text == "" {
|
|
|
return 0
|
|
|
}
|
|
|
- tokenEncoder := getTokenEncoder(model)
|
|
|
- return getTokenNum(tokenEncoder, text)
|
|
|
+ if common.IsOpenAITextModel(model) {
|
|
|
+ tokenEncoder := getTokenEncoder(model)
|
|
|
+ return getTokenNum(tokenEncoder, text)
|
|
|
+ } else {
|
|
|
+ // 非openai模型,使用tiktoken-go计算没有意义,使用估算节省资源
|
|
|
+ return EstimateTokenByModel(model, text)
|
|
|
+ }
|
|
|
}
|