Browse Source

support project-level mcp config

aheizi 11 months ago
parent
commit
4106ca4a19

+ 22 - 0
src/core/webview/ClineProvider.ts

@@ -1114,6 +1114,28 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 						}
 						}
 						break
 						break
 					}
 					}
+					case "openProjectMcpSettings": {
+						if (!vscode.workspace.workspaceFolders?.length) {
+							vscode.window.showErrorMessage("Please open a project folder first")
+							return
+						}
+
+						const workspaceFolder = vscode.workspace.workspaceFolders[0]
+						const rooDir = path.join(workspaceFolder.uri.fsPath, ".roo")
+						const mcpPath = path.join(rooDir, "mcp.json")
+
+						try {
+							await fs.mkdir(rooDir, { recursive: true })
+							const exists = await fileExistsAtPath(mcpPath)
+							if (!exists) {
+								await fs.writeFile(mcpPath, JSON.stringify({ mcpServers: {} }, null, 2))
+							}
+							await openFile(mcpPath)
+						} catch (error) {
+							vscode.window.showErrorMessage(`Failed to create or open .roo/mcp.json: ${error}`)
+						}
+						break
+					}
 					case "openCustomModesSettings": {
 					case "openCustomModesSettings": {
 						const customModesFilePath = await this.customModesManager.getCustomModesFilePath()
 						const customModesFilePath = await this.customModesManager.getCustomModesFilePath()
 						if (customModesFilePath) {
 						if (customModesFilePath) {

+ 124 - 0
src/core/webview/__tests__/ClineProvider.test.ts

@@ -1950,6 +1950,130 @@ describe("ClineProvider", () => {
 	})
 	})
 })
 })
 
 
