Procházet zdrojové kódy

Add Grok 4 Fast and Grok 4.1 Fast, plus xAI native tool calling (#9690)

* Add Grok 4 Fast and Grok 4.1 Fast

* Add native tool calling support
Matt Rubens před 1 měsícem
rodič
revize
ae29777062

+ 60 - 0
packages/types/src/providers/xai.ts

@@ -11,17 +11,71 @@ export const xaiModels = {
 		contextWindow: 262_144,
 		supportsImages: false,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 0.2,
 		outputPrice: 1.5,
 		cacheWritesPrice: 0.02,
 		cacheReadsPrice: 0.02,
 		description: "xAI's Grok Code Fast model with 256K context window",
 	},
+	"grok-4-1-fast-reasoning": {
+		maxTokens: 65_536,
+		contextWindow: 2_000_000,
+		supportsImages: true,
+		supportsPromptCache: true,
+		supportsNativeTools: true,
+		inputPrice: 0.2,
+		outputPrice: 0.5,
+		cacheWritesPrice: 0.05,
+		cacheReadsPrice: 0.05,
+		description:
+			"xAI's Grok 4.1 Fast model with 2M context window, optimized for high-performance agentic tool calling with reasoning",
+	},
+	"grok-4-1-fast-non-reasoning": {
+		maxTokens: 65_536,
+		contextWindow: 2_000_000,
+		supportsImages: true,
+		supportsPromptCache: true,
+		supportsNativeTools: true,
+		inputPrice: 0.2,
+		outputPrice: 0.5,
+		cacheWritesPrice: 0.05,
+		cacheReadsPrice: 0.05,
+		description:
+			"xAI's Grok 4.1 Fast model with 2M context window, optimized for high-performance agentic tool calling",
+	},
+	"grok-4-fast-reasoning": {
+		maxTokens: 65_536,
+		contextWindow: 2_000_000,
+		supportsImages: true,
+		supportsPromptCache: true,
+		supportsNativeTools: true,
+		inputPrice: 0.2,
+		outputPrice: 0.5,
+		cacheWritesPrice: 0.05,
+		cacheReadsPrice: 0.05,
+		description:
+			"xAI's Grok 4 Fast model with 2M context window, optimized for high-performance agentic tool calling with reasoning",
+	},
+	"grok-4-fast-non-reasoning": {
+		maxTokens: 65_536,
+		contextWindow: 2_000_000,
+		supportsImages: true,
+		supportsPromptCache: true,
+		supportsNativeTools: true,
+		inputPrice: 0.2,
+		outputPrice: 0.5,
+		cacheWritesPrice: 0.05,
+		cacheReadsPrice: 0.05,
+		description:
+			"xAI's Grok 4 Fast model with 2M context window, optimized for high-performance agentic tool calling",
+	},
 	"grok-4": {
 		maxTokens: 8192,
 		contextWindow: 256000,
 		supportsImages: true,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 3.0,
 		outputPrice: 15.0,
 		cacheWritesPrice: 0.75,
@@ -33,6 +87,7 @@ export const xaiModels = {
 		contextWindow: 131072,
 		supportsImages: false,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 3.0,
 		outputPrice: 15.0,
 		cacheWritesPrice: 0.75,
@@ -44,6 +99,7 @@ export const xaiModels = {
 		contextWindow: 131072,
 		supportsImages: false,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 5.0,
 		outputPrice: 25.0,
 		cacheWritesPrice: 1.25,
@@ -55,6 +111,7 @@ export const xaiModels = {
 		contextWindow: 131072,
 		supportsImages: false,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 0.3,
 		outputPrice: 0.5,
 		cacheWritesPrice: 0.07,
@@ -67,6 +124,7 @@ export const xaiModels = {
 		contextWindow: 131072,
 		supportsImages: false,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 0.6,
 		outputPrice: 4.0,
 		cacheWritesPrice: 0.15,
@@ -79,6 +137,7 @@ export const xaiModels = {
 		contextWindow: 131072,
 		supportsImages: false,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 2.0,
 		outputPrice: 10.0,
 		description: "xAI's Grok-2 model (version 1212) with 128K context window",
@@ -88,6 +147,7 @@ export const xaiModels = {
 		contextWindow: 32768,
 		supportsImages: true,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 2.0,
 		outputPrice: 10.0,
 		description: "xAI's Grok-2 Vision model (version 1212) with image support and 32K context window",

+ 216 - 0
src/api/providers/__tests__/xai.spec.ts

@@ -280,4 +280,220 @@ describe("XAIHandler", () => {
 			}),
 		)
 	})
+
+	describe("Native Tool Calling", () => {
+		const testTools = [
+			{
+				type: "function" as const,
+				function: {
+					name: "test_tool",
+					description: "A test tool",
+					parameters: {
+						type: "object",
+						properties: {
+							arg1: { type: "string", description: "First argument" },
+						},
+						required: ["arg1"],
+					},
+				},
+			},
+		]
+
+		it("should include tools in request when model supports native tools and tools are provided", async () => {
+			const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
+
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				}
+			})
+
+			const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "native",
+			})
+			await messageGenerator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					tools: expect.arrayContaining([
+						expect.objectContaining({
+							type: "function",
+							function: expect.objectContaining({
+								name: "test_tool",
+							}),
+						}),
+					]),
+					parallel_tool_calls: false,
+				}),
+			)
+		})
+
+		it("should include tool_choice when provided", async () => {
+			const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
+
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				}
+			})
+
+			const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "native",
+				tool_choice: "auto",
+			})
+			await messageGenerator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					tool_choice: "auto",
+				}),
+			)
+		})
+
+		it("should not include tools when toolProtocol is xml", async () => {
+			const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
+
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				}
+			})
+
+			const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "xml",
+			})
+			await messageGenerator.next()
+
+			const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0]
+			expect(callArgs).not.toHaveProperty("tools")
+			expect(callArgs).not.toHaveProperty("tool_choice")
+		})
+
+		it("should yield tool_call_partial chunks during streaming", async () => {
+			const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
+
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						next: vi
+							.fn()
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [
+										{
+											delta: {
+												tool_calls: [
+													{
+														index: 0,
+														id: "call_123",
+														function: {
+															name: "test_tool",
+															arguments: '{"arg1":',
+														},
+													},
+												],
+											},
+										},
+									],
+								},
+							})
+							.mockResolvedValueOnce({
+								done: false,
+								value: {
+									choices: [
+										{
+											delta: {
+												tool_calls: [
+													{
+														index: 0,
+														function: {
+															arguments: '"value"}',
+														},
+													},
+												],
+											},
+										},
+									],
+								},
+							})
+							.mockResolvedValueOnce({ done: true }),
+					}),
+				}
+			})
+
+			const stream = handlerWithTools.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "native",
+			})
+
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks).toContainEqual({
+				type: "tool_call_partial",
+				index: 0,
+				id: "call_123",
+				name: "test_tool",
+				arguments: '{"arg1":',
+			})
+
+			expect(chunks).toContainEqual({
+				type: "tool_call_partial",
+				index: 0,
+				id: undefined,
+				name: undefined,
+				arguments: '"value"}',
+			})
+		})
+
+		it("should set parallel_tool_calls based on metadata", async () => {
+			const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
+
+			mockCreate.mockImplementationOnce(() => {
+				return {
+					[Symbol.asyncIterator]: () => ({
+						async next() {
+							return { done: true }
+						},
+					}),
+				}
+			})
+
+			const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
+				taskId: "test-task-id",
+				tools: testTools,
+				toolProtocol: "native",
+				parallelToolCalls: true,
+			})
+			await messageGenerator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					parallel_tool_calls: true,
+				}),
+			)
+		})
+	})
 })

