Browse Source

feat(mistral): add native tool calling support (#9625)

Hannes Rudolph 1 month ago
parent
commit
240bc0b6b1

+ 18 - 9
packages/types/src/providers/mistral.ts

@@ -11,73 +11,82 @@ export const mistralModels = {
 		contextWindow: 128_000,
 		supportsImages: true,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 2.0,
 		outputPrice: 5.0,
 	},
 	"devstral-medium-latest": {
-		maxTokens: 131_000,
+		maxTokens: 8192,
 		contextWindow: 131_000,
 		supportsImages: true,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 0.4,
 		outputPrice: 2.0,
 	},
 	"mistral-medium-latest": {
-		maxTokens: 131_000,
+		maxTokens: 8192,
 		contextWindow: 131_000,
 		supportsImages: true,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 0.4,
 		outputPrice: 2.0,
 	},
 	"codestral-latest": {
-		maxTokens: 256_000,
+		maxTokens: 8192,
 		contextWindow: 256_000,
 		supportsImages: false,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 0.3,
 		outputPrice: 0.9,
 	},
 	"mistral-large-latest": {
-		maxTokens: 131_000,
+		maxTokens: 8192,
 		contextWindow: 131_000,
 		supportsImages: false,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 2.0,
 		outputPrice: 6.0,
 	},
 	"ministral-8b-latest": {
-		maxTokens: 131_000,
+		maxTokens: 8192,
 		contextWindow: 131_000,
 		supportsImages: false,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 0.1,
 		outputPrice: 0.1,
 	},
 	"ministral-3b-latest": {
-		maxTokens: 131_000,
+		maxTokens: 8192,
 		contextWindow: 131_000,
 		supportsImages: false,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 0.04,
 		outputPrice: 0.04,
 	},
 	"mistral-small-latest": {
-		maxTokens: 32_000,
+		maxTokens: 8192,
 		contextWindow: 32_000,
 		supportsImages: false,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 0.2,
 		outputPrice: 0.6,
 	},
 	"pixtral-large-latest": {
-		maxTokens: 131_000,
+		maxTokens: 8192,
 		contextWindow: 131_000,
 		supportsImages: true,
 		supportsPromptCache: false,
+		supportsNativeTools: true,
 		inputPrice: 2.0,
 		outputPrice: 6.0,
 	},
 } as const satisfies Record<string, ModelInfo>
 
-export const MISTRAL_DEFAULT_TEMPERATURE = 0
+export const MISTRAL_DEFAULT_TEMPERATURE = 1

+ 220 - 1
src/api/providers/__tests__/mistral.spec.ts

@@ -39,9 +39,11 @@ vi.mock("@mistralai/mistralai", () => {
 })
 
 import type { Anthropic } from "@anthropic-ai/sdk"
+import type OpenAI from "openai"
 import { MistralHandler } from "../mistral"
 import type { ApiHandlerOptions } from "../../../shared/api"
-import type { ApiStreamTextChunk, ApiStreamReasoningChunk } from "../../transform/stream"
+import type { ApiHandlerCreateMessageMetadata } from "../../index"
+import type { ApiStreamTextChunk, ApiStreamReasoningChunk, ApiStreamToolCallPartialChunk } from "../../transform/stream"
 
 describe("MistralHandler", () => {
 	let handler: MistralHandler
@@ -223,6 +225,223 @@ describe("MistralHandler", () => {
 		})
 	})
 
+	describe("native tool calling", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [{ type: "text", text: "What's the weather?" }],
+			},
+		]
+
+		const mockTools: OpenAI.Chat.ChatCompletionTool[] = [
+			{
+				type: "function",
+				function: {
+					name: "get_weather",
+					description: "Get the current weather",
+					parameters: {
+						type: "object",
+						properties: {
+							location: { type: "string" },
+						},
+						required: ["location"],
+					},
+				},
+			},
+		]
+
+		it("should include tools in request when toolProtocol is native", async () => {
+			const metadata: ApiHandlerCreateMessageMetadata = {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "native",
+			}
+
+			const iterator = handler.createMessage(systemPrompt, messages, metadata)
+			await iterator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					tools: expect.arrayContaining([
+						expect.objectContaining({
+							type: "function",
+							function: expect.objectContaining({
+								name: "get_weather",
+								description: "Get the current weather",
+								parameters: expect.any(Object),
+							}),
+						}),
+					]),
+					toolChoice: "any",
+				}),
+			)
+		})
+
+		it("should not include tools when toolProtocol is xml", async () => {
+			const metadata: ApiHandlerCreateMessageMetadata = {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "xml",
+			}
+
+			const iterator = handler.createMessage(systemPrompt, messages, metadata)
+			await iterator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.not.objectContaining({
+					tools: expect.anything(),
+				}),
+			)
+		})
+
+		it("should handle tool calls in streaming response", async () => {
+			// Mock stream with tool calls
+			mockCreate.mockImplementationOnce(async (_options) => {
+				const stream = {
+					[Symbol.asyncIterator]: async function* () {
+						yield {
+							data: {
+								choices: [
+									{
+										delta: {
+											toolCalls: [
+												{
+													id: "call_123",
+													type: "function",
+													function: {
+														name: "get_weather",
+														arguments: '{"location":"New York"}',
+													},
+												},
+											],
+										},
+										index: 0,
+									},
+								],
+							},
+						}
+					},
+				}
+				return stream
+			})
+
+			const metadata: ApiHandlerCreateMessageMetadata = {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "native",
+			}
+
+			const iterator = handler.createMessage(systemPrompt, messages, metadata)
+			const results: ApiStreamToolCallPartialChunk[] = []
+
+			for await (const chunk of iterator) {
+				if (chunk.type === "tool_call_partial") {
+					results.push(chunk)
+				}
+			}
+
+			expect(results).toHaveLength(1)
+			expect(results[0]).toEqual({
+				type: "tool_call_partial",
+				index: 0,
+				id: "call_123",
+				name: "get_weather",
+				arguments: '{"location":"New York"}',
+			})
+		})
+
+		it("should handle multiple tool calls in a single response", async () => {
+			// Mock stream with multiple tool calls
+			mockCreate.mockImplementationOnce(async (_options) => {
+				const stream = {
+					[Symbol.asyncIterator]: async function* () {
+						yield {
+							data: {
+								choices: [
+									{
+										delta: {
+											toolCalls: [
+												{
+													id: "call_1",
+													type: "function",
+													function: {
+														name: "get_weather",
+														arguments: '{"location":"NYC"}',
+													},
+												},
+												{
+													id: "call_2",
+													type: "function",
+													function: {
+														name: "get_weather",
+														arguments: '{"location":"LA"}',
+													},
+												},
+											],
+										},
+										index: 0,
+									},
+								],
+							},
+						}
+					},
+				}
+				return stream
+			})
+
+			const metadata: ApiHandlerCreateMessageMetadata = {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "native",
+			}
+
+			const iterator = handler.createMessage(systemPrompt, messages, metadata)
+			const results: ApiStreamToolCallPartialChunk[] = []
+
+			for await (const chunk of iterator) {
+				if (chunk.type === "tool_call_partial") {
+					results.push(chunk)
+				}
+			}
+
+			expect(results).toHaveLength(2)
+			expect(results[0]).toEqual({
+				type: "tool_call_partial",
+				index: 0,
+				id: "call_1",
+				name: "get_weather",
+				arguments: '{"location":"NYC"}',
+			})
+			expect(results[1]).toEqual({
+				type: "tool_call_partial",
+				index: 1,
+				id: "call_2",
+				name: "get_weather",
+				arguments: '{"location":"LA"}',
+			})
+		})
+
+		it("should always set toolChoice to 'any' when tools are provided", async () => {
+			// Even if tool_choice is provided in metadata, we override it to "any"
+			const metadata: ApiHandlerCreateMessageMetadata = {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "native",
+				tool_choice: "auto", // This should be ignored
+			}
+
+			const iterator = handler.createMessage(systemPrompt, messages, metadata)
+			await iterator.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					toolChoice: "any",
+				}),
+			)
+		})
+	})
+
 	describe("completePrompt", () => {
 		it("should complete prompt successfully", async () => {
 			const prompt = "Test prompt"

+ 81 - 5
src/api/providers/mistral.ts

@@ -1,5 +1,6 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import { Mistral } from "@mistralai/mistralai"
+import OpenAI from "openai"
 
 import { type MistralModelId, mistralDefaultModelId, mistralModels, MISTRAL_DEFAULT_TEMPERATURE } from "@roo-code/types"
 
@@ -19,6 +20,26 @@ type ContentChunkWithThinking = {
 	thinking?: Array<{ type: string; text?: string }>
 }
 
+// Type for Mistral tool calls in stream delta
+type MistralToolCall = {
+	id?: string
+	type?: string
+	function?: {
+		name?: string
+		arguments?: string
+	}
+}
+
+// Type for Mistral tool definition - matches Mistral SDK Tool type
+type MistralTool = {
+	type: "function"
+	function: {
+		name: string
+		description?: string
+		parameters: Record<string, unknown>
+	}
+}
+
 export class MistralHandler extends BaseProvider implements SingleCompletionHandler {
 	protected options: ApiHandlerOptions
 	private client: Mistral
@@ -47,14 +68,35 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand
 		messages: Anthropic.Messages.MessageParam[],
 		metadata?: ApiHandlerCreateMessageMetadata,
 	): ApiStream {
-		const { id: model, maxTokens, temperature } = this.getModel()
-
-		const response = await this.client.chat.stream({
+		const { id: model, info, maxTokens, temperature } = this.getModel()
+
+		// Build request options
+		const requestOptions: {
+			model: string
+			messages: ReturnType<typeof convertToMistralMessages>
+			maxTokens: number
+			temperature: number
+			tools?: MistralTool[]
+			toolChoice?: "auto" | "none" | "any" | "required" | { type: "function"; function: { name: string } }
+		} = {
 			model,
 			messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)],
-			maxTokens,
+			maxTokens: maxTokens ?? info.maxTokens,
 			temperature,
-		})
+		}
+
+		// Add tools if provided and toolProtocol is not 'xml' and model supports native tools
+		const supportsNativeTools = info.supportsNativeTools ?? false
+		if (metadata?.tools && metadata.tools.length > 0 && metadata?.toolProtocol !== "xml" && supportsNativeTools) {
+			requestOptions.tools = this.convertToolsForMistral(metadata.tools)
+			// Always use "any" to require tool use
+			requestOptions.toolChoice = "any"
+		}
+
+		// Temporary debug log for QA
+		// console.log("[MISTRAL DEBUG] Raw API request body:", requestOptions)
+
+		const response = await this.client.chat.stream(requestOptions)
 
 		for await (const event of response) {
 			const delta = event.data.choices[0]?.delta
@@ -83,6 +125,22 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand
 				}
 			}
 
+			// Handle tool calls in stream
+			// Mistral SDK provides tool_calls in delta similar to OpenAI format
+			const toolCalls = (delta as { toolCalls?: MistralToolCall[] })?.toolCalls
+			if (toolCalls) {
+				for (let i = 0; i < toolCalls.length; i++) {
+					const toolCall = toolCalls[i]
+					yield {
+						type: "tool_call_partial",
+						index: i,
+						id: toolCall.id,
+						name: toolCall.function?.name,
+						arguments: toolCall.function?.arguments,
+					}
+				}
+			}
+
 			if (event.data.usage) {
 				yield {
 					type: "usage",
@@ -93,6 +151,24 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand
 		}
 	}
 
+	/**
+	 * Convert OpenAI tool definitions to Mistral format.
+	 * Mistral uses the same format as OpenAI for function tools.
+	 */
+	private convertToolsForMistral(tools: OpenAI.Chat.ChatCompletionTool[]): MistralTool[] {
+		return tools
+			.filter((tool) => tool.type === "function")
+			.map((tool) => ({
+				type: "function" as const,
+				function: {
+					name: tool.function.name,
+					description: tool.function.description,
+					// Mistral SDK requires parameters to be defined, use empty object as fallback
+					parameters: (tool.function.parameters as Record<string, unknown>) || {},
+				},
+			}))
+	}
+
 	override getModel() {
 		const id = this.options.apiModelId ?? mistralDefaultModelId
 		const info = mistralModels[id as MistralModelId] ?? mistralModels[mistralDefaultModelId]

+ 20 - 23
src/api/transform/__tests__/mistral-format.spec.ts

@@ -83,10 +83,12 @@ describe("convertToMistralMessages", () => {
 			},
 		]
 
-		// Based on the implementation, tool results without accompanying text/image
-		// don't generate any messages
+		// Tool results are converted to Mistral "tool" role messages
 		const mistralMessages = convertToMistralMessages(anthropicMessages)
-		expect(mistralMessages).toHaveLength(0)
+		expect(mistralMessages).toHaveLength(1)
+		expect(mistralMessages[0].role).toBe("tool")
+		expect((mistralMessages[0] as { toolCallId?: string }).toolCallId).toBe("weather-123")
+		expect(mistralMessages[0].content).toBe("Current temperature in London: 20°C")
 	})
 
 	it("should handle user messages with mixed content (text, image, and tool results)", () => {
@@ -116,24 +118,14 @@ describe("convertToMistralMessages", () => {
 		]
 
 		const mistralMessages = convertToMistralMessages(anthropicMessages)
-		// Based on the implementation, only the text and image content is included
-		// Tool results are not converted to separate messages
+		// Mistral doesn't allow user messages after tool messages, so only tool results are converted
+		// User content (text/images) is intentionally skipped when there are tool results
 		expect(mistralMessages).toHaveLength(1)
 
-		// Message should be the user message with text and image
-		expect(mistralMessages[0].role).toBe("user")
-		const userContent = mistralMessages[0].content as Array<{
-			type: string
-			text?: string
-			imageUrl?: { url: string }
-		}>
-		expect(Array.isArray(userContent)).toBe(true)
-		expect(userContent).toHaveLength(2)
-		expect(userContent[0]).toEqual({ type: "text", text: "Here's the weather data and an image:" })
-		expect(userContent[1]).toEqual({
-			type: "image_url",
-			imageUrl: { url: "" },
-		})
+		// Only the tool result should be present
+		expect(mistralMessages[0].role).toBe("tool")
+		expect((mistralMessages[0] as { toolCallId?: string }).toolCallId).toBe("weather-123")
+		expect(mistralMessages[0].content).toBe("Current temperature in London: 20°C")
 	})
 
 	it("should handle assistant messages with text content", () => {
@@ -254,8 +246,8 @@ describe("convertToMistralMessages", () => {
 		]
 
 		const mistralMessages = convertToMistralMessages(anthropicMessages)
-		// Based on the implementation, user messages with only tool results don't generate messages
-		expect(mistralMessages).toHaveLength(3)
+		// Tool results are now converted to tool messages
+		expect(mistralMessages).toHaveLength(4)
 
 		// User message with image
 		expect(mistralMessages[0].role).toBe("user")
@@ -267,12 +259,17 @@ describe("convertToMistralMessages", () => {
 		expect(Array.isArray(userContent)).toBe(true)
 		expect(userContent).toHaveLength(2)
 
-		// Assistant message with text (tool_use is not included in Mistral format)
+		// Assistant message with text and toolCalls
 		expect(mistralMessages[1].role).toBe("assistant")
 		expect(mistralMessages[1].content).toBe("This image shows a landscape with mountains.")
 
+		// Tool result message
+		expect(mistralMessages[2].role).toBe("tool")
+		expect((mistralMessages[2] as { toolCallId?: string }).toolCallId).toBe("search-123")
+		expect(mistralMessages[2].content).toBe("Found information about different mountain types.")
+
 		// Final assistant message
-		expect(mistralMessages[2]).toEqual({
+		expect(mistralMessages[3]).toEqual({
 			role: "assistant",
 			content: "Based on the search results, I can tell you more about the mountains in the image.",
 		})

+ 70 - 5
src/api/transform/mistral-format.ts

@@ -10,6 +10,16 @@ export type MistralMessage =
 	| (AssistantMessage & { role: "assistant" })
 	| (ToolMessage & { role: "tool" })
 
+// Type for Mistral tool calls in assistant messages
+type MistralToolCallMessage = {
+	id: string
+	type: "function"
+	function: {
+		name: string
+		arguments: string
+	}
+}
+
 export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): MistralMessage[] {
 	const mistralMessages: MistralMessage[] = []
 
@@ -21,7 +31,7 @@ export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.M
 			})
 		} else {
 			if (anthropicMessage.role === "user") {
-				const { nonToolMessages } = anthropicMessage.content.reduce<{
+				const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
 					nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
 					toolMessages: Anthropic.ToolResultBlockParam[]
 				}>(
@@ -36,7 +46,35 @@ export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.M
 					{ nonToolMessages: [], toolMessages: [] },
 				)
 
-				if (nonToolMessages.length > 0) {
+				// If there are tool results, handle them
+				// Mistral's message order is strict: user → assistant → tool → assistant
+				// We CANNOT put user messages after tool messages
+				if (toolMessages.length > 0) {
+					// Convert tool_result blocks to Mistral tool messages
+					for (const toolResult of toolMessages) {
+						let resultContent: string
+						if (typeof toolResult.content === "string") {
+							resultContent = toolResult.content
+						} else if (Array.isArray(toolResult.content)) {
+							// Extract text from content blocks
+							resultContent = toolResult.content
+								.filter((block): block is Anthropic.TextBlockParam => block.type === "text")
+								.map((block) => block.text)
+								.join("\n")
+						} else {
+							resultContent = ""
+						}
+
+						mistralMessages.push({
+							role: "tool",
+							toolCallId: toolResult.tool_use_id,
+							content: resultContent,
+						} as ToolMessage & { role: "tool" })
+					}
+					// Note: We intentionally skip any non-tool user content when there are tool results
+					// because Mistral doesn't allow user messages after tool messages
+				} else if (nonToolMessages.length > 0) {
+					// Only add user content if there are NO tool results
 					mistralMessages.push({
 						role: "user",
 						content: nonToolMessages.map((part) => {
@@ -53,7 +91,7 @@ export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.M
 					})
 				}
 			} else if (anthropicMessage.role === "assistant") {
-				const { nonToolMessages } = anthropicMessage.content.reduce<{
+				const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
 					nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
 					toolMessages: Anthropic.ToolUseBlockParam[]
 				}>(
@@ -80,10 +118,37 @@ export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.M
 						.join("\n")
 				}
 
-				mistralMessages.push({
+				// Convert tool_use blocks to Mistral toolCalls format
+				let toolCalls: MistralToolCallMessage[] | undefined
+				if (toolMessages.length > 0) {
+					toolCalls = toolMessages.map((toolUse) => ({
+						id: toolUse.id,
+						type: "function" as const,
+						function: {
+							name: toolUse.name,
+							arguments:
+								typeof toolUse.input === "string" ? toolUse.input : JSON.stringify(toolUse.input),
+						},
+					}))
+				}
+
+				// Mistral requires either content or toolCalls to be non-empty
+				// If we have toolCalls but no content, we need to handle this properly
+				const assistantMessage: AssistantMessage & { role: "assistant" } = {
 					role: "assistant",
 					content,
-				})
+				}
+
+				if (toolCalls && toolCalls.length > 0) {
+					;(
+						assistantMessage as AssistantMessage & {
+							role: "assistant"
+							toolCalls?: MistralToolCallMessage[]
+						}
+					).toolCalls = toolCalls
+				}
+
+				mistralMessages.push(assistantMessage)
 			}
 		}
 	}