index.ts 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import { experimental_createMCPClient, type Tool } from "ai"
  2. import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"
  3. import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"
  4. import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"
  5. import { Config } from "../config/config"
  6. import { Log } from "../util/log"
  7. import { NamedError } from "../util/error"
  8. import z from "zod/v4"
  9. import { Session } from "../session"
  10. import { Bus } from "../bus"
  11. import { Instance } from "../project/instance"
  12. export namespace MCP {
  13. const log = Log.create({ service: "mcp" })
  14. export const Failed = NamedError.create(
  15. "MCPFailed",
  16. z.object({
  17. name: z.string(),
  18. }),
  19. )
  20. const state = Instance.state(
  21. async () => {
  22. const cfg = await Config.get()
  23. const clients: {
  24. [name: string]: Awaited<ReturnType<typeof experimental_createMCPClient>>
  25. } = {}
  26. for (const [key, mcp] of Object.entries(cfg.mcp ?? {})) {
  27. if (mcp.enabled === false) {
  28. log.info("mcp server disabled", { key })
  29. continue
  30. }
  31. log.info("found", { key, type: mcp.type })
  32. if (mcp.type === "remote") {
  33. const transports = [
  34. {
  35. name: "StreamableHTTP",
  36. transport: new StreamableHTTPClientTransport(new URL(mcp.url), {
  37. requestInit: {
  38. headers: mcp.headers,
  39. },
  40. }),
  41. },
  42. {
  43. name: "SSE",
  44. transport: new SSEClientTransport(new URL(mcp.url), {
  45. requestInit: {
  46. headers: mcp.headers,
  47. },
  48. }),
  49. },
  50. ]
  51. let lastError: Error | undefined
  52. for (const { name, transport } of transports) {
  53. const client = await experimental_createMCPClient({
  54. name: "opencode",
  55. transport,
  56. }).catch((error) => {
  57. lastError = error instanceof Error ? error : new Error(String(error))
  58. log.debug("transport connection failed", {
  59. key,
  60. transport: name,
  61. url: mcp.url,
  62. error: lastError.message,
  63. })
  64. return null
  65. })
  66. if (client) {
  67. log.debug("transport connection succeeded", { key, transport: name })
  68. clients[key] = client
  69. break
  70. }
  71. }
  72. if (!clients[key]) {
  73. const errorMessage = lastError
  74. ? `MCP server ${key} failed to connect: ${lastError.message}`
  75. : `MCP server ${key} failed to connect to ${mcp.url}`
  76. log.error("remote mcp connection failed", { key, url: mcp.url, error: lastError?.message })
  77. Bus.publish(Session.Event.Error, {
  78. error: {
  79. name: "UnknownError",
  80. data: {
  81. message: errorMessage,
  82. },
  83. },
  84. })
  85. }
  86. }
  87. if (mcp.type === "local") {
  88. const [cmd, ...args] = mcp.command
  89. const client = await experimental_createMCPClient({
  90. name: "opencode",
  91. transport: new StdioClientTransport({
  92. stderr: "ignore",
  93. command: cmd,
  94. args,
  95. env: {
  96. ...process.env,
  97. ...(cmd === "opencode" ? { BUN_BE_BUN: "1" } : {}),
  98. ...mcp.environment,
  99. },
  100. }),
  101. }).catch((error) => {
  102. const errorMessage =
  103. error instanceof Error
  104. ? `MCP server ${key} failed to start: ${error.message}`
  105. : `MCP server ${key} failed to start`
  106. log.error("local mcp startup failed", {
  107. key,
  108. command: mcp.command,
  109. error: error instanceof Error ? error.message : String(error),
  110. })
  111. Bus.publish(Session.Event.Error, {
  112. error: {
  113. name: "UnknownError",
  114. data: {
  115. message: errorMessage,
  116. },
  117. },
  118. })
  119. return null
  120. })
  121. if (client) {
  122. clients[key] = client
  123. }
  124. }
  125. }
  126. return {
  127. clients,
  128. }
  129. },
  130. async (state) => {
  131. for (const client of Object.values(state.clients)) {
  132. client.close()
  133. }
  134. },
  135. )
  136. export async function clients() {
  137. return state().then((state) => state.clients)
  138. }
  139. export async function tools() {
  140. const result: Record<string, Tool> = {}
  141. for (const [clientName, client] of Object.entries(await clients())) {
  142. for (const [toolName, tool] of Object.entries(await client.tools())) {
  143. const sanitizedClientName = clientName.replace(/\s+/g, "_")
  144. const sanitizedToolName = toolName.replace(/[-\s]+/g, "_")
  145. result[sanitizedClientName + "_" + sanitizedToolName] = tool
  146. }
  147. }
  148. return result
  149. }
  150. }