| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- package openai
- import (
- "fmt"
- "io"
- "net/http"
- "strings"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/logger"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/relay/helper"
- "github.com/QuantumNous/new-api/service"
- "github.com/QuantumNous/new-api/types"
- "github.com/gin-gonic/gin"
- )
- func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer service.CloseResponseBodyGracefully(resp)
- // read response body
- var responsesResponse dto.OpenAIResponsesResponse
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
- err = common.Unmarshal(responseBody, &responsesResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
- return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
- }
- if responsesResponse.HasImageGenerationCall() {
- c.Set("image_generation_call", true)
- c.Set("image_generation_call_quality", responsesResponse.GetQuality())
- c.Set("image_generation_call_size", responsesResponse.GetSize())
- }
- // 写入新的 response body
- service.IOCopyBytesGracefully(c, resp, responseBody)
- // compute usage
- usage := dto.Usage{}
- if responsesResponse.Usage != nil {
- usage.PromptTokens = responsesResponse.Usage.InputTokens
- usage.CompletionTokens = responsesResponse.Usage.OutputTokens
- usage.TotalTokens = responsesResponse.Usage.TotalTokens
- if responsesResponse.Usage.InputTokensDetails != nil {
- usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
- }
- }
- if info == nil || info.ResponsesUsageInfo == nil || info.ResponsesUsageInfo.BuiltInTools == nil {
- return &usage, nil
- }
- // 解析 Tools 用量
- for _, tool := range responsesResponse.Tools {
- buildToolinfo, ok := info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])]
- if !ok || buildToolinfo == nil {
- logger.LogError(c, fmt.Sprintf("BuiltInTools not found for tool type: %v", tool["type"]))
- continue
- }
- buildToolinfo.CallCount++
- }
- return &usage, nil
- }
- func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- if resp == nil || resp.Body == nil {
- logger.LogError(c, "invalid response or response body")
- return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
- }
- defer service.CloseResponseBodyGracefully(resp)
- var usage = &dto.Usage{}
- var responseTextBuilder strings.Builder
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- // 检查当前数据是否包含 completed 状态和 usage 信息
- var streamResponse dto.ResponsesStreamResponse
- if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
- sendResponsesStreamData(c, streamResponse, data)
- switch streamResponse.Type {
- case "response.completed":
- if streamResponse.Response != nil {
- if streamResponse.Response.Usage != nil {
- if streamResponse.Response.Usage.InputTokens != 0 {
- usage.PromptTokens = streamResponse.Response.Usage.InputTokens
- }
- if streamResponse.Response.Usage.OutputTokens != 0 {
- usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
- }
- if streamResponse.Response.Usage.TotalTokens != 0 {
- usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
- }
- if streamResponse.Response.Usage.InputTokensDetails != nil {
- usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
- }
- }
- if streamResponse.Response.HasImageGenerationCall() {
- c.Set("image_generation_call", true)
- c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
- c.Set("image_generation_call_size", streamResponse.Response.GetSize())
- }
- }
- case "response.output_text.delta":
- // 处理输出文本
- responseTextBuilder.WriteString(streamResponse.Delta)
- case dto.ResponsesOutputTypeItemDone:
- // 函数调用处理
- if streamResponse.Item != nil {
- switch streamResponse.Item.Type {
- case dto.BuildInCallWebSearchCall:
- if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil {
- if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil {
- webSearchTool.CallCount++
- }
- }
- }
- }
- }
- } else {
- logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
- }
- return true
- })
- if usage.CompletionTokens == 0 {
- // 计算输出文本的 token 数量
- tempStr := responseTextBuilder.String()
- if len(tempStr) > 0 {
- // 非正常结束,使用输出文本的 token 数量
- completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
- usage.CompletionTokens = completionTokens
- }
- }
- if usage.PromptTokens == 0 && usage.CompletionTokens != 0 {
- usage.PromptTokens = info.PromptTokens
- }
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- return usage, nil
- }
|