|
|
@@ -39,9 +39,11 @@ vi.mock("@mistralai/mistralai", () => {
|
|
|
})
|
|
|
|
|
|
import type { Anthropic } from "@anthropic-ai/sdk"
|
|
|
+import type OpenAI from "openai"
|
|
|
import { MistralHandler } from "../mistral"
|
|
|
import type { ApiHandlerOptions } from "../../../shared/api"
|
|
|
-import type { ApiStreamTextChunk, ApiStreamReasoningChunk } from "../../transform/stream"
|
|
|
+import type { ApiHandlerCreateMessageMetadata } from "../../index"
|
|
|
+import type { ApiStreamTextChunk, ApiStreamReasoningChunk, ApiStreamToolCallPartialChunk } from "../../transform/stream"
|
|
|
|
|
|
describe("MistralHandler", () => {
|
|
|
let handler: MistralHandler
|
|
|
@@ -223,6 +225,223 @@ describe("MistralHandler", () => {
|
|
|
})
|
|
|
})
|
|
|
|
|
|
+ describe("native tool calling", () => {
|
|
|
+ const systemPrompt = "You are a helpful assistant."
|
|
|
+ const messages: Anthropic.Messages.MessageParam[] = [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ content: [{ type: "text", text: "What's the weather?" }],
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ const mockTools: OpenAI.Chat.ChatCompletionTool[] = [
|
|
|
+ {
|
|
|
+ type: "function",
|
|
|
+ function: {
|
|
|
+ name: "get_weather",
|
|
|
+ description: "Get the current weather",
|
|
|
+ parameters: {
|
|
|
+ type: "object",
|
|
|
+ properties: {
|
|
|
+ location: { type: "string" },
|
|
|
+ },
|
|
|
+ required: ["location"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ]
|
|
|
+
|
|
|
+ it("should include tools in request when toolProtocol is native", async () => {
|
|
|
+ const metadata: ApiHandlerCreateMessageMetadata = {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: mockTools,
|
|
|
+ toolProtocol: "native",
|
|
|
+ }
|
|
|
+
|
|
|
+ const iterator = handler.createMessage(systemPrompt, messages, metadata)
|
|
|
+ await iterator.next()
|
|
|
+
|
|
|
+ expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ tools: expect.arrayContaining([
|
|
|
+ expect.objectContaining({
|
|
|
+ type: "function",
|
|
|
+ function: expect.objectContaining({
|
|
|
+ name: "get_weather",
|
|
|
+ description: "Get the current weather",
|
|
|
+ parameters: expect.any(Object),
|
|
|
+ }),
|
|
|
+ }),
|
|
|
+ ]),
|
|
|
+ toolChoice: "any",
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should not include tools when toolProtocol is xml", async () => {
|
|
|
+ const metadata: ApiHandlerCreateMessageMetadata = {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: mockTools,
|
|
|
+ toolProtocol: "xml",
|
|
|
+ }
|
|
|
+
|
|
|
+ const iterator = handler.createMessage(systemPrompt, messages, metadata)
|
|
|
+ await iterator.next()
|
|
|
+
|
|
|
+ expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ expect.not.objectContaining({
|
|
|
+ tools: expect.anything(),
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle tool calls in streaming response", async () => {
|
|
|
+ // Mock stream with tool calls
|
|
|
+ mockCreate.mockImplementationOnce(async (_options) => {
|
|
|
+ const stream = {
|
|
|
+ [Symbol.asyncIterator]: async function* () {
|
|
|
+ yield {
|
|
|
+ data: {
|
|
|
+ choices: [
|
|
|
+ {
|
|
|
+ delta: {
|
|
|
+ toolCalls: [
|
|
|
+ {
|
|
|
+ id: "call_123",
|
|
|
+ type: "function",
|
|
|
+ function: {
|
|
|
+ name: "get_weather",
|
|
|
+ arguments: '{"location":"New York"}',
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ index: 0,
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ }
|
|
|
+ },
|
|
|
+ }
|
|
|
+ return stream
|
|
|
+ })
|
|
|
+
|
|
|
+ const metadata: ApiHandlerCreateMessageMetadata = {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: mockTools,
|
|
|
+ toolProtocol: "native",
|
|
|
+ }
|
|
|
+
|
|
|
+ const iterator = handler.createMessage(systemPrompt, messages, metadata)
|
|
|
+ const results: ApiStreamToolCallPartialChunk[] = []
|
|
|
+
|
|
|
+ for await (const chunk of iterator) {
|
|
|
+ if (chunk.type === "tool_call_partial") {
|
|
|
+ results.push(chunk)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ expect(results).toHaveLength(1)
|
|
|
+ expect(results[0]).toEqual({
|
|
|
+ type: "tool_call_partial",
|
|
|
+ index: 0,
|
|
|
+ id: "call_123",
|
|
|
+ name: "get_weather",
|
|
|
+ arguments: '{"location":"New York"}',
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle multiple tool calls in a single response", async () => {
|
|
|
+ // Mock stream with multiple tool calls
|
|
|
+ mockCreate.mockImplementationOnce(async (_options) => {
|
|
|
+ const stream = {
|
|
|
+ [Symbol.asyncIterator]: async function* () {
|
|
|
+ yield {
|
|
|
+ data: {
|
|
|
+ choices: [
|
|
|
+ {
|
|
|
+ delta: {
|
|
|
+ toolCalls: [
|
|
|
+ {
|
|
|
+ id: "call_1",
|
|
|
+ type: "function",
|
|
|
+ function: {
|
|
|
+ name: "get_weather",
|
|
|
+ arguments: '{"location":"NYC"}',
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ id: "call_2",
|
|
|
+ type: "function",
|
|
|
+ function: {
|
|
|
+ name: "get_weather",
|
|
|
+ arguments: '{"location":"LA"}',
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ index: 0,
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ }
|
|
|
+ },
|
|
|
+ }
|
|
|
+ return stream
|
|
|
+ })
|
|
|
+
|
|
|
+ const metadata: ApiHandlerCreateMessageMetadata = {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: mockTools,
|
|
|
+ toolProtocol: "native",
|
|
|
+ }
|
|
|
+
|
|
|
+ const iterator = handler.createMessage(systemPrompt, messages, metadata)
|
|
|
+ const results: ApiStreamToolCallPartialChunk[] = []
|
|
|
+
|
|
|
+ for await (const chunk of iterator) {
|
|
|
+ if (chunk.type === "tool_call_partial") {
|
|
|
+ results.push(chunk)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ expect(results).toHaveLength(2)
|
|
|
+ expect(results[0]).toEqual({
|
|
|
+ type: "tool_call_partial",
|
|
|
+ index: 0,
|
|
|
+ id: "call_1",
|
|
|
+ name: "get_weather",
|
|
|
+ arguments: '{"location":"NYC"}',
|
|
|
+ })
|
|
|
+ expect(results[1]).toEqual({
|
|
|
+ type: "tool_call_partial",
|
|
|
+ index: 1,
|
|
|
+ id: "call_2",
|
|
|
+ name: "get_weather",
|
|
|
+ arguments: '{"location":"LA"}',
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should always set toolChoice to 'any' when tools are provided", async () => {
|
|
|
+ // Even if tool_choice is provided in metadata, we override it to "any"
|
|
|
+ const metadata: ApiHandlerCreateMessageMetadata = {
|
|
|
+ taskId: "test-task",
|
|
|
+ tools: mockTools,
|
|
|
+ toolProtocol: "native",
|
|
|
+ tool_choice: "auto", // This should be ignored
|
|
|
+ }
|
|
|
+
|
|
|
+ const iterator = handler.createMessage(systemPrompt, messages, metadata)
|
|
|
+ await iterator.next()
|
|
|
+
|
|
|
+ expect(mockCreate).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ toolChoice: "any",
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
describe("completePrompt", () => {
|
|
|
it("should complete prompt successfully", async () => {
|
|
|
const prompt = "Test prompt"
|