Explorar el Código

feat: migrate Cerebras provider to AI SDK (#11086)

Daniel hace 1 semana
padre
commit
0f43cc9814
Se han modificado 4 ficheros con 583 adiciones y 472 borrados
  1. 44 7
      pnpm-lock.yaml
  2. 414 161
      src/api/providers/__tests__/cerebras.spec.ts
  3. 124 304
      src/api/providers/cerebras.ts
  4. 1 0
      src/package.json

+ 44 - 7
pnpm-lock.yaml

@@ -743,9 +743,12 @@ importers:
 
   src:
     dependencies:
+      '@ai-sdk/cerebras':
+        specifier: ^1.0.0
+        version: 1.0.35([email protected])
       '@ai-sdk/deepseek':
         specifier: ^2.0.14
-        version: 2.0.14([email protected])
+        version: 2.0.15([email protected])
       '@anthropic-ai/bedrock-sdk':
         specifier: ^0.10.2
         version: 0.10.4
@@ -1390,8 +1393,14 @@ packages:
   '@adobe/[email protected]':
     resolution: {integrity: sha512-baYZExFpsdkBNuvGKTKWCwKH57HRZLVtycZS05WTQNVOiXVSeAki3nU35zlRbToeMW8aHlJfyS+1C4BOv27q0A==}
 
-  '@ai-sdk/[email protected]':
-    resolution: {integrity: sha512-1vXh8sVwRJYd1JO57qdy1rACucaNLDoBRCwOER3EbPgSF2vNVPcdJywGutA01Bhn7Cta+UJQ+k5y/yzMAIpP2w==}
+  '@ai-sdk/[email protected]':
+    resolution: {integrity: sha512-JrNdMYptrOUjNthibgBeAcBjZ/H+fXb49sSrWhOx5Aq8eUcrYvwQ2DtSAi8VraHssZu78NAnBMrgFWSUOTXFxw==}
+    engines: {node: '>=18'}
+    peerDependencies:
+      zod: 3.25.76
+
+  '@ai-sdk/[email protected]':
+    resolution: {integrity: sha512-3wJUjNjGrTZS3K8OEfHD1PZYhzkcXuoL8KIVtzi6WrC5xrDQPjCBPATmdKPV7DgDCF+wujQOaMz5cv40Yg+hog==}
     engines: {node: '>=18'}
     peerDependencies:
       zod: 3.25.76
@@ -1420,6 +1429,12 @@ packages:
     peerDependencies:
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]':
+    resolution: {integrity: sha512-y/WOPpcZaBjvNaogy83mBsCRPvbtaK0y1sY9ckRrrbTGMvG2HC/9Y/huqNXKnLAxUIME2PGa2uvF2CDwIsxoXQ==}
+    engines: {node: '>=18'}
+    peerDependencies:
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]':
     resolution: {integrity: sha512-KCUwswvsC5VsW2PWFqF8eJgSCu5Ysj7m1TxiHTVA6g7k360bk0RNQENT8KTMAYEs+8fWPD3Uu4dEmzGHc+jGng==}
     engines: {node: '>=18'}
@@ -1428,6 +1443,10 @@ packages:
     resolution: {integrity: sha512-2Xmoq6DBJqmSl80U6V9z5jJSJP7ehaJJQMy2iFUqTay06wdCqTnPVBBQbtEL8RCChenL+q5DC5H5WzU3vV3v8w==}
     engines: {node: '>=18'}
 
+  '@ai-sdk/[email protected]':
+    resolution: {integrity: sha512-hSfoJtLtpMd7YxKM+iTqlJ0ZB+kJ83WESMiWuWrNVey3X8gg97x0OdAAaeAeclZByCX3UdPOTqhvJdK8qYA3ww==}
+    engines: {node: '>=18'}
+
   '@alcalzone/[email protected]':
     resolution: {integrity: sha512-jsElTJ0sQ4wHRz+C45tfect76BwbTbgkgKByOzpCN9xG61N5V6u/glvg1CsNJhq2xJIFpKHSwG3D2wPPuEYOrQ==}
     engines: {node: '>=18'}
@@ -10819,10 +10838,17 @@ snapshots:
 
   '@adobe/[email protected]': {}
 
-  '@ai-sdk/[email protected]([email protected])':
+  '@ai-sdk/[email protected]([email protected])':
     dependencies:
-      '@ai-sdk/provider': 3.0.5
-      '@ai-sdk/provider-utils': 4.0.10([email protected])
+      '@ai-sdk/openai-compatible': 1.0.31([email protected])
+      '@ai-sdk/provider': 2.0.1
+      '@ai-sdk/provider-utils': 3.0.20([email protected])
+      zod: 3.25.76
+
+  '@ai-sdk/[email protected]([email protected])':
+    dependencies:
+      '@ai-sdk/provider': 3.0.6
+      '@ai-sdk/provider-utils': 4.0.11([email protected])
       zod: 3.25.76
 
   '@ai-sdk/[email protected]([email protected])':
@@ -10852,6 +10878,13 @@ snapshots:
       eventsource-parser: 3.0.6
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]([email protected])':
+    dependencies:
+      '@ai-sdk/provider': 3.0.6
+      '@standard-schema/spec': 1.1.0
+      eventsource-parser: 3.0.6
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]':
     dependencies:
       json-schema: 0.4.0
@@ -10860,6 +10893,10 @@ snapshots:
     dependencies:
       json-schema: 0.4.0
 
+  '@ai-sdk/[email protected]':
+    dependencies:
+      json-schema: 0.4.0
+
   '@alcalzone/[email protected]':
     dependencies:
       ansi-styles: 6.2.3
