فهرست منبع

refactor(effect): resolve built tools through the registry (#20787)

Kit Langton 2 هفته پیش
والد
کامیت
7994dce0f2

+ 2 - 2
packages/opencode/src/cli/cmd/run.ts

@@ -28,13 +28,13 @@ import { BashTool } from "../../tool/bash"
 import { TodoWriteTool } from "../../tool/todo"
 import { Locale } from "../../util/locale"
 
-type ToolProps<T extends Tool.Info> = {
+type ToolProps<T> = {
   input: Tool.InferParameters<T>
   metadata: Tool.InferMetadata<T>
   part: ToolPart
 }
 
-function props<T extends Tool.Info>(part: ToolPart): ToolProps<T> {
+function props<T>(part: ToolPart): ToolProps<T> {
   const state = part.state
   return {
     input: state.input as Tool.InferParameters<T>,

+ 1 - 1
packages/opencode/src/cli/cmd/tui/routes/session/index.tsx

@@ -1572,7 +1572,7 @@ function ToolPart(props: { last: boolean; part: ToolPart; message: AssistantMess
   )
 }
 
-type ToolProps<T extends Tool.Info> = {
+type ToolProps<T> = {
   input: Partial<Tool.InferParameters<T>>
   metadata: Partial<Tool.InferMetadata<T>>
   permission: Record<string, any>

+ 1 - 1
packages/opencode/src/question/index.ts

@@ -198,7 +198,7 @@ export namespace Question {
     }),
   )
 
-  const defaultLayer = layer.pipe(Layer.provide(Bus.layer))
+  export const defaultLayer = layer.pipe(Layer.provide(Bus.layer))
 
   const { runPromise } = makeRuntime(Service, defaultLayer)
 

+ 4 - 4
packages/opencode/src/session/prompt.ts

@@ -560,7 +560,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
       }) {
         const { task, model, lastUser, sessionID, session, msgs } = input
         const ctx = yield* InstanceState.context
-        const taskTool = yield* Effect.promise(() => TaskTool.init())
+        const taskTool = yield* Effect.promise(() => registry.named.task.init())
         const taskModel = task.model ? yield* getModel(task.model.providerID, task.model.modelID, sessionID) : model
         const assistantMessage: MessageV2.Assistant = yield* sessions.updateMessage({
           id: MessageID.ascending(),
@@ -583,7 +583,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
           sessionID: assistantMessage.sessionID,
           type: "tool",
           callID: ulid(),
-          tool: TaskTool.id,
+          tool: registry.named.task.id,
           state: {
             status: "running",
             input: {
@@ -1110,7 +1110,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
                       text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
                     },
                   ]
-                  const read = yield* Effect.promise(() => ReadTool.init()).pipe(
+                  const read = yield* Effect.promise(() => registry.named.read.init()).pipe(
                     Effect.flatMap((t) =>
                       provider.getModel(info.model.providerID, info.model.modelID).pipe(
                         Effect.flatMap((mdl) =>
@@ -1174,7 +1174,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
 
                 if (part.mime === "application/x-directory") {
                   const args = { filePath: filepath }
-                  const result = yield* Effect.promise(() => ReadTool.init()).pipe(
+                  const result = yield* Effect.promise(() => registry.named.read.init()).pipe(
                     Effect.flatMap((t) =>
                       Effect.promise(() =>
                         t.execute(args, {

+ 36 - 23
packages/opencode/src/tool/question.ts

@@ -1,33 +1,46 @@
 import z from "zod"
+import { Effect } from "effect"
 import { Tool } from "./tool"
 import { Question } from "../question"
 import DESCRIPTION from "./question.txt"
 
-export const QuestionTool = Tool.define("question", {
-  description: DESCRIPTION,
-  parameters: z.object({
-    questions: z.array(Question.Info.omit({ custom: true })).describe("Questions to ask"),
-  }),
-  async execute(params, ctx) {
-    const answers = await Question.ask({
-      sessionID: ctx.sessionID,
-      questions: params.questions,
-      tool: ctx.callID ? { messageID: ctx.messageID, callID: ctx.callID } : undefined,
-    })
+const parameters = z.object({
+  questions: z.array(Question.Info.omit({ custom: true })).describe("Questions to ask"),
+})
 
-    function format(answer: Question.Answer | undefined) {
-      if (!answer?.length) return "Unanswered"
-      return answer.join(", ")
-    }
+type Metadata = {
+  answers: Question.Answer[]
+}
 
-    const formatted = params.questions.map((q, i) => `"${q.question}"="${format(answers[i])}"`).join(", ")
+export const QuestionTool = Tool.defineEffect<typeof parameters, Metadata, Question.Service>(
+  "question",
+  Effect.gen(function* () {
+    const question = yield* Question.Service
 
     return {
-      title: `Asked ${params.questions.length} question${params.questions.length > 1 ? "s" : ""}`,
-      output: `User has answered your questions: ${formatted}. You can now continue with the user's answers in mind.`,
-      metadata: {
-        answers,
+      description: DESCRIPTION,
+      parameters,
+      async execute(params: z.infer<typeof parameters>, ctx: Tool.Context<Metadata>) {
+        const answers = await question
+          .ask({
+            sessionID: ctx.sessionID,
+            questions: params.questions,
+            tool: ctx.callID ? { messageID: ctx.messageID, callID: ctx.callID } : undefined,
+          })
+          .pipe(Effect.runPromise)
+
+        const formatted = params.questions
+          .map((q, i) => `"${q.question}"="${answers[i]?.length ? answers[i].join(", ") : "Unanswered"}"`)
+          .join(", ")
+
+        return {
+          title: `Asked ${params.questions.length} question${params.questions.length > 1 ? "s" : ""}`,
+          output: `User has answered your questions: ${formatted}. You can now continue with the user's answers in mind.`,
+          metadata: {
+            answers,
+          },
+        }
       },
-    }
-  },
-})
+    } satisfies Tool.Def<typeof parameters, Metadata>
+  }),
+)

+ 54 - 32
packages/opencode/src/tool/registry.ts

@@ -33,6 +33,7 @@ import { Effect, Layer, ServiceMap } from "effect"
 import { InstanceState } from "@/effect/instance-state"
 import { makeRuntime } from "@/effect/run-service"
 import { Env } from "../env"
+import { Question } from "../question"
 
 export namespace ToolRegistry {
   const log = Log.create({ service: "tool.registry" })
@@ -42,8 +43,11 @@ export namespace ToolRegistry {
   }
 
   export interface Interface {
-    readonly register: (tool: Tool.Info) => Effect.Effect<void>
     readonly ids: () => Effect.Effect<string[]>
+    readonly named: {
+      task: Tool.Info
+      read: Tool.Info
+    }
     readonly tools: (
       model: { providerID: ProviderID; modelID: ModelID },
       agent?: Agent.Info,
@@ -52,12 +56,15 @@ export namespace ToolRegistry {
 
   export class Service extends ServiceMap.Service<Service, Interface>()("@opencode/ToolRegistry") {}
 
-  export const layer: Layer.Layer<Service, never, Config.Service | Plugin.Service> = Layer.effect(
+  export const layer: Layer.Layer<Service, never, Config.Service | Plugin.Service | Question.Service> = Layer.effect(
     Service,
     Effect.gen(function* () {
       const config = yield* Config.Service
       const plugin = yield* Plugin.Service
 
+      const build = <T extends Tool.Info>(tool: T | Effect.Effect<T, never, any>) =>
+        Effect.isEffect(tool) ? tool : Effect.succeed(tool)
+
       const state = yield* InstanceState.make<State>(
         Effect.fn("ToolRegistry.state")(function* (ctx) {
           const custom: Tool.Info[] = []
@@ -112,43 +119,52 @@ export namespace ToolRegistry {
         }),
       )
 
+      const invalid = yield* build(InvalidTool)
+      const ask = yield* build(QuestionTool)
+      const bash = yield* build(BashTool)
+      const read = yield* build(ReadTool)
+      const glob = yield* build(GlobTool)
+      const grep = yield* build(GrepTool)
+      const edit = yield* build(EditTool)
+      const write = yield* build(WriteTool)
+      const task = yield* build(TaskTool)
+      const fetch = yield* build(WebFetchTool)
+      const todo = yield* build(TodoWriteTool)
+      const search = yield* build(WebSearchTool)
+      const code = yield* build(CodeSearchTool)
+      const skill = yield* build(SkillTool)
+      const patch = yield* build(ApplyPatchTool)
+      const lsp = yield* build(LspTool)
+      const batch = yield* build(BatchTool)
+      const plan = yield* build(PlanExitTool)
+
       const all = Effect.fn("ToolRegistry.all")(function* (custom: Tool.Info[]) {
         const cfg = yield* config.get()
         const question = ["app", "cli", "desktop"].includes(Flag.OPENCODE_CLIENT) || Flag.OPENCODE_ENABLE_QUESTION_TOOL
 
         return [
-          InvalidTool,
-          ...(question ? [QuestionTool] : []),
-          BashTool,
-          ReadTool,
-          GlobTool,
-          GrepTool,
-          EditTool,
-          WriteTool,
-          TaskTool,
-          WebFetchTool,
-          TodoWriteTool,
-          WebSearchTool,
-          CodeSearchTool,
-          SkillTool,
-          ApplyPatchTool,
-          ...(Flag.OPENCODE_EXPERIMENTAL_LSP_TOOL ? [LspTool] : []),
-          ...(cfg.experimental?.batch_tool === true ? [BatchTool] : []),
-          ...(Flag.OPENCODE_EXPERIMENTAL_PLAN_MODE && Flag.OPENCODE_CLIENT === "cli" ? [PlanExitTool] : []),
+          invalid,
+          ...(question ? [ask] : []),
+          bash,
+          read,
+          glob,
+          grep,
+          edit,
+          write,
+          task,
+          fetch,
+          todo,
+          search,
+          code,
+          skill,
+          patch,
+          ...(Flag.OPENCODE_EXPERIMENTAL_LSP_TOOL ? [lsp] : []),
+          ...(cfg.experimental?.batch_tool === true ? [batch] : []),
+          ...(Flag.OPENCODE_EXPERIMENTAL_PLAN_MODE && Flag.OPENCODE_CLIENT === "cli" ? [plan] : []),
           ...custom,
         ]
       })
 
-      const register = Effect.fn("ToolRegistry.register")(function* (tool: Tool.Info) {
-        const s = yield* InstanceState.get(state)
-        const idx = s.custom.findIndex((t) => t.id === tool.id)
-        if (idx >= 0) {
-          s.custom.splice(idx, 1, tool)
-          return
-        }
-        s.custom.push(tool)
-      })
-
       const ids = Effect.fn("ToolRegistry.ids")(function* () {
         const s = yield* InstanceState.get(state)
         const tools = yield* all(s.custom)
@@ -196,12 +212,18 @@ export namespace ToolRegistry {
         )
       })
 
-      return Service.of({ register, ids, tools })
+      return Service.of({ ids, named: { task, read }, tools })
     }),
   )
 
   export const defaultLayer = Layer.unwrap(
-    Effect.sync(() => layer.pipe(Layer.provide(Config.defaultLayer), Layer.provide(Plugin.defaultLayer))),
+    Effect.sync(() =>
+      layer.pipe(
+        Layer.provide(Config.defaultLayer),
+        Layer.provide(Plugin.defaultLayer),
+        Layer.provide(Question.defaultLayer),
+      ),
+    ),
   )
 
   const { runPromise } = makeRuntime(Service, defaultLayer)

+ 56 - 36
packages/opencode/src/tool/tool.ts

@@ -1,4 +1,5 @@
 import z from "zod"
+import { Effect } from "effect"
 import type { MessageV2 } from "../session/message-v2"
 import type { Agent } from "../agent/agent"
 import type { Permission } from "../permission"
@@ -45,48 +46,67 @@ export namespace Tool {
     init: (ctx?: InitContext) => Promise<Def<Parameters, M>>
   }
 
-  export type InferParameters<T extends Info> = T extends Info<infer P> ? z.infer<P> : never
-  export type InferMetadata<T extends Info> = T extends Info<any, infer M> ? M : never
+  export type InferParameters<T> =
+    T extends Info<infer P, any>
+      ? z.infer<P>
+      : T extends Effect.Effect<Info<infer P, any>, any, any>
+        ? z.infer<P>
+        : never
+  export type InferMetadata<T> =
+    T extends Info<any, infer M> ? M : T extends Effect.Effect<Info<any, infer M>, any, any> ? M : never
+
+  function wrap<Parameters extends z.ZodType, Result extends Metadata>(
+    id: string,
+    init: ((ctx?: InitContext) => Promise<Def<Parameters, Result>>) | Def<Parameters, Result>,
+  ) {
+    return async (initCtx?: InitContext) => {
+      const toolInfo = init instanceof Function ? await init(initCtx) : { ...init }
+      const execute = toolInfo.execute
+      toolInfo.execute = async (args, ctx) => {
+        try {
+          toolInfo.parameters.parse(args)
+        } catch (error) {
+          if (error instanceof z.ZodError && toolInfo.formatValidationError) {
+            throw new Error(toolInfo.formatValidationError(error), { cause: error })
+          }
+          throw new Error(
+            `The ${id} tool was called with invalid arguments: ${error}.\nPlease rewrite the input so it satisfies the expected schema.`,
+            { cause: error },
+          )
+        }
+        const result = await execute(args, ctx)
+        if (result.metadata.truncated !== undefined) {
+          return result
+        }
+        const truncated = await Truncate.output(result.output, {}, initCtx?.agent)
+        return {
+          ...result,
+          output: truncated.content,
+          metadata: {
+            ...result.metadata,
+            truncated: truncated.truncated,
+            ...(truncated.truncated && { outputPath: truncated.outputPath }),
+          },
+        }
+      }
+      return toolInfo
+    }
+  }
 
   export function define<Parameters extends z.ZodType, Result extends Metadata>(
     id: string,
-    init: Info<Parameters, Result>["init"] | Def<Parameters, Result>,
+    init: ((ctx?: InitContext) => Promise<Def<Parameters, Result>>) | Def<Parameters, Result>,
   ): Info<Parameters, Result> {
     return {
       id,
-      init: async (initCtx) => {
-        const toolInfo = init instanceof Function ? await init(initCtx) : { ...init }
-        const execute = toolInfo.execute
-        toolInfo.execute = async (args, ctx) => {
-          try {
-            toolInfo.parameters.parse(args)
-          } catch (error) {
-            if (error instanceof z.ZodError && toolInfo.formatValidationError) {
-              throw new Error(toolInfo.formatValidationError(error), { cause: error })
-            }
-            throw new Error(
-              `The ${id} tool was called with invalid arguments: ${error}.\nPlease rewrite the input so it satisfies the expected schema.`,
-              { cause: error },
-            )
-          }
-          const result = await execute(args, ctx)
-          // skip truncation for tools that handle it themselves
-          if (result.metadata.truncated !== undefined) {
-            return result
-          }
-          const truncated = await Truncate.output(result.output, {}, initCtx?.agent)
-          return {
-            ...result,
-            output: truncated.content,
-            metadata: {
-              ...result.metadata,
-              truncated: truncated.truncated,
-              ...(truncated.truncated && { outputPath: truncated.outputPath }),
-            },
-          }
-        }
-        return toolInfo
-      },
+      init: wrap(id, init),
     }
   }
+
+  export function defineEffect<Parameters extends z.ZodType, Result extends Metadata, R>(
+    id: string,
+    init: Effect.Effect<((ctx?: InitContext) => Promise<Def<Parameters, Result>>) | Def<Parameters, Result>, never, R>,
+  ): Effect.Effect<Info<Parameters, Result>, never, R> {
+    return Effect.map(init, (next) => ({ id, init: wrap(id, next) }))
+  }
 }

+ 3 - 1
packages/opencode/test/session/prompt-effect.test.ts

@@ -15,6 +15,7 @@ import { Plugin } from "../../src/plugin"
 import { Provider as ProviderSvc } from "../../src/provider/provider"
 import type { Provider } from "../../src/provider/provider"
 import { ModelID, ProviderID } from "../../src/provider/schema"
+import { Question } from "../../src/question"
 import { Session } from "../../src/session"
 import { LLM } from "../../src/session/llm"
 import { MessageV2 } from "../../src/session/message-v2"
@@ -160,7 +161,8 @@ function makeHttp() {
     AppFileSystem.defaultLayer,
     status,
   ).pipe(Layer.provideMerge(infra))
-  const registry = ToolRegistry.layer.pipe(Layer.provideMerge(deps))
+  const question = Question.layer.pipe(Layer.provideMerge(deps))
+  const registry = ToolRegistry.layer.pipe(Layer.provideMerge(question), Layer.provideMerge(deps))
   const trunc = Truncate.layer.pipe(Layer.provideMerge(deps))
   const proc = SessionProcessor.layer.pipe(Layer.provideMerge(deps))
   const compact = SessionCompaction.layer.pipe(Layer.provideMerge(proc), Layer.provideMerge(deps))

+ 3 - 1
packages/opencode/test/session/snapshot-tool-race.test.ts

@@ -38,6 +38,7 @@ import { MCP } from "../../src/mcp"
 import { Permission } from "../../src/permission"
 import { Plugin } from "../../src/plugin"
 import { Provider as ProviderSvc } from "../../src/provider/provider"
+import { Question } from "../../src/question"
 import { SessionCompaction } from "../../src/session/compaction"
 import { Instruction } from "../../src/session/instruction"
 import { SessionProcessor } from "../../src/session/processor"
@@ -124,7 +125,8 @@ function makeHttp() {
     AppFileSystem.defaultLayer,
     status,
   ).pipe(Layer.provideMerge(infra))
-  const registry = ToolRegistry.layer.pipe(Layer.provideMerge(deps))
+  const question = Question.layer.pipe(Layer.provideMerge(deps))
+  const registry = ToolRegistry.layer.pipe(Layer.provideMerge(question), Layer.provideMerge(deps))
   const trunc = Truncate.layer.pipe(Layer.provideMerge(deps))
   const proc = SessionProcessor.layer.pipe(Layer.provideMerge(deps))
   const compact = SessionCompaction.layer.pipe(Layer.provideMerge(proc), Layer.provideMerge(deps))

+ 63 - 45
packages/opencode/test/tool/question.test.ts

@@ -1,8 +1,12 @@
-import { describe, expect, test, spyOn, beforeEach, afterEach } from "bun:test"
-import { z } from "zod"
+import { describe, expect } from "bun:test"
+import { Effect, Fiber, Layer } from "effect"
+import { Tool } from "../../src/tool/tool"
 import { QuestionTool } from "../../src/tool/question"
-import * as QuestionModule from "../../src/question"
+import { Question } from "../../src/question"
 import { SessionID, MessageID } from "../../src/session/schema"
+import * as CrossSpawnSpawner from "../../src/effect/cross-spawn-spawner"
+import { provideTmpdirInstance } from "../fixture/fixture"
+import { testEffect } from "../lib/effect"
 
 const ctx = {
   sessionID: SessionID.make("ses_test-session"),
@@ -15,55 +19,69 @@ const ctx = {
   ask: async () => {},
 }
 
-describe("tool.question", () => {
-  let askSpy: any
-
-  beforeEach(() => {
-    askSpy = spyOn(QuestionModule.Question, "ask").mockImplementation(async () => {
-      return []
-    })
-  })
+const it = testEffect(Layer.mergeAll(Question.defaultLayer, CrossSpawnSpawner.defaultLayer))
 
-  afterEach(() => {
-    askSpy.mockRestore()
-  })
+const pending = Effect.fn("QuestionToolTest.pending")(function* (question: Question.Interface) {
+  for (;;) {
+    const items = yield* question.list()
+    const item = items[0]
+    if (item) return item
+    yield* Effect.sleep("10 millis")
+  }
+})
 
-  test("should successfully execute with valid question parameters", async () => {
-    const tool = await QuestionTool.init()
-    const questions = [
-      {
-        question: "What is your favorite color?",
-        header: "Color",
-        options: [
-          { label: "Red", description: "The color of passion" },
-          { label: "Blue", description: "The color of sky" },
-        ],
-        multiple: false,
-      },
-    ]
+describe("tool.question", () => {
+  it.live("should successfully execute with valid question parameters", () =>
+    provideTmpdirInstance(() =>
+      Effect.gen(function* () {
+        const question = yield* Question.Service
+        const toolInfo = yield* QuestionTool
+        const tool = yield* Effect.promise(() => toolInfo.init())
+        const questions = [
+          {
+            question: "What is your favorite color?",
+            header: "Color",
+            options: [
+              { label: "Red", description: "The color of passion" },
+              { label: "Blue", description: "The color of sky" },
+            ],
+            multiple: false,
+          },
+        ]
 
-    askSpy.mockResolvedValueOnce([["Red"]])
+        const fiber = yield* Effect.promise(() => tool.execute({ questions }, ctx)).pipe(Effect.forkScoped)
+        const item = yield* pending(question)
+        yield* question.reply({ requestID: item.id, answers: [["Red"]] })
 
-    const result = await tool.execute({ questions }, ctx)
-    expect(askSpy).toHaveBeenCalledTimes(1)
-    expect(result.title).toBe("Asked 1 question")
-  })
+        const result = yield* Fiber.join(fiber)
+        expect(result.title).toBe("Asked 1 question")
+      }),
+    ),
+  )
 
-  test("should now pass with a header longer than 12 but less than 30 chars", async () => {
-    const tool = await QuestionTool.init()
-    const questions = [
-      {
-        question: "What is your favorite animal?",
-        header: "This Header is Over 12",
-        options: [{ label: "Dog", description: "Man's best friend" }],
-      },
-    ]
+  it.live("should now pass with a header longer than 12 but less than 30 chars", () =>
+    provideTmpdirInstance(() =>
+      Effect.gen(function* () {
+        const question = yield* Question.Service
+        const toolInfo = yield* QuestionTool
+        const tool = yield* Effect.promise(() => toolInfo.init())
+        const questions = [
+          {
+            question: "What is your favorite animal?",
+            header: "This Header is Over 12",
+            options: [{ label: "Dog", description: "Man's best friend" }],
+          },
+        ]
 
-    askSpy.mockResolvedValueOnce([["Dog"]])
+        const fiber = yield* Effect.promise(() => tool.execute({ questions }, ctx)).pipe(Effect.forkScoped)
+        const item = yield* pending(question)
+        yield* question.reply({ requestID: item.id, answers: [["Dog"]] })
 
-    const result = await tool.execute({ questions }, ctx)
-    expect(result.output).toContain(`"What is your favorite animal?"="Dog"`)
-  })
+        const result = yield* Fiber.join(fiber)
+        expect(result.output).toContain(`"What is your favorite animal?"="Dog"`)
+      }),
+    ),
+  )
 
   // intentionally removed the zod validation due to tool call errors, hoping prompting is gonna be good enough
   //   test("should throw an Error for header exceeding 30 characters", async () => {