Просмотр исходного кода

feat(api): migrate Mistral provider to AI SDK (#11089)

Co-authored-by: Matt Rubens <[email protected]>
Co-authored-by: Roo Code <[email protected]>
Daniel 1 неделя назад
Родитель
Сommit
b020f6be43

+ 15 - 0
pnpm-lock.yaml

@@ -758,6 +758,9 @@ importers:
       '@ai-sdk/groq':
         specifier: ^3.0.19
         version: 3.0.19([email protected])
+      '@ai-sdk/mistral':
+        specifier: ^3.0.0
+        version: 3.0.16([email protected])
       '@anthropic-ai/bedrock-sdk':
         specifier: ^0.10.2
         version: 0.10.4
@@ -1432,6 +1435,12 @@ packages:
     peerDependencies:
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]':
+    resolution: {integrity: sha512-8I/gxXJwghaDLbQQHMBwd61WxYz/PaFUFlG8I38daNYj5qRTMmQ5V10Idi6GJJC0wWEqQkal31lidm9+Y+u6TQ==}
+    engines: {node: '>=18'}
+    peerDependencies:
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]':
     resolution: {integrity: sha512-znBvaVHM0M6yWNerIEy3hR+O8ZK2sPcE7e2cxfb6kYLEX3k//JH5VDnRnajseVofg7LXtTCFFdjsB7WLf1BdeQ==}
     engines: {node: '>=18'}
@@ -11077,6 +11086,12 @@ snapshots:
       '@ai-sdk/provider-utils': 4.0.11([email protected])
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]([email protected])':
+    dependencies:
+      '@ai-sdk/provider': 3.0.6
+      '@ai-sdk/provider-utils': 4.0.11([email protected])
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]([email protected])':
     dependencies:
       '@ai-sdk/provider': 2.0.1

+ 417 - 366
src/api/providers/__tests__/mistral.spec.ts

