Explorar el Código

feat: ACP - stream bash output and synthetic pending events (#14079)

Co-authored-by: Aiden Cline <[email protected]>
Noam Bressler hace 1 mes
padre
commit
888b123387

+ 88 - 35
packages/opencode/src/acp/agent.ts

@@ -41,7 +41,7 @@ import { Config } from "@/config/config"
 import { Todo } from "@/session/todo"
 import { z } from "zod"
 import { LoadAPIKeyError } from "ai"
-import type { AssistantMessage, Event, OpencodeClient, SessionMessageResponse } from "@opencode-ai/sdk/v2"
+import type { AssistantMessage, Event, OpencodeClient, SessionMessageResponse, ToolPart } from "@opencode-ai/sdk/v2"
 import { applyPatch } from "diff"
 
 type ModeOption = { id: string; name: string; description?: string }
@@ -135,6 +135,8 @@ export namespace ACP {
     private sessionManager: ACPSessionManager
     private eventAbort = new AbortController()
     private eventStarted = false
+    private bashSnapshots = new Map<string, string>()
+    private toolStarts = new Set<string>()
     private permissionQueues = new Map<string, Promise<void>>()
     private permissionOptions: PermissionOption[] = [
       { optionId: "once", kind: "allow_once", name: "Allow once" },
@@ -266,47 +268,68 @@ export namespace ACP {
           const session = this.sessionManager.tryGet(part.sessionID)
           if (!session) return
           const sessionId = session.id
-          const directory = session.cwd
-
-          const message = await this.sdk.session
-            .message(
-              {
-                sessionID: part.sessionID,
-                messageID: part.messageID,
-                directory,
-              },
-              { throwOnError: true },
-            )
-            .then((x) => x.data)
-            .catch((error) => {
-              log.error("unexpected error when fetching message", { error })
-              return undefined
-            })
-
-          if (!message || message.info.role !== "assistant") return
 
           if (part.type === "tool") {
+            if (!this.toolStarts.has(part.callID)) {
+              this.toolStarts.add(part.callID)
+              await this.connection
+                .sessionUpdate({
+                  sessionId,
+                  update: {
+                    sessionUpdate: "tool_call",
+                    toolCallId: part.callID,
+                    title: part.tool,
+                    kind: toToolKind(part.tool),
+                    status: "pending",
+                    locations: [],
+                    rawInput: {},
+                  },
+                })
+                .catch((error) => {
+                  log.error("failed to send tool pending to ACP", { error })
+                })
+            }
+
             switch (part.state.status) {
               case "pending":
-                await this.connection
-                  .sessionUpdate({
-                    sessionId,
-                    update: {
-                      sessionUpdate: "tool_call",
-                      toolCallId: part.callID,
-                      title: part.tool,
-                      kind: toToolKind(part.tool),
-                      status: "pending",
-                      locations: [],
-                      rawInput: {},
-                    },
-                  })
-                  .catch((error) => {
-                    log.error("failed to send tool pending to ACP", { error })
-                  })
+                this.bashSnapshots.delete(part.callID)
                 return
 
               case "running":
+                const output = this.bashOutput(part)
+                const content: ToolCallContent[] = []
+                if (output) {
+                  const hash = String(Bun.hash(output))
+                  if (part.tool === "bash") {
+                    if (this.bashSnapshots.get(part.callID) === hash) {
+                      await this.connection
+                        .sessionUpdate({
+                          sessionId,
+                          update: {
+                            sessionUpdate: "tool_call_update",
+                            toolCallId: part.callID,
+                            status: "in_progress",
+                            kind: toToolKind(part.tool),
+                            title: part.tool,
+                            locations: toLocations(part.tool, part.state.input),
+                            rawInput: part.state.input,
+                          },
+                        })
+                        .catch((error) => {
+                          log.error("failed to send tool in_progress to ACP", { error })
+                        })
+                      return
+                    }
+                    this.bashSnapshots.set(part.callID, hash)
+                  }
+                  content.push({
+                    type: "content",
+                    content: {
+                      type: "text",
+                      text: output,
+                    },
+                  })
+                }
                 await this.connection
                   .sessionUpdate({
                     sessionId,
@@ -318,6 +341,7 @@ export namespace ACP {
                       title: part.tool,
                       locations: toLocations(part.tool, part.state.input),
                       rawInput: part.state.input,
+                      ...(content.length > 0 && { content }),
                     },
                   })
                   .catch((error) => {
@@ -326,6 +350,8 @@ export namespace ACP {
                 return
 
               case "completed": {
+                this.toolStarts.delete(part.callID)
+                this.bashSnapshots.delete(part.callID)
                 const kind = toToolKind(part.tool)
                 const content: ToolCallContent[] = [
                   {
@@ -405,6 +431,8 @@ export namespace ACP {
                 return
               }
               case "error":
+                this.toolStarts.delete(part.callID)
+                this.bashSnapshots.delete(part.callID)
                 await this.connection
                   .sessionUpdate({
                     sessionId,
@@ -426,6 +454,7 @@ export namespace ACP {
                       ],
                       rawOutput: {
                         error: part.state.error,
+                        metadata: part.state.metadata,
                       },
                     },
                   })
@@ -802,6 +831,7 @@ export namespace ACP {
         if (part.type === "tool") {
           switch (part.state.status) {
             case "pending":
+              this.bashSnapshots.delete(part.callID)
               await this.connection
                 .sessionUpdate({
                   sessionId,
@@ -820,6 +850,17 @@ export namespace ACP {
                 })
               break
             case "running":
+              const output = this.bashOutput(part)
+              const runningContent: ToolCallContent[] = []
+              if (output) {
+                runningContent.push({
+                  type: "content",
+                  content: {
+                    type: "text",
+                    text: output,
+                  },
+                })
+              }
               await this.connection
                 .sessionUpdate({
                   sessionId,
@@ -831,6 +872,7 @@ export namespace ACP {
                     title: part.tool,
                     locations: toLocations(part.tool, part.state.input),
                     rawInput: part.state.input,
+                    ...(runningContent.length > 0 && { content: runningContent }),
                   },
                 })
                 .catch((err) => {
@@ -838,6 +880,7 @@ export namespace ACP {
                 })
               break
             case "completed":
+              this.bashSnapshots.delete(part.callID)
               const kind = toToolKind(part.tool)
               const content: ToolCallContent[] = [
                 {
@@ -916,6 +959,7 @@ export namespace ACP {
                 })
               break
             case "error":
+              this.bashSnapshots.delete(part.callID)
               await this.connection
                 .sessionUpdate({
                   sessionId,
@@ -937,6 +981,7 @@ export namespace ACP {
                     ],
                     rawOutput: {
                       error: part.state.error,
+                      metadata: part.state.metadata,
                     },
                   },
                 })
@@ -1063,6 +1108,14 @@ export namespace ACP {
       }
     }
 
+    private bashOutput(part: ToolPart) {
+      if (part.tool !== "bash") return
+      if (!("metadata" in part.state) || !part.state.metadata || typeof part.state.metadata !== "object") return
+      const output = part.state.metadata["output"]
+      if (typeof output !== "string") return
+      return output
+    }
+
     private async loadAvailableModes(directory: string): Promise<ModeOption[]> {
       const agents = await this.config.sdk.app
         .agents(

+ 156 - 2
packages/opencode/test/acp/event-subscription.test.ts

@@ -1,7 +1,7 @@
 import { describe, expect, test } from "bun:test"
 import { ACP } from "../../src/acp/agent"
 import type { AgentSideConnection } from "@agentclientprotocol/sdk"
-import type { Event } from "@opencode-ai/sdk/v2"
+import type { Event, EventMessagePartUpdated, ToolStatePending, ToolStateRunning } from "@opencode-ai/sdk/v2"
 import { Instance } from "../../src/project/instance"
 import { tmpdir } from "../fixture/fixture"
 
@@ -19,6 +19,61 @@ type EventController = {
   close: () => void
 }
 
+function inProgressText(update: SessionUpdateParams["update"]) {
+  if (update.sessionUpdate !== "tool_call_update") return undefined
+  if (update.status !== "in_progress") return undefined
+  if (!update.content || !Array.isArray(update.content)) return undefined
+  const first = update.content[0]
+  if (!first || first.type !== "content") return undefined
+  if (first.content.type !== "text") return undefined
+  return first.content.text
+}
+
+function isToolCallUpdate(
+  update: SessionUpdateParams["update"],
+): update is Extract<SessionUpdateParams["update"], { sessionUpdate: "tool_call_update" }> {
+  return update.sessionUpdate === "tool_call_update"
+}
+
+function toolEvent(
+  sessionId: string,
+  cwd: string,
+  opts: {
+    callID: string
+    tool: string
+    input: Record<string, unknown>
+  } & ({ status: "running"; metadata?: Record<string, unknown> } | { status: "pending"; raw: string }),
+): GlobalEventEnvelope {
+  const state: ToolStatePending | ToolStateRunning =
+    opts.status === "running"
+      ? {
+          status: "running",
+          input: opts.input,
+          ...(opts.metadata && { metadata: opts.metadata }),
+          time: { start: Date.now() },
+        }
+      : {
+          status: "pending",
+          input: opts.input,
+          raw: opts.raw,
+        }
+  const payload: EventMessagePartUpdated = {
+    type: "message.part.updated",
+    properties: {
+      part: {
+        id: `part_${opts.callID}`,
+        sessionID: sessionId,
+        messageID: `msg_${opts.callID}`,
+        type: "tool",
+        callID: opts.callID,
+        tool: opts.tool,
+        state,
+      },
+    },
+  }
+  return { directory: cwd, payload }
+}
+
 function createEventStream() {
   const queue: GlobalEventEnvelope[] = []
   const waiters: Array<(value: GlobalEventEnvelope | undefined) => void> = []
@@ -65,6 +120,7 @@ function createEventStream() {
 function createFakeAgent() {
   const updates = new Map<string, string[]>()
   const chunks = new Map<string, string>()
+  const sessionUpdates: SessionUpdateParams[] = []
   const record = (sessionId: string, type: string) => {
     const list = updates.get(sessionId) ?? []
     list.push(type)
@@ -73,6 +129,7 @@ function createFakeAgent() {
 
   const connection = {
     async sessionUpdate(params: SessionUpdateParams) {
+      sessionUpdates.push(params)
       const update = params.update
       const type = update?.sessionUpdate ?? "unknown"
       record(params.sessionId, type)
@@ -197,7 +254,7 @@ function createFakeAgent() {
     ;(agent as any).eventAbort.abort()
   }
 
-  return { agent, controller, calls, updates, chunks, stop, sdk, connection }
+  return { agent, controller, calls, updates, chunks, sessionUpdates, stop, sdk, connection }
 }
 
 describe("acp.agent event subscription", () => {
@@ -435,4 +492,101 @@ describe("acp.agent event subscription", () => {
       },
     })
   })
+
+  test("streams running bash output snapshots and de-dupes identical snapshots", async () => {
+    await using tmp = await tmpdir()
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const { agent, controller, sessionUpdates, stop } = createFakeAgent()
+        const cwd = "/tmp/opencode-acp-test"
+        const sessionId = await agent.newSession({ cwd, mcpServers: [] } as any).then((x) => x.sessionId)
+        const input = { command: "echo hello", description: "run command" }
+
+        for (const output of ["a", "a", "ab"]) {
+          controller.push(
+            toolEvent(sessionId, cwd, { callID: "call_1", tool: "bash", status: "running", input, metadata: { output } }),
+          )
+        }
+        await new Promise((r) => setTimeout(r, 20))
+
+        const snapshots = sessionUpdates
+          .filter((u) => u.sessionId === sessionId)
+          .filter((u) => isToolCallUpdate(u.update))
+          .map((u) => inProgressText(u.update))
+
+        expect(snapshots).toEqual(["a", undefined, "ab"])
+        stop()
+      },
+    })
+  })
+
+  test("emits synthetic pending before first running update for any tool", async () => {
+    await using tmp = await tmpdir()
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const { agent, controller, sessionUpdates, stop } = createFakeAgent()
+        const cwd = "/tmp/opencode-acp-test"
+        const sessionId = await agent.newSession({ cwd, mcpServers: [] } as any).then((x) => x.sessionId)
+
+        controller.push(
+          toolEvent(sessionId, cwd, {
+            callID: "call_bash",
+            tool: "bash",
+            status: "running",
+            input: { command: "echo hi", description: "run command" },
+            metadata: { output: "hi\n" },
+          }),
+        )
+        controller.push(
+          toolEvent(sessionId, cwd, {
+            callID: "call_read",
+            tool: "read",
+            status: "running",
+            input: { filePath: "/tmp/example.txt" },
+          }),
+        )
+        await new Promise((r) => setTimeout(r, 20))
+
+        const types = sessionUpdates
+          .filter((u) => u.sessionId === sessionId)
+          .map((u) => u.update.sessionUpdate)
+          .filter((u) => u === "tool_call" || u === "tool_call_update")
+        expect(types).toEqual(["tool_call", "tool_call_update", "tool_call", "tool_call_update"])
+
+        const pendings = sessionUpdates.filter(
+          (u) => u.sessionId === sessionId && u.update.sessionUpdate === "tool_call",
+        )
+        expect(pendings.every((p) => p.update.sessionUpdate === "tool_call" && p.update.status === "pending")).toBe(true)
+        stop()
+      },
+    })
+  })
+
+  test("clears bash snapshot marker on pending state", async () => {
+    await using tmp = await tmpdir()
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const { agent, controller, sessionUpdates, stop } = createFakeAgent()
+        const cwd = "/tmp/opencode-acp-test"
+        const sessionId = await agent.newSession({ cwd, mcpServers: [] } as any).then((x) => x.sessionId)
+        const input = { command: "echo hello", description: "run command" }
+
+        controller.push(toolEvent(sessionId, cwd, { callID: "call_1", tool: "bash", status: "running", input, metadata: { output: "a" } }))
+        controller.push(toolEvent(sessionId, cwd, { callID: "call_1", tool: "bash", status: "pending", input, raw: '{"command":"echo hello"}' }))
+        controller.push(toolEvent(sessionId, cwd, { callID: "call_1", tool: "bash", status: "running", input, metadata: { output: "a" } }))
+        await new Promise((r) => setTimeout(r, 20))
+
+        const snapshots = sessionUpdates
+          .filter((u) => u.sessionId === sessionId)
+          .filter((u) => isToolCallUpdate(u.update))
+          .map((u) => inProgressText(u.update))
+
+        expect(snapshots).toEqual(["a", "a"])
+        stop()
+      },
+    })
+  })
 })