Просмотр исходного кода

Support native tools in the anthropic provider (#9644)

Co-authored-by: Roo Code <[email protected]>
Matt Rubens 1 месяц назад
Родитель
Сommit
5b64aa95f6

+ 12 - 0
packages/types/src/providers/anthropic.ts

@@ -11,6 +11,7 @@ export const anthropicModels = {
 		contextWindow: 200_000, // Default 200K, extendable to 1M with beta flag 'context-1m-2025-08-07'
 		supportsImages: true,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 3.0, // $3 per million input tokens (≤200K context)
 		outputPrice: 15.0, // $15 per million output tokens (≤200K context)
 		cacheWritesPrice: 3.75, // $3.75 per million tokens
@@ -32,6 +33,7 @@ export const anthropicModels = {
 		contextWindow: 200_000, // Default 200K, extendable to 1M with beta flag 'context-1m-2025-08-07'
 		supportsImages: true,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 3.0, // $3 per million input tokens (≤200K context)
 		outputPrice: 15.0, // $15 per million output tokens (≤200K context)
 		cacheWritesPrice: 3.75, // $3.75 per million tokens
@@ -53,6 +55,7 @@ export const anthropicModels = {
 		contextWindow: 200_000,
 		supportsImages: true,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 5.0, // $5 per million input tokens
 		outputPrice: 25.0, // $25 per million output tokens
 		cacheWritesPrice: 6.25, // $6.25 per million tokens
@@ -64,6 +67,7 @@ export const anthropicModels = {
 		contextWindow: 200_000,
 		supportsImages: true,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 15.0, // $15 per million input tokens
 		outputPrice: 75.0, // $75 per million output tokens
 		cacheWritesPrice: 18.75, // $18.75 per million tokens
@@ -75,6 +79,7 @@ export const anthropicModels = {
 		contextWindow: 200_000,
 		supportsImages: true,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 15.0, // $15 per million input tokens
 		outputPrice: 75.0, // $75 per million output tokens
 		cacheWritesPrice: 18.75, // $18.75 per million tokens
@@ -86,6 +91,7 @@ export const anthropicModels = {
 		contextWindow: 200_000,
 		supportsImages: true,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 3.0, // $3 per million input tokens
 		outputPrice: 15.0, // $15 per million output tokens
 		cacheWritesPrice: 3.75, // $3.75 per million tokens
@@ -98,6 +104,7 @@ export const anthropicModels = {
 		contextWindow: 200_000,
 		supportsImages: true,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 3.0, // $3 per million input tokens
 		outputPrice: 15.0, // $15 per million output tokens
 		cacheWritesPrice: 3.75, // $3.75 per million tokens
@@ -108,6 +115,7 @@ export const anthropicModels = {
 		contextWindow: 200_000,
 		supportsImages: true,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 3.0, // $3 per million input tokens
 		outputPrice: 15.0, // $15 per million output tokens
 		cacheWritesPrice: 3.75, // $3.75 per million tokens
@@ -118,6 +126,7 @@ export const anthropicModels = {
 		contextWindow: 200_000,
 		supportsImages: false,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 1.0,
 		outputPrice: 5.0,
 		cacheWritesPrice: 1.25,
@@ -128,6 +137,7 @@ export const anthropicModels = {
 		contextWindow: 200_000,
 		supportsImages: true,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 15.0,
 		outputPrice: 75.0,
 		cacheWritesPrice: 18.75,
@@ -138,6 +148,7 @@ export const anthropicModels = {
 		contextWindow: 200_000,
 		supportsImages: true,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 0.25,
 		outputPrice: 1.25,
 		cacheWritesPrice: 0.3,
@@ -148,6 +159,7 @@ export const anthropicModels = {
 		contextWindow: 200_000,
 		supportsImages: true,
 		supportsPromptCache: true,
+		supportsNativeTools: true,
 		inputPrice: 1.0,
 		outputPrice: 5.0,
 		cacheWritesPrice: 1.25,

+ 343 - 0
src/api/providers/__tests__/anthropic.spec.ts

@@ -384,4 +384,347 @@ describe("AnthropicHandler", () => {
 			expect(calledMessages.every((m: any) => m.role === "user")).toBe(true)
 		})
 	})
+
+	describe("native tool calling", () => {
+		const systemPrompt = "You are a helpful assistant."
+		const messages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: [{ type: "text" as const, text: "What's the weather in London?" }],
+			},
+		]
+
+		const mockTools = [
+			{
+				type: "function" as const,
+				function: {
+					name: "get_weather",
+					description: "Get the current weather",
+					parameters: {
+						type: "object",
+						properties: {
+							location: { type: "string" },
+						},
+						required: ["location"],
+					},
+				},
+			},
+		]
+
+		it("should include tools in request when toolProtocol is native", async () => {
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "native",
+			})
+
+			// Consume the stream to trigger the API call
+			for await (const _chunk of stream) {
+				// Just consume
+			}
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					tools: expect.arrayContaining([
+						expect.objectContaining({
+							name: "get_weather",
+							description: "Get the current weather",
+							input_schema: expect.objectContaining({
+								type: "object",
+								properties: expect.objectContaining({
+									location: { type: "string" },
+								}),
+							}),
+						}),
+					]),
+				}),
+				expect.anything(),
+			)
+		})
+
+		it("should not include tools when toolProtocol is xml", async () => {
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "xml",
+			})
+
+			// Consume the stream to trigger the API call
+			for await (const _chunk of stream) {
+				// Just consume
+			}
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.not.objectContaining({
+					tools: expect.anything(),
+				}),
+				expect.anything(),
+			)
+		})
+
+		it("should not include tools when no tools are provided", async () => {
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				toolProtocol: "native",
+			})
+
+			// Consume the stream to trigger the API call
+			for await (const _chunk of stream) {
+				// Just consume
+			}
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.not.objectContaining({
+					tools: expect.anything(),
+				}),
+				expect.anything(),
+			)
+		})
+
+		it("should convert tool_choice 'auto' to Anthropic format", async () => {
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "native",
+				tool_choice: "auto",
+			})
+
+			// Consume the stream to trigger the API call
+			for await (const _chunk of stream) {
+				// Just consume
+			}
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					tool_choice: { type: "auto", disable_parallel_tool_use: true },
+				}),
+				expect.anything(),
+			)
+		})
+
+		it("should convert tool_choice 'required' to Anthropic 'any' format", async () => {
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "native",
+				tool_choice: "required",
+			})
+
+			// Consume the stream to trigger the API call
+			for await (const _chunk of stream) {
+				// Just consume
+			}
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					tool_choice: { type: "any", disable_parallel_tool_use: true },
+				}),
+				expect.anything(),
+			)
+		})
+
+		it("should omit both tools and tool_choice when tool_choice is 'none'", async () => {
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "native",
+				tool_choice: "none",
+			})
+
+			// Consume the stream to trigger the API call
+			for await (const _chunk of stream) {
+				// Just consume
+			}
+
+			// Verify that neither tools nor tool_choice are included in the request
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.not.objectContaining({
+					tools: expect.anything(),
+				}),
+				expect.anything(),
+			)
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.not.objectContaining({
+					tool_choice: expect.anything(),
+				}),
+				expect.anything(),
+			)
+		})
+
+		it("should convert specific tool_choice to Anthropic 'tool' format", async () => {
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "native",
+				tool_choice: { type: "function" as const, function: { name: "get_weather" } },
+			})
+
+			// Consume the stream to trigger the API call
+			for await (const _chunk of stream) {
+				// Just consume
+			}
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					tool_choice: { type: "tool", name: "get_weather", disable_parallel_tool_use: true },
+				}),
+				expect.anything(),
+			)
+		})
+
+		it("should enable parallel tool calls when parallelToolCalls is true", async () => {
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "native",
+				tool_choice: "auto",
+				parallelToolCalls: true,
+			})
+
+			// Consume the stream to trigger the API call
+			for await (const _chunk of stream) {
+				// Just consume
+			}
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					tool_choice: { type: "auto", disable_parallel_tool_use: false },
+				}),
+				expect.anything(),
+			)
+		})
+
+		it("should handle tool_use blocks in stream and emit tool_call_partial", async () => {
+			mockCreate.mockImplementationOnce(async () => ({
+				async *[Symbol.asyncIterator]() {
+					yield {
+						type: "message_start",
+						message: {
+							usage: {
+								input_tokens: 100,
+								output_tokens: 50,
+							},
+						},
+					}
+					yield {
+						type: "content_block_start",
+						index: 0,
+						content_block: {
+							type: "tool_use",
+							id: "toolu_123",
+							name: "get_weather",
+						},
+					}
+				},
+			}))
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "native",
+			})
+
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Find the tool_call_partial chunk
+			const toolCallChunk = chunks.find((chunk) => chunk.type === "tool_call_partial")
+			expect(toolCallChunk).toBeDefined()
+			expect(toolCallChunk).toEqual({
+				type: "tool_call_partial",
+				index: 0,
+				id: "toolu_123",
+				name: "get_weather",
+				arguments: undefined,
+			})
+		})
+
+		it("should handle input_json_delta in stream and emit tool_call_partial arguments", async () => {
+			mockCreate.mockImplementationOnce(async () => ({
+				async *[Symbol.asyncIterator]() {
+					yield {
+						type: "message_start",
+						message: {
+							usage: {
+								input_tokens: 100,
+								output_tokens: 50,
+							},
+						},
+					}
+					yield {
+						type: "content_block_start",
+						index: 0,
+						content_block: {
+							type: "tool_use",
+							id: "toolu_123",
+							name: "get_weather",
+						},
+					}
+					yield {
+						type: "content_block_delta",
+						index: 0,
+						delta: {
+							type: "input_json_delta",
+							partial_json: '{"location":',
+						},
+					}
+					yield {
+						type: "content_block_delta",
+						index: 0,
+						delta: {
+							type: "input_json_delta",
+							partial_json: '"London"}',
+						},
+					}
+					yield {
+						type: "content_block_stop",
+						index: 0,
+					}
+				},
+			}))
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				tools: mockTools,
+				toolProtocol: "native",
+			})
+
+			const chunks: any[] = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Find the tool_call_partial chunks
+			const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
+			expect(toolCallChunks).toHaveLength(3)
+
+			// First chunk has id and name
+			expect(toolCallChunks[0]).toEqual({
+				type: "tool_call_partial",
+				index: 0,
+				id: "toolu_123",
+				name: "get_weather",
+				arguments: undefined,
+			})
+
+			// Subsequent chunks have arguments
+			expect(toolCallChunks[1]).toEqual({
+				type: "tool_call_partial",
+				index: 0,
+				id: undefined,
+				name: undefined,
+				arguments: '{"location":',
+			})
+
+			expect(toolCallChunks[2]).toEqual({
+				type: "tool_call_partial",
+				index: 0,
+				id: undefined,
+				name: undefined,
+				arguments: '"London"}',
+			})
+		})
+	})
 })

