Преглед на файлове

feat: migrate xAI provider to use dedicated @ai-sdk/xai package (#11158)

Co-authored-by: Roo Code <[email protected]>
Co-authored-by: daniel-lxs <[email protected]>
roomote[bot] преди 1 седмица
родител
ревизия
e90e6178e3
променени са 5 файла, в които са добавени 777 реда и са изтрити 582 реда
  1. 4 0
      packages/types/src/providers/xai.ts
  2. 29 7
      pnpm-lock.yaml
  3. 610 466
      src/api/providers/__tests__/xai.spec.ts
  4. 133 109
      src/api/providers/xai.ts
  5. 1 0
      src/package.json

+ 4 - 0
packages/types/src/providers/xai.ts

@@ -30,6 +30,8 @@ export const xaiModels = {
 		cacheReadsPrice: 0.05,
 		description:
 			"xAI's Grok 4.1 Fast model with 2M context window, optimized for high-performance agentic tool calling with reasoning",
+		supportsReasoningEffort: ["low", "high"],
+		reasoningEffort: "low",
 		includedTools: ["search_replace"],
 		excludedTools: ["apply_diff"],
 	},
@@ -58,6 +60,8 @@ export const xaiModels = {
 		cacheReadsPrice: 0.05,
 		description:
 			"xAI's Grok 4 Fast model with 2M context window, optimized for high-performance agentic tool calling with reasoning",
+		supportsReasoningEffort: ["low", "high"],
+		reasoningEffort: "low",
 		includedTools: ["search_replace"],
 		excludedTools: ["apply_diff"],
 	},

+ 29 - 7
pnpm-lock.yaml

@@ -761,6 +761,9 @@ importers:
       '@ai-sdk/mistral':
         specifier: ^3.0.0
         version: 3.0.18([email protected])
+      '@ai-sdk/xai':
+        specifier: ^3.0.46
+        version: 3.0.46([email protected])
       '@anthropic-ai/bedrock-sdk':
         specifier: ^0.10.2
         version: 0.10.4
@@ -1462,6 +1465,12 @@ packages:
     peerDependencies:
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]':
+    resolution: {integrity: sha512-l6jdFjI1C2eDAEm7oo+dnRn0oG1EkcyqfbEZ7ozT0TnYrah6amX2JkftYMP1GRzNtAeCB3WNN8XspXdmi6ZNlQ==}
+    engines: {node: '>=18'}
+    peerDependencies:
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]':
     resolution: {integrity: sha512-iXHVe0apM2zUEzauqJwqmpC37A5rihrStAih5Ks+JE32iTe4LZ58y17UGBjpQQTCRw9YxMeo2UFLxLpBluyvLQ==}
     engines: {node: '>=18'}
@@ -1512,6 +1521,12 @@ packages:
     resolution: {integrity: sha512-VkPLrutM6VdA924/mG8OS+5frbVTcu6e046D2bgDo00tehBANR1QBJ/mPcZ9tXMFOsVcm6SQArOregxePzTFPw==}
     engines: {node: '>=18'}
 
+  '@ai-sdk/[email protected]':
+    resolution: {integrity: sha512-26qM/jYcFhF5krTM7bQT1CiZcdz22EQmA+r5me1hKYFM/yM20sSUMHnAcUzvzuuG9oQVKF0tziU2IcC0HX5huQ==}
+    engines: {node: '>=18'}
+    peerDependencies:
+      zod: 3.25.76
+
   '@alcalzone/[email protected]':
     resolution: {integrity: sha512-jsElTJ0sQ4wHRz+C45tfect76BwbTbgkgKByOzpCN9xG61N5V6u/glvg1CsNJhq2xJIFpKHSwG3D2wPPuEYOrQ==}
     engines: {node: '>=18'}
@@ -6530,10 +6545,6 @@ packages:
     resolution: {integrity: sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==}
     engines: {node: '>=0.8.x'}
 
