Ver Fonte

fix: xAI usage

CaIon há 8 meses atrás
pai
commit
ef8ae4db80

+ 6 - 7
relay/channel/openai/helper.go

@@ -41,12 +41,7 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
 	return nil
 }
 
-func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error {
-	var streamResponse dto.ChatCompletionsStreamResponse
-	if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
-		return err
-	}
-
+func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
 	for _, choice := range streamResponse.Choices {
 		responseTextBuilder.WriteString(choice.Delta.GetContentString())
 		responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
@@ -81,7 +76,11 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex
 		// 一次性解析失败,逐个解析
 		common.SysError("error unmarshalling stream response: " + err.Error())
 		for _, item := range streamItems {
-			if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil {
+			var streamResponse dto.ChatCompletionsStreamResponse
+			if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
+				return err
+			}
+			if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
 				common.SysError("error processing stream response: " + err.Error())
 			}
 		}

+ 1 - 2
relay/channel/openai/relay-openai.go

@@ -117,6 +117,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	model := info.UpstreamModelName
 
 	var responseTextBuilder strings.Builder
+	var toolCount int
 	var usage = &dto.Usage{}
 	var streamItems []string // store stream items
 	var forceFormat bool
@@ -130,8 +131,6 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 		thinkToContent = think2Content
 	}
 
-	toolCount := 0
-
 	var (
 		lastStreamData string
 	)

+ 0 - 1
relay/channel/xai/adaptor.go

@@ -48,7 +48,6 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	request.StreamOptions = nil
 	if strings.HasPrefix(request.Model, "grok-3-mini") {
 		if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
 			request.MaxCompletionTokens = request.MaxTokens

+ 12 - 0
relay/channel/xai/text.go

@@ -8,9 +8,11 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
+	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/service"
+	"strings"
 )
 
 func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
@@ -34,6 +36,9 @@ func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage
 
 func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	usage := &dto.Usage{}
+	var responseTextBuilder strings.Builder
+	var toolCount int
+	var containStreamUsage bool
 
 	helper.SetEventStreamHeaders(c)
 
@@ -47,12 +52,14 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 
 		// 把 xAI 的usage转换为 OpenAI 的usage
 		if xAIResp.Usage != nil {
+			containStreamUsage = true
 			usage.PromptTokens = xAIResp.Usage.PromptTokens
 			usage.TotalTokens = xAIResp.Usage.TotalTokens
 			usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
 		}
 
 		openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
+		_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
 		err = helper.ObjectData(c, openaiResponse)
 		if err != nil {
 			common.SysError(err.Error())
@@ -60,6 +67,11 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 		return true
 	})
 
+	if !containStreamUsage {
+		usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+		usage.CompletionTokens += toolCount * 7
+	}
+
 	helper.Done(c)
 	err := resp.Body.Close()
 	if err != nil {

+ 1 - 0
relay/common/relay_info.go

@@ -102,6 +102,7 @@ var streamSupportedChannels = map[int]bool{
 	common.ChannelTypeAzure:      true,
 	common.ChannelTypeVolcEngine: true,
 	common.ChannelTypeOllama:     true,
+	common.ChannelTypeXai:        true,
 }
 
 func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {