Просмотр исходного кода

feat: migrate Groq provider to @ai-sdk/groq (#11088)

Daniel 1 неделя назад
Родитель
Сommit
e771a4936b
4 измененных файлов с 787 добавлено и 157 удалено
  1. 21 6
      pnpm-lock.yaml
  2. 574 141
      src/api/providers/__tests__/groq.spec.ts
  3. 191 10
      src/api/providers/groq.ts
  4. 1 0
      src/package.json

+ 21 - 6
pnpm-lock.yaml

@@ -748,7 +748,10 @@ importers:
         version: 1.0.35([email protected])
       '@ai-sdk/deepseek':
         specifier: ^2.0.14
-        version: 2.0.15([email protected])
+        version: 2.0.14([email protected])
+      '@ai-sdk/groq':
+        specifier: ^3.0.19
+        version: 3.0.19([email protected])
       '@anthropic-ai/bedrock-sdk':
         specifier: ^0.10.2
         version: 0.10.4
@@ -1399,8 +1402,8 @@ packages:
     peerDependencies:
       zod: 3.25.76
 
-  '@ai-sdk/[email protected]5':
-    resolution: {integrity: sha512-3wJUjNjGrTZS3K8OEfHD1PZYhzkcXuoL8KIVtzi6WrC5xrDQPjCBPATmdKPV7DgDCF+wujQOaMz5cv40Yg+hog==}
+  '@ai-sdk/[email protected]4':
+    resolution: {integrity: sha512-1vXh8sVwRJYd1JO57qdy1rACucaNLDoBRCwOER3EbPgSF2vNVPcdJywGutA01Bhn7Cta+UJQ+k5y/yzMAIpP2w==}
     engines: {node: '>=18'}
     peerDependencies:
       zod: 3.25.76
@@ -1411,6 +1414,12 @@ packages:
     peerDependencies:
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]':
+    resolution: {integrity: sha512-WAeGVnp9rvU3RUvu6S1HiD8hAjKgNlhq+z3m4j5Z1fIKRXqcKjOscVZGwL36If8qxsqXNVCtG3ltXawM5UAa8w==}
+    engines: {node: '>=18'}
+    peerDependencies:
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]':
     resolution: {integrity: sha512-znBvaVHM0M6yWNerIEy3hR+O8ZK2sPcE7e2cxfb6kYLEX3k//JH5VDnRnajseVofg7LXtTCFFdjsB7WLf1BdeQ==}
     engines: {node: '>=18'}
@@ -11011,10 +11020,10 @@ snapshots:
       '@ai-sdk/provider-utils': 3.0.20([email protected])
       zod: 3.25.76
 
-  '@ai-sdk/[email protected]5([email protected])':
+  '@ai-sdk/[email protected]4([email protected])':
     dependencies:
-      '@ai-sdk/provider': 3.0.6
-      '@ai-sdk/provider-utils': 4.0.11([email protected])
+      '@ai-sdk/provider': 3.0.5
+      '@ai-sdk/provider-utils': 4.0.10([email protected])
       zod: 3.25.76
 
   '@ai-sdk/[email protected]([email protected])':
@@ -11024,6 +11033,12 @@ snapshots:
       '@vercel/oidc': 3.1.0
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]([email protected])':
+    dependencies:
+      '@ai-sdk/provider': 3.0.6
+      '@ai-sdk/provider-utils': 4.0.11([email protected])
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]([email protected])':
     dependencies:
       '@ai-sdk/provider': 2.0.1

+ 574 - 141
src/api/providers/__tests__/groq.spec.ts

@@ -1,192 +1,625 @@
 // npx vitest run src/api/providers/__tests__/groq.spec.ts
 
-import OpenAI from "openai"
-import { Anthropic } from "@anthropic-ai/sdk"
+// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls
+const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({
+	mockStreamText: vi.fn(),
+	mockGenerateText: vi.fn(),
+}))
 