-  [email protected]:
-    resolution: {integrity: sha512-6RxOBZ/cYgd8usLwsEl+EC09Au/9BcmCKYF2/xbml6DNczf7nv0MQb+7BA2F+li6//I+28VNlQR37XfQtcAJuA==}
-    engines: {node: '>=18.0.0'}
-
   [email protected]:
     resolution: {integrity: sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==}
     engines: {node: '>=18.0.0'}
@@ -11146,6 +11157,12 @@ snapshots:
       '@ai-sdk/provider-utils': 4.0.11([email protected])
       zod: 3.25.76
 
+  '@ai-sdk/[email protected]([email protected])':
+    dependencies:
+      '@ai-sdk/provider': 3.0.7
+      '@ai-sdk/provider-utils': 4.0.13([email protected])
+      zod: 3.25.76
+
   '@ai-sdk/[email protected]([email protected])':
     dependencies:
       '@ai-sdk/provider': 2.0.1
@@ -11202,6 +11219,13 @@ snapshots:
     dependencies:
       json-schema: 0.4.0
 
+  '@ai-sdk/[email protected]([email protected])':
+    dependencies:
+      '@ai-sdk/openai-compatible': 2.0.26([email protected])
+      '@ai-sdk/provider': 3.0.7
+      '@ai-sdk/provider-utils': 4.0.13([email protected])
+      zod: 3.25.76
+
   '@alcalzone/[email protected]':
     dependencies:
       ansi-styles: 6.2.3
@@ -17027,13 +17051,11 @@ snapshots:
 
   [email protected]: {}
 
-  [email protected]: {}
-
   [email protected]: {}
 
   [email protected]:
     dependencies:
-      eventsource-parser: 3.0.2
+      eventsource-parser: 3.0.6
 
   [email protected]:
     dependencies:

+ 610 - 466
src/api/providers/__tests__/xai.spec.ts

@@ -1,587 +1,731 @@
-// npx vitest api/providers/__tests__/xai.spec.ts
-
-// Mock TelemetryService - must come before other imports
-const mockCaptureException = vitest.hoisted(() => vitest.fn())
-vitest.mock("@roo-code/telemetry", () => ({
-	TelemetryService: {
-		instance: {
-			captureException: mockCaptureException,
-		},
-	},
-}))
-
-const mockCreate = vitest.fn()
+// npx vitest run api/providers/__tests__/xai.spec.ts
 
-vitest.mock("openai", () => {
-	const mockConstructor = vitest.fn()
+// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls
+const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({
+	mockStreamText: vi.fn(),
+	mockGenerateText: vi.fn(),
+}))
 
+vi.mock("ai", async (importOriginal) => {
+	const actual = await importOriginal<typeof import("ai")>()
 	return {
-		__esModule: true,
-		default: mockConstructor.mockImplementation(() => ({ chat: { completions: { create: mockCreate } } })),
+		...actual,
+		streamText: mockStreamText,
+		generateText: mockGenerateText,
 	}
 })
 
-import OpenAI from "openai"
+vi.mock("@ai-sdk/xai", () => ({
+	createXai: vi.fn(() => {
+		// Return a function that returns a mock language model
+		return vi.fn(() => ({
+			modelId: "grok-code-fast-1",
+			provider: "xai",
+		}))
+	}),
+}))
+
 import type { Anthropic } from "@anthropic-ai/sdk"
 
-import { xaiDefaultModelId, xaiModels } from "@roo-code/types"
+import { xaiDefaultModelId, xaiModels, type XAIModelId } from "@roo-code/types"
+
+import type { ApiHandlerOptions } from "../../../shared/api"
 
 import { XAIHandler } from "../xai"
 
 describe("XAIHandler", () => {
 	let handler: XAIHandler
+	let mockOptions: ApiHandlerOptions
 
 	beforeEach(() => {
-		// Reset all mocks
+		mockOptions = {
+			xaiApiKey: "test-xai-api-key",
+			apiModelId: "grok-code-fast-1",
+		}
+		handler = new XAIHandler(mockOptions)
 		vi.clearAllMocks()
-		mockCreate.mockClear()
-		mockCaptureException.mockClear()
-
-		// Create handler with mock
-		handler = new XAIHandler({})
 	})
 
-	it("should use the correct X.AI base URL", () => {
-		expect(OpenAI).toHaveBeenCalledWith(
-			expect.objectContaining({
-				baseURL: "https://api.x.ai/v1",
-			}),
-		)
-	})
+	describe("constructor", () => {
+		it("should initialize with provided options", () => {
+			expect(handler).toBeInstanceOf(XAIHandler)
+			expect(handler.getModel().id).toBe(mockOptions.apiModelId)
+		})
 
-	it("should use the provided API key", () => {
-		// Clear mocks before this specific test
-		vi.clearAllMocks()
+		it("should use default model ID if not provided", () => {
+			const handlerWithoutModel = new XAIHandler({
+				...mockOptions,
+				apiModelId: undefined,
+			})
+			expect(handlerWithoutModel.getModel().id).toBe(xaiDefaultModelId)
+		})
+	})
 
-		// Create a handler with our API key
-		const xaiApiKey = "test-api-key"
-		new XAIHandler({ xaiApiKey })
+	describe("getModel", () => {
+		it("should return default model when no model is specified", () => {
+			const handlerWithoutModel = new XAIHandler({
+				xaiApiKey: "test-xai-api-key",
+			})
+			const model = handlerWithoutModel.getModel()
+			expect(model.id).toBe(xaiDefaultModelId)
+			expect(model.info).toEqual(xaiModels[xaiDefaultModelId])
+		})
 
-		// Verify the OpenAI constructor was called with our API key
-		expect(OpenAI).toHaveBeenCalledWith(
-			expect.objectContaining({
-				apiKey: xaiApiKey,
-			}),
-		)
-	})
+		it("should return specified model when valid model is provided", () => {
+			const testModelId: XAIModelId = "grok-3"
+			const handlerWithModel = new XAIHandler({
+				apiModelId: testModelId,
+				xaiApiKey: "test-xai-api-key",
+			})
+			const model = handlerWithModel.getModel()
+			expect(model.id).toBe(testModelId)
+			expect(model.info).toEqual(xaiModels[testModelId])
+		})
 
-	it("should return default model when no model is specified", () => {
-		const model = handler.getModel()
-		expect(model.id).toBe(xaiDefaultModelId)
-		expect(model.info).toEqual(xaiModels[xaiDefaultModelId])
-	})
+		it("should return grok-3-mini model with correct configuration", () => {
+			const testModelId: XAIModelId = "grok-3-mini"
+			const handlerWithModel = new XAIHandler({
+				apiModelId: testModelId,
+				xaiApiKey: "test-xai-api-key",
+			})
+			const model = handlerWithModel.getModel()
+			expect(model.id).toBe(testModelId)
+			expect(model.info).toEqual(
+				expect.objectContaining({
+					maxTokens: 8192,
+					contextWindow: 131072,
+					supportsImages: true,
+					supportsPromptCache: true,
+					inputPrice: 0.3,
+					outputPrice: 0.5,
+				}),
+			)
+		})
 
-	test("should return specified model when valid model is provided", () => {
-		const testModelId = "grok-3"
-		const handlerWithModel = new XAIHandler({ apiModelId: testModelId })
-		const model = handlerWithModel.getModel()
+		it("should return grok-4-0709 model with correct configuration", () => {
+			const testModelId: XAIModelId = "grok-4-0709"
+			const handlerWithModel = new XAIHandler({
+				apiModelId: testModelId,
+				xaiApiKey: "test-xai-api-key",
+			})
+			const model = handlerWithModel.getModel()
+			expect(model.id).toBe(testModelId)
+			expect(model.info).toEqual(
+				expect.objectContaining({
+					maxTokens: 8192,
+					contextWindow: 256_000,
+					supportsImages: true,
+					supportsPromptCache: true,
+					inputPrice: 3.0,
+					outputPrice: 15.0,
+				}),
+			)
+		})
 
-		expect(model.id).toBe(testModelId)
-		expect(model.info).toEqual(xaiModels[testModelId])
-	})
+		it("should fall back to default model for invalid model ID", () => {
+			const handlerWithInvalidModel = new XAIHandler({
+				...mockOptions,
+				apiModelId: "invalid-model",
+			})
+			const model = handlerWithInvalidModel.getModel()
+			expect(model.id).toBe(xaiDefaultModelId)
+			expect(model.info).toBe(xaiModels[xaiDefaultModelId])
+		})
 
-	it("should include reasoning_effort parameter for mini models", async () => {
-		const miniModelHandler = new XAIHandler({
-			apiModelId: "grok-3-mini",
-			reasoningEffort: "high",
+		it("should include model parameters from getModelParams", () => {
+			const model = handler.getModel()
+			expect(model).toHaveProperty("temperature")
+			expect(model).toHaveProperty("maxTokens")
 		})
+	})
 
-		// Setup mock for streaming response
-		mockCreate.mockImplementationOnce(() => {
-			return {
-				[Symbol.asyncIterator]: () => ({
-					async next() {
-						return { done: true }
+	describe("createMessage", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [
+					{
+						type: "text" as const,
+						text: "Hello!",
 					},
-				}),
+				],
+			},
+		]
+
+		it("should handle streaming responses", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response from xAI" }
 			}
-		})
 
-		// Start generating a message
-		const messageGenerator = miniModelHandler.createMessage("test prompt", [])
-		await messageGenerator.next() // Start the generator
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
+			})
 
-		// Check that reasoning_effort was included
-		expect(mockCreate).toHaveBeenCalledWith(
-			expect.objectContaining({
-				reasoning_effort: "high",
-			}),
-		)
-	})
+			const mockProviderMetadata = Promise.resolve({})
 
-	it("should not include reasoning_effort parameter for non-mini models", async () => {
-		const regularModelHandler = new XAIHandler({
-			apiModelId: "grok-3",
-			reasoningEffort: "high",
-		})
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
 
-		// Setup mock for streaming response
-		mockCreate.mockImplementationOnce(() => {
-			return {
-				[Symbol.asyncIterator]: () => ({
-					async next() {
-						return { done: true }
-					},
-				}),
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
 			}
+
+			expect(chunks.length).toBeGreaterThan(0)
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks).toHaveLength(1)
+			expect(textChunks[0].text).toBe("Test response from xAI")
 		})
 
-		// Start generating a message
-		const messageGenerator = regularModelHandler.createMessage("test prompt", [])
-		await messageGenerator.next() // Start the generator
+		it("should include usage information", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
 
-		// Check call args for reasoning_effort
-		const calls = mockCreate.mock.calls
-		const lastCall = calls[calls.length - 1][0]
-		expect(lastCall).not.toHaveProperty("reasoning_effort")
-	})
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 20,
+			})
 
-	it("completePrompt method should return text from OpenAI API", async () => {
-		const expectedResponse = "This is a test response"
-		mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })
+			const mockProviderMetadata = Promise.resolve({})
 
-		const result = await handler.completePrompt("test prompt")
-		expect(result).toBe(expectedResponse)
-	})
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
 
-	it("should handle errors in completePrompt", async () => {
-		const errorMessage = "API error"
-		mockCreate.mockRejectedValueOnce(new Error(errorMessage))
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
 
-		await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`)
-	})
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].inputTokens).toBe(10)
+			expect(usageChunks[0].outputTokens).toBe(20)
+		})
 
-	it("createMessage should yield text content from stream", async () => {
-		const testContent = "This is test content"
-
-		// Setup mock for streaming response
-		mockCreate.mockImplementationOnce(() => {
-			return {
-				[Symbol.asyncIterator]: () => ({
-					next: vi
-						.fn()
-						.mockResolvedValueOnce({
-							done: false,
-							value: {
-								choices: [{ delta: { content: testContent } }],
-							},
-						})
-						.mockResolvedValueOnce({ done: true }),
-				}),
+		it("should handle cached tokens in usage data from providerMetadata", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
 			}
-		})
 
-		// Create and consume the stream
-		const stream = handler.createMessage("system prompt", [])
-		const firstChunk = await stream.next()
+			const mockUsage = Promise.resolve({
+				inputTokens: 100,
+				outputTokens: 50,
+			})
 
-		// Verify the content
-		expect(firstChunk.done).toBe(false)
-		expect(firstChunk.value).toEqual({
-			type: "text",
-			text: testContent,
-		})
-	})
+			// xAI provides cache metrics via providerMetadata for supported models
+			const mockProviderMetadata = Promise.resolve({
+				xai: {
+					cachedPromptTokens: 30,
+				},
+			})
 
-	it("createMessage should yield reasoning content from stream", async () => {
-		const testReasoning = "Test reasoning content"
-
-		// Setup mock for streaming response
-		mockCreate.mockImplementationOnce(() => {
-			return {
-				[Symbol.asyncIterator]: () => ({
-					next: vi
-						.fn()
-						.mockResolvedValueOnce({
-							done: false,
-							value: {
-								choices: [{ delta: { reasoning_content: testReasoning } }],
-							},
-						})
-						.mockResolvedValueOnce({ done: true }),
-				}),
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
 			}
+
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].inputTokens).toBe(100)
+			expect(usageChunks[0].outputTokens).toBe(50)
+			expect(usageChunks[0].cacheReadTokens).toBe(30)
 		})
 
-		// Create and consume the stream
-		const stream = handler.createMessage("system prompt", [])
-		const firstChunk = await stream.next()
+		it("should handle usage with details.cachedInputTokens when providerMetadata is not available", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test response" }
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 100,
+				outputTokens: 50,
+				details: {
+					cachedInputTokens: 25,
+				},
+			})
+
+			const mockProviderMetadata = Promise.resolve({})
 
-		// Verify the reasoning content
-		expect(firstChunk.done).toBe(false)
-		expect(firstChunk.value).toEqual({
-			type: "reasoning",
-			text: testReasoning,
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
+			expect(usageChunks.length).toBeGreaterThan(0)
+			expect(usageChunks[0].cacheReadTokens).toBe(25)
+			expect(usageChunks[0].cacheWriteTokens).toBeUndefined()
 		})
-	})
 
-	it("createMessage should yield usage data from stream", async () => {
-		// Setup mock for streaming response that includes usage data
-		mockCreate.mockImplementationOnce(() => {
-			return {
-				[Symbol.asyncIterator]: () => ({
-					next: vi
-						.fn()
-						.mockResolvedValueOnce({
-							done: false,
-							value: {
-								choices: [{ delta: {} }], // Needs to have choices array to avoid error
-								usage: {
-									prompt_tokens: 10,
-									completion_tokens: 20,
-									cache_read_input_tokens: 5,
-									cache_creation_input_tokens: 15,
-								},
-							},
-						})
-						.mockResolvedValueOnce({ done: true }),
+		it("should pass correct temperature (0 default) to streamText", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+				providerMetadata: Promise.resolve({}),
+			})
+
+			const handlerWithDefaultTemp = new XAIHandler({
+				xaiApiKey: "test-key",
+				apiModelId: "grok-code-fast-1",
+			})
+
+			const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages)
+			for await (const _ of stream) {
+				// consume stream
+			}
+
+			expect(mockStreamText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					temperature: 0,
 				}),
+			)
+		})
+
+		it("should use user-specified temperature over default", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+				providerMetadata: Promise.resolve({}),
+			})
+
+			const handlerWithCustomTemp = new XAIHandler({
+				xaiApiKey: "test-key",
+				apiModelId: "grok-3",
+				modelTemperature: 0.7,
+			})
+
+			const stream = handlerWithCustomTemp.createMessage(systemPrompt, messages)
+			for await (const _ of stream) {
+				// consume stream
 			}
+
+			// User-specified temperature should take precedence over everything
+			expect(mockStreamText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					temperature: 0.7,
+				}),
+			)
 		})
 
-		// Create and consume the stream
-		const stream = handler.createMessage("system prompt", [])
-		const firstChunk = await stream.next()
-
-		// Verify the usage data
-		expect(firstChunk.done).toBe(false)
-		expect(firstChunk.value).toEqual({
-			type: "usage",
-			inputTokens: 10,
-			outputTokens: 20,
-			cacheReadTokens: 5,
-			cacheWriteTokens: 15,
+		it("should handle stream with multiple chunks", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Hello" }
+				yield { type: "text-delta", text: " world" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 5, outputTokens: 10 }),
+				providerMetadata: Promise.resolve({}),
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			const textChunks = chunks.filter((c) => c.type === "text")
+			expect(textChunks[0]).toEqual({ type: "text", text: "Hello" })
+			expect(textChunks[1]).toEqual({ type: "text", text: " world" })
+
+			const usageChunks = chunks.filter((c) => c.type === "usage")
+			expect(usageChunks[0]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 })
 		})
-	})
 
-	it("createMessage should pass correct parameters to OpenAI client", async () => {
-		// Setup a handler with specific model
-		const modelId = "grok-3"
-		const modelInfo = xaiModels[modelId]
-		const handlerWithModel = new XAIHandler({ apiModelId: modelId })
-
-		// Setup mock for streaming response
-		mockCreate.mockImplementationOnce(() => {
-			return {
-				[Symbol.asyncIterator]: () => ({
-					async next() {
-						return { done: true }
-					},
-				}),
+		it("should handle reasoning content from stream", async () => {
+			async function* mockFullStream() {
+				yield { type: "reasoning-delta", text: "Let me think about this..." }
+				yield { type: "text-delta", text: "Here is my answer" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 10, outputTokens: 20 }),
+				providerMetadata: Promise.resolve({}),
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
 			}
+
+			const reasoningChunks = chunks.filter((c) => c.type === "reasoning")
+			expect(reasoningChunks).toHaveLength(1)
+			expect(reasoningChunks[0].text).toBe("Let me think about this...")
+
+			const textChunks = chunks.filter((c) => c.type === "text")
+			expect(textChunks).toHaveLength(1)
+			expect(textChunks[0].text).toBe("Here is my answer")
 		})
 
-		// System prompt and messages
-		const systemPrompt = "Test system prompt"
-		const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]
-
-		// Start generating a message
-		const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
-		await messageGenerator.next() // Start the generator
-
-		// Check that all parameters were passed correctly
-		expect(mockCreate).toHaveBeenCalledWith(
-			expect.objectContaining({
-				model: modelId,
-				max_tokens: modelInfo.maxTokens,
-				temperature: 0,
-				messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
-				stream: true,
-				stream_options: { include_usage: true },
-			}),
-		)
-	})
+		it("should handle errors during streaming", async () => {
+			const mockError = new Error("API error")
+			;(mockError as any).name = "AI_APICallError"
+			;(mockError as any).status = 500
 
-	describe("Native Tool Calling", () => {
-		const testTools = [
-			{
-				type: "function" as const,
-				function: {
-					name: "test_tool",
-					description: "A test tool",
-					parameters: {
-						type: "object",
-						properties: {
-							arg1: { type: "string", description: "First argument" },
-						},
-						required: ["arg1"],
-					},
-				},
-			},
-		]
+			async function* mockFullStream(): AsyncGenerator<never> {
+				// This yield is unreachable but needed to satisfy the require-yield lint rule
+				yield undefined as never
+				throw mockError
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+				providerMetadata: Promise.resolve({}),
+			})
 
-		it("should include tools in request when model supports native tools and tools are provided (native is default)", async () => {
-			const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
+			const stream = handler.createMessage(systemPrompt, messages)
 
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						async next() {
-							return { done: true }
-						},
-					}),
+			await expect(async () => {
+				for await (const _ of stream) {
+					// consume stream
 				}
-			})
+			}).rejects.toThrow("xAI")
+		})
+	})
 
-			const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
-				taskId: "test-task-id",
-				tools: testTools,
+	describe("completePrompt", () => {
+		it("should complete a prompt using generateText", async () => {
+			mockGenerateText.mockResolvedValue({
+				text: "Test completion from xAI",
 			})
-			await messageGenerator.next()
 
-			expect(mockCreate).toHaveBeenCalledWith(
+			const result = await handler.completePrompt("Test prompt")
+
+			expect(result).toBe("Test completion from xAI")
+			expect(mockGenerateText).toHaveBeenCalledWith(
 				expect.objectContaining({
-					tools: expect.arrayContaining([
-						expect.objectContaining({
-							type: "function",
-							function: expect.objectContaining({
-								name: "test_tool",
-							}),
-						}),
-					]),
-					parallel_tool_calls: true,
+					prompt: "Test prompt",
 				}),
 			)
 		})
 
-		it("should include tool_choice when provided", async () => {
-			const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
-
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						async next() {
-							return { done: true }
-						},
-					}),
-				}
+		it("should use default temperature in completePrompt", async () => {
+			mockGenerateText.mockResolvedValue({
+				text: "Test completion",
 			})
 
-			const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
-				taskId: "test-task-id",
-				tools: testTools,
-				tool_choice: "auto",
-			})
-			await messageGenerator.next()
+			await handler.completePrompt("Test prompt")
 
-			expect(mockCreate).toHaveBeenCalledWith(
+			expect(mockGenerateText).toHaveBeenCalledWith(
 				expect.objectContaining({
-					tool_choice: "auto",
+					temperature: 0,
 				}),
 			)
 		})
 
-		it("should always include tools and tool_choice (tools are guaranteed to be present after ALWAYS_AVAILABLE_TOOLS)", async () => {
-			const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
+		it("should handle errors in completePrompt", async () => {
+			const mockError = new Error("API error")
+			;(mockError as any).name = "AI_APICallError"
+			mockGenerateText.mockRejectedValue(mockError)
 
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						async next() {
-							return { done: true }
-						},
-					}),
+			await expect(handler.completePrompt("Test prompt")).rejects.toThrow("xAI")
+		})
+	})
+
+	describe("processUsageMetrics", () => {
+		it("should correctly process usage metrics including cache information from providerMetadata", () => {
+			class TestXAIHandler extends XAIHandler {
+				public testProcessUsageMetrics(usage: any, providerMetadata?: any) {
+					return this.processUsageMetrics(usage, providerMetadata)
 				}
-			})
+			}
 
-			const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
-				taskId: "test-task-id",
-			})
-			await messageGenerator.next()
+			const testHandler = new TestXAIHandler(mockOptions)
+
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
+			}
+
+			const providerMetadata = {
+				xai: {
+					cachedPromptTokens: 20,
+				},
+			}
+
+			const result = testHandler.testProcessUsageMetrics(usage, providerMetadata)
+
+			expect(result.type).toBe("usage")
+			expect(result.inputTokens).toBe(100)
+			expect(result.outputTokens).toBe(50)
+			expect(result.cacheReadTokens).toBe(20)
+			// xAI doesn't report cache write tokens separately
+			expect(result.cacheWriteTokens).toBeUndefined()
+		})
+
+		it("should handle missing cache metrics gracefully", () => {
+			class TestXAIHandler extends XAIHandler {
+				public testProcessUsageMetrics(usage: any, providerMetadata?: any) {
+					return this.processUsageMetrics(usage, providerMetadata)
+				}
+			}
+
+			const testHandler = new TestXAIHandler(mockOptions)
+
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
+			}
+
+			const result = testHandler.testProcessUsageMetrics(usage)
+
+			expect(result.type).toBe("usage")
+			expect(result.inputTokens).toBe(100)
+			expect(result.outputTokens).toBe(50)
+			expect(result.cacheWriteTokens).toBeUndefined()
+			expect(result.cacheReadTokens).toBeUndefined()
+		})
+
+		it("should include reasoning tokens when provided", () => {
+			class TestXAIHandler extends XAIHandler {
+				public testProcessUsageMetrics(usage: any, providerMetadata?: any) {
+					return this.processUsageMetrics(usage, providerMetadata)
+				}
+			}
 
-			// Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS)
-			const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0]
-			expect(callArgs).toHaveProperty("tools")
-			expect(callArgs).toHaveProperty("tool_choice")
-			expect(callArgs).toHaveProperty("parallel_tool_calls", true)
+			const testHandler = new TestXAIHandler(mockOptions)
+
+			const usage = {
+				inputTokens: 100,
+				outputTokens: 50,
+				details: {
+					reasoningTokens: 30,
+				},
+			}
+
+			const result = testHandler.testProcessUsageMetrics(usage)
+
+			expect(result.reasoningTokens).toBe(30)
 		})
+	})
 
-		it("should yield tool_call_partial chunks during streaming", async () => {
-			const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
-
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						next: vi
-							.fn()
-							.mockResolvedValueOnce({
-								done: false,
-								value: {
-									choices: [
-										{
-											delta: {
-												tool_calls: [
-													{
-														index: 0,
-														id: "call_123",
-														function: {
-															name: "test_tool",
-															arguments: '{"arg1":',
-														},
-													},
-												],
-											},
-										},
-									],
-								},
-							})
-							.mockResolvedValueOnce({
-								done: false,
-								value: {
-									choices: [
-										{
-											delta: {
-												tool_calls: [
-													{
-														index: 0,
-														function: {
-															arguments: '"value"}',
-														},
-													},
-												],
-											},
-										},
-									],
-								},
-							})
-							.mockResolvedValueOnce({ done: true }),
-					}),
+	describe("tool handling", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [{ type: "text" as const, text: "Hello!" }],
+			},
+		]
+
+		it("should handle tool calls in streaming", async () => {
+			async function* mockFullStream() {
+				yield {
+					type: "tool-input-start",
+					id: "tool-call-1",
+					toolName: "read_file",
+				}
+				yield {
+					type: "tool-input-delta",
+					id: "tool-call-1",
+					delta: '{"path":"test.ts"}',
+				}
+				yield {
+					type: "tool-input-end",
+					id: "tool-call-1",
 				}
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
 			})
 
-			const stream = handlerWithTools.createMessage("test prompt", [], {
-				taskId: "test-task-id",
-				tools: testTools,
+			const mockProviderMetadata = Promise.resolve({})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: [
+					{
+						type: "function",
+						function: {
+							name: "read_file",
+							description: "Read a file",
+							parameters: {
+								type: "object",
+								properties: { path: { type: "string" } },
+								required: ["path"],
+							},
+						},
+					},
+				],
 			})
 
-			const chunks = []
+			const chunks: any[] = []
 			for await (const chunk of stream) {
 				chunks.push(chunk)
 			}
 
-			expect(chunks).toContainEqual({
-				type: "tool_call_partial",
-				index: 0,
-				id: "call_123",
-				name: "test_tool",
-				arguments: '{"arg1":',
+			const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start")
+			const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta")
+			const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end")
+
+			expect(toolCallStartChunks.length).toBe(1)
+			expect(toolCallStartChunks[0].id).toBe("tool-call-1")
+			expect(toolCallStartChunks[0].name).toBe("read_file")
+
+			expect(toolCallDeltaChunks.length).toBe(1)
+			expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}')
+
+			expect(toolCallEndChunks.length).toBe(1)
+			expect(toolCallEndChunks[0].id).toBe("tool-call-1")
+		})
+
+		it("should ignore tool-call events to prevent duplicate tools in UI", async () => {
+			async function* mockFullStream() {
+				yield {
+					type: "tool-call",
+					toolCallId: "tool-call-1",
+					toolName: "read_file",
+					input: { path: "test.ts" },
+				}
+			}
+
+			const mockUsage = Promise.resolve({
+				inputTokens: 10,
+				outputTokens: 5,
 			})
 
-			expect(chunks).toContainEqual({
-				type: "tool_call_partial",
-				index: 0,
-				id: undefined,
-				name: undefined,
-				arguments: '"value"}',
+			const mockProviderMetadata = Promise.resolve({})
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: mockUsage,
+				providerMetadata: mockProviderMetadata,
 			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// tool-call events should be ignored (only tool-input-start/delta/end are processed)
+			const toolCallChunks = chunks.filter(
+				(c) => c.type === "tool_call_start" || c.type === "tool_call_delta" || c.type === "tool_call_end",
+			)
+			expect(toolCallChunks.length).toBe(0)
 		})
 
-		it("should set parallel_tool_calls based on metadata", async () => {
-			const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
+		it("should pass tools to streamText when provided", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test" }
+			}
 
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						async next() {
-							return { done: true }
-						},
-					}),
-				}
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+				providerMetadata: Promise.resolve({}),
 			})
 
-			const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
-				taskId: "test-task-id",
+			const testTools = [
+				{
+					type: "function" as const,
+					function: {
+						name: "test_tool",
+						description: "A test tool",
+						parameters: {
+							type: "object",
+							properties: {
+								arg1: { type: "string", description: "First argument" },
+							},
+							required: ["arg1"],
+						},
+					},
+				},
+			]
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
 				tools: testTools,
-				parallelToolCalls: true,
+				tool_choice: "auto",
 			})
-			await messageGenerator.next()
 
-			expect(mockCreate).toHaveBeenCalledWith(
+			for await (const _ of stream) {
+				// consume stream
+			}
+
+			expect(mockStreamText).toHaveBeenCalledWith(
 				expect.objectContaining({
-					parallel_tool_calls: true,
+					tools: expect.any(Object),
+					toolChoice: "auto",
 				}),
 			)
 		})
+	})
 
-		it("should yield tool_call_end events when finish_reason is tool_calls", async () => {
-			// Import NativeToolCallParser to set up state
-			const { NativeToolCallParser } = await import("../../../core/assistant-message/NativeToolCallParser")
-
-			// Clear any previous state
-			NativeToolCallParser.clearRawChunkState()
-
-			const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
-
-			mockCreate.mockImplementationOnce(() => {
-				return {
-					[Symbol.asyncIterator]: () => ({
-						next: vi
-							.fn()
-							.mockResolvedValueOnce({
-								done: false,
-								value: {
-									choices: [
-										{
-											delta: {
-												tool_calls: [
-													{
-														index: 0,
-														id: "call_xai_test",
-														function: {
-															name: "test_tool",
-															arguments: '{"arg1":"value"}',
-														},
-													},
-												],
-											},
-										},
-									],
-								},
-							})
-							.mockResolvedValueOnce({
-								done: false,
-								value: {
-									choices: [
-										{
-											delta: {},
-											finish_reason: "tool_calls",
-										},
-									],
-									usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
-								},
-							})
-							.mockResolvedValueOnce({ done: true }),
-					}),
-				}
+	describe("reasoning effort (mini models)", () => {
+		it("should include reasoning effort for grok-3-mini model", async () => {
+			async function* mockFullStream() {
+				yield { type: "text-delta", text: "Test" }
+			}
+
+			mockStreamText.mockReturnValue({
+				fullStream: mockFullStream(),
+				usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
+				providerMetadata: Promise.resolve({}),
 			})
 
-			const stream = handlerWithTools.createMessage("test prompt", [], {
-				taskId: "test-task-id",
-				tools: testTools,
+			const miniModelHandler = new XAIHandler({
+				xaiApiKey: "test-key",
+				apiModelId: "grok-3-mini",
+				reasoningEffort: "high",
 			})
 
-			const chunks = []
-			for await (const chunk of stream) {
-				// Simulate what Task.ts does: when we receive tool_call_partial,
-				// process it through NativeToolCallParser to populate rawChunkTracker
-				if (chunk.type === "tool_call_partial") {
-					NativeToolCallParser.processRawChunk({
-						index: chunk.index,
-						id: chunk.id,
-						name: chunk.name,
-						arguments: chunk.arguments,
-					})
-				}
-				chunks.push(chunk)
+			const stream = miniModelHandler.createMessage("test prompt", [])
+			for await (const _ of stream) {
+				// consume stream
 			}
 
-			// Should have tool_call_partial and tool_call_end
-			const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
-			const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
-
-			expect(partialChunks).toHaveLength(1)
-			expect(endChunks).toHaveLength(1)
-			expect(endChunks[0].id).toBe("call_xai_test")
+			// Check that provider options are passed for reasoning
+			expect(mockStreamText).toHaveBeenCalledWith(
+				expect.objectContaining({
+					providerOptions: expect.any(Object),
+				}),
+			)
 		})
 	})
 })

+ 133 - 109
src/api/providers/xai.ts

@@ -1,166 +1,190 @@
 import { Anthropic } from "@anthropic-ai/sdk"
-import OpenAI from "openai"
+import { createXai } from "@ai-sdk/xai"
+import { streamText, generateText, ToolSet } from "ai"
 
-import { type XAIModelId, xaiDefaultModelId, xaiModels, ApiProviderError } from "@roo-code/types"
-import { TelemetryService } from "@roo-code/telemetry"
+import { type XAIModelId, xaiDefaultModelId, xaiModels, type ModelInfo } from "@roo-code/types"
 
-import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser"
 import type { ApiHandlerOptions } from "../../shared/api"
 
-import { ApiStream } from "../transform/stream"
-import { convertToOpenAiMessages } from "../transform/openai-format"
+import {
+	convertToAiSdkMessages,
+	convertToolsForAiSdk,
+	processAiSdkStreamPart,
+	mapToolChoice,
+	handleAiSdkError,
+} from "../transform/ai-sdk"
+import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
 import { getModelParams } from "../transform/model-params"
 
 import { DEFAULT_HEADERS } from "./constants"
 import { BaseProvider } from "./base-provider"
 import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
-import { handleOpenAIError } from "./utils/openai-error-handler"
 
 const XAI_DEFAULT_TEMPERATURE = 0
 
+/**
+ * xAI provider using the dedicated @ai-sdk/xai package.
+ * Provides native support for Grok models including reasoning models.
+ */
 export class XAIHandler extends BaseProvider implements SingleCompletionHandler {
 	protected options: ApiHandlerOptions
-	private client: OpenAI
-	private readonly providerName = "xAI"
+	protected provider: ReturnType<typeof createXai>
 
 	constructor(options: ApiHandlerOptions) {
 		super()
 		this.options = options
 
-		const apiKey = this.options.xaiApiKey ?? "not-provided"
-
-		this.client = new OpenAI({
+		// Create the xAI provider using AI SDK
+		this.provider = createXai({
 			baseURL: "https://api.x.ai/v1",
-			apiKey: apiKey,
-			defaultHeaders: DEFAULT_HEADERS,
+			apiKey: options.xaiApiKey ?? "not-provided",
+			headers: DEFAULT_HEADERS,
 		})
 	}
 
-	override getModel() {
+	override getModel(): {
+		id: XAIModelId
+		info: ModelInfo
+		maxTokens?: number
+		temperature?: number
+		reasoning?: any
+	} {
 		const id =
 			this.options.apiModelId && this.options.apiModelId in xaiModels
 				? (this.options.apiModelId as XAIModelId)
 				: xaiDefaultModelId
 
 		const info = xaiModels[id]
-		const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options })
+		const params = getModelParams({
+			format: "openai",
+			modelId: id,
+			model: info,
+			settings: this.options,
+			defaultTemperature: XAI_DEFAULT_TEMPERATURE,
+		})
 		return { id, info, ...params }
 	}
 
+	/**
+	 * Get the language model for the configured model ID.
+	 */
+	protected getLanguageModel() {
+		const { id } = this.getModel()
+		return this.provider(id)
+	}
+
+	/**
+	 * Process usage metrics from the AI SDK response.
+	 */
+	protected processUsageMetrics(
+		usage: {
+			inputTokens?: number
+			outputTokens?: number
+			details?: {
+				cachedInputTokens?: number
+				reasoningTokens?: number
+			}
+		},
+		providerMetadata?: {
+			xai?: {
+				cachedPromptTokens?: number
+			}
+		},
+	): ApiStreamUsageChunk {
+		// Extract cache metrics from xAI's providerMetadata if available
+		// xAI supports prompt caching through prompt_tokens_details.cached_tokens
+		const cacheReadTokens = providerMetadata?.xai?.cachedPromptTokens ?? usage.details?.cachedInputTokens
+
+		return {
+			type: "usage",
+			inputTokens: usage.inputTokens || 0,
+			outputTokens: usage.outputTokens || 0,
+			cacheReadTokens,
+			cacheWriteTokens: undefined, // xAI doesn't report cache write tokens separately
+			reasoningTokens: usage.details?.reasoningTokens,
+		}
+	}
+
+	/**
+	 * Get the max tokens parameter to include in the request.
+	 */
+	protected getMaxOutputTokens(): number | undefined {
+		const { info } = this.getModel()
+		return this.options.modelMaxTokens || info.maxTokens || undefined
+	}
+
+	/**
+	 * Create a message stream using the AI SDK.
+	 */
 	override async *createMessage(
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
 		metadata?: ApiHandlerCreateMessageMetadata,
 	): ApiStream {
-		const { id: modelId, info: modelInfo, reasoning } = this.getModel()
-
-		// Use the OpenAI-compatible API.
-		const requestOptions = {
-			model: modelId,
-			max_tokens: modelInfo.maxTokens,
-			temperature: this.options.modelTemperature ?? XAI_DEFAULT_TEMPERATURE,
-			messages: [
-				{ role: "system", content: systemPrompt },
-				...convertToOpenAiMessages(messages),
-			] as OpenAI.Chat.ChatCompletionMessageParam[],
-			stream: true as const,
-			stream_options: { include_usage: true },
-			...(reasoning && reasoning),
-			tools: this.convertToolsForOpenAI(metadata?.tools),
-			tool_choice: metadata?.tool_choice,
-			parallel_tool_calls: metadata?.parallelToolCalls ?? true,
+		const { temperature, reasoning } = this.getModel()
+		const languageModel = this.getLanguageModel()
+
+		// Convert messages to AI SDK format
+		const aiSdkMessages = convertToAiSdkMessages(messages)
+
+		// Convert tools to OpenAI format first, then to AI SDK format
+		const openAiTools = this.convertToolsForOpenAI(metadata?.tools)
+		const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined
+
+		// Build the request options
+		const requestOptions: Parameters<typeof streamText>[0] = {
+			model: languageModel,
+			system: systemPrompt,
+			messages: aiSdkMessages,
+			temperature: this.options.modelTemperature ?? temperature ?? XAI_DEFAULT_TEMPERATURE,
+			maxOutputTokens: this.getMaxOutputTokens(),
+			tools: aiSdkTools,
+			toolChoice: mapToolChoice(metadata?.tool_choice),
+			...(reasoning && { providerOptions: { xai: reasoning } }),
 		}
 
-		let stream
-		try {
-			stream = await this.client.chat.completions.create(requestOptions)
-		} catch (error) {
-			const errorMessage = error instanceof Error ? error.message : String(error)
-			const apiError = new ApiProviderError(errorMessage, this.providerName, modelId, "createMessage")
-			TelemetryService.instance.captureException(apiError)
-			throw handleOpenAIError(error, this.providerName)
-		}
+		// Use streamText for streaming responses
+		const result = streamText(requestOptions)
 
-		for await (const chunk of stream) {
-			const delta = chunk.choices[0]?.delta
-			const finishReason = chunk.choices[0]?.finish_reason
-
-			if (delta?.content) {
-				yield {
-					type: "text",
-					text: delta.content,
-				}
-			}
-
-			if (delta && "reasoning_content" in delta && delta.reasoning_content) {
-				yield {
-					type: "reasoning",
-					text: delta.reasoning_content as string,
-				}
-			}
-
-			// Handle tool calls in stream - emit partial chunks for NativeToolCallParser
-			if (delta?.tool_calls) {
-				for (const toolCall of delta.tool_calls) {
-					yield {
-						type: "tool_call_partial",
-						index: toolCall.index,
-						id: toolCall.id,
-						name: toolCall.function?.name,
-						arguments: toolCall.function?.arguments,
-					}
-				}
-			}
-
-			// Process finish_reason to emit tool_call_end events
-			// This ensures tool calls are finalized even if the stream doesn't properly close
-			if (finishReason) {
-				const endEvents = NativeToolCallParser.processFinishReason(finishReason)
-				for (const event of endEvents) {
-					yield event
+		try {
+			// Process the full stream to get all events including reasoning
+			for await (const part of result.fullStream) {
+				for (const chunk of processAiSdkStreamPart(part)) {
+					yield chunk
 				}
 			}
 
-			if (chunk.usage) {
-				// Extract detailed token information if available
-				// First check for prompt_tokens_details structure (real API response)
-				const promptDetails = "prompt_tokens_details" in chunk.usage ? chunk.usage.prompt_tokens_details : null
-				const cachedTokens = promptDetails && "cached_tokens" in promptDetails ? promptDetails.cached_tokens : 0
-
-				// Fall back to direct fields in usage (used in test mocks)
-				const readTokens =
-					cachedTokens ||
-					("cache_read_input_tokens" in chunk.usage ? (chunk.usage as any).cache_read_input_tokens : 0)
-				const writeTokens =
-					"cache_creation_input_tokens" in chunk.usage ? (chunk.usage as any).cache_creation_input_tokens : 0
-
-				yield {
-					type: "usage",
-					inputTokens: chunk.usage.prompt_tokens || 0,
-					outputTokens: chunk.usage.completion_tokens || 0,
-					cacheReadTokens: readTokens,
-					cacheWriteTokens: writeTokens,
-				}
+			// Yield usage metrics at the end, including cache metrics from providerMetadata
+			const usage = await result.usage
+			const providerMetadata = await result.providerMetadata
+			if (usage) {
+				yield this.processUsageMetrics(usage, providerMetadata as any)
 			}
+		} catch (error) {
+			// Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.)
+			throw handleAiSdkError(error, "xAI")
 		}
 	}
 
+	/**
+	 * Complete a prompt using the AI SDK generateText.
+	 */
 	async completePrompt(prompt: string): Promise<string> {
-		const { id: modelId, reasoning } = this.getModel()
+		const { temperature, reasoning } = this.getModel()
+		const languageModel = this.getLanguageModel()
 
 		try {
-			const response = await this.client.chat.completions.create({
-				model: modelId,
-				messages: [{ role: "user", content: prompt }],
-				...(reasoning && reasoning),
+			const { text } = await generateText({
+				model: languageModel,
+				prompt,
+				maxOutputTokens: this.getMaxOutputTokens(),
+				temperature: this.options.modelTemperature ?? temperature ?? XAI_DEFAULT_TEMPERATURE,
+				...(reasoning && { providerOptions: { xai: reasoning } }),
 			})
 
-			return response.choices[0]?.message.content || ""
+			return text
 		} catch (error) {
-			const errorMessage = error instanceof Error ? error.message : String(error)
-			const apiError = new ApiProviderError(errorMessage, this.providerName, modelId, "completePrompt")
-			TelemetryService.instance.captureException(apiError)
-			throw handleOpenAIError(error, this.providerName)
+			throw handleAiSdkError(error, "xAI")
 		}
 	}
 }

+ 1 - 0
src/package.json

@@ -455,6 +455,7 @@
 		"@ai-sdk/fireworks": "^2.0.26",
 		"@ai-sdk/groq": "^3.0.19",
 		"@ai-sdk/mistral": "^3.0.0",
+		"@ai-sdk/xai": "^3.0.46",
 		"sambanova-ai-provider": "^1.2.2",
 		"@anthropic-ai/bedrock-sdk": "^0.10.2",
 		"@anthropic-ai/sdk": "^0.37.0",