+ 92 - 1
src/api/providers/anthropic.ts

@@ -1,6 +1,7 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming"
 import { CacheControlEphemeral } from "@anthropic-ai/sdk/resources"
+import OpenAI from "openai"
 
 import {
 	type ModelInfo,
@@ -19,6 +20,7 @@ import { filterNonAnthropicBlocks } from "../transform/anthropic-filter"
 import { BaseProvider } from "./base-provider"
 import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
 import { calculateApiCostAnthropic } from "../../shared/cost"
+import { convertOpenAIToolsToAnthropic } from "../../core/prompts/tools/native-tools/converters"
 
 export class AnthropicHandler extends BaseProvider implements SingleCompletionHandler {
 	private options: ApiHandlerOptions
@@ -44,7 +46,13 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
 	): ApiStream {
 		let stream: AnthropicStream<Anthropic.Messages.RawMessageStreamEvent>
 		const cacheControl: CacheControlEphemeral = { type: "ephemeral" }
-		let { id: modelId, betas = [], maxTokens, temperature, reasoning: thinking } = this.getModel()
+		let {
+			id: modelId,
+			betas = ["fine-grained-tool-streaming-2025-05-14"],
+			maxTokens,
+			temperature,
+			reasoning: thinking,
+		} = this.getModel()
 
 		// Filter out non-Anthropic blocks (reasoning, thoughtSignature, etc.) before sending to the API
 		const sanitizedMessages = filterNonAnthropicBlocks(messages)
@@ -57,6 +65,21 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
 			betas.push("context-1m-2025-08-07")
 		}
 
+		// Prepare native tool parameters if tools are provided and protocol is not XML
+		// Also exclude tools when tool_choice is "none" since that means "don't use tools"
+		const shouldIncludeNativeTools =
+			metadata?.tools &&
+			metadata.tools.length > 0 &&
+			metadata?.toolProtocol !== "xml" &&
+			metadata?.tool_choice !== "none"
+
+		const nativeToolParams = shouldIncludeNativeTools
+			? {
+					tools: convertOpenAIToolsToAnthropic(metadata.tools!),
+					tool_choice: this.convertOpenAIToolChoice(metadata.tool_choice, metadata.parallelToolCalls),
+				}
+			: {}
+
 		switch (modelId) {
 			case "claude-sonnet-4-5":
 			case "claude-sonnet-4-20250514":
@@ -112,6 +135,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
 							return message
 						}),
 						stream: true,
