| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- package xai
- import (
- "encoding/json"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/relay/channel/openai"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/types"
- "strings"
- "github.com/gin-gonic/gin"
- )
- func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
- if xAIResp == nil {
- return nil
- }
- if xAIResp.Usage != nil {
- xAIResp.Usage.CompletionTokens = usage.CompletionTokens
- }
- openAIResp := &dto.ChatCompletionsStreamResponse{
- Id: xAIResp.Id,
- Object: xAIResp.Object,
- Created: xAIResp.Created,
- Model: xAIResp.Model,
- Choices: xAIResp.Choices,
- Usage: xAIResp.Usage,
- }
- return openAIResp
- }
- func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- usage := &dto.Usage{}
- var responseTextBuilder strings.Builder
- var toolCount int
- var containStreamUsage bool
- helper.SetEventStreamHeaders(c)
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- var xAIResp *dto.ChatCompletionsStreamResponse
- err := json.Unmarshal([]byte(data), &xAIResp)
- if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- // 把 xAI 的usage转换为 OpenAI 的usage
- if xAIResp.Usage != nil {
- containStreamUsage = true
- usage.PromptTokens = xAIResp.Usage.PromptTokens
- usage.TotalTokens = xAIResp.Usage.TotalTokens
- usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
- }
- openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
- _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
- err = helper.ObjectData(c, openaiResponse)
- if err != nil {
- common.SysError(err.Error())
- }
- return true
- })
- if !containStreamUsage {
- usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
- usage.CompletionTokens += toolCount * 7
- }
- helper.Done(c)
- common.CloseResponseBodyGracefully(resp)
- return usage, nil
- }
- func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- defer common.CloseResponseBodyGracefully(resp)
- responseBody, err := io.ReadAll(resp.Body)
- var response *dto.SimpleResponse
- err = common.Unmarshal(responseBody, &response)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
- response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
- // new body
- encodeJson, err := common.Marshal(response)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
- }
- common.IOCopyBytesGracefully(c, resp, encodeJson)
- return &response.Usage, nil
- }
|