relay_responses.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. package openai
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "strings"
  7. "github.com/QuantumNous/new-api/common"
  8. "github.com/QuantumNous/new-api/dto"
  9. "github.com/QuantumNous/new-api/logger"
  10. relaycommon "github.com/QuantumNous/new-api/relay/common"
  11. "github.com/QuantumNous/new-api/relay/helper"
  12. "github.com/QuantumNous/new-api/service"
  13. "github.com/QuantumNous/new-api/types"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  17. defer service.CloseResponseBodyGracefully(resp)
  18. // read response body
  19. var responsesResponse dto.OpenAIResponsesResponse
  20. responseBody, err := io.ReadAll(resp.Body)
  21. if err != nil {
  22. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  23. }
  24. err = common.Unmarshal(responseBody, &responsesResponse)
  25. if err != nil {
  26. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  27. }
  28. if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
  29. return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
  30. }
  31. if responsesResponse.HasImageGenerationCall() {
  32. c.Set("image_generation_call", true)
  33. c.Set("image_generation_call_quality", responsesResponse.GetQuality())
  34. c.Set("image_generation_call_size", responsesResponse.GetSize())
  35. }
  36. // 写入新的 response body
  37. service.IOCopyBytesGracefully(c, resp, responseBody)
  38. // compute usage
  39. usage := dto.Usage{}
  40. if responsesResponse.Usage != nil {
  41. usage.PromptTokens = responsesResponse.Usage.InputTokens
  42. usage.CompletionTokens = responsesResponse.Usage.OutputTokens
  43. usage.TotalTokens = responsesResponse.Usage.TotalTokens
  44. if responsesResponse.Usage.InputTokensDetails != nil {
  45. usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
  46. }
  47. }
  48. if info == nil || info.ResponsesUsageInfo == nil || info.ResponsesUsageInfo.BuiltInTools == nil {
  49. return &usage, nil
  50. }
  51. // 解析 Tools 用量
  52. for _, tool := range responsesResponse.Tools {
  53. buildToolinfo, ok := info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])]
  54. if !ok || buildToolinfo == nil {
  55. logger.LogError(c, fmt.Sprintf("BuiltInTools not found for tool type: %v", tool["type"]))
  56. continue
  57. }
  58. buildToolinfo.CallCount++
  59. }
  60. return &usage, nil
  61. }
  62. func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  63. if resp == nil || resp.Body == nil {
  64. logger.LogError(c, "invalid response or response body")
  65. return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
  66. }
  67. defer service.CloseResponseBodyGracefully(resp)
  68. var usage = &dto.Usage{}
  69. var responseTextBuilder strings.Builder
  70. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  71. // 检查当前数据是否包含 completed 状态和 usage 信息
  72. var streamResponse dto.ResponsesStreamResponse
  73. if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
  74. sendResponsesStreamData(c, streamResponse, data)
  75. switch streamResponse.Type {
  76. case "response.completed":
  77. if streamResponse.Response != nil {
  78. if streamResponse.Response.Usage != nil {
  79. if streamResponse.Response.Usage.InputTokens != 0 {
  80. usage.PromptTokens = streamResponse.Response.Usage.InputTokens
  81. }
  82. if streamResponse.Response.Usage.OutputTokens != 0 {
  83. usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
  84. }
  85. if streamResponse.Response.Usage.TotalTokens != 0 {
  86. usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
  87. }
  88. if streamResponse.Response.Usage.InputTokensDetails != nil {
  89. usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
  90. }
  91. }
  92. if streamResponse.Response.HasImageGenerationCall() {
  93. c.Set("image_generation_call", true)
  94. c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
  95. c.Set("image_generation_call_size", streamResponse.Response.GetSize())
  96. }
  97. }
  98. case "response.output_text.delta":
  99. // 处理输出文本
  100. responseTextBuilder.WriteString(streamResponse.Delta)
  101. case dto.ResponsesOutputTypeItemDone:
  102. // 函数调用处理
  103. if streamResponse.Item != nil {
  104. switch streamResponse.Item.Type {
  105. case dto.BuildInCallWebSearchCall:
  106. if info != nil && info.ResponsesUsageInfo != nil && info.ResponsesUsageInfo.BuiltInTools != nil {
  107. if webSearchTool, exists := info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool != nil {
  108. webSearchTool.CallCount++
  109. }
  110. }
  111. }
  112. }
  113. }
  114. } else {
  115. logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
  116. }
  117. return true
  118. })
  119. if usage.CompletionTokens == 0 {
  120. // 计算输出文本的 token 数量
  121. tempStr := responseTextBuilder.String()
  122. if len(tempStr) > 0 {
  123. // 非正常结束,使用输出文本的 token 数量
  124. completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
  125. usage.CompletionTokens = completionTokens
  126. }
  127. }
  128. if usage.PromptTokens == 0 && usage.CompletionTokens != 0 {
  129. usage.PromptTokens = info.PromptTokens
  130. }
  131. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  132. return usage, nil
  133. }