Jelajahi Sumber

feat(xai): 为xAI渠道添加/v1/responses支持 (#2897)

* feat(xai): 为xAI渠道添加/v1/responses支持

* Add video generation model to constants

* fix: 修正先前更改中对于grok-3-mini的思考预算和"-search"设计
funkpopo 1 Minggu lalu
induk
melakukan
75e533edb0

+ 2 - 0
common/endpoint_type.go

@@ -26,6 +26,8 @@ func GetEndpointTypesByChannelType(channelType int, modelName string) []constant
 		endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
 	case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
 		endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
+	case constant.ChannelTypeXai:
+		endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI, constant.EndpointTypeOpenAIResponse}
 	case constant.ChannelTypeSora:
 		endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIVideo}
 	default:

+ 10 - 5
relay/channel/xai/adaptor.go

@@ -83,9 +83,6 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 		} else if strings.HasSuffix(request.Model, "-low") {
 			request.ReasoningEffort = "low"
 			request.Model = strings.TrimSuffix(request.Model, "-low")
-		} else if strings.HasSuffix(request.Model, "-medium") {
-			request.ReasoningEffort = "medium"
-			request.Model = strings.TrimSuffix(request.Model, "-medium")
 		}
 		info.ReasoningEffort = request.ReasoningEffort
 		info.UpstreamModelName = request.Model
@@ -103,8 +100,10 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 }
 
 func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
-	// TODO implement me
-	return nil, errors.New("not implemented")
+	if request.Model == "" && info != nil {
+		request.Model = info.UpstreamModelName
+	}
+	return request, nil
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
@@ -115,6 +114,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	switch info.RelayMode {
 	case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
 		usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
+	case constant.RelayModeResponses:
+		if info.IsStream {
+			usage, err = openai.OaiResponsesStreamHandler(c, info, resp)
+		} else {
+			usage, err = openai.OaiResponsesHandler(c, info, resp)
+		}
 	default:
 		if info.IsStream {
 			usage, err = xAIStreamHandler(c, info, resp)

+ 26 - 14
relay/channel/xai/constants.go

@@ -1,20 +1,32 @@
 package xai
 
 var ModelList = []string{
-	// grok-4
-	"grok-4", "grok-4-0709", "grok-4-0709-search",
-	// grok-3
-	"grok-3-beta", "grok-3-mini-beta",
-	// grok-3 mini
-	"grok-3-fast-beta", "grok-3-mini-fast-beta",
-	// extend grok-3-mini reasoning
-	"grok-3-mini-beta-high", "grok-3-mini-beta-low", "grok-3-mini-beta-medium",
-	"grok-3-mini-fast-beta-high", "grok-3-mini-fast-beta-low", "grok-3-mini-fast-beta-medium",
-	// image model
-	"grok-2-image",
-	// legacy models
-	"grok-2", "grok-2-vision",
-	"grok-beta", "grok-vision-beta",
+	// language models
+	"grok-4-1-fast-reasoning",
+	"grok-4-1-fast-non-reasoning",
+	"grok-code-fast-1",
+	"grok-4-fast-reasoning",
+	"grok-4-fast-non-reasoning",
+	"grok-4-0709",
+	"grok-3-mini",
+	"grok-3",
+	"grok-2-vision-1212",
+	// search variants
+	"grok-4-1-fast-reasoning-search",
+	"grok-4-1-fast-non-reasoning-search",
+	"grok-4-fast-reasoning-search",
+	"grok-4-fast-non-reasoning-search",
+	"grok-4-0709-search",
+	"grok-3-mini-search",
+	"grok-3-search",
+	// grok-3-mini reasoning effort variants
+	"grok-3-mini-high", "grok-3-mini-low",
+	// image generation models
+	"grok-imagine-image-pro",
+	"grok-imagine-image",
+	"grok-2-image-1212",
+	// video generation model
+	"grok-imagine-video",
 }
 
 var ChannelName = "xai"

+ 1 - 2
relay/channel/xai/text.go

@@ -1,7 +1,6 @@
 package xai
 
 import (
-	"encoding/json"
 	"io"
 	"net/http"
 	"strings"
@@ -46,7 +45,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var xAIResp *dto.ChatCompletionsStreamResponse
-		err := json.Unmarshal([]byte(data), &xAIResp)
+		err := common.UnmarshalJsonStr(data, &xAIResp)
 		if err != nil {
 			common.SysLog("error unmarshalling stream response: " + err.Error())
 			return true