Jelajahi Sumber

feat(gemini): add allowedFunctionNames support to prevent mode switch errors (#10708)

Co-authored-by: Roo Code <[email protected]>
Hannes Rudolph 1 Minggu lalu
induk
melakukan
9b1c8500d9

+ 9 - 0
src/api/index.ts

@@ -94,6 +94,15 @@ export interface ApiHandlerCreateMessageMetadata {
 	 * Only applies when toolProtocol is "native".
 	 */
 	parallelToolCalls?: boolean
+	/**
+	 * Optional array of tool names that the model is allowed to call.
+	 * When provided, all tool definitions are passed to the model (so it can reference
+	 * historical tool calls), but only the specified tools can actually be invoked.
+	 * This is used when switching modes to prevent model errors from missing tool
+	 * definitions while still restricting callable tools to the current mode's permissions.
+	 * Only applies to providers that support function calling restrictions (e.g., Gemini).
+	 */
+	allowedFunctionNames?: string[]
 }
 
 export interface ApiHandler {

+ 149 - 0
src/api/providers/__tests__/gemini-handler.spec.ts

@@ -1,4 +1,5 @@
 import { t } from "i18next"
+import { FunctionCallingConfigMode } from "@google/genai"
 
 import { GeminiHandler } from "../gemini"
 import type { ApiHandlerOptions } from "../../../shared/api"
@@ -141,4 +142,152 @@ describe("GeminiHandler backend support", () => {
 			}).rejects.toThrow(t("common:errors.gemini.generate_stream", { error: "API rate limit exceeded" }))
 		})
 	})
+
+	describe("allowedFunctionNames support", () => {
+		const testTools = [
+			{
+				type: "function" as const,
+				function: {
+					name: "read_file",
+					description: "Read a file",
+					parameters: { type: "object", properties: {} },
+				},
+			},
+			{
+				type: "function" as const,
+				function: {
+					name: "write_to_file",
+					description: "Write to a file",
+					parameters: { type: "object", properties: {} },
+				},
+			},
+			{
+				type: "function" as const,
+				function: {
+					name: "execute_command",
+					description: "Execute a command",
+					parameters: { type: "object", properties: {} },
+				},
+			},
+		]
+
+		it("should pass allowedFunctionNames to toolConfig when provided", async () => {
+			const options = {
+				apiProvider: "gemini",
+			} as ApiHandlerOptions
+			const handler = new GeminiHandler(options)
+			const stub = vi.fn().mockReturnValue((async function* () {})())
+			// @ts-ignore access private client
+			handler["client"].models.generateContentStream = stub
+
+			await handler
+				.createMessage("test", [] as any, {
+					taskId: "test-task",
+					tools: testTools,
+					allowedFunctionNames: ["read_file", "write_to_file"],
+				})
+				.next()
+
+			const config = stub.mock.calls[0][0].config
+			expect(config.toolConfig).toEqual({
+				functionCallingConfig: {
+					mode: FunctionCallingConfigMode.ANY,
+					allowedFunctionNames: ["read_file", "write_to_file"],
+				},
+			})
+		})
+
+		it("should include all tools but restrict callable functions via allowedFunctionNames", async () => {
+			const options = {
+				apiProvider: "gemini",
+			} as ApiHandlerOptions
+			const handler = new GeminiHandler(options)
+			const stub = vi.fn().mockReturnValue((async function* () {})())
+			// @ts-ignore access private client
+			handler["client"].models.generateContentStream = stub
+
+			await handler
+				.createMessage("test", [] as any, {
+					taskId: "test-task",
+					tools: testTools,
+					allowedFunctionNames: ["read_file"],
+				})
+				.next()
+
+			const config = stub.mock.calls[0][0].config
+			// All tools should be passed to the model
+			expect(config.tools[0].functionDeclarations).toHaveLength(3)
+			// But only read_file should be allowed to be called
+			expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(["read_file"])
+		})
+
+		it("should take precedence over tool_choice when allowedFunctionNames is provided", async () => {
+			const options = {
+				apiProvider: "gemini",
+			} as ApiHandlerOptions
+			const handler = new GeminiHandler(options)
+			const stub = vi.fn().mockReturnValue((async function* () {})())
+			// @ts-ignore access private client
+			handler["client"].models.generateContentStream = stub
+
+			await handler
+				.createMessage("test", [] as any, {
+					taskId: "test-task",
+					tools: testTools,
+					tool_choice: "auto",
+					allowedFunctionNames: ["read_file"],
+				})
+				.next()
+
+			const config = stub.mock.calls[0][0].config
+			// allowedFunctionNames should take precedence - mode should be ANY, not AUTO
+			expect(config.toolConfig.functionCallingConfig.mode).toBe(FunctionCallingConfigMode.ANY)
+			expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(["read_file"])
+		})
+
+		it("should fall back to tool_choice when allowedFunctionNames is empty", async () => {
+			const options = {
+				apiProvider: "gemini",
+			} as ApiHandlerOptions
+			const handler = new GeminiHandler(options)
+			const stub = vi.fn().mockReturnValue((async function* () {})())
+			// @ts-ignore access private client
+			handler["client"].models.generateContentStream = stub
+
+			await handler
+				.createMessage("test", [] as any, {
+					taskId: "test-task",
+					tools: testTools,
+					tool_choice: "auto",
+					allowedFunctionNames: [],
+				})
+				.next()
+
+			const config = stub.mock.calls[0][0].config
+			// Empty allowedFunctionNames should fall back to tool_choice behavior
+			expect(config.toolConfig.functionCallingConfig.mode).toBe(FunctionCallingConfigMode.AUTO)
+			expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toBeUndefined()
+		})
+
+		it("should not set toolConfig when allowedFunctionNames is undefined and no tool_choice", async () => {
+			const options = {
+				apiProvider: "gemini",
+			} as ApiHandlerOptions
+			const handler = new GeminiHandler(options)
+			const stub = vi.fn().mockReturnValue((async function* () {})())
+			// @ts-ignore access private client
+			handler["client"].models.generateContentStream = stub
+
+			await handler
+				.createMessage("test", [] as any, {
+					taskId: "test-task",
+					tools: testTools,
+				})
+				.next()
+
+			const config = stub.mock.calls[0][0].config
+			// No toolConfig should be set when neither allowedFunctionNames nor tool_choice is provided
+			expect(config.toolConfig).toBeUndefined()
+		})
+	})
 })