@@ -14686,7 +14723,7 @@ snapshots:
       sirv: 3.0.1
       tinyglobby: 0.2.14
       tinyrainbow: 2.0.0
-      vitest: 3.2.4(@types/[email protected])(@types/node@20.17.50)(@vitest/[email protected])([email protected])([email protected])([email protected])([email protected])([email protected])
+      vitest: 3.2.4(@types/[email protected])(@types/node@24.2.1)(@vitest/[email protected])([email protected])([email protected])([email protected])([email protected])([email protected])
 
   '@vitest/[email protected]':
     dependencies:

+ 414 - 161
src/api/providers/__tests__/cerebras.spec.ts

@@ -1,249 +1,502 @@
-// Mock i18n
-vi.mock("../../i18n", () => ({
-	t: vi.fn((key: string, params?: Record<string, any>) => {
-		// Return a simplified mock translation for testing
-		if (key.startsWith("common:errors.cerebras.")) {
-			return `Mocked: ${key.replace("common:errors.cerebras.", "")}`
-		}
-		return key
-	}),
+// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls
+const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({
+	mockStreamText: vi.fn(),
+	mockGenerateText: vi.fn(),
 }))
 
-// Mock DEFAULT_HEADERS
-vi.mock("../constants", () => ({
-	DEFAULT_HEADERS: {
-		"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
-		"X-Title": "Roo Code",
-		"User-Agent": "RooCode/1.0.0",
-	},
+vi.mock("ai", async (importOriginal) => {
+	const actual = await importOriginal<typeof import("ai")>()
+	return {
+		...actual,
+		streamText: mockStreamText,
+		generateText: mockGenerateText,
+	}
+})
+
+vi.mock("@ai-sdk/cerebras", () => ({
+	createCerebras: vi.fn(() => {
+		// Return a function that returns a mock language model
+		return vi.fn(() => ({
+			modelId: "llama-3.3-70b",
+			provider: "cerebras",
+		}))
+	}),
 }))
 
-import { CerebrasHandler } from "../cerebras"
-import { cerebrasModels, type CerebrasModelId } from "@roo-code/types"
+import type { Anthropic } from "@anthropic-ai/sdk"
+
+import { cerebrasDefaultModelId, cerebrasModels, type CerebrasModelId } from "@roo-code/types"
 
-// Mock fetch globally
-global.fetch = vi.fn()
+import type { ApiHandlerOptions } from "../../../shared/api"
+
+import { CerebrasHandler } from "../cerebras"
 
 describe("CerebrasHandler", () => {
 	let handler: CerebrasHandler
-	const mockOptions = {
-		cerebrasApiKey: "test-api-key",
-		apiModelId: "llama-3.3-70b" as CerebrasModelId,
-	}
+	let mockOptions: ApiHandlerOptions
 
 	beforeEach(() => {
-		vi.clearAllMocks()
+		mockOptions = {
+			cerebrasApiKey: "test-api-key",
+			apiModelId: "llama-3.3-70b" as CerebrasModelId,
+		}
 		handler = new CerebrasHandler(mockOptions)
+		vi.clearAllMocks()
 	})
 
 	describe("constructor", () => {
-		it("should throw error when API key is missing", () => {
-			expect(() => new CerebrasHandler({ cerebrasApiKey: "" })).toThrow("Cerebras API key is required")
+		it("should initialize with provided options", () => {
+			expect(handler).toBeInstanceOf(CerebrasHandler)
+			expect(handler.getModel().id).toBe(mockOptions.apiModelId)
 		})
 
-		it("should initialize with valid API key", () => {
-			expect(() => new CerebrasHandler(mockOptions)).not.toThrow()
+		it("should use default model ID if not provided", () => {
+			const handlerWithoutModel = new CerebrasHandler({
+				...mockOptions,
+				apiModelId: undefined,
+			})
+			expect(handlerWithoutModel.getModel().id).toBe(cerebrasDefaultModelId)
 		})
 	})
 
 	describe("getModel", () => {
-		it("should return correct model info", () => {
-			const { id, info } = handler.getModel()
-			expect(id).toBe("llama-3.3-70b")
-			expect(info).toEqual(cerebrasModels["llama-3.3-70b"])
+		it("should return model info for valid model ID", () => {
+			const model = handler.getModel()
+			expect(model.id).toBe(mockOptions.apiModelId)
+			expect(model.info).toBeDefined()
+			expect(model.info.maxTokens).toBe(16384)
+			expect(model.info.contextWindow).toBe(64000)
+			expect(model.info.supportsImages).toBe(false)
+			expect(model.info.supportsPromptCache).toBe(false)
 		})
 
-		it("should fallback to default model when apiModelId is not provided", () => {
-			const handlerWithoutModel = new CerebrasHandler({ cerebrasApiKey: "test" })
-			const { id } = handlerWithoutModel.getModel()
-			expect(id).toBe("gpt-oss-120b") // cerebrasDefaultModelId
-		})
-	})
-
-	describe("message conversion", () => {
-		it("should strip thinking tokens from assistant messages", () => {
-			// This would test the stripThinkingTokens function
-			// Implementation details would test the regex functionality
+		it("should return provided model ID with default model info if model does not exist", () => {
+			const handlerWithInvalidModel = new CerebrasHandler({
+				...mockOptions,
+				apiModelId: "invalid-model",
+			})
+			const model = handlerWithInvalidModel.getModel()
+			expect(model.id).toBe("invalid-model") // Returns provided ID
+			expect(model.info).toBeDefined()
+			// Should have the same base properties as default model
+			expect(model.info.contextWindow).toBe(cerebrasModels[cerebrasDefaultModelId].contextWindow)
 		})
 
-		it("should flatten complex message content to strings", () => {
-			// This would test the flattenMessageContent function
-			// Test various content types: strings, arrays, image objects
+		it("should return default model if no model ID is provided", () => {
+			const handlerWithoutModel = new CerebrasHandler({
+				...mockOptions,
+				apiModelId: undefined,
+			})
+			const model = handlerWithoutModel.getModel()
+			expect(model.id).toBe(cerebrasDefaultModelId)
+			expect(model.info).toBeDefined()
 		})
 
-		it("should convert OpenAI messages to Cerebras format", () => {
-			// This would test the convertToCerebrasMessages function
-			// Ensure all messages have string content and proper role/content structure
+		it("should include model parameters from getModelParams", () => {
+			const model = handler.getModel()
+			expect(model).toHaveProperty("temperature")
+			expect(model).toHaveProperty("maxTokens")
 		})
 	})
 
 	describe("createMessage", () => {
-		it("should make correct API request", async () => {
-			// Mock successful API response
-			const mockResponse = {
-				ok: true,
-				body: {
-					getReader: () => ({
-						read: vi.fn().mockResolvedValueOnce({ done: true, value: new Uint8Array() }),
-						releaseLock: vi.fn(),
-					}),
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [
+					{
+						type: "text" as const,
+						text: "Hello!",
+					},
+				],
+			},
+		]
+
+		it("should handle streaming responses", async () => {
+			// Mock the fullStream async generator
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			// Mock usage promise
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks.length).toBeGreaterThan(0)
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks).toHaveLength(1)
+			expect(textChunks[0].text).toBe("Test response")
+		})
+
+		it("should include usage information", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].inputTokens).toBe(10)
+			expect(usageChunks[0].outputTokens).toBe(5)
+		})
+
+		it("should handle reasoning content in streaming responses", async () => {
+			// Mock the fullStream async generator with reasoning content
+			async function* mockFullStream() {
+				yield { type: "reasoning", text: "Let me think about this..." }
+				yield { type: "reasoning", text: " I'll analyze step by step." }
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+				details: {
+					reasoningTokens: 15,
 				},
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
 			}
-			vi.mocked(fetch).mockResolvedValueOnce(mockResponse as any)
 
-			const generator = handler.createMessage("System prompt", [])
-			await generator.next() // Actually start the generator to trigger the fetch call
+			// Should have reasoning chunks
+			const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning")
+			expect(reasoningChunks.length).toBe(2)
+			expect(reasoningChunks[0].text).toBe("Let me think about this...")
+			expect(reasoningChunks[1].text).toBe(" I'll analyze step by step.")
+
+			// Should also have text chunks
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks.length).toBe(1)
+			expect(textChunks[0].text).toBe("Test response")
+		})
+	})
+
+	describe("completePrompt", () => {
+		it("should complete a prompt using generateText", async () => {
+			mockGenerateText.mockResolvedValue({
+				text: "Test completion",
+			})
+
+			const result = await handler.completePrompt("Test prompt")
 
-			// Test that fetch was called with correct parameters
-			expect(fetch).toHaveBeenCalledWith(
-				"https://api.cerebras.ai/v1/chat/completions",
+			expect(result).toBe("Test completion")
+			expect(mockGenerateText).toHaveBeenCalledWith(
 				expect.objectContaining({
-					method: "POST",
-					headers: expect.objectContaining({
-						"Content-Type": "application/json",
-						Authorization: "Bearer test-api-key",
-						"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
-						"X-Title": "Roo Code",
-						"User-Agent": "RooCode/1.0.0",
-					}),
+					prompt: "Test prompt",
 				}),
 			)
 		})
+	})
 
-		it("should handle API errors properly", async () => {
-			const mockErrorResponse = {
-				ok: false,
-				status: 400,
-				text: () => Promise.resolve('{"error": {"message": "Bad Request"}}'),
+	describe("processUsageMetrics", () => {
+		it("should correctly process usage metrics", () => {
+			// We need to access the protected method, so we'll create a test subclass
+			class TestCerebrasHandler extends CerebrasHandler {
+				public testProcessUsageMetrics(usage: any) {
+					return this.processUsageMetrics(usage)
+				}
 			}
-			vi.mocked(fetch).mockResolvedValueOnce(mockErrorResponse as any)
 
-			const generator = handler.createMessage("System prompt", [])
-			// Since the mock isn't working, let's just check that an error is thrown
-			await expect(generator.next()).rejects.toThrow()
-		})
+			const testHandler = new TestCerebrasHandler(mockOptions)
 
-		it("should parse streaming responses correctly", async () => {
-			// Test streaming response parsing
-			// Mock ReadableStream with various data chunks
-			// Verify thinking token extraction and usage tracking
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
+				details: {
+					cachedInputTokens: 20,
+					reasoningTokens: 30,
+				},
+			}
+
+			const result = testHandler.testProcessUsageMetrics(usage)
+
+			expect(result.type).toBe("usage")
+			expect(result.inputTokens).toBe(100)
+			expect(result.outputTokens).toBe(50)
+			expect(result.cacheReadTokens).toBe(20)
+			expect(result.reasoningTokens).toBe(30)
 		})
 
-		it("should handle temperature clamping", async () => {
-			const handlerWithTemp = new CerebrasHandler({
-				...mockOptions,
-				modelTemperature: 2.0, // Above Cerebras max of 1.5
-			})
+		it("should handle missing cache metrics gracefully", () => {
+			class TestCerebrasHandler extends CerebrasHandler {
+				public testProcessUsageMetrics(usage: any) {
+					return this.processUsageMetrics(usage)
+				}
+			}
+
+			const testHandler = new TestCerebrasHandler(mockOptions)
 
-			vi.mocked(fetch).mockResolvedValueOnce({
-				ok: true,
-				body: { getReader: () => ({ read: () => Promise.resolve({ done: true }), releaseLock: vi.fn() }) },
-			} as any)
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
+			}
 
-			await handlerWithTemp.createMessage("test", []).next()
+			const result = testHandler.testProcessUsageMetrics(usage)
 
-			const requestBody = JSON.parse(vi.mocked(fetch).mock.calls[0][1]?.body as string)
-			expect(requestBody.temperature).toBe(1.5) // Should be clamped
+			expect(result.type).toBe("usage")
+			expect(result.inputTokens).toBe(100)
+			expect(result.outputTokens).toBe(50)
+			expect(result.cacheReadTokens).toBeUndefined()
+			expect(result.reasoningTokens).toBeUndefined()
 		})
 	})
 
-	describe("completePrompt", () => {
-		it("should handle non-streaming completion", async () => {
-			const mockResponse = {
-				ok: true,
-				json: () =>
-					Promise.resolve({
-						choices: [{ message: { content: "Test response" } }],
-					}),
+	describe("getMaxOutputTokens", () => {
+		it("should return maxTokens from model info", () => {
+			class TestCerebrasHandler extends CerebrasHandler {
+				public testGetMaxOutputTokens() {
+					return this.getMaxOutputTokens()
+				}
 			}
-			vi.mocked(fetch).mockResolvedValueOnce(mockResponse as any)
 
-			const result = await handler.completePrompt("Test prompt")
-			expect(result).toBe("Test response")
+			const testHandler = new TestCerebrasHandler(mockOptions)
+			const result = testHandler.testGetMaxOutputTokens()
+
+			// llama-3.3-70b maxTokens is 16384
+			expect(result).toBe(16384)
 		})
-	})
 
-	describe("token usage and cost calculation", () => {
-		it("should track token usage properly", () => {
-			// Test that lastUsage is updated correctly
-			// Test getApiCost returns calculated cost based on actual usage
+		it("should use modelMaxTokens when provided", () => {
+			class TestCerebrasHandler extends CerebrasHandler {
+				public testGetMaxOutputTokens() {
+					return this.getMaxOutputTokens()
+				}
+			}
+
+			const customMaxTokens = 5000
+			const testHandler = new TestCerebrasHandler({
+				...mockOptions,
+				modelMaxTokens: customMaxTokens,
+			})
+
+			const result = testHandler.testGetMaxOutputTokens()
+			expect(result).toBe(customMaxTokens)
 		})
 
-		it("should provide usage estimates when API doesn't return usage", () => {
-			// Test fallback token estimation logic
+		it("should fall back to modelInfo.maxTokens when modelMaxTokens is not provided", () => {
+			class TestCerebrasHandler extends CerebrasHandler {
+				public testGetMaxOutputTokens() {
+					return this.getMaxOutputTokens()
+				}
+			}
+
+			const testHandler = new TestCerebrasHandler(mockOptions)
+			const result = testHandler.testGetMaxOutputTokens()
+
+			// llama-3.3-70b has maxTokens of 16384
+			expect(result).toBe(16384)
 		})
 	})
 
-	describe("convertToolsForOpenAI", () => {
-		it("should set all tools to strict: false for Cerebras API consistency", () => {
-			// Access the protected method through a test subclass
-			const regularTool = {
-				type: "function",
-				function: {
-					name: "read_file",
-					parameters: {
-						type: "object",
-						properties: {
-							path: { type: "string" },
+	describe("tool handling", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [{ type: "text" as const, text: "Hello!" }],
+			},
+		]
+
+		it("should handle tool calls in streaming", async () => {
+			async function* mockFullStream() {
+				yield {
+					type: "tool-input-start",
+					id: "tool-call-1",
+					toolName: "read_file",
+				}
+				yield {
+					type: "tool-input-delta",
+					id: "tool-call-1",
+					delta: '{"path":"test.ts"}',
+				}
+				yield {
+					type: "tool-input-end",
+					id: "tool-call-1",
+				}
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: [
+					{
+						type: "function",
+						function: {
+							name: "read_file",
+							description: "Read a file",
+							parameters: {
+								type: "object",
+								properties: { path: { type: "string" } },
+								required: ["path"],
+							},
 						},
-						required: ["path"],
 					},
-				},
+				],
+			})
+
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
 			}
 
-			// MCP tool with the 'mcp--' prefix
-			const mcpTool = {
-				type: "function",
-				function: {
-					name: "mcp--server--tool",
-					parameters: {
-						type: "object",
-						properties: {
-							arg: { type: "string" },
+			const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start")
+			const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta")
+			const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end")
+
+			expect(toolCallStartChunks.length).toBe(1)
+			expect(toolCallStartChunks[0].id).toBe("tool-call-1")
+			expect(toolCallStartChunks[0].name).toBe("read_file")
+
+			expect(toolCallDeltaChunks.length).toBe(1)
+			expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}')
+
+			expect(toolCallEndChunks.length).toBe(1)
+			expect(toolCallEndChunks[0].id).toBe("tool-call-1")
+		})
+
+		it("should ignore tool-call events to prevent duplicate tools in UI", async () => {
+			// tool-call events are intentionally ignored because tool-input-start/delta/end
+			// already provide complete tool call information. Emitting tool-call would cause
+			// duplicate tools in the UI for AI SDK providers (e.g., DeepSeek, Moonshot, Cerebras).
+			async function* mockFullStream() {
+				yield {
+					type: "tool-call",
+					toolCallId: "tool-call-1",
+					toolName: "read_file",
+					input: { path: "test.ts" },
+				}
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: [
+					{
+						type: "function",
+						function: {
+							name: "read_file",
+							description: "Read a file",
+							parameters: {
+								type: "object",
+								properties: { path: { type: "string" } },
+								required: ["path"],
+							},
 						},
 					},
-				},
+				],
+			})
+
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
 			}
 
-			// Create a test wrapper to access protected method
+			// tool-call events are ignored, so no tool_call chunks should be emitted
+			const toolCallChunks = chunks.filter((c) => c.type === "tool_call")
+			expect(toolCallChunks.length).toBe(0)
+		})
+	})
+
+	describe("mapToolChoice", () => {
+		it("should handle string tool choices", () => {
 			class TestCerebrasHandler extends CerebrasHandler {
-				public testConvertToolsForOpenAI(tools: any[]) {
-					return this.convertToolsForOpenAI(tools)
+				public testMapToolChoice(toolChoice: any) {
+					return this.mapToolChoice(toolChoice)
 				}
 			}
 
-			const testHandler = new TestCerebrasHandler({ cerebrasApiKey: "test" })
-			const converted = testHandler.testConvertToolsForOpenAI([regularTool, mcpTool])
+			const testHandler = new TestCerebrasHandler(mockOptions)
 
-			// Both tools should have strict: false
-			expect(converted).toHaveLength(2)
-			expect(converted![0].function.strict).toBe(false)
-			expect(converted![1].function.strict).toBe(false)
+			expect(testHandler.testMapToolChoice("auto")).toBe("auto")
+			expect(testHandler.testMapToolChoice("none")).toBe("none")
+			expect(testHandler.testMapToolChoice("required")).toBe("required")
+			expect(testHandler.testMapToolChoice("unknown")).toBe("auto")
 		})
 
-		it("should return undefined when tools is undefined", () => {
+		it("should handle object tool choice with function name", () => {
 			class TestCerebrasHandler extends CerebrasHandler {
-				public testConvertToolsForOpenAI(tools: any[] | undefined) {
-					return this.convertToolsForOpenAI(tools)
+				public testMapToolChoice(toolChoice: any) {
+					return this.mapToolChoice(toolChoice)
 				}
 			}
 
-			const testHandler = new TestCerebrasHandler({ cerebrasApiKey: "test" })
-			expect(testHandler.testConvertToolsForOpenAI(undefined)).toBeUndefined()
+			const testHandler = new TestCerebrasHandler(mockOptions)
+
+			const result = testHandler.testMapToolChoice({
+				type: "function",
+				function: { name: "my_tool" },
+			})
+
+			expect(result).toEqual({ type: "tool", toolName: "my_tool" })
 		})
 
-		it("should pass through non-function tools unchanged", () => {
+		it("should return undefined for null or undefined", () => {
 			class TestCerebrasHandler extends CerebrasHandler {
-				public testConvertToolsForOpenAI(tools: any[]) {
-					return this.convertToolsForOpenAI(tools)
+				public testMapToolChoice(toolChoice: any) {
+					return this.mapToolChoice(toolChoice)
 				}
 			}
 
-			const nonFunctionTool = { type: "other", data: "test" }
-			const testHandler = new TestCerebrasHandler({ cerebrasApiKey: "test" })
-			const converted = testHandler.testConvertToolsForOpenAI([nonFunctionTool])
+			const testHandler = new TestCerebrasHandler(mockOptions)
 
-			expect(converted![0]).toEqual(nonFunctionTool)
+			expect(testHandler.testMapToolChoice(null)).toBeUndefined()
+			expect(testHandler.testMapToolChoice(undefined)).toBeUndefined()
 		})
 	})
 })

+ 124 - 304
src/api/providers/cerebras.ts

@@ -1,362 +1,182 @@
 import { Anthropic } from "@anthropic-ai/sdk"
+import { createCerebras } from "@ai-sdk/cerebras"
+import { streamText, generateText, ToolSet } from "ai"
 
-import { type CerebrasModelId, cerebrasDefaultModelId, cerebrasModels } from "@roo-code/types"
+import { cerebrasModels, cerebrasDefaultModelId, type CerebrasModelId, type ModelInfo } from "@roo-code/types"
 
 import type { ApiHandlerOptions } from "../../shared/api"
-import { calculateApiCostOpenAI } from "../../shared/cost"
-import { ApiStream } from "../transform/stream"
-import { convertToOpenAiMessages } from "../transform/openai-format"
-import { TagMatcher } from "../../utils/tag-matcher"
 
-import type { ApiHandlerCreateMessageMetadata, SingleCompletionHandler } from "../index"
-import { BaseProvider } from "./base-provider"
-import { DEFAULT_HEADERS } from "./constants"
-import { t } from "../../i18n"
+import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk"
+import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
+import { getModelParams } from "../transform/model-params"
 
-const CEREBRAS_BASE_URL = "https://api.cerebras.ai/v1"
-const CEREBRAS_DEFAULT_TEMPERATURE = 0
+import { DEFAULT_HEADERS } from "./constants"
+import { BaseProvider } from "./base-provider"
+import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
 
 const CEREBRAS_INTEGRATION_HEADER = "X-Cerebras-3rd-Party-Integration"
 const CEREBRAS_INTEGRATION_NAME = "roocode"
+const CEREBRAS_DEFAULT_TEMPERATURE = 0
 
+/**
+ * Cerebras provider using the dedicated @ai-sdk/cerebras package.
+ * Provides high-speed inference powered by Wafer-Scale Engines.
+ */
 export class CerebrasHandler extends BaseProvider implements SingleCompletionHandler {
-	private apiKey: string
-	private providerModels: typeof cerebrasModels
-	private defaultProviderModelId: CerebrasModelId
-	private options: ApiHandlerOptions
-	private lastUsage: { inputTokens: number; outputTokens: number } = { inputTokens: 0, outputTokens: 0 }
+	protected options: ApiHandlerOptions
+	protected provider: ReturnType<typeof createCerebras>
 
 	constructor(options: ApiHandlerOptions) {
 		super()
 		this.options = options
-		this.apiKey = options.cerebrasApiKey || ""
-		this.providerModels = cerebrasModels
-		this.defaultProviderModelId = cerebrasDefaultModelId
 
-		if (!this.apiKey) {
-			throw new Error("Cerebras API key is required")
-		}
+		// Create the Cerebras provider using AI SDK
+		this.provider = createCerebras({
+			apiKey: options.cerebrasApiKey ?? "not-provided",
+			headers: {
+				...DEFAULT_HEADERS,
+				[CEREBRAS_INTEGRATION_HEADER]: CEREBRAS_INTEGRATION_NAME,
+			},
+		})
 	}
 
-	getModel(): { id: CerebrasModelId; info: (typeof cerebrasModels)[CerebrasModelId] } {
-		const modelId = this.options.apiModelId as CerebrasModelId
-		const validModelId = modelId && this.providerModels[modelId] ? modelId : this.defaultProviderModelId
-
-		return {
-			id: validModelId,
-			info: this.providerModels[validModelId],
-		}
+	override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } {
+		const id = (this.options.apiModelId ?? cerebrasDefaultModelId) as CerebrasModelId
+		const info = cerebrasModels[id as keyof typeof cerebrasModels] || cerebrasModels[cerebrasDefaultModelId]
+		const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options })
+		return { id, info, ...params }
 	}
 
 	/**
-	 * Override convertToolSchemaForOpenAI to remove unsupported schema fields for Cerebras.
-	 * Cerebras doesn't support minItems/maxItems in array schemas with strict mode.
+	 * Get the language model for the configured model ID.
 	 */
-	protected override convertToolSchemaForOpenAI(schema: any): any {
-		const converted = super.convertToolSchemaForOpenAI(schema)
-		return this.stripUnsupportedSchemaFields(converted)
+	protected getLanguageModel() {
+		const { id } = this.getModel()
+		return this.provider(id)
 	}
 
 	/**
-	 * Recursively strips unsupported schema fields for Cerebras.
-	 * Cerebras strict mode doesn't support minItems, maxItems on arrays.
+	 * Process usage metrics from the AI SDK response.
 	 */
-	private stripUnsupportedSchemaFields(schema: any): any {
-		if (!schema || typeof schema !== "object") {
-			return schema
+	protected processUsageMetrics(usage: {
+		inputTokens?: number
+		outputTokens?: number
+		details?: {
+			cachedInputTokens?: number
+			reasoningTokens?: number
 		}
+	}): ApiStreamUsageChunk {
+		return {
+			type: "usage",
+			inputTokens: usage.inputTokens || 0,
+			outputTokens: usage.outputTokens || 0,
+			cacheReadTokens: usage.details?.cachedInputTokens,
+			reasoningTokens: usage.details?.reasoningTokens,
+		}
+	}
 
-		const result = { ...schema }
-
-		// Remove unsupported array constraints
-		if (result.type === "array" || (Array.isArray(result.type) && result.type.includes("array"))) {
-			delete result.minItems
-			delete result.maxItems
+	/**
+	 * Map OpenAI tool_choice to AI SDK toolChoice format.
+	 */
+	protected mapToolChoice(
+		toolChoice: any,
+	): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined {
+		if (!toolChoice) {
+			return undefined
 		}
 
-		// Recursively process properties
-		if (result.properties) {
-			const newProps = { ...result.properties }
-			for (const key of Object.keys(newProps)) {
-				newProps[key] = this.stripUnsupportedSchemaFields(newProps[key])
+		// Handle string values
+		if (typeof toolChoice === "string") {
+			switch (toolChoice) {
+				case "auto":
+					return "auto"
+				case "none":
+					return "none"
+				case "required":
+					return "required"
+				default:
+					return "auto"
 			}
-			result.properties = newProps
 		}
 
-		// Recursively process array items
-		if (result.items) {
-			result.items = this.stripUnsupportedSchemaFields(result.items)
+		// Handle object values (OpenAI ChatCompletionNamedToolChoice format)
+		if (typeof toolChoice === "object" && "type" in toolChoice) {
+			if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) {
+				return { type: "tool", toolName: toolChoice.function.name }
+			}
 		}
 
-		return result
+		return undefined
 	}
 
 	/**
-	 * Override convertToolsForOpenAI to ensure all tools have consistent strict values.
-	 * Cerebras API requires all tools to have the same strict mode setting.
-	 * We use strict: false for all tools since MCP tools cannot use strict mode
-	 * (they have optional parameters from the MCP server schema).
+	 * Get the max tokens parameter to include in the request.
 	 */
-	protected override convertToolsForOpenAI(tools: any[] | undefined): any[] | undefined {
-		if (!tools) {
-			return undefined
-		}
-
-		return tools.map((tool) => {
-			if (tool.type !== "function") {
-				return tool
-			}
-
-			return {
-				...tool,
-				function: {
-					...tool.function,
-					strict: false,
-					parameters: this.convertToolSchemaForOpenAI(tool.function.parameters),
-				},
-			}
-		})
+	protected getMaxOutputTokens(): number | undefined {
+		const { info } = this.getModel()
+		return this.options.modelMaxTokens || info.maxTokens || undefined
 	}
 
-	async *createMessage(
+	/**
+	 * Create a message stream using the AI SDK.
+	 */
+	override async *createMessage(
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
 		metadata?: ApiHandlerCreateMessageMetadata,
 	): ApiStream {
-		const { id: model, info: modelInfo } = this.getModel()
-		const max_tokens = modelInfo.maxTokens
-		const temperature = this.options.modelTemperature ?? CEREBRAS_DEFAULT_TEMPERATURE
-
-		// Convert Anthropic messages to OpenAI format (Cerebras is OpenAI-compatible)
-		const openaiMessages = convertToOpenAiMessages(messages)
-
-		// Prepare request body following Cerebras API specification exactly
-		const requestBody: Record<string, any> = {
-			model,
-			messages: [{ role: "system", content: systemPrompt }, ...openaiMessages],
-			stream: true,
-			// Use max_completion_tokens (Cerebras-specific parameter)
-			...(max_tokens && max_tokens > 0 && max_tokens <= 32768 ? { max_completion_tokens: max_tokens } : {}),
-			// Clamp temperature to Cerebras range (0 to 1.5)
-			...(temperature !== undefined && temperature !== CEREBRAS_DEFAULT_TEMPERATURE
-				? {
-						temperature: Math.max(0, Math.min(1.5, temperature)),
-					}
-				: {}),
-			// Native tool calling support
-			tools: this.convertToolsForOpenAI(metadata?.tools),
-			tool_choice: metadata?.tool_choice,
-			parallel_tool_calls: metadata?.parallelToolCalls ?? true,
+		const { temperature } = this.getModel()
+		const languageModel = this.getLanguageModel()
+
+		// Convert messages to AI SDK format
+		const aiSdkMessages = convertToAiSdkMessages(messages)
+
+		// Convert tools to OpenAI format first, then to AI SDK format
+		const openAiTools = this.convertToolsForOpenAI(metadata?.tools)
+		const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined
+
+		// Build the request options
+		const requestOptions: Parameters<typeof streamText>[0] = {
+			model: languageModel,
+			system: systemPrompt,
+			messages: aiSdkMessages,
+			temperature: this.options.modelTemperature ?? temperature ?? CEREBRAS_DEFAULT_TEMPERATURE,
+			maxOutputTokens: this.getMaxOutputTokens(),
+			tools: aiSdkTools,
+			toolChoice: this.mapToolChoice(metadata?.tool_choice),
 		}
 
-		try {
-			const response = await fetch(`${CEREBRAS_BASE_URL}/chat/completions`, {
-				method: "POST",
-				headers: {
-					...DEFAULT_HEADERS,
-					"Content-Type": "application/json",
-					Authorization: `Bearer ${this.apiKey}`,
-					[CEREBRAS_INTEGRATION_HEADER]: CEREBRAS_INTEGRATION_NAME,
-				},
-				body: JSON.stringify(requestBody),
-			})
-
-			if (!response.ok) {
-				const errorText = await response.text()
-
-				let errorMessage = "Unknown error"
-				try {
-					const errorJson = JSON.parse(errorText)
-					errorMessage = errorJson.error?.message || errorJson.message || JSON.stringify(errorJson, null, 2)
-				} catch {
-					errorMessage = errorText || `HTTP ${response.status}`
-				}
-
-				// Provide more actionable error messages
-				if (response.status === 401) {
-					throw new Error(t("common:errors.cerebras.authenticationFailed"))
-				} else if (response.status === 403) {
-					throw new Error(t("common:errors.cerebras.accessForbidden"))
-				} else if (response.status === 429) {
-					throw new Error(t("common:errors.cerebras.rateLimitExceeded"))
-				} else if (response.status >= 500) {
-					throw new Error(t("common:errors.cerebras.serverError", { status: response.status }))
-				} else {
-					throw new Error(
-						t("common:errors.cerebras.genericError", { status: response.status, message: errorMessage }),
-					)
-				}
-			}
-
-			if (!response.body) {
-				throw new Error(t("common:errors.cerebras.noResponseBody"))
-			}
-
-			// Initialize TagMatcher to parse <think>...</think> tags
-			const matcher = new TagMatcher(
-				"think",
-				(chunk) =>
-					({
-						type: chunk.matched ? "reasoning" : "text",
-						text: chunk.data,
-					}) as const,
-			)
-
-			const reader = response.body.getReader()
-			const decoder = new TextDecoder()
-			let buffer = ""
-			let inputTokens = 0
-			let outputTokens = 0
-
-			try {
-				while (true) {
-					const { done, value } = await reader.read()
-					if (done) break
-
-					buffer += decoder.decode(value, { stream: true })
-					const lines = buffer.split("\n")
-					buffer = lines.pop() || "" // Keep the last incomplete line in the buffer
-
-					for (const line of lines) {
-						if (line.trim() === "") continue
-
-						try {
-							if (line.startsWith("data: ")) {
-								const jsonStr = line.slice(6).trim()
-								if (jsonStr === "[DONE]") {
-									continue
-								}
-
-								const parsed = JSON.parse(jsonStr)
+		// Use streamText for streaming responses
+		const result = streamText(requestOptions)
 
-								const delta = parsed.choices?.[0]?.delta
-
-								// Handle text content - parse for thinking tokens
-								if (delta?.content) {
-									const content = delta.content
-
-									// Use TagMatcher to parse <think>...</think> tags
-									for (const chunk of matcher.update(content)) {
-										yield chunk
-									}
-								}
-
-								// Handle tool calls in stream - emit partial chunks for NativeToolCallParser
-								if (delta?.tool_calls) {
-									for (const toolCall of delta.tool_calls) {
-										yield {
-											type: "tool_call_partial",
-											index: toolCall.index,
-											id: toolCall.id,
-											name: toolCall.function?.name,
-											arguments: toolCall.function?.arguments,
-										}
-									}
-								}
-
-								// Handle usage information if available
-								if (parsed.usage) {
-									inputTokens = parsed.usage.prompt_tokens || 0
-									outputTokens = parsed.usage.completion_tokens || 0
-								}
-							}
-						} catch (error) {
-							// Silently ignore malformed streaming data lines
-						}
-					}
-				}
-			} finally {
-				reader.releaseLock()
-			}
-
-			// Process any remaining content in the matcher
-			for (const chunk of matcher.final()) {
+		// Process the full stream to get all events including reasoning
+		for await (const part of result.fullStream) {
+			for (const chunk of processAiSdkStreamPart(part)) {
 				yield chunk
 			}
+		}
 
-			// Provide token usage estimate if not available from API
-			if (inputTokens === 0 || outputTokens === 0) {
-				const inputText =
-					systemPrompt +
-					openaiMessages
-						.map((m: any) => (typeof m.content === "string" ? m.content : JSON.stringify(m.content)))
-						.join("")
-				inputTokens = inputTokens || Math.ceil(inputText.length / 4) // Rough estimate: 4 chars per token
-				outputTokens = outputTokens || Math.ceil((max_tokens || 1000) / 10) // Rough estimate
-			}
-
-			// Store usage for cost calculation
-			this.lastUsage = { inputTokens, outputTokens }
-
-			yield {
-				type: "usage",
-				inputTokens,
-				outputTokens,
-			}
-		} catch (error) {
-			if (error instanceof Error) {
-				throw new Error(t("common:errors.cerebras.completionError", { error: error.message }))
-			}
-			throw error
+		// Yield usage metrics at the end
+		const usage = await result.usage
+		if (usage) {
+			yield this.processUsageMetrics(usage)
 		}
 	}
 
+	/**
+	 * Complete a prompt using the AI SDK generateText.
+	 */
 	async completePrompt(prompt: string): Promise<string> {
-		const { id: model } = this.getModel()
-
-		// Prepare request body for non-streaming completion
-		const requestBody = {
-			model,
-			messages: [{ role: "user", content: prompt }],
-			stream: false,
-		}
-
-		try {
-			const response = await fetch(`${CEREBRAS_BASE_URL}/chat/completions`, {
-				method: "POST",
-				headers: {
-					...DEFAULT_HEADERS,
-					"Content-Type": "application/json",
-					Authorization: `Bearer ${this.apiKey}`,
-					[CEREBRAS_INTEGRATION_HEADER]: CEREBRAS_INTEGRATION_NAME,
-				},
-				body: JSON.stringify(requestBody),
-			})
-
-			if (!response.ok) {
-				const errorText = await response.text()
-
-				// Provide consistent error handling with createMessage
-				if (response.status === 401) {
-					throw new Error(t("common:errors.cerebras.authenticationFailed"))
-				} else if (response.status === 403) {
-					throw new Error(t("common:errors.cerebras.accessForbidden"))
-				} else if (response.status === 429) {
-					throw new Error(t("common:errors.cerebras.rateLimitExceeded"))
-				} else if (response.status >= 500) {
-					throw new Error(t("common:errors.cerebras.serverError", { status: response.status }))
-				} else {
-					throw new Error(
-						t("common:errors.cerebras.genericError", { status: response.status, message: errorText }),
-					)
-				}
-			}
-
-			const result = await response.json()
-			return result.choices?.[0]?.message?.content || ""
-		} catch (error) {
-			if (error instanceof Error) {
-				throw new Error(t("common:errors.cerebras.completionError", { error: error.message }))
-			}
-			throw error
-		}
-	}
+		const { temperature } = this.getModel()
+		const languageModel = this.getLanguageModel()
+
+		const { text } = await generateText({
+			model: languageModel,
+			prompt,
+			maxOutputTokens: this.getMaxOutputTokens(),
+			temperature: this.options.modelTemperature ?? temperature ?? CEREBRAS_DEFAULT_TEMPERATURE,
+		})
 
-	getApiCost(metadata: ApiHandlerCreateMessageMetadata): number {
-		const { info } = this.getModel()
-		// Use actual token usage from the last request
-		const { inputTokens, outputTokens } = this.lastUsage
-		const { totalCost } = calculateApiCostOpenAI(info, inputTokens, outputTokens)
-		return totalCost
+		return text
 	}
 }

+ 1 - 0
src/package.json

@@ -450,6 +450,7 @@
 		"clean": "rimraf README.md CHANGELOG.md LICENSE dist logs mock .turbo"
 	},
 	"dependencies": {
+		"@ai-sdk/cerebras": "^1.0.0",
 		"@ai-sdk/deepseek": "^2.0.14",
 		"@anthropic-ai/bedrock-sdk": "^0.10.2",
 		"@anthropic-ai/sdk": "^0.37.0",