Просмотр исходного кода

Merge pull request #2355 from QuantumNous/feat/optimize-token-counter

feat: refactor token estimation logic
Calcium-Ion 4 недель назад
Родитель
Сommit
48635360cd

+ 17 - 0
common/model.go

@@ -17,6 +17,13 @@ var (
 		"flux-",
 		"flux.1-",
 	}
+	OpenAITextModels = []string{
+		"gpt-",
+		"o1",
+		"o3",
+		"o4",
+		"chatgpt",
+	}
 )
 
 func IsOpenAIResponseOnlyModel(modelName string) bool {
@@ -40,3 +47,13 @@ func IsImageGenerationModel(modelName string) bool {
 	}
 	return false
 }
+
+func IsOpenAITextModel(modelName string) bool {
+	modelName = strings.ToLower(modelName)
+	for _, m := range OpenAITextModels {
+		if strings.Contains(modelName, m) {
+			return true
+		}
+	}
+	return false
+}

+ 3 - 2
constant/context_key.go

@@ -3,8 +3,9 @@ package constant
 type ContextKey string
 
 const (
-	ContextKeyTokenCountMeta ContextKey = "token_count_meta"
-	ContextKeyPromptTokens   ContextKey = "prompt_tokens"
+	ContextKeyTokenCountMeta  ContextKey = "token_count_meta"
+	ContextKeyPromptTokens    ContextKey = "prompt_tokens"
+	ContextKeyEstimatedTokens ContextKey = "estimated_tokens"
 
 	ContextKeyOriginalModel    ContextKey = "original_model"
 	ContextKeyRequestStartTime ContextKey = "request_start_time"

+ 1 - 1
controller/channel-test.go

@@ -351,7 +351,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
 			newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
 		}
 	}
-	info.PromptTokens = usage.PromptTokens
+	info.SetEstimatePromptTokens(usage.PromptTokens)
 
 	quota := 0
 	if !priceData.UsePrice {

+ 2 - 2
controller/relay.go

@@ -125,13 +125,13 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 		}
 	}
 
-	tokens, err := service.CountRequestToken(c, meta, relayInfo)
+	tokens, err := service.EstimateRequestToken(c, meta, relayInfo)
 	if err != nil {
 		newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
 		return
 	}
 
-	relayInfo.SetPromptTokens(tokens)
+	relayInfo.SetEstimatePromptTokens(tokens)
 
 	priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
 	if err != nil {

+ 2 - 5
relay/channel/claude/relay-claude.go

@@ -673,7 +673,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
 
 	if requestMode == RequestModeCompletion {
-		claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
+		claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
 	} else {
 		if claudeInfo.Usage.PromptTokens == 0 {
 			//上游出错
@@ -734,10 +734,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 		return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
 	}
 	if requestMode == RequestModeCompletion {
-		completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
-		claudeInfo.Usage.PromptTokens = info.PromptTokens
-		claudeInfo.Usage.CompletionTokens = completionTokens
-		claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
+		claudeInfo.Usage = service.ResponseText2Usage(c, claudeResponse.Completion, info.UpstreamModelName, info.GetEstimatePromptTokens())
 	} else {
 		claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
 		claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens

+ 3 - 7
relay/channel/cloudflare/relay_cloudflare.go

@@ -74,7 +74,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
 	if err := scanner.Err(); err != nil {
 		logger.LogError(c, "error_scanning_stream_response: "+err.Error())
 	}
-	usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
+	usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
 	if info.ShouldIncludeUsage {
 		response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
 		err := helper.ObjectData(c, response)
@@ -105,7 +105,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
 	for _, choice := range response.Choices {
 		responseText += choice.Message.StringContent()
 	}
-	usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
+	usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
 	response.Usage = *usage
 	response.Id = helper.GetResponseID(c)
 	jsonResponse, err := json.Marshal(response)
@@ -142,10 +142,6 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
 	c.Writer.WriteHeader(resp.StatusCode)
 	_, _ = c.Writer.Write(jsonResponse)
 
-	usage := &dto.Usage{}
-	usage.PromptTokens = info.PromptTokens
-	usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
-	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
-
+	usage := service.ResponseText2Usage(c, cfResp.Result.Text, info.UpstreamModelName, info.GetEstimatePromptTokens())
 	return nil, usage
 }

+ 3 - 3
relay/channel/cohere/relay-cohere.go

@@ -165,7 +165,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 		}
 	})
 	if usage.PromptTokens == 0 {
-		usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
 	}
 	return usage, nil
 }