+ 13 - 1
src/api/providers/gemini.ts

@@ -172,7 +172,19 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
 			...(tools.length > 0 ? { tools } : {}),
 		}
 
-		if (metadata?.tool_choice) {
+		// Handle allowedFunctionNames for mode-restricted tool access.
+		// When provided, all tool definitions are passed to the model (so it can reference
+		// historical tool calls in conversation), but only the specified tools can be invoked.
+		// This takes precedence over tool_choice to ensure mode restrictions are honored.
+		if (metadata?.allowedFunctionNames && metadata.allowedFunctionNames.length > 0) {
+			config.toolConfig = {
+				functionCallingConfig: {
+					// Use ANY mode to allow calling any of the allowed functions
+					mode: FunctionCallingConfigMode.ANY,
+					allowedFunctionNames: metadata.allowedFunctionNames,
+				},
+			}
+		} else if (metadata?.tool_choice) {
 			const choice = metadata.tool_choice
 			let mode: FunctionCallingConfigMode
 			let allowedFunctionNames: string[] | undefined

+ 52 - 1
src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts

@@ -1,7 +1,12 @@
 import { describe, it, expect, beforeEach, afterEach } from "vitest"
 import type OpenAI from "openai"
 import type { ModeConfig, ModelInfo } from "@roo-code/types"
-import { filterNativeToolsForMode, filterMcpToolsForMode, applyModelToolCustomization } from "../filter-tools-for-mode"
+import {
+	filterNativeToolsForMode,
+	filterMcpToolsForMode,
+	applyModelToolCustomization,
+	resolveToolAlias,
+} from "../filter-tools-for-mode"
 import * as toolsModule from "../../../../shared/tools"
 
 describe("filterNativeToolsForMode", () => {
@@ -859,3 +864,49 @@ describe("filterMcpToolsForMode", () => {
 		})
 	})
 })
