|
|
@@ -1,7 +1,7 @@
|
|
|
-import type { McpHub as McpHubType, McpConnection } from "../McpHub"
|
|
|
+import type { McpHub as McpHubType, McpConnection, ConnectedMcpConnection, DisconnectedMcpConnection } from "../McpHub"
|
|
|
import type { ClineProvider } from "../../../core/webview/ClineProvider"
|
|
|
import type { ExtensionContext, Uri } from "vscode"
|
|
|
-import { ServerConfigSchema, McpHub } from "../McpHub"
|
|
|
+import { ServerConfigSchema, McpHub, DisableReason } from "../McpHub"
|
|
|
import fs from "fs/promises"
|
|
|
import { vi, Mock } from "vitest"
|
|
|
|
|
|
@@ -33,11 +33,15 @@ vi.mock("fs/promises", () => ({
|
|
|
mkdir: vi.fn().mockResolvedValue(undefined),
|
|
|
}))
|
|
|
|
|
|
+// Import safeWriteJson to use in mocks
|
|
|
+import { safeWriteJson } from "../../../utils/safeWriteJson"
|
|
|
+
|
|
|
// Mock safeWriteJson
|
|
|
vi.mock("../../../utils/safeWriteJson", () => ({
|
|
|
safeWriteJson: vi.fn(async (filePath, data) => {
|
|
|
// Instead of trying to write to the file system, just call fs.writeFile mock
|
|
|
// This avoids the complex file locking and temp file operations
|
|
|
+ const fs = await import("fs/promises")
|
|
|
return fs.writeFile(filePath, JSON.stringify(data), "utf8")
|
|
|
}),
|
|
|
}))
|
|
|
@@ -79,6 +83,16 @@ vi.mock("@modelcontextprotocol/sdk/client/index.js", () => ({
|
|
|
Client: vi.fn(),
|
|
|
}))
|
|
|
|
|
|
+// Mock chokidar
|
|
|
+vi.mock("chokidar", () => ({
|
|
|
+ default: {
|
|
|
+ watch: vi.fn().mockReturnValue({
|
|
|
+ on: vi.fn().mockReturnThis(),
|
|
|
+ close: vi.fn(),
|
|
|
+ }),
|
|
|
+ },
|
|
|
+}))
|
|
|
+
|
|
|
describe("McpHub", () => {
|
|
|
let mcpHub: McpHubType
|
|
|
let mockProvider: Partial<ClineProvider>
|
|
|
@@ -108,6 +122,7 @@ describe("McpHub", () => {
|
|
|
ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
|
|
|
ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
|
|
|
postMessageToWebview: vi.fn(),
|
|
|
+ getState: vi.fn().mockResolvedValue({ mcpEnabled: true }),
|
|
|
context: {
|
|
|
subscriptions: [],
|
|
|
workspaceState: {} as any,
|
|
|
@@ -140,31 +155,612 @@ describe("McpHub", () => {
|
|
|
} as ExtensionContext,
|
|
|
}
|
|
|
|
|
|
- // Mock fs.readFile for initial settings
|
|
|
- vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
- JSON.stringify({
|
|
|
- mcpServers: {
|
|
|
- "test-server": {
|
|
|
- type: "stdio",
|
|
|
- command: "node",
|
|
|
- args: ["test.js"],
|
|
|
- alwaysAllow: ["allowed-tool"],
|
|
|
- disabledTools: ["disabled-tool"],
|
|
|
- },
|
|
|
+ // Mock fs.readFile for initial settings
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "test-server": {
|
|
|
+ type: "stdio",
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ alwaysAllow: ["allowed-tool"],
|
|
|
+ disabledTools: ["disabled-tool"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ })
|
|
|
+
|
|
|
+ afterEach(() => {
|
|
|
+ // Restore original console methods
|
|
|
+ console.error = originalConsoleError
|
|
|
+ // Restore original platform
|
|
|
+ if (originalPlatform) {
|
|
|
+ Object.defineProperty(process, "platform", originalPlatform)
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ describe("Discriminated union type handling", () => {
|
|
|
+ it("should create connected connections with proper type", async () => {
|
|
|
+ // Mock StdioClientTransport
|
|
|
+ const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js")
|
|
|
+ const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockTransport = {
|
|
|
+ start: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ stderr: {
|
|
|
+ on: vi.fn(),
|
|
|
+ },
|
|
|
+ onerror: null,
|
|
|
+ onclose: null,
|
|
|
+ }
|
|
|
+
|
|
|
+ StdioClientTransport.mockImplementation(() => mockTransport)
|
|
|
+
|
|
|
+ // Mock Client
|
|
|
+ const clientModule = await import("@modelcontextprotocol/sdk/client/index.js")
|
|
|
+ const Client = clientModule.Client as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockClient = {
|
|
|
+ connect: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ getInstructions: vi.fn().mockReturnValue("test instructions"),
|
|
|
+ request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }),
|
|
|
+ }
|
|
|
+
|
|
|
+ Client.mockImplementation(() => mockClient)
|
|
|
+
|
|
|
+ // Mock the config file read
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "union-test-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ // Create McpHub and let it initialize
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Find the connection
|
|
|
+ const connection = mcpHub.connections.find((conn) => conn.server.name === "union-test-server")
|
|
|
+ expect(connection).toBeDefined()
|
|
|
+
|
|
|
+ // Type guard check - connected connections should have client and transport
|
|
|
+ if (connection && connection.type === "connected") {
|
|
|
+ expect(connection.client).toBeDefined()
|
|
|
+ expect(connection.transport).toBeDefined()
|
|
|
+ expect(connection.server.status).toBe("connected")
|
|
|
+ } else {
|
|
|
+ throw new Error("Connection should be of type 'connected'")
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should create disconnected connections for disabled servers", async () => {
|
|
|
+ // Mock the config file read with a disabled server
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "disabled-union-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ disabled: true,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ // Create McpHub and let it initialize
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Find the connection
|
|
|
+ const connection = mcpHub.connections.find((conn) => conn.server.name === "disabled-union-server")
|
|
|
+ expect(connection).toBeDefined()
|
|
|
+
|
|
|
+ // Type guard check - disconnected connections should have null client and transport
|
|
|
+ if (connection && connection.type === "disconnected") {
|
|
|
+ expect(connection.client).toBeNull()
|
|
|
+ expect(connection.transport).toBeNull()
|
|
|
+ expect(connection.server.status).toBe("disconnected")
|
|
|
+ expect(connection.server.disabled).toBe(true)
|
|
|
+ } else {
|
|
|
+ throw new Error("Connection should be of type 'disconnected'")
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle type narrowing correctly in callTool", async () => {
|
|
|
+ // Mock fs.readFile to return empty config so no servers are initialized
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {},
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ // Create a mock McpHub instance
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+
|
|
|
+ // Wait for initialization
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Clear any connections that might have been created
|
|
|
+ mcpHub.connections = []
|
|
|
+
|
|
|
+ // Directly set up a connected connection
|
|
|
+ const connectedConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
+ server: {
|
|
|
+ name: "test-server",
|
|
|
+ config: JSON.stringify({ command: "node", args: ["test.js"] }),
|
|
|
+ status: "connected",
|
|
|
+ source: "global",
|
|
|
+ errorHistory: [],
|
|
|
+ } as any,
|
|
|
+ client: {
|
|
|
+ request: vi.fn().mockResolvedValue({ result: "success" }),
|
|
|
+ } as any,
|
|
|
+ transport: {} as any,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Add the connected connection
|
|
|
+ mcpHub.connections = [connectedConnection]
|
|
|
+
|
|
|
+ // Call tool should work with connected server
|
|
|
+ const result = await mcpHub.callTool("test-server", "test-tool", {})
|
|
|
+ expect(result).toEqual({ result: "success" })
|
|
|
+ expect(connectedConnection.client.request).toHaveBeenCalled()
|
|
|
+
|
|
|
+ // Now test with a disconnected connection
|
|
|
+ const disconnectedConnection: DisconnectedMcpConnection = {
|
|
|
+ type: "disconnected",
|
|
|
+ server: {
|
|
|
+ name: "disabled-server",
|
|
|
+ config: JSON.stringify({ command: "node", args: ["test.js"], disabled: true }),
|
|
|
+ status: "disconnected",
|
|
|
+ disabled: true,
|
|
|
+ source: "global",
|
|
|
+ errorHistory: [],
|
|
|
+ } as any,
|
|
|
+ client: null,
|
|
|
+ transport: null,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Replace connections with disconnected one
|
|
|
+ mcpHub.connections = [disconnectedConnection]
|
|
|
+
|
|
|
+ // Call tool should fail with disconnected server
|
|
|
+ await expect(mcpHub.callTool("disabled-server", "test-tool", {})).rejects.toThrow(
|
|
|
+ "No connection found for server: disabled-server",
|
|
|
+ )
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ describe("File watcher cleanup", () => {
|
|
|
+ it("should clean up file watchers when server is disabled", async () => {
|
|
|
+ // Get the mocked chokidar
|
|
|
+ const chokidar = (await import("chokidar")).default
|
|
|
+ const mockWatcher = {
|
|
|
+ on: vi.fn().mockReturnThis(),
|
|
|
+ close: vi.fn(),
|
|
|
+ }
|
|
|
+ vi.mocked(chokidar.watch).mockReturnValue(mockWatcher as any)
|
|
|
+
|
|
|
+ // Mock StdioClientTransport
|
|
|
+ const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js")
|
|
|
+ const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockTransport = {
|
|
|
+ start: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ stderr: {
|
|
|
+ on: vi.fn(),
|
|
|
+ },
|
|
|
+ onerror: null,
|
|
|
+ onclose: null,
|
|
|
+ }
|
|
|
+
|
|
|
+ StdioClientTransport.mockImplementation(() => mockTransport)
|
|
|
+
|
|
|
+ // Mock Client
|
|
|
+ const clientModule = await import("@modelcontextprotocol/sdk/client/index.js")
|
|
|
+ const Client = clientModule.Client as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockClient = {
|
|
|
+ connect: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ getInstructions: vi.fn().mockReturnValue("test instructions"),
|
|
|
+ request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }),
|
|
|
+ }
|
|
|
+
|
|
|
+ Client.mockImplementation(() => mockClient)
|
|
|
+
|
|
|
+ // Create server with watchPaths
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "watcher-test-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ watchPaths: ["/path/to/watch"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Verify watcher was created
|
|
|
+ expect(chokidar.watch).toHaveBeenCalledWith(["/path/to/watch"], expect.any(Object))
|
|
|
+
|
|
|
+ // Now disable the server
|
|
|
+ await mcpHub.toggleServerDisabled("watcher-test-server", true)
|
|
|
+
|
|
|
+ // Verify watcher was closed
|
|
|
+ expect(mockWatcher.close).toHaveBeenCalled()
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should clean up all file watchers when server is deleted", async () => {
|
|
|
+ // Get the mocked chokidar
|
|
|
+ const chokidar = (await import("chokidar")).default
|
|
|
+ const mockWatcher1 = {
|
|
|
+ on: vi.fn().mockReturnThis(),
|
|
|
+ close: vi.fn(),
|
|
|
+ }
|
|
|
+ const mockWatcher2 = {
|
|
|
+ on: vi.fn().mockReturnThis(),
|
|
|
+ close: vi.fn(),
|
|
|
+ }
|
|
|
+
|
|
|
+ // Return different watchers for different paths
|
|
|
+ let watcherIndex = 0
|
|
|
+ vi.mocked(chokidar.watch).mockImplementation(() => {
|
|
|
+ return (watcherIndex++ === 0 ? mockWatcher1 : mockWatcher2) as any
|
|
|
+ })
|
|
|
+
|
|
|
+ // Mock StdioClientTransport
|
|
|
+ const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js")
|
|
|
+ const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockTransport = {
|
|
|
+ start: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ stderr: {
|
|
|
+ on: vi.fn(),
|
|
|
+ },
|
|
|
+ onerror: null,
|
|
|
+ onclose: null,
|
|
|
+ }
|
|
|
+
|
|
|
+ StdioClientTransport.mockImplementation(() => mockTransport)
|
|
|
+
|
|
|
+ // Mock Client
|
|
|
+ const clientModule = await import("@modelcontextprotocol/sdk/client/index.js")
|
|
|
+ const Client = clientModule.Client as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockClient = {
|
|
|
+ connect: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ getInstructions: vi.fn().mockReturnValue("test instructions"),
|
|
|
+ request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }),
|
|
|
+ }
|
|
|
+
|
|
|
+ Client.mockImplementation(() => mockClient)
|
|
|
+
|
|
|
+ // Create server with multiple watchPaths
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "multi-watcher-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js", "build/index.js"], // This will create a watcher for build/index.js
|
|
|
+ watchPaths: ["/path/to/watch1", "/path/to/watch2"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Verify watchers were created
|
|
|
+ expect(chokidar.watch).toHaveBeenCalled()
|
|
|
+
|
|
|
+ // Delete the connection (this should clean up all watchers)
|
|
|
+ await mcpHub.deleteConnection("multi-watcher-server")
|
|
|
+
|
|
|
+ // Verify all watchers were closed
|
|
|
+ expect(mockWatcher1.close).toHaveBeenCalled()
|
|
|
+ expect(mockWatcher2.close).toHaveBeenCalled()
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should not create file watchers for disabled servers on initialization", async () => {
|
|
|
+ // Get the mocked chokidar
|
|
|
+ const chokidar = (await import("chokidar")).default
|
|
|
+
|
|
|
+ // Create disabled server with watchPaths
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "disabled-watcher-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ watchPaths: ["/path/to/watch"],
|
|
|
+ disabled: true,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ vi.mocked(chokidar.watch).mockClear()
|
|
|
+
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Verify no watcher was created for disabled server
|
|
|
+ expect(chokidar.watch).not.toHaveBeenCalled()
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ describe("DisableReason enum usage", () => {
|
|
|
+ it("should use MCP_DISABLED reason when MCP is globally disabled", async () => {
|
|
|
+ // Mock provider with mcpEnabled: false
|
|
|
+ mockProvider.getState = vi.fn().mockResolvedValue({ mcpEnabled: false })
|
|
|
+
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "mcp-disabled-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Find the connection
|
|
|
+ const connection = mcpHub.connections.find((conn) => conn.server.name === "mcp-disabled-server")
|
|
|
+ expect(connection).toBeDefined()
|
|
|
+ expect(connection?.type).toBe("disconnected")
|
|
|
+ expect(connection?.server.status).toBe("disconnected")
|
|
|
+
|
|
|
+ // The server should not be marked as disabled individually
|
|
|
+ expect(connection?.server.disabled).toBeUndefined()
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should use SERVER_DISABLED reason when server is individually disabled", async () => {
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "server-disabled-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ disabled: true,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Find the connection
|
|
|
+ const connection = mcpHub.connections.find((conn) => conn.server.name === "server-disabled-server")
|
|
|
+ expect(connection).toBeDefined()
|
|
|
+ expect(connection?.type).toBe("disconnected")
|
|
|
+ expect(connection?.server.status).toBe("disconnected")
|
|
|
+ expect(connection?.server.disabled).toBe(true)
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle both disable reasons correctly", async () => {
|
|
|
+ // First test with MCP globally disabled
|
|
|
+ mockProvider.getState = vi.fn().mockResolvedValue({ mcpEnabled: false })
|
|
|
+
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "both-reasons-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ disabled: true, // Server is also individually disabled
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Find the connection
|
|
|
+ const connection = mcpHub.connections.find((conn) => conn.server.name === "both-reasons-server")
|
|
|
+ expect(connection).toBeDefined()
|
|
|
+ expect(connection?.type).toBe("disconnected")
|
|
|
+
|
|
|
+ // When MCP is globally disabled, it takes precedence
|
|
|
+ // The server's individual disabled state should be preserved
|
|
|
+ expect(connection?.server.disabled).toBe(true)
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ describe("Null safety improvements", () => {
|
|
|
+ it("should handle null client safely in disconnected connections", async () => {
|
|
|
+ // Mock fs.readFile to return a disabled server config
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "null-safety-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ disabled: true,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+
|
|
|
+ // Wait for initialization
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // The server should be created as a disconnected connection with null client/transport
|
|
|
+ const connection = mcpHub.connections.find((conn) => conn.server.name === "null-safety-server")
|
|
|
+ expect(connection).toBeDefined()
|
|
|
+ expect(connection?.type).toBe("disconnected")
|
|
|
+
|
|
|
+ // Type guard to ensure it's a disconnected connection
|
|
|
+ if (connection?.type === "disconnected") {
|
|
|
+ expect(connection.client).toBeNull()
|
|
|
+ expect(connection.transport).toBeNull()
|
|
|
+ }
|
|
|
+
|
|
|
+ // Try to call tool on disconnected server
|
|
|
+ await expect(mcpHub.callTool("null-safety-server", "test-tool", {})).rejects.toThrow(
|
|
|
+ "No connection found for server: null-safety-server",
|
|
|
+ )
|
|
|
+
|
|
|
+ // Try to read resource on disconnected server
|
|
|
+ await expect(mcpHub.readResource("null-safety-server", "test-uri")).rejects.toThrow(
|
|
|
+ "No connection found for server: null-safety-server",
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle connection type checks safely", async () => {
|
|
|
+ // Mock StdioClientTransport
|
|
|
+ const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js")
|
|
|
+ const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockTransport = {
|
|
|
+ start: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ stderr: {
|
|
|
+ on: vi.fn(),
|
|
|
+ },
|
|
|
+ onerror: null,
|
|
|
+ onclose: null,
|
|
|
+ }
|
|
|
+
|
|
|
+ StdioClientTransport.mockImplementation(() => mockTransport)
|
|
|
+
|
|
|
+ // Mock Client
|
|
|
+ const clientModule = await import("@modelcontextprotocol/sdk/client/index.js")
|
|
|
+ const Client = clientModule.Client as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockClient = {
|
|
|
+ connect: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ getInstructions: vi.fn().mockReturnValue("test instructions"),
|
|
|
+ request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }),
|
|
|
+ }
|
|
|
+
|
|
|
+ Client.mockImplementation(() => mockClient)
|
|
|
+
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "type-check-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Get the connection
|
|
|
+ const connection = mcpHub.connections.find((conn) => conn.server.name === "type-check-server")
|
|
|
+ expect(connection).toBeDefined()
|
|
|
+
|
|
|
+ // Safe type checking
|
|
|
+ if (connection?.type === "connected") {
|
|
|
+ expect(connection.client).toBeDefined()
|
|
|
+ expect(connection.transport).toBeDefined()
|
|
|
+ } else if (connection?.type === "disconnected") {
|
|
|
+ expect(connection.client).toBeNull()
|
|
|
+ expect(connection.transport).toBeNull()
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle missing connections safely", async () => {
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Try operations on non-existent server
|
|
|
+ await expect(mcpHub.callTool("non-existent-server", "test-tool", {})).rejects.toThrow(
|
|
|
+ "No connection found for server: non-existent-server",
|
|
|
+ )
|
|
|
+
|
|
|
+ await expect(mcpHub.readResource("non-existent-server", "test-uri")).rejects.toThrow(
|
|
|
+ "No connection found for server: non-existent-server",
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle connection deletion safely", async () => {
|
|
|
+ // Mock StdioClientTransport
|
|
|
+ const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js")
|
|
|
+ const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockTransport = {
|
|
|
+ start: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ stderr: {
|
|
|
+ on: vi.fn(),
|
|
|
},
|
|
|
- }),
|
|
|
- )
|
|
|
+ onerror: null,
|
|
|
+ onclose: null,
|
|
|
+ }
|
|
|
|
|
|
- mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
- })
|
|
|
+ StdioClientTransport.mockImplementation(() => mockTransport)
|
|
|
|
|
|
- afterEach(() => {
|
|
|
- // Restore original console methods
|
|
|
- console.error = originalConsoleError
|
|
|
- // Restore original platform
|
|
|
- if (originalPlatform) {
|
|
|
- Object.defineProperty(process, "platform", originalPlatform)
|
|
|
- }
|
|
|
+ // Mock Client
|
|
|
+ const clientModule = await import("@modelcontextprotocol/sdk/client/index.js")
|
|
|
+ const Client = clientModule.Client as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockClient = {
|
|
|
+ connect: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ getInstructions: vi.fn().mockReturnValue("test instructions"),
|
|
|
+ request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }),
|
|
|
+ }
|
|
|
+
|
|
|
+ Client.mockImplementation(() => mockClient)
|
|
|
+
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "delete-safety-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Delete the connection
|
|
|
+ await mcpHub.deleteConnection("delete-safety-server")
|
|
|
+
|
|
|
+ // Verify connection is removed
|
|
|
+ const connection = mcpHub.connections.find((conn) => conn.server.name === "delete-safety-server")
|
|
|
+ expect(connection).toBeUndefined()
|
|
|
+
|
|
|
+ // Verify transport and client were closed
|
|
|
+ expect(mockTransport.close).toHaveBeenCalled()
|
|
|
+ expect(mockClient.close).toHaveBeenCalled()
|
|
|
+ })
|
|
|
})
|
|
|
|
|
|
describe("toggleToolAlwaysAllow", () => {
|
|
|
@@ -184,7 +780,8 @@ describe("McpHub", () => {
|
|
|
vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig))
|
|
|
|
|
|
// Set up mock connection without alwaysAllow
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
type: "stdio",
|
|
|
@@ -232,7 +829,8 @@ describe("McpHub", () => {
|
|
|
vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig))
|
|
|
|
|
|
// Set up mock connection
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
type: "stdio",
|
|
|
@@ -280,7 +878,8 @@ describe("McpHub", () => {
|
|
|
vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig))
|
|
|
|
|
|
// Set up mock connection
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
type: "stdio",
|
|
|
@@ -325,7 +924,8 @@ describe("McpHub", () => {
|
|
|
}
|
|
|
|
|
|
// Set up mock connection
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
config: "test-server-config",
|
|
|
@@ -372,7 +972,8 @@ describe("McpHub", () => {
|
|
|
}
|
|
|
|
|
|
// Set up mock connection
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
config: "test-server-config",
|
|
|
@@ -418,7 +1019,8 @@ describe("McpHub", () => {
|
|
|
}
|
|
|
|
|
|
// Set up mock connection
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
config: "test-server-config",
|
|
|
@@ -468,7 +1070,8 @@ describe("McpHub", () => {
|
|
|
vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig))
|
|
|
|
|
|
// Set up mock connection
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
type: "stdio",
|
|
|
@@ -500,6 +1103,7 @@ describe("McpHub", () => {
|
|
|
it("should filter out disabled servers from getServers", () => {
|
|
|
const mockConnections: McpConnection[] = [
|
|
|
{
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "enabled-server",
|
|
|
config: "{}",
|
|
|
@@ -508,17 +1112,18 @@ describe("McpHub", () => {
|
|
|
},
|
|
|
client: {} as any,
|
|
|
transport: {} as any,
|
|
|
- },
|
|
|
+ } as ConnectedMcpConnection,
|
|
|
{
|
|
|
+ type: "disconnected",
|
|
|
server: {
|
|
|
name: "disabled-server",
|
|
|
config: "{}",
|
|
|
- status: "connected",
|
|
|
+ status: "disconnected",
|
|
|
disabled: true,
|
|
|
},
|
|
|
- client: {} as any,
|
|
|
- transport: {} as any,
|
|
|
- },
|
|
|
+ client: null,
|
|
|
+ transport: null,
|
|
|
+ } as DisconnectedMcpConnection,
|
|
|
]
|
|
|
|
|
|
mcpHub.connections = mockConnections
|
|
|
@@ -529,44 +1134,64 @@ describe("McpHub", () => {
|
|
|
})
|
|
|
|
|
|
it("should prevent calling tools on disabled servers", async () => {
|
|
|
- const mockConnection: McpConnection = {
|
|
|
- server: {
|
|
|
- name: "disabled-server",
|
|
|
- config: "{}",
|
|
|
- status: "connected",
|
|
|
- disabled: true,
|
|
|
- },
|
|
|
- client: {
|
|
|
- request: vi.fn().mockResolvedValue({ result: "success" }),
|
|
|
- } as any,
|
|
|
- transport: {} as any,
|
|
|
- }
|
|
|
+ // Mock fs.readFile to return a disabled server config
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "disabled-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ disabled: true,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
|
|
|
- mcpHub.connections = [mockConnection]
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+
|
|
|
+ // Wait for initialization
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
|
|
|
+ // The server should be created as a disconnected connection
|
|
|
+ const connection = mcpHub.connections.find((conn) => conn.server.name === "disabled-server")
|
|
|
+ expect(connection).toBeDefined()
|
|
|
+ expect(connection?.type).toBe("disconnected")
|
|
|
+ expect(connection?.server.disabled).toBe(true)
|
|
|
+
|
|
|
+ // Try to call tool on disabled server
|
|
|
await expect(mcpHub.callTool("disabled-server", "some-tool", {})).rejects.toThrow(
|
|
|
- 'Server "disabled-server" is disabled and cannot be used',
|
|
|
+ "No connection found for server: disabled-server",
|
|
|
)
|
|
|
})
|
|
|
|
|
|
it("should prevent reading resources from disabled servers", async () => {
|
|
|
- const mockConnection: McpConnection = {
|
|
|
- server: {
|
|
|
- name: "disabled-server",
|
|
|
- config: "{}",
|
|
|
- status: "connected",
|
|
|
- disabled: true,
|
|
|
- },
|
|
|
- client: {
|
|
|
- request: vi.fn(),
|
|
|
- } as any,
|
|
|
- transport: {} as any,
|
|
|
- }
|
|
|
+ // Mock fs.readFile to return a disabled server config
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "disabled-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ disabled: true,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
|
|
|
- mcpHub.connections = [mockConnection]
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+
|
|
|
+ // Wait for initialization
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
|
|
|
+ // The server should be created as a disconnected connection
|
|
|
+ const connection = mcpHub.connections.find((conn) => conn.server.name === "disabled-server")
|
|
|
+ expect(connection).toBeDefined()
|
|
|
+ expect(connection?.type).toBe("disconnected")
|
|
|
+ expect(connection?.server.disabled).toBe(true)
|
|
|
+
|
|
|
+ // Try to read resource from disabled server
|
|
|
await expect(mcpHub.readResource("disabled-server", "some/uri")).rejects.toThrow(
|
|
|
- 'Server "disabled-server" is disabled',
|
|
|
+ "No connection found for server: disabled-server",
|
|
|
)
|
|
|
})
|
|
|
})
|
|
|
@@ -574,7 +1199,8 @@ describe("McpHub", () => {
|
|
|
describe("callTool", () => {
|
|
|
it("should execute tool successfully", async () => {
|
|
|
// Mock the connection with a minimal client implementation
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
config: JSON.stringify({}),
|
|
|
@@ -595,7 +1221,7 @@ describe("McpHub", () => {
|
|
|
await mcpHub.callTool("test-server", "some-tool", {})
|
|
|
|
|
|
// Verify the request was made with correct parameters
|
|
|
- expect(mockConnection.client.request).toHaveBeenCalledWith(
|
|
|
+ expect(mockConnection.client!.request).toHaveBeenCalledWith(
|
|
|
{
|
|
|
method: "tools/call",
|
|
|
params: {
|
|
|
@@ -637,7 +1263,8 @@ describe("McpHub", () => {
|
|
|
})
|
|
|
|
|
|
it("should use default timeout of 60 seconds if not specified", async () => {
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
config: JSON.stringify({ type: "stdio", command: "test" }), // No timeout specified
|
|
|
@@ -652,7 +1279,7 @@ describe("McpHub", () => {
|
|
|
mcpHub.connections = [mockConnection]
|
|
|
await mcpHub.callTool("test-server", "test-tool")
|
|
|
|
|
|
- expect(mockConnection.client.request).toHaveBeenCalledWith(
|
|
|
+ expect(mockConnection.client!.request).toHaveBeenCalledWith(
|
|
|
expect.anything(),
|
|
|
expect.anything(),
|
|
|
expect.objectContaining({ timeout: 60000 }), // 60 seconds in milliseconds
|
|
|
@@ -660,7 +1287,8 @@ describe("McpHub", () => {
|
|
|
})
|
|
|
|
|
|
it("should apply configured timeout to tool calls", async () => {
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
config: JSON.stringify({ type: "stdio", command: "test", timeout: 120 }), // 2 minutes
|
|
|
@@ -675,7 +1303,7 @@ describe("McpHub", () => {
|
|
|
mcpHub.connections = [mockConnection]
|
|
|
await mcpHub.callTool("test-server", "test-tool")
|
|
|
|
|
|
- expect(mockConnection.client.request).toHaveBeenCalledWith(
|
|
|
+ expect(mockConnection.client!.request).toHaveBeenCalledWith(
|
|
|
expect.anything(),
|
|
|
expect.anything(),
|
|
|
expect.objectContaining({ timeout: 120000 }), // 120 seconds in milliseconds
|
|
|
@@ -700,7 +1328,8 @@ describe("McpHub", () => {
|
|
|
vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig))
|
|
|
|
|
|
// Set up mock connection
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
type: "stdio",
|
|
|
@@ -745,7 +1374,8 @@ describe("McpHub", () => {
|
|
|
vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig))
|
|
|
|
|
|
// Set up mock connection before updating
|
|
|
- const mockConnectionInitial: McpConnection = {
|
|
|
+ const mockConnectionInitial: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
type: "stdio",
|
|
|
@@ -768,7 +1398,8 @@ describe("McpHub", () => {
|
|
|
expect(fs.writeFile).toHaveBeenCalled()
|
|
|
|
|
|
// Setup connection with invalid timeout
|
|
|
- const mockConnectionInvalid: McpConnection = {
|
|
|
+ const mockConnectionInvalid: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
config: JSON.stringify({
|
|
|
@@ -791,7 +1422,7 @@ describe("McpHub", () => {
|
|
|
await mcpHub.callTool("test-server", "test-tool")
|
|
|
|
|
|
// Verify default timeout was used
|
|
|
- expect(mockConnectionInvalid.client.request).toHaveBeenCalledWith(
|
|
|
+ expect(mockConnectionInvalid.client!.request).toHaveBeenCalledWith(
|
|
|
expect.anything(),
|
|
|
expect.anything(),
|
|
|
expect.objectContaining({ timeout: 60000 }), // Default 60 seconds
|
|
|
@@ -813,7 +1444,8 @@ describe("McpHub", () => {
|
|
|
vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig))
|
|
|
|
|
|
// Set up mock connection
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
type: "stdio",
|
|
|
@@ -852,7 +1484,8 @@ describe("McpHub", () => {
|
|
|
vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig))
|
|
|
|
|
|
// Set up mock connection
|
|
|
- const mockConnection: McpConnection = {
|
|
|
+ const mockConnection: ConnectedMcpConnection = {
|
|
|
+ type: "connected",
|
|
|
server: {
|
|
|
name: "test-server",
|
|
|
type: "stdio",
|
|
|
@@ -877,6 +1510,291 @@ describe("McpHub", () => {
|
|
|
})
|
|
|
})
|
|
|
|
|
|
+ describe("MCP global enable/disable", () => {
|
|
|
+ beforeEach(() => {
|
|
|
+ // Clear all mocks before each test
|
|
|
+ vi.clearAllMocks()
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should disconnect all servers when MCP is toggled from enabled to disabled", async () => {
|
|
|
+ // Mock StdioClientTransport
|
|
|
+ const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js")
|
|
|
+ const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockTransport = {
|
|
|
+ start: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ stderr: {
|
|
|
+ on: vi.fn(),
|
|
|
+ },
|
|
|
+ onerror: null,
|
|
|
+ onclose: null,
|
|
|
+ }
|
|
|
+
|
|
|
+ StdioClientTransport.mockImplementation(() => mockTransport)
|
|
|
+
|
|
|
+ // Mock Client
|
|
|
+ const clientModule = await import("@modelcontextprotocol/sdk/client/index.js")
|
|
|
+ const Client = clientModule.Client as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockClient = {
|
|
|
+ connect: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ getInstructions: vi.fn().mockReturnValue("test instructions"),
|
|
|
+ request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }),
|
|
|
+ }
|
|
|
+
|
|
|
+ Client.mockImplementation(() => mockClient)
|
|
|
+
|
|
|
+ // Start with MCP enabled
|
|
|
+ mockProvider.getState = vi.fn().mockResolvedValue({ mcpEnabled: true })
|
|
|
+
|
|
|
+ // Mock the config file read
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "toggle-test-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ // Create McpHub and let it initialize with MCP enabled
|
|
|
+ const mcpHub = new McpHub(mockProvider as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Verify server is connected
|
|
|
+ const connectedServer = mcpHub.connections.find((conn) => conn.server.name === "toggle-test-server")
|
|
|
+ expect(connectedServer).toBeDefined()
|
|
|
+ expect(connectedServer!.server.status).toBe("connected")
|
|
|
+ expect(connectedServer!.client).toBeDefined()
|
|
|
+ expect(connectedServer!.transport).toBeDefined()
|
|
|
+
|
|
|
+ // Now simulate toggling MCP to disabled
|
|
|
+ mockProvider.getState = vi.fn().mockResolvedValue({ mcpEnabled: false })
|
|
|
+
|
|
|
+ // Manually trigger what would happen when MCP is disabled
|
|
|
+ // (normally this would be triggered by the webview message handler)
|
|
|
+ const existingConnections = [...mcpHub.connections]
|
|
|
+ for (const conn of existingConnections) {
|
|
|
+ await mcpHub.deleteConnection(conn.server.name, conn.server.source)
|
|
|
+ }
|
|
|
+ await mcpHub.refreshAllConnections()
|
|
|
+
|
|
|
+ // Verify server is now tracked but disconnected
|
|
|
+ const disconnectedServer = mcpHub.connections.find((conn) => conn.server.name === "toggle-test-server")
|
|
|
+ expect(disconnectedServer).toBeDefined()
|
|
|
+ expect(disconnectedServer!.server.status).toBe("disconnected")
|
|
|
+ expect(disconnectedServer!.client).toBeNull()
|
|
|
+ expect(disconnectedServer!.transport).toBeNull()
|
|
|
+
|
|
|
+ // Verify close was called on the original client and transport
|
|
|
+ expect(mockClient.close).toHaveBeenCalled()
|
|
|
+ expect(mockTransport.close).toHaveBeenCalled()
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should not connect to servers when MCP is globally disabled", async () => {
|
|
|
+ // Mock provider with mcpEnabled: false
|
|
|
+ const disabledMockProvider = {
|
|
|
+ ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
|
|
|
+ ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
|
|
|
+ postMessageToWebview: vi.fn(),
|
|
|
+ getState: vi.fn().mockResolvedValue({ mcpEnabled: false }),
|
|
|
+ context: mockProvider.context,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Mock the config file read with a different server name to avoid conflicts
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "disabled-test-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ // Create a new McpHub instance with disabled MCP
|
|
|
+ const mcpHub = new McpHub(disabledMockProvider as unknown as ClineProvider)
|
|
|
+
|
|
|
+ // Wait for initialization
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Find the disabled-test-server
|
|
|
+ const disabledServer = mcpHub.connections.find((conn) => conn.server.name === "disabled-test-server")
|
|
|
+
|
|
|
+ // Verify that the server is tracked but not connected
|
|
|
+ expect(disabledServer).toBeDefined()
|
|
|
+ expect(disabledServer!.server.status).toBe("disconnected")
|
|
|
+ expect(disabledServer!.client).toBeNull()
|
|
|
+ expect(disabledServer!.transport).toBeNull()
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should connect to servers when MCP is globally enabled", async () => {
|
|
|
+ // Clear all mocks
|
|
|
+ vi.clearAllMocks()
|
|
|
+
|
|
|
+ // Mock StdioClientTransport
|
|
|
+ const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js")
|
|
|
+ const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ const mockTransport = {
|
|
|
+ start: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ stderr: {
|
|
|
+ on: vi.fn(),
|
|
|
+ },
|
|
|
+ onerror: null,
|
|
|
+ onclose: null,
|
|
|
+ }
|
|
|
+
|
|
|
+ StdioClientTransport.mockImplementation(() => mockTransport)
|
|
|
+
|
|
|
+ // Mock Client
|
|
|
+ const clientModule = await import("@modelcontextprotocol/sdk/client/index.js")
|
|
|
+ const Client = clientModule.Client as ReturnType<typeof vi.fn>
|
|
|
+
|
|
|
+ Client.mockImplementation(() => ({
|
|
|
+ connect: vi.fn().mockResolvedValue(undefined),
|
|
|
+ close: vi.fn().mockResolvedValue(undefined),
|
|
|
+ getInstructions: vi.fn().mockReturnValue("test instructions"),
|
|
|
+ request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }),
|
|
|
+ }))
|
|
|
+
|
|
|
+ // Mock provider with mcpEnabled: true
|
|
|
+ const enabledMockProvider = {
|
|
|
+ ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
|
|
|
+ ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
|
|
|
+ postMessageToWebview: vi.fn(),
|
|
|
+ getState: vi.fn().mockResolvedValue({ mcpEnabled: true }),
|
|
|
+ context: mockProvider.context,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Mock the config file read with a different server name
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "enabled-test-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ // Create a new McpHub instance with enabled MCP
|
|
|
+ const mcpHub = new McpHub(enabledMockProvider as unknown as ClineProvider)
|
|
|
+
|
|
|
+ // Wait for initialization
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Find the enabled-test-server
|
|
|
+ const enabledServer = mcpHub.connections.find((conn) => conn.server.name === "enabled-test-server")
|
|
|
+
|
|
|
+ // Verify that the server is connected
|
|
|
+ expect(enabledServer).toBeDefined()
|
|
|
+ expect(enabledServer!.server.status).toBe("connected")
|
|
|
+ expect(enabledServer!.client).toBeDefined()
|
|
|
+ expect(enabledServer!.transport).toBeDefined()
|
|
|
+
|
|
|
+ // Verify StdioClientTransport was called
|
|
|
+ expect(StdioClientTransport).toHaveBeenCalled()
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle refreshAllConnections when MCP is disabled", async () => {
|
|
|
+ // Mock provider with mcpEnabled: false
|
|
|
+ const disabledMockProvider = {
|
|
|
+ ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
|
|
|
+ ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
|
|
|
+ postMessageToWebview: vi.fn(),
|
|
|
+ getState: vi.fn().mockResolvedValue({ mcpEnabled: false }),
|
|
|
+ context: mockProvider.context,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Mock the config file read
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "refresh-test-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ // Create McpHub with disabled MCP
|
|
|
+ const mcpHub = new McpHub(disabledMockProvider as unknown as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Clear previous calls
|
|
|
+ vi.clearAllMocks()
|
|
|
+
|
|
|
+ // Call refreshAllConnections
|
|
|
+ await mcpHub.refreshAllConnections()
|
|
|
+
|
|
|
+ // Verify that servers are tracked but not connected
|
|
|
+ const server = mcpHub.connections.find((conn) => conn.server.name === "refresh-test-server")
|
|
|
+ expect(server).toBeDefined()
|
|
|
+ expect(server!.server.status).toBe("disconnected")
|
|
|
+ expect(server!.client).toBeNull()
|
|
|
+ expect(server!.transport).toBeNull()
|
|
|
+
|
|
|
+ // Verify postMessageToWebview was called to update the UI
|
|
|
+ expect(disabledMockProvider.postMessageToWebview).toHaveBeenCalledWith(
|
|
|
+ expect.objectContaining({
|
|
|
+ type: "mcpServers",
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should skip restarting connection when MCP is disabled", async () => {
|
|
|
+ // Mock provider with mcpEnabled: false
|
|
|
+ const disabledMockProvider = {
|
|
|
+ ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
|
|
|
+ ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"),
|
|
|
+ postMessageToWebview: vi.fn(),
|
|
|
+ getState: vi.fn().mockResolvedValue({ mcpEnabled: false }),
|
|
|
+ context: mockProvider.context,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Mock the config file read
|
|
|
+ vi.mocked(fs.readFile).mockResolvedValue(
|
|
|
+ JSON.stringify({
|
|
|
+ mcpServers: {
|
|
|
+ "restart-test-server": {
|
|
|
+ command: "node",
|
|
|
+ args: ["test.js"],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }),
|
|
|
+ )
|
|
|
+
|
|
|
+ // Create McpHub with disabled MCP
|
|
|
+ const mcpHub = new McpHub(disabledMockProvider as unknown as ClineProvider)
|
|
|
+ await new Promise((resolve) => setTimeout(resolve, 100))
|
|
|
+
|
|
|
+ // Set isConnecting to false to ensure it's properly reset
|
|
|
+ mcpHub.isConnecting = false
|
|
|
+
|
|
|
+ // Try to restart a connection
|
|
|
+ await mcpHub.restartConnection("restart-test-server")
|
|
|
+
|
|
|
+ // Verify that isConnecting was reset to false
|
|
|
+ expect(mcpHub.isConnecting).toBe(false)
|
|
|
+
|
|
|
+ // Verify that the server remains disconnected
|
|
|
+ const server = mcpHub.connections.find((conn) => conn.server.name === "restart-test-server")
|
|
|
+ expect(server).toBeDefined()
|
|
|
+ expect(server!.server.status).toBe("disconnected")
|
|
|
+ expect(server!.client).toBeNull()
|
|
|
+ expect(server!.transport).toBeNull()
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
describe("Windows command wrapping", () => {
|
|
|
let StdioClientTransport: ReturnType<typeof vi.fn>
|
|
|
let Client: ReturnType<typeof vi.fn>
|