Просмотр исходного кода

Handle custom tool use similarly to MCP tools for ipc schema purposes (#10364)

John Richmond 1 день назад
Родитель
Сommit
a42387e0af

+ 1 - 0
packages/types/src/tool.ts

@@ -37,6 +37,7 @@ export const toolNames = [
 	"update_todo_list",
 	"run_slash_command",
 	"generate_image",
+	"custom_tool",
 ] as const
 
 export const toolNamesSchema = z.enum(toolNames)

+ 349 - 0
src/core/assistant-message/__tests__/presentAssistantMessage-custom-tool.spec.ts

@@ -0,0 +1,349 @@
+// npx vitest src/core/assistant-message/__tests__/presentAssistantMessage-custom-tool.spec.ts
+
+import { describe, it, expect, beforeEach, vi } from "vitest"
+import { presentAssistantMessage } from "../presentAssistantMessage"
+
+// Mock dependencies
+vi.mock("../../task/Task")
+vi.mock("../../tools/validateToolUse", () => ({
+	validateToolUse: vi.fn(),
+}))
+
+// Mock custom tool registry - must be done inline without external variable references
+vi.mock("@roo-code/core", () => ({
+	customToolRegistry: {
+		has: vi.fn(),
+		get: vi.fn(),
+	},
+}))
+
+vi.mock("@roo-code/telemetry", () => ({
+	TelemetryService: {
+		instance: {
+			captureToolUsage: vi.fn(),
+			captureConsecutiveMistakeError: vi.fn(),
+		},
+	},
+}))
+
+import { TelemetryService } from "@roo-code/telemetry"
+import { customToolRegistry } from "@roo-code/core"
+
+describe("presentAssistantMessage - Custom Tool Recording", () => {
+	let mockTask: any
+
+	beforeEach(() => {
+		// Reset all mocks
+		vi.clearAllMocks()
+
+		// Create a mock Task with minimal properties needed for testing
+		mockTask = {
+			taskId: "test-task-id",
+			instanceId: "test-instance",
+			abort: false,
+			presentAssistantMessageLocked: false,
+			presentAssistantMessageHasPendingUpdates: false,
+			currentStreamingContentIndex: 0,
+			assistantMessageContent: [],
+			userMessageContent: [],
+			didCompleteReadingStream: false,
+			didRejectTool: false,
+			didAlreadyUseTool: false,
+			diffEnabled: false,
+			consecutiveMistakeCount: 0,
+			clineMessages: [],
+			api: {
+				getModel: () => ({ id: "test-model", info: {} }),
+			},
+			browserSession: {
+				closeBrowser: vi.fn().mockResolvedValue(undefined),
+			},
+			recordToolUsage: vi.fn(),
+			recordToolError: vi.fn(),
+			toolRepetitionDetector: {
+				check: vi.fn().mockReturnValue({ allowExecution: true }),
+			},
+			providerRef: {
+				deref: () => ({
+					getState: vi.fn().mockResolvedValue({
+						mode: "code",
+						customModes: [],
+						experiments: {
+							customTools: true, // Enable by default
+						},
+					}),
+				}),
+			},
+			say: vi.fn().mockResolvedValue(undefined),
+			ask: vi.fn().mockResolvedValue({ response: "yesButtonClicked" }),
+		}
+	})
+
+	describe("Custom tool usage recording", () => {
+		it("should record custom tool usage as 'custom_tool' when experiment is enabled", async () => {
+			const toolCallId = "tool_call_custom_123"
+			mockTask.assistantMessageContent = [
+				{
+					type: "tool_use",
+					id: toolCallId,
+					name: "my_custom_tool",
+					params: { value: "test" },
+					partial: false,
+				},
+			]
+
+			// Mock customToolRegistry to recognize this as a custom tool
+			vi.mocked(customToolRegistry.has).mockReturnValue(true)
+			vi.mocked(customToolRegistry.get).mockReturnValue({
+				name: "my_custom_tool",
+				description: "A custom tool",
+				execute: vi.fn().mockResolvedValue("Custom tool result"),
+			})
+
+			await presentAssistantMessage(mockTask)
+
+			// Should record as "custom_tool", not "my_custom_tool"
+			expect(mockTask.recordToolUsage).toHaveBeenCalledWith("custom_tool")
+			expect(TelemetryService.instance.captureToolUsage).toHaveBeenCalledWith(
+				mockTask.taskId,
+				"custom_tool",
+				"native",
+			)
+		})
+
+		it("should record custom tool usage as 'custom_tool' in XML protocol", async () => {
+			mockTask.assistantMessageContent = [
+				{
+					type: "tool_use",
+					// No ID = XML protocol
+					name: "my_custom_tool",
+					params: { value: "test" },
+					partial: false,
+				},
+			]
+
+			vi.mocked(customToolRegistry.has).mockReturnValue(true)
+			vi.mocked(customToolRegistry.get).mockReturnValue({
+				name: "my_custom_tool",
+				description: "A custom tool",
+				execute: vi.fn().mockResolvedValue("Custom tool result"),
+			})
+
+			await presentAssistantMessage(mockTask)
+
+			expect(mockTask.recordToolUsage).toHaveBeenCalledWith("custom_tool")
+			expect(TelemetryService.instance.captureToolUsage).toHaveBeenCalledWith(
+				mockTask.taskId,
+				"custom_tool",
+				"xml",
+			)
+		})
+	})
+
+	describe("Custom tool error recording", () => {
+		it("should record custom tool error as 'custom_tool'", async () => {
+			const toolCallId = "tool_call_custom_error_123"
+			mockTask.assistantMessageContent = [
+				{
+					type: "tool_use",
+					id: toolCallId,
+					name: "failing_custom_tool",
+					params: {},
+					partial: false,
+				},
+			]
+
+			// Mock customToolRegistry with a tool that throws an error
+			vi.mocked(customToolRegistry.has).mockReturnValue(true)
+			vi.mocked(customToolRegistry.get).mockReturnValue({
+				name: "failing_custom_tool",
+				description: "A failing custom tool",
+				execute: vi.fn().mockRejectedValue(new Error("Custom tool execution failed")),
+			})
+
+			await presentAssistantMessage(mockTask)
+
+			// Should record error as "custom_tool", not "failing_custom_tool"
+			expect(mockTask.recordToolError).toHaveBeenCalledWith("custom_tool", "Custom tool execution failed")
+			expect(mockTask.consecutiveMistakeCount).toBe(1)
+		})
+	})
+
+	describe("Regular tool recording", () => {
+		it("should record regular tool usage with actual tool name", async () => {
+			const toolCallId = "tool_call_read_file_123"
+			mockTask.assistantMessageContent = [
+				{
+					type: "tool_use",
+					id: toolCallId,
+					name: "read_file",
+					params: { path: "test.txt" },
+					partial: false,
+				},
+			]
+
+			// read_file is not a custom tool
+			vi.mocked(customToolRegistry.has).mockReturnValue(false)
+
+			await presentAssistantMessage(mockTask)
+
+			// Should record as "read_file", not "custom_tool"
+			expect(mockTask.recordToolUsage).toHaveBeenCalledWith("read_file")
+			expect(TelemetryService.instance.captureToolUsage).toHaveBeenCalledWith(
+				mockTask.taskId,
+				"read_file",
+				"native",
+			)
+		})
+
+		it("should record MCP tool usage as 'use_mcp_tool' (not custom_tool)", async () => {
+			const toolCallId = "tool_call_mcp_123"
+			mockTask.assistantMessageContent = [
+				{
+					type: "tool_use",
+					id: toolCallId,
+					name: "use_mcp_tool",
+					params: {
+						server_name: "test-server",
+						tool_name: "test-tool",
+						arguments: "{}",
+					},
+					partial: false,
+				},
+			]
+
+			vi.mocked(customToolRegistry.has).mockReturnValue(false)
+
+			// Mock MCP hub for use_mcp_tool
+			mockTask.providerRef = {
+				deref: () => ({
+					getState: vi.fn().mockResolvedValue({
+						mode: "code",
+						customModes: [],
+						experiments: {
+							customTools: true,
+						},
+					}),
+					getMcpHub: () => ({
+						findServerNameBySanitizedName: () => "test-server",
+						executeToolCall: vi.fn().mockResolvedValue({ content: [{ type: "text", text: "result" }] }),
+					}),
+				}),
+			}
+
+			await presentAssistantMessage(mockTask)
+
+			// Should record as "use_mcp_tool", not "custom_tool"
+			expect(mockTask.recordToolUsage).toHaveBeenCalledWith("use_mcp_tool")
+			expect(TelemetryService.instance.captureToolUsage).toHaveBeenCalledWith(
+				mockTask.taskId,
+				"use_mcp_tool",
+				"native",
+			)
+		})
+	})
+
+	describe("Custom tool experiment gate", () => {
+		it("should treat custom tool as unknown when experiment is disabled", async () => {
+			const toolCallId = "tool_call_disabled_123"
+			mockTask.assistantMessageContent = [
+				{
+					type: "tool_use",
+					id: toolCallId,
+					name: "my_custom_tool",
+					params: {},
+					partial: false,
+				},
+			]
+
+			// Mock provider state with customTools experiment DISABLED
+			mockTask.providerRef = {
+				deref: () => ({
+					getState: vi.fn().mockResolvedValue({
+						mode: "code",
+						customModes: [],
+						experiments: {
+							customTools: false, // Disabled
+						},
+					}),
+				}),
+			}
+
+			// Even if registry recognizes it, experiment gate should prevent execution
+			vi.mocked(customToolRegistry.has).mockReturnValue(true)
+			vi.mocked(customToolRegistry.get).mockReturnValue({
+				name: "my_custom_tool",
+				description: "A custom tool",
+				execute: vi.fn().mockResolvedValue("Should not execute"),
+			})
+
+			await presentAssistantMessage(mockTask)
+
+			// Should be treated as unknown tool (not executed)
+			expect(mockTask.say).toHaveBeenCalledWith("error", "unknownToolError")
+			expect(mockTask.consecutiveMistakeCount).toBe(1)
+
+			// Custom tool should NOT have been executed
+			const getMock = vi.mocked(customToolRegistry.get)
+			if (getMock.mock.results.length > 0) {
+				const customTool = getMock.mock.results[0].value
+				if (customTool) {
+					expect(customTool.execute).not.toHaveBeenCalled()
+				}
+			}
+		})
+
+		it("should not call customToolRegistry.has() when experiment is disabled", async () => {
+			mockTask.assistantMessageContent = [
+				{
+					type: "tool_use",
+					id: "tool_call_123",
+					name: "some_tool",
+					params: {},
+					partial: false,
+				},
+			]
+
+			// Disable experiment
+			mockTask.providerRef = {
+				deref: () => ({
+					getState: vi.fn().mockResolvedValue({
+						mode: "code",
+						customModes: [],
+						experiments: {
+							customTools: false,
+						},
+					}),
+				}),
+			}
+
+			await presentAssistantMessage(mockTask)
+
+			// When experiment is off, shouldn't even check the registry
+			// (Code checks stateExperiments?.customTools before calling has())
+			expect(customToolRegistry.has).not.toHaveBeenCalled()
+		})
+	})
+
+	describe("Partial blocks", () => {
+		it("should not record usage for partial custom tool blocks", async () => {
+			mockTask.assistantMessageContent = [
+				{
+					type: "tool_use",
+					id: "tool_call_partial_123",
+					name: "my_custom_tool",
+					params: { value: "test" },
+					partial: true, // Still streaming
+				},
+			]
+
+			vi.mocked(customToolRegistry.has).mockReturnValue(true)
+
+			await presentAssistantMessage(mockTask)
+
+			// Should not record usage for partial blocks
+			expect(mockTask.recordToolUsage).not.toHaveBeenCalled()
+			expect(TelemetryService.instance.captureToolUsage).not.toHaveBeenCalled()
+		})
+	})
+})

+ 7 - 2
src/core/assistant-message/presentAssistantMessage.ts

@@ -695,8 +695,11 @@ export async function presentAssistantMessage(cline: Task) {
 			}
 
 			if (!block.partial) {
-				cline.recordToolUsage(block.name)
-				TelemetryService.instance.captureToolUsage(cline.taskId, block.name, toolProtocol)
+				// Check if this is a custom tool - if so, record as "custom_tool" (like MCP tools)
+				const isCustomTool = stateExperiments?.customTools && customToolRegistry.has(block.name)
+				const recordName = isCustomTool ? "custom_tool" : block.name
+				cline.recordToolUsage(recordName)
+				TelemetryService.instance.captureToolUsage(cline.taskId, recordName, toolProtocol)
 			}
 
 			// Validate tool use before execution - ONLY for complete (non-partial) blocks.
@@ -1091,6 +1094,8 @@ export async function presentAssistantMessage(cline: Task) {
 							cline.consecutiveMistakeCount = 0
 						} catch (executionError: any) {
 							cline.consecutiveMistakeCount++
+							// Record custom tool error with static name
+							cline.recordToolError("custom_tool", executionError.message)
 							await handleError(`executing custom tool "${block.name}"`, executionError)
 						}
 

+ 1 - 0
src/shared/tools.ts

@@ -266,6 +266,7 @@ export const TOOL_DISPLAY_NAMES: Record<ToolName, string> = {
 	update_todo_list: "update todo list",
 	run_slash_command: "run slash command",
 	generate_image: "generate images",
+	custom_tool: "use custom tools",
 } as const
 
 // Define available tool groups.