index.ts 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import { experimental_createMCPClient, type Tool } from "ai"
  2. import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"
  3. import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"
  4. import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"
  5. import { App } from "../app/app"
  6. import { Config } from "../config/config"
  7. import { Log } from "../util/log"
  8. import { NamedError } from "../util/error"
  9. import { z } from "zod"
  10. import { Session } from "../session"
  11. import { Bus } from "../bus"
  12. export namespace MCP {
  13. const log = Log.create({ service: "mcp" })
  14. export const Failed = NamedError.create(
  15. "MCPFailed",
  16. z.object({
  17. name: z.string(),
  18. }),
  19. )
  20. const state = App.state(
  21. "mcp",
  22. async () => {
  23. const cfg = await Config.get()
  24. const clients: {
  25. [name: string]: Awaited<ReturnType<typeof experimental_createMCPClient>>
  26. } = {}
  27. for (const [key, mcp] of Object.entries(cfg.mcp ?? {})) {
  28. if (mcp.enabled === false) {
  29. log.info("mcp server disabled", { key })
  30. continue
  31. }
  32. log.info("found", { key, type: mcp.type })
  33. if (mcp.type === "remote") {
  34. const transports = [
  35. {
  36. name: "StreamableHTTP",
  37. transport: new StreamableHTTPClientTransport(new URL(mcp.url), {
  38. requestInit: {
  39. headers: mcp.headers,
  40. },
  41. }),
  42. },
  43. {
  44. name: "SSE",
  45. transport: new SSEClientTransport(new URL(mcp.url), {
  46. requestInit: {
  47. headers: mcp.headers,
  48. },
  49. }),
  50. },
  51. ]
  52. let lastError: Error | undefined
  53. for (const { name, transport } of transports) {
  54. const client = await experimental_createMCPClient({
  55. name: key,
  56. transport,
  57. }).catch((error) => {
  58. lastError = error instanceof Error ? error : new Error(String(error))
  59. log.debug("transport connection failed", {
  60. key,
  61. transport: name,
  62. url: mcp.url,
  63. error: lastError.message,
  64. })
  65. return null
  66. })
  67. if (client) {
  68. log.debug("transport connection succeeded", { key, transport: name })
  69. clients[key] = client
  70. break
  71. }
  72. }
  73. if (!clients[key]) {
  74. const errorMessage = lastError
  75. ? `MCP server ${key} failed to connect: ${lastError.message}`
  76. : `MCP server ${key} failed to connect to ${mcp.url}`
  77. log.error("remote mcp connection failed", { key, url: mcp.url, error: lastError?.message })
  78. Bus.publish(Session.Event.Error, {
  79. error: {
  80. name: "UnknownError",
  81. data: {
  82. message: errorMessage,
  83. },
  84. },
  85. })
  86. }
  87. }
  88. if (mcp.type === "local") {
  89. const [cmd, ...args] = mcp.command
  90. const client = await experimental_createMCPClient({
  91. name: key,
  92. transport: new StdioClientTransport({
  93. stderr: "ignore",
  94. command: cmd,
  95. args,
  96. env: {
  97. ...process.env,
  98. ...(cmd === "opencode" ? { BUN_BE_BUN: "1" } : {}),
  99. ...mcp.environment,
  100. },
  101. }),
  102. }).catch((error) => {
  103. const errorMessage =
  104. error instanceof Error
  105. ? `MCP server ${key} failed to start: ${error.message}`
  106. : `MCP server ${key} failed to start`
  107. log.error("local mcp startup failed", {
  108. key,
  109. command: mcp.command,
  110. error: error instanceof Error ? error.message : String(error),
  111. })
  112. Bus.publish(Session.Event.Error, {
  113. error: {
  114. name: "UnknownError",
  115. data: {
  116. message: errorMessage,
  117. },
  118. },
  119. })
  120. return null
  121. })
  122. if (client) {
  123. clients[key] = client
  124. }
  125. }
  126. }
  127. return {
  128. clients,
  129. }
  130. },
  131. async (state) => {
  132. for (const client of Object.values(state.clients)) {
  133. client.close()
  134. }
  135. },
  136. )
  137. export async function clients() {
  138. return state().then((state) => state.clients)
  139. }
  140. export async function tools() {
  141. const result: Record<string, Tool> = {}
  142. for (const [clientName, client] of Object.entries(await clients())) {
  143. for (const [toolName, tool] of Object.entries(await client.tools())) {
  144. const sanitizedClientName = clientName.replace(/\s+/g, "_")
  145. const sanitizedToolName = toolName.replace(/[-\s]+/g, "_")
  146. result[sanitizedClientName + "_" + sanitizedToolName] = tool
  147. }
  148. }
  149. return result
  150. }
  151. }