| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687 |
- import { dynamicTool, type Tool, jsonSchema, type JSONSchema7 } from "ai"
- import { Client } from "@modelcontextprotocol/sdk/client/index.js"
- import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"
- import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"
- import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"
- import { UnauthorizedError } from "@modelcontextprotocol/sdk/client/auth.js"
- import {
- CallToolResultSchema,
- type Tool as MCPToolDef,
- ToolListChangedNotificationSchema,
- } from "@modelcontextprotocol/sdk/types.js"
- import { Config } from "../config/config"
- import { Log } from "../util/log"
- import { NamedError } from "@opencode-ai/util/error"
- import z from "zod/v4"
- import { Instance } from "../project/instance"
- import { Installation } from "../installation"
- import { withTimeout } from "@/util/timeout"
- import { McpOAuthProvider } from "./oauth-provider"
- import { McpOAuthCallback } from "./oauth-callback"
- import { McpAuth } from "./auth"
- import { BusEvent } from "../bus/bus-event"
- import { Bus } from "@/bus"
- import { TuiEvent } from "@/cli/cmd/tui/event"
- import open from "open"
- export namespace MCP {
- const log = Log.create({ service: "mcp" })
- export const ToolsChanged = BusEvent.define(
- "mcp.tools.changed",
- z.object({
- server: z.string(),
- }),
- )
- export const Failed = NamedError.create(
- "MCPFailed",
- z.object({
- name: z.string(),
- }),
- )
- type MCPClient = Client
- export const Status = z
- .discriminatedUnion("status", [
- z
- .object({
- status: z.literal("connected"),
- })
- .meta({
- ref: "MCPStatusConnected",
- }),
- z
- .object({
- status: z.literal("disabled"),
- })
- .meta({
- ref: "MCPStatusDisabled",
- }),
- z
- .object({
- status: z.literal("failed"),
- error: z.string(),
- })
- .meta({
- ref: "MCPStatusFailed",
- }),
- z
- .object({
- status: z.literal("needs_auth"),
- })
- .meta({
- ref: "MCPStatusNeedsAuth",
- }),
- z
- .object({
- status: z.literal("needs_client_registration"),
- error: z.string(),
- })
- .meta({
- ref: "MCPStatusNeedsClientRegistration",
- }),
- ])
- .meta({
- ref: "MCPStatus",
- })
- export type Status = z.infer<typeof Status>
- // Register notification handlers for MCP client
- function registerNotificationHandlers(client: MCPClient, serverName: string) {
- client.setNotificationHandler(ToolListChangedNotificationSchema, async () => {
- log.info("tools list changed notification received", { server: serverName })
- Bus.publish(ToolsChanged, { server: serverName })
- })
- }
- // Convert MCP tool definition to AI SDK Tool type
- async function convertMcpTool(mcpTool: MCPToolDef, client: MCPClient): Promise<Tool> {
- const inputSchema = mcpTool.inputSchema
- // Spread first, then override type to ensure it's always "object"
- const schema: JSONSchema7 = {
- ...(inputSchema as JSONSchema7),
- type: "object",
- properties: (inputSchema.properties ?? {}) as JSONSchema7["properties"],
- additionalProperties: false,
- }
- const config = await Config.get()
- return dynamicTool({
- description: mcpTool.description ?? "",
- inputSchema: jsonSchema(schema),
- execute: async (args: unknown) => {
- return client.callTool(
- {
- name: mcpTool.name,
- arguments: args as Record<string, unknown>,
- },
- CallToolResultSchema,
- {
- resetTimeoutOnProgress: true,
- timeout: config.experimental?.mcp_timeout,
- },
- )
- },
- })
- }
- // Store transports for OAuth servers to allow finishing auth
- type TransportWithAuth = StreamableHTTPClientTransport | SSEClientTransport
- const pendingOAuthTransports = new Map<string, TransportWithAuth>()
- const state = Instance.state(
- async () => {
- const cfg = await Config.get()
- const config = cfg.mcp ?? {}
- const clients: Record<string, MCPClient> = {}
- const status: Record<string, Status> = {}
- await Promise.all(
- Object.entries(config).map(async ([key, mcp]) => {
- // If disabled by config, mark as disabled without trying to connect
- if (mcp.enabled === false) {
- status[key] = { status: "disabled" }
- return
- }
- const result = await create(key, mcp).catch(() => undefined)
- if (!result) return
- status[key] = result.status
- if (result.mcpClient) {
- clients[key] = result.mcpClient
- }
- }),
- )
- return {
- status,
- clients,
- }
- },
- async (state) => {
- await Promise.all(
- Object.values(state.clients).map((client) =>
- client.close().catch((error) => {
- log.error("Failed to close MCP client", {
- error,
- })
- }),
- ),
- )
- pendingOAuthTransports.clear()
- },
- )
- export async function add(name: string, mcp: Config.Mcp) {
- const s = await state()
- const result = await create(name, mcp)
- if (!result) {
- const status = {
- status: "failed" as const,
- error: "unknown error",
- }
- s.status[name] = status
- return {
- status,
- }
- }
- if (!result.mcpClient) {
- s.status[name] = result.status
- return {
- status: s.status,
- }
- }
- s.clients[name] = result.mcpClient
- s.status[name] = result.status
- return {
- status: s.status,
- }
- }
- async function create(key: string, mcp: Config.Mcp) {
- if (mcp.enabled === false) {
- log.info("mcp server disabled", { key })
- return {
- mcpClient: undefined,
- status: { status: "disabled" as const },
- }
- }
- log.info("found", { key, type: mcp.type })
- let mcpClient: MCPClient | undefined
- let status: Status | undefined = undefined
- if (mcp.type === "remote") {
- // OAuth is enabled by default for remote servers unless explicitly disabled with oauth: false
- const oauthDisabled = mcp.oauth === false
- const oauthConfig = typeof mcp.oauth === "object" ? mcp.oauth : undefined
- let authProvider: McpOAuthProvider | undefined
- if (!oauthDisabled) {
- authProvider = new McpOAuthProvider(
- key,
- mcp.url,
- {
- clientId: oauthConfig?.clientId,
- clientSecret: oauthConfig?.clientSecret,
- scope: oauthConfig?.scope,
- },
- {
- onRedirect: async (url) => {
- log.info("oauth redirect requested", { key, url: url.toString() })
- // Store the URL - actual browser opening is handled by startAuth
- },
- },
- )
- }
- const transports: Array<{ name: string; transport: TransportWithAuth }> = [
- {
- name: "StreamableHTTP",
- transport: new StreamableHTTPClientTransport(new URL(mcp.url), {
- authProvider,
- requestInit: mcp.headers ? { headers: mcp.headers } : undefined,
- }),
- },
- {
- name: "SSE",
- transport: new SSEClientTransport(new URL(mcp.url), {
- authProvider,
- requestInit: mcp.headers ? { headers: mcp.headers } : undefined,
- }),
- },
- ]
- let lastError: Error | undefined
- for (const { name, transport } of transports) {
- try {
- const client = new Client({
- name: "opencode",
- version: Installation.VERSION,
- })
- await client.connect(transport)
- registerNotificationHandlers(client, key)
- mcpClient = client
- log.info("connected", { key, transport: name })
- status = { status: "connected" }
- break
- } catch (error) {
- lastError = error instanceof Error ? error : new Error(String(error))
- // Handle OAuth-specific errors
- if (error instanceof UnauthorizedError) {
- log.info("mcp server requires authentication", { key, transport: name })
- // Check if this is a "needs registration" error
- if (lastError.message.includes("registration") || lastError.message.includes("client_id")) {
- status = {
- status: "needs_client_registration" as const,
- error: "Server does not support dynamic client registration. Please provide clientId in config.",
- }
- // Show toast for needs_client_registration
- Bus.publish(TuiEvent.ToastShow, {
- title: "MCP Authentication Required",
- message: `Server "${key}" requires a pre-registered client ID. Add clientId to your config.`,
- variant: "warning",
- duration: 8000,
- }).catch((e) => log.debug("failed to show toast", { error: e }))
- } else {
- // Store transport for later finishAuth call
- pendingOAuthTransports.set(key, transport)
- status = { status: "needs_auth" as const }
- // Show toast for needs_auth
- Bus.publish(TuiEvent.ToastShow, {
- title: "MCP Authentication Required",
- message: `Server "${key}" requires authentication. Run: opencode mcp auth ${key}`,
- variant: "warning",
- duration: 8000,
- }).catch((e) => log.debug("failed to show toast", { error: e }))
- }
- break
- }
- log.debug("transport connection failed", {
- key,
- transport: name,
- url: mcp.url,
- error: lastError.message,
- })
- status = {
- status: "failed" as const,
- error: lastError.message,
- }
- }
- }
- }
- if (mcp.type === "local") {
- const [cmd, ...args] = mcp.command
- const cwd = Instance.directory
- const transport = new StdioClientTransport({
- stderr: "ignore",
- command: cmd,
- args,
- cwd,
- env: {
- ...process.env,
- ...(cmd === "opencode" ? { BUN_BE_BUN: "1" } : {}),
- ...mcp.environment,
- },
- })
- try {
- const client = new Client({
- name: "opencode",
- version: Installation.VERSION,
- })
- await client.connect(transport)
- registerNotificationHandlers(client, key)
- mcpClient = client
- status = {
- status: "connected",
- }
- } catch (error) {
- log.error("local mcp startup failed", {
- key,
- command: mcp.command,
- cwd,
- error: error instanceof Error ? error.message : String(error),
- })
- status = {
- status: "failed" as const,
- error: error instanceof Error ? error.message : String(error),
- }
- }
- }
- if (!status) {
- status = {
- status: "failed" as const,
- error: "Unknown error",
- }
- }
- if (!mcpClient) {
- return {
- mcpClient: undefined,
- status,
- }
- }
- const result = await withTimeout(mcpClient.listTools(), mcp.timeout ?? 5000).catch((err) => {
- log.error("failed to get tools from client", { key, error: err })
- return undefined
- })
- if (!result) {
- await mcpClient.close().catch((error) => {
- log.error("Failed to close MCP client", {
- error,
- })
- })
- status = {
- status: "failed",
- error: "Failed to get tools",
- }
- return {
- mcpClient: undefined,
- status: {
- status: "failed" as const,
- error: "Failed to get tools",
- },
- }
- }
- log.info("create() successfully created client", { key, toolCount: result.tools.length })
- return {
- mcpClient,
- status,
- }
- }
- export async function status() {
- const s = await state()
- const cfg = await Config.get()
- const config = cfg.mcp ?? {}
- const result: Record<string, Status> = {}
- // Include all MCPs from config, not just connected ones
- for (const key of Object.keys(config)) {
- result[key] = s.status[key] ?? { status: "disabled" }
- }
- return result
- }
- export async function clients() {
- return state().then((state) => state.clients)
- }
- export async function connect(name: string) {
- const cfg = await Config.get()
- const config = cfg.mcp ?? {}
- const mcp = config[name]
- if (!mcp) {
- log.error("MCP config not found", { name })
- return
- }
- const result = await create(name, { ...mcp, enabled: true })
- if (!result) {
- const s = await state()
- s.status[name] = {
- status: "failed",
- error: "Unknown error during connection",
- }
- return
- }
- const s = await state()
- s.status[name] = result.status
- if (result.mcpClient) {
- s.clients[name] = result.mcpClient
- }
- }
- export async function disconnect(name: string) {
- const s = await state()
- const client = s.clients[name]
- if (client) {
- await client.close().catch((error) => {
- log.error("Failed to close MCP client", { name, error })
- })
- delete s.clients[name]
- }
- s.status[name] = { status: "disabled" }
- }
- export async function tools() {
- const result: Record<string, Tool> = {}
- const s = await state()
- const clientsSnapshot = await clients()
- for (const [clientName, client] of Object.entries(clientsSnapshot)) {
- // Only include tools from connected MCPs (skip disabled ones)
- if (s.status[clientName]?.status !== "connected") {
- continue
- }
- const toolsResult = await client.listTools().catch((e) => {
- log.error("failed to get tools", { clientName, error: e.message })
- const failedStatus = {
- status: "failed" as const,
- error: e instanceof Error ? e.message : String(e),
- }
- s.status[clientName] = failedStatus
- delete s.clients[clientName]
- return undefined
- })
- if (!toolsResult) {
- continue
- }
- for (const mcpTool of toolsResult.tools) {
- const sanitizedClientName = clientName.replace(/[^a-zA-Z0-9_-]/g, "_")
- const sanitizedToolName = mcpTool.name.replace(/[^a-zA-Z0-9_-]/g, "_")
- result[sanitizedClientName + "_" + sanitizedToolName] = await convertMcpTool(mcpTool, client)
- }
- }
- return result
- }
- /**
- * Start OAuth authentication flow for an MCP server.
- * Returns the authorization URL that should be opened in a browser.
- */
- export async function startAuth(mcpName: string): Promise<{ authorizationUrl: string }> {
- const cfg = await Config.get()
- const mcpConfig = cfg.mcp?.[mcpName]
- if (!mcpConfig) {
- throw new Error(`MCP server not found: ${mcpName}`)
- }
- if (mcpConfig.type !== "remote") {
- throw new Error(`MCP server ${mcpName} is not a remote server`)
- }
- if (mcpConfig.oauth === false) {
- throw new Error(`MCP server ${mcpName} has OAuth explicitly disabled`)
- }
- // Start the callback server
- await McpOAuthCallback.ensureRunning()
- // Generate and store a cryptographically secure state parameter BEFORE creating the provider
- // The SDK will call provider.state() to read this value
- const oauthState = Array.from(crypto.getRandomValues(new Uint8Array(32)))
- .map((b) => b.toString(16).padStart(2, "0"))
- .join("")
- await McpAuth.updateOAuthState(mcpName, oauthState)
- // Create a new auth provider for this flow
- // OAuth config is optional - if not provided, we'll use auto-discovery
- const oauthConfig = typeof mcpConfig.oauth === "object" ? mcpConfig.oauth : undefined
- let capturedUrl: URL | undefined
- const authProvider = new McpOAuthProvider(
- mcpName,
- mcpConfig.url,
- {
- clientId: oauthConfig?.clientId,
- clientSecret: oauthConfig?.clientSecret,
- scope: oauthConfig?.scope,
- },
- {
- onRedirect: async (url) => {
- capturedUrl = url
- },
- },
- )
- // Create transport with auth provider
- const transport = new StreamableHTTPClientTransport(new URL(mcpConfig.url), {
- authProvider,
- })
- // Try to connect - this will trigger the OAuth flow
- try {
- const client = new Client({
- name: "opencode",
- version: Installation.VERSION,
- })
- await client.connect(transport)
- // If we get here, we're already authenticated
- return { authorizationUrl: "" }
- } catch (error) {
- if (error instanceof UnauthorizedError && capturedUrl) {
- // Store transport for finishAuth
- pendingOAuthTransports.set(mcpName, transport)
- return { authorizationUrl: capturedUrl.toString() }
- }
- throw error
- }
- }
- /**
- * Complete OAuth authentication after user authorizes in browser.
- * Opens the browser and waits for callback.
- */
- export async function authenticate(mcpName: string): Promise<Status> {
- const { authorizationUrl } = await startAuth(mcpName)
- if (!authorizationUrl) {
- // Already authenticated
- const s = await state()
- return s.status[mcpName] ?? { status: "connected" }
- }
- // Get the state that was already generated and stored in startAuth()
- const oauthState = await McpAuth.getOAuthState(mcpName)
- if (!oauthState) {
- throw new Error("OAuth state not found - this should not happen")
- }
- // The SDK has already added the state parameter to the authorization URL
- // We just need to open the browser
- log.info("opening browser for oauth", { mcpName, url: authorizationUrl, state: oauthState })
- await open(authorizationUrl)
- // Wait for callback using the OAuth state parameter
- const code = await McpOAuthCallback.waitForCallback(oauthState)
- // Validate and clear the state
- const storedState = await McpAuth.getOAuthState(mcpName)
- if (storedState !== oauthState) {
- await McpAuth.clearOAuthState(mcpName)
- throw new Error("OAuth state mismatch - potential CSRF attack")
- }
- await McpAuth.clearOAuthState(mcpName)
- // Finish auth
- return finishAuth(mcpName, code)
- }
- /**
- * Complete OAuth authentication with the authorization code.
- */
- export async function finishAuth(mcpName: string, authorizationCode: string): Promise<Status> {
- const transport = pendingOAuthTransports.get(mcpName)
- if (!transport) {
- throw new Error(`No pending OAuth flow for MCP server: ${mcpName}`)
- }
- try {
- // Call finishAuth on the transport
- await transport.finishAuth(authorizationCode)
- // Clear the code verifier after successful auth
- await McpAuth.clearCodeVerifier(mcpName)
- // Now try to reconnect
- const cfg = await Config.get()
- const mcpConfig = cfg.mcp?.[mcpName]
- if (!mcpConfig) {
- throw new Error(`MCP server not found: ${mcpName}`)
- }
- // Re-add the MCP server to establish connection
- pendingOAuthTransports.delete(mcpName)
- const result = await add(mcpName, mcpConfig)
- const statusRecord = result.status as Record<string, Status>
- return statusRecord[mcpName] ?? { status: "failed", error: "Unknown error after auth" }
- } catch (error) {
- log.error("failed to finish oauth", { mcpName, error })
- return {
- status: "failed",
- error: error instanceof Error ? error.message : String(error),
- }
- }
- }
- /**
- * Remove OAuth credentials for an MCP server.
- */
- export async function removeAuth(mcpName: string): Promise<void> {
- await McpAuth.remove(mcpName)
- McpOAuthCallback.cancelPending(mcpName)
- pendingOAuthTransports.delete(mcpName)
- await McpAuth.clearOAuthState(mcpName)
- log.info("removed oauth credentials", { mcpName })
- }
- /**
- * Check if an MCP server supports OAuth (remote servers support OAuth by default unless explicitly disabled).
- */
- export async function supportsOAuth(mcpName: string): Promise<boolean> {
- const cfg = await Config.get()
- const mcpConfig = cfg.mcp?.[mcpName]
- return mcpConfig?.type === "remote" && mcpConfig.oauth !== false
- }
- /**
- * Check if an MCP server has stored OAuth tokens.
- */
- export async function hasStoredTokens(mcpName: string): Promise<boolean> {
- const entry = await McpAuth.get(mcpName)
- return !!entry?.tokens
- }
- export type AuthStatus = "authenticated" | "expired" | "not_authenticated"
- /**
- * Get the authentication status for an MCP server.
- */
- export async function getAuthStatus(mcpName: string): Promise<AuthStatus> {
- const hasTokens = await hasStoredTokens(mcpName)
- if (!hasTokens) return "not_authenticated"
- const expired = await McpAuth.isTokenExpired(mcpName)
- return expired ? "expired" : "authenticated"
- }
- }
|