@@ -1,59 +1,36 @@
-// Mock TelemetryService - must come before other imports
-const mockCaptureException = vi.hoisted(() => vi.fn())
-vi.mock("@roo-code/telemetry", () => ({
-	TelemetryService: {
-		instance: {
-			captureException: mockCaptureException,
-		},
-	},
+// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls
+const { mockStreamText, mockGenerateText, mockCreateMistral } = vi.hoisted(() => ({
+	mockStreamText: vi.fn(),
+	mockGenerateText: vi.fn(),
+	mockCreateMistral: vi.fn(() => {
+		// Return a function that returns a mock language model
+		return vi.fn(() => ({
+			modelId: "codestral-latest",
+			provider: "mistral",
+		}))
+	}),
 }))
 
-// Mock Mistral client - must come before other imports
-const mockCreate = vi.fn()
-const mockComplete = vi.fn()
-vi.mock("@mistralai/mistralai", () => {
+vi.mock("ai", async (importOriginal) => {
+	const actual = await importOriginal<typeof import("ai")>()
 	return {
-		Mistral: vi.fn().mockImplementation(() => ({
-			chat: {
-				stream: mockCreate.mockImplementation(async (_options) => {
-					const stream = {
-						[Symbol.asyncIterator]: async function* () {
-							yield {
-								data: {
-									choices: [
-										{
-											delta: { content: "Test response" },
-											index: 0,
-										},
-									],
-								},
-							}
-						},
-					}
-					return stream
-				}),
-				complete: mockComplete.mockImplementation(async (_options) => {
-					return {
-						choices: [
-							{
-								message: {
-									content: "Test response",
-								},
-							},
-						],
-					}
-				}),
-			},
-		})),
+		...actual,
+		streamText: mockStreamText,
+		generateText: mockGenerateText,
 	}
 })
 
+vi.mock("@ai-sdk/mistral", () => ({
+	createMistral: mockCreateMistral,
+}))
+
 import type { Anthropic } from "@anthropic-ai/sdk"
-import type OpenAI from "openai"
-import { MistralHandler } from "../mistral"
+
+import { mistralDefaultModelId, mistralModels, type MistralModelId } from "@roo-code/types"
+
 import type { ApiHandlerOptions } from "../../../shared/api"
-import type { ApiHandlerCreateMessageMetadata } from "../../index"
-import type { ApiStreamTextChunk, ApiStreamReasoningChunk, ApiStreamToolCallPartialChunk } from "../../transform/stream"
+
+import { MistralHandler } from "../mistral"
 
 describe("MistralHandler", () => {
 	let handler: MistralHandler
@@ -61,15 +38,11 @@ describe("MistralHandler", () => {
 
 	beforeEach(() => {
 		mockOptions = {
-			apiModelId: "codestral-latest", // Update to match the actual model ID
 			mistralApiKey: "test-api-key",
-			includeMaxTokens: true,
-			modelTemperature: 0,
+			apiModelId: "codestral-latest" as MistralModelId,
 		}
 		handler = new MistralHandler(mockOptions)
-		mockCreate.mockClear()
-		mockComplete.mockClear()
-		mockCaptureException.mockClear()
+		vi.clearAllMocks()
 	})
 
 	describe("constructor", () => {
@@ -78,32 +51,53 @@ describe("MistralHandler", () => {
 			expect(handler.getModel().id).toBe(mockOptions.apiModelId)
 		})
 
-		it("should throw error if API key is missing", () => {
-			expect(() => {
-				new MistralHandler({
-					...mockOptions,
-					mistralApiKey: undefined,
-				})
-			}).toThrow("Mistral API key is required")
-		})
-
-		it("should use custom base URL if provided", () => {
-			const customBaseUrl = "https://custom.mistral.ai/v1"
-			const handlerWithCustomUrl = new MistralHandler({
+		it("should use default model ID if not provided", () => {
+			const handlerWithoutModel = new MistralHandler({
 				...mockOptions,
-				mistralCodestralUrl: customBaseUrl,
+				apiModelId: undefined,
 			})
-			expect(handlerWithCustomUrl).toBeInstanceOf(MistralHandler)
+			expect(handlerWithoutModel.getModel().id).toBe(mistralDefaultModelId)
 		})
 	})
 
 	describe("getModel", () => {
-		it("should return correct model info", () => {
+		it("should return model info for valid model ID", () => {
 			const model = handler.getModel()
 			expect(model.id).toBe(mockOptions.apiModelId)
 			expect(model.info).toBeDefined()
+			expect(model.info.maxTokens).toBe(8192)
+			expect(model.info.contextWindow).toBe(256_000)
+			expect(model.info.supportsImages).toBe(false)
 			expect(model.info.supportsPromptCache).toBe(false)
 		})
+
+		it("should return provided model ID with default model info if model does not exist", () => {
+			const handlerWithInvalidModel = new MistralHandler({
+				...mockOptions,
+				apiModelId: "invalid-model",
+			})
+			const model = handlerWithInvalidModel.getModel()
+			expect(model.id).toBe("invalid-model") // Returns provided ID
+			expect(model.info).toBeDefined()
+			// Should have the same base properties as default model
+			expect(model.info.contextWindow).toBe(mistralModels[mistralDefaultModelId].contextWindow)
+		})
+
+		it("should return default model if no model ID is provided", () => {
+			const handlerWithoutModel = new MistralHandler({
+				...mockOptions,
+				apiModelId: undefined,
+			})
+			const model = handlerWithoutModel.getModel()
+			expect(model.id).toBe(mistralDefaultModelId)
+			expect(model.info).toBeDefined()
+		})
+
+		it("should include model parameters from getModelParams", () => {
+			const model = handler.getModel()
+			expect(model).toHaveProperty("temperature")
+			expect(model).toHaveProperty("maxTokens")
+		})
 	})
 
 	describe("createMessage", () => {
@@ -111,389 +105,446 @@ describe("MistralHandler", () => {
 		const messages: Anthropic.Messages.MessageParam[] = [
 			{
 				role: "user",
-				content: [{ type: "text", text: "Hello!" }],
+				content: [
+					{
+						type: "text" as const,
+						text: "Hello!",
+					},
+				],
 			},
 		]
 
-		it("should create message successfully", async () => {
-			const iterator = handler.createMessage(systemPrompt, messages)
-			const result = await iterator.next()
+		it("should handle streaming responses", async () => {
+			// Mock the fullStream async generator
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			// Mock usage promise
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks.length).toBeGreaterThan(0)
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks).toHaveLength(1)
+			expect(textChunks[0].text).toBe("Test response")
+		})
+
+		it("should include usage information", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].inputTokens).toBe(10)
+			expect(usageChunks[0].outputTokens).toBe(5)
+		})
+
+		it("should handle reasoning content in streaming responses", async () => {
+			// Mock the fullStream async generator with reasoning content
+			async function* mockFullStream() {
+				yield { type: "reasoning", text: "Let me think about this..." }
+				yield { type: "reasoning", text: " I'll analyze step by step." }
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+				details: {
+					reasoningTokens: 15,
+				},
+			})
 
-			expect(mockCreate).toHaveBeenCalledWith(
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Should have reasoning chunks
+			const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning")
+			expect(reasoningChunks.length).toBe(2)
+			expect(reasoningChunks[0].text).toBe("Let me think about this...")
+			expect(reasoningChunks[1].text).toBe(" I'll analyze step by step.")
+
+			// Should also have text chunks
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks.length).toBe(1)
+			expect(textChunks[0].text).toBe("Test response")
+		})
+	})
+
+	describe("completePrompt", () => {
+		it("should complete a prompt using generateText", async () => {
+			mockGenerateText.mockResolvedValue({
+				text: "Test completion",
+			})
+
+			const result = await handler.completePrompt("Test prompt")
+
+			expect(result).toBe("Test completion")
+			expect(mockGenerateText).toHaveBeenCalledWith(
 				expect.objectContaining({
-					model: mockOptions.apiModelId,
-					messages: expect.any(Array),
-					maxTokens: expect.any(Number),
-					temperature: 0,
-					// Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS)
-					tools: expect.any(Array),
-					toolChoice: "any",
+					prompt: "Test prompt",
 				}),
 			)
-
-			expect(result.value).toBeDefined()
-			expect(result.done).toBe(false)
 		})
+	})
 
-		it("should handle streaming response correctly", async () => {
-			const iterator = handler.createMessage(systemPrompt, messages)
-			const results: ApiStreamTextChunk[] = []
-
-			for await (const chunk of iterator) {
-				if ("text" in chunk) {
-					results.push(chunk as ApiStreamTextChunk)
+	describe("processUsageMetrics", () => {
+		it("should correctly process usage metrics", () => {
+			// We need to access the protected method, so we'll create a test subclass
+			class TestMistralHandler extends MistralHandler {
+				public testProcessUsageMetrics(usage: any) {
+					return this.processUsageMetrics(usage)
 				}
 			}
 
-			expect(results.length).toBeGreaterThan(0)
-			expect(results[0].text).toBe("Test response")
-		})
+			const testHandler = new TestMistralHandler(mockOptions)
+
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
+				details: {
+					cachedInputTokens: 20,
+					reasoningTokens: 30,
+				},
+			}
+
+			const result = testHandler.testProcessUsageMetrics(usage)
 
-		it("should handle errors gracefully", async () => {
-			mockCreate.mockRejectedValueOnce(new Error("API Error"))
-			await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error")
+			expect(result.type).toBe("usage")
+			expect(result.inputTokens).toBe(100)
+			expect(result.outputTokens).toBe(50)
+			expect(result.cacheReadTokens).toBe(20)
+			expect(result.reasoningTokens).toBe(30)
 		})
 
-		it("should handle thinking content as reasoning chunks", async () => {
-			// Mock stream with thinking content matching new SDK structure
-			mockCreate.mockImplementationOnce(async (_options) => {
-				const stream = {
-					[Symbol.asyncIterator]: async function* () {
-						yield {
-							data: {
-								choices: [
-									{
-										delta: {
-											content: [
-												{
-													type: "thinking",
-													thinking: [{ type: "text", text: "Let me think about this..." }],
-												},
-												{ type: "text", text: "Here's the answer" },
-											],
-										},
-										index: 0,
-									},
-								],
-							},
-						}
-					},
+		it("should handle missing cache metrics gracefully", () => {
+			class TestMistralHandler extends MistralHandler {
+				public testProcessUsageMetrics(usage: any) {
+					return this.processUsageMetrics(usage)
 				}
-				return stream
-			})
+			}
+
+			const testHandler = new TestMistralHandler(mockOptions)
 
-			const iterator = handler.createMessage(systemPrompt, messages)
-			const results: (ApiStreamTextChunk | ApiStreamReasoningChunk)[] = []
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
+			}
+
+			const result = testHandler.testProcessUsageMetrics(usage)
+
+			expect(result.type).toBe("usage")
+			expect(result.inputTokens).toBe(100)
+			expect(result.outputTokens).toBe(50)
+			expect(result.cacheReadTokens).toBeUndefined()
+			expect(result.reasoningTokens).toBeUndefined()
+		})
+	})
 
-			for await (const chunk of iterator) {
-				if ("text" in chunk) {
-					results.push(chunk as ApiStreamTextChunk | ApiStreamReasoningChunk)
+	describe("getMaxOutputTokens", () => {
+		it("should return maxTokens from model info", () => {
+			class TestMistralHandler extends MistralHandler {
+				public testGetMaxOutputTokens() {
+					return this.getMaxOutputTokens()
 				}
 			}
 
-			expect(results).toHaveLength(2)
-			expect(results[0]).toEqual({ type: "reasoning", text: "Let me think about this..." })
-			expect(results[1]).toEqual({ type: "text", text: "Here's the answer" })
+			const testHandler = new TestMistralHandler(mockOptions)
+			const result = testHandler.testGetMaxOutputTokens()
+
+			// codestral-latest maxTokens is 8192
+			expect(result).toBe(8192)
 		})
 
-		it("should handle mixed content arrays correctly", async () => {
-			// Mock stream with mixed content matching new SDK structure
-			mockCreate.mockImplementationOnce(async (_options) => {
-				const stream = {
-					[Symbol.asyncIterator]: async function* () {
-						yield {
-							data: {
-								choices: [
-									{
-										delta: {
-											content: [
-												{ type: "text", text: "First text" },
-												{
-													type: "thinking",
-													thinking: [{ type: "text", text: "Some reasoning" }],
-												},
-												{ type: "text", text: "Second text" },
-											],
-										},
-										index: 0,
-									},
-								],
-							},
-						}
-					},
+		it("should use modelMaxTokens when provided", () => {
+			class TestMistralHandler extends MistralHandler {
+				public testGetMaxOutputTokens() {
+					return this.getMaxOutputTokens()
 				}
-				return stream
+			}
+
+			const customMaxTokens = 5000
+			const testHandler = new TestMistralHandler({
+				...mockOptions,
+				modelMaxTokens: customMaxTokens,
 			})
 
-			const iterator = handler.createMessage(systemPrompt, messages)
-			const results: (ApiStreamTextChunk | ApiStreamReasoningChunk)[] = []
+			const result = testHandler.testGetMaxOutputTokens()
+			expect(result).toBe(customMaxTokens)
+		})
 
-			for await (const chunk of iterator) {
-				if ("text" in chunk) {
-					results.push(chunk as ApiStreamTextChunk | ApiStreamReasoningChunk)
+		it("should fall back to modelInfo.maxTokens when modelMaxTokens is not provided", () => {
+			class TestMistralHandler extends MistralHandler {
+				public testGetMaxOutputTokens() {
+					return this.getMaxOutputTokens()
 				}
 			}
 
-			expect(results).toHaveLength(3)
-			expect(results[0]).toEqual({ type: "text", text: "First text" })
-			expect(results[1]).toEqual({ type: "reasoning", text: "Some reasoning" })
-			expect(results[2]).toEqual({ type: "text", text: "Second text" })
+			const testHandler = new TestMistralHandler(mockOptions)
+			const result = testHandler.testGetMaxOutputTokens()
+
+			// codestral-latest has maxTokens of 8192
+			expect(result).toBe(8192)
 		})
 	})
 
-	describe("native tool calling", () => {
+	describe("tool handling", () => {
 		const systemPrompt = "You are a helpful assistant."
 		const messages: Anthropic.Messages.MessageParam[] = [
 			{
 				role: "user",
-				content: [{ type: "text", text: "What's the weather?" }],
+				content: [{ type: "text" as const, text: "Hello!" }],
 			},
 		]
 
-		const mockTools: OpenAI.Chat.ChatCompletionTool[] = [
-			{
-				type: "function",
-				function: {
-					name: "get_weather",
-					description: "Get the current weather",
-					parameters: {
-						type: "object",
-						properties: {
-							location: { type: "string" },
+		it("should handle tool calls in streaming", async () => {
+			async function* mockFullStream() {
+				yield {
+					type: "tool-input-start",
+					id: "tool-call-1",
+					toolName: "read_file",
+				}
+				yield {
+					type: "tool-input-delta",
+					id: "tool-call-1",
+					delta: '{"path":"test.ts"}',
+				}
+				yield {
+					type: "tool-input-end",
+					id: "tool-call-1",
+				}
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: [
+					{
+						type: "function",
+						function: {
+							name: "read_file",
+							description: "Read a file",
+							parameters: {
+								type: "object",
+								properties: { path: { type: "string" } },
+								required: ["path"],
+							},
 						},
-						required: ["location"],
 					},
-				},
-			},
-		]
+				],
+			})
 
-		it("should include tools in request by default (native is default)", async () => {
-			const metadata: ApiHandlerCreateMessageMetadata = {
-				taskId: "test-task",
-				tools: mockTools,
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
 			}
 
-			const iterator = handler.createMessage(systemPrompt, messages, metadata)
-			await iterator.next()
+			const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start")
+			const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta")
+			const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end")
 
-			expect(mockCreate).toHaveBeenCalledWith(
-				expect.objectContaining({
-					tools: expect.arrayContaining([
-						expect.objectContaining({
-							type: "function",
-							function: expect.objectContaining({
-								name: "get_weather",
-								description: "Get the current weather",
-								parameters: expect.any(Object),
-							}),
-						}),
-					]),
-					toolChoice: "any",
-				}),
-			)
+			expect(toolCallStartChunks.length).toBe(1)
+			expect(toolCallStartChunks[0].id).toBe("tool-call-1")
+			expect(toolCallStartChunks[0].name).toBe("read_file")
+
+			expect(toolCallDeltaChunks.length).toBe(1)
+			expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}')
+
+			expect(toolCallEndChunks.length).toBe(1)
+			expect(toolCallEndChunks[0].id).toBe("tool-call-1")
 		})
 
-		it("should always include tools in request (tools are always present after PR #10841)", async () => {
-			const metadata: ApiHandlerCreateMessageMetadata = {
-				taskId: "test-task",
+		it("should ignore tool-call events to prevent duplicate tools in UI", async () => {
+			// tool-call events are intentionally ignored because tool-input-start/delta/end
+			// already provide complete tool call information. Emitting tool-call would cause
+			// duplicate tools in the UI for AI SDK providers.
+			async function* mockFullStream() {
+				yield {
+					type: "tool-call",
+					toolCallId: "tool-call-1",
+					toolName: "read_file",
+					input: { path: "test.ts" },
+				}
 			}
 
-			const iterator = handler.createMessage(systemPrompt, messages, metadata)
-			await iterator.next()
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
 
-			// Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS)
-			expect(mockCreate).toHaveBeenCalledWith(
-				expect.objectContaining({
-					tools: expect.any(Array),
-					toolChoice: "any",
-				}),
-			)
-		})
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
 
-		it("should handle tool calls in streaming response", async () => {
-			// Mock stream with tool calls
-			mockCreate.mockImplementationOnce(async (_options) => {
-				const stream = {
-					[Symbol.asyncIterator]: async function* () {
-						yield {
-							data: {
-								choices: [
-									{
-										delta: {
-											toolCalls: [
-												{
-													id: "call_123",
-													type: "function",
-													function: {
-														name: "get_weather",
-														arguments: '{"location":"New York"}',
-													},
-												},
-											],
-										},
-										index: 0,
-									},
-								],
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: [
+					{
+						type: "function",
+						function: {
+							name: "read_file",
+							description: "Read a file",
+							parameters: {
+								type: "object",
+								properties: { path: { type: "string" } },
+								required: ["path"],
 							},
-						}
+						},
 					},
-				}
-				return stream
+				],
 			})
 
-			const metadata: ApiHandlerCreateMessageMetadata = {
-				taskId: "test-task",
-				tools: mockTools,
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
 			}
 
-			const iterator = handler.createMessage(systemPrompt, messages, metadata)
-			const results: ApiStreamToolCallPartialChunk[] = []
+			// tool-call events are ignored, so no tool_call chunks should be emitted
+			const toolCallChunks = chunks.filter((c) => c.type === "tool_call")
+			expect(toolCallChunks.length).toBe(0)
+		})
+	})
 
-			for await (const chunk of iterator) {
-				if (chunk.type === "tool_call_partial") {
-					results.push(chunk)
+	describe("mapToolChoice", () => {
+		it("should handle string tool choices", () => {
+			class TestMistralHandler extends MistralHandler {
+				public testMapToolChoice(toolChoice: any) {
+					return this.mapToolChoice(toolChoice)
 				}
 			}
 
-			expect(results).toHaveLength(1)
-			expect(results[0]).toEqual({
-				type: "tool_call_partial",
-				index: 0,
-				id: "call_123",
-				name: "get_weather",
-				arguments: '{"location":"New York"}',
-			})
+			const testHandler = new TestMistralHandler(mockOptions)
+
+			expect(testHandler.testMapToolChoice("auto")).toBe("auto")
+			expect(testHandler.testMapToolChoice("none")).toBe("none")
+			expect(testHandler.testMapToolChoice("required")).toBe("required")
+			expect(testHandler.testMapToolChoice("any")).toBe("required")
+			expect(testHandler.testMapToolChoice("unknown")).toBe("auto")
 		})
 
-		it("should handle multiple tool calls in a single response", async () => {
-			// Mock stream with multiple tool calls
-			mockCreate.mockImplementationOnce(async (_options) => {
-				const stream = {
-					[Symbol.asyncIterator]: async function* () {
-						yield {
-							data: {
-								choices: [
-									{
-										delta: {
-											toolCalls: [
-												{
-													id: "call_1",
-													type: "function",
-													function: {
-														name: "get_weather",
-														arguments: '{"location":"NYC"}',
-													},
-												},
-												{
-													id: "call_2",
-													type: "function",
-													function: {
-														name: "get_weather",
-														arguments: '{"location":"LA"}',
-													},
-												},
-											],
-										},
-										index: 0,
-									},
-								],
-							},
-						}
-					},
+		it("should handle object tool choice with function name", () => {
+			class TestMistralHandler extends MistralHandler {
+				public testMapToolChoice(toolChoice: any) {
+					return this.mapToolChoice(toolChoice)
 				}
-				return stream
-			})
-
-			const metadata: ApiHandlerCreateMessageMetadata = {
-				taskId: "test-task",
-				tools: mockTools,
 			}
 
-			const iterator = handler.createMessage(systemPrompt, messages, metadata)
-			const results: ApiStreamToolCallPartialChunk[] = []
+			const testHandler = new TestMistralHandler(mockOptions)
 
-			for await (const chunk of iterator) {
-				if (chunk.type === "tool_call_partial") {
-					results.push(chunk)
+			const result = testHandler.testMapToolChoice({
+				type: "function",
+				function: { name: "my_tool" },
+			})
+
+			expect(result).toEqual({ type: "tool", toolName: "my_tool" })
+		})
+
+		it("should return undefined for null or undefined", () => {
+			class TestMistralHandler extends MistralHandler {
+				public testMapToolChoice(toolChoice: any) {
+					return this.mapToolChoice(toolChoice)
 				}
 			}
 
-			expect(results).toHaveLength(2)
-			expect(results[0]).toEqual({
-				type: "tool_call_partial",
-				index: 0,
-				id: "call_1",
-				name: "get_weather",
-				arguments: '{"location":"NYC"}',
-			})
-			expect(results[1]).toEqual({
-				type: "tool_call_partial",
-				index: 1,
-				id: "call_2",
-				name: "get_weather",
-				arguments: '{"location":"LA"}',
-			})
+			const testHandler = new TestMistralHandler(mockOptions)
+
+			expect(testHandler.testMapToolChoice(null)).toBeUndefined()
+			expect(testHandler.testMapToolChoice(undefined)).toBeUndefined()
 		})
+	})
 
-		it("should always set toolChoice to 'any' when tools are provided", async () => {
-			// Even if tool_choice is provided in metadata, we override it to "any"
-			const metadata: ApiHandlerCreateMessageMetadata = {
-				taskId: "test-task",
-				tools: mockTools,
-				tool_choice: "auto", // This should be ignored
-			}
+	describe("Codestral URL handling", () => {
+		beforeEach(() => {
+			mockCreateMistral.mockClear()
+		})
 
-			const iterator = handler.createMessage(systemPrompt, messages, metadata)
-			await iterator.next()
+		it("should use default Codestral URL for codestral models", () => {
+			new MistralHandler({
+				...mockOptions,
+				apiModelId: "codestral-latest",
+			})
 
-			expect(mockCreate).toHaveBeenCalledWith(
+			expect(mockCreateMistral).toHaveBeenCalledWith(
 				expect.objectContaining({
-					toolChoice: "any",
+					baseURL: "https://codestral.mistral.ai/v1",
 				}),
 			)
 		})
-	})
 
-	describe("completePrompt", () => {
-		it("should complete prompt successfully", async () => {
-			const prompt = "Test prompt"
-			const result = await handler.completePrompt(prompt)
-
-			expect(mockComplete).toHaveBeenCalledWith({
-				model: mockOptions.apiModelId,
-				messages: [{ role: "user", content: prompt }],
-				temperature: 0,
+		it("should use custom Codestral URL when provided", () => {
+			new MistralHandler({
+				...mockOptions,
+				apiModelId: "codestral-latest",
+				mistralCodestralUrl: "https://custom.codestral.url/v1",
 			})
 
-			expect(result).toBe("Test response")
+			expect(mockCreateMistral).toHaveBeenCalledWith(
+				expect.objectContaining({
+					baseURL: "https://custom.codestral.url/v1",
+				}),
+			)
 		})
 
-		it("should filter out thinking content in completePrompt", async () => {
-			mockComplete.mockImplementationOnce(async (_options) => {
-				return {
-					choices: [
-						{
-							message: {
-								content: [
-									{ type: "thinking", text: "Let me think..." },
-									{ type: "text", text: "Answer part 1" },
-									{ type: "text", text: "Answer part 2" },
-								],
-							},
-						},
-					],
-				}
+		it("should use default Mistral URL for non-codestral models", () => {
+			new MistralHandler({
+				...mockOptions,
+				apiModelId: "mistral-large-latest",
 			})
 
-			const prompt = "Test prompt"
-			const result = await handler.completePrompt(prompt)
-
-			expect(result).toBe("Answer part 1Answer part 2")
-		})
-
-		it("should handle errors in completePrompt", async () => {
-			mockComplete.mockRejectedValueOnce(new Error("API Error"))
-			await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Mistral completion error: API Error")
+			expect(mockCreateMistral).toHaveBeenCalledWith(
+				expect.objectContaining({
+					baseURL: "https://api.mistral.ai/v1",
+				}),
+			)
 		})
 	})
 })

+ 155 - 178
src/api/providers/mistral.ts

@@ -1,224 +1,201 @@
 import { Anthropic } from "@anthropic-ai/sdk"
-import { Mistral } from "@mistralai/mistralai"
-import OpenAI from "openai"
+import { createMistral } from "@ai-sdk/mistral"
+import { streamText, generateText, ToolSet, LanguageModel } from "ai"
 
 import {
-	type MistralModelId,
-	mistralDefaultModelId,
 	mistralModels,
+	mistralDefaultModelId,
+	type MistralModelId,
+	type ModelInfo,
 	MISTRAL_DEFAULT_TEMPERATURE,
-	ApiProviderError,
 } from "@roo-code/types"
-import { TelemetryService } from "@roo-code/telemetry"
 
-import { ApiHandlerOptions } from "../../shared/api"
-
-import { convertToMistralMessages } from "../transform/mistral-format"
-import { ApiStream } from "../transform/stream"
-import { handleProviderError } from "./utils/error-handler"
+import type { ApiHandlerOptions } from "../../shared/api"
 
+import {
+	convertToAiSdkMessages,
+	convertToolsForAiSdk,
+	processAiSdkStreamPart,
+	handleAiSdkError,
+} from "../transform/ai-sdk"
+import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
+import { getModelParams } from "../transform/model-params"
+
+import { DEFAULT_HEADERS } from "./constants"
 import { BaseProvider } from "./base-provider"
 import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
 
-// Type helper to handle thinking chunks from Mistral API
-// The SDK includes ThinkChunk but TypeScript has trouble with the discriminated union
-type ContentChunkWithThinking = {
-	type: string
-	text?: string
-	thinking?: Array<{ type: string; text?: string }>
-}
-
-// Type for Mistral tool calls in stream delta
-type MistralToolCall = {
-	id?: string
-	type?: string
-	function?: {
-		name?: string
-		arguments?: string
-	}
-}
-
-// Type for Mistral tool definition - matches Mistral SDK Tool type
-type MistralTool = {
-	type: "function"
-	function: {
-		name: string
-		description?: string
-		parameters: Record<string, unknown>
-	}
-}
-
+/**
+ * Mistral provider using the dedicated @ai-sdk/mistral package.
+ * Provides access to Mistral AI models including Codestral, Mistral Large, and more.
+ */
 export class MistralHandler extends BaseProvider implements SingleCompletionHandler {
 	protected options: ApiHandlerOptions
-	private client: Mistral
-	private readonly providerName = "Mistral"
+	protected provider: ReturnType<typeof createMistral>
 
 	constructor(options: ApiHandlerOptions) {
 		super()
+		this.options = options
 
-		if (!options.mistralApiKey) {
-			throw new Error("Mistral API key is required")
-		}
+		const modelId = options.apiModelId ?? mistralDefaultModelId
 
-		// Set default model ID if not provided.
-		const apiModelId = options.apiModelId || mistralDefaultModelId
-		this.options = { ...options, apiModelId }
+		// Determine the base URL based on the model (Codestral uses a different endpoint)
+		const baseURL = modelId.startsWith("codestral-")
+			? options.mistralCodestralUrl || "https://codestral.mistral.ai/v1"
+			: "https://api.mistral.ai/v1"
 
-		this.client = new Mistral({
-			serverURL: apiModelId.startsWith("codestral-")
-				? this.options.mistralCodestralUrl || "https://codestral.mistral.ai"
-				: "https://api.mistral.ai",
-			apiKey: this.options.mistralApiKey,
+		// Create the Mistral provider using AI SDK
+		this.provider = createMistral({
+			apiKey: options.mistralApiKey ?? "not-provided",
+			baseURL,
+			headers: DEFAULT_HEADERS,
 		})
 	}
 
-	override async *createMessage(
-		systemPrompt: string,
-		messages: Anthropic.Messages.MessageParam[],
-		metadata?: ApiHandlerCreateMessageMetadata,
-	): ApiStream {
-		const { id: model, info, maxTokens, temperature } = this.getModel()
-
-		// Build request options
-		const requestOptions: {
-			model: string
-			messages: ReturnType<typeof convertToMistralMessages>
-			maxTokens: number
-			temperature: number
-			tools?: MistralTool[]
-			toolChoice?: "auto" | "none" | "any" | "required" | { type: "function"; function: { name: string } }
-		} = {
-			model,
-			messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)],
-			maxTokens: maxTokens ?? info.maxTokens,
-			temperature,
-		}
-
-		requestOptions.tools = this.convertToolsForMistral(metadata?.tools ?? [])
-		// Always use "any" to require tool use
-		requestOptions.toolChoice = "any"
+	override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } {
+		const id = (this.options.apiModelId ?? mistralDefaultModelId) as MistralModelId
+		const info = mistralModels[id as keyof typeof mistralModels] || mistralModels[mistralDefaultModelId]
+		const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options })
+		return { id, info, ...params }
+	}
 
-		// Temporary debug log for QA
-		// console.log("[MISTRAL DEBUG] Raw API request body:", requestOptions)
+	/**
+	 * Get the language model for the configured model ID.
+	 */
+	protected getLanguageModel(): LanguageModel {
+		const { id } = this.getModel()
+		// Type assertion needed due to version mismatch between @ai-sdk/mistral and ai packages
+		return this.provider(id) as unknown as LanguageModel
+	}
 
-		let response
-		try {
-			response = await this.client.chat.stream(requestOptions)
-		} catch (error) {
-			const errorMessage = error instanceof Error ? error.message : String(error)
-			const apiError = new ApiProviderError(errorMessage, this.providerName, model, "createMessage")
-			TelemetryService.instance.captureException(apiError)
-			throw new Error(`Mistral completion error: ${errorMessage}`)
+	/**
+	 * Process usage metrics from the AI SDK response.
+	 */
+	protected processUsageMetrics(usage: {
+		inputTokens?: number
+		outputTokens?: number
+		details?: {
+			cachedInputTokens?: number
+			reasoningTokens?: number
+		}
+	}): ApiStreamUsageChunk {
+		return {
+			type: "usage",
+			inputTokens: usage.inputTokens || 0,
+			outputTokens: usage.outputTokens || 0,
+			cacheReadTokens: usage.details?.cachedInputTokens,
+			reasoningTokens: usage.details?.reasoningTokens,
 		}
+	}
 
-		for await (const event of response) {
-			const delta = event.data.choices[0]?.delta
-
-			if (delta?.content) {
-				if (typeof delta.content === "string") {
-					// Handle string content as text
-					yield { type: "text", text: delta.content }
-				} else if (Array.isArray(delta.content)) {
-					// Handle array of content chunks
-					// The SDK v1.9.18 supports ThinkChunk with type "thinking"
-					for (const chunk of delta.content as ContentChunkWithThinking[]) {
-						if (chunk.type === "thinking" && chunk.thinking) {
-							// Handle thinking content as reasoning chunks
-							// ThinkChunk has a 'thinking' property that contains an array of text/reference chunks
-							for (const thinkingPart of chunk.thinking) {
-								if (thinkingPart.type === "text" && thinkingPart.text) {
-									yield { type: "reasoning", text: thinkingPart.text }
-								}
-							}
-						} else if (chunk.type === "text" && chunk.text) {
-							// Handle text content normally
-							yield { type: "text", text: chunk.text }
-						}
-					}
-				}
-			}
+	/**
+	 * Map OpenAI tool_choice to AI SDK toolChoice format.
+	 */
+	protected mapToolChoice(
+		toolChoice: any,
+	): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined {
+		if (!toolChoice) {
+			return undefined
+		}
 
-			// Handle tool calls in stream
-			// Mistral SDK provides tool_calls in delta similar to OpenAI format
-			const toolCalls = (delta as { toolCalls?: MistralToolCall[] })?.toolCalls
-			if (toolCalls) {
-				for (let i = 0; i < toolCalls.length; i++) {
-					const toolCall = toolCalls[i]
-					yield {
-						type: "tool_call_partial",
-						index: i,
-						id: toolCall.id,
-						name: toolCall.function?.name,
-						arguments: toolCall.function?.arguments,
-					}
-				}
+		// Handle string values
+		if (typeof toolChoice === "string") {
+			switch (toolChoice) {
+				case "auto":
+					return "auto"
+				case "none":
+					return "none"
+				case "required":
+				case "any":
+					return "required"
+				default:
+					return "auto"
 			}
+		}
 
-			if (event.data.usage) {
-				yield {
-					type: "usage",
-					inputTokens: event.data.usage.promptTokens || 0,
-					outputTokens: event.data.usage.completionTokens || 0,
-				}
+		// Handle object values (OpenAI ChatCompletionNamedToolChoice format)
+		if (typeof toolChoice === "object" && "type" in toolChoice) {
+			if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) {
+				return { type: "tool", toolName: toolChoice.function.name }
 			}
 		}
+
+		return undefined
 	}
 
 	/**
-	 * Convert OpenAI tool definitions to Mistral format.
-	 * Mistral uses the same format as OpenAI for function tools.
+	 * Get the max tokens parameter to include in the request.
 	 */
-	private convertToolsForMistral(tools: OpenAI.Chat.ChatCompletionTool[]): MistralTool[] {
-		return tools
-			.filter((tool) => tool.type === "function")
-			.map((tool) => ({
-				type: "function" as const,
-				function: {
-					name: tool.function.name,
-					description: tool.function.description,
-					// Mistral SDK requires parameters to be defined, use empty object as fallback
-					parameters: (tool.function.parameters as Record<string, unknown>) || {},
-				},
-			}))
+	protected getMaxOutputTokens(): number | undefined {
+		const { info } = this.getModel()
+		return this.options.modelMaxTokens || info.maxTokens || undefined
 	}
 
-	override getModel() {
-		const id = this.options.apiModelId ?? mistralDefaultModelId
-		const info = mistralModels[id as MistralModelId] ?? mistralModels[mistralDefaultModelId]
-
-		// @TODO: Move this to the `getModelParams` function.
-		const maxTokens = this.options.includeMaxTokens ? info.maxTokens : undefined
-		const temperature = this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE
-
-		return { id, info, maxTokens, temperature }
-	}
+	/**
+	 * Create a message stream using the AI SDK.
+	 */
+	override async *createMessage(
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[],
+		metadata?: ApiHandlerCreateMessageMetadata,
+	): ApiStream {
+		const languageModel = this.getLanguageModel()
+
+		// Convert messages to AI SDK format
+		const aiSdkMessages = convertToAiSdkMessages(messages)
+
+		// Convert tools to OpenAI format first, then to AI SDK format
+		const openAiTools = this.convertToolsForOpenAI(metadata?.tools)
+		const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined
+
+		// Build the request options
+		// Use MISTRAL_DEFAULT_TEMPERATURE (1) as fallback to match original behavior
+		const requestOptions: Parameters<typeof streamText>[0] = {
+			model: languageModel,
+			system: systemPrompt,
+			messages: aiSdkMessages,
+			temperature: this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE,
+			maxOutputTokens: this.getMaxOutputTokens(),
+			tools: aiSdkTools,
+			toolChoice: this.mapToolChoice(metadata?.tool_choice),
+		}
 
-	async completePrompt(prompt: string): Promise<string> {
-		const { id: model, temperature } = this.getModel()
+		// Use streamText for streaming responses
+		const result = streamText(requestOptions)
 
 		try {
-			const response = await this.client.chat.complete({
-				model,
-				messages: [{ role: "user", content: prompt }],
-				temperature,
-			})
-
-			const content = response.choices?.[0]?.message.content
-
-			if (Array.isArray(content)) {
-				// Only return text content, filter out thinking content for non-streaming
-				return (content as ContentChunkWithThinking[])
-					.filter((c) => c.type === "text" && c.text)
-					.map((c) => c.text || "")
-					.join("")
+			// Process the full stream to get all events including reasoning
+			for await (const part of result.fullStream) {
+				for (const chunk of processAiSdkStreamPart(part)) {
+					yield chunk
+				}
 			}
 
-			return content || ""
+			// Yield usage metrics at the end
+			const usage = await result.usage
+			if (usage) {
+				yield this.processUsageMetrics(usage)
+			}
 		} catch (error) {
-			const errorMessage = error instanceof Error ? error.message : String(error)
-			const apiError = new ApiProviderError(errorMessage, this.providerName, model, "completePrompt")
-			TelemetryService.instance.captureException(apiError)
-			throw new Error(`Mistral completion error: ${errorMessage}`)
+			// Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.)
+			throw handleAiSdkError(error, "Mistral")
 		}
 	}
+
+	/**
+	 * Complete a prompt using the AI SDK generateText.
+	 */
+	async completePrompt(prompt: string): Promise<string> {
+		const languageModel = this.getLanguageModel()
+
+		// Use MISTRAL_DEFAULT_TEMPERATURE (1) as fallback to match original behavior
+		const { text } = await generateText({
+			model: languageModel,
+			prompt,
+			maxOutputTokens: this.getMaxOutputTokens(),
+			temperature: this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE,
+		})
+
+		return text
+	}
 }

+ 18 - 24
src/core/tools/__tests__/useMcpToolTool.spec.ts

@@ -676,14 +676,12 @@ describe("useMcpToolTool", () => {
 			mockProviderRef.deref.mockReturnValue({
 				getMcpHub: () => ({
 					callTool: vi.fn().mockResolvedValue(mockToolResult),
-					getAllServers: vi
-						.fn()
-						.mockReturnValue([
-							{
-								name: "figma-server",
-								tools: [{ name: "get_screenshot", description: "Get screenshot" }],
-							},
-						]),
+					getAllServers: vi.fn().mockReturnValue([
+						{
+							name: "figma-server",
+							tools: [{ name: "get_screenshot", description: "Get screenshot" }],
+						},
+					]),
 				}),
 				postMessageToWebview: vi.fn(),
 			})
@@ -790,14 +788,12 @@ describe("useMcpToolTool", () => {
 			mockProviderRef.deref.mockReturnValue({
 				getMcpHub: () => ({
 					callTool: vi.fn().mockResolvedValue(mockToolResult),
-					getAllServers: vi
-						.fn()
-						.mockReturnValue([
-							{
-								name: "figma-server",
-								tools: [{ name: "get_screenshot", description: "Get screenshot" }],
-							},
-						]),
+					getAllServers: vi.fn().mockReturnValue([
+						{
+							name: "figma-server",
+							tools: [{ name: "get_screenshot", description: "Get screenshot" }],
+						},
+					]),
 				}),
 				postMessageToWebview: vi.fn(),
 			})
@@ -852,14 +848,12 @@ describe("useMcpToolTool", () => {
 			mockProviderRef.deref.mockReturnValue({
 				getMcpHub: () => ({
 					callTool: vi.fn().mockResolvedValue(mockToolResult),
-					getAllServers: vi
-						.fn()
-						.mockReturnValue([
-							{
-								name: "figma-server",
-								tools: [{ name: "get_screenshots", description: "Get screenshots" }],
-							},
-						]),
+					getAllServers: vi.fn().mockReturnValue([
+						{
+							name: "figma-server",
+							tools: [{ name: "get_screenshots", description: "Get screenshots" }],
+						},
+					]),
 				}),
 				postMessageToWebview: vi.fn(),
 			})

+ 1 - 0
src/package.json

@@ -454,6 +454,7 @@
 		"@ai-sdk/deepseek": "^2.0.14",
 		"@ai-sdk/fireworks": "^2.0.26",
 		"@ai-sdk/groq": "^3.0.19",
+		"@ai-sdk/mistral": "^3.0.0",
 		"@anthropic-ai/bedrock-sdk": "^0.10.2",
 		"@anthropic-ai/sdk": "^0.37.0",
 		"@anthropic-ai/vertex-sdk": "^0.7.0",