浏览代码

MCP improvements (#5699)

Rohan Mukherjee 2 月之前
父节点
当前提交
7427b887f9

+ 47 - 5
packages/opencode/src/mcp/auth.ts

@@ -24,6 +24,8 @@ export namespace McpAuth {
     tokens: Tokens.optional(),
     clientInfo: ClientInfo.optional(),
     codeVerifier: z.string().optional(),
+    oauthState: z.string().optional(),
+    serverUrl: z.string().optional(), // Track the URL these credentials are for
   })
   export type Entry = z.infer<typeof Entry>
 
@@ -34,14 +36,35 @@ export namespace McpAuth {
     return data[mcpName]
   }
 
+  /**
+   * Get auth entry and validate it's for the correct URL.
+   * Returns undefined if URL has changed (credentials are invalid).
+   */
+  export async function getForUrl(mcpName: string, serverUrl: string): Promise<Entry | undefined> {
+    const entry = await get(mcpName)
+    if (!entry) return undefined
+
+    // If no serverUrl is stored, this is from an old version - consider it invalid
+    if (!entry.serverUrl) return undefined
+
+    // If URL has changed, credentials are invalid
+    if (entry.serverUrl !== serverUrl) return undefined
+
+    return entry
+  }
+
   export async function all(): Promise<Record<string, Entry>> {
     const file = Bun.file(filepath)
     return file.json().catch(() => ({}))
   }
 
