|
|
@@ -1,59 +1,36 @@
|
|
|
-// Mock TelemetryService - must come before other imports
|
|
|
-const mockCaptureException = vi.hoisted(() => vi.fn())
|
|
|
-vi.mock("@roo-code/telemetry", () => ({
|
|
|
- TelemetryService: {
|
|
|
- instance: {
|
|
|
- captureException: mockCaptureException,
|
|
|
- },
|
|
|
- },
|
|
|
+// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls
|
|
|
+const { mockStreamText, mockGenerateText, mockCreateMistral } = vi.hoisted(() => ({
|
|
|
+ mockStreamText: vi.fn(),
|
|
|
+ mockGenerateText: vi.fn(),
|
|
|
+ mockCreateMistral: vi.fn(() => {
|
|
|
+ // Return a function that returns a mock language model
|
|
|
+ return vi.fn(() => ({
|
|
|
+ modelId: "codestral-latest",
|
|
|
+ provider: "mistral",
|
|
|
+ }))
|
|
|
+ }),
|
|
|
}))
|
|
|
|
|
|
-// Mock Mistral client - must come before other imports
|
|
|
-const mockCreate = vi.fn()
|
|
|
-const mockComplete = vi.fn()
|
|
|
-vi.mock("@mistralai/mistralai", () => {
|
|
|
+vi.mock("ai", async (importOriginal) => {
|
|
|
+ const actual = await importOriginal<typeof import("ai")>()
|
|
|
return {
|
|
|
- Mistral: vi.fn().mockImplementation(() => ({
|
|
|
- chat: {
|
|
|
- stream: mockCreate.mockImplementation(async (_options) => {
|
|
|
- const stream = {
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
- yield {
|
|
|
- data: {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- delta: { content: "Test response" },
|
|
|
- index: 0,
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- }
|
|
|
- },
|
|
|
- }
|
|
|
- return stream
|
|
|
- }),
|
|
|
- complete: mockComplete.mockImplementation(async (_options) => {
|
|
|
- return {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- message: {
|
|
|
- content: "Test response",
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
- }
|
|
|
- }),
|
|
|
- },
|
|
|
- })),
|
|
|
+ ...actual,
|
|
|
+ streamText: mockStreamText,
|
|
|
+ generateText: mockGenerateText,
|
|
|
}
|
|
|
})
|
|
|
|
|
|
+vi.mock("@ai-sdk/mistral", () => ({
|
|
|
+ createMistral: mockCreateMistral,
|
|
|
+}))
|
|
|
+
|
|
|
import type { Anthropic } from "@anthropic-ai/sdk"
|
|
|
-import type OpenAI from "openai"
|
|
|
-import { MistralHandler } from "../mistral"
|
|
|
+
|
|
|
+import { mistralDefaultModelId, mistralModels, type MistralModelId } from "@roo-code/types"
|
|
|
+
|
|
|
import type { ApiHandlerOptions } from "../../../shared/api"
|
|
|
-import type { ApiHandlerCreateMessageMetadata } from "../../index"
|
|
|
-import type { ApiStreamTextChunk, ApiStreamReasoningChunk, ApiStreamToolCallPartialChunk } from "../../transform/stream"
|
|
|
+
|
|
|
+import { MistralHandler } from "../mistral"
|
|
|
|
|
|
describe("MistralHandler", () => {
|
|
|
let handler: MistralHandler
|
|
|
@@ -61,15 +38,11 @@ describe("MistralHandler", () => {
|
|
|
|
|
|
beforeEach(() => {
|
|
|
mockOptions = {
|
|
|
- apiModelId: "codestral-latest", // Update to match the actual model ID
|
|
|
mistralApiKey: "test-api-key",
|
|
|
- includeMaxTokens: true,
|
|
|
- modelTemperature: 0,
|
|
|
+ apiModelId: "codestral-latest" as MistralModelId,
|
|
|
}
|
|
|
handler = new MistralHandler(mockOptions)
|
|
|
- mockCreate.mockClear()
|
|
|
- mockComplete.mockClear()
|
|
|
- mockCaptureException.mockClear()
|
|
|
+ vi.clearAllMocks()
|
|
|
})
|
|
|
|
|
|
describe("constructor", () => {
|
|
|
@@ -78,32 +51,53 @@ describe("MistralHandler", () => {
|
|
|
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
|
|
|
})
|
|
|
|
|
|
- it("should throw error if API key is missing", () => {
|
|
|
- expect(() => {
|
|
|
- new MistralHandler({
|
|
|
- ...mockOptions,
|
|
|
- mistralApiKey: undefined,
|
|
|
- })
|
|
|
- }).toThrow("Mistral API key is required")
|
|
|
- })
|
|
|
-
|
|
|
- it("should use custom base URL if provided", () => {
|
|
|
- const customBaseUrl = "https://custom.mistral.ai/v1"
|
|
|
- const handlerWithCustomUrl = new MistralHandler({
|
|
|
+ it("should use default model ID if not provided", () => {
|
|
|
+ const handlerWithoutModel = new MistralHandler({
|
|
|
...mockOptions,
|
|
|
- mistralCodestralUrl: customBaseUrl,
|
|
|
+ apiModelId: undefined,
|
|
|
})
|
|
|
- expect(handlerWithCustomUrl).toBeInstanceOf(MistralHandler)
|
|
|
+ expect(handlerWithoutModel.getModel().id).toBe(mistralDefaultModelId)
|
|
|
})
|
|
|
})
|
|
|
|
|
|
describe("getModel", () => {
|
|
|
- it("should return correct model info", () => {
|
|
|
+ it("should return model info for valid model ID", () => {
|
|
|
const model = handler.getModel()
|
|
|
expect(model.id).toBe(mockOptions.apiModelId)
|
|
|
expect(model.info).toBeDefined()
|
|
|
+ expect(model.info.maxTokens).toBe(8192)
|
|
|
+ expect(model.info.contextWindow).toBe(256_000)
|
|
|
+ expect(model.info.supportsImages).toBe(false)
|
|
|
expect(model.info.supportsPromptCache).toBe(false)
|
|
|
})
|
|
|
+
|
|
|
+ it("should return provided model ID with default model info if model does not exist", () => {
|
|
|
+ const handlerWithInvalidModel = new MistralHandler({
|
|
|
+ ...mockOptions,
|
|
|
+ apiModelId: "invalid-model",
|
|
|
+ })
|
|
|
+ const model = handlerWithInvalidModel.getModel()
|
|
|
+ expect(model.id).toBe("invalid-model") // Returns provided ID
|
|
|
+ expect(model.info).toBeDefined()
|
|
|
+ // Should have the same base properties as default model
|
|
|
+ expect(model.info.contextWindow).toBe(mistralModels[mistralDefaultModelId].contextWindow)
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should return default model if no model ID is provided", () => {
|
|
|
+ const handlerWithoutModel = new MistralHandler({
|
|
|
+ ...mockOptions,
|
|
|
+ apiModelId: undefined,
|
|
|
+ })
|
|
|
+ const model = handlerWithoutModel.getModel()
|
|
|
+ expect(model.id).toBe(mistralDefaultModelId)
|
|
|
+ expect(model.info).toBeDefined()
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should include model parameters from getModelParams", () => {
|
|
|
+ const model = handler.getModel()
|
|
|
+ expect(model).toHaveProperty("temperature")
|
|
|
+ expect(model).toHaveProperty("maxTokens")
|
|
|
+ })
|
|
|
})
|
|
|
|
|
|
describe("createMessage", () => {
|
|
|
@@ -111,389 +105,446 @@ describe("MistralHandler", () => {
|
|
|
const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
{
|
|
|
role: "user",
|
|
|
- content: [{ type: "text", text: "Hello!" }],
|
|
|
+ content: [
|
|
|
+ {
|
|
|
+ type: "text" as const,
|
|
|
+ text: "Hello!",
|
|
|
+ },
|
|
|
+ ],
|
|
|
},
|
|
|
]
|
|
|
|
|
|
- it("should create message successfully", async () => {
|
|
|
- const iterator = handler.createMessage(systemPrompt, messages)
|
|
|
- const result = await iterator.next()
|
|
|
+ it("should handle streaming responses", async () => {
|
|
|
+ // Mock the fullStream async generator
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test response" }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Mock usage promise
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
+ })
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ expect(chunks.length).toBeGreaterThan(0)
|
|
|
+ const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
|
|
+ expect(textChunks).toHaveLength(1)
|
|
|
+ expect(textChunks[0].text).toBe("Test response")
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should include usage information", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "text-delta", text: "Test response" }
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
+ })
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
|
|
|
+ expect(usageChunks.length).toBeGreaterThan(0)
|
|
|
+ expect(usageChunks[0].inputTokens).toBe(10)
|
|
|
+ expect(usageChunks[0].outputTokens).toBe(5)
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle reasoning content in streaming responses", async () => {
|
|
|
+ // Mock the fullStream async generator with reasoning content
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield { type: "reasoning", text: "Let me think about this..." }
|
|
|
+ yield { type: "reasoning", text: " I'll analyze step by step." }
|
|
|
+ yield { type: "text-delta", text: "Test response" }
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
+ details: {
|
|
|
+ reasoningTokens: 15,
|
|
|
+ },
|
|
|
+ })
|
|
|
|
|
|
- expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages)
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Should have reasoning chunks
|
|
|
+ const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning")
|
|
|
+ expect(reasoningChunks.length).toBe(2)
|
|
|
+ expect(reasoningChunks[0].text).toBe("Let me think about this...")
|
|
|
+ expect(reasoningChunks[1].text).toBe(" I'll analyze step by step.")
|
|
|
+
|
|
|
+ // Should also have text chunks
|
|
|
+ const textChunks = chunks.filter((chunk) => chunk.type === "text")
|
|
|
+ expect(textChunks.length).toBe(1)
|
|
|
+ expect(textChunks[0].text).toBe("Test response")
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ describe("completePrompt", () => {
|
|
|
+ it("should complete a prompt using generateText", async () => {
|
|
|
+ mockGenerateText.mockResolvedValue({
|
|
|
+ text: "Test completion",
|
|
|
+ })
|
|
|
+
|
|
|
+ const result = await handler.completePrompt("Test prompt")
|
|
|
+
|
|
|
+ expect(result).toBe("Test completion")
|
|
|
+ expect(mockGenerateText).toHaveBeenCalledWith(
|
|
|
expect.objectContaining({
|
|
|
- model: mockOptions.apiModelId,
|
|
|
- messages: expect.any(Array),
|
|
|
- maxTokens: expect.any(Number),
|
|
|
- temperature: 0,
|
|
|
- // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS)
|
|
|
- tools: expect.any(Array),
|
|
|
- toolChoice: "any",
|
|
|
+ prompt: "Test prompt",
|
|
|
}),
|
|
|
)
|
|
|
-
|
|
|
- expect(result.value).toBeDefined()
|
|
|
- expect(result.done).toBe(false)
|
|
|
})
|
|
|
+ })
|
|
|
|
|
|
- it("should handle streaming response correctly", async () => {
|
|
|
- const iterator = handler.createMessage(systemPrompt, messages)
|
|
|
- const results: ApiStreamTextChunk[] = []
|
|
|
-
|
|
|
- for await (const chunk of iterator) {
|
|
|
- if ("text" in chunk) {
|
|
|
- results.push(chunk as ApiStreamTextChunk)
|
|
|
+ describe("processUsageMetrics", () => {
|
|
|
+ it("should correctly process usage metrics", () => {
|
|
|
+ // We need to access the protected method, so we'll create a test subclass
|
|
|
+ class TestMistralHandler extends MistralHandler {
|
|
|
+ public testProcessUsageMetrics(usage: any) {
|
|
|
+ return this.processUsageMetrics(usage)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- expect(results.length).toBeGreaterThan(0)
|
|
|
- expect(results[0].text).toBe("Test response")
|
|
|
- })
|
|
|
+ const testHandler = new TestMistralHandler(mockOptions)
|
|
|
+
|
|
|
+ const usage = {
|
|
|
+ inputTokens: 100,
|
|
|
+ outputTokens: 50,
|
|
|
+ details: {
|
|
|
+ cachedInputTokens: 20,
|
|
|
+ reasoningTokens: 30,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ const result = testHandler.testProcessUsageMetrics(usage)
|
|
|
|
|
|
- it("should handle errors gracefully", async () => {
|
|
|
- mockCreate.mockRejectedValueOnce(new Error("API Error"))
|
|
|
- await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error")
|
|
|
+ expect(result.type).toBe("usage")
|
|
|
+ expect(result.inputTokens).toBe(100)
|
|
|
+ expect(result.outputTokens).toBe(50)
|
|
|
+ expect(result.cacheReadTokens).toBe(20)
|
|
|
+ expect(result.reasoningTokens).toBe(30)
|
|
|
})
|
|
|
|
|
|
- it("should handle thinking content as reasoning chunks", async () => {
|
|
|
- // Mock stream with thinking content matching new SDK structure
|
|
|
- mockCreate.mockImplementationOnce(async (_options) => {
|
|
|
- const stream = {
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
- yield {
|
|
|
- data: {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- delta: {
|
|
|
- content: [
|
|
|
- {
|
|
|
- type: "thinking",
|
|
|
- thinking: [{ type: "text", text: "Let me think about this..." }],
|
|
|
- },
|
|
|
- { type: "text", text: "Here's the answer" },
|
|
|
- ],
|
|
|
- },
|
|
|
- index: 0,
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- }
|
|
|
- },
|
|
|
+ it("should handle missing cache metrics gracefully", () => {
|
|
|
+ class TestMistralHandler extends MistralHandler {
|
|
|
+ public testProcessUsageMetrics(usage: any) {
|
|
|
+ return this.processUsageMetrics(usage)
|
|
|
}
|
|
|
- return stream
|
|
|
- })
|
|
|
+ }
|
|
|
+
|
|
|
+ const testHandler = new TestMistralHandler(mockOptions)
|
|
|
|
|
|
- const iterator = handler.createMessage(systemPrompt, messages)
|
|
|
- const results: (ApiStreamTextChunk | ApiStreamReasoningChunk)[] = []
|
|
|
+ const usage = {
|
|
|
+ inputTokens: 100,
|
|
|
+ outputTokens: 50,
|
|
|
+ }
|
|
|
+
|
|
|
+ const result = testHandler.testProcessUsageMetrics(usage)
|
|
|
+
|
|
|
+ expect(result.type).toBe("usage")
|
|
|
+ expect(result.inputTokens).toBe(100)
|
|
|
+ expect(result.outputTokens).toBe(50)
|
|
|
+ expect(result.cacheReadTokens).toBeUndefined()
|
|
|
+ expect(result.reasoningTokens).toBeUndefined()
|
|
|
+ })
|
|
|
+ })
|
|
|
|
|
|
- for await (const chunk of iterator) {
|
|
|
- if ("text" in chunk) {
|
|
|
- results.push(chunk as ApiStreamTextChunk | ApiStreamReasoningChunk)
|
|
|
+ describe("getMaxOutputTokens", () => {
|
|
|
+ it("should return maxTokens from model info", () => {
|
|
|
+ class TestMistralHandler extends MistralHandler {
|
|
|
+ public testGetMaxOutputTokens() {
|
|
|
+ return this.getMaxOutputTokens()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- expect(results).toHaveLength(2)
|
|
|
- expect(results[0]).toEqual({ type: "reasoning", text: "Let me think about this..." })
|
|
|
- expect(results[1]).toEqual({ type: "text", text: "Here's the answer" })
|
|
|
+ const testHandler = new TestMistralHandler(mockOptions)
|
|
|
+ const result = testHandler.testGetMaxOutputTokens()
|
|
|
+
|
|
|
+ // codestral-latest maxTokens is 8192
|
|
|
+ expect(result).toBe(8192)
|
|
|
})
|
|
|
|
|
|
- it("should handle mixed content arrays correctly", async () => {
|
|
|
- // Mock stream with mixed content matching new SDK structure
|
|
|
- mockCreate.mockImplementationOnce(async (_options) => {
|
|
|
- const stream = {
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
- yield {
|
|
|
- data: {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- delta: {
|
|
|
- content: [
|
|
|
- { type: "text", text: "First text" },
|
|
|
- {
|
|
|
- type: "thinking",
|
|
|
- thinking: [{ type: "text", text: "Some reasoning" }],
|
|
|
- },
|
|
|
- { type: "text", text: "Second text" },
|
|
|
- ],
|
|
|
- },
|
|
|
- index: 0,
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- }
|
|
|
- },
|
|
|
+ it("should use modelMaxTokens when provided", () => {
|
|
|
+ class TestMistralHandler extends MistralHandler {
|
|
|
+ public testGetMaxOutputTokens() {
|
|
|
+ return this.getMaxOutputTokens()
|
|
|
}
|
|
|
- return stream
|
|
|
+ }
|
|
|
+
|
|
|
+ const customMaxTokens = 5000
|
|
|
+ const testHandler = new TestMistralHandler({
|
|
|
+ ...mockOptions,
|
|
|
+ modelMaxTokens: customMaxTokens,
|
|
|
})
|
|
|
|
|
|
- const iterator = handler.createMessage(systemPrompt, messages)
|
|
|
- const results: (ApiStreamTextChunk | ApiStreamReasoningChunk)[] = []
|
|
|
+ const result = testHandler.testGetMaxOutputTokens()
|
|
|
+ expect(result).toBe(customMaxTokens)
|
|
|
+ })
|
|
|
|
|
|
- for await (const chunk of iterator) {
|
|
|
- if ("text" in chunk) {
|
|
|
- results.push(chunk as ApiStreamTextChunk | ApiStreamReasoningChunk)
|
|
|
+ it("should fall back to modelInfo.maxTokens when modelMaxTokens is not provided", () => {
|
|
|
+ class TestMistralHandler extends MistralHandler {
|
|
|
+ public testGetMaxOutputTokens() {
|
|
|
+ return this.getMaxOutputTokens()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- expect(results).toHaveLength(3)
|
|
|
- expect(results[0]).toEqual({ type: "text", text: "First text" })
|
|
|
- expect(results[1]).toEqual({ type: "reasoning", text: "Some reasoning" })
|
|
|
- expect(results[2]).toEqual({ type: "text", text: "Second text" })
|
|
|
+ const testHandler = new TestMistralHandler(mockOptions)
|
|
|
+ const result = testHandler.testGetMaxOutputTokens()
|
|
|
+
|
|
|
+ // codestral-latest has maxTokens of 8192
|
|
|
+ expect(result).toBe(8192)
|
|
|
})
|
|
|
})
|
|
|
|
|
|
- describe("native tool calling", () => {
|
|
|
+ describe("tool handling", () => {
|
|
|
const systemPrompt = "You are a helpful assistant."
|
|
|
const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
{
|
|
|
role: "user",
|
|
|
- content: [{ type: "text", text: "What's the weather?" }],
|
|
|
+ content: [{ type: "text" as const, text: "Hello!" }],
|
|
|
},
|
|
|
]
|
|
|
|
|
|
- const mockTools: OpenAI.Chat.ChatCompletionTool[] = [
|
|
|
- {
|
|
|
- type: "function",
|
|
|
- function: {
|
|
|
- name: "get_weather",
|
|
|
- description: "Get the current weather",
|
|
|
- parameters: {
|
|
|
- type: "object",
|
|
|
- properties: {
|
|
|
- location: { type: "string" },
|
|
|
+ it("should handle tool calls in streaming", async () => {
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield {
|
|
|
+ type: "tool-input-start",
|
|
|
+ id: "tool-call-1",
|
|
|
+ toolName: "read_file",
|
|
|
+ }
|
|
|
+ yield {
|
|
|
+ type: "tool-input-delta",
|
|
|
+ id: "tool-call-1",
|
|
|
+ delta: '{"path":"test.ts"}',
|
|
|
+ }
|
|
|
+ yield {
|
|
|
+ type: "tool-input-end",
|
|
|
+ id: "tool-call-1",
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
+ })
|
|
|
+
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages, {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: [
|
|
|
+ {
|
|
|
+ type: "function",
|
|
|
+ function: {
|
|
|
+ name: "read_file",
|
|
|
+ description: "Read a file",
|
|
|
+ parameters: {
|
|
|
+ type: "object",
|
|
|
+ properties: { path: { type: "string" } },
|
|
|
+ required: ["path"],
|
|
|
+ },
|
|
|
},
|
|
|
- required: ["location"],
|
|
|
},
|
|
|
- },
|
|
|
- },
|
|
|
- ]
|
|
|
+ ],
|
|
|
+ })
|
|
|
|
|
|
- it("should include tools in request by default (native is default)", async () => {
|
|
|
- const metadata: ApiHandlerCreateMessageMetadata = {
|
|
|
- taskId: "test-task",
|
|
|
- tools: mockTools,
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
}
|
|
|
|
|
|
- const iterator = handler.createMessage(systemPrompt, messages, metadata)
|
|
|
- await iterator.next()
|
|
|
+ const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start")
|
|
|
+ const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta")
|
|
|
+ const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end")
|
|
|
|
|
|
- expect(mockCreate).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- tools: expect.arrayContaining([
|
|
|
- expect.objectContaining({
|
|
|
- type: "function",
|
|
|
- function: expect.objectContaining({
|
|
|
- name: "get_weather",
|
|
|
- description: "Get the current weather",
|
|
|
- parameters: expect.any(Object),
|
|
|
- }),
|
|
|
- }),
|
|
|
- ]),
|
|
|
- toolChoice: "any",
|
|
|
- }),
|
|
|
- )
|
|
|
+ expect(toolCallStartChunks.length).toBe(1)
|
|
|
+ expect(toolCallStartChunks[0].id).toBe("tool-call-1")
|
|
|
+ expect(toolCallStartChunks[0].name).toBe("read_file")
|
|
|
+
|
|
|
+ expect(toolCallDeltaChunks.length).toBe(1)
|
|
|
+ expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}')
|
|
|
+
|
|
|
+ expect(toolCallEndChunks.length).toBe(1)
|
|
|
+ expect(toolCallEndChunks[0].id).toBe("tool-call-1")
|
|
|
})
|
|
|
|
|
|
- it("should always include tools in request (tools are always present after PR #10841)", async () => {
|
|
|
- const metadata: ApiHandlerCreateMessageMetadata = {
|
|
|
- taskId: "test-task",
|
|
|
+ it("should ignore tool-call events to prevent duplicate tools in UI", async () => {
|
|
|
+ // tool-call events are intentionally ignored because tool-input-start/delta/end
|
|
|
+ // already provide complete tool call information. Emitting tool-call would cause
|
|
|
+ // duplicate tools in the UI for AI SDK providers.
|
|
|
+ async function* mockFullStream() {
|
|
|
+ yield {
|
|
|
+ type: "tool-call",
|
|
|
+ toolCallId: "tool-call-1",
|
|
|
+ toolName: "read_file",
|
|
|
+ input: { path: "test.ts" },
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- const iterator = handler.createMessage(systemPrompt, messages, metadata)
|
|
|
- await iterator.next()
|
|
|
+ const mockUsage = Promise.resolve({
|
|
|
+ inputTokens: 10,
|
|
|
+ outputTokens: 5,
|
|
|
+ })
|
|
|
|
|
|
- // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS)
|
|
|
- expect(mockCreate).toHaveBeenCalledWith(
|
|
|
- expect.objectContaining({
|
|
|
- tools: expect.any(Array),
|
|
|
- toolChoice: "any",
|
|
|
- }),
|
|
|
- )
|
|
|
- })
|
|
|
+ mockStreamText.mockReturnValue({
|
|
|
+ fullStream: mockFullStream(),
|
|
|
+ usage: mockUsage,
|
|
|
+ })
|
|
|
|
|
|
- it("should handle tool calls in streaming response", async () => {
|
|
|
- // Mock stream with tool calls
|
|
|
- mockCreate.mockImplementationOnce(async (_options) => {
|
|
|
- const stream = {
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
- yield {
|
|
|
- data: {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- delta: {
|
|
|
- toolCalls: [
|
|
|
- {
|
|
|
- id: "call_123",
|
|
|
- type: "function",
|
|
|
- function: {
|
|
|
- name: "get_weather",
|
|
|
- arguments: '{"location":"New York"}',
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- index: 0,
|
|
|
- },
|
|
|
- ],
|
|
|
+ const stream = handler.createMessage(systemPrompt, messages, {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: [
|
|
|
+ {
|
|
|
+ type: "function",
|
|
|
+ function: {
|
|
|
+ name: "read_file",
|
|
|
+ description: "Read a file",
|
|
|
+ parameters: {
|
|
|
+ type: "object",
|
|
|
+ properties: { path: { type: "string" } },
|
|
|
+ required: ["path"],
|
|
|
},
|
|
|
- }
|
|
|
+ },
|
|
|
},
|
|
|
- }
|
|
|
- return stream
|
|
|
+ ],
|
|
|
})
|
|
|
|
|
|
- const metadata: ApiHandlerCreateMessageMetadata = {
|
|
|
- taskId: "test-task",
|
|
|
- tools: mockTools,
|
|
|
+ const chunks: any[] = []
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
}
|
|
|
|
|
|
- const iterator = handler.createMessage(systemPrompt, messages, metadata)
|
|
|
- const results: ApiStreamToolCallPartialChunk[] = []
|
|
|
+ // tool-call events are ignored, so no tool_call chunks should be emitted
|
|
|
+ const toolCallChunks = chunks.filter((c) => c.type === "tool_call")
|
|
|
+ expect(toolCallChunks.length).toBe(0)
|
|
|
+ })
|
|
|
+ })
|
|
|
|
|
|
- for await (const chunk of iterator) {
|
|
|
- if (chunk.type === "tool_call_partial") {
|
|
|
- results.push(chunk)
|
|
|
+ describe("mapToolChoice", () => {
|
|
|
+ it("should handle string tool choices", () => {
|
|
|
+ class TestMistralHandler extends MistralHandler {
|
|
|
+ public testMapToolChoice(toolChoice: any) {
|
|
|
+ return this.mapToolChoice(toolChoice)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- expect(results).toHaveLength(1)
|
|
|
- expect(results[0]).toEqual({
|
|
|
- type: "tool_call_partial",
|
|
|
- index: 0,
|
|
|
- id: "call_123",
|
|
|
- name: "get_weather",
|
|
|
- arguments: '{"location":"New York"}',
|
|
|
- })
|
|
|
+ const testHandler = new TestMistralHandler(mockOptions)
|
|
|
+
|
|
|
+ expect(testHandler.testMapToolChoice("auto")).toBe("auto")
|
|
|
+ expect(testHandler.testMapToolChoice("none")).toBe("none")
|
|
|
+ expect(testHandler.testMapToolChoice("required")).toBe("required")
|
|
|
+ expect(testHandler.testMapToolChoice("any")).toBe("required")
|
|
|
+ expect(testHandler.testMapToolChoice("unknown")).toBe("auto")
|
|
|
})
|
|
|
|
|
|
- it("should handle multiple tool calls in a single response", async () => {
|
|
|
- // Mock stream with multiple tool calls
|
|
|
- mockCreate.mockImplementationOnce(async (_options) => {
|
|
|
- const stream = {
|
|
|
- [Symbol.asyncIterator]: async function* () {
|
|
|
- yield {
|
|
|
- data: {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- delta: {
|
|
|
- toolCalls: [
|
|
|
- {
|
|
|
- id: "call_1",
|
|
|
- type: "function",
|
|
|
- function: {
|
|
|
- name: "get_weather",
|
|
|
- arguments: '{"location":"NYC"}',
|
|
|
- },
|
|
|
- },
|
|
|
- {
|
|
|
- id: "call_2",
|
|
|
- type: "function",
|
|
|
- function: {
|
|
|
- name: "get_weather",
|
|
|
- arguments: '{"location":"LA"}',
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- index: 0,
|
|
|
- },
|
|
|
- ],
|
|
|
- },
|
|
|
- }
|
|
|
- },
|
|
|
+ it("should handle object tool choice with function name", () => {
|
|
|
+ class TestMistralHandler extends MistralHandler {
|
|
|
+ public testMapToolChoice(toolChoice: any) {
|
|
|
+ return this.mapToolChoice(toolChoice)
|
|
|
}
|
|
|
- return stream
|
|
|
- })
|
|
|
-
|
|
|
- const metadata: ApiHandlerCreateMessageMetadata = {
|
|
|
- taskId: "test-task",
|
|
|
- tools: mockTools,
|
|
|
}
|
|
|
|
|
|
- const iterator = handler.createMessage(systemPrompt, messages, metadata)
|
|
|
- const results: ApiStreamToolCallPartialChunk[] = []
|
|
|
+ const testHandler = new TestMistralHandler(mockOptions)
|
|
|
|
|
|
- for await (const chunk of iterator) {
|
|
|
- if (chunk.type === "tool_call_partial") {
|
|
|
- results.push(chunk)
|
|
|
+ const result = testHandler.testMapToolChoice({
|
|
|
+ type: "function",
|
|
|
+ function: { name: "my_tool" },
|
|
|
+ })
|
|
|
+
|
|
|
+ expect(result).toEqual({ type: "tool", toolName: "my_tool" })
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should return undefined for null or undefined", () => {
|
|
|
+ class TestMistralHandler extends MistralHandler {
|
|
|
+ public testMapToolChoice(toolChoice: any) {
|
|
|
+ return this.mapToolChoice(toolChoice)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- expect(results).toHaveLength(2)
|
|
|
- expect(results[0]).toEqual({
|
|
|
- type: "tool_call_partial",
|
|
|
- index: 0,
|
|
|
- id: "call_1",
|
|
|
- name: "get_weather",
|
|
|
- arguments: '{"location":"NYC"}',
|
|
|
- })
|
|
|
- expect(results[1]).toEqual({
|
|
|
- type: "tool_call_partial",
|
|
|
- index: 1,
|
|
|
- id: "call_2",
|
|
|
- name: "get_weather",
|
|
|
- arguments: '{"location":"LA"}',
|
|
|
- })
|
|
|
+ const testHandler = new TestMistralHandler(mockOptions)
|
|
|
+
|
|
|
+ expect(testHandler.testMapToolChoice(null)).toBeUndefined()
|
|
|
+ expect(testHandler.testMapToolChoice(undefined)).toBeUndefined()
|
|
|
})
|
|
|
+ })
|
|
|
|
|
|
- it("should always set toolChoice to 'any' when tools are provided", async () => {
|
|
|
- // Even if tool_choice is provided in metadata, we override it to "any"
|
|
|
- const metadata: ApiHandlerCreateMessageMetadata = {
|
|
|
- taskId: "test-task",
|
|
|
- tools: mockTools,
|
|
|
- tool_choice: "auto", // This should be ignored
|
|
|
- }
|
|
|
+ describe("Codestral URL handling", () => {
|
|
|
+ beforeEach(() => {
|
|
|
+ mockCreateMistral.mockClear()
|
|
|
+ })
|
|
|
|
|
|
- const iterator = handler.createMessage(systemPrompt, messages, metadata)
|
|
|
- await iterator.next()
|
|
|
+ it("should use default Codestral URL for codestral models", () => {
|
|
|
+ new MistralHandler({
|
|
|
+ ...mockOptions,
|
|
|
+ apiModelId: "codestral-latest",
|
|
|
+ })
|
|
|
|
|
|
- expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ expect(mockCreateMistral).toHaveBeenCalledWith(
|
|
|
expect.objectContaining({
|
|
|
- toolChoice: "any",
|
|
|
+ baseURL: "https://codestral.mistral.ai/v1",
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
- })
|
|
|
|
|
|
- describe("completePrompt", () => {
|
|
|
- it("should complete prompt successfully", async () => {
|
|
|
- const prompt = "Test prompt"
|
|
|
- const result = await handler.completePrompt(prompt)
|
|
|
-
|
|
|
- expect(mockComplete).toHaveBeenCalledWith({
|
|
|
- model: mockOptions.apiModelId,
|
|
|
- messages: [{ role: "user", content: prompt }],
|
|
|
- temperature: 0,
|
|
|
+ it("should use custom Codestral URL when provided", () => {
|
|
|
+ new MistralHandler({
|
|
|
+ ...mockOptions,
|
|
|
+ apiModelId: "codestral-latest",
|
|
|
+ mistralCodestralUrl: "https://custom.codestral.url/v1",
|
|
|
})
|
|
|
|
|
|
- expect(result).toBe("Test response")
|
|
|
+ expect(mockCreateMistral).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ baseURL: "https://custom.codestral.url/v1",
|
|
|
+ }),
|
|
|
+ )
|
|
|
})
|
|
|
|
|
|
- it("should filter out thinking content in completePrompt", async () => {
|
|
|
- mockComplete.mockImplementationOnce(async (_options) => {
|
|
|
- return {
|
|
|
- choices: [
|
|
|
- {
|
|
|
- message: {
|
|
|
- content: [
|
|
|
- { type: "thinking", text: "Let me think..." },
|
|
|
- { type: "text", text: "Answer part 1" },
|
|
|
- { type: "text", text: "Answer part 2" },
|
|
|
- ],
|
|
|
- },
|
|
|
- },
|
|
|
- ],
|
|
|
- }
|
|
|
+ it("should use default Mistral URL for non-codestral models", () => {
|
|
|
+ new MistralHandler({
|
|
|
+ ...mockOptions,
|
|
|
+ apiModelId: "mistral-large-latest",
|
|
|
})
|
|
|
|
|
|
- const prompt = "Test prompt"
|
|
|
- const result = await handler.completePrompt(prompt)
|
|
|
-
|
|
|
- expect(result).toBe("Answer part 1Answer part 2")
|
|
|
- })
|
|
|
-
|
|
|
- it("should handle errors in completePrompt", async () => {
|
|
|
- mockComplete.mockRejectedValueOnce(new Error("API Error"))
|
|
|
- await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Mistral completion error: API Error")
|
|
|
+ expect(mockCreateMistral).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ baseURL: "https://api.mistral.ai/v1",
|
|
|
+ }),
|
|
|
+ )
|
|
|
})
|
|
|
})
|
|
|
})
|