+
+describe("resolveToolAlias", () => {
+	it("should resolve known alias to canonical name", () => {
+		// write_file is an alias for write_to_file (defined in TOOL_ALIASES)
+		expect(resolveToolAlias("write_file")).toBe("write_to_file")
+	})
+
+	it("should return canonical name unchanged", () => {
+		expect(resolveToolAlias("write_to_file")).toBe("write_to_file")
+		expect(resolveToolAlias("read_file")).toBe("read_file")
+		expect(resolveToolAlias("apply_diff")).toBe("apply_diff")
+	})
+
+	it("should return unknown tool names unchanged", () => {
+		expect(resolveToolAlias("unknown_tool")).toBe("unknown_tool")
+		expect(resolveToolAlias("custom_tool_xyz")).toBe("custom_tool_xyz")
+	})
+
+	it("should ensure allowedFunctionNames are consistent with functionDeclarations", () => {
+		// This test documents the fix for the Gemini allowedFunctionNames issue.
+		// When tools are renamed via aliasRenames, the alias names must be resolved
+		// back to canonical names for allowedFunctionNames to match functionDeclarations.
+		//
+		// Example scenario:
+		// - Model specifies includedTools: ["write_file"] (an alias)
+		// - filterNativeToolsForMode returns tool with name "write_file"
+		// - But allTools (functionDeclarations) contains "write_to_file" (canonical)
+		// - If allowedFunctionNames contains "write_file", Gemini will error
+		// - Resolving aliases ensures consistency: resolveToolAlias("write_file") -> "write_to_file"
+
+		const aliasToolName = "write_file"
+		const canonicalToolName = "write_to_file"
+
+		// Simulate extracting name from a filtered tool that was renamed to alias
+		const extractedName = aliasToolName
+
+		// Before the fix: allowedFunctionNames would contain alias name
+		// This would cause Gemini to error because "write_file" doesn't exist in functionDeclarations
+
+		// After the fix: we resolve to canonical name
+		const resolvedName = resolveToolAlias(extractedName)
+
+		// The resolved name matches what's in functionDeclarations (canonical names)
+		expect(resolvedName).toBe(canonicalToolName)
+	})
+})

+ 21 - 3
src/core/task/Task.ts

@@ -95,7 +95,7 @@ import { sanitizeToolUseId } from "../../utils/tool-id"
 // prompts
 import { formatResponse } from "../prompts/responses"
 import { SYSTEM_PROMPT } from "../prompts/system"
-import { buildNativeToolsArray } from "./build-tools"
+import { buildNativeToolsArrayWithRestrictions } from "./build-tools"
 
 // core modules
 import { ToolRepetitionDetector } from "../tools/ToolRepetitionDetector"
@@ -4091,15 +4091,27 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 		const taskProtocol = this._taskToolProtocol ?? "xml"
 		const shouldIncludeTools = taskProtocol === TOOL_PROTOCOL.NATIVE && (modelInfo.supportsNativeTools ?? false)
 
-		// Build complete tools array: native tools + dynamic MCP tools, filtered by mode restrictions
+		// Build complete tools array: native tools + dynamic MCP tools
+		// When includeAllToolsWithRestrictions is true, returns all tools but provides
+		// allowedFunctionNames for providers (like Gemini) that need to see all tool
+		// definitions in history while restricting callable tools for the current mode.
+		// Only Gemini currently supports this - other providers filter tools normally.
 		let allTools: OpenAI.Chat.ChatCompletionTool[] = []
+		let allowedFunctionNames: string[] | undefined
+
+		// Gemini requires all tool definitions to be present for history compatibility,
+		// but uses allowedFunctionNames to restrict which tools can be called.
+		// Other providers (Anthropic, OpenAI, etc.) don't support this feature yet,
+		// so they continue to receive only the filtered tools for the current mode.
+		const supportsAllowedFunctionNames = apiConfiguration?.apiProvider === "gemini"
+
 		if (shouldIncludeTools) {
 			const provider = this.providerRef.deref()
 			if (!provider) {
 				throw new Error("Provider reference lost during tool building")
 			}
 
-			allTools = await buildNativeToolsArray({
+			const toolsResult = await buildNativeToolsArrayWithRestrictions({
 				provider,
 				cwd: this.cwd,
 				mode,
@@ -4111,7 +4123,10 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 				browserToolEnabled: state?.browserToolEnabled ?? true,
 				modelInfo,
 				diffEnabled: this.diffEnabled,
+				includeAllToolsWithRestrictions: supportsAllowedFunctionNames,
 			})
+			allTools = toolsResult.tools
+			allowedFunctionNames = toolsResult.allowedFunctionNames
 		}
 
 		// Parallel tool calls are disabled - feature is on hold
@@ -4129,6 +4144,9 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 						tool_choice: "auto",
 						toolProtocol: taskProtocol,
 						parallelToolCalls: parallelToolCallsEnabled,
+						// When mode restricts tools, provide allowedFunctionNames so providers
+						// like Gemini can see all tools in history but only call allowed ones
+						...(allowedFunctionNames ? { allowedFunctionNames } : {}),
 					}
 				: {}),
 		}

+ 77 - 2
src/core/task/build-tools.ts

