|
|
@@ -1,587 +1,731 @@
|
|
|
-// npx vitest api/providers/__tests__/xai.spec.ts
|
|
|
-
|
|
|
-// Mock TelemetryService - must come before other imports
|
|
|
-const mockCaptureException = vitest.hoisted(() => vitest.fn())
|
|
|
-vitest.mock("@roo-code/telemetry", () => ({
|
|
|
- TelemetryService: {
|
|
|
- instance: {
|
|
|
- captureException: mockCaptureException,
|
|
|
- },
|
|
|
- },
|
|
|
-}))
|
|
|
-
|
|
|
-const mockCreate = vitest.fn()
|
|
|
+// npx vitest run api/providers/__tests__/xai.spec.ts
|
|
|
|
|
|
-vitest.mock("openai", () => {
|
|
|
- const mockConstructor = vitest.fn()
|
|
|
+// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls
|
|
|
+const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({
|
|
|
+ mockStreamText: vi.fn(),
|
|
|
+ mockGenerateText: vi.fn(),
|
|
|
+}))
|
|
|
|
|
|
+vi.mock("ai", async (importOriginal) => {
|
|
|
+ const actual = await importOriginal<typeof import("ai")>()
|
|
|
return {
|
|
|
- __esModule: true,
|
|
|
- default: mockConstructor.mockImplementation(() => ({ chat: { completions: { create: mockCreate } } })),
|
|
|
+ ...actual,
|
|
|
+ streamText: mockStreamText,
|
|
|
+ generateText: mockGenerateText,
|
|
|
}
|
|
|
})
|
|
|
|
|
|
-import OpenAI from "openai"
|
|
|
+vi.mock("@ai-sdk/xai", () => ({
|
|
|
+ createXai: vi.fn(() => {
|
|
|
+ // Return a function that returns a mock language model
|
|
|
+ return vi.fn(() => ({
|
|
|
+ modelId: "grok-code-fast-1",
|
|
|
+ provider: "xai",
|
|
|
+ }))
|
|
|
+ }),
|
|
|
+}))
|
|
|
+
|
|
|
import type { Anthropic } from "@anthropic-ai/sdk"
|
|
|
|
|
|
-import { xaiDefaultModelId, xaiModels } from "@roo-code/types"
|
|
|
+import { xaiDefaultModelId, xaiModels, type XAIModelId } from "@roo-code/types"
|
|
|
+
|
|
|
+import type { ApiHandlerOptions } from "../../../shared/api"
|
|
|
|
|
|
import { XAIHandler } from "../xai"
|
|
|
|
|
|
describe("XAIHandler", () => {
|
|
|
let handler: XAIHandler
|
|
|
+ let mockOptions: ApiHandlerOptions
|
|
|
|
|
|
beforeEach(() => {
|
|
|
- // Reset all mocks
|
|
|
+ mockOptions = {
|
|
|
+ xaiApiKey: "test-xai-api-key",
|
|
|
+ apiModelId: "grok-code-fast-1",
|
|
|
+ }
|
|
|
+ handler = new XAIHandler(mockOptions)
|
|
|
vi.clearAllMocks()
|
|
|
- mockCreate.mockClear()
|
|
|
- mockCaptureException.mockClear()
|
|
|
-
|
|
|
- // Create handler with mock
|
|
|
- handler = new XAIHandler({})
|
|
|
})
|
|
|
|
|
|
- it("should use the correct X.AI base URL", () => {
|
|
|
- expect(OpenAI).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- baseURL: "https://api.x.ai/v1",
|
|
|
- }),
|
|
|
- )
|
|
|
- })
|
|
|
+ describe("constructor", () => {
|
|
|
+ it("should initialize with provided options", () => {
|
|
|
+ expect(handler).toBeInstanceOf(XAIHandler)
|
|
|
+ expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
|
|
+ })
|
|
|
|
|
|
- it("should use the provided API key", () => {
|
|
|
- // Clear mocks before this specific test
|
|
|
- vi.clearAllMocks()
|
|
|
+ it("should use default model ID if not provided", () => {
|
|
|
+ const handlerWithoutModel = new XAIHandler({
|
|
|
+ ...mockOptions,
|
|
|
+ apiModelId: undefined,
|
|
|
+ })
|
|
|
+ expect(handlerWithoutModel.getModel().id).toBe(xaiDefaultModelId)
|
|
|
+ })
|
|
|
+ })
|
|
|
|
|
|
- // Create a handler with our API key
|
|
|
- const xaiApiKey = "test-api-key"
|
|
|
- new XAIHandler({ xaiApiKey })
|
|
|
+ describe("getModel", () => {
|
|
|
+ it("should return default model when no model is specified", () => {
|
|
|
+ const handlerWithoutModel = new XAIHandler({
|
|
|
+ xaiApiKey: "test-xai-api-key",
|
|
|
+ })
|
|
|
+ const model = handlerWithoutModel.getModel()
|
|
|
+ expect(model.id).toBe(xaiDefaultModelId)
|
|
|
+ expect(model.info).toEqual(xaiModels[xaiDefaultModelId])
|
|
|
+ })
|
|
|
|
|
|
- // Verify the OpenAI constructor was called with our API key
|
|
|
- expect(OpenAI).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- apiKey: xaiApiKey,
|
|
|
- }),
|
|
|
- )
|
|
|
- })
|
|
|
+ it("should return specified model when valid model is provided", () => {
|
|
|
+ const testModelId: XAIModelId = "grok-3"
|
|
|
+ const handlerWithModel = new XAIHandler({
|
|
|
+ apiModelId: testModelId,
|
|
|
+ xaiApiKey: "test-xai-api-key",
|
|
|
+ })
|
|
|
+ const model = handlerWithModel.getModel()
|
|
|
+ expect(model.id).toBe(testModelId)
|
|
|
+ expect(model.info).toEqual(xaiModels[testModelId])
|
|
|
+ })
|
|
|
|
|
|
- it("should return default model when no model is specified", () => {
|
|
|
- const model = handler.getModel()
|
|
|
- expect(model.id).toBe(xaiDefaultModelId)
|
|
|
- expect(model.info).toEqual(xaiModels[xaiDefaultModelId])
|
|
|
- })
|
|
|
+ it("should return grok-3-mini model with correct configuration", () => {
|
|
|
+ const testModelId: XAIModelId = "grok-3-mini"
|
|
|
+ const handlerWithModel = new XAIHandler({
|
|
|
+ apiModelId: testModelId,
|
|
|
+ xaiApiKey: "test-xai-api-key",
|
|
|
+ })
|
|
|
+ const model = handlerWithModel.getModel()
|
|
|
+ expect(model.id).toBe(testModelId)
|
|
|
+ expect(model.info).toEqual(
|
|
|
+ expect.objectContaining({
|
|
|
+ maxTokens: 8192,
|
|
|
+ contextWindow: 131072,
|
|
|
+ supportsImages: true,
|
|
|
+ supportsPromptCache: true,
|
|
|
+ inputPrice: 0.3,
|
|
|
+ outputPrice: 0.5,
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
|
|
|
- test("should return specified model when valid model is provided", () => {
|
|
|
- const testModelId = "grok-3"
|
|
|
- const handlerWithModel = new XAIHandler({ apiModelId: testModelId })
|
|
|
- const model = handlerWithModel.getModel()
|
|
|
+ it("should return grok-4-0709 model with correct configuration", () => {
|
|
|
+ const testModelId: XAIModelId = "grok-4-0709"
|
|
|
+ const handlerWithModel = new XAIHandler({
|
|
|
+ apiModelId: testModelId,
|
|
|
+ xaiApiKey: "test-xai-api-key",
|
|
|
+ })
|
|
|
+ const model = handlerWithModel.getModel()
|
|
|
+ expect(model.id).toBe(testModelId)
|
|
|
+ expect(model.info).toEqual(
|
|
|
+ expect.objectContaining({
|
|
|
+ maxTokens: 8192,
|
|
|
+ contextWindow: 256_000,
|
|
|
+ supportsImages: true,
|
|
|
+ supportsPromptCache: true,
|
|
|
+ inputPrice: 3.0,
|
|
|
+ outputPrice: 15.0,
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
|
|
|
- expect(model.id).toBe(testModelId)
|
|
|
- expect(model.info).toEqual(xaiModels[testModelId])
|
|
|
- })
|
|
|
+ it("should fall back to default model for invalid model ID", () => {
|
|
|
+ const handlerWithInvalidModel = new XAIHandler({
|
|
|
+ ...mockOptions,
|
|
|
+ apiModelId: "invalid-model",
|
|
|
+ })
|
|
|
+ const model = handlerWithInvalidModel.getModel()
|
|
|
+ expect(model.id).toBe(xaiDefaultModelId)
|
|
|
+ expect(model.info).toBe(xaiModels[xaiDefaultModelId])
|
|
|
+ })
|
|
|
|
|
|
- it("should include reasoning_effort parameter for mini models", async () => {
|
|
|
- const miniModelHandler = new XAIHandler({
|
|
|
- apiModelId: "grok-3-mini",
|
|
|
- reasoningEffort: "high",
|
|
|
+ it("should include model parameters from getModelParams", () => {
|
|
|
+ const model = handler.getModel()
|
|
|
+ expect(model).toHaveProperty("temperature")
|
|
|
+ expect(model).toHaveProperty("maxTokens")
|
|
|
})
|
|
|
+ })
|
|
|
|
|
|
- // Setup mock for streaming response
|
|
|
- mockCreate.mockImplementationOnce(() => {
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: () => ({
|
|
|
- async next() {
|
|
|
- return { done: true }
|
|
|
+ describe("createMessage", () => {
|
|
|
+ const systemPrompt = "You are a helpful assistant."
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: [
|
|
|
+ {
|
|
|
+ type: "text" as const,
|
|
|
+ text: "Hello!",
|
|
|
},
|
|
|
- }),
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ it("should handle streaming responses", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test response from xAI" }
|
|
|
}
|
|
|
- })
|
|
|
|
|
|
- // Start generating a message
|
|
|
- const messageGenerator = miniModelHandler.createMessage("test prompt", [])
|
|
|
- await messageGenerator.next() // Start the generator
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
+ })
|
|
|
|
|
|
- // Check that reasoning_effort was included
|
|
|
- expect(mockCreate).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- reasoning_effort: "high",
|
|
|
- }),
|
|
|
- )
|
|
|
- })
|
|
|
+ const mockProviderMetadata = Promise.resolve({})
|
|
|
|
|
|
- it("should not include reasoning_effort parameter for non-mini models", async () => {
|
|
|
- const regularModelHandler = new XAIHandler({
|
|
|
- apiModelId: "grok-3",
|
|
|
- reasoningEffort: "high",
|
|
|
- })
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ providerMetadata: mockProviderMetadata,
|
|
|
+ })
|
|
|
|
|
|
- // Setup mock for streaming response
|
|
|
- mockCreate.mockImplementationOnce(() => {
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: () => ({
|
|
|
- async next() {
|
|
|
- return { done: true }
|
|
|
- },
|
|
|
- }),
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
}
|
|
|
+
|
|
|
+ expect(chunks.length).toBeGreaterThan(0)
|
|
|
+ const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
|
|
+ expect(textChunks).toHaveLength(1)
|
|
|
+ expect(textChunks[0].text).toBe("Test response from xAI")
|
|
|
})
|
|
|
|
|
|
- // Start generating a message
|
|
|
- const messageGenerator = regularModelHandler.createMessage("test prompt", [])
|
|
|
- await messageGenerator.next() // Start the generator
|
|
|
+ it("should include usage information", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test response" }
|
|
|
+ }
|
|
|
|
|
|
- // Check call args for reasoning_effort
|
|
|
- const calls = mockCreate.mock.calls
|
|
|
- const lastCall = calls[calls.length - 1][0]
|
|
|
- expect(lastCall).not.toHaveProperty("reasoning_effort")
|
|
|
- })
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 20,
|
|
|
+ })
|
|
|
|
|
|
- it("completePrompt method should return text from OpenAI API", async () => {
|
|
|
- const expectedResponse = "This is a test response"
|
|
|
- mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })
|
|
|
+ const mockProviderMetadata = Promise.resolve({})
|
|
|
|
|
|
- const result = await handler.completePrompt("test prompt")
|
|
|
- expect(result).toBe(expectedResponse)
|
|
|
- })
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ providerMetadata: mockProviderMetadata,
|
|
|
+ })
|
|
|
|
|
|
- it("should handle errors in completePrompt", async () => {
|
|
|
- const errorMessage = "API error"
|
|
|
- mockCreate.mockRejectedValueOnce(new Error(errorMessage))
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
|
|
|
- await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`)
|
|
|
- })
|
|
|
+ const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
|
|
|
+ expect(usageChunks.length).toBeGreaterThan(0)
|
|
|
+ expect(usageChunks[0].inputTokens).toBe(10)
|
|
|
+ expect(usageChunks[0].outputTokens).toBe(20)
|
|
|
+ })
|
|
|
|
|
|
- it("createMessage should yield text content from stream", async () => {
|
|
|
- const testContent = "This is test content"
|
|
|
-
|
|
|
- // Setup mock for streaming response
|
|
|
- mockCreate.mockImplementationOnce(() => {
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: () => ({
|
|
|
- next: vi
|
|
|
- .fn()
|
|
|
- .mockResolvedValueOnce({
|
|
|
- done: false,
|
|
|
- value: {
|
|
|
- choices: [{ delta: { content: testContent } }],
|
|
|
- },
|
|
|
- })
|
|
|
- .mockResolvedValueOnce({ done: true }),
|
|
|
- }),
|
|
|
+ it("should handle cached tokens in usage data from providerMetadata", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test response" }
|
|
|
}
|
|
|
- })
|
|
|
|
|
|
- // Create and consume the stream
|
|
|
- const stream = handler.createMessage("system prompt", [])
|
|
|
- const firstChunk = await stream.next()
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 100,
|
|
|
+ outputTokens: 50,
|
|
|
+ })
|
|
|
|
|
|
- // Verify the content
|
|
|
- expect(firstChunk.done).toBe(false)
|
|
|
- expect(firstChunk.value).toEqual({
|
|
|
- type: "text",
|
|
|
- text: testContent,
|
|
|
- })
|
|
|
- })
|
|
|
+ // xAI provides cache metrics via providerMetadata for supported models
|
|
|
+ const mockProviderMetadata = Promise.resolve({
|
|
|
+ xai: {
|
|
|
+ cachedPromptTokens: 30,
|
|
|
+ },
|
|
|
+ })
|
|
|
|
|
|
- it("createMessage should yield reasoning content from stream", async () => {
|
|
|
- const testReasoning = "Test reasoning content"
|
|
|
-
|
|
|
- // Setup mock for streaming response
|
|
|
- mockCreate.mockImplementationOnce(() => {
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: () => ({
|
|
|
- next: vi
|
|
|
- .fn()
|
|
|
- .mockResolvedValueOnce({
|
|
|
- done: false,
|
|
|
- value: {
|
|
|
- choices: [{ delta: { reasoning_content: testReasoning } }],
|
|
|
- },
|
|
|
- })
|
|
|
- .mockResolvedValueOnce({ done: true }),
|
|
|
- }),
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ providerMetadata: mockProviderMetadata,
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
}
|
|
|
+
|
|
|
+ const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
|
|
|
+ expect(usageChunks.length).toBeGreaterThan(0)
|
|
|
+ expect(usageChunks[0].inputTokens).toBe(100)
|
|
|
+ expect(usageChunks[0].outputTokens).toBe(50)
|
|
|
+ expect(usageChunks[0].cacheReadTokens).toBe(30)
|
|
|
})
|
|
|
|
|
|
- // Create and consume the stream
|
|
|
- const stream = handler.createMessage("system prompt", [])
|
|
|
- const firstChunk = await stream.next()
|
|
|
+ it("should handle usage with details.cachedInputTokens when providerMetadata is not available", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test response" }
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 100,
|
|
|
+ outputTokens: 50,
|
|
|
+ details: {
|
|
|
+ cachedInputTokens: 25,
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ const mockProviderMetadata = Promise.resolve({})
|
|
|
|
|
|
- // Verify the reasoning content
|
|
|
- expect(firstChunk.done).toBe(false)
|
|
|
- expect(firstChunk.value).toEqual({
|
|
|
- type: "reasoning",
|
|
|
- text: testReasoning,
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ providerMetadata: mockProviderMetadata,
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
|
|
|
+ expect(usageChunks.length).toBeGreaterThan(0)
|
|
|
+ expect(usageChunks[0].cacheReadTokens).toBe(25)
|
|
|
+ expect(usageChunks[0].cacheWriteTokens).toBeUndefined()
|
|
|
})
|
|
|
- })
|
|
|
|
|
|
- it("createMessage should yield usage data from stream", async () => {
|
|
|
- // Setup mock for streaming response that includes usage data
|
|
|
- mockCreate.mockImplementationOnce(() => {
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: () => ({
|
|
|
- next: vi
|
|
|
- .fn()
|
|
|
- .mockResolvedValueOnce({
|
|
|
- done: false,
|
|
|
- value: {
|
|
|
- choices: [{ delta: {} }], // Needs to have choices array to avoid error
|
|
|
- usage: {
|
|
|
- prompt_tokens: 10,
|
|
|
- completion_tokens: 20,
|
|
|
- cache_read_input_tokens: 5,
|
|
|
- cache_creation_input_tokens: 15,
|
|
|
- },
|
|
|
- },
|
|
|
- })
|
|
|
- .mockResolvedValueOnce({ done: true }),
|
|
|
+ it("should pass correct temperature (0 default) to streamText", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
+ providerMetadata: Promise.resolve({}),
|
|
|
+ })
|
|
|
+
|
|
|
+ const handlerWithDefaultTemp = new XAIHandler({
|
|
|
+ xaiApiKey: "test-key",
|
|
|
+ apiModelId: "grok-code-fast-1",
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages)
|
|
|
+ for await (const _ of stream) {
|
|
|
+ // consume stream
|
|
|
+ }
|
|
|
+
|
|
|
+ expect(mockStreamText).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ temperature: 0,
|
|
|
}),
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should use user-specified temperature over default", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
+ providerMetadata: Promise.resolve({}),
|
|
|
+ })
|
|
|
+
|
|
|
+ const handlerWithCustomTemp = new XAIHandler({
|
|
|
+ xaiApiKey: "test-key",
|
|
|
+ apiModelId: "grok-3",
|
|
|
+ modelTemperature: 0.7,
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handlerWithCustomTemp.createMessage(systemPrompt, messages)
|
|
|
+ for await (const _ of stream) {
|
|
|
+ // consume stream
|
|
|
}
|
|
|
+
|
|
|
+ // User-specified temperature should take precedence over everything
|
|
|
+ expect(mockStreamText).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ temperature: 0.7,
|
|
|
+ }),
|
|
|
+ )
|
|
|
})
|
|
|
|
|
|
- // Create and consume the stream
|
|
|
- const stream = handler.createMessage("system prompt", [])
|
|
|
- const firstChunk = await stream.next()
|
|
|
-
|
|
|
- // Verify the usage data
|
|
|
- expect(firstChunk.done).toBe(false)
|
|
|
- expect(firstChunk.value).toEqual({
|
|
|
- type: "usage",
|
|
|
- inputTokens: 10,
|
|
|
- outputTokens: 20,
|
|
|
- cacheReadTokens: 5,
|
|
|
- cacheWriteTokens: 15,
|
|
|
+ it("should handle stream with multiple chunks", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Hello" }
|
|
|
+ yield { type: "text-delta", text: " world" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 5, outputTokens: 10 }),
|
|
|
+ providerMetadata: Promise.resolve({}),
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ const textChunks = chunks.filter((c) => c.type === "text")
|
|
|
+ expect(textChunks[0]).toEqual({ type: "text", text: "Hello" })
|
|
|
+ expect(textChunks[1]).toEqual({ type: "text", text: " world" })
|
|
|
+
|
|
|
+ const usageChunks = chunks.filter((c) => c.type === "usage")
|
|
|
+ expect(usageChunks[0]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 })
|
|
|
})
|
|
|
- })
|
|
|
|
|
|
- it("createMessage should pass correct parameters to OpenAI client", async () => {
|
|
|
- // Setup a handler with specific model
|
|
|
- const modelId = "grok-3"
|
|
|
- const modelInfo = xaiModels[modelId]
|
|
|
- const handlerWithModel = new XAIHandler({ apiModelId: modelId })
|
|
|
-
|
|
|
- // Setup mock for streaming response
|
|
|
- mockCreate.mockImplementationOnce(() => {
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: () => ({
|
|
|
- async next() {
|
|
|
- return { done: true }
|
|
|
- },
|
|
|
- }),
|
|
|
+ it("should handle reasoning content from stream", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "reasoning-delta", text: "Let me think about this..." }
|
|
|
+ yield { type: "text-delta", text: "Here is my answer" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 10, outputTokens: 20 }),
|
|
|
+ providerMetadata: Promise.resolve({}),
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
}
|
|
|
+
|
|
|
+ const reasoningChunks = chunks.filter((c) => c.type === "reasoning")
|
|
|
+ expect(reasoningChunks).toHaveLength(1)
|
|
|
+ expect(reasoningChunks[0].text).toBe("Let me think about this...")
|
|
|
+
|
|
|
+ const textChunks = chunks.filter((c) => c.type === "text")
|
|
|
+ expect(textChunks).toHaveLength(1)
|
|
|
+ expect(textChunks[0].text).toBe("Here is my answer")
|
|
|
})
|
|
|
|
|
|
- // System prompt and messages
|
|
|
- const systemPrompt = "Test system prompt"
|
|
|
- const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]
|
|
|
-
|
|
|
- // Start generating a message
|
|
|
- const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
|
|
|
- await messageGenerator.next() // Start the generator
|
|
|
-
|
|
|
- // Check that all parameters were passed correctly
|
|
|
- expect(mockCreate).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- model: modelId,
|
|
|
- max_tokens: modelInfo.maxTokens,
|
|
|
- temperature: 0,
|
|
|
- messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
|
|
|
- stream: true,
|
|
|
- stream_options: { include_usage: true },
|
|
|
- }),
|
|
|
- )
|
|
|
- })
|
|
|
+ it("should handle errors during streaming", async () => {
|
|
|
+ const mockError = new Error("API error")
|
|
|
+ ;(mockError as any).name = "AI_APICallError"
|
|
|
+ ;(mockError as any).status = 500
|
|
|
|
|
|
- describe("Native Tool Calling", () => {
|
|
|
- const testTools = [
|
|
|
- {
|
|
|
- type: "function" as const,
|
|
|
- function: {
|
|
|
- name: "test_tool",
|
|
|
- description: "A test tool",
|
|
|
- parameters: {
|
|
|
- type: "object",
|
|
|
- properties: {
|
|
|
- arg1: { type: "string", description: "First argument" },
|
|
|
- },
|
|
|
- required: ["arg1"],
|
|
|
- },
|
|
|
- },
|
|
|
- },
|
|
|
- ]
|
|
|
+ async function* mockFullStream(): AsyncGenerator<never> {
|
|
|
+ // This yield is unreachable but needed to satisfy the require-yield lint rule
|
|
|
+ yield undefined as never
|
|
|
+ throw mockError
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
+ providerMetadata: Promise.resolve({}),
|
|
|
+ })
|
|
|
|
|
|
- it("should include tools in request when model supports native tools and tools are provided (native is default)", async () => {
|
|
|
- const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
|
|
|
- mockCreate.mockImplementationOnce(() => {
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: () => ({
|
|
|
- async next() {
|
|
|
- return { done: true }
|
|
|
- },
|
|
|
- }),
|
|
|
+ await expect(async () => {
|
|
|
+ for await (const _ of stream) {
|
|
|
+ // consume stream
|
|
|
}
|
|
|
- })
|
|
|
+ }).rejects.toThrow("xAI")
|
|
|
+ })
|
|
|
+ })
|
|
|
|
|
|
- const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
|
|
|
- taskId: "test-task-id",
|
|
|
- tools: testTools,
|
|
|
+ describe("completePrompt", () => {
|
|
|
+ it("should complete a prompt using generateText", async () => {
|
|
|
+ mockGenerateText.mockResolvedValue({
|
|
|
+ text: "Test completion from xAI",
|
|
|
})
|
|
|
- await messageGenerator.next()
|
|
|
|
|
|
- expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ const result = await handler.completePrompt("Test prompt")
|
|
|
+
|
|
|
+ expect(result).toBe("Test completion from xAI")
|
|
|
+ expect(mockGenerateText).toHaveBeenCalledWith(
|
|
|
expect.objectContaining({
|
|
|
- tools: expect.arrayContaining([
|
|
|
- expect.objectContaining({
|
|
|
- type: "function",
|
|
|
- function: expect.objectContaining({
|
|
|
- name: "test_tool",
|
|
|
- }),
|
|
|
- }),
|
|
|
- ]),
|
|
|
- parallel_tool_calls: true,
|
|
|
+ prompt: "Test prompt",
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
|
|
|
- it("should include tool_choice when provided", async () => {
|
|
|
- const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
|
|
|
-
|
|
|
- mockCreate.mockImplementationOnce(() => {
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: () => ({
|
|
|
- async next() {
|
|
|
- return { done: true }
|
|
|
- },
|
|
|
- }),
|
|
|
- }
|
|
|
+ it("should use default temperature in completePrompt", async () => {
|
|
|
+ mockGenerateText.mockResolvedValue({
|
|
|
+ text: "Test completion",
|
|
|
})
|
|
|
|
|
|
- const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
|
|
|
- taskId: "test-task-id",
|
|
|
- tools: testTools,
|
|
|
- tool_choice: "auto",
|
|
|
- })
|
|
|
- await messageGenerator.next()
|
|
|
+ await handler.completePrompt("Test prompt")
|
|
|
|
|
|
- expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ expect(mockGenerateText).toHaveBeenCalledWith(
|
|
|
expect.objectContaining({
|
|
|
- tool_choice: "auto",
|
|
|
+ temperature: 0,
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
|
|
|
- it("should always include tools and tool_choice (tools are guaranteed to be present after ALWAYS_AVAILABLE_TOOLS)", async () => {
|
|
|
- const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
|
|
|
+ it("should handle errors in completePrompt", async () => {
|
|
|
+ const mockError = new Error("API error")
|
|
|
+ ;(mockError as any).name = "AI_APICallError"
|
|
|
+ mockGenerateText.mockRejectedValue(mockError)
|
|
|
|
|
|
- mockCreate.mockImplementationOnce(() => {
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: () => ({
|
|
|
- async next() {
|
|
|
- return { done: true }
|
|
|
- },
|
|
|
- }),
|
|
|
+ await expect(handler.completePrompt("Test prompt")).rejects.toThrow("xAI")
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ describe("processUsageMetrics", () => {
|
|
|
+ it("should correctly process usage metrics including cache information from providerMetadata", () => {
|
|
|
+ class TestXAIHandler extends XAIHandler {
|
|
|
+ public testProcessUsageMetrics(usage: any, providerMetadata?: any) {
|
|
|
+ return this.processUsageMetrics(usage, providerMetadata)
|
|
|
}
|
|
|
- })
|
|
|
+ }
|
|
|
|
|
|
- const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
|
|
|
- taskId: "test-task-id",
|
|
|
- })
|
|
|
- await messageGenerator.next()
|
|
|
+ const testHandler = new TestXAIHandler(mockOptions)
|
|
|
+
|
|
|
+ const usage = {
|
|
|
+ inputTokens: 100,
|
|
|
+ outputTokens: 50,
|
|
|
+ }
|
|
|
+
|
|
|
+ const providerMetadata = {
|
|
|
+ xai: {
|
|
|
+ cachedPromptTokens: 20,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ const result = testHandler.testProcessUsageMetrics(usage, providerMetadata)
|
|
|
+
|
|
|
+ expect(result.type).toBe("usage")
|
|
|
+ expect(result.inputTokens).toBe(100)
|
|
|
+ expect(result.outputTokens).toBe(50)
|
|
|
+ expect(result.cacheReadTokens).toBe(20)
|
|
|
+ // xAI doesn't report cache write tokens separately
|
|
|
+ expect(result.cacheWriteTokens).toBeUndefined()
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle missing cache metrics gracefully", () => {
|
|
|
+ class TestXAIHandler extends XAIHandler {
|
|
|
+ public testProcessUsageMetrics(usage: any, providerMetadata?: any) {
|
|
|
+ return this.processUsageMetrics(usage, providerMetadata)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const testHandler = new TestXAIHandler(mockOptions)
|
|
|
+
|
|
|
+ const usage = {
|
|
|
+ inputTokens: 100,
|
|
|
+ outputTokens: 50,
|
|
|
+ }
|
|
|
+
|
|
|
+ const result = testHandler.testProcessUsageMetrics(usage)
|
|
|
+
|
|
|
+ expect(result.type).toBe("usage")
|
|
|
+ expect(result.inputTokens).toBe(100)
|
|
|
+ expect(result.outputTokens).toBe(50)
|
|
|
+ expect(result.cacheWriteTokens).toBeUndefined()
|
|
|
+ expect(result.cacheReadTokens).toBeUndefined()
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should include reasoning tokens when provided", () => {
|
|
|
+ class TestXAIHandler extends XAIHandler {
|
|
|
+ public testProcessUsageMetrics(usage: any, providerMetadata?: any) {
|
|
|
+ return this.processUsageMetrics(usage, providerMetadata)
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS)
|
|
|
- const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0]
|
|
|
- expect(callArgs).toHaveProperty("tools")
|
|
|
- expect(callArgs).toHaveProperty("tool_choice")
|
|
|
- expect(callArgs).toHaveProperty("parallel_tool_calls", true)
|
|
|
+ const testHandler = new TestXAIHandler(mockOptions)
|
|
|
+
|
|
|
+ const usage = {
|
|
|
+ inputTokens: 100,
|
|
|
+ outputTokens: 50,
|
|
|
+ details: {
|
|
|
+ reasoningTokens: 30,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ const result = testHandler.testProcessUsageMetrics(usage)
|
|
|
+
|
|
|
+ expect(result.reasoningTokens).toBe(30)
|
|
|
})
|
|
|
+ })
|
|
|
|
|
|
- it("should yield tool_call_partial chunks during streaming", async () => {
|
|
|
- const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
|
|
|
-
|
|
|
- mockCreate.mockImplementationOnce(() => {
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: () => ({
|
|
|
- next: vi
|
|
|
- .fn()
|
|
|
- .mockResolvedValueOnce({
|
|
|
- done: false,
|
|
|
- value: {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- delta: {
|
|
|
- tool_calls: [
|
|
|
- {
|
|
|
- index: 0,
|
|
|
- id: "call_123",
|
|
|
- function: {
|
|
|
- name: "test_tool",
|
|
|
- arguments: '{"arg1":',
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- })
|
|
|
- .mockResolvedValueOnce({
|
|
|
- done: false,
|
|
|
- value: {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- delta: {
|
|
|
- tool_calls: [
|
|
|
- {
|
|
|
- index: 0,
|
|
|
- function: {
|
|
|
- arguments: '"value"}',
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- })
|
|
|
- .mockResolvedValueOnce({ done: true }),
|
|
|
- }),
|
|
|
+ describe("tool handling", () => {
|
|
|
+ const systemPrompt = "You are a helpful assistant."
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: [{ type: "text" as const, text: "Hello!" }],
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ it("should handle tool calls in streaming", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield {
|
|
|
+ type: "tool-input-start",
|
|
|
+ id: "tool-call-1",
|
|
|
+ toolName: "read_file",
|
|
|
+ }
|
|
|
+ yield {
|
|
|
+ type: "tool-input-delta",
|
|
|
+ id: "tool-call-1",
|
|
|
+ delta: '{"path":"test.ts"}',
|
|
|
+ }
|
|
|
+ yield {
|
|
|
+ type: "tool-input-end",
|
|
|
+ id: "tool-call-1",
|
|
|
}
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
})
|
|
|
|
|
|
- const stream = handlerWithTools.createMessage("test prompt", [], {
|
|
|
- taskId: "test-task-id",
|
|
|
- tools: testTools,
|
|
|
+ const mockProviderMetadata = Promise.resolve({})
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ providerMetadata: mockProviderMetadata,
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages, {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: [
|
|
|
+ {
|
|
|
+ type: "function",
|
|
|
+ function: {
|
|
|
+ name: "read_file",
|
|
|
+ description: "Read a file",
|
|
|
+ parameters: {
|
|
|
+ type: "object",
|
|
|
+ properties: { path: { type: "string" } },
|
|
|
+ required: ["path"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
})
|
|
|
|
|
|
- const chunks = []
|
|
|
+ const chunks: any[] = []
|
|
|
for await (const chunk of stream) {
|
|
|
chunks.push(chunk)
|
|
|
}
|
|
|
|
|
|
- expect(chunks).toContainEqual({
|
|
|
- type: "tool_call_partial",
|
|
|
- index: 0,
|
|
|
- id: "call_123",
|
|
|
- name: "test_tool",
|
|
|
- arguments: '{"arg1":',
|
|
|
+ const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start")
|
|
|
+ const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta")
|
|
|
+ const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end")
|
|
|
+
|
|
|
+ expect(toolCallStartChunks.length).toBe(1)
|
|
|
+ expect(toolCallStartChunks[0].id).toBe("tool-call-1")
|
|
|
+ expect(toolCallStartChunks[0].name).toBe("read_file")
|
|
|
+
|
|
|
+ expect(toolCallDeltaChunks.length).toBe(1)
|
|
|
+ expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}')
|
|
|
+
|
|
|
+ expect(toolCallEndChunks.length).toBe(1)
|
|
|
+ expect(toolCallEndChunks[0].id).toBe("tool-call-1")
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should ignore tool-call events to prevent duplicate tools in UI", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield {
|
|
|
+ type: "tool-call",
|
|
|
+ toolCallId: "tool-call-1",
|
|
|
+ toolName: "read_file",
|
|
|
+ input: { path: "test.ts" },
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
})
|
|
|
|
|
|
- expect(chunks).toContainEqual({
|
|
|
- type: "tool_call_partial",
|
|
|
- index: 0,
|
|
|
- id: undefined,
|
|
|
- name: undefined,
|
|
|
- arguments: '"value"}',
|
|
|
+ const mockProviderMetadata = Promise.resolve({})
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ providerMetadata: mockProviderMetadata,
|
|
|
})
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ // tool-call events should be ignored (only tool-input-start/delta/end are processed)
|
|
|
+ const toolCallChunks = chunks.filter(
|
|
|
+ (c) => c.type === "tool_call_start" || c.type === "tool_call_delta" || c.type === "tool_call_end",
|
|
|
+ )
|
|
|
+ expect(toolCallChunks.length).toBe(0)
|
|
|
})
|
|
|
|
|
|
- it("should set parallel_tool_calls based on metadata", async () => {
|
|
|
- const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
|
|
|
+ it("should pass tools to streamText when provided", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test" }
|
|
|
+ }
|
|
|
|
|
|
- mockCreate.mockImplementationOnce(() => {
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: () => ({
|
|
|
- async next() {
|
|
|
- return { done: true }
|
|
|
- },
|
|
|
- }),
|
|
|
- }
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
+ providerMetadata: Promise.resolve({}),
|
|
|
})
|
|
|
|
|
|
- const messageGenerator = handlerWithTools.createMessage("test prompt", [], {
|
|
|
- taskId: "test-task-id",
|
|
|
+ const testTools = [
|
|
|
+ {
|
|
|
+ type: "function" as const,
|
|
|
+ function: {
|
|
|
+ name: "test_tool",
|
|
|
+ description: "A test tool",
|
|
|
+ parameters: {
|
|
|
+ type: "object",
|
|
|
+ properties: {
|
|
|
+ arg1: { type: "string", description: "First argument" },
|
|
|
+ },
|
|
|
+ required: ["arg1"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages, {
|
|
|
+ taskId: "test-task",
|
|
|
tools: testTools,
|
|
|
- parallelToolCalls: true,
|
|
|
+ tool_choice: "auto",
|
|
|
})
|
|
|
- await messageGenerator.next()
|
|
|
|
|
|
- expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ for await (const _ of stream) {
|
|
|
+ // consume stream
|
|
|
+ }
|
|
|
+
|
|
|
+ expect(mockStreamText).toHaveBeenCalledWith(
|
|
|
expect.objectContaining({
|
|
|
- parallel_tool_calls: true,
|
|
|
+ tools: expect.any(Object),
|
|
|
+ toolChoice: "auto",
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
+ })
|
|
|
|
|
|
- it("should yield tool_call_end events when finish_reason is tool_calls", async () => {
|
|
|
- // Import NativeToolCallParser to set up state
|
|
|
- const { NativeToolCallParser } = await import("../../../core/assistant-message/NativeToolCallParser")
|
|
|
-
|
|
|
- // Clear any previous state
|
|
|
- NativeToolCallParser.clearRawChunkState()
|
|
|
-
|
|
|
- const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" })
|
|
|
-
|
|
|
- mockCreate.mockImplementationOnce(() => {
|
|
|
- return {
|
|
|
- [Symbol.asyncIterator]: () => ({
|
|
|
- next: vi
|
|
|
- .fn()
|
|
|
- .mockResolvedValueOnce({
|
|
|
- done: false,
|
|
|
- value: {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- delta: {
|
|
|
- tool_calls: [
|
|
|
- {
|
|
|
- index: 0,
|
|
|
- id: "call_xai_test",
|
|
|
- function: {
|
|
|
- name: "test_tool",
|
|
|
- arguments: '{"arg1":"value"}',
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- })
|
|
|
- .mockResolvedValueOnce({
|
|
|
- done: false,
|
|
|
- value: {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- delta: {},
|
|
|
- finish_reason: "tool_calls",
|
|
|
- },
|
|
|
- ],
|
|
|
- usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
|
|
|
- },
|
|
|
- })
|
|
|
- .mockResolvedValueOnce({ done: true }),
|
|
|
- }),
|
|
|
- }
|
|
|
+ describe("reasoning effort (mini models)", () => {
|
|
|
+ it("should include reasoning effort for grok-3-mini model", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test" }
|
|
|
+ }
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }),
|
|
|
+ providerMetadata: Promise.resolve({}),
|
|
|
})
|
|
|
|
|
|
- const stream = handlerWithTools.createMessage("test prompt", [], {
|
|
|
- taskId: "test-task-id",
|
|
|
- tools: testTools,
|
|
|
+ const miniModelHandler = new XAIHandler({
|
|
|
+ xaiApiKey: "test-key",
|
|
|
+ apiModelId: "grok-3-mini",
|
|
|
+ reasoningEffort: "high",
|
|
|
})
|
|
|
|
|
|
- const chunks = []
|
|
|
- for await (const chunk of stream) {
|
|
|
- // Simulate what Task.ts does: when we receive tool_call_partial,
|
|
|
- // process it through NativeToolCallParser to populate rawChunkTracker
|
|
|
- if (chunk.type === "tool_call_partial") {
|
|
|
- NativeToolCallParser.processRawChunk({
|
|
|
- index: chunk.index,
|
|
|
- id: chunk.id,
|
|
|
- name: chunk.name,
|
|
|
- arguments: chunk.arguments,
|
|
|
- })
|
|
|
- }
|
|
|
- chunks.push(chunk)
|
|
|
+ const stream = miniModelHandler.createMessage("test prompt", [])
|
|
|
+ for await (const _ of stream) {
|
|
|
+ // consume stream
|
|
|
}
|
|
|
|
|
|
- // Should have tool_call_partial and tool_call_end
|
|
|
- const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
|
|
|
- const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end")
|
|
|
-
|
|
|
- expect(partialChunks).toHaveLength(1)
|
|
|
- expect(endChunks).toHaveLength(1)
|
|
|
- expect(endChunks[0].id).toBe("call_xai_test")
|
|
|
+ // Check that provider options are passed for reasoning
|
|
|
+ expect(mockStreamText).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ providerOptions: expect.any(Object),
|
|
|
+ }),
|
|
|
+ )
|
|
|
})
|
|
|
})
|
|
|
})
|