@@ -225,9 +225,9 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	}
 	usage := dto.Usage{}
 	if cohereResp.Meta.BilledUnits.InputTokens == 0 {
-		usage.PromptTokens = info.PromptTokens
+		usage.PromptTokens = info.GetEstimatePromptTokens()
 		usage.CompletionTokens = 0
-		usage.TotalTokens = info.PromptTokens
+		usage.TotalTokens = info.GetEstimatePromptTokens()
 	} else {
 		usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
 		usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens

+ 1 - 1
relay/channel/dify/relay-dify.go

@@ -246,7 +246,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 	})
 	helper.Done(c)
 	if usage.TotalTokens == 0 {
-		usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
 	}
 	usage.CompletionTokens += nodeToken
 	return usage, nil

+ 1 - 7
relay/channel/gemini/relay-gemini-native.go

@@ -5,7 +5,6 @@ import (
 	"net/http"
 
 	"github.com/QuantumNous/new-api/common"
-	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/logger"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
@@ -70,12 +69,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
 		println(string(responseBody))
 	}
 
-	usage := &dto.Usage{
-		PromptTokens: info.PromptTokens,
-		TotalTokens:  info.PromptTokens,
-	}
-
-	common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
+	usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
 
 	if info.IsGeminiBatchEmbedding {
 		var geminiResponse dto.GeminiBatchEmbeddingResponse

+ 2 - 6
relay/channel/gemini/relay-gemini.go

@@ -1115,7 +1115,7 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 	if usage.CompletionTokens <= 0 {
 		str := responseText.String()
 		if len(str) > 0 {
-			usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.PromptTokens)
+			usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
 		} else {
 			usage = &dto.Usage{}
 		}
@@ -1288,11 +1288,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
 	// Google has not yet clarified how embedding models will be billed
 	// refer to openai billing method to use input tokens billing
 	// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
-	usage := &dto.Usage{
-		PromptTokens:     info.PromptTokens,
-		CompletionTokens: 0,
-		TotalTokens:      info.PromptTokens,
-	}
+	usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
 	openAIResponse.Usage = *usage
 
 	jsonResponse, jsonErr := common.Marshal(openAIResponse)

+ 1 - 1
relay/channel/minimax/tts.go

@@ -163,7 +163,7 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re
 	}
 
 	usage = &dto.Usage{
-		PromptTokens:     info.PromptTokens,
+		PromptTokens:     info.GetEstimatePromptTokens(),
 		CompletionTokens: 0,
 		TotalTokens:      int(minimaxResp.ExtraInfo.UsageCharacters),
 	}

+ 6 - 6
relay/channel/openai/relay-openai.go

@@ -183,7 +183,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 	}
 
 	if !containStreamUsage {
-		usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
 		usage.CompletionTokens += toolCount * 7
 	}
 
@@ -245,9 +245,9 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 			}
 		}
 		simpleResponse.Usage = dto.Usage{
-			PromptTokens:     info.PromptTokens,
+			PromptTokens:     info.GetEstimatePromptTokens(),
 			CompletionTokens: completionTokens,
-			TotalTokens:      info.PromptTokens + completionTokens,
+			TotalTokens:      info.GetEstimatePromptTokens() + completionTokens,
 		}
 		usageModified = true
 	}
@@ -336,8 +336,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	// and can be terminated directly.
 	defer service.CloseResponseBodyGracefully(resp)
 	usage := &dto.Usage{}
-	usage.PromptTokens = info.PromptTokens
-	usage.TotalTokens = info.PromptTokens
+	usage.PromptTokens = info.GetEstimatePromptTokens()
+	usage.TotalTokens = info.GetEstimatePromptTokens()
 	for k, v := range resp.Header {
 		c.Writer.Header().Set(k, v[0])
 	}
@@ -383,7 +383,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	}
 
 	usage := &dto.Usage{}
-	usage.PromptTokens = info.PromptTokens
+	usage.PromptTokens = info.GetEstimatePromptTokens()
 	usage.CompletionTokens = 0
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	return nil, usage

+ 1 - 1
relay/channel/openai/relay_responses.go

@@ -141,7 +141,7 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
 	}
 
 	if usage.PromptTokens == 0 && usage.CompletionTokens != 0 {
-		usage.PromptTokens = info.PromptTokens
+		usage.PromptTokens = info.GetEstimatePromptTokens()
 	}
 
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens

+ 1 - 1
relay/channel/palm/adaptor.go

@@ -81,7 +81,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 		var responseText string
 		err, responseText = palmStreamHandler(c, resp)
-		usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
 	} else {
 		usage, err = palmHandler(c, info, resp)
 	}

+ 3 - 8
relay/channel/palm/relay-palm.go

@@ -121,13 +121,8 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
 		}, resp.StatusCode)
 	}
 	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
-	completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName)
-	usage := dto.Usage{
-		PromptTokens:     info.PromptTokens,
-		CompletionTokens: completionTokens,
-		TotalTokens:      info.PromptTokens + completionTokens,
-	}
-	fullTextResponse.Usage = usage
+	usage := service.ResponseText2Usage(c, palmResponse.Candidates[0].Content, info.UpstreamModelName, info.GetEstimatePromptTokens())
+	fullTextResponse.Usage = *usage
 	jsonResponse, err := common.Marshal(fullTextResponse)
 	if err != nil {
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
@@ -135,5 +130,5 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)
 	service.IOCopyBytesGracefully(c, resp, jsonResponse)
-	return &usage, nil
+	return usage, nil
 }

+ 2 - 2
relay/channel/tencent/relay-tencent.go

@@ -105,7 +105,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
 		data = strings.TrimPrefix(data, "data:")
 
 		var tencentResponse TencentChatResponse
-		err := json.Unmarshal([]byte(data), &tencentResponse)
+		err := common.Unmarshal([]byte(data), &tencentResponse)
 		if err != nil {
 			common.SysLog("error unmarshalling stream response: " + err.Error())
 			continue
@@ -130,7 +130,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
 
 	service.CloseResponseBodyGracefully(resp)
 
-	return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens), nil
+	return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens()), nil
 }
 
 func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {

+ 6 - 6
relay/channel/volcengine/tts.go

@@ -184,9 +184,9 @@ func handleTTSResponse(c *gin.Context, resp *http.Response, info *relaycommon.Re
 	c.Data(http.StatusOK, contentType, audioData)
 
 	usage = &dto.Usage{
-		PromptTokens:     info.PromptTokens,
+		PromptTokens:     info.GetEstimatePromptTokens(),
 		CompletionTokens: 0,
-		TotalTokens:      info.PromptTokens,
+		TotalTokens:      info.GetEstimatePromptTokens(),
 	}
 
 	return usage, nil
@@ -284,9 +284,9 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
 			if msg.Sequence < 0 {
 				c.Status(http.StatusOK)
 				usage = &dto.Usage{
-					PromptTokens:     info.PromptTokens,
+					PromptTokens:     info.GetEstimatePromptTokens(),
 					CompletionTokens: 0,
-					TotalTokens:      info.PromptTokens,
+					TotalTokens:      info.GetEstimatePromptTokens(),
 				}
 				return usage, nil
 			}
@@ -297,9 +297,9 @@ func handleTTSWebSocketResponse(c *gin.Context, requestURL string, volcRequest V
 
 	c.Status(http.StatusOK)
 	usage = &dto.Usage{
-		PromptTokens:     info.PromptTokens,
+		PromptTokens:     info.GetEstimatePromptTokens(),
 		CompletionTokens: 0,
-		TotalTokens:      info.PromptTokens,
+		TotalTokens:      info.GetEstimatePromptTokens(),
 	}
 	return usage, nil
 }

+ 1 - 1
relay/channel/xai/text.go

@@ -70,7 +70,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 	})
 
 	if !containStreamUsage {
-		usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
 		usage.CompletionTokens += toolCount * 7
 	}
 

+ 21 - 5
relay/common/relay_info.go

@@ -73,6 +73,11 @@ type ChannelMeta struct {
 	SupportStreamOptions bool // 是否支持流式选项
 }
 
+type TokenCountMeta struct {
+	//promptTokens int
+	estimatePromptTokens int
+}
+
 type RelayInfo struct {
 	TokenId           int
 	TokenKey          string
@@ -91,7 +96,6 @@ type RelayInfo struct {
 	RelayMode              int
 	OriginModelName        string
 	RequestURLPath         string
-	PromptTokens           int
 	ShouldIncludeUsage     bool
 	DisablePing            bool // 是否禁止向下游发送自定义 Ping
 	ClientWs               *websocket.Conn
@@ -115,6 +119,7 @@ type RelayInfo struct {
 	Request dto.Request
 
 	ThinkingContentInfo
+	TokenCountMeta
 	*ClaudeConvertInfo
 	*RerankerInfo
 	*ResponsesUsageInfo
@@ -189,7 +194,7 @@ func (info *RelayInfo) ToString() string {
 	fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground)
 	fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath)
 	fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName)
-	fmt.Fprintf(b, "PromptTokens: %d, ", info.PromptTokens)
+	fmt.Fprintf(b, "EstimatePromptTokens: %d, ", info.estimatePromptTokens)
 	fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage)
 	fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing)
 	fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount)
@@ -391,7 +396,6 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
 		UserEmail:  common.GetContextKeyString(c, constant.ContextKeyUserEmail),
 
 		OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
-		PromptTokens:    common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
 
 		TokenId:        common.GetContextKeyInt(c, constant.ContextKeyTokenId),
 		TokenKey:       common.GetContextKeyString(c, constant.ContextKeyTokenKey),
@@ -408,6 +412,10 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
 			IsFirstThinkingContent:  true,
 			SendLastThinkingContent: false,
 		},
+		TokenCountMeta: TokenCountMeta{
+			//promptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
+			estimatePromptTokens: common.GetContextKeyInt(c, constant.ContextKeyEstimatedTokens),
+		},
 	}
 
 	if info.RelayMode == relayconstant.RelayModeUnknown {
@@ -463,8 +471,16 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
 	}
 }
 
-func (info *RelayInfo) SetPromptTokens(promptTokens int) {
-	info.PromptTokens = promptTokens
+//func (info *RelayInfo) SetPromptTokens(promptTokens int) {
+//	info.promptTokens = promptTokens
+//}
+
+func (info *RelayInfo) SetEstimatePromptTokens(promptTokens int) {
+	info.estimatePromptTokens = promptTokens
+}
+
+func (info *RelayInfo) GetEstimatePromptTokens() int {
+	return info.estimatePromptTokens
 }
 
 func (info *RelayInfo) SetFirstResponseTime() {

+ 2 - 2
relay/common_handler/rerank.go

@@ -57,8 +57,8 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 		jinaResp = dto.RerankResponse{
 			Results: jinaRespResults,
 			Usage: dto.Usage{
-				PromptTokens: info.PromptTokens,
-				TotalTokens:  info.PromptTokens,
+				PromptTokens: info.GetEstimatePromptTokens(),
+				TotalTokens:  info.GetEstimatePromptTokens(),
 			},
 		}
 	} else {

+ 2 - 2
relay/compatible_handler.go

@@ -192,9 +192,9 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
 	if usage == nil {
 		usage = &dto.Usage{
-			PromptTokens:     relayInfo.PromptTokens,
+			PromptTokens:     relayInfo.GetEstimatePromptTokens(),
 			CompletionTokens: 0,
-			TotalTokens:      relayInfo.PromptTokens,
+			TotalTokens:      relayInfo.GetEstimatePromptTokens(),
 		}
 		extraContent += "(可能是请求出错)"
 	}

+ 9 - 3
service/convert.go

@@ -209,7 +209,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
 			Type:  "message",
 			Role:  "assistant",
 			Usage: &dto.ClaudeUsage{
-				InputTokens:  info.PromptTokens,
+				InputTokens:  info.GetEstimatePromptTokens(),
 				OutputTokens: 0,
 			},
 		}
@@ -734,12 +734,18 @@ func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamRespon
 	geminiResponse := &dto.GeminiChatResponse{
 		Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
 		UsageMetadata: dto.GeminiUsageMetadata{
-			PromptTokenCount:     info.PromptTokens,
+			PromptTokenCount:     info.GetEstimatePromptTokens(),
 			CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
-			TotalTokenCount:      info.PromptTokens,
+			TotalTokenCount:      info.GetEstimatePromptTokens(),
 		},
 	}
 
+	if openAIResponse.Usage != nil {
+		geminiResponse.UsageMetadata.PromptTokenCount = openAIResponse.Usage.PromptTokens
+		geminiResponse.UsageMetadata.CandidatesTokenCount = openAIResponse.Usage.CompletionTokens
+		geminiResponse.UsageMetadata.TotalTokenCount = openAIResponse.Usage.TotalTokens
+	}
+
 	for _, choice := range openAIResponse.Choices {
 		candidate := dto.GeminiChatCandidate{
 			Index:         int64(choice.Index),

+ 12 - 201
service/token_counter.go

@@ -1,7 +1,6 @@
 package service
 
 import (
-	"encoding/json"
 	"errors"
 	"fmt"
 	"image"
@@ -12,7 +11,6 @@ import (
 	"math"
 	"path/filepath"
 	"strings"
-	"sync"
 	"unicode/utf8"
 
 	"github.com/QuantumNous/new-api/common"
@@ -23,64 +21,8 @@ import (
 	"github.com/QuantumNous/new-api/types"
 
 	"github.com/gin-gonic/gin"
-	"github.com/tiktoken-go/tokenizer"
-	"github.com/tiktoken-go/tokenizer/codec"
 )
 
-// tokenEncoderMap won't grow after initialization
-var defaultTokenEncoder tokenizer.Codec
-
-// tokenEncoderMap is used to store token encoders for different models
-var tokenEncoderMap = make(map[string]tokenizer.Codec)
-
-// tokenEncoderMutex protects tokenEncoderMap for concurrent access
-var tokenEncoderMutex sync.RWMutex
-
-func InitTokenEncoders() {
-	common.SysLog("initializing token encoders")
-	defaultTokenEncoder = codec.NewCl100kBase()
-	common.SysLog("token encoders initialized")
-}
-
-func getTokenEncoder(model string) tokenizer.Codec {
-	// First, try to get the encoder from cache with read lock
-	tokenEncoderMutex.RLock()
-	if encoder, exists := tokenEncoderMap[model]; exists {
-		tokenEncoderMutex.RUnlock()
-		return encoder
-	}
-	tokenEncoderMutex.RUnlock()
-
-	// If not in cache, create new encoder with write lock
-	tokenEncoderMutex.Lock()
-	defer tokenEncoderMutex.Unlock()
-
-	// Double-check if another goroutine already created the encoder
-	if encoder, exists := tokenEncoderMap[model]; exists {
-		return encoder
-	}
-
-	// Create new encoder
-	modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
-	if err != nil {
-		// Cache the default encoder for this model to avoid repeated failures
-		tokenEncoderMap[model] = defaultTokenEncoder
-		return defaultTokenEncoder
-	}
-
-	// Cache the new encoder
-	tokenEncoderMap[model] = modelCodec
-	return modelCodec
-}
-
-func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
-	if text == "" {
-		return 0
-	}
-	tkm, _ := tokenEncoder.Count(text)
-	return tkm
-}
-
 func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
 	if fileMeta == nil {
 		return 0, fmt.Errorf("image_url_is_nil")
@@ -257,7 +199,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
 	return tiles*tileTokens + baseTokens, nil
 }
 
-func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
+func EstimateRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
 	// 是否统计token
 	if !constant.CountToken {
 		return 0, nil
@@ -375,14 +317,14 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
 	for i, file := range meta.Files {
 		switch file.FileType {
 		case types.FileTypeImage:
-			if info.RelayFormat == types.RelayFormatGemini {
-				tkm += 520 // gemini per input image tokens
-			} else {
+			if common.IsOpenAITextModel(info.UpstreamModelName) {
 				token, err := getImageToken(file, model, info.IsStream)
 				if err != nil {
 					return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err)
 				}
 				tkm += token
+			} else {
+				tkm += 520
 			}
 		case types.FileTypeAudio:
 			tkm += 256
@@ -399,111 +341,6 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
 	return tkm, nil
 }
 
-func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
-	tkm := 0
-
-	// Count tokens in messages
-	msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream)
-	if err != nil {
-		return 0, err
-	}
-	tkm += msgTokens
-
-	// Count tokens in system message
-	if request.System != "" {
-		systemTokens := CountTokenInput(request.System, model)
-		tkm += systemTokens
-	}
-
-	if request.Tools != nil {
-		// check is array
-		if tools, ok := request.Tools.([]any); ok {
-			if len(tools) > 0 {
-				parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools)
-				if err1 != nil {
-					return 0, fmt.Errorf("tools: Input should be a valid list: %v", err)
-				}
-				toolTokens, err2 := CountTokenClaudeTools(parsedTools, model)
-				if err2 != nil {
-					return 0, fmt.Errorf("tools: %v", err)
-				}
-				tkm += toolTokens
-			}
-		} else {
-			return 0, errors.New("tools: Input should be a valid list")
-		}
-	}
-
-	return tkm, nil
-}
-
-func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) {
-	tokenEncoder := getTokenEncoder(model)
-	tokenNum := 0
-
-	for _, message := range messages {
-		// Count tokens for role
-		tokenNum += getTokenNum(tokenEncoder, message.Role)
-		if message.IsStringContent() {
-			tokenNum += getTokenNum(tokenEncoder, message.GetStringContent())
-		} else {
-			content, err := message.ParseContent()
-			if err != nil {
-				return 0, err
-			}
-			for _, mediaMessage := range content {
-				switch mediaMessage.Type {
-				case "text":
-					tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText())
-				case "image":
-					//imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream)
-					//if err != nil {
-					//	return 0, err
-					//}
-					tokenNum += 1000
-				case "tool_use":
-					if mediaMessage.Input != nil {
-						tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
-						inputJSON, _ := json.Marshal(mediaMessage.Input)
-						tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
-					}
-				case "tool_result":
-					if mediaMessage.Content != nil {
-						contentJSON, _ := json.Marshal(mediaMessage.Content)
-						tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
-					}
-				}
-			}
-		}
-	}
-
-	// Add a constant for message formatting (this may need adjustment based on Claude's exact formatting)
-	tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting
-
-	return tokenNum, nil
-}
-
-func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) {
-	tokenEncoder := getTokenEncoder(model)
-	tokenNum := 0
-
-	for _, tool := range tools {
-		tokenNum += getTokenNum(tokenEncoder, tool.Name)
-		tokenNum += getTokenNum(tokenEncoder, tool.Description)
-
-		schemaJSON, err := json.Marshal(tool.InputSchema)
-		if err != nil {
-			return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error()))
-		}
-		tokenNum += getTokenNum(tokenEncoder, string(schemaJSON))
-	}
-
-	// Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting)
-	tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting
-
-	return tokenNum, nil
-}
-
 func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
 	audioToken := 0
 	textToken := 0
