|
|
@@ -8,9 +8,11 @@ import (
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
"one-api/dto"
|
|
|
+ "one-api/relay/channel/openai"
|
|
|
relaycommon "one-api/relay/common"
|
|
|
"one-api/relay/helper"
|
|
|
"one-api/service"
|
|
|
+ "strings"
|
|
|
)
|
|
|
|
|
|
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
|
|
|
@@ -34,6 +36,9 @@ func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage
|
|
|
|
|
|
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
usage := &dto.Usage{}
|
|
|
+ var responseTextBuilder strings.Builder
|
|
|
+ var toolCount int
|
|
|
+ var containStreamUsage bool
|
|
|
|
|
|
helper.SetEventStreamHeaders(c)
|
|
|
|
|
|
@@ -47,12 +52,14 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|
|
|
|
|
// 把 xAI 的usage转换为 OpenAI 的usage
|
|
|
if xAIResp.Usage != nil {
|
|
|
+ containStreamUsage = true
|
|
|
usage.PromptTokens = xAIResp.Usage.PromptTokens
|
|
|
usage.TotalTokens = xAIResp.Usage.TotalTokens
|
|
|
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
|
|
}
|
|
|
|
|
|
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
|
|
|
+ _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
|
|
|
err = helper.ObjectData(c, openaiResponse)
|
|
|
if err != nil {
|
|
|
common.SysError(err.Error())
|
|
|
@@ -60,6 +67,11 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|
|
return true
|
|
|
})
|
|
|
|
|
|
+ if !containStreamUsage {
|
|
|
+ usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
|
|
+ usage.CompletionTokens += toolCount * 7
|
|
|
+ }
|
|
|
+
|
|
|
helper.Done(c)
|
|
|
err := resp.Body.Close()
|
|
|
if err != nil {
|