Browse Source

feat(vscode-lm): add native tool support (#10191)

Daniel 3 weeks ago
parent
commit
157a097032
2 changed files with 208 additions and 18 deletions
  1. 154 1
      src/api/providers/__tests__/vscode-lm.spec.ts
  2. 54 17
      src/api/providers/vscode-lm.ts

+ 154 - 1
src/api/providers/__tests__/vscode-lm.spec.ts

@@ -180,7 +180,7 @@ describe("VsCodeLmHandler", () => {
 			})
 		})
 
-		it("should handle tool calls", async () => {
+		it("should handle tool calls as text when not using native tool protocol", async () => {
 			const systemPrompt = "You are a helpful assistant"
 			const messages: Anthropic.Messages.MessageParam[] = [
 				{
@@ -223,6 +223,139 @@ describe("VsCodeLmHandler", () => {
 			})
 		})
 
+		it("should handle native tool calls when using native tool protocol", async () => {
+			const systemPrompt = "You are a helpful assistant"
+			const messages: Anthropic.Messages.MessageParam[] = [
+				{
+					role: "user" as const,
+					content: "Calculate 2+2",
+				},
+			]
+
+			const toolCallData = {
+				name: "calculator",
+				arguments: { operation: "add", numbers: [2, 2] },
+				callId: "call-1",
+			}
+
+			const tools = [
+				{
+					type: "function" as const,
+					function: {
+						name: "calculator",
+						description: "A simple calculator",
+						parameters: {
+							type: "object",
+							properties: {
+								operation: { type: "string" },
+								numbers: { type: "array", items: { type: "number" } },
+							},
+						},
+					},
+				},
+			]
+
+			mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
+				stream: (async function* () {
+					yield new vscode.LanguageModelToolCallPart(
+						toolCallData.callId,
+						toolCallData.name,
+						toolCallData.arguments,
+					)
+					return
+				})(),
+				text: (async function* () {
+					yield JSON.stringify({ type: "tool_call", ...toolCallData })
+					return
+				})(),
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				toolProtocol: "native",
+				tools,
+			})
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks).toHaveLength(2) // Tool call chunk + usage chunk
+			expect(chunks[0]).toEqual({
+				type: "tool_call",
+				id: toolCallData.callId,
+				name: toolCallData.name,
+				arguments: JSON.stringify(toolCallData.arguments),
+			})
+		})
+
+		it("should pass tools to request options when using native tool protocol", async () => {
+			const systemPrompt = "You are a helpful assistant"
+			const messages: Anthropic.Messages.MessageParam[] = [
+				{
+					role: "user" as const,
+					content: "Calculate 2+2",
+				},
+			]
+
+			const tools = [
+				{
+					type: "function" as const,
+					function: {
+						name: "calculator",
+						description: "A simple calculator",
+						parameters: {
+							type: "object",
+							properties: {
+								operation: { type: "string" },
+							},
+						},
+					},
+				},
+			]
+
+			mockLanguageModelChat.sendRequest.mockResolvedValueOnce({
+				stream: (async function* () {
+					yield new vscode.LanguageModelTextPart("Result: 4")
+					return
+				})(),
+				text: (async function* () {
+					yield "Result: 4"
+					return
+				})(),
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages, {
+				taskId: "test-task",
+				toolProtocol: "native",
+				tools,
+			})
+			const chunks = []
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Verify sendRequest was called with tools in options
+			expect(mockLanguageModelChat.sendRequest).toHaveBeenCalledWith(
+				expect.any(Array),
+				expect.objectContaining({
+					tools: [
+						{
+							name: "calculator",
+							description: "A simple calculator",
+							inputSchema: {
+								type: "object",
+								properties: {
+									operation: { type: "string" },
+								},
+							},
+						},
+					],
+				}),
+				expect.anything(),
+			)
+		})
+
 		it("should handle errors", async () => {
 			const systemPrompt = "You are a helpful assistant"
 			const messages: Anthropic.Messages.MessageParam[] = [
@@ -259,6 +392,26 @@ describe("VsCodeLmHandler", () => {
 			expect(model.id).toBe("test-vendor/test-family")
 			expect(model.info).toBeDefined()
 		})
+
+		it("should return supportsNativeTools and defaultToolProtocol in model info", async () => {
+			const mockModel = { ...mockLanguageModelChat }
+			;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel])
+
+			// Initialize client
+			await handler["getClient"]()
+
+			const model = handler.getModel()
+			expect(model.info.supportsNativeTools).toBe(true)
+			expect(model.info.defaultToolProtocol).toBe("native")
+		})
+
+		it("should return supportsNativeTools and defaultToolProtocol in fallback model info", () => {
+			// Clear the client first
+			handler["client"] = null
+			const model = handler.getModel()
+			expect(model.info.supportsNativeTools).toBe(true)
+			expect(model.info.defaultToolProtocol).toBe("native")
+		})
 	})
 
 	describe("completePrompt", () => {

+ 54 - 17
src/api/providers/vscode-lm.ts

@@ -1,5 +1,6 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import * as vscode from "vscode"
+import OpenAI from "openai"
 
 import { type ModelInfo, openAiModelInfoSaneDefaults } from "@roo-code/types"
 
@@ -12,6 +13,21 @@ import { convertToVsCodeLmMessages, extractTextCountFromMessage } from "../trans
 import { BaseProvider } from "./base-provider"
 import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
 
+/**
+ * Converts OpenAI-format tools to VSCode Language Model tools.
+ * @param tools Array of OpenAI ChatCompletionTool definitions
+ * @returns Array of VSCode LanguageModelChatTool definitions
+ */
+function convertToVsCodeLmTools(tools: OpenAI.Chat.ChatCompletionTool[]): vscode.LanguageModelChatTool[] {
+	return tools
+		.filter((tool) => tool.type === "function")
+		.map((tool) => ({
+			name: tool.function.name,
+			description: tool.function.description || "",
+			inputSchema: tool.function.parameters as Record<string, unknown> | undefined,
+		}))
+}
+
 /**
  * Handles interaction with VS Code's Language Model API for chat-based operations.
  * This handler extends BaseProvider to provide VS Code LM specific functionality.
@@ -360,14 +376,19 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
 		// Accumulate the text and count at the end of the stream to reduce token counting overhead.
 		let accumulatedText: string = ""
 
+		// Determine if we're using native tool protocol
+		const useNativeTools = metadata?.toolProtocol === "native" && metadata?.tools && metadata.tools.length > 0
+
 		try {
-			// Create the response stream with minimal required options
+			// Create the response stream with required options
 			const requestOptions: vscode.LanguageModelChatRequestOptions = {
 				justification: `Roo Code would like to use '${client.name}' from '${client.vendor}', Click 'Allow' to proceed.`,
 			}
 
-			// Note: Tool support is currently provided by the VSCode Language Model API directly
-			// Extensions can register tools using vscode.lm.registerTool()
+			// Add tools to request options when using native tool protocol
+			if (useNativeTools && metadata?.tools) {
+				requestOptions.tools = convertToVsCodeLmTools(metadata.tools)
+			}
 
 			const response: vscode.LanguageModelChatResponse = await client.sendRequest(
 				vsCodeLmMessages,
@@ -408,17 +429,6 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
 							continue
 						}
 
-						// Convert tool calls to text format with proper error handling
-						const toolCall = {
-							type: "tool_call",
-							name: chunk.name,
-							arguments: chunk.input,
-							callId: chunk.callId,
-						}
-
-						const toolCallText = JSON.stringify(toolCall)
-						accumulatedText += toolCallText
-
 						// Log tool call for debugging
 						console.debug("Roo Code <Language Model API>: Processing tool call:", {
 							name: chunk.name,
@@ -426,9 +436,32 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
 							inputSize: JSON.stringify(chunk.input).length,
 						})
 
-						yield {
-							type: "text",
-							text: toolCallText,
+						// Yield native tool_call chunk when using native tool protocol
+						if (useNativeTools) {
+							const argumentsString = JSON.stringify(chunk.input)
+							accumulatedText += argumentsString
+							yield {
+								type: "tool_call",
+								id: chunk.callId,
+								name: chunk.name,
+								arguments: argumentsString,
+							}
+						} else {
+							// Fallback: Convert tool calls to text format for XML tool protocol
+							const toolCall = {
+								type: "tool_call",
+								name: chunk.name,
+								arguments: chunk.input,
+								callId: chunk.callId,
+							}
+
+							const toolCallText = JSON.stringify(toolCall)
+							accumulatedText += toolCallText
+
+							yield {
+								type: "text",
+								text: toolCallText,
+							}
 						}
 					} catch (error) {
 						console.error("Roo Code <Language Model API>: Failed to process tool call:", error)
@@ -512,6 +545,8 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
 						: openAiModelInfoSaneDefaults.contextWindow,
 				supportsImages: false, // VSCode Language Model API currently doesn't support image inputs
 				supportsPromptCache: true,
+				supportsNativeTools: true, // VSCode Language Model API supports native tool calling
+				defaultToolProtocol: "native", // Use native tool protocol by default
 				inputPrice: 0,
 				outputPrice: 0,
 				description: `VSCode Language Model: ${modelId}`,
@@ -531,6 +566,8 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
 			id: fallbackId,
 			info: {
 				...openAiModelInfoSaneDefaults,
+				supportsNativeTools: true, // VSCode Language Model API supports native tool calling
+				defaultToolProtocol: "native", // Use native tool protocol by default
 				description: `VSCode Language Model (Fallback): ${fallbackId}`,
 			},
 		}