index.ts 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. import { dynamicTool, type Tool, jsonSchema, type JSONSchema7 } from "ai"
  2. import { Client } from "@modelcontextprotocol/sdk/client/index.js"
  3. import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"
  4. import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"
  5. import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"
  6. import { UnauthorizedError } from "@modelcontextprotocol/sdk/client/auth.js"
  7. import type { Tool as MCPToolDef } from "@modelcontextprotocol/sdk/types.js"
  8. import { Config } from "../config/config"
  9. import { Log } from "../util/log"
  10. import { NamedError } from "@opencode-ai/util/error"
  11. import z from "zod/v4"
  12. import { Instance } from "../project/instance"
  13. import { Installation } from "../installation"
  14. import { withTimeout } from "@/util/timeout"
  15. import { McpOAuthProvider } from "./oauth-provider"
  16. import { McpOAuthCallback } from "./oauth-callback"
  17. import { McpAuth } from "./auth"
  18. import { Bus } from "@/bus"
  19. import { TuiEvent } from "@/cli/cmd/tui/event"
  20. import open from "open"
  21. export namespace MCP {
  22. const log = Log.create({ service: "mcp" })
  23. export const Failed = NamedError.create(
  24. "MCPFailed",
  25. z.object({
  26. name: z.string(),
  27. }),
  28. )
  29. type MCPClient = Client
  30. export const Status = z
  31. .discriminatedUnion("status", [
  32. z
  33. .object({
  34. status: z.literal("connected"),
  35. })
  36. .meta({
  37. ref: "MCPStatusConnected",
  38. }),
  39. z
  40. .object({
  41. status: z.literal("disabled"),
  42. })
  43. .meta({
  44. ref: "MCPStatusDisabled",
  45. }),
  46. z
  47. .object({
  48. status: z.literal("failed"),
  49. error: z.string(),
  50. })
  51. .meta({
  52. ref: "MCPStatusFailed",
  53. }),
  54. z
  55. .object({
  56. status: z.literal("needs_auth"),
  57. })
  58. .meta({
  59. ref: "MCPStatusNeedsAuth",
  60. }),
  61. z
  62. .object({
  63. status: z.literal("needs_client_registration"),
  64. error: z.string(),
  65. })
  66. .meta({
  67. ref: "MCPStatusNeedsClientRegistration",
  68. }),
  69. ])
  70. .meta({
  71. ref: "MCPStatus",
  72. })
  73. export type Status = z.infer<typeof Status>
  74. // Convert MCP tool definition to AI SDK Tool type
  75. function convertMcpTool(mcpTool: MCPToolDef, client: MCPClient): Tool {
  76. const inputSchema = mcpTool.inputSchema
  77. // Spread first, then override type to ensure it's always "object"
  78. const schema: JSONSchema7 = {
  79. ...(inputSchema as JSONSchema7),
  80. type: "object",
  81. properties: (inputSchema.properties ?? {}) as JSONSchema7["properties"],
  82. additionalProperties: false,
  83. }
  84. return dynamicTool({
  85. description: mcpTool.description ?? "",
  86. inputSchema: jsonSchema(schema),
  87. execute: async (args: unknown) => {
  88. return client.callTool({
  89. name: mcpTool.name,
  90. arguments: args as Record<string, unknown>,
  91. })
  92. },
  93. })
  94. }
  95. // Store transports for OAuth servers to allow finishing auth
  96. type TransportWithAuth = StreamableHTTPClientTransport | SSEClientTransport
  97. const pendingOAuthTransports = new Map<string, TransportWithAuth>()
  98. const state = Instance.state(
  99. async () => {
  100. const cfg = await Config.get()
  101. const config = cfg.mcp ?? {}
  102. const clients: Record<string, MCPClient> = {}
  103. const status: Record<string, Status> = {}
  104. await Promise.all(
  105. Object.entries(config).map(async ([key, mcp]) => {
  106. // If disabled by config, mark as disabled without trying to connect
  107. if (mcp.enabled === false) {
  108. status[key] = { status: "disabled" }
  109. return
  110. }
  111. const result = await create(key, mcp).catch(() => undefined)
  112. if (!result) return
  113. status[key] = result.status
  114. if (result.mcpClient) {
  115. clients[key] = result.mcpClient
  116. }
  117. }),
  118. )
  119. return {
  120. status,
  121. clients,
  122. }
  123. },
  124. async (state) => {
  125. await Promise.all(
  126. Object.values(state.clients).map((client) =>
  127. client.close().catch((error) => {
  128. log.error("Failed to close MCP client", {
  129. error,
  130. })
  131. }),
  132. ),
  133. )
  134. pendingOAuthTransports.clear()
  135. },
  136. )
  137. export async function add(name: string, mcp: Config.Mcp) {
  138. const s = await state()
  139. const result = await create(name, mcp)
  140. if (!result) {
  141. const status = {
  142. status: "failed" as const,
  143. error: "unknown error",
  144. }
  145. s.status[name] = status
  146. return {
  147. status,
  148. }
  149. }
  150. if (!result.mcpClient) {
  151. s.status[name] = result.status
  152. return {
  153. status: s.status,
  154. }
  155. }
  156. s.clients[name] = result.mcpClient
  157. s.status[name] = result.status
  158. return {
  159. status: s.status,
  160. }
  161. }
  162. async function create(key: string, mcp: Config.Mcp) {
  163. if (mcp.enabled === false) {
  164. log.info("mcp server disabled", { key })
  165. return {
  166. mcpClient: undefined,
  167. status: { status: "disabled" as const },
  168. }
  169. }
  170. log.info("found", { key, type: mcp.type })
  171. let mcpClient: MCPClient | undefined
  172. let status: Status | undefined = undefined
  173. if (mcp.type === "remote") {
  174. // OAuth is enabled by default for remote servers unless explicitly disabled with oauth: false
  175. const oauthDisabled = mcp.oauth === false
  176. const oauthConfig = typeof mcp.oauth === "object" ? mcp.oauth : undefined
  177. let authProvider: McpOAuthProvider | undefined
  178. if (!oauthDisabled) {
  179. authProvider = new McpOAuthProvider(
  180. key,
  181. mcp.url,
  182. {
  183. clientId: oauthConfig?.clientId,
  184. clientSecret: oauthConfig?.clientSecret,
  185. scope: oauthConfig?.scope,
  186. },
  187. {
  188. onRedirect: async (url) => {
  189. log.info("oauth redirect requested", { key, url: url.toString() })
  190. // Store the URL - actual browser opening is handled by startAuth
  191. },
  192. },
  193. )
  194. }
  195. const transports: Array<{ name: string; transport: TransportWithAuth }> = [
  196. {
  197. name: "StreamableHTTP",
  198. transport: new StreamableHTTPClientTransport(new URL(mcp.url), {
  199. authProvider,
  200. requestInit: mcp.headers ? { headers: mcp.headers } : undefined,
  201. }),
  202. },
  203. {
  204. name: "SSE",
  205. transport: new SSEClientTransport(new URL(mcp.url), {
  206. authProvider,
  207. requestInit: mcp.headers ? { headers: mcp.headers } : undefined,
  208. }),
  209. },
  210. ]
  211. let lastError: Error | undefined
  212. for (const { name, transport } of transports) {
  213. try {
  214. const client = new Client({
  215. name: "opencode",
  216. version: Installation.VERSION,
  217. })
  218. await client.connect(transport)
  219. mcpClient = client
  220. log.info("connected", { key, transport: name })
  221. status = { status: "connected" }
  222. break
  223. } catch (error) {
  224. lastError = error instanceof Error ? error : new Error(String(error))
  225. // Handle OAuth-specific errors
  226. if (error instanceof UnauthorizedError) {
  227. log.info("mcp server requires authentication", { key, transport: name })
  228. // Check if this is a "needs registration" error
  229. if (lastError.message.includes("registration") || lastError.message.includes("client_id")) {
  230. status = {
  231. status: "needs_client_registration" as const,
  232. error: "Server does not support dynamic client registration. Please provide clientId in config.",
  233. }
  234. // Show toast for needs_client_registration
  235. Bus.publish(TuiEvent.ToastShow, {
  236. title: "MCP Authentication Required",
  237. message: `Server "${key}" requires a pre-registered client ID. Add clientId to your config.`,
  238. variant: "warning",
  239. duration: 8000,
  240. }).catch((e) => log.debug("failed to show toast", { error: e }))
  241. } else {
  242. // Store transport for later finishAuth call
  243. pendingOAuthTransports.set(key, transport)
  244. status = { status: "needs_auth" as const }
  245. // Show toast for needs_auth
  246. Bus.publish(TuiEvent.ToastShow, {
  247. title: "MCP Authentication Required",
  248. message: `Server "${key}" requires authentication. Run: opencode mcp auth ${key}`,
  249. variant: "warning",
  250. duration: 8000,
  251. }).catch((e) => log.debug("failed to show toast", { error: e }))
  252. }
  253. break
  254. }
  255. log.debug("transport connection failed", {
  256. key,
  257. transport: name,
  258. url: mcp.url,
  259. error: lastError.message,
  260. })
  261. status = {
  262. status: "failed" as const,
  263. error: lastError.message,
  264. }
  265. }
  266. }
  267. }
  268. if (mcp.type === "local") {
  269. const [cmd, ...args] = mcp.command
  270. const transport = new StdioClientTransport({
  271. stderr: "ignore",
  272. command: cmd,
  273. args,
  274. env: {
  275. ...process.env,
  276. ...(cmd === "opencode" ? { BUN_BE_BUN: "1" } : {}),
  277. ...mcp.environment,
  278. },
  279. })
  280. try {
  281. const client = new Client({
  282. name: "opencode",
  283. version: Installation.VERSION,
  284. })
  285. await client.connect(transport)
  286. mcpClient = client
  287. status = {
  288. status: "connected",
  289. }
  290. } catch (error) {
  291. log.error("local mcp startup failed", {
  292. key,
  293. command: mcp.command,
  294. error: error instanceof Error ? error.message : String(error),
  295. })
  296. status = {
  297. status: "failed" as const,
  298. error: error instanceof Error ? error.message : String(error),
  299. }
  300. }
  301. }
  302. if (!status) {
  303. status = {
  304. status: "failed" as const,
  305. error: "Unknown error",
  306. }
  307. }
  308. if (!mcpClient) {
  309. return {
  310. mcpClient: undefined,
  311. status,
  312. }
  313. }
  314. const result = await withTimeout(mcpClient.listTools(), mcp.timeout ?? 5000).catch((err) => {
  315. log.error("failed to get tools from client", { key, error: err })
  316. return undefined
  317. })
  318. if (!result) {
  319. await mcpClient.close().catch((error) => {
  320. log.error("Failed to close MCP client", {
  321. error,
  322. })
  323. })
  324. status = {
  325. status: "failed",
  326. error: "Failed to get tools",
  327. }
  328. return {
  329. mcpClient: undefined,
  330. status: {
  331. status: "failed" as const,
  332. error: "Failed to get tools",
  333. },
  334. }
  335. }
  336. log.info("create() successfully created client", { key, toolCount: result.tools.length })
  337. return {
  338. mcpClient,
  339. status,
  340. }
  341. }
  342. export async function status() {
  343. const s = await state()
  344. const cfg = await Config.get()
  345. const config = cfg.mcp ?? {}
  346. const result: Record<string, Status> = {}
  347. // Include all MCPs from config, not just connected ones
  348. for (const key of Object.keys(config)) {
  349. result[key] = s.status[key] ?? { status: "disabled" }
  350. }
  351. return result
  352. }
  353. export async function clients() {
  354. return state().then((state) => state.clients)
  355. }
  356. export async function connect(name: string) {
  357. const cfg = await Config.get()
  358. const config = cfg.mcp ?? {}
  359. const mcp = config[name]
  360. if (!mcp) {
  361. log.error("MCP config not found", { name })
  362. return
  363. }
  364. const result = await create(name, { ...mcp, enabled: true })
  365. if (!result) {
  366. const s = await state()
  367. s.status[name] = {
  368. status: "failed",
  369. error: "Unknown error during connection",
  370. }
  371. return
  372. }
  373. const s = await state()
  374. s.status[name] = result.status
  375. if (result.mcpClient) {
  376. s.clients[name] = result.mcpClient
  377. }
  378. }
  379. export async function disconnect(name: string) {
  380. const s = await state()
  381. const client = s.clients[name]
  382. if (client) {
  383. await client.close().catch((error) => {
  384. log.error("Failed to close MCP client", { name, error })
  385. })
  386. delete s.clients[name]
  387. }
  388. s.status[name] = { status: "disabled" }
  389. }
  390. export async function tools() {
  391. const result: Record<string, Tool> = {}
  392. const s = await state()
  393. const clientsSnapshot = await clients()
  394. for (const [clientName, client] of Object.entries(clientsSnapshot)) {
  395. // Only include tools from connected MCPs (skip disabled ones)
  396. if (s.status[clientName]?.status !== "connected") {
  397. continue
  398. }
  399. const toolsResult = await client.listTools().catch((e) => {
  400. log.error("failed to get tools", { clientName, error: e.message })
  401. const failedStatus = {
  402. status: "failed" as const,
  403. error: e instanceof Error ? e.message : String(e),
  404. }
  405. s.status[clientName] = failedStatus
  406. delete s.clients[clientName]
  407. return undefined
  408. })
  409. if (!toolsResult) {
  410. continue
  411. }
  412. for (const mcpTool of toolsResult.tools) {
  413. const sanitizedClientName = clientName.replace(/[^a-zA-Z0-9_-]/g, "_")
  414. const sanitizedToolName = mcpTool.name.replace(/[^a-zA-Z0-9_-]/g, "_")
  415. result[sanitizedClientName + "_" + sanitizedToolName] = convertMcpTool(mcpTool, client)
  416. }
  417. }
  418. return result
  419. }
  420. /**
  421. * Start OAuth authentication flow for an MCP server.
  422. * Returns the authorization URL that should be opened in a browser.
  423. */
  424. export async function startAuth(mcpName: string): Promise<{ authorizationUrl: string }> {
  425. const cfg = await Config.get()
  426. const mcpConfig = cfg.mcp?.[mcpName]
  427. if (!mcpConfig) {
  428. throw new Error(`MCP server not found: ${mcpName}`)
  429. }
  430. if (mcpConfig.type !== "remote") {
  431. throw new Error(`MCP server ${mcpName} is not a remote server`)
  432. }
  433. if (mcpConfig.oauth === false) {
  434. throw new Error(`MCP server ${mcpName} has OAuth explicitly disabled`)
  435. }
  436. // Start the callback server
  437. await McpOAuthCallback.ensureRunning()
  438. // Generate and store a cryptographically secure state parameter BEFORE creating the provider
  439. // The SDK will call provider.state() to read this value
  440. const oauthState = Array.from(crypto.getRandomValues(new Uint8Array(32)))
  441. .map((b) => b.toString(16).padStart(2, "0"))
  442. .join("")
  443. await McpAuth.updateOAuthState(mcpName, oauthState)
  444. // Create a new auth provider for this flow
  445. // OAuth config is optional - if not provided, we'll use auto-discovery
  446. const oauthConfig = typeof mcpConfig.oauth === "object" ? mcpConfig.oauth : undefined
  447. let capturedUrl: URL | undefined
  448. const authProvider = new McpOAuthProvider(
  449. mcpName,
  450. mcpConfig.url,
  451. {
  452. clientId: oauthConfig?.clientId,
  453. clientSecret: oauthConfig?.clientSecret,
  454. scope: oauthConfig?.scope,
  455. },
  456. {
  457. onRedirect: async (url) => {
  458. capturedUrl = url
  459. },
  460. },
  461. )
  462. // Create transport with auth provider
  463. const transport = new StreamableHTTPClientTransport(new URL(mcpConfig.url), {
  464. authProvider,
  465. })
  466. // Try to connect - this will trigger the OAuth flow
  467. try {
  468. const client = new Client({
  469. name: "opencode",
  470. version: Installation.VERSION,
  471. })
  472. await client.connect(transport)
  473. // If we get here, we're already authenticated
  474. return { authorizationUrl: "" }
  475. } catch (error) {
  476. if (error instanceof UnauthorizedError && capturedUrl) {
  477. // Store transport for finishAuth
  478. pendingOAuthTransports.set(mcpName, transport)
  479. return { authorizationUrl: capturedUrl.toString() }
  480. }
  481. throw error
  482. }
  483. }
  484. /**
  485. * Complete OAuth authentication after user authorizes in browser.
  486. * Opens the browser and waits for callback.
  487. */
  488. export async function authenticate(mcpName: string): Promise<Status> {
  489. const { authorizationUrl } = await startAuth(mcpName)
  490. if (!authorizationUrl) {
  491. // Already authenticated
  492. const s = await state()
  493. return s.status[mcpName] ?? { status: "connected" }
  494. }
  495. // Get the state that was already generated and stored in startAuth()
  496. const oauthState = await McpAuth.getOAuthState(mcpName)
  497. if (!oauthState) {
  498. throw new Error("OAuth state not found - this should not happen")
  499. }
  500. // The SDK has already added the state parameter to the authorization URL
  501. // We just need to open the browser
  502. log.info("opening browser for oauth", { mcpName, url: authorizationUrl, state: oauthState })
  503. await open(authorizationUrl)
  504. // Wait for callback using the OAuth state parameter
  505. const code = await McpOAuthCallback.waitForCallback(oauthState)
  506. // Validate and clear the state
  507. const storedState = await McpAuth.getOAuthState(mcpName)
  508. if (storedState !== oauthState) {
  509. await McpAuth.clearOAuthState(mcpName)
  510. throw new Error("OAuth state mismatch - potential CSRF attack")
  511. }
  512. await McpAuth.clearOAuthState(mcpName)
  513. // Finish auth
  514. return finishAuth(mcpName, code)
  515. }
  516. /**
  517. * Complete OAuth authentication with the authorization code.
  518. */
  519. export async function finishAuth(mcpName: string, authorizationCode: string): Promise<Status> {
  520. const transport = pendingOAuthTransports.get(mcpName)
  521. if (!transport) {
  522. throw new Error(`No pending OAuth flow for MCP server: ${mcpName}`)
  523. }
  524. try {
  525. // Call finishAuth on the transport
  526. await transport.finishAuth(authorizationCode)
  527. // Clear the code verifier after successful auth
  528. await McpAuth.clearCodeVerifier(mcpName)
  529. // Now try to reconnect
  530. const cfg = await Config.get()
  531. const mcpConfig = cfg.mcp?.[mcpName]
  532. if (!mcpConfig) {
  533. throw new Error(`MCP server not found: ${mcpName}`)
  534. }
  535. // Re-add the MCP server to establish connection
  536. pendingOAuthTransports.delete(mcpName)
  537. const result = await add(mcpName, mcpConfig)
  538. const statusRecord = result.status as Record<string, Status>
  539. return statusRecord[mcpName] ?? { status: "failed", error: "Unknown error after auth" }
  540. } catch (error) {
  541. log.error("failed to finish oauth", { mcpName, error })
  542. return {
  543. status: "failed",
  544. error: error instanceof Error ? error.message : String(error),
  545. }
  546. }
  547. }
  548. /**
  549. * Remove OAuth credentials for an MCP server.
  550. */
  551. export async function removeAuth(mcpName: string): Promise<void> {
  552. await McpAuth.remove(mcpName)
  553. McpOAuthCallback.cancelPending(mcpName)
  554. pendingOAuthTransports.delete(mcpName)
  555. await McpAuth.clearOAuthState(mcpName)
  556. log.info("removed oauth credentials", { mcpName })
  557. }
  558. /**
  559. * Check if an MCP server supports OAuth (remote servers support OAuth by default unless explicitly disabled).
  560. */
  561. export async function supportsOAuth(mcpName: string): Promise<boolean> {
  562. const cfg = await Config.get()
  563. const mcpConfig = cfg.mcp?.[mcpName]
  564. return mcpConfig?.type === "remote" && mcpConfig.oauth !== false
  565. }
  566. /**
  567. * Check if an MCP server has stored OAuth tokens.
  568. */
  569. export async function hasStoredTokens(mcpName: string): Promise<boolean> {
  570. const entry = await McpAuth.get(mcpName)
  571. return !!entry?.tokens
  572. }
  573. export type AuthStatus = "authenticated" | "expired" | "not_authenticated"
  574. /**
  575. * Get the authentication status for an MCP server.
  576. */
  577. export async function getAuthStatus(mcpName: string): Promise<AuthStatus> {
  578. const hasTokens = await hasStoredTokens(mcpName)
  579. if (!hasTokens) return "not_authenticated"
  580. const expired = await McpAuth.isTokenExpired(mcpName)
  581. return expired ? "expired" : "authenticated"
  582. }
  583. }