+describe("Project MCP Settings", () => {
+	let provider: ClineProvider
+	let mockContext: vscode.ExtensionContext
+	let mockOutputChannel: vscode.OutputChannel
+	let mockWebviewView: vscode.WebviewView
+	let mockPostMessage: jest.Mock
+
+	beforeEach(() => {
+		jest.clearAllMocks()
+
+		mockContext = {
+			extensionPath: "/test/path",
+			extensionUri: {} as vscode.Uri,
+			globalState: {
+				get: jest.fn(),
+				update: jest.fn(),
+				keys: jest.fn().mockReturnValue([]),
+			},
+			secrets: {
+				get: jest.fn(),
+				store: jest.fn(),
+				delete: jest.fn(),
+			},
+			subscriptions: [],
+			extension: {
+				packageJSON: { version: "1.0.0" },
+			},
+			globalStorageUri: {
+				fsPath: "/test/storage/path",
+			},
+		} as unknown as vscode.ExtensionContext
+
+		mockOutputChannel = {
+			appendLine: jest.fn(),
+			clear: jest.fn(),
+			dispose: jest.fn(),
+		} as unknown as vscode.OutputChannel
+
+		mockPostMessage = jest.fn()
+		mockWebviewView = {
+			webview: {
+				postMessage: mockPostMessage,
+				html: "",
+				options: {},
+				onDidReceiveMessage: jest.fn(),
+				asWebviewUri: jest.fn(),
+			},
+			visible: true,
+			onDidDispose: jest.fn(),
+			onDidChangeVisibility: jest.fn(),
+		} as unknown as vscode.WebviewView
+
+		provider = new ClineProvider(mockContext, mockOutputChannel)
+	})
+
+	test("handles openProjectMcpSettings message", async () => {
+		await provider.resolveWebviewView(mockWebviewView)
+		const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0]
+
+		// Mock workspace folders
+		;(vscode.workspace as any).workspaceFolders = [{ uri: { fsPath: "/test/workspace" } }]
+
+		// Mock fs functions
+		const fs = require("fs/promises")
+		fs.mkdir.mockResolvedValue(undefined)
+		fs.writeFile.mockResolvedValue(undefined)
+
+		// Trigger openProjectMcpSettings
+		await messageHandler({
+			type: "openProjectMcpSettings",
+		})
+
+		// Verify directory was created
+		expect(fs.mkdir).toHaveBeenCalledWith(
+			expect.stringContaining(".roo"),
+			expect.objectContaining({ recursive: true }),
+		)
+
+		// Verify file was created with default content
+		expect(fs.writeFile).toHaveBeenCalledWith(
+			expect.stringContaining("mcp.json"),
+			JSON.stringify({ mcpServers: {} }, null, 2),
+		)
+	})
+
+	test("handles openProjectMcpSettings when workspace is not open", async () => {
+		await provider.resolveWebviewView(mockWebviewView)
+		const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0]
+
+		// Mock no workspace folders
+		;(vscode.workspace as any).workspaceFolders = []
+
+		// Trigger openProjectMcpSettings
+		await messageHandler({
+			type: "openProjectMcpSettings",
+		})
+
+		// Verify error message was shown
+		expect(vscode.window.showErrorMessage).toHaveBeenCalledWith("Please open a project folder first")
+	})
+
+	test("handles openProjectMcpSettings file creation error", async () => {
+		await provider.resolveWebviewView(mockWebviewView)
+		const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0]
+
+		// Mock workspace folders
+		;(vscode.workspace as any).workspaceFolders = [{ uri: { fsPath: "/test/workspace" } }]
+
+		// Mock fs functions to fail
+		const fs = require("fs/promises")
+		fs.mkdir.mockRejectedValue(new Error("Failed to create directory"))
+
+		// Trigger openProjectMcpSettings
+		await messageHandler({
+			type: "openProjectMcpSettings",
+		})
+
+		// Verify error message was shown
+		expect(vscode.window.showErrorMessage).toHaveBeenCalledWith(
+			expect.stringContaining("Failed to create or open .roo/mcp.json"),
+		)
+	})
+})
+
 describe("ContextProxy integration", () => {
 describe("ContextProxy integration", () => {
 	let provider: ClineProvider
 	let provider: ClineProvider
 	let mockContext: vscode.ExtensionContext
 	let mockContext: vscode.ExtensionContext

+ 115 - 7
src/services/mcp/McpHub.ts

@@ -73,6 +73,7 @@ export class McpHub {
 	private providerRef: WeakRef<ClineProvider>
 	private providerRef: WeakRef<ClineProvider>
 	private disposables: vscode.Disposable[] = []
 	private disposables: vscode.Disposable[] = []
 	private settingsWatcher?: vscode.FileSystemWatcher
 	private settingsWatcher?: vscode.FileSystemWatcher
+	private projectMcpWatcher?: vscode.FileSystemWatcher
 	private fileWatchers: Map<string, FSWatcher> = new Map()
 	private fileWatchers: Map<string, FSWatcher> = new Map()
 	private isDisposed: boolean = false
 	private isDisposed: boolean = false
 	connections: McpConnection[] = []
 	connections: McpConnection[] = []
@@ -81,9 +82,55 @@ export class McpHub {
 	constructor(provider: ClineProvider) {
 	constructor(provider: ClineProvider) {
 		this.providerRef = new WeakRef(provider)
 		this.providerRef = new WeakRef(provider)
 		this.watchMcpSettingsFile()
 		this.watchMcpSettingsFile()
+		this.watchProjectMcpFile()
+		this.setupWorkspaceFoldersWatcher()
 		this.initializeMcpServers()
 		this.initializeMcpServers()
 	}
 	}
 
 
+	private setupWorkspaceFoldersWatcher(): void {
+		this.disposables.push(
+			vscode.workspace.onDidChangeWorkspaceFolders(async () => {
+				await this.updateProjectMcpServers()
+				this.watchProjectMcpFile()
+			}),
+		)
+	}
+
+	private watchProjectMcpFile(): void {
+		this.projectMcpWatcher?.dispose()
+
+		this.projectMcpWatcher = vscode.workspace.createFileSystemWatcher("**/.roo/mcp.json", false, false, false)
+
+		this.disposables.push(
+			this.projectMcpWatcher.onDidChange(async () => {
+				await this.updateProjectMcpServers()
+			}),
+			this.projectMcpWatcher.onDidCreate(async () => {
+				await this.updateProjectMcpServers()
+			}),
+			this.projectMcpWatcher.onDidDelete(async () => {
+				await this.cleanupProjectMcpServers()
+			}),
+		)
+
+		this.disposables.push(this.projectMcpWatcher)
+	}
+
+	private async updateProjectMcpServers(): Promise<void> {
+		await this.cleanupProjectMcpServers()
+		await this.initializeProjectMcpServers()
+	}
+
+	private async cleanupProjectMcpServers(): Promise<void> {
+		const projectServers = this.connections.filter((conn) => conn.server.source === "project")
+
+		for (const conn of projectServers) {
+			await this.deleteConnection(conn.server.name)
+		}
+
+		await this.notifyWebviewOfServerChanges()
+	}
+
 	getServers(): McpServer[] {
 	getServers(): McpServer[] {
 		// Only return enabled servers
 		// Only return enabled servers
 		return this.connections.filter((conn) => !conn.server.disabled).map((conn) => conn.server)
 		return this.connections.filter((conn) => !conn.server.disabled).map((conn) => conn.server)
@@ -158,16 +205,68 @@ export class McpHub {
 
 
 	private async initializeMcpServers(): Promise<void> {
 	private async initializeMcpServers(): Promise<void> {
 		try {
 		try {
+			// 1. Initialize global MCP servers
 			const settingsPath = await this.getMcpSettingsFilePath()
 			const settingsPath = await this.getMcpSettingsFilePath()
 			const content = await fs.readFile(settingsPath, "utf-8")
 			const content = await fs.readFile(settingsPath, "utf-8")
 			const config = JSON.parse(content)
 			const config = JSON.parse(content)
-			await this.updateServerConnections(config.mcpServers || {})
+			await this.updateServerConnections(config.mcpServers || {}, "global")
+
+			// 2. Initialize project-level MCP servers
+			await this.initializeProjectMcpServers()
 		} catch (error) {
 		} catch (error) {
 			console.error("Failed to initialize MCP servers:", error)
 			console.error("Failed to initialize MCP servers:", error)
 		}
 		}
 	}
 	}
 
 
-	private async connectToServer(name: string, config: z.infer<typeof ServerConfigSchema>): Promise<void> {
+	// Get project-level MCP configuration path
+	private async getProjectMcpPath(): Promise<string | null> {
+		if (!vscode.workspace.workspaceFolders?.length) {
+			return null
+		}
+
+		const workspaceFolder = vscode.workspace.workspaceFolders[0]
+		const projectMcpDir = path.join(workspaceFolder.uri.fsPath, ".roo")
+		const projectMcpPath = path.join(projectMcpDir, "mcp.json")
+
+		try {
+			await fs.access(projectMcpPath)
+			return projectMcpPath
+		} catch {
+			return null
+		}
+	}
+
+	// Initialize project-level MCP servers
+	private async initializeProjectMcpServers(): Promise<void> {
+		const projectMcpPath = await this.getProjectMcpPath()
+		if (!projectMcpPath) {
+			return
+		}
+
+		try {
+			const content = await fs.readFile(projectMcpPath, "utf-8")
+			const config = JSON.parse(content)
+
+			// Validate configuration structure
+			const result = McpSettingsSchema.safeParse(config)
+			if (!result.success) {
+				vscode.window.showErrorMessage("项目 MCP 配置格式无效")
+				return
+			}
+
+			// Update server connections
+			await this.updateServerConnections(result.data.mcpServers || {}, "project")
+		} catch (error) {
+			console.error("Failed to initialize project MCP servers:", error)
+			vscode.window.showErrorMessage(`初始化项目 MCP 服务器失败: ${error}`)
+		}
+	}
+
+	private async connectToServer(
+		name: string,
+		config: z.infer<typeof ServerConfigSchema>,
+		source: "global" | "project" = "global",
+	): Promise<void> {
 		// Remove existing connection if it exists
 		// Remove existing connection if it exists
 		await this.deleteConnection(name)
 		await this.deleteConnection(name)
 
 
@@ -272,6 +371,8 @@ export class McpHub {
 					config: JSON.stringify(config),
 					config: JSON.stringify(config),
 					status: "connecting",
 					status: "connecting",
 					disabled: config.disabled,
 					disabled: config.disabled,
+					source,
+					projectPath: source === "project" ? vscode.workspace.workspaceFolders?.[0]?.uri.fsPath : undefined,
 				},
 				},
 				client,
 				client,
 				transport,
 				transport,
@@ -366,10 +467,17 @@ export class McpHub {
 		}
 		}
 	}
 	}
 
 
-	async updateServerConnections(newServers: Record<string, any>): Promise<void> {
+	async updateServerConnections(
+		newServers: Record<string, any>,
+		source: "global" | "project" = "global",
+	): Promise<void> {
 		this.isConnecting = true
 		this.isConnecting = true
 		this.removeAllFileWatchers()
 		this.removeAllFileWatchers()
-		const currentNames = new Set(this.connections.map((conn) => conn.server.name))
+		// Filter connections by source
+		const currentConnections = this.connections.filter(
+			(conn) => conn.server.source === source || (!conn.server.source && source === "global"),
+		)
+		const currentNames = new Set(currentConnections.map((conn) => conn.server.name))
 		const newNames = new Set(Object.keys(newServers))
 		const newNames = new Set(Object.keys(newServers))
 
 
 		// Delete removed servers
 		// Delete removed servers
@@ -388,7 +496,7 @@ export class McpHub {
 				// New server
 				// New server
 				try {
 				try {
 					this.setupFileWatcher(name, config)
 					this.setupFileWatcher(name, config)
-					await this.connectToServer(name, config)
+					await this.connectToServer(name, config, source)
 				} catch (error) {
 				} catch (error) {
 					console.error(`Failed to connect to new MCP server ${name}:`, error)
 					console.error(`Failed to connect to new MCP server ${name}:`, error)
 				}
 				}
@@ -397,8 +505,8 @@ export class McpHub {
 				try {
 				try {
 					this.setupFileWatcher(name, config)
 					this.setupFileWatcher(name, config)
 					await this.deleteConnection(name)
 					await this.deleteConnection(name)
-					await this.connectToServer(name, config)
-					console.log(`Reconnected MCP server with updated config: ${name}`)
+					await this.connectToServer(name, config, source)
+					console.log(`Reconnected ${source} MCP server with updated config: ${name}`)
 				} catch (error) {
 				} catch (error) {
 					console.error(`Failed to reconnect MCP server ${name}:`, error)
 					console.error(`Failed to reconnect MCP server ${name}:`, error)
 				}
 				}

+ 1 - 0
src/shared/WebviewMessage.ts

@@ -59,6 +59,7 @@ export interface WebviewMessage {
 		| "screenshotQuality"
 		| "screenshotQuality"
 		| "remoteBrowserHost"
 		| "remoteBrowserHost"
 		| "openMcpSettings"
 		| "openMcpSettings"
+		| "openProjectMcpSettings"
 		| "restartMcpServer"
 		| "restartMcpServer"
 		| "toggleToolAlwaysAllow"
 		| "toggleToolAlwaysAllow"
 		| "toggleMcpServer"
 		| "toggleMcpServer"

+ 2 - 0
src/shared/mcp.ts

@@ -8,6 +8,8 @@ export type McpServer = {
 	resourceTemplates?: McpResourceTemplate[]
 	resourceTemplates?: McpResourceTemplate[]
 	disabled?: boolean
 	disabled?: boolean
 	timeout?: number
 	timeout?: number
+	source?: "global" | "project"
+	projectPath?: string
 }
 }
 
 
 export type McpTool = {
 export type McpTool = {

+ 29 - 5
webview-ui/src/components/mcp/McpView.tsx

@@ -94,16 +94,25 @@ const McpView = ({ onDone }: McpViewProps) => {
 							</div>
 							</div>
 						)}
 						)}
 
 
-						{/* Edit Settings Button */}
-						<div style={{ marginTop: "10px", width: "100%" }}>
+						{/* Edit Settings Buttons */}
+						<div style={{ marginTop: "10px", width: "100%", display: "flex", gap: "10px" }}>
 							<VSCodeButton
 							<VSCodeButton
 								appearance="secondary"
 								appearance="secondary"
-								style={{ width: "100%" }}
+								style={{ flex: 1 }}
 								onClick={() => {
 								onClick={() => {
 									vscode.postMessage({ type: "openMcpSettings" })
 									vscode.postMessage({ type: "openMcpSettings" })
 								}}>
 								}}>
 								<span className="codicon codicon-edit" style={{ marginRight: "6px" }}></span>
 								<span className="codicon codicon-edit" style={{ marginRight: "6px" }}></span>
-								Edit MCP Settings
+								Edit Global MCP
+							</VSCodeButton>
+							<VSCodeButton
+								appearance="secondary"
+								style={{ flex: 1 }}
+								onClick={() => {
+									vscode.postMessage({ type: "openProjectMcpSettings" })
+								}}>
+								<span className="codicon codicon-edit" style={{ marginRight: "6px" }}></span>
+								Edit Project MCP
 							</VSCodeButton>
 							</VSCodeButton>
 						</div>
 						</div>
 					</>
 					</>
@@ -193,7 +202,22 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer; alwaysAllowM
 						style={{ marginRight: "8px" }}
 						style={{ marginRight: "8px" }}
 					/>
 					/>
 				)}
 				)}
-				<span style={{ flex: 1 }}>{server.name}</span>
+				<span style={{ flex: 1 }}>
+					{server.name}
+					{server.source && (
+						<span
+							style={{
+								marginLeft: "8px",
+								padding: "1px 6px",
+								fontSize: "11px",
+								borderRadius: "4px",
+								background: "var(--vscode-badge-background)",
+								color: "var(--vscode-badge-foreground)",
+							}}>
+							{server.source}
+						</span>
+					)}
+				</span>
 				<div
 				<div
 					style={{ display: "flex", alignItems: "center", marginRight: "8px" }}
 					style={{ display: "flex", alignItems: "center", marginRight: "8px" }}
 					onClick={(e) => e.stopPropagation()}>
 					onClick={(e) => e.stopPropagation()}>