+						...nativeToolParams,
 					},
 					(() => {
 						// prompt caching: https://x.com/alexalbert__/status/1823751995901272068
@@ -148,6 +172,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
 					system: [{ text: systemPrompt, type: "text" }],
 					messages: sanitizedMessages,
 					stream: true,
+					...nativeToolParams,
 				})) as any
 				break
 			}
@@ -217,6 +242,17 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
 
 							yield { type: "text", text: chunk.content_block.text }
 							break
+						case "tool_use": {
+							// Emit initial tool call partial with id and name
+							yield {
+								type: "tool_call_partial",
+								index: chunk.index,
+								id: chunk.content_block.id,
+								name: chunk.content_block.name,
+								arguments: undefined,
+							}
+							break
+						}
 					}
 					break
 				case "content_block_delta":
@@ -227,11 +263,23 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
 						case "text_delta":
 							yield { type: "text", text: chunk.delta.text }
 							break
+						case "input_json_delta": {
+							// Emit tool call partial chunks as arguments stream in
+							yield {
+								type: "tool_call_partial",
+								index: chunk.index,
+								id: undefined,
+								name: undefined,
+								arguments: chunk.delta.partial_json,
+							}
+							break
+						}
 					}
 
 					break
 				case "content_block_stop":
 					// Block complete - no action needed for now.
