Просмотр исходного кода

feat (acp): mcp server support, file diffs, some default slash commands (/init, /compact), show todos properly (#3490)

The mcp server support does not mean acp didn't allow u to use mcp servers previously, it means that now you can connect new servers via ACP instead of relying on the opencode defined ones
Aiden Cline 3 месяцев назад
Родитель
Сommit
982954cc1b

+ 178 - 60
packages/opencode/src/acp/agent.ts

@@ -1,16 +1,20 @@
-import type {
-  Agent as ACPAgent,
-  AgentSideConnection,
-  AuthenticateRequest,
-  CancelNotification,
-  InitializeRequest,
-  LoadSessionRequest,
-  NewSessionRequest,
-  PermissionOption,
-  PromptRequest,
-  SetSessionModelRequest,
-  SetSessionModeRequest,
-  SetSessionModeResponse,
+import {
+  sessionModeSchema,
+  type Agent as ACPAgent,
+  type AgentSideConnection,
+  type AuthenticateRequest,
+  type CancelNotification,
+  type InitializeRequest,
+  type LoadSessionRequest,
+  type NewSessionRequest,
+  type PermissionOption,
+  type PlanEntry,
+  type PromptRequest,
+  type SetSessionModelRequest,
+  type SetSessionModeRequest,
+  type SetSessionModeResponse,
+  type ToolCallContent,
+  type ToolKind,
 } from "@agentclientprotocol/sdk"
 import { Log } from "../util/log"
 import { ACPSessionManager } from "./session"
@@ -25,24 +29,17 @@ import { Storage } from "@/storage/storage"
 import { Command } from "@/command"
 import { Agent as Agents } from "@/agent/agent"
 import { Permission } from "@/permission"
+import { Session } from "@/session"
+import { Identifier } from "@/id/id"
+import { SessionCompaction } from "@/session/compaction"
+import type { Config } from "@/config/config"
+import { MCP } from "@/mcp"
+import { Todo } from "@/session/todo"
+import { z } from "zod"
 
 export namespace ACP {
   const log = Log.create({ service: "acp-agent" })
 
-  // TODO: mcp servers?
-
-  type ToolKind =
-    | "read"
-    | "edit"
-    | "delete"
-    | "move"
-    | "search"
-    | "execute"
-    | "think"
-    | "fetch"
-    | "switch_mode"
-    | "other"
-
   export class Agent implements ACPAgent {
     private sessionManager = new ACPSessionManager()
     private connection: AgentSideConnection
@@ -157,6 +154,62 @@ export namespace ACP {
                 })
               break
             case "completed":
+              const kind = toToolKind(part.tool)
+              const content: ToolCallContent[] = [
+                {
+                  type: "content",
+                  content: {
+                    type: "text",
+                    text: part.state.output,
+                  },
+                },
+              ]
+
+              if (kind === "edit") {
+                const input = part.state.input
+                const filePath = typeof input["filePath"] === "string" ? input["filePath"] : ""
+                const oldText = typeof input["oldString"] === "string" ? input["oldString"] : ""
+                const newText =
+                  typeof input["newString"] === "string"
+                    ? input["newString"]
+                    : typeof input["content"] === "string"
+                      ? input["content"]
+                      : ""
+                content.push({
+                  type: "diff",
+                  path: filePath,
+                  oldText,
+                  newText,
+                })
+              }
+
+              if (part.tool === "todowrite") {
+                const parsedTodos = z.array(Todo.Info).safeParse(JSON.parse(part.state.output))
+                if (parsedTodos.success) {
+                  await this.connection
+                    .sessionUpdate({
+                      sessionId: acpSession.id,
+                      update: {
+                        sessionUpdate: "plan",
+                        entries: parsedTodos.data.map((todo) => {
+                          const status: PlanEntry["status"] =
+                            todo.status === "cancelled" ? "completed" : (todo.status as PlanEntry["status"])
+                          return {
+                            priority: "medium",
+                            status,
+                            content: todo.content,
+                          }
+                        }),
+                      },
+                    })
+                    .catch((err) => {
+                      log.error("failed to send session update for todo", { error: err })
+                    })
+                } else {
+                  log.error("failed to parse todo output", { error: parsedTodos.error })
+                }
+              }
+
               await this.connection
                 .sessionUpdate({
                   sessionId: acpSession.id,
@@ -164,15 +217,8 @@ export namespace ACP {
                     sessionUpdate: "tool_call_update",
                     toolCallId: part.callID,
                     status: "completed",
-                    content: [
-                      {
-                        type: "content",
-                        content: {
-                          type: "text",
-                          text: part.state.output,
-                        },
-                      },
-                    ],
+                    kind,
+                    content,
                     title: part.state.title,
                     rawOutput: {
                       output: part.state.output,
@@ -258,11 +304,14 @@ export namespace ACP {
         protocolVersion: 1,
         agentCapabilities: {
           loadSession: true,
-          // TODO: map acp mcp
-          // mcpCapabilities: {
-          //   http: true,
-          //   sse: true,
-          // },
+          mcpCapabilities: {
+            http: true,
+            sse: true,
+          },
+          promptCapabilities: {
+            embeddedContext: true,
+            image: true,
+          },
         },
         authMethods: [
           {
@@ -287,6 +336,7 @@ export namespace ACP {
       const model = await defaultModel(this.config)
       const session = await this.sessionManager.create(params.cwd, params.mcpServers, model)
 
+      log.info("creating_session", { mcpServers: params.mcpServers.length })
       const load = await this.loadSession({
         cwd: params.cwd,
         mcpServers: params.mcpServers,
@@ -325,6 +375,17 @@ export namespace ACP {
         name: command.name,
         description: command.description ?? "",
       }))
+      const names = new Set(availableCommands.map((c) => c.name))
+      if (!names.has("init"))
+        availableCommands.push({
+          name: "init",
+          description: "create/update a AGENTS.md",
+        })
+      if (!names.has("compact"))
+        availableCommands.push({
+          name: "compact",
+          description: "compact the session",
+        })
 
       setTimeout(() => {
         this.connection.sessionUpdate({
@@ -346,6 +407,35 @@ export namespace ACP {
 
       const currentModeId = availableModes.find((m) => m.name === "build")?.id ?? availableModes[0].id
 
+      const mcpServers: Record<string, Config.Mcp> = {}
+      for (const server of params.mcpServers) {
+        if ("type" in server) {
+          mcpServers[server.name] = {
+            url: server.url,
+            headers: server.headers.reduce<Record<string, string>>((acc, { name, value }) => {
+              acc[name] = value
+              return acc
+            }, {}),
+            type: "remote",
+          }
+        } else {
+          mcpServers[server.name] = {
+            type: "local",
+            command: [server.command, ...server.args],
+            environment: server.env.reduce<Record<string, string>>((acc, { name, value }) => {
+              acc[name] = value
+              return acc
+            }, {}),
+          }
+        }
+      }
+
+      await Promise.all(
+        Object.entries(mcpServers).map(async ([key, mcp]) => {
+          await MCP.add(key, mcp)
+        }),
+      )
+
       return {
         sessionId,
         models: {
@@ -452,25 +542,25 @@ export namespace ACP {
 
       log.info("parts", { parts })
 
-      const cmd = await (async () => {
-        const text = parts.filter((part) => part.type === "text").join("")
-        const match = text.match(/^\/(\w+)\s*(.*)$/)
-        if (!match) return
+      const cmd = (() => {
+        const text = parts
+          .filter((p) => p.type === "text")
+          .map((p) => p.text)
+          .join("")
+          .trim()
 
-        const [c, args] = match.slice(1)
-        const command = await Command.get(c)
-        if (!command) return
-        return { command, args }
+        if (!text.startsWith("/")) return
+
+        const [name, ...rest] = text.slice(1).split(/\s+/)
+        return { name, args: rest.join(" ").trim() }
       })()
 
-      if (cmd) {
-        await SessionPrompt.command({
-          sessionID,
-          command: cmd.command.name,
-          arguments: cmd.args,
-          agent,
-        })
-      } else {
+      const done = {
+        stopReason: "end_turn" as const,
+        _meta: {},
+      }
+
+      if (!cmd) {
         await SessionPrompt.prompt({
           sessionID,
           model: {
@@ -480,12 +570,40 @@ export namespace ACP {
           parts,
           agent,
         })
+        return done
       }
 
-      return {
-        stopReason: "end_turn" as const,
-        _meta: {},
+      const command = await Command.get(cmd.name)
+      if (command) {
+        await SessionPrompt.command({
+          sessionID,
+          command: command.name,
+          arguments: cmd.args,
+          model: model.providerID + "/" + model.modelID,
+          agent,
+        })
+        return done
       }
+
+      switch (cmd.name) {
+        case "init":
+          await Session.initialize({
+            sessionID,
+            messageID: Identifier.ascending("message"),
+            providerID: model.providerID,
+            modelID: model.modelID,
+          })
+          break
+        case "compact":
+          await SessionCompaction.run({
+            sessionID,
+            providerID: model.providerID,
+            modelID: model.modelID,
+          })
+          break
+      }
+
+      return done
     }
 
     async cancel(params: CancelNotification) {

+ 136 - 109
packages/opencode/src/mcp/index.ts

@@ -26,122 +26,22 @@ export namespace MCP {
   const state = Instance.state(
     async () => {
       const cfg = await Config.get()
+      const config = cfg.mcp ?? {}
       const clients: {
         [name: string]: MCPClient
       } = {}
-      for (const [key, mcp] of Object.entries(cfg.mcp ?? {})) {
-        if (mcp.enabled === false) {
-          log.info("mcp server disabled", { key })
-          continue
-        }
-        log.info("found", { key, type: mcp.type })
-        if (mcp.type === "remote") {
-          const transports = [
-            {
-              name: "StreamableHTTP",
-              transport: new StreamableHTTPClientTransport(new URL(mcp.url), {
-                requestInit: {
-                  headers: mcp.headers,
-                },
-              }),
-            },
-            {
-              name: "SSE",
-              transport: new SSEClientTransport(new URL(mcp.url), {
-                requestInit: {
-                  headers: mcp.headers,
-                },
-              }),
-            },
-          ]
-          let lastError: Error | undefined
-          for (const { name, transport } of transports) {
-            const client = await experimental_createMCPClient({
-              name: "opencode",
-              transport,
-            }).catch((error) => {
-              lastError = error instanceof Error ? error : new Error(String(error))
-              log.debug("transport connection failed", {
-                key,
-                transport: name,
-                url: mcp.url,
-                error: lastError.message,
-              })
-              return null
-            })
-            if (client) {
-              log.debug("transport connection succeeded", { key, transport: name })
-              clients[key] = client
-              break
-            }
-          }
-          if (!clients[key]) {
-            const errorMessage = lastError
-              ? `MCP server ${key} failed to connect: ${lastError.message}`
-              : `MCP server ${key} failed to connect to ${mcp.url}`
-            log.error("remote mcp connection failed", { key, url: mcp.url, error: lastError?.message })
-            Bus.publish(Session.Event.Error, {
-              error: {
-                name: "UnknownError",
-                data: {
-                  message: errorMessage,
-                },
-              },
-            })
-          }
-        }
-
-        if (mcp.type === "local") {
-          const [cmd, ...args] = mcp.command
-          const client = await experimental_createMCPClient({
-            name: "opencode",
-            transport: new StdioClientTransport({
-              stderr: "ignore",
-              command: cmd,
-              args,
-              env: {
-                ...process.env,
-                ...(cmd === "opencode" ? { BUN_BE_BUN: "1" } : {}),
-                ...mcp.environment,
-              },
-            }),
-          }).catch((error) => {
-            const errorMessage =
-              error instanceof Error
-                ? `MCP server ${key} failed to start: ${error.message}`
-                : `MCP server ${key} failed to start`
-            log.error("local mcp startup failed", {
-              key,
-              command: mcp.command,
-              error: error instanceof Error ? error.message : String(error),
-            })
-            Bus.publish(Session.Event.Error, {
-              error: {
-                name: "UnknownError",
-                data: {
-                  message: errorMessage,
-                },
-              },
-            })
-            return null
-          })
-          if (client) {
-            clients[key] = client
-          }
-        }
-      }
 
-      for (const [key, client] of Object.entries(clients)) {
-        const result = await withTimeout(client.tools(), 5000).catch(() => {})
-        if (!result) {
-          log.warn("mcp client verification failed, removing client", { key })
-          delete clients[key]
-        }
-      }
+      await Promise.all(
+        Object.entries(config).map(async ([key, mcp]) => {
+          const result = await create(key, mcp).catch(() => undefined)
+          if (!result) return
+          clients[key] = result.client
+        }),
+      )
 
       return {
         clients,
-        config: cfg.mcp ?? {},
+        config,
       }
     },
     async (state) => {
@@ -151,6 +51,133 @@ export namespace MCP {
     },
   )
 
+  export async function add(name: string, mcp: Config.Mcp) {
+    const s = await state()
+    const result = await create(name, mcp)
+    if (!result) return
+    s.clients[name] = result.client
+  }
+
+  async function create(name: string, mcp: Config.Mcp) {
+    if (mcp.enabled === false) {
+      log.info("mcp server disabled", { name })
+      return
+    }
+    log.info("found", { name, type: mcp.type })
+
+    let mcpClient: MCPClient | undefined
+
+    if (mcp.type === "remote") {
+      const transports = [
+        {
+          name: "StreamableHTTP",
+          transport: new StreamableHTTPClientTransport(new URL(mcp.url), {
+            requestInit: {
+              headers: mcp.headers,
+            },
+          }),
+        },
+        {
+          name: "SSE",
+          transport: new SSEClientTransport(new URL(mcp.url), {
+            requestInit: {
+              headers: mcp.headers,
+            },
+          }),
+        },
+      ]
+      let lastError: Error | undefined
+      for (const { name, transport } of transports) {
+        const client = await experimental_createMCPClient({
+          name: "opencode",
+          transport,
+        }).catch((error) => {
+          lastError = error instanceof Error ? error : new Error(String(error))
+          log.debug("transport connection failed", {
+            name,
+            transport: name,
+            url: mcp.url,
+            error: lastError.message,
+          })
+          return null
+        })
+        if (client) {
+          log.debug("transport connection succeeded", { name, transport: name })
+          mcpClient = client
+          break
+        }
+      }
+      if (!mcpClient) {
+        const errorMessage = lastError
+          ? `MCP server ${name} failed to connect: ${lastError.message}`
+          : `MCP server ${name} failed to connect to ${mcp.url}`
+        log.error("remote mcp connection failed", { name, url: mcp.url, error: lastError?.message })
+        Bus.publish(Session.Event.Error, {
+          error: {
+            name: "UnknownError",
+            data: {
+              message: errorMessage,
+            },
+          },
+        })
+      }
+    }
+
+    if (mcp.type === "local") {
+      const [cmd, ...args] = mcp.command
+      const client = await experimental_createMCPClient({
+        name: "opencode",
+        transport: new StdioClientTransport({
+          stderr: "ignore",
+          command: cmd,
+          args,
+          env: {
+            ...process.env,
+            ...(cmd === "opencode" ? { BUN_BE_BUN: "1" } : {}),
+            ...mcp.environment,
+          },
+        }),
+      }).catch((error) => {
+        const errorMessage =
+          error instanceof Error
+            ? `MCP server ${name} failed to start: ${error.message}`
+            : `MCP server ${name} failed to start`
+        log.error("local mcp startup failed", {
+          name,
+          command: mcp.command,
+          error: error instanceof Error ? error.message : String(error),
+        })
+        Bus.publish(Session.Event.Error, {
+          error: {
+            name: "UnknownError",
+            data: {
+              message: errorMessage,
+            },
+          },
+        })
+        return null
+      })
+      if (client) {
+        mcpClient = client
+      }
+    }
+
+    if (!mcpClient) {
+      log.warn("mcp client not initialized", { name })
+      return
+    }
+
+    const result = await withTimeout(mcpClient.tools(), 5000).catch(() => {})
+    if (!result) {
+      log.warn("mcp client verification failed, dropping client", { name })
+      return
+    }
+
+    return {
+      client: mcpClient,
+    }
+  }
+
   export async function status() {
     return state().then((state) => {
       const result: Record<string, "connected" | "failed" | "disabled"> = {}

+ 5 - 1
packages/opencode/src/session/compaction.ts

@@ -189,7 +189,11 @@ export namespace SessionCompaction {
             case "text-delta":
               part.text += value.text
               if (value.providerMetadata) part.metadata = value.providerMetadata
-              if (part.text) await Session.updatePart(part)
+              if (part.text)
+                await Session.updatePart({
+                  part,
+                  delta: value.text,
+                })
               continue
             case "text-end": {
               part.text = part.text.trimEnd()