-import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types"
-
-import { GroqHandler } from "../groq"
-
-vitest.mock("openai", () => {
-	const createMock = vitest.fn()
+vi.mock("ai", async (importOriginal) => {
+	const actual = await importOriginal<typeof import("ai")>()
 	return {
-		default: vitest.fn(() => ({ chat: { completions: { create: createMock } } })),
+		...actual,
+		streamText: mockStreamText,
+		generateText: mockGenerateText,
 	}
 })
 
+vi.mock("@ai-sdk/groq", () => ({
+	createGroq: vi.fn(() => {
+		// Return a function that returns a mock language model
+		return vi.fn(() => ({
+			modelId: "moonshotai/kimi-k2-instruct-0905",
+			provider: "groq",
+		}))
+	}),
+}))
+
+import type { Anthropic } from "@anthropic-ai/sdk"
+
+import { groqDefaultModelId, groqModels, type GroqModelId } from "@roo-code/types"
+
+import type { ApiHandlerOptions } from "../../../shared/api"
+
+import { GroqHandler } from "../groq"
+
 describe("GroqHandler", () => {
 	let handler: GroqHandler
-	let mockCreate: any
+	let mockOptions: ApiHandlerOptions
 
 	beforeEach(() => {
-		vitest.clearAllMocks()
-		mockCreate = (OpenAI as unknown as any)().chat.completions.create
-		handler = new GroqHandler({ groqApiKey: "test-groq-api-key" })
+		mockOptions = {
+			groqApiKey: "test-groq-api-key",
+			apiModelId: "moonshotai/kimi-k2-instruct-0905",
+		}
+		handler = new GroqHandler(mockOptions)
+		vi.clearAllMocks()
 	})
 
-	it("should use the correct Groq base URL", () => {
-		new GroqHandler({ groqApiKey: "test-groq-api-key" })
-		expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.groq.com/openai/v1" }))
-	})
+	describe("constructor", () => {
+		it("should initialize with provided options", () => {
+			expect(handler).toBeInstanceOf(GroqHandler)
+			expect(handler.getModel().id).toBe(mockOptions.apiModelId)
+		})
 
-	it("should use the provided API key", () => {
-		const groqApiKey = "test-groq-api-key"
-		new GroqHandler({ groqApiKey })
-		expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: groqApiKey }))
+		it("should use default model ID if not provided", () => {
+			const handlerWithoutModel = new GroqHandler({
+				...mockOptions,
+				apiModelId: undefined,
+			})
+			expect(handlerWithoutModel.getModel().id).toBe(groqDefaultModelId)
+		})
 	})
 
-	it("should return default model when no model is specified", () => {
-		const model = handler.getModel()
-		expect(model.id).toBe(groqDefaultModelId)
-		expect(model.info).toEqual(groqModels[groqDefaultModelId])
-	})
+	describe("getModel", () => {
+		it("should return default model when no model is specified", () => {
+			const handlerWithoutModel = new GroqHandler({
+				groqApiKey: "test-groq-api-key",
+			})
+			const model = handlerWithoutModel.getModel()
+			expect(model.id).toBe(groqDefaultModelId)
+			expect(model.info).toEqual(groqModels[groqDefaultModelId])
+		})
 
-	it("should return specified model when valid model is provided", () => {
-		const testModelId: GroqModelId = "llama-3.3-70b-versatile"
-		const handlerWithModel = new GroqHandler({ apiModelId: testModelId, groqApiKey: "test-groq-api-key" })
-		const model = handlerWithModel.getModel()
-		expect(model.id).toBe(testModelId)
-		expect(model.info).toEqual(groqModels[testModelId])
-	})
+		it("should return specified model when valid model is provided", () => {
+			const testModelId: GroqModelId = "llama-3.3-70b-versatile"
+			const handlerWithModel = new GroqHandler({
+				apiModelId: testModelId,
+				groqApiKey: "test-groq-api-key",
+			})
+			const model = handlerWithModel.getModel()
+			expect(model.id).toBe(testModelId)
+			expect(model.info).toEqual(groqModels[testModelId])
+		})
 
-	it("completePrompt method should return text from Groq API", async () => {
-		const expectedResponse = "This is a test response from Groq"
-		mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })
-		const result = await handler.completePrompt("test prompt")
-		expect(result).toBe(expectedResponse)
-	})
+		it("should return model info for llama-3.1-8b-instant", () => {
+			const handlerWithLlama = new GroqHandler({
+				...mockOptions,
+				apiModelId: "llama-3.1-8b-instant",
+			})
+			const model = handlerWithLlama.getModel()
+			expect(model.id).toBe("llama-3.1-8b-instant")
+			expect(model.info).toBeDefined()
+			expect(model.info.maxTokens).toBe(8192)
+			expect(model.info.contextWindow).toBe(131072)
+			expect(model.info.supportsImages).toBe(false)
+			expect(model.info.supportsPromptCache).toBe(false)
+		})
 
-	it("should handle errors in completePrompt", async () => {
-		const errorMessage = "Groq API error"
-		mockCreate.mockRejectedValueOnce(new Error(errorMessage))
-		await expect(handler.completePrompt("test prompt")).rejects.toThrow(`Groq completion error: ${errorMessage}`)
+		it("should return model info for kimi-k2 which supports prompt cache", () => {
+			const handlerWithKimi = new GroqHandler({
+				...mockOptions,
+				apiModelId: "moonshotai/kimi-k2-instruct-0905",
+			})
+			const model = handlerWithKimi.getModel()
+			expect(model.id).toBe("moonshotai/kimi-k2-instruct-0905")
+			expect(model.info).toBeDefined()
+			expect(model.info.maxTokens).toBe(16384)
+			expect(model.info.contextWindow).toBe(262144)
+			expect(model.info.supportsPromptCache).toBe(true)
+		})
+
+		it("should return provided model ID with default model info if model does not exist", () => {
+			const handlerWithInvalidModel = new GroqHandler({
+				...mockOptions,
+				apiModelId: "invalid-model",
+			})
+			const model = handlerWithInvalidModel.getModel()
+			expect(model.id).toBe("invalid-model")
+			expect(model.info).toBeDefined()
+			// Should use default model info
+			expect(model.info).toBe(groqModels[groqDefaultModelId])
+		})
+
+		it("should include model parameters from getModelParams", () => {
+			const model = handler.getModel()
+			expect(model).toHaveProperty("temperature")
+			expect(model).toHaveProperty("maxTokens")
+		})
 	})
 
-	it("createMessage should yield text content from stream", async () => {
-		const testContent = "This is test content from Groq stream"
-
-		mockCreate.mockImplementationOnce(() => {
-			return {
-				[Symbol.asyncIterator]: () => ({
-					next: vitest
-						.fn()
-						.mockResolvedValueOnce({
-							done: false,
-							value: { choices: [{ delta: { content: testContent } }] },
-						})
-						.mockResolvedValueOnce({ done: true }),
-				}),
+	describe("createMessage", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [
+					{
+						type: "text" as const,
+						text: "Hello!",
+					},
+				],
+			},
+		]
+
+		it("should handle streaming responses", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response from Groq" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			const mockProviderMetadata = Promise.resolve({})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks.length).toBeGreaterThan(0)
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks).toHaveLength(1)
+			expect(textChunks[0].text).toBe("Test response from Groq")
+		})
+
+		it("should include usage information", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
 			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 20,
+			})
+
+			const mockProviderMetadata = Promise.resolve({})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].inputTokens).toBe(10)
+			expect(usageChunks[0].outputTokens).toBe(20)
+		})
+
+		it("should handle cached tokens in usage data from providerMetadata", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 100,
+				outputTokens: 50,
+			})
+
+			// Groq provides cache metrics via providerMetadata for supported models
+			const mockProviderMetadata = Promise.resolve({
+				groq: {
+					promptCacheHitTokens: 30,
+					promptCacheMissTokens: 70,
+				},
+			})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].inputTokens).toBe(100)
+			expect(usageChunks[0].outputTokens).toBe(50)
+			expect(usageChunks[0].cacheReadTokens).toBe(30)
+			expect(usageChunks[0].cacheWriteTokens).toBe(70)
 		})
 
-		const stream = handler.createMessage("system prompt", [])
-		const firstChunk = await stream.next()
+		it("should handle usage with details.cachedInputTokens when providerMetadata is not available", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 100,
+				outputTokens: 50,
+				details: {
+					cachedInputTokens: 25,
+				},
+			})
+
+			const mockProviderMetadata = Promise.resolve({})
 
-		expect(firstChunk.done).toBe(false)
-		expect(firstChunk.value).toEqual({ type: "text", text: testContent })
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].cacheReadTokens).toBe(25)
+			expect(usageChunks[0].cacheWriteTokens).toBeUndefined()
+		})
+
+		it("should pass correct temperature (0.5 default) to streamText", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+				providerMetadata: Promise.resolve({}),
+			})
+
+			const handlerWithDefaultTemp = new GroqHandler({
+				groqApiKey: "test-key",
+				apiModelId: "llama-3.1-8b-instant",
+			})
+
+			const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages)
+			for await (const _ of stream) {
+				// consume stream
+			}
+
+			expect(mockStreamText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					temperature: 0.5,
+				}),
+			)
+		})
 	})
 
-	it("createMessage should yield usage data from stream", async () => {
-		mockCreate.mockImplementationOnce(() => {
-			return {
-				[Symbol.asyncIterator]: () => ({
-					next: vitest
-						.fn()
-						.mockResolvedValueOnce({
-							done: false,
-							value: { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 20 } },
-						})
-						.mockResolvedValueOnce({ done: true }),
+	describe("completePrompt", () => {
+		it("should complete a prompt using generateText", async () => {
+			mockGenerateText.mockResolvedValue({
+				text: "Test completion from Groq",
+			})
+
+			const result = await handler.completePrompt("Test prompt")
+
+			expect(result).toBe("Test completion from Groq")
+			expect(mockGenerateText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					prompt: "Test prompt",
+				}),
+			)
+		})
+
+		it("should use default temperature in completePrompt", async () => {
+			mockGenerateText.mockResolvedValue({
+				text: "Test completion",
+			})
+
+			await handler.completePrompt("Test prompt")
+
+			expect(mockGenerateText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					temperature: 0.5,
 				}),
+			)
+		})
+	})
+
+	describe("processUsageMetrics", () => {
+		it("should correctly process usage metrics including cache information from providerMetadata", () => {
+			class TestGroqHandler extends GroqHandler {
+				public testProcessUsageMetrics(usage: any, providerMetadata?: any) {
+					return this.processUsageMetrics(usage, providerMetadata)
+				}
+			}
+
+			const testHandler = new TestGroqHandler(mockOptions)
+
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
 			}
+
+			const providerMetadata = {
+				groq: {
+					promptCacheHitTokens: 20,
+					promptCacheMissTokens: 80,
+				},
+			}
+
+			const result = testHandler.testProcessUsageMetrics(usage, providerMetadata)
+
+			expect(result.type).toBe("usage")
+			expect(result.inputTokens).toBe(100)
+			expect(result.outputTokens).toBe(50)
+			expect(result.cacheWriteTokens).toBe(80)
+			expect(result.cacheReadTokens).toBe(20)
 		})
 
-		const stream = handler.createMessage("system prompt", [])
-		const firstChunk = await stream.next()
+		it("should handle missing cache metrics gracefully", () => {
+			class TestGroqHandler extends GroqHandler {
+				public testProcessUsageMetrics(usage: any, providerMetadata?: any) {
+					return this.processUsageMetrics(usage, providerMetadata)
+				}
+			}
+
+			const testHandler = new TestGroqHandler(mockOptions)
+
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
+			}
+
+			const result = testHandler.testProcessUsageMetrics(usage)
 
-		expect(firstChunk.done).toBe(false)
-		expect(firstChunk.value).toMatchObject({
-			type: "usage",
-			inputTokens: 10,
-			outputTokens: 20,
+			expect(result.type).toBe("usage")
+			expect(result.inputTokens).toBe(100)
+			expect(result.outputTokens).toBe(50)
+			expect(result.cacheWriteTokens).toBeUndefined()
+			expect(result.cacheReadTokens).toBeUndefined()
+		})
+
+		it("should include reasoning tokens when provided", () => {
+			class TestGroqHandler extends GroqHandler {
+				public testProcessUsageMetrics(usage: any, providerMetadata?: any) {
+					return this.processUsageMetrics(usage, providerMetadata)
+				}
+			}
+
+			const testHandler = new TestGroqHandler(mockOptions)
+
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
+				details: {
+					reasoningTokens: 30,
+				},
+			}
+
+			const result = testHandler.testProcessUsageMetrics(usage)
+
+			expect(result.reasoningTokens).toBe(30)
 		})
-		// cacheWriteTokens and cacheReadTokens will be undefined when 0
-		expect(firstChunk.value.cacheWriteTokens).toBeUndefined()
-		expect(firstChunk.value.cacheReadTokens).toBeUndefined()
-		// Check that totalCost is a number (we don't need to test the exact value as that's tested in cost.spec.ts)
-		expect(typeof firstChunk.value.totalCost).toBe("number")
 	})
 
-	it("createMessage should handle cached tokens in usage data", async () => {
-		mockCreate.mockImplementationOnce(() => {
-			return {
-				[Symbol.asyncIterator]: () => ({
-					next: vitest
-						.fn()
-						.mockResolvedValueOnce({
-							done: false,
-							value: {
-								choices: [{ delta: {} }],
-								usage: {
-									prompt_tokens: 100,
-									completion_tokens: 50,
-									prompt_tokens_details: {
-										cached_tokens: 30,
-									},
-								},
+	describe("tool handling", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [{ type: "text" as const, text: "Hello!" }],
+			},
+		]
+
+		it("should handle tool calls in streaming", async () => {
+			async function* mockFullStream() {
+				yield {
+					type: "tool-input-start",
+					id: "tool-call-1",
+					toolName: "read_file",
+				}
+				yield {
+					type: "tool-input-delta",
+					id: "tool-call-1",
+					delta: '{"path":"test.ts"}',
+				}
+				yield {
+					type: "tool-input-end",
+					id: "tool-call-1",
+				}
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			const mockProviderMetadata = Promise.resolve({})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: [
+					{
+						type: "function",
+						function: {
+							name: "read_file",
+							description: "Read a file",
+							parameters: {
+								type: "object",
+								properties: { path: { type: "string" } },
+								required: ["path"],
 							},
-						})
-						.mockResolvedValueOnce({ done: true }),
-				}),
+						},
+					},
+				],
+			})
+
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
 			}
+
+			const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start")
+			const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta")
+			const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end")
+
+			expect(toolCallStartChunks.length).toBe(1)
+			expect(toolCallStartChunks[0].id).toBe("tool-call-1")
+			expect(toolCallStartChunks[0].name).toBe("read_file")
+
+			expect(toolCallDeltaChunks.length).toBe(1)
+			expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}')
+
+			expect(toolCallEndChunks.length).toBe(1)
+			expect(toolCallEndChunks[0].id).toBe("tool-call-1")
 		})
 
-		const stream = handler.createMessage("system prompt", [])
-		const firstChunk = await stream.next()
+		it("should ignore tool-call events to prevent duplicate tools in UI", async () => {
+			async function* mockFullStream() {
+				yield {
+					type: "tool-call",
+					toolCallId: "tool-call-1",
+					toolName: "read_file",
+					input: { path: "test.ts" },
+				}
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
+
+			const mockProviderMetadata = Promise.resolve({})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: [
+					{
+						type: "function",
+						function: {
+							name: "read_file",
+							description: "Read a file",
+							parameters: {
+								type: "object",
+								properties: { path: { type: "string" } },
+								required: ["path"],
+							},
+						},
+					},
+				],
+			})
 
-		expect(firstChunk.done).toBe(false)
-		expect(firstChunk.value).toMatchObject({
-			type: "usage",
-			inputTokens: 100,
-			outputTokens: 50,
-			cacheReadTokens: 30,
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// tool-call events are ignored, so no tool_call chunks should be emitted
+			const toolCallChunks = chunks.filter((c) => c.type === "tool_call")
+			expect(toolCallChunks.length).toBe(0)
 		})
-		// cacheWriteTokens will be undefined when 0
-		expect(firstChunk.value.cacheWriteTokens).toBeUndefined()
-		expect(typeof firstChunk.value.totalCost).toBe("number")
 	})
 
-	it("createMessage should pass correct parameters to Groq client", async () => {
-		const modelId: GroqModelId = "llama-3.1-8b-instant"
-		const modelInfo = groqModels[modelId]
-		const handlerWithModel = new GroqHandler({ apiModelId: modelId, groqApiKey: "test-groq-api-key" })
+	describe("getMaxOutputTokens", () => {
+		it("should return maxTokens from model info", () => {
+			class TestGroqHandler extends GroqHandler {
+				public testGetMaxOutputTokens() {
+					return this.getMaxOutputTokens()
+				}
+			}
 
-		mockCreate.mockImplementationOnce(() => {
-			return {
-				[Symbol.asyncIterator]: () => ({
-					async next() {
-						return { done: true }
-					},
-				}),
+			const testHandler = new TestGroqHandler({
+				...mockOptions,
+				apiModelId: "llama-3.1-8b-instant",
+			})
+			const result = testHandler.testGetMaxOutputTokens()
+
+			// llama-3.1-8b-instant has maxTokens of 8192
+			expect(result).toBe(8192)
+		})
+
+		it("should use modelMaxTokens when provided", () => {
+			class TestGroqHandler extends GroqHandler {
+				public testGetMaxOutputTokens() {
+					return this.getMaxOutputTokens()
+				}
 			}
+
+			const customMaxTokens = 5000
+			const testHandler = new TestGroqHandler({
+				...mockOptions,
+				modelMaxTokens: customMaxTokens,
+			})
+
+			const result = testHandler.testGetMaxOutputTokens()
+			expect(result).toBe(customMaxTokens)
 		})
+	})
 
-		const systemPrompt = "Test system prompt for Groq"
-		const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Groq" }]
+	describe("mapToolChoice", () => {
+		it("should handle string tool choices", () => {
+			class TestGroqHandler extends GroqHandler {
+				public testMapToolChoice(toolChoice: any) {
+					return this.mapToolChoice(toolChoice)
+				}
+			}
 
-		const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
-		await messageGenerator.next()
+			const testHandler = new TestGroqHandler(mockOptions)
+
+			expect(testHandler.testMapToolChoice("auto")).toBe("auto")
+			expect(testHandler.testMapToolChoice("none")).toBe("none")
+			expect(testHandler.testMapToolChoice("required")).toBe("required")
+			expect(testHandler.testMapToolChoice("unknown")).toBe("auto")
+		})
+
+		it("should handle object tool choice with function name", () => {
+			class TestGroqHandler extends GroqHandler {
+				public testMapToolChoice(toolChoice: any) {
+					return this.mapToolChoice(toolChoice)
+				}
+			}
 
-		expect(mockCreate).toHaveBeenCalledWith(
-			expect.objectContaining({
-				model: modelId,
-				max_tokens: modelInfo.maxTokens,
-				temperature: 0.5,
-				messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
-				stream: true,
-				stream_options: { include_usage: true },
-			}),
-			undefined,
-		)
+			const testHandler = new TestGroqHandler(mockOptions)
+
+			const result = testHandler.testMapToolChoice({
+				type: "function",
+				function: { name: "my_tool" },
+			})
+
+			expect(result).toEqual({ type: "tool", toolName: "my_tool" })
+		})
+
+		it("should return undefined for null or undefined", () => {
+			class TestGroqHandler extends GroqHandler {
+				public testMapToolChoice(toolChoice: any) {
+					return this.mapToolChoice(toolChoice)
+				}
+			}
+
+			const testHandler = new TestGroqHandler(mockOptions)
+
+			expect(testHandler.testMapToolChoice(null)).toBeUndefined()
+			expect(testHandler.testMapToolChoice(undefined)).toBeUndefined()
+		})
 	})
 })

+ 191 - 10
src/api/providers/groq.ts

@@ -1,19 +1,200 @@
-import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types"
+import { Anthropic } from "@anthropic-ai/sdk"
+import { createGroq } from "@ai-sdk/groq"
+import { streamText, generateText, ToolSet } from "ai"
+
+import { groqModels, groqDefaultModelId, type ModelInfo } from "@roo-code/types"
 
 import type { ApiHandlerOptions } from "../../shared/api"
 
-import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
+import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk"
+import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
+import { getModelParams } from "../transform/model-params"
+
+import { DEFAULT_HEADERS } from "./constants"
+import { BaseProvider } from "./base-provider"
+import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
+
+const GROQ_DEFAULT_TEMPERATURE = 0.5
+
+/**
+ * Groq provider using the dedicated @ai-sdk/groq package.
+ * Provides native support for reasoning models and prompt caching.
+ */
+export class GroqHandler extends BaseProvider implements SingleCompletionHandler {
+	protected options: ApiHandlerOptions
+	protected provider: ReturnType<typeof createGroq>
 
-export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
 	constructor(options: ApiHandlerOptions) {
-		super({
-			...options,
-			providerName: "Groq",
+		super()
+		this.options = options
+
+		// Create the Groq provider using AI SDK
+		this.provider = createGroq({
 			baseURL: "https://api.groq.com/openai/v1",
-			apiKey: options.groqApiKey,
-			defaultProviderModelId: groqDefaultModelId,
-			providerModels: groqModels,
-			defaultTemperature: 0.5,
+			apiKey: options.groqApiKey ?? "not-provided",
+			headers: DEFAULT_HEADERS,
 		})
 	}
+
+	override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } {
+		const id = this.options.apiModelId ?? groqDefaultModelId
+		const info = groqModels[id as keyof typeof groqModels] || groqModels[groqDefaultModelId]
+		const params = getModelParams({
+			format: "openai",
+			modelId: id,
+			model: info,
+			settings: this.options,
+			defaultTemperature: GROQ_DEFAULT_TEMPERATURE,
+		})
+		return { id, info, ...params }
+	}
+
+	/**
+	 * Get the language model for the configured model ID.
+	 */
+	protected getLanguageModel() {
+		const { id } = this.getModel()
+		return this.provider(id)
+	}
+
+	/**
+	 * Process usage metrics from the AI SDK response, including Groq's cache metrics.
+	 * Groq provides cache hit/miss info via providerMetadata for supported models.
+	 */
+	protected processUsageMetrics(
+		usage: {
+			inputTokens?: number
+			outputTokens?: number
+			details?: {
+				cachedInputTokens?: number
+				reasoningTokens?: number
+			}
+		},
+		providerMetadata?: {
+			groq?: {
+				promptCacheHitTokens?: number
+				promptCacheMissTokens?: number
+			}
+		},
+	): ApiStreamUsageChunk {
+		// Extract cache metrics from Groq's providerMetadata
+		const cacheReadTokens = providerMetadata?.groq?.promptCacheHitTokens ?? usage.details?.cachedInputTokens
+		const cacheWriteTokens = providerMetadata?.groq?.promptCacheMissTokens
+
+		return {
+			type: "usage",
+			inputTokens: usage.inputTokens || 0,
+			outputTokens: usage.outputTokens || 0,
+			cacheReadTokens,
+			cacheWriteTokens,
+			reasoningTokens: usage.details?.reasoningTokens,
+		}
+	}
+
+	/**
+	 * Map OpenAI tool_choice to AI SDK toolChoice format.
+	 */
+	protected mapToolChoice(
+		toolChoice: any,
+	): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined {
+		if (!toolChoice) {
+			return undefined
+		}
+
+		// Handle string values
+		if (typeof toolChoice === "string") {
+			switch (toolChoice) {
+				case "auto":
+					return "auto"
+				case "none":
+					return "none"
+				case "required":
+					return "required"
+				default:
+					return "auto"
+			}
+		}
+
+		// Handle object values (OpenAI ChatCompletionNamedToolChoice format)
+		if (typeof toolChoice === "object" && "type" in toolChoice) {
+			if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) {
+				return { type: "tool", toolName: toolChoice.function.name }
+			}
+		}
+
+		return undefined
+	}
+
+	/**
+	 * Get the max tokens parameter to include in the request.
+	 */
+	protected getMaxOutputTokens(): number | undefined {
+		const { info } = this.getModel()
+		return this.options.modelMaxTokens || info.maxTokens || undefined
+	}
+
+	/**
+	 * Create a message stream using the AI SDK.
+	 * Groq supports reasoning for models like qwen/qwen3-32b via reasoningFormat: 'parsed'.
+	 */
+	override async *createMessage(
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[],
+		metadata?: ApiHandlerCreateMessageMetadata,
+	): ApiStream {
+		const { temperature } = this.getModel()
+		const languageModel = this.getLanguageModel()
+
+		// Convert messages to AI SDK format
+		const aiSdkMessages = convertToAiSdkMessages(messages)
+
+		// Convert tools to OpenAI format first, then to AI SDK format
+		const openAiTools = this.convertToolsForOpenAI(metadata?.tools)
+		const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined
+
+		// Build the request options
+		const requestOptions: Parameters<typeof streamText>[0] = {
+			model: languageModel,
+			system: systemPrompt,
+			messages: aiSdkMessages,
+			temperature: this.options.modelTemperature ?? temperature ?? GROQ_DEFAULT_TEMPERATURE,
+			maxOutputTokens: this.getMaxOutputTokens(),
+			tools: aiSdkTools,
+			toolChoice: this.mapToolChoice(metadata?.tool_choice),
+		}
+
+		// Use streamText for streaming responses
+		const result = streamText(requestOptions)
+
+		// Process the full stream to get all events including reasoning
+		for await (const part of result.fullStream) {
+			for (const chunk of processAiSdkStreamPart(part)) {
+				yield chunk
+			}
+		}
+
+		// Yield usage metrics at the end, including cache metrics from providerMetadata
+		const usage = await result.usage
+		const providerMetadata = await result.providerMetadata
+		if (usage) {
+			yield this.processUsageMetrics(usage, providerMetadata as any)
+		}
+	}
+
+	/**
+	 * Complete a prompt using the AI SDK generateText.
+	 */
+	async completePrompt(prompt: string): Promise<string> {
+		const { temperature } = this.getModel()
+		const languageModel = this.getLanguageModel()
+
+		const { text } = await generateText({
+			model: languageModel,
+			prompt,
+			maxOutputTokens: this.getMaxOutputTokens(),
+			temperature: this.options.modelTemperature ?? temperature ?? GROQ_DEFAULT_TEMPERATURE,
+		})
+
+		return text
+	}
 }

+ 1 - 0
src/package.json

@@ -452,6 +452,7 @@
 	"dependencies": {
 		"@ai-sdk/cerebras": "^1.0.0",
 		"@ai-sdk/deepseek": "^2.0.14",
+		"@ai-sdk/groq": "^3.0.19",
 		"@anthropic-ai/bedrock-sdk": "^0.10.2",
 		"@anthropic-ai/sdk": "^0.37.0",
 		"@anthropic-ai/vertex-sdk": "^0.7.0",