Преглед изворни кода

Make Grok3 streaming work with OpenAI Compatible (#2449)

amittell пре 8 месеци
родитељ
комит
1445bb0a3c
2 измењених фајлова са 51 додато и 2 уклоњено
  1. 40 0
      src/api/providers/__tests__/openai.test.ts
  2. 11 2
      src/api/providers/openai.ts

+ 40 - 0
src/api/providers/__tests__/openai.test.ts

@@ -352,4 +352,44 @@ describe("OpenAiHandler", () => {
 			)
 			)
 		})
 		})
 	})
 	})
+
+	describe("Grok xAI Provider", () => {
+		const grokOptions = {
+			...mockOptions,
+			openAiBaseUrl: "https://api.x.ai/v1",
+			openAiModelId: "grok-1",
+		}
+
+		it("should initialize with Grok xAI configuration", () => {
+			const grokHandler = new OpenAiHandler(grokOptions)
+			expect(grokHandler).toBeInstanceOf(OpenAiHandler)
+			expect(grokHandler.getModel().id).toBe(grokOptions.openAiModelId)
+		})
+
+		it("should exclude stream_options when streaming with Grok xAI", async () => {
+			const grokHandler = new OpenAiHandler(grokOptions)
+			const systemPrompt = "You are a helpful assistant."
+			const messages: Anthropic.Messages.MessageParam[] = [
+				{
+					role: "user",
+					content: "Hello!",
+				},
+			]
+
+			const stream = grokHandler.createMessage(systemPrompt, messages)
+			await stream.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					model: grokOptions.openAiModelId,
+					stream: true,
+				}),
+				{},
+			)
+
+			const mockCalls = mockCreate.mock.calls
+			const lastCall = mockCalls[mockCalls.length - 1]
+			expect(lastCall[0]).not.toHaveProperty("stream_options")
+		})
+	})
 })
 })

+ 11 - 2
src/api/providers/openai.ts

@@ -137,12 +137,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 				}
 				}
 			}
 			}
 
 
+			const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)
+
 			const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
 			const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
 				model: modelId,
 				model: modelId,
 				temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
 				temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
 				messages: convertedMessages,
 				messages: convertedMessages,
 				stream: true as const,
 				stream: true as const,
-				stream_options: { include_usage: true },
+				...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
 			}
 			}
 			if (this.options.includeMaxTokens) {
 			if (this.options.includeMaxTokens) {
 				requestOptions.max_tokens = modelInfo.maxTokens
 				requestOptions.max_tokens = modelInfo.maxTokens
@@ -265,6 +267,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 		if (this.options.openAiStreamingEnabled ?? true) {
 		if (this.options.openAiStreamingEnabled ?? true) {
 			const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
 			const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
 
 
+			const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)
+
 			const stream = await this.client.chat.completions.create(
 			const stream = await this.client.chat.completions.create(
 				{
 				{
 					model: modelId,
 					model: modelId,
@@ -276,7 +280,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 						...convertToOpenAiMessages(messages),
 						...convertToOpenAiMessages(messages),
 					],
 					],
 					stream: true,
 					stream: true,
-					stream_options: { include_usage: true },
+					...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
 					reasoning_effort: this.getModel().info.reasoningEffort,
 					reasoning_effort: this.getModel().info.reasoningEffort,
 				},
 				},
 				methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
 				methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
@@ -337,6 +341,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 		}
 		}
 	}
 	}
 
 
+	private _isGrokXAI(baseUrl?: string): boolean {
+		const urlHost = this._getUrlHost(baseUrl)
+		return urlHost.includes("x.ai")
+	}
+
 	private _isAzureAiInference(baseUrl?: string): boolean {
 	private _isAzureAiInference(baseUrl?: string): boolean {
 		const urlHost = this._getUrlHost(baseUrl)
 		const urlHost = this._getUrlHost(baseUrl)
 		return urlHost.endsWith(".services.ai.azure.com")
 		return urlHost.endsWith(".services.ai.azure.com")