prompt.ts 45 KB


  1. import path from "path"
  2. import os from "os"
  3. import fs from "fs/promises"
  4. import z from "zod"
  5. import { Identifier } from "../id/id"
  6. import { MessageV2 } from "./message-v2"
  7. import { Log } from "../util/log"
  8. import { SessionRevert } from "./revert"
  9. import { Session } from "."
  10. import { Agent } from "../agent/agent"
  11. import { Provider } from "../provider/provider"
  12. import {
  13. generateText,
  14. streamText,
  15. type ModelMessage,
  16. type Tool as AITool,
  17. tool,
  18. wrapLanguageModel,
  19. stepCountIs,
  20. jsonSchema,
  21. } from "ai"
  22. import { SessionCompaction } from "./compaction"
  23. import { Instance } from "../project/instance"
  24. import { Bus } from "../bus"
  25. import { ProviderTransform } from "../provider/transform"
  26. import { SystemPrompt } from "./system"
  27. import { Plugin } from "../plugin"
  28. import PROMPT_PLAN from "../session/prompt/plan.txt"
  29. import BUILD_SWITCH from "../session/prompt/build-switch.txt"
  30. import { defer } from "../util/defer"
  31. import { mergeDeep, pipe } from "remeda"
  32. import { ToolRegistry } from "../tool/registry"
  33. import { Wildcard } from "../util/wildcard"
  34. import { MCP } from "../mcp"
  35. import { LSP } from "../lsp"
  36. import { ReadTool } from "../tool/read"
  37. import { ListTool } from "../tool/ls"
  38. import { FileTime } from "../file/time"
  39. import { ulid } from "ulid"
  40. import { spawn } from "child_process"
  41. import { Command } from "../command"
  42. import { $, fileURLToPath } from "bun"
  43. import { ConfigMarkdown } from "../config/markdown"
  44. import { SessionSummary } from "./summary"
  45. import { NamedError } from "@/util/error"
  46. import { fn } from "@/util/fn"
  47. import { SessionProcessor } from "./processor"
  48. import { TaskTool } from "@/tool/task"
  49. import { SessionStatus } from "./status"
  50. // @ts-ignore
  51. globalThis.AI_SDK_LOG_WARNINGS = false
  52. export namespace SessionPrompt {
  53. const log = Log.create({ service: "session.prompt" })
  54. export const OUTPUT_TOKEN_MAX = 32_000
  55. const state = Instance.state(
  56. () => {
  57. const data: Record<
  58. string,
  59. {
  60. abort: AbortController
  61. callbacks: {
  62. resolve(input: MessageV2.WithParts): void
  63. reject(): void
  64. }[]
  65. }
  66. > = {}
  67. return data
  68. },
  69. async (current) => {
  70. for (const item of Object.values(current)) {
  71. item.abort.abort()
  72. }
  73. },
  74. )
  75. export function assertNotBusy(sessionID: string) {
  76. const match = state()[sessionID]
  77. if (match) throw new Session.BusyError(sessionID)
  78. }
  79. export const PromptInput = z.object({
  80. sessionID: Identifier.schema("session"),
  81. messageID: Identifier.schema("message").optional(),
  82. model: z
  83. .object({
  84. providerID: z.string(),
  85. modelID: z.string(),
  86. })
  87. .optional(),
  88. agent: z.string().optional(),
  89. noReply: z.boolean().optional(),
  90. system: z.string().optional(),
  91. tools: z.record(z.string(), z.boolean()).optional(),
  92. parts: z.array(
  93. z.discriminatedUnion("type", [
  94. MessageV2.TextPart.omit({
  95. messageID: true,
  96. sessionID: true,
  97. })
  98. .partial({
  99. id: true,
  100. })
  101. .meta({
  102. ref: "TextPartInput",
  103. }),
  104. MessageV2.FilePart.omit({
  105. messageID: true,
  106. sessionID: true,
  107. })
  108. .partial({
  109. id: true,
  110. })
  111. .meta({
  112. ref: "FilePartInput",
  113. }),
  114. MessageV2.AgentPart.omit({
  115. messageID: true,
  116. sessionID: true,
  117. })
  118. .partial({
  119. id: true,
  120. })
  121. .meta({
  122. ref: "AgentPartInput",
  123. }),
  124. MessageV2.SubtaskPart.omit({
  125. messageID: true,
  126. sessionID: true,
  127. })
  128. .partial({
  129. id: true,
  130. })
  131. .meta({
  132. ref: "SubtaskPartInput",
  133. }),
  134. ]),
  135. ),
  136. })
  137. export type PromptInput = z.infer<typeof PromptInput>
  138. export async function resolvePromptParts(template: string): Promise<PromptInput["parts"]> {
  139. const parts: PromptInput["parts"] = [
  140. {
  141. type: "text",
  142. text: template,
  143. },
  144. ]
  145. const files = ConfigMarkdown.files(template)
  146. await Promise.all(
  147. files.map(async (match) => {
  148. const name = match[1]
  149. const filepath = name.startsWith("~/")
  150. ? path.join(os.homedir(), name.slice(2))
  151. : path.resolve(Instance.worktree, name)
  152. const stats = await fs.stat(filepath).catch(() => undefined)
  153. if (!stats) {
  154. const agent = await Agent.get(name)
  155. if (agent) {
  156. parts.push({
  157. type: "agent",
  158. name: agent.name,
  159. })
  160. }
  161. return
  162. }
  163. if (stats.isDirectory()) {
  164. parts.push({
  165. type: "file",
  166. url: `file://${filepath}`,
  167. filename: name,
  168. mime: "application/x-directory",
  169. })
  170. return
  171. }
  172. parts.push({
  173. type: "file",
  174. url: `file://${filepath}`,
  175. filename: name,
  176. mime: "text/plain",
  177. })
  178. }),
  179. )
  180. return parts
  181. }
  182. export const prompt = fn(PromptInput, async (input) => {
  183. const session = await Session.get(input.sessionID)
  184. await SessionRevert.cleanup(session)
  185. const message = await createUserMessage(input)
  186. await Session.touch(input.sessionID)
  187. if (input.noReply === true) {
  188. return message
  189. }
  190. return loop(input.sessionID)
  191. })
  192. function start(sessionID: string) {
  193. const s = state()
  194. if (s[sessionID]) return
  195. const controller = new AbortController()
  196. s[sessionID] = {
  197. abort: controller,
  198. callbacks: [],
  199. }
  200. return controller.signal
  201. }
  202. export function cancel(sessionID: string) {
  203. log.info("cancel", { sessionID })
  204. const s = state()
  205. const match = s[sessionID]
  206. if (!match) return
  207. match.abort.abort()
  208. for (const item of match.callbacks) {
  209. item.reject()
  210. }
  211. delete s[sessionID]
  212. SessionStatus.set(sessionID, { type: "idle" })
  213. return
  214. }
  215. export const loop = fn(Identifier.schema("session"), async (sessionID) => {
  216. const abort = start(sessionID)
  217. if (!abort) {
  218. return new Promise<MessageV2.WithParts>((resolve, reject) => {
  219. const callbacks = state()[sessionID].callbacks
  220. callbacks.push({ resolve, reject })
  221. })
  222. }
  223. using _ = defer(() => cancel(sessionID))
  224. let step = 0
  225. while (true) {
  226. SessionStatus.set(sessionID, { type: "busy" })
  227. log.info("loop", { step, sessionID })
  228. if (abort.aborted) break
  229. let msgs = await MessageV2.filterCompacted(MessageV2.stream(sessionID))
  230. let lastUser: MessageV2.User | undefined
  231. let lastAssistant: MessageV2.Assistant | undefined
  232. let lastFinished: MessageV2.Assistant | undefined
  233. let tasks: (MessageV2.CompactionPart | MessageV2.SubtaskPart)[] = []
  234. for (let i = msgs.length - 1; i >= 0; i--) {
  235. const msg = msgs[i]
  236. if (!lastUser && msg.info.role === "user") lastUser = msg.info as MessageV2.User
  237. if (!lastAssistant && msg.info.role === "assistant") lastAssistant = msg.info as MessageV2.Assistant
  238. if (!lastFinished && msg.info.role === "assistant" && msg.info.finish)
  239. lastFinished = msg.info as MessageV2.Assistant
  240. if (lastUser && lastFinished) break
  241. const task = msg.parts.filter((part) => part.type === "compaction" || part.type === "subtask")
  242. if (task && !lastFinished) {
  243. tasks.push(...task)
  244. }
  245. }
  246. if (!lastUser) throw new Error("No user message found in stream. This should never happen.")
  247. if (
  248. lastAssistant?.finish &&
  249. !["tool-calls", "unknown"].includes(lastAssistant.finish) &&
  250. lastUser.id < lastAssistant.id
  251. ) {
  252. log.info("exiting loop", { sessionID })
  253. break
  254. }
  255. step++
  256. if (step === 1)
  257. ensureTitle({
  258. session: await Session.get(sessionID),
  259. modelID: lastUser.model.modelID,
  260. providerID: lastUser.model.providerID,
  261. message: msgs.find((m) => m.info.role === "user")!,
  262. history: msgs,
  263. })
  264. const model = await Provider.getModel(lastUser.model.providerID, lastUser.model.modelID)
  265. const task = tasks.pop()
  266. // pending subtask
  267. // TODO: centralize "invoke tool" logic
  268. if (task?.type === "subtask") {
  269. const taskTool = await TaskTool.init()
  270. const assistantMessage = (await Session.updateMessage({
  271. id: Identifier.ascending("message"),
  272. role: "assistant",
  273. parentID: lastUser.id,
  274. sessionID,
  275. mode: task.agent,
  276. path: {
  277. cwd: Instance.directory,
  278. root: Instance.worktree,
  279. },
  280. cost: 0,
  281. tokens: {
  282. input: 0,
  283. output: 0,
  284. reasoning: 0,
  285. cache: { read: 0, write: 0 },
  286. },
  287. modelID: model.modelID,
  288. providerID: model.providerID,
  289. time: {
  290. created: Date.now(),
  291. },
  292. })) as MessageV2.Assistant
  293. let part = (await Session.updatePart({
  294. id: Identifier.ascending("part"),
  295. messageID: assistantMessage.id,
  296. sessionID: assistantMessage.sessionID,
  297. type: "tool",
  298. callID: ulid(),
  299. tool: TaskTool.id,
  300. state: {
  301. status: "running",
  302. input: {
  303. prompt: task.prompt,
  304. description: task.description,
  305. subagent_type: task.agent,
  306. },
  307. time: {
  308. start: Date.now(),
  309. },
  310. },
  311. })) as MessageV2.ToolPart
  312. const result = await taskTool
  313. .execute(
  314. {
  315. prompt: task.prompt,
  316. description: task.description,
  317. subagent_type: task.agent,
  318. },
  319. {
  320. agent: task.agent,
  321. messageID: assistantMessage.id,
  322. sessionID: sessionID,
  323. abort,
  324. async metadata(input) {
  325. await Session.updatePart({
  326. ...part,
  327. type: "tool",
  328. state: {
  329. ...part.state,
  330. ...input,
  331. },
  332. } satisfies MessageV2.ToolPart)
  333. },
  334. },
  335. )
  336. .catch(() => {})
  337. assistantMessage.finish = "tool-calls"
  338. assistantMessage.time.completed = Date.now()
  339. await Session.updateMessage(assistantMessage)
  340. if (result && part.state.status === "running") {
  341. await Session.updatePart({
  342. ...part,
  343. state: {
  344. status: "completed",
  345. input: part.state.input,
  346. title: result.title,
  347. metadata: result.metadata,
  348. output: result.output,
  349. attachments: result.attachments,
  350. time: {
  351. ...part.state.time,
  352. end: Date.now(),
  353. },
  354. },
  355. } satisfies MessageV2.ToolPart)
  356. }
  357. if (!result) {
  358. await Session.updatePart({
  359. ...part,
  360. state: {
  361. status: "error",
  362. error: "Tool execution failed",
  363. time: {
  364. start: part.state.status === "running" ? part.state.time.start : Date.now(),
  365. end: Date.now(),
  366. },
  367. metadata: part.metadata,
  368. input: part.state.input,
  369. },
  370. } satisfies MessageV2.ToolPart)
  371. }
  372. continue
  373. }
  374. // pending compaction
  375. if (task?.type === "compaction") {
  376. const result = await SessionCompaction.process({
  377. messages: msgs,
  378. parentID: lastUser.id,
  379. abort,
  380. agent: lastUser.agent,
  381. model: {
  382. providerID: model.providerID,
  383. modelID: model.modelID,
  384. },
  385. sessionID,
  386. })
  387. if (result === "stop") break
  388. continue
  389. }
  390. // context overflow, needs compaction
  391. if (
  392. lastFinished &&
  393. lastFinished.summary !== true &&
  394. SessionCompaction.isOverflow({ tokens: lastFinished.tokens, model: model.info })
  395. ) {
  396. await SessionCompaction.create({
  397. sessionID,
  398. agent: lastUser.agent,
  399. model: lastUser.model,
  400. })
  401. continue
  402. }
  403. // normal processing
  404. const agent = await Agent.get(lastUser.agent)
  405. msgs = insertReminders({
  406. messages: msgs,
  407. agent,
  408. })
  409. const processor = SessionProcessor.create({
  410. assistantMessage: (await Session.updateMessage({
  411. id: Identifier.ascending("message"),
  412. parentID: lastUser.id,
  413. role: "assistant",
  414. mode: agent.name,
  415. path: {
  416. cwd: Instance.directory,
  417. root: Instance.worktree,
  418. },
  419. cost: 0,
  420. tokens: {
  421. input: 0,
  422. output: 0,
  423. reasoning: 0,
  424. cache: { read: 0, write: 0 },
  425. },
  426. modelID: model.modelID,
  427. providerID: model.providerID,
  428. time: {
  429. created: Date.now(),
  430. },
  431. sessionID,
  432. })) as MessageV2.Assistant,
  433. sessionID: sessionID,
  434. model: model.info,
  435. providerID: model.providerID,
  436. abort,
  437. })
  438. const system = await resolveSystemPrompt({
  439. providerID: model.providerID,
  440. modelID: model.info.id,
  441. agent,
  442. system: lastUser.system,
  443. })
  444. const tools = await resolveTools({
  445. agent,
  446. sessionID,
  447. model: lastUser.model,
  448. tools: lastUser.tools,
  449. processor,
  450. })
  451. const params = await Plugin.trigger(
  452. "chat.params",
  453. {
  454. sessionID: sessionID,
  455. agent: lastUser.agent,
  456. model: model.info,
  457. provider: await Provider.getProvider(model.providerID),
  458. message: lastUser,
  459. },
  460. {
  461. temperature: model.info.temperature
  462. ? (agent.temperature ?? ProviderTransform.temperature(model.providerID, model.modelID))
  463. : undefined,
  464. topP: agent.topP ?? ProviderTransform.topP(model.providerID, model.modelID),
  465. options: pipe(
  466. {},
  467. mergeDeep(ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", sessionID)),
  468. mergeDeep(model.info.options),
  469. mergeDeep(agent.options),
  470. ),
  471. },
  472. )
  473. if (step === 1) {
  474. SessionSummary.summarize({
  475. sessionID: sessionID,
  476. messageID: lastUser.id,
  477. })
  478. }
  479. const result = await processor.process(() =>
  480. streamText({
  481. onError(error) {
  482. log.error("stream error", {
  483. error,
  484. })
  485. },
  486. async experimental_repairToolCall(input) {
  487. const lower = input.toolCall.toolName.toLowerCase()
  488. if (lower !== input.toolCall.toolName && tools[lower]) {
  489. log.info("repairing tool call", {
  490. tool: input.toolCall.toolName,
  491. repaired: lower,
  492. })
  493. return {
  494. ...input.toolCall,
  495. toolName: lower,
  496. }
  497. }
  498. return {
  499. ...input.toolCall,
  500. input: JSON.stringify({
  501. tool: input.toolCall.toolName,
  502. error: input.error.message,
  503. }),
  504. toolName: "invalid",
  505. }
  506. },
  507. headers: {
  508. ...(model.providerID.startsWith("opencode")
  509. ? {
  510. "x-opencode-session": sessionID,
  511. "x-opencode-request": lastUser.id,
  512. }
  513. : undefined),
  514. ...model.info.headers,
  515. },
  516. // set to 0, we handle loop
  517. maxRetries: 0,
  518. activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
  519. maxOutputTokens: ProviderTransform.maxOutputTokens(
  520. model.providerID,
  521. params.options,
  522. model.info.limit.output,
  523. OUTPUT_TOKEN_MAX,
  524. ),
  525. abortSignal: abort,
  526. providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, params.options),
  527. stopWhen: stepCountIs(1),
  528. temperature: params.temperature,
  529. topP: params.topP,
  530. messages: [
  531. ...system.map(
  532. (x): ModelMessage => ({
  533. role: "system",
  534. content: x,
  535. }),
  536. ),
  537. ...MessageV2.toModelMessage(
  538. msgs.filter((m) => {
  539. if (m.info.role !== "assistant" || m.info.error === undefined) {
  540. return true
  541. }
  542. if (
  543. MessageV2.AbortedError.isInstance(m.info.error) &&
  544. m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
  545. ) {
  546. return true
  547. }
  548. return false
  549. }),
  550. ),
  551. ],
  552. tools: model.info.tool_call === false ? undefined : tools,
  553. model: wrapLanguageModel({
  554. model: model.language,
  555. middleware: [
  556. {
  557. async transformParams(args) {
  558. if (args.type === "stream") {
  559. // @ts-expect-error
  560. args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
  561. }
  562. return args.params
  563. },
  564. },
  565. ],
  566. }),
  567. }),
  568. )
  569. if (result === "stop") break
  570. continue
  571. }
  572. SessionCompaction.prune({ sessionID })
  573. for await (const item of MessageV2.stream(sessionID)) {
  574. if (item.info.role === "user") continue
  575. const queued = state()[sessionID]?.callbacks ?? []
  576. for (const q of queued) {
  577. q.resolve(item)
  578. }
  579. return item
  580. }
  581. throw new Error("Impossible")
  582. })
  583. async function lastModel(sessionID: string) {
  584. for await (const item of MessageV2.stream(sessionID)) {
  585. if (item.info.role === "user" && item.info.model) return item.info.model
  586. }
  587. return Provider.defaultModel()
  588. }
  589. async function resolveSystemPrompt(input: {
  590. system?: string
  591. agent: Agent.Info
  592. providerID: string
  593. modelID: string
  594. }) {
  595. let system = SystemPrompt.header(input.providerID)
  596. system.push(
  597. ...(() => {
  598. if (input.system) return [input.system]
  599. if (input.agent.prompt) return [input.agent.prompt]
  600. return SystemPrompt.provider(input.modelID)
  601. })(),
  602. )
  603. system.push(...(await SystemPrompt.environment()))
  604. system.push(...(await SystemPrompt.custom()))
  605. // max 2 system prompt messages for caching purposes
  606. const [first, ...rest] = system
  607. system = [first, rest.join("\n")]
  608. return system
  609. }
  610. async function resolveTools(input: {
  611. agent: Agent.Info
  612. model: {
  613. providerID: string
  614. modelID: string
  615. }
  616. sessionID: string
  617. tools?: Record<string, boolean>
  618. processor: SessionProcessor.Info
  619. }) {
  620. const tools: Record<string, AITool> = {}
  621. const enabledTools = pipe(
  622. input.agent.tools,
  623. mergeDeep(await ToolRegistry.enabled(input.model.providerID, input.model.modelID, input.agent)),
  624. mergeDeep(input.tools ?? {}),
  625. )
  626. for (const item of await ToolRegistry.tools(input.model.providerID, input.model.modelID)) {
  627. if (Wildcard.all(item.id, enabledTools) === false) continue
  628. const schema = ProviderTransform.schema(
  629. input.model.providerID,
  630. input.model.modelID,
  631. z.toJSONSchema(item.parameters),
  632. )
  633. tools[item.id] = tool({
  634. id: item.id as any,
  635. description: item.description,
  636. inputSchema: jsonSchema(schema as any),
  637. async execute(args, options) {
  638. await Plugin.trigger(
  639. "tool.execute.before",
  640. {
  641. tool: item.id,
  642. sessionID: input.sessionID,
  643. callID: options.toolCallId,
  644. },
  645. {
  646. args,
  647. },
  648. )
  649. const result = await item.execute(args, {
  650. sessionID: input.sessionID,
  651. abort: options.abortSignal!,
  652. messageID: input.processor.message.id,
  653. callID: options.toolCallId,
  654. extra: input.model,
  655. agent: input.agent.name,
  656. metadata: async (val) => {
  657. const match = input.processor.partFromToolCall(options.toolCallId)
  658. if (match && match.state.status === "running") {
  659. await Session.updatePart({
  660. ...match,
  661. state: {
  662. title: val.title,
  663. metadata: val.metadata,
  664. status: "running",
  665. input: args,
  666. time: {
  667. start: Date.now(),
  668. },
  669. },
  670. })
  671. }
  672. },
  673. })
  674. await Plugin.trigger(
  675. "tool.execute.after",
  676. {
  677. tool: item.id,
  678. sessionID: input.sessionID,
  679. callID: options.toolCallId,
  680. },
  681. result,
  682. )
  683. return result
  684. },
  685. toModelOutput(result) {
  686. return {
  687. type: "text",
  688. value: result.output,
  689. }
  690. },
  691. })
  692. }
  693. for (const [key, item] of Object.entries(await MCP.tools())) {
  694. if (Wildcard.all(key, enabledTools) === false) continue
  695. const execute = item.execute
  696. if (!execute) continue
  697. item.execute = async (args, opts) => {
  698. await Plugin.trigger(
  699. "tool.execute.before",
  700. {
  701. tool: key,
  702. sessionID: input.sessionID,
  703. callID: opts.toolCallId,
  704. },
  705. {
  706. args,
  707. },
  708. )
  709. const result = await execute(args, opts)
  710. await Plugin.trigger(
  711. "tool.execute.after",
  712. {
  713. tool: key,
  714. sessionID: input.sessionID,
  715. callID: opts.toolCallId,
  716. },
  717. result,
  718. )
  719. const textParts: string[] = []
  720. const attachments: MessageV2.FilePart[] = []
  721. for (const item of result.content) {
  722. if (item.type === "text") {
  723. textParts.push(item.text)
  724. } else if (item.type === "image") {
  725. attachments.push({
  726. id: Identifier.ascending("part"),
  727. sessionID: input.sessionID,
  728. messageID: input.processor.message.id,
  729. type: "file",
  730. mime: item.mimeType,
  731. url: `data:${item.mimeType};base64,${item.data}`,
  732. })
  733. }
  734. // Add support for other types if needed
  735. }
  736. return {
  737. title: "",
  738. metadata: result.metadata ?? {},
  739. output: textParts.join("\n\n"),
  740. attachments,
  741. content: result.content, // directly return content to preserve ordering when outputting to model
  742. }
  743. }
  744. item.toModelOutput = (result) => {
  745. return {
  746. type: "text",
  747. value: result.output,
  748. }
  749. }
  750. tools[key] = item
  751. }
  752. return tools
  753. }
  754. async function createUserMessage(input: PromptInput) {
  755. const agent = await Agent.get(input.agent ?? "build")
  756. const info: MessageV2.Info = {
  757. id: input.messageID ?? Identifier.ascending("message"),
  758. role: "user",
  759. sessionID: input.sessionID,
  760. time: {
  761. created: Date.now(),
  762. },
  763. tools: input.tools,
  764. system: input.system,
  765. agent: agent.name,
  766. model: input.model ?? agent.model ?? (await lastModel(input.sessionID)),
  767. }
  768. const parts = await Promise.all(
  769. input.parts.map(async (part): Promise<MessageV2.Part[]> => {
  770. if (part.type === "file") {
  771. const url = new URL(part.url)
  772. switch (url.protocol) {
  773. case "data:":
  774. if (part.mime === "text/plain") {
  775. return [
  776. {
  777. id: Identifier.ascending("part"),
  778. messageID: info.id,
  779. sessionID: input.sessionID,
  780. type: "text",
  781. synthetic: true,
  782. text: `Called the Read tool with the following input: ${JSON.stringify({ filePath: part.filename })}`,
  783. },
  784. {
  785. id: Identifier.ascending("part"),
  786. messageID: info.id,
  787. sessionID: input.sessionID,
  788. type: "text",
  789. synthetic: true,
  790. text: Buffer.from(part.url, "base64url").toString(),
  791. },
  792. {
  793. ...part,
  794. id: part.id ?? Identifier.ascending("part"),
  795. messageID: info.id,
  796. sessionID: input.sessionID,
  797. },
  798. ]
  799. }
  800. break
  801. case "file:":
  802. log.info("file", { mime: part.mime })
  803. // have to normalize, symbol search returns absolute paths
  804. // Decode the pathname since URL constructor doesn't automatically decode it
  805. const filepath = fileURLToPath(part.url)
  806. const stat = await Bun.file(filepath).stat()
  807. if (stat.isDirectory()) {
  808. part.mime = "application/x-directory"
  809. }
  810. if (part.mime === "text/plain") {
  811. let offset: number | undefined = undefined
  812. let limit: number | undefined = undefined
  813. const range = {
  814. start: url.searchParams.get("start"),
  815. end: url.searchParams.get("end"),
  816. }
  817. if (range.start != null) {
  818. const filePathURI = part.url.split("?")[0]
  819. let start = parseInt(range.start)
  820. let end = range.end ? parseInt(range.end) : undefined
  821. // some LSP servers (eg, gopls) don't give full range in
  822. // workspace/symbol searches, so we'll try to find the
  823. // symbol in the document to get the full range
  824. if (start === end) {
  825. const symbols = await LSP.documentSymbol(filePathURI)
  826. for (const symbol of symbols) {
  827. let range: LSP.Range | undefined
  828. if ("range" in symbol) {
  829. range = symbol.range
  830. } else if ("location" in symbol) {
  831. range = symbol.location.range
  832. }
  833. if (range?.start?.line && range?.start?.line === start) {
  834. start = range.start.line
  835. end = range?.end?.line ?? start
  836. break
  837. }
  838. }
  839. }
  840. offset = Math.max(start - 1, 0)
  841. if (end) {
  842. limit = end - offset
  843. }
  844. }
  845. const args = { filePath: filepath, offset, limit }
  846. const pieces: MessageV2.Part[] = [
  847. {
  848. id: Identifier.ascending("part"),
  849. messageID: info.id,
  850. sessionID: input.sessionID,
  851. type: "text",
  852. synthetic: true,
  853. text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
  854. },
  855. ]
  856. await ReadTool.init()
  857. .then(async (t) => {
  858. const result = await t.execute(args, {
  859. sessionID: input.sessionID,
  860. abort: new AbortController().signal,
  861. agent: input.agent!,
  862. messageID: info.id,
  863. extra: { bypassCwdCheck: true, ...info.model },
  864. metadata: async () => {},
  865. })
  866. pieces.push({
  867. id: Identifier.ascending("part"),
  868. messageID: info.id,
  869. sessionID: input.sessionID,
  870. type: "text",
  871. synthetic: true,
  872. text: result.output,
  873. })
  874. if (result.attachments?.length) {
  875. pieces.push(
  876. ...result.attachments.map((attachment) => ({
  877. ...attachment,
  878. synthetic: true,
  879. filename: attachment.filename ?? part.filename,
  880. messageID: info.id,
  881. sessionID: input.sessionID,
  882. })),
  883. )
  884. } else {
  885. pieces.push({
  886. ...part,
  887. id: part.id ?? Identifier.ascending("part"),
  888. messageID: info.id,
  889. sessionID: input.sessionID,
  890. })
  891. }
  892. })
  893. .catch((error) => {
  894. log.error("failed to read file", { error })
  895. const message = error instanceof Error ? error.message : error.toString()
  896. Bus.publish(Session.Event.Error, {
  897. sessionID: input.sessionID,
  898. error: new NamedError.Unknown({
  899. message,
  900. }).toObject(),
  901. })
  902. pieces.push({
  903. id: Identifier.ascending("part"),
  904. messageID: info.id,
  905. sessionID: input.sessionID,
  906. type: "text",
  907. synthetic: true,
  908. text: `Read tool failed to read ${filepath} with the following error: ${message}`,
  909. })
  910. })
  911. return pieces
  912. }
  913. if (part.mime === "application/x-directory") {
  914. const args = { path: filepath }
  915. const result = await ListTool.init().then((t) =>
  916. t.execute(args, {
  917. sessionID: input.sessionID,
  918. abort: new AbortController().signal,
  919. agent: input.agent!,
  920. messageID: info.id,
  921. extra: { bypassCwdCheck: true },
  922. metadata: async () => {},
  923. }),
  924. )
  925. return [
  926. {
  927. id: Identifier.ascending("part"),
  928. messageID: info.id,
  929. sessionID: input.sessionID,
  930. type: "text",
  931. synthetic: true,
  932. text: `Called the list tool with the following input: ${JSON.stringify(args)}`,
  933. },
  934. {
  935. id: Identifier.ascending("part"),
  936. messageID: info.id,
  937. sessionID: input.sessionID,
  938. type: "text",
  939. synthetic: true,
  940. text: result.output,
  941. },
  942. {
  943. ...part,
  944. id: part.id ?? Identifier.ascending("part"),
  945. messageID: info.id,
  946. sessionID: input.sessionID,
  947. },
  948. ]
  949. }
  950. const file = Bun.file(filepath)
  951. FileTime.read(input.sessionID, filepath)
  952. return [
  953. {
  954. id: Identifier.ascending("part"),
  955. messageID: info.id,
  956. sessionID: input.sessionID,
  957. type: "text",
  958. text: `Called the Read tool with the following input: {\"filePath\":\"${filepath}\"}`,
  959. synthetic: true,
  960. },
  961. {
  962. id: part.id ?? Identifier.ascending("part"),
  963. messageID: info.id,
  964. sessionID: input.sessionID,
  965. type: "file",
  966. url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
  967. mime: part.mime,
  968. filename: part.filename!,
  969. source: part.source,
  970. },
  971. ]
  972. }
  973. }
  974. if (part.type === "agent") {
  975. return [
  976. {
  977. id: Identifier.ascending("part"),
  978. ...part,
  979. messageID: info.id,
  980. sessionID: input.sessionID,
  981. },
  982. {
  983. id: Identifier.ascending("part"),
  984. messageID: info.id,
  985. sessionID: input.sessionID,
  986. type: "text",
  987. synthetic: true,
  988. text:
  989. "Use the above message and context to generate a prompt and call the task tool with subagent: " +
  990. part.name,
  991. },
  992. ]
  993. }
  994. return [
  995. {
  996. id: Identifier.ascending("part"),
  997. ...part,
  998. messageID: info.id,
  999. sessionID: input.sessionID,
  1000. },
  1001. ]
  1002. }),
  1003. ).then((x) => x.flat())
  1004. await Plugin.trigger(
  1005. "chat.message",
  1006. {
  1007. sessionID: input.sessionID,
  1008. agent: input.agent,
  1009. model: input.model,
  1010. messageID: input.messageID,
  1011. },
  1012. {
  1013. message: info,
  1014. parts,
  1015. },
  1016. )
  1017. await Session.updateMessage(info)
  1018. for (const part of parts) {
  1019. await Session.updatePart(part)
  1020. }
  1021. return {
  1022. info,
  1023. parts,
  1024. }
  1025. }
  1026. function insertReminders(input: { messages: MessageV2.WithParts[]; agent: Agent.Info }) {
  1027. const userMessage = input.messages.findLast((msg) => msg.info.role === "user")
  1028. if (!userMessage) return input.messages
  1029. if (input.agent.name === "plan") {
  1030. userMessage.parts.push({
  1031. id: Identifier.ascending("part"),
  1032. messageID: userMessage.info.id,
  1033. sessionID: userMessage.info.sessionID,
  1034. type: "text",
  1035. text: PROMPT_PLAN,
  1036. synthetic: true,
  1037. })
  1038. }
  1039. const wasPlan = input.messages.some((msg) => msg.info.role === "assistant" && msg.info.mode === "plan")
  1040. if (wasPlan && input.agent.name === "build") {
  1041. userMessage.parts.push({
  1042. id: Identifier.ascending("part"),
  1043. messageID: userMessage.info.id,
  1044. sessionID: userMessage.info.sessionID,
  1045. type: "text",
  1046. text: BUILD_SWITCH,
  1047. synthetic: true,
  1048. })
  1049. }
  1050. return input.messages
  1051. }
  1052. export const ShellInput = z.object({
  1053. sessionID: Identifier.schema("session"),
  1054. agent: z.string(),
  1055. model: z
  1056. .object({
  1057. providerID: z.string(),
  1058. modelID: z.string(),
  1059. })
  1060. .optional(),
  1061. command: z.string(),
  1062. })
  1063. export type ShellInput = z.infer<typeof ShellInput>
  1064. export async function shell(input: ShellInput) {
  1065. const session = await Session.get(input.sessionID)
  1066. if (session.revert) {
  1067. SessionRevert.cleanup(session)
  1068. }
  1069. const agent = await Agent.get(input.agent)
  1070. const model = input.model ?? agent.model ?? (await lastModel(input.sessionID))
  1071. const userMsg: MessageV2.User = {
  1072. id: Identifier.ascending("message"),
  1073. sessionID: input.sessionID,
  1074. time: {
  1075. created: Date.now(),
  1076. },
  1077. role: "user",
  1078. agent: input.agent,
  1079. model: {
  1080. providerID: model.providerID,
  1081. modelID: model.modelID,
  1082. },
  1083. }
  1084. await Session.updateMessage(userMsg)
  1085. const userPart: MessageV2.Part = {
  1086. type: "text",
  1087. id: Identifier.ascending("part"),
  1088. messageID: userMsg.id,
  1089. sessionID: input.sessionID,
  1090. text: "The following tool was executed by the user",
  1091. synthetic: true,
  1092. }
  1093. await Session.updatePart(userPart)
  1094. const msg: MessageV2.Assistant = {
  1095. id: Identifier.ascending("message"),
  1096. sessionID: input.sessionID,
  1097. parentID: userMsg.id,
  1098. mode: input.agent,
  1099. cost: 0,
  1100. path: {
  1101. cwd: Instance.directory,
  1102. root: Instance.worktree,
  1103. },
  1104. time: {
  1105. created: Date.now(),
  1106. },
  1107. role: "assistant",
  1108. tokens: {
  1109. input: 0,
  1110. output: 0,
  1111. reasoning: 0,
  1112. cache: { read: 0, write: 0 },
  1113. },
  1114. modelID: model.modelID,
  1115. providerID: model.providerID,
  1116. }
  1117. await Session.updateMessage(msg)
  1118. const part: MessageV2.Part = {
  1119. type: "tool",
  1120. id: Identifier.ascending("part"),
  1121. messageID: msg.id,
  1122. sessionID: input.sessionID,
  1123. tool: "bash",
  1124. callID: ulid(),
  1125. state: {
  1126. status: "running",
  1127. time: {
  1128. start: Date.now(),
  1129. },
  1130. input: {
  1131. command: input.command,
  1132. },
  1133. },
  1134. }
  1135. await Session.updatePart(part)
  1136. const shell = process.env["SHELL"] ?? "bash"
  1137. const shellName = path.basename(shell)
  1138. const invocations: Record<string, { args: string[] }> = {
  1139. nu: {
  1140. args: ["-c", input.command],
  1141. },
  1142. fish: {
  1143. args: ["-c", input.command],
  1144. },
  1145. zsh: {
  1146. args: [
  1147. "-c",
  1148. "-l",
  1149. `
  1150. [[ -f ~/.zshenv ]] && source ~/.zshenv >/dev/null 2>&1 || true
  1151. [[ -f "\${ZDOTDIR:-$HOME}/.zshrc" ]] && source "\${ZDOTDIR:-$HOME}/.zshrc" >/dev/null 2>&1 || true
  1152. ${input.command}
  1153. `,
  1154. ],
  1155. },
  1156. bash: {
  1157. args: [
  1158. "-c",
  1159. "-l",
  1160. `
  1161. [[ -f ~/.bashrc ]] && source ~/.bashrc >/dev/null 2>&1 || true
  1162. ${input.command}
  1163. `,
  1164. ],
  1165. },
  1166. // Fallback: any shell that doesn't match those above
  1167. "": {
  1168. args: ["-c", "-l", `${input.command}`],
  1169. },
  1170. }
  1171. const matchingInvocation = invocations[shellName] ?? invocations[""]
  1172. const args = matchingInvocation?.args
  1173. const proc = spawn(shell, args, {
  1174. cwd: Instance.directory,
  1175. detached: true,
  1176. stdio: ["ignore", "pipe", "pipe"],
  1177. env: {
  1178. ...process.env,
  1179. TERM: "dumb",
  1180. },
  1181. })
  1182. let output = ""
  1183. proc.stdout?.on("data", (chunk) => {
  1184. output += chunk.toString()
  1185. if (part.state.status === "running") {
  1186. part.state.metadata = {
  1187. output: output,
  1188. description: "",
  1189. }
  1190. Session.updatePart(part)
  1191. }
  1192. })
  1193. proc.stderr?.on("data", (chunk) => {
  1194. output += chunk.toString()
  1195. if (part.state.status === "running") {
  1196. part.state.metadata = {
  1197. output: output,
  1198. description: "",
  1199. }
  1200. Session.updatePart(part)
  1201. }
  1202. })
  1203. await new Promise<void>((resolve) => {
  1204. proc.on("close", () => {
  1205. resolve()
  1206. })
  1207. })
  1208. msg.time.completed = Date.now()
  1209. await Session.updateMessage(msg)
  1210. if (part.state.status === "running") {
  1211. part.state = {
  1212. status: "completed",
  1213. time: {
  1214. ...part.state.time,
  1215. end: Date.now(),
  1216. },
  1217. input: part.state.input,
  1218. title: "",
  1219. metadata: {
  1220. output,
  1221. description: "",
  1222. },
  1223. output,
  1224. }
  1225. await Session.updatePart(part)
  1226. }
  1227. return { info: msg, parts: [part] }
  1228. }
  1229. export const CommandInput = z.object({
  1230. messageID: Identifier.schema("message").optional(),
  1231. sessionID: Identifier.schema("session"),
  1232. agent: z.string().optional(),
  1233. model: z.string().optional(),
  1234. arguments: z.string(),
  1235. command: z.string(),
  1236. })
  1237. export type CommandInput = z.infer<typeof CommandInput>
  1238. const bashRegex = /!`([^`]+)`/g
  1239. const argsRegex = /(?:[^\s"']+|"[^"]*"|'[^']*')+/g
  1240. const placeholderRegex = /\$(\d+)/g
  1241. const quoteTrimRegex = /^["']|["']$/g
  1242. /**
  1243. * Regular expression to match @ file references in text
  1244. * Matches @ followed by file paths, excluding commas, periods at end of sentences, and backticks
  1245. * Does not match when preceded by word characters or backticks (to avoid email addresses and quoted references)
  1246. */
  1247. export async function command(input: CommandInput) {
  1248. log.info("command", input)
  1249. const command = await Command.get(input.command)
  1250. const agentName = command.agent ?? input.agent ?? "build"
  1251. const raw = input.arguments.match(argsRegex) ?? []
  1252. const args = raw.map((arg) => arg.replace(quoteTrimRegex, ""))
  1253. const placeholders = command.template.match(placeholderRegex) ?? []
  1254. let last = 0
  1255. for (const item of placeholders) {
  1256. const value = Number(item.slice(1))
  1257. if (value > last) last = value
  1258. }
  1259. // Let the final placeholder swallow any extra arguments so prompts read naturally
  1260. const withArgs = command.template.replaceAll(placeholderRegex, (_, index) => {
  1261. const position = Number(index)
  1262. const argIndex = position - 1
  1263. if (argIndex >= args.length) return ""
  1264. if (position === last) return args.slice(argIndex).join(" ")
  1265. return args[argIndex]
  1266. })
  1267. let template = withArgs.replaceAll("$ARGUMENTS", input.arguments)
  1268. const shell = ConfigMarkdown.shell(template)
  1269. if (shell.length > 0) {
  1270. const results = await Promise.all(
  1271. shell.map(async ([, cmd]) => {
  1272. try {
  1273. return await $`${{ raw: cmd }}`.nothrow().text()
  1274. } catch (error) {
  1275. return `Error executing command: ${error instanceof Error ? error.message : String(error)}`
  1276. }
  1277. }),
  1278. )
  1279. let index = 0
  1280. template = template.replace(bashRegex, () => results[index++])
  1281. }
  1282. template = template.trim()
  1283. const model = await (async () => {
  1284. if (command.model) {
  1285. return Provider.parseModel(command.model)
  1286. }
  1287. if (command.agent) {
  1288. const cmdAgent = await Agent.get(command.agent)
  1289. if (cmdAgent.model) {
  1290. return cmdAgent.model
  1291. }
  1292. }
  1293. if (input.model) return Provider.parseModel(input.model)
  1294. return await lastModel(input.sessionID)
  1295. })()
  1296. const agent = await Agent.get(agentName)
  1297. const parts =
  1298. (agent.mode === "subagent" && command.subtask !== false) || command.subtask === true
  1299. ? [
  1300. {
  1301. type: "subtask" as const,
  1302. agent: agent.name,
  1303. description: command.description ?? "",
  1304. // TODO: how can we make task tool accept a more complex input?
  1305. prompt: await resolvePromptParts(template).then((x) => x.find((y) => y.type === "text")?.text ?? ""),
  1306. },
  1307. ]
  1308. : await resolvePromptParts(template)
  1309. const result = (await prompt({
  1310. sessionID: input.sessionID,
  1311. messageID: input.messageID,
  1312. model,
  1313. agent: agentName,
  1314. parts,
  1315. })) as MessageV2.WithParts
  1316. Bus.publish(Command.Event.Executed, {
  1317. name: input.command,
  1318. sessionID: input.sessionID,
  1319. arguments: input.arguments,
  1320. messageID: result.info.id,
  1321. })
  1322. return result
  1323. }
  1324. async function ensureTitle(input: {
  1325. session: Session.Info
  1326. message: MessageV2.WithParts
  1327. history: MessageV2.WithParts[]
  1328. providerID: string
  1329. modelID: string
  1330. }) {
  1331. if (input.session.parentID) return
  1332. if (!Session.isDefaultTitle(input.session.title)) return
  1333. const isFirst =
  1334. input.history.filter((m) => m.info.role === "user" && !m.parts.every((p) => "synthetic" in p && p.synthetic))
  1335. .length === 1
  1336. if (!isFirst) return
  1337. const small =
  1338. (await Provider.getSmallModel(input.providerID)) ?? (await Provider.getModel(input.providerID, input.modelID))
  1339. const options = pipe(
  1340. {},
  1341. mergeDeep(ProviderTransform.options(small.providerID, small.modelID, small.npm ?? "", input.session.id)),
  1342. mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, modelID: small.modelID })),
  1343. mergeDeep(small.info.options),
  1344. )
  1345. await generateText({
  1346. // use higher # for reasoning models since reasoning tokens eat up a lot of the budget
  1347. maxOutputTokens: small.info.reasoning ? 3000 : 20,
  1348. providerOptions: ProviderTransform.providerOptions(small.npm, small.providerID, options),
  1349. messages: [
  1350. ...SystemPrompt.title(small.providerID).map(
  1351. (x): ModelMessage => ({
  1352. role: "system",
  1353. content: x,
  1354. }),
  1355. ),
  1356. {
  1357. role: "user",
  1358. content: "Generate a title for this conversation:\n",
  1359. },
  1360. ...MessageV2.toModelMessage([
  1361. {
  1362. info: {
  1363. id: Identifier.ascending("message"),
  1364. role: "user",
  1365. sessionID: input.session.id,
  1366. time: {
  1367. created: Date.now(),
  1368. },
  1369. agent: input.message.info.role === "user" ? input.message.info.agent : "build",
  1370. model: {
  1371. providerID: input.providerID,
  1372. modelID: input.modelID,
  1373. },
  1374. },
  1375. parts: input.message.parts,
  1376. },
  1377. ]),
  1378. ],
  1379. headers: small.info.headers,
  1380. model: small.language,
  1381. })
  1382. .then((result) => {
  1383. if (result.text)
  1384. return Session.update(input.session.id, (draft) => {
  1385. const cleaned = result.text
  1386. .replace(/<think>[\s\S]*?<\/think>\s*/g, "")
  1387. .split("\n")
  1388. .map((line) => line.trim())
  1389. .find((line) => line.length > 0)
  1390. if (!cleaned) return
  1391. const title = cleaned.length > 100 ? cleaned.substring(0, 97) + "..." : cleaned
  1392. draft.title = title
  1393. })
  1394. })
  1395. .catch((error) => {
  1396. log.error("failed to generate title", { error, model: small.info.id })
  1397. })
  1398. }
  1399. }