Просмотр исходного кода

feat: mcp prompts as slash commands (alternative) (#5767)

Co-authored-by: Aiden Cline <[email protected]>
Co-authored-by: Aiden Cline <[email protected]>
Paolo Ricciuti 1 месяц назад
Родитель
Сommit
57a2b5f444

+ 55 - 5
packages/opencode/src/command/index.ts

@@ -6,6 +6,7 @@ import { Instance } from "../project/instance"
 import { Identifier } from "../id/id"
 import PROMPT_INITIALIZE from "./template/initialize.txt"
 import PROMPT_REVIEW from "./template/review.txt"
+import { MCP } from "../mcp"
 
 export namespace Command {
   export const Event = {
@@ -26,13 +27,28 @@ export namespace Command {
       description: z.string().optional(),
       agent: z.string().optional(),
       model: z.string().optional(),
-      template: z.string(),
+      // workaround for zod not supporting async functions natively so we use getters
+      // https://zod.dev/v4/changelog?id=zfunction
+      template: z.promise(z.string()).or(z.string()),
       subtask: z.boolean().optional(),
+      hints: z.array(z.string()),
     })
     .meta({
       ref: "Command",
     })
-  export type Info = z.infer<typeof Info>
+
+  // for some reason zod is inferring `string` for z.promise(z.string()).or(z.string()) so we have to manually override it
+  export type Info = Omit<z.infer<typeof Info>, "template"> & { template: Promise<string> | string }
+
+  export function hints(template: string): string[] {
+    const result: string[] = []
+    const numbered = template.match(/\$\d+/g)
+    if (numbered) {
+      for (const match of [...new Set(numbered)].sort()) result.push(match)
+    }
+    if (template.includes("$ARGUMENTS")) result.push("$ARGUMENTS")
+    return result
+  }
 
   export const Default = {
     INIT: "init",
@@ -46,13 +62,19 @@ export namespace Command {
       [Default.INIT]: {
         name: Default.INIT,
         description: "create/update AGENTS.md",
-        template: PROMPT_INITIALIZE.replace("${path}", Instance.worktree),
+        get template() {
+          return PROMPT_INITIALIZE.replace("${path}", Instance.worktree)
+        },
+        hints: hints(PROMPT_INITIALIZE),
       },
       [Default.REVIEW]: {
         name: Default.REVIEW,
         description: "review changes [commit|branch|pr], defaults to uncommitted",
-        template: PROMPT_REVIEW.replace("${path}", Instance.worktree),
+        get template() {
+          return PROMPT_REVIEW.replace("${path}", Instance.worktree)
+        },
         subtask: true,
+        hints: hints(PROMPT_REVIEW),
       },
     }
 
@@ -62,8 +84,36 @@ export namespace Command {
         agent: command.agent,
         model: command.model,
         description: command.description,
-        template: command.template,
+        get template() {
+          return command.template
+        },
         subtask: command.subtask,
+        hints: hints(command.template),
+      }
+    }
+    for (const [name, prompt] of Object.entries(await MCP.prompts())) {
+      result[name] = {
+        name,
+        description: prompt.description,
+        get template() {
+          // since a getter can't be async we need to manually return a promise here
+          return new Promise<string>(async (resolve, reject) => {
+            const template = await MCP.getPrompt(
+              prompt.client,
+              prompt.name,
+              prompt.arguments
+                ? // substitute each argument with $1, $2, etc.
+                  Object.fromEntries(prompt.arguments?.map((argument, i) => [argument.name, `$${i + 1}`]))
+                : {},
+            ).catch(reject)
+            resolve(
+              template?.messages
+                .map((message) => (message.content.type === "text" ? message.content.text : ""))
+                .join("\n") || "",
+            )
+          })
+        },
+        hints: prompt.arguments?.map((_, i) => `$${i + 1}`) ?? [],
       }
     }
 

+ 75 - 0
packages/opencode/src/mcp/index.ts

@@ -132,6 +132,9 @@ export namespace MCP {
   type TransportWithAuth = StreamableHTTPClientTransport | SSEClientTransport
   const pendingOAuthTransports = new Map<string, TransportWithAuth>()
 
+  // Prompt cache types
+  type PromptInfo = Awaited<ReturnType<MCPClient["listPrompts"]>>["prompts"][number]
+
   const state = Instance.state(
     async () => {
       const cfg = await Config.get()
@@ -176,6 +179,29 @@ export namespace MCP {
     },
   )
 
+  // Helper function to fetch prompts for a specific client
+  async function fetchPromptsForClient(clientName: string, client: Client) {
+    const prompts = await client.listPrompts().catch((e) => {
+      log.error("failed to get prompts", { clientName, error: e.message })
+      return undefined
+    })
+
+    if (!prompts) {
+      return
+    }
+
+    const commands: Record<string, PromptInfo & { client: string }> = {}
+
+    for (const prompt of prompts.prompts) {
+      const sanitizedClientName = clientName.replace(/[^a-zA-Z0-9_-]/g, "_")
+      const sanitizedPromptName = prompt.name.replace(/[^a-zA-Z0-9_-]/g, "_")
+      const key = sanitizedClientName + ":" + sanitizedPromptName + " (MCP)"
+
+      commands[key] = { ...prompt, client: clientName }
+    }
+    return commands
+  }
+
   export async function add(name: string, mcp: Config.Mcp) {
     const s = await state()
     const result = await create(name, mcp)
@@ -492,6 +518,55 @@ export namespace MCP {
     return result
   }
 
+  export async function prompts() {
+    const s = await state()
+    const clientsSnapshot = await clients()
+
+    const prompts = Object.fromEntries<PromptInfo & { client: string }>(
+      (
+        await Promise.all(
+          Object.entries(clientsSnapshot).map(async ([clientName, client]) => {
+            if (s.status[clientName]?.status !== "connected") {
+              return []
+            }
+
+            return Object.entries((await fetchPromptsForClient(clientName, client)) ?? {})
+          }),
+        )
+      ).flat(),
+    )
+
+    return prompts
+  }
+
+  export async function getPrompt(clientName: string, name: string, args?: Record<string, string>) {
+    const clientsSnapshot = await clients()
+    const client = clientsSnapshot[clientName]
+
+    if (!client) {
+      log.warn("client not found for prompt", {
+        clientName,
+      })
+      return undefined
+    }
+
+    const result = await client
+      .getPrompt({
+        name: name,
+        arguments: args,
+      })
+      .catch((e) => {
+        log.error("failed to get prompt from MCP server", {
+          clientName,
+          promptName: name,
+          error: e.message,
+        })
+        return undefined
+      })
+
+    return result
+  }
+
   /**
    * Start OAuth authentication flow for an MCP server.
    * Returns the authorization URL that should be opened in a browser.

+ 4 - 2
packages/opencode/src/session/prompt.ts

@@ -1290,7 +1290,9 @@ export namespace SessionPrompt {
     const raw = input.arguments.match(argsRegex) ?? []
     const args = raw.map((arg) => arg.replace(quoteTrimRegex, ""))
 
-    const placeholders = command.template.match(placeholderRegex) ?? []
+    const templateCommand = await command.template
+
+    const placeholders = templateCommand.match(placeholderRegex) ?? []
     let last = 0
     for (const item of placeholders) {
       const value = Number(item.slice(1))
@@ -1298,7 +1300,7 @@ export namespace SessionPrompt {
     }
 
     // Let the final placeholder swallow any extra arguments so prompts read naturally
-    const withArgs = command.template.replaceAll(placeholderRegex, (_, index) => {
+    const withArgs = templateCommand.replaceAll(placeholderRegex, (_, index) => {
       const position = Number(index)
       const argIndex = position - 1
       if (argIndex >= args.length) return ""

+ 1 - 0
packages/sdk/js/src/v2/gen/types.gen.ts

@@ -1731,6 +1731,7 @@ export type Command = {
   model?: string
   template: string
   subtask?: boolean
+  hints: Array<string>
 }
 
 export type Model = {