2
0
Эх сурвалжийг харах

fix: propagate abort signal to inline read tool (#21584)

Kit Langton 1 долоо хоног өмнө
parent
commit
2bdd279467

+ 41 - 40
packages/opencode/src/session/prompt.ts

@@ -559,7 +559,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
       }) {
         const { task, model, lastUser, sessionID, session, msgs } = input
         const ctx = yield* InstanceState.context
-        const taskTool = yield* registry.fromID(TaskTool.id)
+        const { task: taskTool } = yield* registry.named()
         const taskModel = task.model ? yield* getModel(task.model.providerID, task.model.modelID, sessionID) : model
         const assistantMessage: MessageV2.Assistant = yield* sessions.updateMessage({
           id: MessageID.ascending(),
@@ -1080,6 +1080,21 @@ NOTE: At any point in time through this workflow you should feel free to ask the
                 const filepath = fileURLToPath(part.url)
                 if (yield* fsys.isDir(filepath)) part.mime = "application/x-directory"
 
+                const { read } = yield* registry.named()
+                const execRead = (args: Parameters<typeof read.execute>[0], extra?: Tool.Context["extra"]) =>
+                  Effect.promise((signal: AbortSignal) =>
+                    read.execute(args, {
+                      sessionID: input.sessionID,
+                      abort: signal,
+                      agent: input.agent!,
+                      messageID: info.id,
+                      extra: { bypassCwdCheck: true, ...extra },
+                      messages: [],
+                      metadata: async () => {},
+                      ask: async () => {},
+                    }),
+                  )
+
                 if (part.mime === "text/plain") {
                   let offset: number | undefined
                   let limit: number | undefined
@@ -1116,29 +1131,12 @@ NOTE: At any point in time through this workflow you should feel free to ask the
                       text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
                     },
                   ]
-                  const read = yield* registry.fromID("read").pipe(
-                    Effect.flatMap((t) =>
-                      provider.getModel(info.model.providerID, info.model.modelID).pipe(
-                        Effect.flatMap((mdl) =>
-                          Effect.promise(() =>
-                            t.execute(args, {
-                              sessionID: input.sessionID,
-                              abort: new AbortController().signal,
-                              agent: input.agent!,
-                              messageID: info.id,
-                              extra: { bypassCwdCheck: true, model: mdl },
-                              messages: [],
-                              metadata: async () => {},
-                              ask: async () => {},
-                            }),
-                          ),
-                        ),
-                      ),
-                    ),
+                  const exit = yield* provider.getModel(info.model.providerID, info.model.modelID).pipe(
+                    Effect.flatMap((mdl) => execRead(args, { model: mdl })),
                     Effect.exit,
                   )
-                  if (Exit.isSuccess(read)) {
-                    const result = read.value
+                  if (Exit.isSuccess(exit)) {
+                    const result = exit.value
                     pieces.push({
                       messageID: info.id,
                       sessionID: input.sessionID,
@@ -1160,7 +1158,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
                       pieces.push({ ...part, messageID: info.id, sessionID: input.sessionID })
                     }
                   } else {
-                    const error = Cause.squash(read.cause)
+                    const error = Cause.squash(exit.cause)
                     log.error("failed to read file", { error })
                     const message = error instanceof Error ? error.message : String(error)
                     yield* bus.publish(Session.Event.Error, {
@@ -1180,22 +1178,25 @@ NOTE: At any point in time through this workflow you should feel free to ask the
 
                 if (part.mime === "application/x-directory") {
                   const args = { filePath: filepath }
-                  const result = yield* registry.fromID("read").pipe(
-                    Effect.flatMap((t) =>
-                      Effect.promise(() =>
-                        t.execute(args, {
-                          sessionID: input.sessionID,
-                          abort: new AbortController().signal,
-                          agent: input.agent!,
-                          messageID: info.id,
-                          extra: { bypassCwdCheck: true },
-                          messages: [],
-                          metadata: async () => {},
-                          ask: async () => {},
-                        }),
-                      ),
-                    ),
-                  )
+                  const exit = yield* execRead(args).pipe(Effect.exit)
+                  if (Exit.isFailure(exit)) {
+                    const error = Cause.squash(exit.cause)
+                    log.error("failed to read directory", { error })
+                    const message = error instanceof Error ? error.message : String(error)
+                    yield* bus.publish(Session.Event.Error, {
+                      sessionID: input.sessionID,
+                      error: new NamedError.Unknown({ message }).toObject(),
+                    })
+                    return [
+                      {
+                        messageID: info.id,
+                        sessionID: input.sessionID,
+                        type: "text",
+                        synthetic: true,
+                        text: `Read tool failed to read ${filepath} with the following error: ${message}`,
+                      },
+                    ]
+                  }
                   return [
                     {
                       messageID: info.id,
@@ -1209,7 +1210,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
                       sessionID: input.sessionID,
                       type: "text",
                       synthetic: true,
-                      text: result.output,
+                      text: exit.value.output,
                     },
                     { ...part, messageID: info.id, sessionID: input.sessionID },
                   ]

+ 14 - 13
packages/opencode/src/tool/registry.ts

@@ -42,24 +42,25 @@ import { Agent } from "../agent/agent"
 export namespace ToolRegistry {
   const log = Log.create({ service: "tool.registry" })
 
+  type TaskDef = Tool.InferDef<typeof TaskTool>
+  type ReadDef = Tool.InferDef<typeof ReadTool>
+
   type State = {
     custom: Tool.Def[]
     builtin: Tool.Def[]
+    task: TaskDef
+    read: ReadDef
   }
 
   export interface Interface {
     readonly ids: () => Effect.Effect<string[]>
     readonly all: () => Effect.Effect<Tool.Def[]>
-    readonly named: {
-      task: Tool.Info
-      read: Tool.Info
-    }
+    readonly named: () => Effect.Effect<{ task: TaskDef; read: ReadDef }>
     readonly tools: (model: {
       providerID: ProviderID
       modelID: ModelID
       agent: Agent.Info
     }) => Effect.Effect<Tool.Def[]>
-    readonly fromID: (id: string) => Effect.Effect<Tool.Def>
   }
 
   export class Service extends ServiceMap.Service<Service, Interface>()("@opencode/ToolRegistry") {}
@@ -183,6 +184,8 @@ export namespace ToolRegistry {
               ...(Flag.OPENCODE_EXPERIMENTAL_LSP_TOOL ? [tool.lsp] : []),
               ...(Flag.OPENCODE_EXPERIMENTAL_PLAN_MODE && Flag.OPENCODE_CLIENT === "cli" ? [tool.plan] : []),
             ],
+            task: tool.task,
+            read: tool.read,
           }
         }),
       )
@@ -192,13 +195,6 @@ export namespace ToolRegistry {
         return [...s.builtin, ...s.custom] as Tool.Def[]
       })
 
-      const fromID: Interface["fromID"] = Effect.fn("ToolRegistry.fromID")(function* (id: string) {
-        const tools = yield* all()
-        const match = tools.find((tool) => tool.id === id)
-        if (!match) return yield* Effect.die(`Tool not found: ${id}`)
-        return match
-      })
-
       const ids: Interface["ids"] = Effect.fn("ToolRegistry.ids")(function* () {
         return (yield* all()).map((tool) => tool.id)
       })
@@ -245,7 +241,12 @@ export namespace ToolRegistry {
         )
       })
 
-      return Service.of({ ids, all, named: { task, read }, tools, fromID })
+      const named: Interface["named"] = Effect.fn("ToolRegistry.named")(function* () {
+        const s = yield* InstanceState.get(state)
+        return { task: s.task, read: s.read }
+      })
+
+      return Service.of({ ids, all, named, tools })
     }),
   )
 