@@ -578,31 +415,6 @@ func CountTokenInput(input any, model string) int {
 	return CountTokenInput(fmt.Sprintf("%v", input), model)
 }
 
-func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
-	tokens := 0
-	for _, message := range messages {
-		tkm := CountTokenInput(message.Delta.GetContentString(), model)
-		tokens += tkm
-		if message.Delta.ToolCalls != nil {
-			for _, tool := range message.Delta.ToolCalls {
-				tkm := CountTokenInput(tool.Function.Name, model)
-				tokens += tkm
-				tkm = CountTokenInput(tool.Function.Arguments, model)
-				tokens += tkm
-			}
-		}
-	}
-	return tokens
-}
-
-func CountTTSToken(text string, model string) int {
-	if strings.HasPrefix(model, "tts") {
-		return utf8.RuneCountInString(text)
-	} else {
-		return CountTextToken(text, model)
-	}
-}
-
 func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
 	if audioBase64 == "" {
 		return 0, nil
@@ -625,17 +437,16 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
 	return int(duration / 60 * 200 / 0.24), nil
 }
 
-//func CountAudioToken(sec float64, audioType string) {
-//	if audioType == "input" {
-//
-//	}
-//}
-
-// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
+// CountTextToken 统计文本的token数量,仅OpenAI模型使用tokenizer,其余模型使用估算
 func CountTextToken(text string, model string) int {
 	if text == "" {
 		return 0
 	}
-	tokenEncoder := getTokenEncoder(model)
-	return getTokenNum(tokenEncoder, text)
+	if common.IsOpenAITextModel(model) {
+		tokenEncoder := getTokenEncoder(model)
+		return getTokenNum(tokenEncoder, text)
+	} else {
+		// 非openai模型,使用tiktoken-go计算没有意义,使用估算节省资源
+		return EstimateTokenByModel(model, text)
+	}
 }