+					// NativeToolCallParser handles tool call completion
 					// Note: Signature for multi-turn thinking would require using stream.finalMessage()
 					// after iteration completes, which requires restructuring the streaming approach.
 					break
@@ -296,6 +344,49 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
 		}
 	}
 
+	/**
+	 * Converts OpenAI tool_choice to Anthropic ToolChoice format
+	 * @param toolChoice - OpenAI tool_choice parameter
+	 * @param parallelToolCalls - When true, allows parallel tool calls. When false (default), disables parallel tool calls.
+	 */
+	private convertOpenAIToolChoice(
+		toolChoice: OpenAI.Chat.ChatCompletionCreateParams["tool_choice"],
+		parallelToolCalls?: boolean,
+	): Anthropic.Messages.MessageCreateParams["tool_choice"] | undefined {
+		// Anthropic allows parallel tool calls by default. When parallelToolCalls is false or undefined,
+		// we disable parallel tool use to ensure one tool call at a time.
+		const disableParallelToolUse = !parallelToolCalls
+
+		if (!toolChoice) {
+			// Default to auto with parallel tool use control
+			return { type: "auto", disable_parallel_tool_use: disableParallelToolUse }
+		}
+
+		if (typeof toolChoice === "string") {
+			switch (toolChoice) {
+				case "none":
+					return undefined // Anthropic doesn't have "none", just omit tools
+				case "auto":
+					return { type: "auto", disable_parallel_tool_use: disableParallelToolUse }
+				case "required":
+					return { type: "any", disable_parallel_tool_use: disableParallelToolUse }
+				default:
+					return { type: "auto", disable_parallel_tool_use: disableParallelToolUse }
+			}
+		}
+
+		// Handle object form { type: "function", function: { name: string } }
+		if (typeof toolChoice === "object" && "function" in toolChoice) {
+			return {
+				type: "tool",
+				name: toolChoice.function.name,
+				disable_parallel_tool_use: disableParallelToolUse,
+			}
+		}
+
+		return { type: "auto", disable_parallel_tool_use: disableParallelToolUse }
+	}
+
 	async completePrompt(prompt: string) {
 		let { id: model, temperature } = this.getModel()