package relay import ( "bytes" "fmt" "io" "net/http" "strings" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/relay/helper" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/model_setting" "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/QuantumNous/new-api/types" "github.com/shopspring/decimal" "github.com/gin-gonic/gin" ) func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { info.InitChannelMeta(c) textReq, ok := info.Request.(*dto.GeneralOpenAIRequest) if !ok { return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } request, err := common.DeepCopy(textReq) if err != nil { return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } if request.WebSearchOptions != nil { c.Set("chat_completion_web_search_context_size", request.WebSearchOptions.SearchContextSize) } err = helper.ModelMappedHelper(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } includeUsage := true // 判断用户是否需要返回使用情况 if request.StreamOptions != nil { includeUsage = request.StreamOptions.IncludeUsage } // 如果不支持StreamOptions,将StreamOptions设置为nil if !info.SupportStreamOptions || !request.Stream { request.StreamOptions = nil } else { // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions if constant.ForceStreamOption { request.StreamOptions = &dto.StreamOptions{ IncludeUsage: true, } } } info.ShouldIncludeUsage = includeUsage adaptor := GetAdaptor(info.ApiType) if adaptor == nil { return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(info) var requestBody io.Reader if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled { body, err := common.GetRequestBody(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } if common.DebugEnabled { println("requestBody: ", string(body)) } requestBody = bytes.NewBuffer(body) } else { convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } if info.ChannelSetting.SystemPrompt != "" { // 如果有系统提示,则将其添加到请求中 request, ok := convertedRequest.(*dto.GeneralOpenAIRequest) if ok { containSystemPrompt := false for _, message := range request.Messages { if message.Role == request.GetSystemRoleName() { containSystemPrompt = true break } } if !containSystemPrompt { // 如果没有系统提示,则添加系统提示 systemMessage := dto.Message{ Role: request.GetSystemRoleName(), Content: info.ChannelSetting.SystemPrompt, } request.Messages = append([]dto.Message{systemMessage}, request.Messages...) } else if info.ChannelSetting.SystemPromptOverride { common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) // 如果有系统提示,且允许覆盖,则拼接到前面 for i, message := range request.Messages { if message.Role == request.GetSystemRoleName() { if message.IsStringContent() { request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) } else { contents := message.ParseContent() contents = append([]dto.MediaContent{ { Type: dto.ContentTypeText, Text: info.ChannelSetting.SystemPrompt, }, }, contents...) request.Messages[i].Content = contents } break } } } } } jsonData, err := common.Marshal(convertedRequest) if err != nil { return types.NewError(err, types.ErrorCodeJsonMarshalFailed, types.ErrOptionWithSkipRetry()) } // remove disabled fields for OpenAI API jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override if len(info.ParamOverride) > 0 { jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride) if err != nil { return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry()) } } logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData))) requestBody = bytes.NewBuffer(jsonData) } var httpResp *http.Response resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } statusCodeMappingStr := c.GetString("status_code_mapping") if resp != nil { httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr } } usage, newApiErr := adaptor.DoResponse(c, httpResp, info) if newApiErr != nil { // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr } if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") { service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { postConsumeQuota(c, info, usage.(*dto.Usage), "") } return nil } func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { if usage == nil { usage = &dto.Usage{ PromptTokens: relayInfo.PromptTokens, CompletionTokens: 0, TotalTokens: relayInfo.PromptTokens, } extraContent += "(可能是请求出错)" } useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens cacheTokens := usage.PromptTokensDetails.CachedTokens imageTokens := usage.PromptTokensDetails.ImageTokens audioTokens := usage.PromptTokensDetails.AudioTokens completionTokens := usage.CompletionTokens cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") completionRatio := relayInfo.PriceData.CompletionRatio cacheRatio := relayInfo.PriceData.CacheRatio imageRatio := relayInfo.PriceData.ImageRatio modelRatio := relayInfo.PriceData.ModelRatio groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio modelPrice := relayInfo.PriceData.ModelPrice cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio // Convert values to decimal for precise calculation dPromptTokens := decimal.NewFromInt(int64(promptTokens)) dCacheTokens := decimal.NewFromInt(int64(cacheTokens)) dImageTokens := decimal.NewFromInt(int64(imageTokens)) dAudioTokens := decimal.NewFromInt(int64(audioTokens)) dCompletionTokens := decimal.NewFromInt(int64(completionTokens)) dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens)) dCompletionRatio := decimal.NewFromFloat(completionRatio) dCacheRatio := decimal.NewFromFloat(cacheRatio) dImageRatio := decimal.NewFromFloat(imageRatio) dModelRatio := decimal.NewFromFloat(modelRatio) dGroupRatio := decimal.NewFromFloat(groupRatio) dModelPrice := decimal.NewFromFloat(modelPrice) dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) ratio := dModelRatio.Mul(dGroupRatio) // openai web search 工具计费 var dWebSearchQuota decimal.Decimal var webSearchPrice float64 // response api 格式工具计费 if relayInfo.ResponsesUsageInfo != nil { if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 { // 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率) webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize) dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s", webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()) } } else if strings.HasSuffix(modelName, "search-preview") { // search-preview 模型不支持 response api searchContextSize := ctx.GetString("chat_completion_web_search_context_size") if searchContextSize == "" { searchContextSize = "medium" } webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize) dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s", searchContextSize, dWebSearchQuota.String()) } // claude web search tool 计费 var dClaudeWebSearchQuota decimal.Decimal var claudeWebSearchPrice float64 claudeWebSearchCallCount := ctx.GetInt("claude_web_search_requests") if claudeWebSearchCallCount > 0 { claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand() dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice). Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount))) extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s", claudeWebSearchCallCount, dClaudeWebSearchQuota.String()) } // file search tool 计费 var dFileSearchQuota decimal.Decimal var fileSearchPrice float64 if relayInfo.ResponsesUsageInfo != nil { if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 { fileSearchPrice = operation_setting.GetFileSearchPricePerThousand() dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice). Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))). Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s", fileSearchTool.CallCount, dFileSearchQuota.String()) } } var dImageGenerationCallQuota decimal.Decimal var imageGenerationCallPrice float64 if ctx.GetBool("image_generation_call") { imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size")) dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit) extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()) } var quotaCalculateDecimal decimal.Decimal var audioInputQuota decimal.Decimal var audioInputPrice float64 if !relayInfo.PriceData.UsePrice { baseTokens := dPromptTokens // 减去 cached tokens var cachedTokensWithRatio decimal.Decimal if !dCacheTokens.IsZero() { baseTokens = baseTokens.Sub(dCacheTokens) cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio) } var dCachedCreationTokensWithRatio decimal.Decimal if !dCachedCreationTokens.IsZero() { baseTokens = baseTokens.Sub(dCachedCreationTokens) dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio) } // 减去 image tokens var imageTokensWithRatio decimal.Decimal if !dImageTokens.IsZero() { baseTokens = baseTokens.Sub(dImageTokens) imageTokensWithRatio = dImageTokens.Mul(dImageRatio) } // 减去 Gemini audio tokens if !dAudioTokens.IsZero() { audioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(modelName) if audioInputPrice > 0 { // 重新计算 base tokens baseTokens = baseTokens.Sub(dAudioTokens) audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit) extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()) } } promptQuota := baseTokens.Add(cachedTokensWithRatio). Add(imageTokensWithRatio). Add(dCachedCreationTokensWithRatio) completionQuota := dCompletionTokens.Mul(dCompletionRatio) quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio) if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) { quotaCalculateDecimal = decimal.NewFromInt(1) } } else { quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) } // 添加 responses tools call 调用的配额 quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) // 添加 audio input 独立计费 quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota) // 添加 image generation call 计费 quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota) quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens var logContent string // record all the consume log even if quota is 0 if totalTokens == 0 { // in this case, must be some error happened // we cannot just return, because we may have to return the pre-consumed quota quota = 0 logContent += fmt.Sprintf("(可能是上游超时)") logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { if !ratio.IsZero() && quota == 0 { quota = 1 } model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } quotaDelta := quota - relayInfo.FinalPreConsumedQuota //logger.LogInfo(ctx, fmt.Sprintf("request quota delta: %s", logger.FormatQuota(quotaDelta))) if quotaDelta > 0 { logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", logger.FormatQuota(quotaDelta), logger.FormatQuota(quota), logger.FormatQuota(relayInfo.FinalPreConsumedQuota), )) } else if quotaDelta < 0 { logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", logger.FormatQuota(-quotaDelta), logger.FormatQuota(quota), logger.FormatQuota(relayInfo.FinalPreConsumedQuota), )) } if quotaDelta != 0 { err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) if err != nil { logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) } } logModel := modelName if strings.HasPrefix(logModel, "gpt-4-gizmo") { logModel = "gpt-4-gizmo-*" logContent += fmt.Sprintf(",模型 %s", modelName) } if strings.HasPrefix(logModel, "gpt-4o-gizmo") { logModel = "gpt-4o-gizmo-*" logContent += fmt.Sprintf(",模型 %s", modelName) } if extraContent != "" { logContent += ", " + extraContent } other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) if imageTokens != 0 { other["image"] = true other["image_ratio"] = imageRatio other["image_output"] = imageTokens } if cachedCreationTokens != 0 { other["cache_creation_tokens"] = cachedCreationTokens other["cache_creation_ratio"] = cachedCreationRatio } if !dWebSearchQuota.IsZero() { if relayInfo.ResponsesUsageInfo != nil { if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { other["web_search"] = true other["web_search_call_count"] = webSearchTool.CallCount other["web_search_price"] = webSearchPrice } } else if strings.HasSuffix(modelName, "search-preview") { other["web_search"] = true other["web_search_call_count"] = 1 other["web_search_price"] = webSearchPrice } } else if !dClaudeWebSearchQuota.IsZero() { other["web_search"] = true other["web_search_call_count"] = claudeWebSearchCallCount other["web_search_price"] = claudeWebSearchPrice } if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil { if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists { other["file_search"] = true other["file_search_call_count"] = fileSearchTool.CallCount other["file_search_price"] = fileSearchPrice } } if !audioInputQuota.IsZero() { other["audio_input_seperate_price"] = true other["audio_input_token_count"] = audioTokens other["audio_input_price"] = audioInputPrice } if !dImageGenerationCallQuota.IsZero() { other["image_generation_call"] = true other["image_generation_call_price"] = imageGenerationCallPrice } model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: promptTokens, CompletionTokens: completionTokens, ModelName: logModel, TokenName: tokenName, Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, Other: other, }) }