+ 230 - 0
service/token_estimator.go

@@ -0,0 +1,230 @@
+package service
+
+import (
+	"math"
+	"strings"
+	"sync"
+	"unicode"
+)
+
+// Provider 定义模型厂商大类
+type Provider string
+
+const (
+	OpenAI  Provider = "openai"  // 代表 GPT-3.5, GPT-4, GPT-4o
+	Gemini  Provider = "gemini"  // 代表 Gemini 1.0, 1.5 Pro/Flash
+	Claude  Provider = "claude"  // 代表 Claude 3, 3.5 Sonnet
+	Unknown Provider = "unknown" // 兜底默认
+)
+
+// multipliers 定义不同厂商的计费权重
+type multipliers struct {
+	Word       float64 // 英文单词 (每词)
+	Number     float64 // 数字 (每连续数字串)
+	CJK        float64 // 中日韩字符 (每字)
+	Symbol     float64 // 普通标点符号 (每个)
+	MathSymbol float64 // 数学符号 (∑,∫,∂,√等,每个)
+	URLDelim   float64 // URL分隔符 (/,:,?,&,=,#,%) - tokenizer优化好
+	AtSign     float64 // @符号 - 导致单词切分,消耗较高
+	Emoji      float64 // Emoji表情 (每个)
+	Newline    float64 // 换行符/制表符 (每个)
+	Space      float64 // 空格 (每个)
+	BasePad    int     // 基础起步消耗 (Start/End tokens)
+}
+
+var (
+	multipliersMap = map[Provider]multipliers{
+		Gemini: {
+			Word: 1.15, Number: 2.8, CJK: 0.68, Symbol: 0.38, MathSymbol: 1.05, URLDelim: 1.2, AtSign: 2.5, Emoji: 1.08, Newline: 1.15, Space: 0.2, BasePad: 0,
+		},
+		Claude: {
+			Word: 1.13, Number: 1.63, CJK: 1.21, Symbol: 0.4, MathSymbol: 4.52, URLDelim: 1.26, AtSign: 2.82, Emoji: 2.6, Newline: 0.89, Space: 0.39, BasePad: 0,
+		},
+		OpenAI: {
+			Word: 1.02, Number: 1.55, CJK: 0.85, Symbol: 0.4, MathSymbol: 2.68, URLDelim: 1.0, AtSign: 2.0, Emoji: 2.12, Newline: 0.5, Space: 0.42, BasePad: 0,
+		},
+	}
+	multipliersLock sync.RWMutex
+)
+
+// getMultipliers 根据厂商获取权重配置
+func getMultipliers(p Provider) multipliers {
+	multipliersLock.RLock()
+	defer multipliersLock.RUnlock()
+
+	switch p {
+	case Gemini:
+		return multipliersMap[Gemini]
+	case Claude:
+		return multipliersMap[Claude]
+	case OpenAI:
+		return multipliersMap[OpenAI]
+	default:
+		// 默认兜底 (按 OpenAI 的算)
+		return multipliersMap[OpenAI]
+	}
+}
+
+// EstimateToken 计算 Token 数量
+func EstimateToken(provider Provider, text string) int {
+	m := getMultipliers(provider)
+	var count float64
+
+	// 状态机变量
+	type WordType int
+	const (
+		None WordType = iota
+		Latin
+		Number
+	)
+	currentWordType := None
+
+	for _, r := range text {
+		// 1. 处理空格和换行符
+		if unicode.IsSpace(r) {
+			currentWordType = None
+			// 换行符和制表符使用Newline权重
+			if r == '\n' || r == '\t' {
+				count += m.Newline
+			} else {
+				// 普通空格使用Space权重
+				count += m.Space
+			}
+			continue
+		}
+
+		// 2. 处理 CJK (中日韩) - 按字符计费
+		if isCJK(r) {
+			currentWordType = None
+			count += m.CJK
+			continue
+		}
+
+		// 3. 处理Emoji - 使用专门的Emoji权重
+		if isEmoji(r) {
+			currentWordType = None
+			count += m.Emoji
+			continue
+		}
+
+		// 4. 处理拉丁字母/数字 (英文单词)
+		if isLatinOrNumber(r) {
+			isNum := unicode.IsNumber(r)
+			newType := Latin
+			if isNum {
+				newType = Number
+			}
+
+			// 如果之前不在单词中,或者类型发生变化(字母<->数字),则视为新token
+			// 注意:对于OpenAI,通常"version 3.5"会切分,"abc123xyz"有时也会切分
+			// 这里简单起见,字母和数字切换时增加权重
+			if currentWordType == None || currentWordType != newType {
+				if newType == Number {
+					count += m.Number
+				} else {
+					count += m.Word
+				}
+				currentWordType = newType
+			}
+			// 单词中间的字符不额外计费
+			continue
+		}
+
+		// 5. 处理标点符号/特殊字符 - 按类型使用不同权重
+		currentWordType = None
+		if isMathSymbol(r) {
+			count += m.MathSymbol
+		} else if r == '@' {
+			count += m.AtSign
+		} else if isURLDelim(r) {
+			count += m.URLDelim
+		} else {
+			count += m.Symbol
+		}
+	}
+
+	// 向上取整并加上基础 padding
+	return int(math.Ceil(count)) + m.BasePad
+}
+
+// 辅助:判断是否为 CJK 字符
+func isCJK(r rune) bool {
+	return unicode.Is(unicode.Han, r) ||
+		(r >= 0x3040 && r <= 0x30FF) || // 日文
+		(r >= 0xAC00 && r <= 0xD7A3) // 韩文
+}
+
+// 辅助:判断是否为单词主体 (字母或数字)
+func isLatinOrNumber(r rune) bool {
+	return unicode.IsLetter(r) || unicode.IsNumber(r)
+}
+
+// 辅助:判断是否为Emoji字符
+func isEmoji(r rune) bool {
+	// Emoji的Unicode范围
+	// 基本范围:0x1F300-0x1F9FF (Emoticons, Symbols, Pictographs)
+	// 补充范围:0x2600-0x26FF (Misc Symbols), 0x2700-0x27BF (Dingbats)
+	// 表情符号:0x1F600-0x1F64F (Emoticons)
+	// 其他:0x1F900-0x1F9FF (Supplemental Symbols and Pictographs)
+	return (r >= 0x1F300 && r <= 0x1F9FF) ||
+		(r >= 0x2600 && r <= 0x26FF) ||
+		(r >= 0x2700 && r <= 0x27BF) ||
+		(r >= 0x1F600 && r <= 0x1F64F) ||
+		(r >= 0x1F900 && r <= 0x1F9FF) ||
+		(r >= 0x1FA00 && r <= 0x1FAFF) // Symbols and Pictographs Extended-A
+}
+
+// 辅助:判断是否为数学符号
+func isMathSymbol(r rune) bool {
+	// 数学运算符和符号
+	// 基本数学符号:∑ ∫ ∂ √ ∞ ≤ ≥ ≠ ≈ ± × ÷
+	// 上下标数字:² ³ ¹ ⁴ ⁵ ⁶ ⁷ ⁸ ⁹ ⁰
+	// 希腊字母等也常用于数学
+	mathSymbols := "∑∫∂√∞≤≥≠≈±×÷∈∉∋∌⊂⊃⊆⊇∪∩∧∨¬∀∃∄∅∆∇∝∟∠∡∢°′″‴⁺⁻⁼⁽⁾ⁿ₀₁₂₃₄₅₆₇₈₉₊₋₌₍₎²³¹⁴⁵⁶⁷⁸⁹⁰"
+	for _, m := range mathSymbols {
+		if r == m {
+			return true
+		}
+	}
+	// Mathematical Operators (U+2200–U+22FF)
+	if r >= 0x2200 && r <= 0x22FF {
+		return true
+	}
+	// Supplemental Mathematical Operators (U+2A00–U+2AFF)
+	if r >= 0x2A00 && r <= 0x2AFF {
+		return true
+	}
+	// Mathematical Alphanumeric Symbols (U+1D400–U+1D7FF)
+	if r >= 0x1D400 && r <= 0x1D7FF {
+		return true
+	}
+	return false
+}
+
+// 辅助:判断是否为URL分隔符(tokenizer对这些优化较好)
+func isURLDelim(r rune) bool {
+	// URL中常见的分隔符,tokenizer通常优化处理
+	urlDelims := "/:?&=;#%"
+	for _, d := range urlDelims {
+		if r == d {
+			return true
+		}
+	}
+	return false
+}
+
+func EstimateTokenByModel(model, text string) int {
+	// strings.Contains(model, "gpt-4o")
+	if text == "" {
+		return 0
+	}
+
+	model = strings.ToLower(model)
+	if strings.Contains(model, "gemini") {
+		return EstimateToken(Gemini, text)
+	} else if strings.Contains(model, "claude") {
+		return EstimateToken(Claude, text)
+	} else {
+		return EstimateToken(OpenAI, text)
+	}
+}

