Selaa lähdekoodia

feat: Add ContextKeyLocalCountTokens and update ResponseText2Usage to use context in multiple channels

CaIon 1 kuukausi sitten
vanhempi
sitoutus
84745d5ca4

+ 2 - 0
constant/context_key.go

@@ -46,5 +46,7 @@ const (
 	ContextKeyUsingGroup  ContextKey = "group"
 	ContextKeyUserName    ContextKey = "username"
 
+	ContextKeyLocalCountTokens ContextKey = "local_count_tokens"
+
 	ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
 )

+ 2 - 2
relay/channel/claude/relay-claude.go

@@ -673,7 +673,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
 
 	if requestMode == RequestModeCompletion {
-		claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
+		claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
 	} else {
 		if claudeInfo.Usage.PromptTokens == 0 {
 			//上游出错
@@ -682,7 +682,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
 			if common.DebugEnabled {
 				common.SysLog("claude response usage is not complete, maybe upstream error")
 			}
-			claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
+			claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
 		}
 	}
 

+ 2 - 2
relay/channel/cloudflare/relay_cloudflare.go

@@ -74,7 +74,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
 	if err := scanner.Err(); err != nil {
 		logger.LogError(c, "error_scanning_stream_response: "+err.Error())
 	}
-	usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
 	if info.ShouldIncludeUsage {
 		response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
 		err := helper.ObjectData(c, response)
@@ -105,7 +105,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
 	for _, choice := range response.Choices {
 		responseText += choice.Message.StringContent()
 	}
-	usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
 	response.Usage = *usage
 	response.Id = helper.GetResponseID(c)
 	jsonResponse, err := json.Marshal(response)

+ 1 - 1
relay/channel/cohere/relay-cohere.go

@@ -165,7 +165,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 		}
 	})
 	if usage.PromptTokens == 0 {
-		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
 	}
 	return usage, nil
 }

+ 1 - 1
relay/channel/coze/relay-coze.go

@@ -142,7 +142,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
 	helper.Done(c)
 
 	if usage.TotalTokens == 0 {
-		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
+		usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
 	}
 
 	return usage, nil

+ 1 - 1
relay/channel/dify/relay-dify.go

@@ -246,7 +246,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 	})
 	helper.Done(c)
 	if usage.TotalTokens == 0 {
-		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
 	}
 	usage.CompletionTokens += nodeToken
 	return usage, nil

+ 2 - 70
relay/channel/gemini/relay-gemini-native.go

@@ -3,7 +3,6 @@ package gemini
 import (
 	"io"
 	"net/http"
-	"strings"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/dto"
@@ -13,8 +12,6 @@ import (
 	"github.com/QuantumNous/new-api/service"
 	"github.com/QuantumNous/new-api/types"
 
-	"github.com/pkg/errors"
-
 	"github.com/gin-gonic/gin"
 )
 
@@ -97,80 +94,15 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
 }
 
 func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
-	var usage = &dto.Usage{}
-	var imageCount int
-
 	helper.SetEventStreamHeaders(c)
 
-	responseText := strings.Builder{}
-
-	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
-		var geminiResponse dto.GeminiChatResponse
-		err := common.UnmarshalJsonStr(data, &geminiResponse)
-		if err != nil {
-			logger.LogError(c, "error unmarshalling stream response: "+err.Error())
-			return false
-		}
-
-		// 统计图片数量
-		for _, candidate := range geminiResponse.Candidates {
-			for _, part := range candidate.Content.Parts {
-				if part.InlineData != nil && part.InlineData.MimeType != "" {
-					imageCount++
-				}
-				if part.Text != "" {
-					responseText.WriteString(part.Text)
-				}
-			}
-		}
-
-		// 更新使用量统计
-		if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
-			usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
-			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
-			usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
-			usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
-			for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
-				if detail.Modality == "AUDIO" {
-					usage.PromptTokensDetails.AudioTokens = detail.TokenCount
-				} else if detail.Modality == "TEXT" {
-					usage.PromptTokensDetails.TextTokens = detail.TokenCount
-				}
-			}
-		}
-
+	return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
 		// 直接发送 GeminiChatResponse 响应
-		err = helper.StringData(c, data)
+		err := helper.StringData(c, data)
 		if err != nil {
 			logger.LogError(c, err.Error())
 		}
 		info.SendResponseCount++
 		return true
 	})
-
-	if info.SendResponseCount == 0 {
-		return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
-	}
-
-	if imageCount != 0 {
-		if usage.CompletionTokens == 0 {
-			usage.CompletionTokens = imageCount * 258
-		}
-	}
-
-	// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
-	if usage.CompletionTokens == 0 {
-		str := responseText.String()
-		if len(str) > 0 {
-			usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
-		} else {
-			// 空补全,不需要使用量
-			usage = &dto.Usage{}
-		}
-	}
-
-	// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
-	//helper.Done(c)
-
-	return usage, nil
 }

+ 52 - 45
relay/channel/gemini/relay-gemini.go

@@ -954,14 +954,10 @@ func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.Ch
 	return nil
 }
 
-func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
-	// responseText := ""
-	id := helper.GetResponseID(c)
-	createAt := common.GetTimestamp()
-	responseText := strings.Builder{}
+func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response, callback func(data string, geminiResponse *dto.GeminiChatResponse) bool) (*dto.Usage, *types.NewAPIError) {
 	var usage = &dto.Usage{}
 	var imageCount int
-	finishReason := constant.FinishReasonStop
+	responseText := strings.Builder{}
 
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var geminiResponse dto.GeminiChatResponse
@@ -971,6 +967,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 			return false
 		}
 
+		// 统计图片数量
 		for _, candidate := range geminiResponse.Candidates {
 			for _, part := range candidate.Content.Parts {
 				if part.InlineData != nil && part.InlineData.MimeType != "" {
@@ -982,14 +979,10 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 			}
 		}
 
-		response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
-
-		response.Id = id
-		response.Created = createAt
-		response.Model = info.UpstreamModelName
+		// 更新使用量统计
 		if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
 			usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
-			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
+			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
 			usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
 			usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
 			for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
@@ -1000,6 +993,45 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 				}
 			}
 		}
+
+		return callback(data, &geminiResponse)
+	})
+
+	if imageCount != 0 {
+		if usage.CompletionTokens == 0 {
+			usage.CompletionTokens = imageCount * 1400
+		}
+	}
+
+	usage.PromptTokensDetails.TextTokens = usage.PromptTokens
+	if usage.TotalTokens > 0 {
+		usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
+	}
+
+	if usage.CompletionTokens <= 0 {
+		str := responseText.String()
+		if len(str) > 0 {
+			usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.PromptTokens)
+		} else {
+			usage = &dto.Usage{}
+		}
+	}
+
+	return usage, nil
+}
+
+func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+	id := helper.GetResponseID(c)
+	createAt := common.GetTimestamp()
+	finishReason := constant.FinishReasonStop
+
+	usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
+		response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
+
+		response.Id = id
+		response.Created = createAt
+		response.Model = info.UpstreamModelName
+
 		logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
 		if info.SendResponseCount == 0 {
 			// send first response
@@ -1015,7 +1047,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 					emptyResponse.Choices[0].Delta.ToolCalls = copiedToolCalls
 				}
 				finishReason = constant.FinishReasonToolCalls
-				err = handleStream(c, info, emptyResponse)
+				err := handleStream(c, info, emptyResponse)
 				if err != nil {
 					logger.LogError(c, err.Error())
 				}
@@ -1025,14 +1057,14 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 					response.Choices[0].FinishReason = nil
 				}
 			} else {
-				err = handleStream(c, info, emptyResponse)
+				err := handleStream(c, info, emptyResponse)
 				if err != nil {
 					logger.LogError(c, err.Error())
 				}
 			}
 		}
 
-		err = handleStream(c, info, response)
+		err := handleStream(c, info, response)
 		if err != nil {
 			logger.LogError(c, err.Error())
 		}
@@ -1042,40 +1074,15 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 		return true
 	})
 
-	if info.SendResponseCount == 0 {
-		// 空补全,报错不计费
-		// empty response, throw an error
-		return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
-	}
-
-	if imageCount != 0 {
-		if usage.CompletionTokens == 0 {
-			usage.CompletionTokens = imageCount * 258
-		}
-	}
-
-	usage.PromptTokensDetails.TextTokens = usage.PromptTokens
-	usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
-
-	if usage.CompletionTokens == 0 {
-		str := responseText.String()
-		if len(str) > 0 {
-			usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
-		} else {
-			// 空补全,不需要使用量
-			usage = &dto.Usage{}
-		}
+	if err != nil {
+		return usage, err
 	}
 
 	response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
-	err := handleFinalStream(c, info, response)
-	if err != nil {
-		common.SysLog("send final response failed: " + err.Error())
+	handleErr := handleFinalStream(c, info, response)
+	if handleErr != nil {
+		common.SysLog("send final response failed: " + handleErr.Error())
 	}
-	//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
-	//	helper.Done(c)
-	//}
-	//resp.Body.Close()
 	return usage, nil
 }
 

+ 1 - 1
relay/channel/openai/relay-openai.go

@@ -183,7 +183,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 	}
 
 	if !containStreamUsage {
-		usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
 		usage.CompletionTokens += toolCount * 7
 	}
 

+ 1 - 1
relay/channel/palm/adaptor.go

@@ -81,7 +81,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 		var responseText string
 		err, responseText = palmStreamHandler(c, resp)
-		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
 		usage, err = palmHandler(c, info, resp)
 	}

+ 1 - 1
relay/channel/tencent/relay-tencent.go

@@ -130,7 +130,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
 
 	service.CloseResponseBodyGracefully(resp)
 
-	return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
+	return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens), nil
 }
 
 func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {

+ 1 - 1
relay/channel/xai/text.go

@@ -70,7 +70,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 	})
 
 	if !containStreamUsage {
-		usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
 		usage.CompletionTokens += toolCount * 7
 	}
 

+ 1 - 0
service/image.go

@@ -16,6 +16,7 @@ import (
 	"golang.org/x/image/webp"
 )
 
+// return image.Config, format, clean base64 string, error
 func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
 	// 去除base64数据的URL前缀(如果有)
 	if idx := strings.Index(base64String, ","); idx != -1 {

+ 6 - 0
service/log_info_generate.go

@@ -62,6 +62,12 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
 		adminInfo["is_multi_key"] = true
 		adminInfo["multi_key_index"] = common.GetContextKeyInt(ctx, constant.ContextKeyChannelMultiKeyIndex)
 	}
+
+	isLocalCountTokens := common.GetContextKeyBool(ctx, constant.ContextKeyLocalCountTokens)
+	if isLocalCountTokens {
+		adminInfo["local_count_tokens"] = isLocalCountTokens
+	}
+
 	other["admin_info"] = adminInfo
 	appendRequestPath(ctx, relayInfo, other)
 	return other

+ 5 - 1
service/usage_helpr.go

@@ -1,7 +1,10 @@
 package service
 
 import (
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
+	"github.com/gin-gonic/gin"
 )
 
 //func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
@@ -16,7 +19,8 @@ import (
 //	return 0, errors.New("unknown relay mode")
 //}
 
-func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
+func ResponseText2Usage(c *gin.Context, responseText string, modeName string, promptTokens int) *dto.Usage {
+	common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
 	usage := &dto.Usage{}
 	usage.PromptTokens = promptTokens
 	ctkm := CountTextToken(responseText, modeName)

+ 5 - 0
setting/ratio_setting/model_ratio.go

@@ -598,6 +598,11 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
 			return 2.5 / 0.3, false
 		} else if strings.HasPrefix(name, "gemini-robotics-er-1.5") {
 			return 2.5 / 0.3, false
+		} else if strings.HasPrefix(name, "gemini-3-pro") {
+			if strings.HasPrefix(name, "gemini-3-pro-image") {
+				return 60, false
+			}
+			return 6, false
 		}
 		return 4, false
 	}

+ 12 - 0
web/src/hooks/usage-logs/useUsageLogsData.jsx

@@ -482,6 +482,18 @@ export const useLogsData = () => {
           value: other.request_path,
         });
       }
+      if (isAdminUser) {
+        let localCountMode = '';
+        if (other?.admin_info?.local_count_tokens) {
+          localCountMode = t('本地计费');
+        } else {
+          localCountMode = t('上游返回');
+        }
+        expandDataLocal.push({
+            key: t('计费模式'),
+            value: localCountMode,
+        });
+      }
       expandDatesLocal[logs[i].key] = expandDataLocal;
     }