index.ts 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. import { dynamicTool, type Tool, jsonSchema, type JSONSchema7 } from "ai"
  2. import { Client } from "@modelcontextprotocol/sdk/client/index.js"
  3. import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"
  4. import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"
  5. import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"
  6. import { UnauthorizedError } from "@modelcontextprotocol/sdk/client/auth.js"
  7. import {
  8. CallToolResultSchema,
  9. type Tool as MCPToolDef,
  10. ToolListChangedNotificationSchema,
  11. } from "@modelcontextprotocol/sdk/types.js"
  12. import { Config } from "../config/config"
  13. import { Log } from "../util/log"
  14. import { NamedError } from "@opencode-ai/util/error"
  15. import z from "zod/v4"
  16. import { Instance } from "../project/instance"
  17. import { Installation } from "../installation"
  18. import { withTimeout } from "@/util/timeout"
  19. import { McpOAuthProvider } from "./oauth-provider"
  20. import { McpOAuthCallback } from "./oauth-callback"
  21. import { McpAuth } from "./auth"
  22. import { BusEvent } from "../bus/bus-event"
  23. import { Bus } from "@/bus"
  24. import { TuiEvent } from "@/cli/cmd/tui/event"
  25. import open from "open"
  26. export namespace MCP {
  27. const log = Log.create({ service: "mcp" })
  28. export const ToolsChanged = BusEvent.define(
  29. "mcp.tools.changed",
  30. z.object({
  31. server: z.string(),
  32. }),
  33. )
  34. export const Failed = NamedError.create(
  35. "MCPFailed",
  36. z.object({
  37. name: z.string(),
  38. }),
  39. )
  40. type MCPClient = Client
  41. export const Status = z
  42. .discriminatedUnion("status", [
  43. z
  44. .object({
  45. status: z.literal("connected"),
  46. })
  47. .meta({
  48. ref: "MCPStatusConnected",
  49. }),
  50. z
  51. .object({
  52. status: z.literal("disabled"),
  53. })
  54. .meta({
  55. ref: "MCPStatusDisabled",
  56. }),
  57. z
  58. .object({
  59. status: z.literal("failed"),
  60. error: z.string(),
  61. })
  62. .meta({
  63. ref: "MCPStatusFailed",
  64. }),
  65. z
  66. .object({
  67. status: z.literal("needs_auth"),
  68. })
  69. .meta({
  70. ref: "MCPStatusNeedsAuth",
  71. }),
  72. z
  73. .object({
  74. status: z.literal("needs_client_registration"),
  75. error: z.string(),
  76. })
  77. .meta({
  78. ref: "MCPStatusNeedsClientRegistration",
  79. }),
  80. ])
  81. .meta({
  82. ref: "MCPStatus",
  83. })
  84. export type Status = z.infer<typeof Status>
  85. // Register notification handlers for MCP client
  86. function registerNotificationHandlers(client: MCPClient, serverName: string) {
  87. client.setNotificationHandler(ToolListChangedNotificationSchema, async () => {
  88. log.info("tools list changed notification received", { server: serverName })
  89. Bus.publish(ToolsChanged, { server: serverName })
  90. })
  91. }
  92. // Convert MCP tool definition to AI SDK Tool type
  93. async function convertMcpTool(mcpTool: MCPToolDef, client: MCPClient): Promise<Tool> {
  94. const inputSchema = mcpTool.inputSchema
  95. // Spread first, then override type to ensure it's always "object"
  96. const schema: JSONSchema7 = {
  97. ...(inputSchema as JSONSchema7),
  98. type: "object",
  99. properties: (inputSchema.properties ?? {}) as JSONSchema7["properties"],
  100. additionalProperties: false,
  101. }
  102. const config = await Config.get()
  103. return dynamicTool({
  104. description: mcpTool.description ?? "",
  105. inputSchema: jsonSchema(schema),
  106. execute: async (args: unknown) => {
  107. return client.callTool(
  108. {
  109. name: mcpTool.name,
  110. arguments: args as Record<string, unknown>,
  111. },
  112. CallToolResultSchema,
  113. {
  114. resetTimeoutOnProgress: true,
  115. timeout: config.experimental?.mcp_timeout,
  116. },
  117. )
  118. },
  119. })
  120. }
  121. // Store transports for OAuth servers to allow finishing auth
  122. type TransportWithAuth = StreamableHTTPClientTransport | SSEClientTransport
  123. const pendingOAuthTransports = new Map<string, TransportWithAuth>()
  124. const state = Instance.state(
  125. async () => {
  126. const cfg = await Config.get()
  127. const config = cfg.mcp ?? {}
  128. const clients: Record<string, MCPClient> = {}
  129. const status: Record<string, Status> = {}
  130. await Promise.all(
  131. Object.entries(config).map(async ([key, mcp]) => {
  132. // If disabled by config, mark as disabled without trying to connect
  133. if (mcp.enabled === false) {
  134. status[key] = { status: "disabled" }
  135. return
  136. }
  137. const result = await create(key, mcp).catch(() => undefined)
  138. if (!result) return
  139. status[key] = result.status
  140. if (result.mcpClient) {
  141. clients[key] = result.mcpClient
  142. }
  143. }),
  144. )
  145. return {
  146. status,
  147. clients,
  148. }
  149. },
  150. async (state) => {
  151. await Promise.all(
  152. Object.values(state.clients).map((client) =>
  153. client.close().catch((error) => {
  154. log.error("Failed to close MCP client", {
  155. error,
  156. })
  157. }),
  158. ),
  159. )
  160. pendingOAuthTransports.clear()
  161. },
  162. )
  163. export async function add(name: string, mcp: Config.Mcp) {
  164. const s = await state()
  165. const result = await create(name, mcp)
  166. if (!result) {
  167. const status = {
  168. status: "failed" as const,
  169. error: "unknown error",
  170. }
  171. s.status[name] = status
  172. return {
  173. status,
  174. }
  175. }
  176. if (!result.mcpClient) {
  177. s.status[name] = result.status
  178. return {
  179. status: s.status,
  180. }
  181. }
  182. s.clients[name] = result.mcpClient
  183. s.status[name] = result.status
  184. return {
  185. status: s.status,
  186. }
  187. }
  188. async function create(key: string, mcp: Config.Mcp) {
  189. if (mcp.enabled === false) {
  190. log.info("mcp server disabled", { key })
  191. return {
  192. mcpClient: undefined,
  193. status: { status: "disabled" as const },
  194. }
  195. }
  196. log.info("found", { key, type: mcp.type })
  197. let mcpClient: MCPClient | undefined
  198. let status: Status | undefined = undefined
  199. if (mcp.type === "remote") {
  200. // OAuth is enabled by default for remote servers unless explicitly disabled with oauth: false
  201. const oauthDisabled = mcp.oauth === false
  202. const oauthConfig = typeof mcp.oauth === "object" ? mcp.oauth : undefined
  203. let authProvider: McpOAuthProvider | undefined
  204. if (!oauthDisabled) {
  205. authProvider = new McpOAuthProvider(
  206. key,
  207. mcp.url,
  208. {
  209. clientId: oauthConfig?.clientId,
  210. clientSecret: oauthConfig?.clientSecret,
  211. scope: oauthConfig?.scope,
  212. },
  213. {
  214. onRedirect: async (url) => {
  215. log.info("oauth redirect requested", { key, url: url.toString() })
  216. // Store the URL - actual browser opening is handled by startAuth
  217. },
  218. },
  219. )
  220. }
  221. const transports: Array<{ name: string; transport: TransportWithAuth }> = [
  222. {
  223. name: "StreamableHTTP",
  224. transport: new StreamableHTTPClientTransport(new URL(mcp.url), {
  225. authProvider,
  226. requestInit: mcp.headers ? { headers: mcp.headers } : undefined,
  227. }),
  228. },
  229. {
  230. name: "SSE",
  231. transport: new SSEClientTransport(new URL(mcp.url), {
  232. authProvider,
  233. requestInit: mcp.headers ? { headers: mcp.headers } : undefined,
  234. }),
  235. },
  236. ]
  237. let lastError: Error | undefined
  238. for (const { name, transport } of transports) {
  239. try {
  240. const client = new Client({
  241. name: "opencode",
  242. version: Installation.VERSION,
  243. })
  244. await client.connect(transport)
  245. registerNotificationHandlers(client, key)
  246. mcpClient = client
  247. log.info("connected", { key, transport: name })
  248. status = { status: "connected" }
  249. break
  250. } catch (error) {
  251. lastError = error instanceof Error ? error : new Error(String(error))
  252. // Handle OAuth-specific errors
  253. if (error instanceof UnauthorizedError) {
  254. log.info("mcp server requires authentication", { key, transport: name })
  255. // Check if this is a "needs registration" error
  256. if (lastError.message.includes("registration") || lastError.message.includes("client_id")) {
  257. status = {
  258. status: "needs_client_registration" as const,
  259. error: "Server does not support dynamic client registration. Please provide clientId in config.",
  260. }
  261. // Show toast for needs_client_registration
  262. Bus.publish(TuiEvent.ToastShow, {
  263. title: "MCP Authentication Required",
  264. message: `Server "${key}" requires a pre-registered client ID. Add clientId to your config.`,
  265. variant: "warning",
  266. duration: 8000,
  267. }).catch((e) => log.debug("failed to show toast", { error: e }))
  268. } else {
  269. // Store transport for later finishAuth call
  270. pendingOAuthTransports.set(key, transport)
  271. status = { status: "needs_auth" as const }
  272. // Show toast for needs_auth
  273. Bus.publish(TuiEvent.ToastShow, {
  274. title: "MCP Authentication Required",
  275. message: `Server "${key}" requires authentication. Run: opencode mcp auth ${key}`,
  276. variant: "warning",
  277. duration: 8000,
  278. }).catch((e) => log.debug("failed to show toast", { error: e }))
  279. }
  280. break
  281. }
  282. log.debug("transport connection failed", {
  283. key,
  284. transport: name,
  285. url: mcp.url,
  286. error: lastError.message,
  287. })
  288. status = {
  289. status: "failed" as const,
  290. error: lastError.message,
  291. }
  292. }
  293. }
  294. }
  295. if (mcp.type === "local") {
  296. const [cmd, ...args] = mcp.command
  297. const cwd = Instance.directory
  298. const transport = new StdioClientTransport({
  299. stderr: "ignore",
  300. command: cmd,
  301. args,
  302. cwd,
  303. env: {
  304. ...process.env,
  305. ...(cmd === "opencode" ? { BUN_BE_BUN: "1" } : {}),
  306. ...mcp.environment,
  307. },
  308. })
  309. try {
  310. const client = new Client({
  311. name: "opencode",
  312. version: Installation.VERSION,
  313. })
  314. await client.connect(transport)
  315. registerNotificationHandlers(client, key)
  316. mcpClient = client
  317. status = {
  318. status: "connected",
  319. }
  320. } catch (error) {
  321. log.error("local mcp startup failed", {
  322. key,
  323. command: mcp.command,
  324. cwd,
  325. error: error instanceof Error ? error.message : String(error),
  326. })
  327. status = {
  328. status: "failed" as const,
  329. error: error instanceof Error ? error.message : String(error),
  330. }
  331. }
  332. }
  333. if (!status) {
  334. status = {
  335. status: "failed" as const,
  336. error: "Unknown error",
  337. }
  338. }
  339. if (!mcpClient) {
  340. return {
  341. mcpClient: undefined,
  342. status,
  343. }
  344. }
  345. const result = await withTimeout(mcpClient.listTools(), mcp.timeout ?? 5000).catch((err) => {
  346. log.error("failed to get tools from client", { key, error: err })
  347. return undefined
  348. })
  349. if (!result) {
  350. await mcpClient.close().catch((error) => {
  351. log.error("Failed to close MCP client", {
  352. error,
  353. })
  354. })
  355. status = {
  356. status: "failed",
  357. error: "Failed to get tools",
  358. }
  359. return {
  360. mcpClient: undefined,
  361. status: {
  362. status: "failed" as const,
  363. error: "Failed to get tools",
  364. },
  365. }
  366. }
  367. log.info("create() successfully created client", { key, toolCount: result.tools.length })
  368. return {
  369. mcpClient,
  370. status,
  371. }
  372. }
  373. export async function status() {
  374. const s = await state()
  375. const cfg = await Config.get()
  376. const config = cfg.mcp ?? {}
  377. const result: Record<string, Status> = {}
  378. // Include all MCPs from config, not just connected ones
  379. for (const key of Object.keys(config)) {
  380. result[key] = s.status[key] ?? { status: "disabled" }
  381. }
  382. return result
  383. }
  384. export async function clients() {
  385. return state().then((state) => state.clients)
  386. }
  387. export async function connect(name: string) {
  388. const cfg = await Config.get()
  389. const config = cfg.mcp ?? {}
  390. const mcp = config[name]
  391. if (!mcp) {
  392. log.error("MCP config not found", { name })
  393. return
  394. }
  395. const result = await create(name, { ...mcp, enabled: true })
  396. if (!result) {
  397. const s = await state()
  398. s.status[name] = {
  399. status: "failed",
  400. error: "Unknown error during connection",
  401. }
  402. return
  403. }
  404. const s = await state()
  405. s.status[name] = result.status
  406. if (result.mcpClient) {
  407. s.clients[name] = result.mcpClient
  408. }
  409. }
  410. export async function disconnect(name: string) {
  411. const s = await state()
  412. const client = s.clients[name]
  413. if (client) {
  414. await client.close().catch((error) => {
  415. log.error("Failed to close MCP client", { name, error })
  416. })
  417. delete s.clients[name]
  418. }
  419. s.status[name] = { status: "disabled" }
  420. }
  421. export async function tools() {
  422. const result: Record<string, Tool> = {}
  423. const s = await state()
  424. const clientsSnapshot = await clients()
  425. for (const [clientName, client] of Object.entries(clientsSnapshot)) {
  426. // Only include tools from connected MCPs (skip disabled ones)
  427. if (s.status[clientName]?.status !== "connected") {
  428. continue
  429. }
  430. const toolsResult = await client.listTools().catch((e) => {
  431. log.error("failed to get tools", { clientName, error: e.message })
  432. const failedStatus = {
  433. status: "failed" as const,
  434. error: e instanceof Error ? e.message : String(e),
  435. }
  436. s.status[clientName] = failedStatus
  437. delete s.clients[clientName]
  438. return undefined
  439. })
  440. if (!toolsResult) {
  441. continue
  442. }
  443. for (const mcpTool of toolsResult.tools) {
  444. const sanitizedClientName = clientName.replace(/[^a-zA-Z0-9_-]/g, "_")
  445. const sanitizedToolName = mcpTool.name.replace(/[^a-zA-Z0-9_-]/g, "_")
  446. result[sanitizedClientName + "_" + sanitizedToolName] = await convertMcpTool(mcpTool, client)
  447. }
  448. }
  449. return result
  450. }
  451. /**
  452. * Start OAuth authentication flow for an MCP server.
  453. * Returns the authorization URL that should be opened in a browser.
  454. */
  455. export async function startAuth(mcpName: string): Promise<{ authorizationUrl: string }> {
  456. const cfg = await Config.get()
  457. const mcpConfig = cfg.mcp?.[mcpName]
  458. if (!mcpConfig) {
  459. throw new Error(`MCP server not found: ${mcpName}`)
  460. }
  461. if (mcpConfig.type !== "remote") {
  462. throw new Error(`MCP server ${mcpName} is not a remote server`)
  463. }
  464. if (mcpConfig.oauth === false) {
  465. throw new Error(`MCP server ${mcpName} has OAuth explicitly disabled`)
  466. }
  467. // Start the callback server
  468. await McpOAuthCallback.ensureRunning()
  469. // Generate and store a cryptographically secure state parameter BEFORE creating the provider
  470. // The SDK will call provider.state() to read this value
  471. const oauthState = Array.from(crypto.getRandomValues(new Uint8Array(32)))
  472. .map((b) => b.toString(16).padStart(2, "0"))
  473. .join("")
  474. await McpAuth.updateOAuthState(mcpName, oauthState)
  475. // Create a new auth provider for this flow
  476. // OAuth config is optional - if not provided, we'll use auto-discovery
  477. const oauthConfig = typeof mcpConfig.oauth === "object" ? mcpConfig.oauth : undefined
  478. let capturedUrl: URL | undefined
  479. const authProvider = new McpOAuthProvider(
  480. mcpName,
  481. mcpConfig.url,
  482. {
  483. clientId: oauthConfig?.clientId,
  484. clientSecret: oauthConfig?.clientSecret,
  485. scope: oauthConfig?.scope,
  486. },
  487. {
  488. onRedirect: async (url) => {
  489. capturedUrl = url
  490. },
  491. },
  492. )
  493. // Create transport with auth provider
  494. const transport = new StreamableHTTPClientTransport(new URL(mcpConfig.url), {
  495. authProvider,
  496. })
  497. // Try to connect - this will trigger the OAuth flow
  498. try {
  499. const client = new Client({
  500. name: "opencode",
  501. version: Installation.VERSION,
  502. })
  503. await client.connect(transport)
  504. // If we get here, we're already authenticated
  505. return { authorizationUrl: "" }
  506. } catch (error) {
  507. if (error instanceof UnauthorizedError && capturedUrl) {
  508. // Store transport for finishAuth
  509. pendingOAuthTransports.set(mcpName, transport)
  510. return { authorizationUrl: capturedUrl.toString() }
  511. }
  512. throw error
  513. }
  514. }
  515. /**
  516. * Complete OAuth authentication after user authorizes in browser.
  517. * Opens the browser and waits for callback.
  518. */
  519. export async function authenticate(mcpName: string): Promise<Status> {
  520. const { authorizationUrl } = await startAuth(mcpName)
  521. if (!authorizationUrl) {
  522. // Already authenticated
  523. const s = await state()
  524. return s.status[mcpName] ?? { status: "connected" }
  525. }
  526. // Get the state that was already generated and stored in startAuth()
  527. const oauthState = await McpAuth.getOAuthState(mcpName)
  528. if (!oauthState) {
  529. throw new Error("OAuth state not found - this should not happen")
  530. }
  531. // The SDK has already added the state parameter to the authorization URL
  532. // We just need to open the browser
  533. log.info("opening browser for oauth", { mcpName, url: authorizationUrl, state: oauthState })
  534. await open(authorizationUrl)
  535. // Wait for callback using the OAuth state parameter
  536. const code = await McpOAuthCallback.waitForCallback(oauthState)
  537. // Validate and clear the state
  538. const storedState = await McpAuth.getOAuthState(mcpName)
  539. if (storedState !== oauthState) {
  540. await McpAuth.clearOAuthState(mcpName)
  541. throw new Error("OAuth state mismatch - potential CSRF attack")
  542. }
  543. await McpAuth.clearOAuthState(mcpName)
  544. // Finish auth
  545. return finishAuth(mcpName, code)
  546. }
  547. /**
  548. * Complete OAuth authentication with the authorization code.
  549. */
  550. export async function finishAuth(mcpName: string, authorizationCode: string): Promise<Status> {
  551. const transport = pendingOAuthTransports.get(mcpName)
  552. if (!transport) {
  553. throw new Error(`No pending OAuth flow for MCP server: ${mcpName}`)
  554. }
  555. try {
  556. // Call finishAuth on the transport
  557. await transport.finishAuth(authorizationCode)
  558. // Clear the code verifier after successful auth
  559. await McpAuth.clearCodeVerifier(mcpName)
  560. // Now try to reconnect
  561. const cfg = await Config.get()
  562. const mcpConfig = cfg.mcp?.[mcpName]
  563. if (!mcpConfig) {
  564. throw new Error(`MCP server not found: ${mcpName}`)
  565. }
  566. // Re-add the MCP server to establish connection
  567. pendingOAuthTransports.delete(mcpName)
  568. const result = await add(mcpName, mcpConfig)
  569. const statusRecord = result.status as Record<string, Status>
  570. return statusRecord[mcpName] ?? { status: "failed", error: "Unknown error after auth" }
  571. } catch (error) {
  572. log.error("failed to finish oauth", { mcpName, error })
  573. return {
  574. status: "failed",
  575. error: error instanceof Error ? error.message : String(error),
  576. }
  577. }
  578. }
  579. /**
  580. * Remove OAuth credentials for an MCP server.
  581. */
  582. export async function removeAuth(mcpName: string): Promise<void> {
  583. await McpAuth.remove(mcpName)
  584. McpOAuthCallback.cancelPending(mcpName)
  585. pendingOAuthTransports.delete(mcpName)
  586. await McpAuth.clearOAuthState(mcpName)
  587. log.info("removed oauth credentials", { mcpName })
  588. }
  589. /**
  590. * Check if an MCP server supports OAuth (remote servers support OAuth by default unless explicitly disabled).
  591. */
  592. export async function supportsOAuth(mcpName: string): Promise<boolean> {
  593. const cfg = await Config.get()
  594. const mcpConfig = cfg.mcp?.[mcpName]
  595. return mcpConfig?.type === "remote" && mcpConfig.oauth !== false
  596. }
  597. /**
  598. * Check if an MCP server has stored OAuth tokens.
  599. */
  600. export async function hasStoredTokens(mcpName: string): Promise<boolean> {
  601. const entry = await McpAuth.get(mcpName)
  602. return !!entry?.tokens
  603. }
  604. export type AuthStatus = "authenticated" | "expired" | "not_authenticated"
  605. /**
  606. * Get the authentication status for an MCP server.
  607. */
  608. export async function getAuthStatus(mcpName: string): Promise<AuthStatus> {
  609. const hasTokens = await hasStoredTokens(mcpName)
  610. if (!hasTokens) return "not_authenticated"
  611. const expired = await McpAuth.isTokenExpired(mcpName)
  612. return expired ? "expired" : "authenticated"
  613. }
  614. }