text.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package xai
  2. import (
  3. "encoding/json"
  4. "io"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/dto"
  8. "one-api/relay/channel/openai"
  9. relaycommon "one-api/relay/common"
  10. "one-api/relay/helper"
  11. "one-api/service"
  12. "one-api/types"
  13. "strings"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
  17. if xAIResp == nil {
  18. return nil
  19. }
  20. if xAIResp.Usage != nil {
  21. xAIResp.Usage.CompletionTokens = usage.CompletionTokens
  22. }
  23. openAIResp := &dto.ChatCompletionsStreamResponse{
  24. Id: xAIResp.Id,
  25. Object: xAIResp.Object,
  26. Created: xAIResp.Created,
  27. Model: xAIResp.Model,
  28. Choices: xAIResp.Choices,
  29. Usage: xAIResp.Usage,
  30. }
  31. return openAIResp
  32. }
  33. func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  34. usage := &dto.Usage{}
  35. var responseTextBuilder strings.Builder
  36. var toolCount int
  37. var containStreamUsage bool
  38. helper.SetEventStreamHeaders(c)
  39. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  40. var xAIResp *dto.ChatCompletionsStreamResponse
  41. err := json.Unmarshal([]byte(data), &xAIResp)
  42. if err != nil {
  43. common.SysError("error unmarshalling stream response: " + err.Error())
  44. return true
  45. }
  46. // 把 xAI 的usage转换为 OpenAI 的usage
  47. if xAIResp.Usage != nil {
  48. containStreamUsage = true
  49. usage.PromptTokens = xAIResp.Usage.PromptTokens
  50. usage.TotalTokens = xAIResp.Usage.TotalTokens
  51. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  52. }
  53. openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
  54. _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
  55. err = helper.ObjectData(c, openaiResponse)
  56. if err != nil {
  57. common.SysError(err.Error())
  58. }
  59. return true
  60. })
  61. if !containStreamUsage {
  62. usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
  63. usage.CompletionTokens += toolCount * 7
  64. }
  65. helper.Done(c)
  66. common.CloseResponseBodyGracefully(resp)
  67. return usage, nil
  68. }
  69. func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  70. defer common.CloseResponseBodyGracefully(resp)
  71. responseBody, err := io.ReadAll(resp.Body)
  72. var response *dto.SimpleResponse
  73. err = common.Unmarshal(responseBody, &response)
  74. if err != nil {
  75. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  76. }
  77. response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
  78. response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
  79. // new body
  80. encodeJson, err := common.Marshal(response)
  81. if err != nil {
  82. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  83. }
  84. common.IOCopyBytesGracefully(c, resp, encodeJson)
  85. return &response.Usage, nil
  86. }