+ 8 - 1
packages/opencode/src/tool/tool.ts

@@ -60,6 +60,13 @@ export namespace Tool {
   export type InferMetadata<T> =
     T extends Info<any, infer M> ? M : T extends Effect.Effect<Info<any, infer M>, any, any> ? M : never
 
+  export type InferDef<T> =
+    T extends Info<infer P, infer M>
+      ? Def<P, M>
+      : T extends Effect.Effect<Info<infer P, infer M>, any, any>
+        ? Def<P, M>
+        : never
+
   function wrap<Parameters extends z.ZodType, Result extends Metadata>(
     id: string,
     init: (() => Promise<DefWithoutID<Parameters, Result>>) | DefWithoutID<Parameters, Result>,
@@ -118,7 +125,7 @@ export namespace Tool {
     )
   }
 
-  export function init(info: Info): Effect.Effect<Def> {
+  export function init<P extends z.ZodType, M extends Metadata>(info: Info<P, M>): Effect.Effect<Def<P, M>> {
     return Effect.gen(function* () {
       const init = yield* Effect.promise(() => info.init())
       return {

+ 122 - 25
packages/opencode/test/session/prompt-effect.test.ts

@@ -631,31 +631,22 @@ it.live(
           const ready = defer<void>()
           const aborted = defer<void>()
           const registry = yield* ToolRegistry.Service
-          const init = registry.named.task.init
-          registry.named.task.init = async () => ({
-            description: "task",
-            parameters: z.object({
-              description: z.string(),
-              prompt: z.string(),
-              subagent_type: z.string(),
-              task_id: z.string().optional(),
-              command: z.string().optional(),
-            }),
-            execute: async (_args, ctx) => {
-              ready.resolve()
-              ctx.abort.addEventListener("abort", () => aborted.resolve(), { once: true })
-              await new Promise<void>(() => {})
-              return {
-                title: "",
-                metadata: {
-                  sessionId: SessionID.make("task"),
-                  model: ref,
-                },
-                output: "",
-              }
-            },
-          })
-          yield* Effect.addFinalizer(() => Effect.sync(() => void (registry.named.task.init = init)))
+          const { task } = yield* registry.named()
+          const original = task.execute
+          task.execute = async (_args, ctx) => {
+            ready.resolve()
+            ctx.abort.addEventListener("abort", () => aborted.resolve(), { once: true })
+            await new Promise<void>(() => {})
+            return {
+              title: "",
+              metadata: {
+                sessionId: SessionID.make("task"),
+                model: ref,
+              },
+              output: "",
+            }
+          }
+          yield* Effect.addFinalizer(() => Effect.sync(() => void (task.execute = original)))
 
           const { prompt, chat } = yield* boot()
           const msg = yield* user(chat.id, "hello")
@@ -1240,3 +1231,109 @@ unix(
     ),
   30_000,
 )
+
+// Abort signal propagation tests for inline tool execution
+
+/** Override a tool's execute to hang until aborted. Returns ready/aborted defers and a finalizer. */
+function hangUntilAborted(tool: { execute: (...args: any[]) => any }) {
+  const ready = defer<void>()
+  const aborted = defer<void>()
+  const original = tool.execute
+  tool.execute = async (_args: any, ctx: any) => {
+    ready.resolve()
+    ctx.abort.addEventListener("abort", () => aborted.resolve(), { once: true })
+    await new Promise<void>(() => {})
+    return { title: "", metadata: {}, output: "" }
+  }
+  const restore = Effect.addFinalizer(() => Effect.sync(() => void (tool.execute = original)))
+  return { ready, aborted, restore }
+}
+
+it.live(
+  "interrupt propagates abort signal to read tool via file part (text/plain)",
+  () =>
+    provideTmpdirInstance(
+      (dir) =>
+        Effect.gen(function* () {
+          const registry = yield* ToolRegistry.Service
+          const { read } = yield* registry.named()
+          const { ready, aborted, restore } = hangUntilAborted(read)
+          yield* restore
+
+          const prompt = yield* SessionPrompt.Service
+          const sessions = yield* Session.Service
+          const chat = yield* sessions.create({ title: "Abort Test" })
+
+          const testFile = path.join(dir, "test.txt")
+          yield* Effect.promise(() => Bun.write(testFile, "hello world"))
+
+          const fiber = yield* prompt
+            .prompt({
+              sessionID: chat.id,
+              agent: "build",
+              parts: [
+                { type: "text", text: "read this" },
+                { type: "file", url: `file://${testFile}`, filename: "test.txt", mime: "text/plain" },
+              ],
+            })
+            .pipe(Effect.forkChild)
+
+          yield* Effect.promise(() => ready.promise)
+          yield* Fiber.interrupt(fiber)
+
+          yield* Effect.promise(() =>
+            Promise.race([
+              aborted.promise,
+              new Promise<void>((_, reject) =>
+                setTimeout(() => reject(new Error("abort signal not propagated within 2s")), 2_000),
+              ),
+            ]),
+          )
+        }),
+      { git: true, config: cfg },
+    ),
+  30_000,
+)
+
+it.live(
+  "interrupt propagates abort signal to read tool via file part (directory)",
+  () =>
+    provideTmpdirInstance(
+      (dir) =>
+        Effect.gen(function* () {
+          const registry = yield* ToolRegistry.Service
+          const { read } = yield* registry.named()
+          const { ready, aborted, restore } = hangUntilAborted(read)
+          yield* restore
+
+          const prompt = yield* SessionPrompt.Service
+          const sessions = yield* Session.Service
+          const chat = yield* sessions.create({ title: "Abort Test" })
+
+          const fiber = yield* prompt
+            .prompt({
+              sessionID: chat.id,
+              agent: "build",
+              parts: [
+                { type: "text", text: "read this" },
+                { type: "file", url: `file://${dir}`, filename: "dir", mime: "application/x-directory" },
+              ],
+            })
+            .pipe(Effect.forkChild)
+
+          yield* Effect.promise(() => ready.promise)
+          yield* Fiber.interrupt(fiber)
+
+          yield* Effect.promise(() =>
+            Promise.race([
+              aborted.promise,
+              new Promise<void>((_, reject) =>
+                setTimeout(() => reject(new Error("abort signal not propagated within 2s")), 2_000),
+              ),
+            ]),
+          )
+        }),
+      { git: true, config: cfg },
+    ),
+  30_000,
+)