+ 63 - 0
service/tokenizer.go

@@ -0,0 +1,63 @@
+package service
+
+import (
+	"sync"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/tiktoken-go/tokenizer"
+	"github.com/tiktoken-go/tokenizer/codec"
+)
+
+// tokenEncoderMap won't grow after initialization
+var defaultTokenEncoder tokenizer.Codec
+
+// tokenEncoderMap is used to store token encoders for different models
+var tokenEncoderMap = make(map[string]tokenizer.Codec)
+
+// tokenEncoderMutex protects tokenEncoderMap for concurrent access
+var tokenEncoderMutex sync.RWMutex
+
+func InitTokenEncoders() {
+	common.SysLog("initializing token encoders")
+	defaultTokenEncoder = codec.NewCl100kBase()
+	common.SysLog("token encoders initialized")
+}
+
+func getTokenEncoder(model string) tokenizer.Codec {
+	// First, try to get the encoder from cache with read lock
+	tokenEncoderMutex.RLock()
+	if encoder, exists := tokenEncoderMap[model]; exists {
+		tokenEncoderMutex.RUnlock()
+		return encoder
+	}
+	tokenEncoderMutex.RUnlock()
+
+	// If not in cache, create new encoder with write lock
+	tokenEncoderMutex.Lock()
+	defer tokenEncoderMutex.Unlock()
+
+	// Double-check if another goroutine already created the encoder
+	if encoder, exists := tokenEncoderMap[model]; exists {
+		return encoder
+	}
+
+	// Create new encoder
+	modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
+	if err != nil {
+		// Cache the default encoder for this model to avoid repeated failures
+		tokenEncoderMap[model] = defaultTokenEncoder
+		return defaultTokenEncoder
+	}
+
+	// Cache the new encoder
+	tokenEncoderMap[model] = modelCodec
+	return modelCodec
+}
+
+func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
+	if text == "" {
+		return 0
+	}
+	tkm, _ := tokenEncoder.Count(text)
+	return tkm
+}

+ 1 - 2
service/usage_helpr.go

@@ -23,8 +23,7 @@ func ResponseText2Usage(c *gin.Context, responseText string, modeName string, pr
 	common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
 	usage := &dto.Usage{}
 	usage.PromptTokens = promptTokens
-	ctkm := CountTextToken(responseText, modeName)
-	usage.CompletionTokens = ctkm
+	usage.CompletionTokens = EstimateTokenByModel(modeName, responseText)
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	return usage
 }