relay_responses.go 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package openai
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/dto"
  8. relaycommon "one-api/relay/common"
  9. "one-api/relay/helper"
  10. "one-api/service"
  11. "one-api/types"
  12. "strings"
  13. "github.com/gin-gonic/gin"
  14. )
  15. func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  16. defer common.CloseResponseBodyGracefully(resp)
  17. // read response body
  18. var responsesResponse dto.OpenAIResponsesResponse
  19. responseBody, err := io.ReadAll(resp.Body)
  20. if err != nil {
  21. return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
  22. }
  23. err = common.Unmarshal(responseBody, &responsesResponse)
  24. if err != nil {
  25. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  26. }
  27. if responsesResponse.Error != nil {
  28. return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)
  29. }
  30. // 写入新的 response body
  31. common.IOCopyBytesGracefully(c, resp, responseBody)
  32. // compute usage
  33. usage := dto.Usage{}
  34. usage.PromptTokens = responsesResponse.Usage.InputTokens
  35. usage.CompletionTokens = responsesResponse.Usage.OutputTokens
  36. usage.TotalTokens = responsesResponse.Usage.TotalTokens
  37. // 解析 Tools 用量
  38. for _, tool := range responsesResponse.Tools {
  39. info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++
  40. }
  41. return &usage, nil
  42. }
  43. func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  44. if resp == nil || resp.Body == nil {
  45. common.LogError(c, "invalid response or response body")
  46. return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
  47. }
  48. var usage = &dto.Usage{}
  49. var responseTextBuilder strings.Builder
  50. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  51. // 检查当前数据是否包含 completed 状态和 usage 信息
  52. var streamResponse dto.ResponsesStreamResponse
  53. if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
  54. sendResponsesStreamData(c, streamResponse, data)
  55. switch streamResponse.Type {
  56. case "response.completed":
  57. usage.PromptTokens = streamResponse.Response.Usage.InputTokens
  58. usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
  59. usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
  60. case "response.output_text.delta":
  61. // 处理输出文本
  62. responseTextBuilder.WriteString(streamResponse.Delta)
  63. case dto.ResponsesOutputTypeItemDone:
  64. // 函数调用处理
  65. if streamResponse.Item != nil {
  66. switch streamResponse.Item.Type {
  67. case dto.BuildInCallWebSearchCall:
  68. info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++
  69. }
  70. }
  71. }
  72. }
  73. return true
  74. })
  75. if usage.CompletionTokens == 0 {
  76. // 计算输出文本的 token 数量
  77. tempStr := responseTextBuilder.String()
  78. if len(tempStr) > 0 {
  79. // 非正常结束,使用输出文本的 token 数量
  80. completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
  81. usage.CompletionTokens = completionTokens
  82. }
  83. }
  84. return usage, nil
  85. }