-  export async function set(mcpName: string, entry: Entry): Promise<void> {
+  export async function set(mcpName: string, entry: Entry, serverUrl?: string): Promise<void> {
     const file = Bun.file(filepath)
     const data = await all()
+    // Always update serverUrl if provided
+    if (serverUrl) {
+      entry.serverUrl = serverUrl
+    }
     await Bun.write(file, JSON.stringify({ ...data, [mcpName]: entry }, null, 2))
     await fs.chmod(file.name!, 0o600)
   }
@@ -54,16 +77,16 @@ export namespace McpAuth {
     await fs.chmod(file.name!, 0o600)
   }
 
-  export async function updateTokens(mcpName: string, tokens: Tokens): Promise<void> {
+  export async function updateTokens(mcpName: string, tokens: Tokens, serverUrl?: string): Promise<void> {
     const entry = (await get(mcpName)) ?? {}
     entry.tokens = tokens
-    await set(mcpName, entry)
+    await set(mcpName, entry, serverUrl)
   }
 
-  export async function updateClientInfo(mcpName: string, clientInfo: ClientInfo): Promise<void> {
+  export async function updateClientInfo(mcpName: string, clientInfo: ClientInfo, serverUrl?: string): Promise<void> {
     const entry = (await get(mcpName)) ?? {}
     entry.clientInfo = clientInfo
-    await set(mcpName, entry)
+    await set(mcpName, entry, serverUrl)
   }
 
   export async function updateCodeVerifier(mcpName: string, codeVerifier: string): Promise<void> {
@@ -79,4 +102,23 @@ export namespace McpAuth {
       await set(mcpName, entry)
     }
   }
+
+  export async function updateOAuthState(mcpName: string, oauthState: string): Promise<void> {
+    const entry = (await get(mcpName)) ?? {}
+    entry.oauthState = oauthState
+    await set(mcpName, entry)
+  }
+
+  export async function getOAuthState(mcpName: string): Promise<string | undefined> {
+    const entry = await get(mcpName)
+    return entry?.oauthState
+  }
+
+  export async function clearOAuthState(mcpName: string): Promise<void> {
+    const entry = await get(mcpName)
+    if (entry) {
+      delete entry.oauthState
+      await set(mcpName, entry)
+    }
+  }
 }

+ 26 - 14
packages/opencode/src/mcp/index.ts

@@ -436,6 +436,13 @@ export namespace MCP {
     // Start the callback server
     await McpOAuthCallback.ensureRunning()
 
+    // Generate and store a cryptographically secure state parameter BEFORE creating the provider
+    // The SDK will call provider.state() to read this value
+    const oauthState = Array.from(crypto.getRandomValues(new Uint8Array(32)))
+      .map((b) => b.toString(16).padStart(2, "0"))
+      .join("")
+    await McpAuth.updateOAuthState(mcpName, oauthState)
+
     // Create a new auth provider for this flow
     // OAuth config is optional - if not provided, we'll use auto-discovery
     const oauthConfig = typeof mcpConfig.oauth === "object" ? mcpConfig.oauth : undefined
@@ -491,25 +498,29 @@ export namespace MCP {
       return s.status[mcpName] ?? { status: "connected" }
     }
 
-    // Extract state from authorization URL to use as callback key
-    // If no state parameter, use mcpName as fallback
-    const authUrl = new URL(authorizationUrl)
-    let oauthState = mcpName
-
-    if (authUrl.searchParams.has("state")) {
-      oauthState = authUrl.searchParams.get("state")!
-    } else {
-      log.info("no state parameter in authorization URL, using mcpName as state", { mcpName })
-      authUrl.searchParams.set("state", oauthState)
+    // Get the state that was already generated and stored in startAuth()
+    const oauthState = await McpAuth.getOAuthState(mcpName)
+    if (!oauthState) {
+      throw new Error("OAuth state not found - this should not happen")
     }
 
-    // Open browser
-    log.info("opening browser for oauth", { mcpName, url: authUrl.toString(), state: oauthState })
-    await open(authUrl.toString())
+    // The SDK has already added the state parameter to the authorization URL
+    // We just need to open the browser
+    log.info("opening browser for oauth", { mcpName, url: authorizationUrl, state: oauthState })
+    await open(authorizationUrl)
 
-    // Wait for callback using the OAuth state parameter (or mcpName as fallback)
+    // Wait for callback using the OAuth state parameter
     const code = await McpOAuthCallback.waitForCallback(oauthState)
 
+    // Validate and clear the state
+    const storedState = await McpAuth.getOAuthState(mcpName)
+    if (storedState !== oauthState) {
+      await McpAuth.clearOAuthState(mcpName)
+      throw new Error("OAuth state mismatch - potential CSRF attack")
+    }
+
+    await McpAuth.clearOAuthState(mcpName)
+
     // Finish auth
     return finishAuth(mcpName, code)
   }
@@ -561,6 +572,7 @@ export namespace MCP {
     await McpAuth.remove(mcpName)
     McpOAuthCallback.cancelPending(mcpName)
     pendingOAuthTransports.delete(mcpName)
+    await McpAuth.clearOAuthState(mcpName)
     log.info("removed oauth credentials", { mcpName })
   }
 

+ 22 - 25
packages/opencode/src/mcp/oauth-callback.ts

@@ -81,9 +81,19 @@ export namespace McpOAuthCallback {
 
         log.info("received oauth callback", { hasCode: !!code, state, error })
 
+        // Enforce state parameter presence
+        if (!state) {
+          const errorMsg = "Missing required state parameter - potential CSRF attack"
+          log.error("oauth callback missing state parameter", { url: url.toString() })
+          return new Response(HTML_ERROR(errorMsg), {
+            status: 400,
+            headers: { "Content-Type": "text/html" },
+          })
+        }
+
         if (error) {
           const errorMsg = errorDescription || error
-          if (state && pendingAuths.has(state)) {
+          if (pendingAuths.has(state)) {
             const pending = pendingAuths.get(state)!
             clearTimeout(pending.timeout)
             pendingAuths.delete(state)
@@ -101,33 +111,20 @@ export namespace McpOAuthCallback {
           })
         }
 
-        // Try to find the pending auth by state parameter, or if no state, use the single pending auth
-        let pending: PendingAuth | undefined
-        let pendingKey: string | undefined
-
-        if (state && pendingAuths.has(state)) {
-          pending = pendingAuths.get(state)!
-          pendingKey = state
-        } else if (!state && pendingAuths.size === 1) {
-          // No state parameter but only one pending auth - use it
-          const [key, value] = pendingAuths.entries().next().value as [string, PendingAuth]
-          pending = value
-          pendingKey = key
-          log.info("no state parameter, using single pending auth", { key })
-        }
-
-        if (!pending || !pendingKey) {
-          const errorMsg = !state
-            ? "No state parameter provided and multiple pending authorizations"
-            : "Unknown or expired authorization request"
+        // Validate state parameter
+        if (!pendingAuths.has(state)) {
+          const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
+          log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
           return new Response(HTML_ERROR(errorMsg), {
             status: 400,
             headers: { "Content-Type": "text/html" },
           })
         }
 
+        const pending = pendingAuths.get(state)!
+
         clearTimeout(pending.timeout)
-        pendingAuths.delete(pendingKey)
+        pendingAuths.delete(state)
         pending.resolve(code)
 
         return new Response(HTML_SUCCESS, {
@@ -139,16 +136,16 @@ export namespace McpOAuthCallback {
     log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT })
   }
 
-  export function waitForCallback(mcpName: string): Promise<string> {
+  export function waitForCallback(oauthState: string): Promise<string> {
     return new Promise((resolve, reject) => {
       const timeout = setTimeout(() => {
-        if (pendingAuths.has(mcpName)) {
-          pendingAuths.delete(mcpName)
+        if (pendingAuths.has(oauthState)) {
+          pendingAuths.delete(oauthState)
           reject(new Error("OAuth callback timeout - authorization took too long"))
         }
       }, CALLBACK_TIMEOUT_MS)
 
-      pendingAuths.set(mcpName, { resolve, reject, timeout })
+      pendingAuths.set(oauthState, { resolve, reject, timeout })
     })
   }
 

+ 37 - 15
packages/opencode/src/mcp/oauth-provider.ts

@@ -56,7 +56,8 @@ export class McpOAuthProvider implements OAuthClientProvider {
     }
 
     // Check stored client info (from dynamic registration)
-    const entry = await McpAuth.get(this.mcpName)
+    // Use getForUrl to validate credentials are for the current server URL
+    const entry = await McpAuth.getForUrl(this.mcpName, this.serverUrl)
     if (entry?.clientInfo) {
       // Check if client secret has expired
       if (entry.clientInfo.clientSecretExpiresAt && entry.clientInfo.clientSecretExpiresAt < Date.now() / 1000) {
@@ -69,17 +70,21 @@ export class McpOAuthProvider implements OAuthClientProvider {
       }
     }
 
-    // No client info - will trigger dynamic registration
+    // No client info or URL changed - will trigger dynamic registration
     return undefined
   }
 
   async saveClientInformation(info: OAuthClientInformationFull): Promise<void> {
-    await McpAuth.updateClientInfo(this.mcpName, {
-      clientId: info.client_id,
-      clientSecret: info.client_secret,
-      clientIdIssuedAt: info.client_id_issued_at,
-      clientSecretExpiresAt: info.client_secret_expires_at,
-    })
+    await McpAuth.updateClientInfo(
+      this.mcpName,
+      {
+        clientId: info.client_id,
+        clientSecret: info.client_secret,
+        clientIdIssuedAt: info.client_id_issued_at,
+        clientSecretExpiresAt: info.client_secret_expires_at,
+      },
+      this.serverUrl,
+    )
     log.info("saved dynamically registered client", {
       mcpName: this.mcpName,
       clientId: info.client_id,
@@ -87,7 +92,8 @@ export class McpOAuthProvider implements OAuthClientProvider {
   }
 
   async tokens(): Promise<OAuthTokens | undefined> {
-    const entry = await McpAuth.get(this.mcpName)
+    // Use getForUrl to validate tokens are for the current server URL
+    const entry = await McpAuth.getForUrl(this.mcpName, this.serverUrl)
     if (!entry?.tokens) return undefined
 
     return {
@@ -102,12 +108,16 @@ export class McpOAuthProvider implements OAuthClientProvider {
   }
 
   async saveTokens(tokens: OAuthTokens): Promise<void> {
-    await McpAuth.updateTokens(this.mcpName, {
-      accessToken: tokens.access_token,
-      refreshToken: tokens.refresh_token,
-      expiresAt: tokens.expires_in ? Date.now() / 1000 + tokens.expires_in : undefined,
-      scope: tokens.scope,
-    })
+    await McpAuth.updateTokens(
+      this.mcpName,
+      {
+        accessToken: tokens.access_token,
+        refreshToken: tokens.refresh_token,
+        expiresAt: tokens.expires_in ? Date.now() / 1000 + tokens.expires_in : undefined,
+        scope: tokens.scope,
+      },
+      this.serverUrl,
+    )
     log.info("saved oauth tokens", { mcpName: this.mcpName })
   }
 
@@ -127,6 +137,18 @@ export class McpOAuthProvider implements OAuthClientProvider {
     }
     return entry.codeVerifier
   }
+
+  async saveState(state: string): Promise<void> {
+    await McpAuth.updateOAuthState(this.mcpName, state)
+  }
+
+  async state(): Promise<string> {
+    const entry = await McpAuth.get(this.mcpName)
+    if (!entry?.oauthState) {
+      throw new Error(`No OAuth state saved for MCP server: ${this.mcpName}`)
+    }
+    return entry.oauthState
+  }
 }
 
 export { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH }