| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704 |
- package openai
- import (
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "strings"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/logger"
- "github.com/QuantumNous/new-api/relay/channel/openrouter"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/relay/helper"
- "github.com/QuantumNous/new-api/service"
- "github.com/QuantumNous/new-api/types"
- "github.com/bytedance/gopkg/util/gopool"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
- )
- func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
- if data == "" {
- return nil
- }
- if !forceFormat && !thinkToContent {
- return helper.StringData(c, data)
- }
- var lastStreamResponse dto.ChatCompletionsStreamResponse
- if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
- return err
- }
- if !thinkToContent {
- return helper.ObjectData(c, lastStreamResponse)
- }
- hasThinkingContent := false
- hasContent := false
- var thinkingContent strings.Builder
- for _, choice := range lastStreamResponse.Choices {
- if len(choice.Delta.GetReasoningContent()) > 0 {
- hasThinkingContent = true
- thinkingContent.WriteString(choice.Delta.GetReasoningContent())
- }
- if len(choice.Delta.GetContentString()) > 0 {
- hasContent = true
- }
- }
- // Handle think to content conversion
- if info.ThinkingContentInfo.IsFirstThinkingContent {
- if hasThinkingContent {
- response := lastStreamResponse.Copy()
- for i := range response.Choices {
- // send `think` tag with thinking content
- response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
- response.Choices[i].Delta.ReasoningContent = nil
- response.Choices[i].Delta.Reasoning = nil
- }
- info.ThinkingContentInfo.IsFirstThinkingContent = false
- info.ThinkingContentInfo.HasSentThinkingContent = true
- return helper.ObjectData(c, response)
- }
- }
- if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
- return helper.ObjectData(c, lastStreamResponse)
- }
- // Process each choice
- for i, choice := range lastStreamResponse.Choices {
- // Handle transition from thinking to content
- // only send `</think>` tag when previous thinking content has been sent
- if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
- response := lastStreamResponse.Copy()
- for j := range response.Choices {
- response.Choices[j].Delta.SetContentString("\n</think>\n")
- response.Choices[j].Delta.ReasoningContent = nil
- response.Choices[j].Delta.Reasoning = nil
- }
- info.ThinkingContentInfo.SendLastThinkingContent = true
- helper.ObjectData(c, response)
- }
- // Convert reasoning content to regular content if any
- if len(choice.Delta.GetReasoningContent()) > 0 {
- lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
- lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
- lastStreamResponse.Choices[i].Delta.Reasoning = nil
- } else if !hasThinkingContent && !hasContent {
- // flush thinking content
- lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
- lastStreamResponse.Choices[i].Delta.Reasoning = nil
- }
- }
- return helper.ObjectData(c, lastStreamResponse)
- }
- func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- if resp == nil || resp.Body == nil {
- logger.LogError(c, "invalid response or response body")
- return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
- }
- defer service.CloseResponseBodyGracefully(resp)
- model := info.UpstreamModelName
- var responseId string
- var createAt int64 = 0
- var systemFingerprint string
- var containStreamUsage bool
- var responseTextBuilder strings.Builder
- var toolCount int
- var usage = &dto.Usage{}
- var streamItems []string // store stream items
- var lastStreamData string
- var secondLastStreamData string // 存储倒数第二个stream data,用于音频模型
- // 检查是否为音频模型
- isAudioModel := strings.Contains(strings.ToLower(model), "audio")
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- if lastStreamData != "" {
- err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
- if err != nil {
- common.SysLog("error handling stream format: " + err.Error())
- }
- }
- if len(data) > 0 {
- // 对音频模型,保存倒数第二个stream data
- if isAudioModel && lastStreamData != "" {
- secondLastStreamData = lastStreamData
- }
- lastStreamData = data
- streamItems = append(streamItems, data)
- }
- return true
- })
- // 对音频模型,从倒数第二个stream data中提取usage信息
- if isAudioModel && secondLastStreamData != "" {
- var streamResp struct {
- Usage *dto.Usage `json:"usage"`
- }
- err := json.Unmarshal([]byte(secondLastStreamData), &streamResp)
- if err == nil && streamResp.Usage != nil && service.ValidUsage(streamResp.Usage) {
- usage = streamResp.Usage
- containStreamUsage = true
- if common.DebugEnabled {
- logger.LogDebug(c, fmt.Sprintf("Audio model usage extracted from second last SSE: PromptTokens=%d, CompletionTokens=%d, TotalTokens=%d, InputTokens=%d, OutputTokens=%d",
- usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens,
- usage.InputTokens, usage.OutputTokens))
- }
- }
- }
- // 处理最后的响应
- shouldSendLastResp := true
- if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
- &containStreamUsage, info, &shouldSendLastResp); err != nil {
- logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
- }
- if info.RelayFormat == types.RelayFormatOpenAI {
- if shouldSendLastResp {
- _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
- }
- }
- // 处理token计算
- if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
- logger.LogError(c, "error processing tokens: "+err.Error())
- }
- if !containStreamUsage {
- usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
- usage.CompletionTokens += toolCount * 7
- }
- applyUsagePostProcessing(info, usage, nil)
- HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
- return usage, nil
- }
- func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer service.CloseResponseBodyGracefully(resp)
- var simpleResponse dto.OpenAITextResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
- if common.DebugEnabled {
- println("upstream response body:", string(responseBody))
- }
- // Unmarshal to simpleResponse
- if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
- // 尝试解析为 openrouter enterprise
- var enterpriseResponse openrouter.OpenRouterEnterpriseResponse
- err = common.Unmarshal(responseBody, &enterpriseResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- if enterpriseResponse.Success {
- responseBody = enterpriseResponse.Data
- } else {
- logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data))
- return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- }
- err = common.Unmarshal(responseBody, &simpleResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
- return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
- }
- forceFormat := false
- if info.ChannelSetting.ForceFormat {
- forceFormat = true
- }
- usageModified := false
- if simpleResponse.Usage.PromptTokens == 0 {
- completionTokens := simpleResponse.Usage.CompletionTokens
- if completionTokens == 0 {
- for _, choice := range simpleResponse.Choices {
- ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
- completionTokens += ctkm
- }
- }
- simpleResponse.Usage = dto.Usage{
- PromptTokens: info.PromptTokens,
- CompletionTokens: completionTokens,
- TotalTokens: info.PromptTokens + completionTokens,
- }
- usageModified = true
- }
- applyUsagePostProcessing(info, &simpleResponse.Usage, responseBody)
- switch info.RelayFormat {
- case types.RelayFormatOpenAI:
- if usageModified {
- var bodyMap map[string]interface{}
- err = common.Unmarshal(responseBody, &bodyMap)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- bodyMap["usage"] = simpleResponse.Usage
- responseBody, _ = common.Marshal(bodyMap)
- }
- if forceFormat {
- responseBody, err = common.Marshal(simpleResponse)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- } else {
- break
- }
- case types.RelayFormatClaude:
- claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
- claudeRespStr, err := common.Marshal(claudeResp)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- responseBody = claudeRespStr
- case types.RelayFormatGemini:
- geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
- geminiRespStr, err := common.Marshal(geminiResp)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- responseBody = geminiRespStr
- }
- service.IOCopyBytesGracefully(c, resp, responseBody)
- return &simpleResponse.Usage, nil
- }
- func streamTTSResponse(c *gin.Context, resp *http.Response) {
- c.Writer.WriteHeaderNow()
- flusher, ok := c.Writer.(http.Flusher)
- if !ok {
- logger.LogWarn(c, "streaming not supported")
- _, err := io.Copy(c.Writer, resp.Body)
- if err != nil {
- logger.LogWarn(c, err.Error())
- }
- return
- }
- buffer := make([]byte, 4096)
- for {
- n, err := resp.Body.Read(buffer)
- //logger.LogInfo(c, fmt.Sprintf("streamTTSResponse read %d bytes", n))
- if n > 0 {
- if _, writeErr := c.Writer.Write(buffer[:n]); writeErr != nil {
- logger.LogError(c, writeErr.Error())
- break
- }
- flusher.Flush()
- }
- if err != nil {
- if err != io.EOF {
- logger.LogError(c, err.Error())
- }
- break
- }
- }
- }
- func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
- // the status code has been judged before, if there is a body reading failure,
- // it should be regarded as a non-recoverable error, so it should not return err for external retry.
- // Analogous to nginx's load balancing, it will only retry if it can't be requested or
- // if the upstream returns a specific status code, once the upstream has already written the header,
- // the subsequent failure of the response body should be regarded as a non-recoverable error,
- // and can be terminated directly.
- defer service.CloseResponseBodyGracefully(resp)
- usage := &dto.Usage{}
- usage.PromptTokens = info.PromptTokens
- usage.TotalTokens = info.PromptTokens
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
- isStreaming := resp.ContentLength == -1 || resp.Header.Get("Content-Length") == ""
- if isStreaming {
- streamTTSResponse(c, resp)
- } else {
- c.Writer.WriteHeaderNow()
- _, err := io.Copy(c.Writer, resp.Body)
- if err != nil {
- logger.LogError(c, err.Error())
- }
- }
- return usage
- }
- func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
- defer service.CloseResponseBodyGracefully(resp)
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
- }
- // 写入新的 response body
- service.IOCopyBytesGracefully(c, resp, responseBody)
- var responseData struct {
- Usage *dto.Usage `json:"usage"`
- }
- if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
- if responseData.Usage.TotalTokens > 0 {
- usage := responseData.Usage
- if usage.PromptTokens == 0 {
- usage.PromptTokens = usage.InputTokens
- }
- if usage.CompletionTokens == 0 {
- usage.CompletionTokens = usage.OutputTokens
- }
- return nil, usage
- }
- }
- usage := &dto.Usage{}
- usage.PromptTokens = info.PromptTokens
- usage.CompletionTokens = 0
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- return nil, usage
- }
- func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
- if info == nil || info.ClientWs == nil || info.TargetWs == nil {
- return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
- }
- info.IsStream = true
- clientConn := info.ClientWs
- targetConn := info.TargetWs
- clientClosed := make(chan struct{})
- targetClosed := make(chan struct{})
- sendChan := make(chan []byte, 100)
- receiveChan := make(chan []byte, 100)
- errChan := make(chan error, 2)
- usage := &dto.RealtimeUsage{}
- localUsage := &dto.RealtimeUsage{}
- sumUsage := &dto.RealtimeUsage{}
- gopool.Go(func() {
- defer func() {
- if r := recover(); r != nil {
- errChan <- fmt.Errorf("panic in client reader: %v", r)
- }
- }()
- for {
- select {
- case <-c.Done():
- return
- default:
- _, message, err := clientConn.ReadMessage()
- if err != nil {
- if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
- errChan <- fmt.Errorf("error reading from client: %v", err)
- }
- close(clientClosed)
- return
- }
- realtimeEvent := &dto.RealtimeEvent{}
- err = common.Unmarshal(message, realtimeEvent)
- if err != nil {
- errChan <- fmt.Errorf("error unmarshalling message: %v", err)
- return
- }
- if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
- if realtimeEvent.Session != nil {
- if realtimeEvent.Session.Tools != nil {
- info.RealtimeTools = realtimeEvent.Session.Tools
- }
- }
- }
- textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
- if err != nil {
- errChan <- fmt.Errorf("error counting text token: %v", err)
- return
- }
- logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
- localUsage.TotalTokens += textToken + audioToken
- localUsage.InputTokens += textToken + audioToken
- localUsage.InputTokenDetails.TextTokens += textToken
- localUsage.InputTokenDetails.AudioTokens += audioToken
- err = helper.WssString(c, targetConn, string(message))
- if err != nil {
- errChan <- fmt.Errorf("error writing to target: %v", err)
- return
- }
- select {
- case sendChan <- message:
- default:
- }
- }
- }
- })
- gopool.Go(func() {
- defer func() {
- if r := recover(); r != nil {
- errChan <- fmt.Errorf("panic in target reader: %v", r)
- }
- }()
- for {
- select {
- case <-c.Done():
- return
- default:
- _, message, err := targetConn.ReadMessage()
- if err != nil {
- if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
- errChan <- fmt.Errorf("error reading from target: %v", err)
- }
- close(targetClosed)
- return
- }
- info.SetFirstResponseTime()
- realtimeEvent := &dto.RealtimeEvent{}
- err = common.Unmarshal(message, realtimeEvent)
- if err != nil {
- errChan <- fmt.Errorf("error unmarshalling message: %v", err)
- return
- }
- if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
- realtimeUsage := realtimeEvent.Response.Usage
- if realtimeUsage != nil {
- usage.TotalTokens += realtimeUsage.TotalTokens
- usage.InputTokens += realtimeUsage.InputTokens
- usage.OutputTokens += realtimeUsage.OutputTokens
- usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
- usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
- usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
- usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
- usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
- err := preConsumeUsage(c, info, usage, sumUsage)
- if err != nil {
- errChan <- fmt.Errorf("error consume usage: %v", err)
- return
- }
- // 本次计费完成,清除
- usage = &dto.RealtimeUsage{}
- localUsage = &dto.RealtimeUsage{}
- } else {
- textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
- if err != nil {
- errChan <- fmt.Errorf("error counting text token: %v", err)
- return
- }
- logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
- localUsage.TotalTokens += textToken + audioToken
- info.IsFirstRequest = false
- localUsage.InputTokens += textToken + audioToken
- localUsage.InputTokenDetails.TextTokens += textToken
- localUsage.InputTokenDetails.AudioTokens += audioToken
- err = preConsumeUsage(c, info, localUsage, sumUsage)
- if err != nil {
- errChan <- fmt.Errorf("error consume usage: %v", err)
- return
- }
- // 本次计费完成,清除
- localUsage = &dto.RealtimeUsage{}
- // print now usage
- }
- logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
- logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
- logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
- } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
- realtimeSession := realtimeEvent.Session
- if realtimeSession != nil {
- // update audio format
- info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
- info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
- }
- } else {
- textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
- if err != nil {
- errChan <- fmt.Errorf("error counting text token: %v", err)
- return
- }
- logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
- localUsage.TotalTokens += textToken + audioToken
- localUsage.OutputTokens += textToken + audioToken
- localUsage.OutputTokenDetails.TextTokens += textToken
- localUsage.OutputTokenDetails.AudioTokens += audioToken
- }
- err = helper.WssString(c, clientConn, string(message))
- if err != nil {
- errChan <- fmt.Errorf("error writing to client: %v", err)
- return
- }
- select {
- case receiveChan <- message:
- default:
- }
- }
- }
- })
- select {
- case <-clientClosed:
- case <-targetClosed:
- case err := <-errChan:
- //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
- logger.LogError(c, "realtime error: "+err.Error())
- case <-c.Done():
- }
- if usage.TotalTokens != 0 {
- _ = preConsumeUsage(c, info, usage, sumUsage)
- }
- if localUsage.TotalTokens != 0 {
- _ = preConsumeUsage(c, info, localUsage, sumUsage)
- }
- // check usage total tokens, if 0, use local usage
- return nil, sumUsage
- }
- func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
- if usage == nil || totalUsage == nil {
- return fmt.Errorf("invalid usage pointer")
- }
- totalUsage.TotalTokens += usage.TotalTokens
- totalUsage.InputTokens += usage.InputTokens
- totalUsage.OutputTokens += usage.OutputTokens
- totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
- totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
- totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
- totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
- totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
- // clear usage
- err := service.PreWssConsumeQuota(ctx, info, usage)
- return err
- }
- func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer service.CloseResponseBodyGracefully(resp)
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
- var usageResp dto.SimpleResponse
- err = common.Unmarshal(responseBody, &usageResp)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- // 写入新的 response body
- service.IOCopyBytesGracefully(c, resp, responseBody)
- // Once we've written to the client, we should not return errors anymore
- // because the upstream has already consumed resources and returned content
- // We should still perform billing even if parsing fails
- // format
- if usageResp.InputTokens > 0 {
- usageResp.PromptTokens += usageResp.InputTokens
- }
- if usageResp.OutputTokens > 0 {
- usageResp.CompletionTokens += usageResp.OutputTokens
- }
- if usageResp.InputTokensDetails != nil {
- usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
- usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
- }
- applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
- return &usageResp.Usage, nil
- }
- func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
- if info == nil || usage == nil {
- return
- }
- switch info.ChannelType {
- case constant.ChannelTypeDeepSeek:
- if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
- usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
- }
- case constant.ChannelTypeZhipu_v4:
- if usage.PromptTokensDetails.CachedTokens == 0 {
- if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
- usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
- } else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
- usage.PromptTokensDetails.CachedTokens = cachedTokens
- } else if usage.PromptCacheHitTokens > 0 {
- usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
- }
- }
- }
- }
- func extractCachedTokensFromBody(body []byte) (int, bool) {
- if len(body) == 0 {
- return 0, false
- }
- var payload struct {
- Usage struct {
- PromptTokensDetails struct {
- CachedTokens *int `json:"cached_tokens"`
- } `json:"prompt_tokens_details"`
- CachedTokens *int `json:"cached_tokens"`
- PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
- } `json:"usage"`
- }
- if err := json.Unmarshal(body, &payload); err != nil {
- return 0, false
- }
- if payload.Usage.PromptTokensDetails.CachedTokens != nil {
- return *payload.Usage.PromptTokensDetails.CachedTokens, true
- }
- if payload.Usage.CachedTokens != nil {
- return *payload.Usage.CachedTokens, true
- }
- if payload.Usage.PromptCacheHitTokens != nil {
- return *payload.Usage.PromptCacheHitTokens, true
- }
- return 0, false
- }
|