+ 21 - 0
src/api/providers/xai.ts

@@ -52,6 +52,11 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
 	): ApiStream {
 		const { id: modelId, info: modelInfo, reasoning } = this.getModel()
 
+		// Check if model supports native tools and tools are provided with native protocol
+		const supportsNativeTools = modelInfo.supportsNativeTools ?? false
+		const useNativeTools =
+			supportsNativeTools && metadata?.tools && metadata.tools.length > 0 && metadata?.toolProtocol !== "xml"
+
 		// Use the OpenAI-compatible API.
 		let stream
 		try {
@@ -63,6 +68,9 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
 				stream: true,
 				stream_options: { include_usage: true },
 				...(reasoning && reasoning),
+				...(useNativeTools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
+				...(useNativeTools && metadata.tool_choice && { tool_choice: metadata.tool_choice }),
+				...(useNativeTools && { parallel_tool_calls: metadata?.parallelToolCalls ?? false }),
 			})
 		} catch (error) {
 			throw handleOpenAIError(error, this.providerName)
@@ -85,6 +93,19 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler
 				}
 			}
 
+			// Handle tool calls in stream - emit partial chunks for NativeToolCallParser
+			if (delta?.tool_calls) {
+				for (const toolCall of delta.tool_calls) {
+					yield {
+						type: "tool_call_partial",
+						index: toolCall.index,
+						id: toolCall.id,
+						name: toolCall.function?.name,
+						arguments: toolCall.function?.arguments,
+					}
+				}
+			}
+
 			if (chunk.usage) {
 				// Extract detailed token information if available
 				// First check for prompt_tokens_details structure (real API response)