Просмотр исходного кода

fix: prevent duplicate MCP tools error by deduplicating servers at source (#10096)

Daniel 2 недель назад
Родитель
Сommit
1f3ab2b493

+ 195 - 0
src/core/prompts/tools/native-tools/__tests__/mcp_server.spec.ts

@@ -0,0 +1,195 @@
+import type OpenAI from "openai"
+import { getMcpServerTools } from "../mcp_server"
+import type { McpHub } from "../../../../../services/mcp/McpHub"
+import type { McpServer, McpTool } from "../../../../../shared/mcp"
+
+// Helper type to access function tools
+type FunctionTool = OpenAI.Chat.ChatCompletionTool & { type: "function" }
+
+// Helper to get the function property from a tool
+const getFunction = (tool: OpenAI.Chat.ChatCompletionTool) => (tool as FunctionTool).function
+
+describe("getMcpServerTools", () => {
+	const createMockTool = (name: string, description = "Test tool"): McpTool => ({
+		name,
+		description,
+		inputSchema: {
+			type: "object",
+			properties: {},
+		},
+	})
+
+	const createMockServer = (name: string, tools: McpTool[], source: "global" | "project" = "global"): McpServer => ({
+		name,
+		config: JSON.stringify({ type: "stdio", command: "test" }),
+		status: "connected",
+		source,
+		tools,
+	})
+
+	const createMockMcpHub = (servers: McpServer[]): Partial<McpHub> => ({
+		getServers: vi.fn().mockReturnValue(servers),
+	})
+
+	it("should return empty array when mcpHub is undefined", () => {
+		const result = getMcpServerTools(undefined)
+		expect(result).toEqual([])
+	})
+
+	it("should return empty array when no servers are available", () => {
+		const mockHub = createMockMcpHub([])
+		const result = getMcpServerTools(mockHub as McpHub)
+		expect(result).toEqual([])
+	})
+
+	it("should generate tool definitions for server tools", () => {
+		const server = createMockServer("testServer", [createMockTool("testTool")])
+		const mockHub = createMockMcpHub([server])
+
+		const result = getMcpServerTools(mockHub as McpHub)
+
+		expect(result).toHaveLength(1)
+		expect(result[0].type).toBe("function")
+		expect(getFunction(result[0]).name).toBe("mcp--testServer--testTool")
+		expect(getFunction(result[0]).description).toBe("Test tool")
+	})
+
+	it("should filter out tools with enabledForPrompt set to false", () => {
+		const enabledTool = createMockTool("enabledTool")
+		const disabledTool = { ...createMockTool("disabledTool"), enabledForPrompt: false }
+		const server = createMockServer("testServer", [enabledTool, disabledTool])
+		const mockHub = createMockMcpHub([server])
+
+		const result = getMcpServerTools(mockHub as McpHub)
+
+		expect(result).toHaveLength(1)
+		expect(getFunction(result[0]).name).toBe("mcp--testServer--enabledTool")
+	})
+
+	it("should deduplicate tools when same server exists in both global and project configs", () => {
+		const globalServer = createMockServer(
+			"context7",
+			[createMockTool("resolve-library-id", "Global description")],
+			"global",
+		)
+		const projectServer = createMockServer(
+			"context7",
+			[createMockTool("resolve-library-id", "Project description")],
+			"project",
+		)
+
+		// McpHub.getServers() deduplicates with project servers taking priority
+		// This test simulates the deduplicated result (only project server returned)
+		const mockHub = createMockMcpHub([projectServer])
+
+		const result = getMcpServerTools(mockHub as McpHub)
+
+		// Should only have one tool (from project server)
+		expect(result).toHaveLength(1)
+		expect(getFunction(result[0]).name).toBe("mcp--context7--resolve-library-id")
+		// Project server takes priority
+		expect(getFunction(result[0]).description).toBe("Project description")
+	})
+
+	it("should allow tools with different names from the same server", () => {
+		const server = createMockServer("testServer", [
+			createMockTool("tool1"),
+			createMockTool("tool2"),
+			createMockTool("tool3"),
+		])
+		const mockHub = createMockMcpHub([server])
+
+		const result = getMcpServerTools(mockHub as McpHub)
+
+		expect(result).toHaveLength(3)
+		const toolNames = result.map((t) => getFunction(t).name)
+		expect(toolNames).toContain("mcp--testServer--tool1")
+		expect(toolNames).toContain("mcp--testServer--tool2")
+		expect(toolNames).toContain("mcp--testServer--tool3")
+	})
+
+	it("should allow tools with same name from different servers", () => {
+		const server1 = createMockServer("server1", [createMockTool("commonTool")])
+		const server2 = createMockServer("server2", [createMockTool("commonTool")])
+		const mockHub = createMockMcpHub([server1, server2])
+
+		const result = getMcpServerTools(mockHub as McpHub)
+
+		expect(result).toHaveLength(2)
+		const toolNames = result.map((t) => getFunction(t).name)
+		expect(toolNames).toContain("mcp--server1--commonTool")
+		expect(toolNames).toContain("mcp--server2--commonTool")
+	})
+
+	it("should skip servers without tools", () => {
+		const serverWithTools = createMockServer("withTools", [createMockTool("tool1")])
+		const serverWithoutTools = createMockServer("withoutTools", [])
+		const serverWithUndefinedTools: McpServer = {
+			...createMockServer("undefinedTools", []),
+			tools: undefined,
+		}
+		const mockHub = createMockMcpHub([serverWithTools, serverWithoutTools, serverWithUndefinedTools])
+
+		const result = getMcpServerTools(mockHub as McpHub)
+
+		expect(result).toHaveLength(1)
+		expect(getFunction(result[0]).name).toBe("mcp--withTools--tool1")
+	})
+
+	it("should include required fields from tool schema", () => {
+		const toolWithRequired: McpTool = {
+			name: "toolWithRequired",
+			description: "Tool with required fields",
+			inputSchema: {
+				type: "object",
+				properties: {
+					requiredField: { type: "string" },
+					optionalField: { type: "number" },
+				},
+				required: ["requiredField"],
+			},
+		}
+		const server = createMockServer("testServer", [toolWithRequired])
+		const mockHub = createMockMcpHub([server])
+
+		const result = getMcpServerTools(mockHub as McpHub)
+
+		expect(result).toHaveLength(1)
+		expect(getFunction(result[0]).parameters).toEqual({
+			type: "object",
+			properties: {
+				requiredField: { type: "string" },
+				optionalField: { type: "number" },
+			},
+			additionalProperties: false,
+			required: ["requiredField"],
+		})
+	})
+
+	it("should not include required field when schema has no required fields", () => {
+		const toolWithoutRequired: McpTool = {
+			name: "toolWithoutRequired",
+			description: "Tool without required fields",
+			inputSchema: {
+				type: "object",
+				properties: {
+					optionalField: { type: "string" },
+				},
+			},
+		}
+		const server = createMockServer("testServer", [toolWithoutRequired])
+		const mockHub = createMockMcpHub([server])
+
+		const result = getMcpServerTools(mockHub as McpHub)
+
+		expect(result).toHaveLength(1)
+		expect(getFunction(result[0]).parameters).toEqual({
+			type: "object",
+			properties: {
+				optionalField: { type: "string" },
+			},
+			additionalProperties: false,
+		})
+		expect(getFunction(result[0]).parameters).not.toHaveProperty("required")
+	})
+})

+ 14 - 4
src/core/prompts/tools/native-tools/mcp_server.ts

@@ -4,6 +4,8 @@ import { buildMcpToolName } from "../../../../utils/mcp-name"
 
 /**
  * Dynamically generates native tool definitions for all enabled tools across connected MCP servers.
+ * Tools are deduplicated by name to prevent API errors. When the same server exists in both
+ * global and project configs, project servers take priority (handled by McpHub.getServers()).
  *
  * @param mcpHub The McpHub instance containing connected servers.
  * @returns An array of OpenAI.Chat.ChatCompletionTool definitions.
@@ -15,6 +17,8 @@ export function getMcpServerTools(mcpHub?: McpHub): OpenAI.Chat.ChatCompletionTo
 
 	const servers = mcpHub.getServers()
 	const tools: OpenAI.Chat.ChatCompletionTool[] = []
+	// Track seen tool names to prevent duplicates (e.g., when same server exists in both global and project configs)
+	const seenToolNames = new Set<string>()
 
 	for (const server of servers) {
 		if (!server.tools) {
@@ -26,6 +30,16 @@ export function getMcpServerTools(mcpHub?: McpHub): OpenAI.Chat.ChatCompletionTo
 				continue
 			}
 
+			// Build sanitized tool name for API compliance
+			// The name is sanitized to conform to API requirements (e.g., Gemini's function name restrictions)
+			const toolName = buildMcpToolName(server.name, tool.name)
+
+			// Skip duplicate tool names - first occurrence wins (project servers come before global servers)
+			if (seenToolNames.has(toolName)) {
+				continue
+			}
+			seenToolNames.add(toolName)
+
 			const originalSchema = tool.inputSchema as Record<string, any> | undefined
 			const toolInputProps = originalSchema?.properties ?? {}
 			const toolInputRequired = (originalSchema?.required ?? []) as string[]
@@ -44,10 +58,6 @@ export function getMcpServerTools(mcpHub?: McpHub): OpenAI.Chat.ChatCompletionTo
 				parameters.required = toolInputRequired
 			}
 
-			// Build sanitized tool name for API compliance
-			// The name is sanitized to conform to API requirements (e.g., Gemini's function name restrictions)
-			const toolName = buildMcpToolName(server.name, tool.name)
-
 			const toolDefinition: OpenAI.Chat.ChatCompletionTool = {
 				type: "function",
 				function: {

+ 17 - 2
src/services/mcp/McpHub.ts

@@ -435,8 +435,23 @@ export class McpHub {
 	}
 
 	getServers(): McpServer[] {
-		// Only return enabled servers
-		return this.connections.filter((conn) => !conn.server.disabled).map((conn) => conn.server)
+		// Only return enabled servers, deduplicating by name with project servers taking priority
+		const enabledConnections = this.connections.filter((conn) => !conn.server.disabled)
+
+		// Deduplicate by server name: project servers take priority over global servers
+		const serversByName = new Map<string, McpServer>()
+		for (const conn of enabledConnections) {
+			const existing = serversByName.get(conn.server.name)
+			if (!existing) {
+				serversByName.set(conn.server.name, conn.server)
+			} else if (conn.server.source === "project" && existing.source !== "project") {
+				// Project server overrides global server with the same name
+				serversByName.set(conn.server.name, conn.server)
+			}
+			// If existing is project and current is global, keep existing (project wins)
+		}
+
+		return Array.from(serversByName.values())
 	}
 
 	getAllServers(): McpServer[] {

+ 81 - 0
src/services/mcp/__tests__/McpHub.spec.ts

@@ -1136,6 +1136,87 @@ describe("McpHub", () => {
 			expect(servers[0].name).toBe("enabled-server")
 		})
 
+		it("should deduplicate servers by name with project servers taking priority", () => {
+			const mockConnections: McpConnection[] = [
+				{
+					type: "connected",
+					server: {
+						name: "shared-server",
+						config: '{"source":"global"}',
+						status: "connected",
+						disabled: false,
+						source: "global",
+					},
+					client: {} as any,
+					transport: {} as any,
+				} as ConnectedMcpConnection,
+				{
+					type: "connected",
+					server: {
+						name: "shared-server",
+						config: '{"source":"project"}',
+						status: "connected",
+						disabled: false,
+						source: "project",
+					},
+					client: {} as any,
+					transport: {} as any,
+				} as ConnectedMcpConnection,
+				{
+					type: "connected",
+					server: {
+						name: "unique-global-server",
+						config: "{}",
+						status: "connected",
+						disabled: false,
+						source: "global",
+					},
+					client: {} as any,
+					transport: {} as any,
+				} as ConnectedMcpConnection,
+			]
+
+			mcpHub.connections = mockConnections
+			const servers = mcpHub.getServers()
+
+			// Should have 2 servers: deduplicated "shared-server" + "unique-global-server"
+			expect(servers.length).toBe(2)
+
+			// Find the shared-server - it should be the project version
+			const sharedServer = servers.find((s) => s.name === "shared-server")
+			expect(sharedServer).toBeDefined()
+			expect(sharedServer!.source).toBe("project")
+			expect(sharedServer!.config).toBe('{"source":"project"}')
+
+			// The unique global server should also be present
+			const uniqueServer = servers.find((s) => s.name === "unique-global-server")
+			expect(uniqueServer).toBeDefined()
+		})
+
+		it("should keep global server when no project server with same name exists", () => {
+			const mockConnections: McpConnection[] = [
+				{
+					type: "connected",
+					server: {
+						name: "global-only-server",
+						config: "{}",
+						status: "connected",
+						disabled: false,
+						source: "global",
+					},
+					client: {} as any,
+					transport: {} as any,
+				} as ConnectedMcpConnection,
+			]
+
+			mcpHub.connections = mockConnections
+			const servers = mcpHub.getServers()
+
+			expect(servers.length).toBe(1)
+			expect(servers[0].name).toBe("global-only-server")
+			expect(servers[0].source).toBe("global")
+		})
+
 		it("should prevent calling tools on disabled servers", async () => {
 			// Mock fs.readFile to return a disabled server config
 			vi.mocked(fs.readFile).mockResolvedValue(