@@ -9,7 +9,11 @@ import type { ClineProvider } from "../webview/ClineProvider"
 import { getRooDirectoriesForCwd } from "../../services/roo-config/index.js"
 
 import { getNativeTools, getMcpServerTools } from "../prompts/tools/native-tools"
-import { filterNativeToolsForMode, filterMcpToolsForMode } from "../prompts/tools/filter-tools-for-mode"
+import {
+	filterNativeToolsForMode,
+	filterMcpToolsForMode,
+	resolveToolAlias,
+} from "../prompts/tools/filter-tools-for-mode"
 
 interface BuildToolsOptions {
 	provider: ClineProvider
@@ -23,6 +27,35 @@ interface BuildToolsOptions {
 	browserToolEnabled: boolean
 	modelInfo?: ModelInfo
 	diffEnabled: boolean
+	/**
+	 * If true, returns all tools without mode filtering, but also includes
+	 * the list of allowed tool names for use with allowedFunctionNames.
+	 * This enables providers that support function call restrictions (e.g., Gemini)
+	 * to pass all tool definitions while restricting callable tools.
+	 */
+	includeAllToolsWithRestrictions?: boolean
+}
+
+interface BuildToolsResult {
+	/**
+	 * The tools to pass to the model.
+	 * If includeAllToolsWithRestrictions is true, this includes ALL tools.
+	 * Otherwise, it includes only mode-filtered tools.
+	 */
+	tools: OpenAI.Chat.ChatCompletionTool[]
+	/**
+	 * The names of tools that are allowed to be called based on mode restrictions.
+	 * Only populated when includeAllToolsWithRestrictions is true.
+	 * Use this with allowedFunctionNames in providers that support it.
+	 */
+	allowedFunctionNames?: string[]
+}
+
+/**
+ * Extracts the function name from a tool definition.
+ */
+function getToolName(tool: OpenAI.Chat.ChatCompletionTool): string {
+	return (tool as OpenAI.Chat.ChatCompletionFunctionTool).function.name
 }
 
 /**
@@ -33,6 +66,23 @@ interface BuildToolsOptions {
  * @returns Array of filtered native and MCP tools
  */
 export async function buildNativeToolsArray(options: BuildToolsOptions): Promise<OpenAI.Chat.ChatCompletionTool[]> {
+	const result = await buildNativeToolsArrayWithRestrictions(options)
+	return result.tools
+}
+
+/**
+ * Builds the complete tools array for native protocol requests with optional mode restrictions.
+ * When includeAllToolsWithRestrictions is true, returns ALL tools but also provides
+ * the list of allowed tool names for use with allowedFunctionNames.
+ *
+ * This enables providers like Gemini to pass all tool definitions to the model
+ * (so it can reference historical tool calls) while restricting which tools
+ * can actually be invoked via allowedFunctionNames in toolConfig.
+ *
+ * @param options - Configuration options for building the tools
+ * @returns BuildToolsResult with tools array and optional allowedFunctionNames
+ */
+export async function buildNativeToolsArrayWithRestrictions(options: BuildToolsOptions): Promise<BuildToolsResult> {
 	const {
 		provider,
 		cwd,
@@ -45,6 +95,7 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise
 		browserToolEnabled,
 		modelInfo,
 		diffEnabled,
+		includeAllToolsWithRestrictions,
 	} = options
 
 	const mcpHub = provider.getMcpHub()
@@ -102,5 +153,29 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise
 		}
 	}
 
-	return [...filteredNativeTools, ...filteredMcpTools, ...nativeCustomTools]
+	// Combine filtered tools (for backward compatibility and for allowedFunctionNames)
+	const filteredTools = [...filteredNativeTools, ...filteredMcpTools, ...nativeCustomTools]
+
+	// If includeAllToolsWithRestrictions is true, return ALL tools but provide
+	// allowed names based on mode filtering
+	if (includeAllToolsWithRestrictions) {
+		// Combine ALL tools (unfiltered native + all MCP + custom)
+		const allTools = [...nativeTools, ...mcpTools, ...nativeCustomTools]
+
+		// Extract names of tools that are allowed based on mode filtering.
+		// Resolve any alias names to canonical names to ensure consistency with allTools
+		// (which uses canonical names). This prevents Gemini errors when tools are renamed
+		// to aliases in filteredTools but allTools contains the original canonical names.
+		const allowedFunctionNames = filteredTools.map((tool) => resolveToolAlias(getToolName(tool)))
+
+		return {
+			tools: allTools,
+			allowedFunctionNames,
+		}
+	}
+
+	// Default behavior: return only filtered tools
+	return {
+		tools: filteredTools,
+	}
 }