prompt.ts 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599
  1. import path from "path"
  2. import os from "os"
  3. import fs from "fs/promises"
  4. import z from "zod/v4"
  5. import { Identifier } from "../id/id"
  6. import { MessageV2 } from "./message-v2"
  7. import { Log } from "../util/log"
  8. import { SessionRevert } from "./revert"
  9. import { Session } from "."
  10. import { Agent } from "../agent/agent"
  11. import { Provider } from "../provider/provider"
  12. import {
  13. generateText,
  14. streamText,
  15. type ModelMessage,
  16. type Tool as AITool,
  17. tool,
  18. wrapLanguageModel,
  19. type StreamTextResult,
  20. LoadAPIKeyError,
  21. stepCountIs,
  22. jsonSchema,
  23. } from "ai"
  24. import { SessionCompaction } from "./compaction"
  25. import { Instance } from "../project/instance"
  26. import { Bus } from "../bus"
  27. import { ProviderTransform } from "../provider/transform"
  28. import { SystemPrompt } from "./system"
  29. import { Plugin } from "../plugin"
  30. import PROMPT_PLAN from "../session/prompt/plan.txt"
  31. import BUILD_SWITCH from "../session/prompt/build-switch.txt"
  32. import { ModelsDev } from "../provider/models"
  33. import { defer } from "../util/defer"
  34. import { mergeDeep, pipe } from "remeda"
  35. import { ToolRegistry } from "../tool/registry"
  36. import { Wildcard } from "../util/wildcard"
  37. import { MCP } from "../mcp"
  38. import { LSP } from "../lsp"
  39. import { ReadTool } from "../tool/read"
  40. import { ListTool } from "../tool/ls"
  41. import { TaskTool } from "../tool/task"
  42. import { FileTime } from "../file/time"
  43. import { Permission } from "../permission"
  44. import { Snapshot } from "../snapshot"
  45. import { NamedError } from "../util/error"
  46. import { ulid } from "ulid"
  47. import { spawn } from "child_process"
  48. import { Command } from "../command"
  49. import { $ } from "bun"
  50. export namespace SessionPrompt {
  51. const log = Log.create({ service: "session.prompt" })
  52. export const OUTPUT_TOKEN_MAX = 32_000
  53. export const Event = {
  54. Idle: Bus.event(
  55. "session.idle",
  56. z.object({
  57. sessionID: z.string(),
  58. }),
  59. ),
  60. }
  61. const state = Instance.state(
  62. () => {
  63. const pending = new Map<string, AbortController>()
  64. const queued = new Map<
  65. string,
  66. {
  67. messageID: string
  68. callback: (input: MessageV2.WithParts) => void
  69. }[]
  70. >()
  71. return {
  72. pending,
  73. queued,
  74. }
  75. },
  76. async (state) => {
  77. for (const [_, controller] of state.pending) {
  78. controller.abort()
  79. }
  80. },
  81. )
  82. export const PromptInput = z.object({
  83. sessionID: Identifier.schema("session"),
  84. messageID: Identifier.schema("message").optional(),
  85. model: z
  86. .object({
  87. providerID: z.string(),
  88. modelID: z.string(),
  89. })
  90. .optional(),
  91. agent: z.string().optional(),
  92. system: z.string().optional(),
  93. tools: z.record(z.string(), z.boolean()).optional(),
  94. parts: z.array(
  95. z.discriminatedUnion("type", [
  96. MessageV2.TextPart.omit({
  97. messageID: true,
  98. sessionID: true,
  99. })
  100. .partial({
  101. id: true,
  102. })
  103. .meta({
  104. ref: "TextPartInput",
  105. }),
  106. MessageV2.FilePart.omit({
  107. messageID: true,
  108. sessionID: true,
  109. })
  110. .partial({
  111. id: true,
  112. })
  113. .meta({
  114. ref: "FilePartInput",
  115. }),
  116. MessageV2.AgentPart.omit({
  117. messageID: true,
  118. sessionID: true,
  119. })
  120. .partial({
  121. id: true,
  122. })
  123. .meta({
  124. ref: "AgentPartInput",
  125. }),
  126. ]),
  127. ),
  128. })
  129. export type PromptInput = z.infer<typeof PromptInput>
  130. export async function prompt(input: PromptInput): Promise<MessageV2.WithParts> {
  131. const l = log.clone().tag("session", input.sessionID)
  132. l.info("prompt")
  133. const session = await Session.get(input.sessionID)
  134. await SessionRevert.cleanup(session)
  135. const userMsg = await createUserMessage(input)
  136. await Session.touch(input.sessionID)
  137. if (isBusy(input.sessionID)) {
  138. return new Promise((resolve) => {
  139. const queue = state().queued.get(input.sessionID) ?? []
  140. queue.push({
  141. messageID: userMsg.info.id,
  142. callback: resolve,
  143. })
  144. state().queued.set(input.sessionID, queue)
  145. })
  146. }
  147. const agent = await Agent.get(input.agent ?? "build")
  148. const model = await resolveModel({
  149. agent,
  150. model: input.model,
  151. }).then((x) => Provider.getModel(x.providerID, x.modelID))
  152. const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
  153. using abort = lock(input.sessionID)
  154. const system = await resolveSystemPrompt({
  155. providerID: model.providerID,
  156. modelID: model.info.id,
  157. agent,
  158. system: input.system,
  159. })
  160. const processor = await createProcessor({
  161. sessionID: input.sessionID,
  162. model: model.info,
  163. providerID: model.providerID,
  164. agent: agent.name,
  165. system,
  166. abort: abort.signal,
  167. })
  168. const tools = await resolveTools({
  169. agent,
  170. sessionID: input.sessionID,
  171. modelID: model.modelID,
  172. providerID: model.providerID,
  173. tools: input.tools,
  174. processor,
  175. })
  176. const params = await Plugin.trigger(
  177. "chat.params",
  178. {
  179. model: model.info,
  180. provider: await Provider.getProvider(model.providerID),
  181. message: userMsg,
  182. },
  183. {
  184. temperature: model.info.temperature
  185. ? (agent.temperature ?? ProviderTransform.temperature(model.providerID, model.modelID))
  186. : undefined,
  187. topP: agent.topP ?? ProviderTransform.topP(model.providerID, model.modelID),
  188. options: {
  189. ...ProviderTransform.options(model.providerID, model.modelID, input.sessionID),
  190. ...model.info.options,
  191. ...agent.options,
  192. },
  193. },
  194. )
  195. let step = 0
  196. while (true) {
  197. const msgs: MessageV2.WithParts[] = pipe(
  198. await getMessages({
  199. sessionID: input.sessionID,
  200. model: model.info,
  201. providerID: model.providerID,
  202. }),
  203. (messages) => insertReminders({ messages, agent }),
  204. )
  205. if (step === 0)
  206. ensureTitle({
  207. session,
  208. history: msgs,
  209. message: userMsg,
  210. providerID: model.providerID,
  211. modelID: model.info.id,
  212. })
  213. step++
  214. await processor.next()
  215. await using _ = defer(async () => {
  216. await processor.end()
  217. })
  218. const stream = streamText({
  219. onError(error) {
  220. log.error("stream error", {
  221. error,
  222. })
  223. },
  224. async experimental_repairToolCall(input) {
  225. const lower = input.toolCall.toolName.toLowerCase()
  226. if (lower !== input.toolCall.toolName && tools[lower]) {
  227. log.info("repairing tool call", {
  228. tool: input.toolCall.toolName,
  229. repaired: lower,
  230. })
  231. return {
  232. ...input.toolCall,
  233. toolName: lower,
  234. }
  235. }
  236. return {
  237. ...input.toolCall,
  238. input: JSON.stringify({
  239. tool: input.toolCall.toolName,
  240. error: input.error.message,
  241. }),
  242. toolName: "invalid",
  243. }
  244. },
  245. headers:
  246. model.providerID === "opencode"
  247. ? {
  248. "x-opencode-session": input.sessionID,
  249. "x-opencode-request": userMsg.info.id,
  250. }
  251. : undefined,
  252. maxRetries: 10,
  253. activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
  254. maxOutputTokens: ProviderTransform.maxOutputTokens(model.providerID, outputLimit, params.options),
  255. abortSignal: abort.signal,
  256. providerOptions: {
  257. [model.npm === "@ai-sdk/openai" ? "openai" : model.providerID]: params.options,
  258. },
  259. stopWhen: stepCountIs(1),
  260. temperature: params.temperature,
  261. topP: params.topP,
  262. messages: [
  263. ...system.map(
  264. (x): ModelMessage => ({
  265. role: "system",
  266. content: x,
  267. }),
  268. ),
  269. ...MessageV2.toModelMessage(
  270. msgs.filter(
  271. (m) => !(m.info.role === "assistant" && m.info.error && !MessageV2.AbortedError.isInstance(m.info.error)),
  272. ),
  273. ),
  274. ],
  275. tools: model.info.tool_call === false ? undefined : tools,
  276. model: wrapLanguageModel({
  277. model: model.language,
  278. middleware: [
  279. {
  280. async transformParams(args) {
  281. if (args.type === "stream") {
  282. // @ts-expect-error
  283. args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
  284. }
  285. return args.params
  286. },
  287. },
  288. ],
  289. }),
  290. })
  291. const result = await processor.process(stream)
  292. await processor.end()
  293. const queued = state().queued.get(input.sessionID) ?? []
  294. if (!result.blocked && !result.info.error) {
  295. if ((await stream.finishReason) === "tool-calls") {
  296. continue
  297. }
  298. const unprocessed = queued.filter((x) => x.messageID > result.info.id)
  299. if (unprocessed.length) {
  300. continue
  301. }
  302. }
  303. for (const item of queued) {
  304. item.callback(result)
  305. }
  306. state().queued.delete(input.sessionID)
  307. // Session.prune(input)
  308. return result
  309. }
  310. }
  311. async function getMessages(input: { sessionID: string; model: ModelsDev.Model; providerID: string }) {
  312. let msgs = await Session.messages(input.sessionID).then(MessageV2.filterSummarized)
  313. const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant")
  314. if (
  315. lastAssistant?.info.role === "assistant" &&
  316. SessionCompaction.isOverflow({
  317. tokens: lastAssistant.info.tokens,
  318. model: input.model,
  319. })
  320. ) {
  321. const msg = await SessionCompaction.run({
  322. sessionID: input.sessionID,
  323. providerID: input.providerID,
  324. modelID: input.model.id,
  325. })
  326. msgs = [msg]
  327. }
  328. return msgs
  329. }
  330. async function resolveModel(input: { model: PromptInput["model"]; agent: Agent.Info }) {
  331. if (input.model) {
  332. return input.model
  333. }
  334. if (input.agent.model) {
  335. return input.agent.model
  336. }
  337. return Provider.defaultModel()
  338. }
  339. async function resolveSystemPrompt(input: {
  340. system?: string
  341. agent: Agent.Info
  342. providerID: string
  343. modelID: string
  344. }) {
  345. let system = SystemPrompt.header(input.providerID)
  346. system.push(
  347. ...(() => {
  348. if (input.system) return [input.system]
  349. if (input.agent.prompt) return [input.agent.prompt]
  350. return SystemPrompt.provider(input.modelID)
  351. })(),
  352. )
  353. system.push(...(await SystemPrompt.environment()))
  354. system.push(...(await SystemPrompt.custom()))
  355. // max 2 system prompt messages for caching purposes
  356. const [first, ...rest] = system
  357. system = [first, rest.join("\n")]
  358. return system
  359. }
  360. async function resolveTools(input: {
  361. agent: Agent.Info
  362. sessionID: string
  363. modelID: string
  364. providerID: string
  365. tools?: Record<string, boolean>
  366. processor: Processor
  367. }) {
  368. const tools: Record<string, AITool> = {}
  369. const enabledTools = pipe(
  370. input.agent.tools,
  371. mergeDeep(await ToolRegistry.enabled(input.providerID, input.modelID, input.agent)),
  372. mergeDeep(input.tools ?? {}),
  373. )
  374. for (const item of await ToolRegistry.tools(input.providerID, input.modelID)) {
  375. if (Wildcard.all(item.id, enabledTools) === false) continue
  376. const schema = ProviderTransform.schema(input.providerID, input.modelID, z.toJSONSchema(item.parameters))
  377. tools[item.id] = tool({
  378. id: item.id as any,
  379. description: item.description,
  380. inputSchema: jsonSchema(schema as any),
  381. async execute(args, options) {
  382. await Plugin.trigger(
  383. "tool.execute.before",
  384. {
  385. tool: item.id,
  386. sessionID: input.sessionID,
  387. callID: options.toolCallId,
  388. },
  389. {
  390. args,
  391. },
  392. )
  393. const result = await item.execute(args, {
  394. sessionID: input.sessionID,
  395. abort: options.abortSignal!,
  396. messageID: input.processor.message.id,
  397. callID: options.toolCallId,
  398. agent: input.agent.name,
  399. metadata: async (val) => {
  400. const match = input.processor.partFromToolCall(options.toolCallId)
  401. if (match && match.state.status === "running") {
  402. await Session.updatePart({
  403. ...match,
  404. state: {
  405. title: val.title,
  406. metadata: val.metadata,
  407. status: "running",
  408. input: args,
  409. time: {
  410. start: Date.now(),
  411. },
  412. },
  413. })
  414. }
  415. },
  416. })
  417. await Plugin.trigger(
  418. "tool.execute.after",
  419. {
  420. tool: item.id,
  421. sessionID: input.sessionID,
  422. callID: options.toolCallId,
  423. },
  424. result,
  425. )
  426. return result
  427. },
  428. toModelOutput(result) {
  429. return {
  430. type: "text",
  431. value: result.output,
  432. }
  433. },
  434. })
  435. }
  436. for (const [key, item] of Object.entries(await MCP.tools())) {
  437. if (Wildcard.all(key, enabledTools) === false) continue
  438. const execute = item.execute
  439. if (!execute) continue
  440. item.execute = async (args, opts) => {
  441. await Plugin.trigger(
  442. "tool.execute.before",
  443. {
  444. tool: key,
  445. sessionID: input.sessionID,
  446. callID: opts.toolCallId,
  447. },
  448. {
  449. args,
  450. },
  451. )
  452. const result = await execute(args, opts)
  453. const output = result.content
  454. .filter((x: any) => x.type === "text")
  455. .map((x: any) => x.text)
  456. .join("\n\n")
  457. await Plugin.trigger(
  458. "tool.execute.after",
  459. {
  460. tool: key,
  461. sessionID: input.sessionID,
  462. callID: opts.toolCallId,
  463. },
  464. result,
  465. )
  466. return {
  467. output,
  468. }
  469. }
  470. item.toModelOutput = (result) => {
  471. return {
  472. type: "text",
  473. value: result.output,
  474. }
  475. }
  476. tools[key] = item
  477. }
  478. return tools
  479. }
  480. async function createUserMessage(input: PromptInput) {
  481. const info: MessageV2.Info = {
  482. id: input.messageID ?? Identifier.ascending("message"),
  483. role: "user",
  484. sessionID: input.sessionID,
  485. time: {
  486. created: Date.now(),
  487. },
  488. }
  489. const parts = await Promise.all(
  490. input.parts.map(async (part): Promise<MessageV2.Part[]> => {
  491. if (part.type === "file") {
  492. const url = new URL(part.url)
  493. switch (url.protocol) {
  494. case "data:":
  495. if (part.mime === "text/plain") {
  496. return [
  497. {
  498. id: Identifier.ascending("part"),
  499. messageID: info.id,
  500. sessionID: input.sessionID,
  501. type: "text",
  502. synthetic: true,
  503. text: `Called the Read tool with the following input: ${JSON.stringify({ filePath: part.filename })}`,
  504. },
  505. {
  506. id: Identifier.ascending("part"),
  507. messageID: info.id,
  508. sessionID: input.sessionID,
  509. type: "text",
  510. synthetic: true,
  511. text: Buffer.from(part.url, "base64url").toString(),
  512. },
  513. {
  514. ...part,
  515. id: part.id ?? Identifier.ascending("part"),
  516. messageID: info.id,
  517. sessionID: input.sessionID,
  518. },
  519. ]
  520. }
  521. break
  522. case "file:":
  523. // have to normalize, symbol search returns absolute paths
  524. // Decode the pathname since URL constructor doesn't automatically decode it
  525. const filePath = decodeURIComponent(url.pathname)
  526. if (part.mime === "text/plain") {
  527. let offset: number | undefined = undefined
  528. let limit: number | undefined = undefined
  529. const range = {
  530. start: url.searchParams.get("start"),
  531. end: url.searchParams.get("end"),
  532. }
  533. if (range.start != null) {
  534. const filePath = part.url.split("?")[0]
  535. let start = parseInt(range.start)
  536. let end = range.end ? parseInt(range.end) : undefined
  537. // some LSP servers (eg, gopls) don't give full range in
  538. // workspace/symbol searches, so we'll try to find the
  539. // symbol in the document to get the full range
  540. if (start === end) {
  541. const symbols = await LSP.documentSymbol(filePath)
  542. for (const symbol of symbols) {
  543. let range: LSP.Range | undefined
  544. if ("range" in symbol) {
  545. range = symbol.range
  546. } else if ("location" in symbol) {
  547. range = symbol.location.range
  548. }
  549. if (range?.start?.line && range?.start?.line === start) {
  550. start = range.start.line
  551. end = range?.end?.line ?? start
  552. break
  553. }
  554. }
  555. }
  556. offset = Math.max(start - 1, 0)
  557. if (end) {
  558. limit = end - offset
  559. }
  560. }
  561. const args = { filePath, offset, limit }
  562. const result = await ReadTool.init().then((t) =>
  563. t.execute(args, {
  564. sessionID: input.sessionID,
  565. abort: new AbortController().signal,
  566. agent: input.agent!,
  567. messageID: info.id,
  568. extra: { bypassCwdCheck: true },
  569. metadata: async () => {},
  570. }),
  571. )
  572. return [
  573. {
  574. id: Identifier.ascending("part"),
  575. messageID: info.id,
  576. sessionID: input.sessionID,
  577. type: "text",
  578. synthetic: true,
  579. text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
  580. },
  581. {
  582. id: Identifier.ascending("part"),
  583. messageID: info.id,
  584. sessionID: input.sessionID,
  585. type: "text",
  586. synthetic: true,
  587. text: result.output,
  588. },
  589. {
  590. ...part,
  591. id: part.id ?? Identifier.ascending("part"),
  592. messageID: info.id,
  593. sessionID: input.sessionID,
  594. },
  595. ]
  596. }
  597. if (part.mime === "application/x-directory") {
  598. const args = { path: filePath }
  599. const result = await ListTool.init().then((t) =>
  600. t.execute(args, {
  601. sessionID: input.sessionID,
  602. abort: new AbortController().signal,
  603. agent: input.agent!,
  604. messageID: info.id,
  605. extra: { bypassCwdCheck: true },
  606. metadata: async () => {},
  607. }),
  608. )
  609. return [
  610. {
  611. id: Identifier.ascending("part"),
  612. messageID: info.id,
  613. sessionID: input.sessionID,
  614. type: "text",
  615. synthetic: true,
  616. text: `Called the list tool with the following input: ${JSON.stringify(args)}`,
  617. },
  618. {
  619. id: Identifier.ascending("part"),
  620. messageID: info.id,
  621. sessionID: input.sessionID,
  622. type: "text",
  623. synthetic: true,
  624. text: result.output,
  625. },
  626. {
  627. ...part,
  628. id: part.id ?? Identifier.ascending("part"),
  629. messageID: info.id,
  630. sessionID: input.sessionID,
  631. },
  632. ]
  633. }
  634. const file = Bun.file(filePath)
  635. FileTime.read(input.sessionID, filePath)
  636. return [
  637. {
  638. id: Identifier.ascending("part"),
  639. messageID: info.id,
  640. sessionID: input.sessionID,
  641. type: "text",
  642. text: `Called the Read tool with the following input: {\"filePath\":\"${filePath}\"}`,
  643. synthetic: true,
  644. },
  645. {
  646. id: part.id ?? Identifier.ascending("part"),
  647. messageID: info.id,
  648. sessionID: input.sessionID,
  649. type: "file",
  650. url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
  651. mime: part.mime,
  652. filename: part.filename!,
  653. source: part.source,
  654. },
  655. ]
  656. }
  657. }
  658. if (part.type === "agent") {
  659. return [
  660. {
  661. id: Identifier.ascending("part"),
  662. ...part,
  663. messageID: info.id,
  664. sessionID: input.sessionID,
  665. },
  666. {
  667. id: Identifier.ascending("part"),
  668. messageID: info.id,
  669. sessionID: input.sessionID,
  670. type: "text",
  671. synthetic: true,
  672. text:
  673. "Use the above message and context to generate a prompt and call the task tool with subagent: " +
  674. part.name,
  675. },
  676. ]
  677. }
  678. return [
  679. {
  680. id: Identifier.ascending("part"),
  681. ...part,
  682. messageID: info.id,
  683. sessionID: input.sessionID,
  684. },
  685. ]
  686. }),
  687. ).then((x) => x.flat())
  688. await Plugin.trigger(
  689. "chat.message",
  690. {},
  691. {
  692. message: info,
  693. parts,
  694. },
  695. )
  696. await Session.updateMessage(info)
  697. for (const part of parts) {
  698. await Session.updatePart(part)
  699. }
  700. return {
  701. info,
  702. parts,
  703. }
  704. }
  705. function insertReminders(input: { messages: MessageV2.WithParts[]; agent: Agent.Info }) {
  706. const userMessage = input.messages.findLast((msg) => msg.info.role === "user")
  707. if (!userMessage) return input.messages
  708. if (input.agent.name === "plan") {
  709. userMessage.parts.push({
  710. id: Identifier.ascending("part"),
  711. messageID: userMessage.info.id,
  712. sessionID: userMessage.info.sessionID,
  713. type: "text",
  714. text: PROMPT_PLAN,
  715. synthetic: true,
  716. })
  717. }
  718. const wasPlan = input.messages.some((msg) => msg.info.role === "assistant" && msg.info.mode === "plan")
  719. if (wasPlan && input.agent.name === "build") {
  720. userMessage.parts.push({
  721. id: Identifier.ascending("part"),
  722. messageID: userMessage.info.id,
  723. sessionID: userMessage.info.sessionID,
  724. type: "text",
  725. text: BUILD_SWITCH,
  726. synthetic: true,
  727. })
  728. }
  729. return input.messages
  730. }
  731. export type Processor = Awaited<ReturnType<typeof createProcessor>>
  732. async function createProcessor(input: {
  733. sessionID: string
  734. providerID: string
  735. model: ModelsDev.Model
  736. system: string[]
  737. agent: string
  738. abort: AbortSignal
  739. }) {
  740. const toolcalls: Record<string, MessageV2.ToolPart> = {}
  741. let snapshot: string | undefined
  742. let blocked = false
  743. async function createMessage() {
  744. const msg: MessageV2.Info = {
  745. id: Identifier.ascending("message"),
  746. role: "assistant",
  747. system: input.system,
  748. mode: input.agent,
  749. path: {
  750. cwd: Instance.directory,
  751. root: Instance.worktree,
  752. },
  753. cost: 0,
  754. tokens: {
  755. input: 0,
  756. output: 0,
  757. reasoning: 0,
  758. cache: { read: 0, write: 0 },
  759. },
  760. modelID: input.model.id,
  761. providerID: input.providerID,
  762. time: {
  763. created: Date.now(),
  764. },
  765. sessionID: input.sessionID,
  766. }
  767. await Session.updateMessage(msg)
  768. return msg
  769. }
  770. let assistantMsg: MessageV2.Assistant | undefined
  771. const result = {
  772. async end() {
  773. if (assistantMsg) {
  774. assistantMsg.time.completed = Date.now()
  775. await Session.updateMessage(assistantMsg)
  776. assistantMsg = undefined
  777. }
  778. },
  779. async next() {
  780. if (assistantMsg) {
  781. throw new Error("end previous assistant message first")
  782. }
  783. assistantMsg = await createMessage()
  784. return assistantMsg
  785. },
  786. get message() {
  787. if (!assistantMsg) throw new Error("call next() first before accessing message")
  788. return assistantMsg
  789. },
  790. partFromToolCall(toolCallID: string) {
  791. return toolcalls[toolCallID]
  792. },
  793. async process(stream: StreamTextResult<Record<string, AITool>, never>) {
  794. log.info("process")
  795. if (!assistantMsg) throw new Error("call next() first before processing")
  796. try {
  797. let currentText: MessageV2.TextPart | undefined
  798. let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
  799. for await (const value of stream.fullStream) {
  800. input.abort.throwIfAborted()
  801. log.info("part", {
  802. type: value.type,
  803. })
  804. switch (value.type) {
  805. case "start":
  806. break
  807. case "reasoning-start":
  808. if (value.id in reasoningMap) {
  809. continue
  810. }
  811. reasoningMap[value.id] = {
  812. id: Identifier.ascending("part"),
  813. messageID: assistantMsg.id,
  814. sessionID: assistantMsg.sessionID,
  815. type: "reasoning",
  816. text: "",
  817. time: {
  818. start: Date.now(),
  819. },
  820. }
  821. break
  822. case "reasoning-delta":
  823. if (value.id in reasoningMap) {
  824. const part = reasoningMap[value.id]
  825. part.text += value.text
  826. if (value.providerMetadata) part.metadata = value.providerMetadata
  827. if (part.text) await Session.updatePart(part)
  828. }
  829. break
  830. case "reasoning-end":
  831. if (value.id in reasoningMap) {
  832. const part = reasoningMap[value.id]
  833. part.text = part.text.trimEnd()
  834. part.time = {
  835. ...part.time,
  836. end: Date.now(),
  837. }
  838. await Session.updatePart(part)
  839. delete reasoningMap[value.id]
  840. }
  841. break
  842. case "tool-input-start":
  843. const part = await Session.updatePart({
  844. id: toolcalls[value.id]?.id ?? Identifier.ascending("part"),
  845. messageID: assistantMsg.id,
  846. sessionID: assistantMsg.sessionID,
  847. type: "tool",
  848. tool: value.toolName,
  849. callID: value.id,
  850. state: {
  851. status: "pending",
  852. },
  853. })
  854. toolcalls[value.id] = part as MessageV2.ToolPart
  855. break
  856. case "tool-input-delta":
  857. break
  858. case "tool-input-end":
  859. break
  860. case "tool-call": {
  861. const match = toolcalls[value.toolCallId]
  862. if (match) {
  863. const part = await Session.updatePart({
  864. ...match,
  865. tool: value.toolName,
  866. state: {
  867. status: "running",
  868. input: value.input,
  869. time: {
  870. start: Date.now(),
  871. },
  872. },
  873. })
  874. toolcalls[value.toolCallId] = part as MessageV2.ToolPart
  875. }
  876. break
  877. }
  878. case "tool-result": {
  879. const match = toolcalls[value.toolCallId]
  880. if (match && match.state.status === "running") {
  881. await Session.updatePart({
  882. ...match,
  883. state: {
  884. status: "completed",
  885. input: value.input,
  886. output: value.output.output,
  887. metadata: value.output.metadata,
  888. title: value.output.title,
  889. time: {
  890. start: match.state.time.start,
  891. end: Date.now(),
  892. },
  893. },
  894. })
  895. delete toolcalls[value.toolCallId]
  896. }
  897. break
  898. }
  899. case "tool-error": {
  900. const match = toolcalls[value.toolCallId]
  901. if (match && match.state.status === "running") {
  902. await Session.updatePart({
  903. ...match,
  904. state: {
  905. status: "error",
  906. input: value.input,
  907. error: (value.error as any).toString(),
  908. metadata: value.error instanceof Permission.RejectedError ? value.error.metadata : undefined,
  909. time: {
  910. start: match.state.time.start,
  911. end: Date.now(),
  912. },
  913. },
  914. })
  915. if (value.error instanceof Permission.RejectedError) {
  916. blocked = true
  917. }
  918. delete toolcalls[value.toolCallId]
  919. }
  920. break
  921. }
  922. case "error":
  923. throw value.error
  924. case "start-step":
  925. await Session.updatePart({
  926. id: Identifier.ascending("part"),
  927. messageID: assistantMsg.id,
  928. sessionID: assistantMsg.sessionID,
  929. type: "step-start",
  930. })
  931. snapshot = await Snapshot.track()
  932. break
  933. case "finish-step":
  934. const usage = Session.getUsage(input.model, value.usage, value.providerMetadata)
  935. assistantMsg.cost += usage.cost
  936. assistantMsg.tokens = usage.tokens
  937. await Session.updatePart({
  938. id: Identifier.ascending("part"),
  939. messageID: assistantMsg.id,
  940. sessionID: assistantMsg.sessionID,
  941. type: "step-finish",
  942. tokens: usage.tokens,
  943. cost: usage.cost,
  944. })
  945. await Session.updateMessage(assistantMsg)
  946. if (snapshot) {
  947. const patch = await Snapshot.patch(snapshot)
  948. if (patch.files.length) {
  949. await Session.updatePart({
  950. id: Identifier.ascending("part"),
  951. messageID: assistantMsg.id,
  952. sessionID: assistantMsg.sessionID,
  953. type: "patch",
  954. hash: patch.hash,
  955. files: patch.files,
  956. })
  957. }
  958. snapshot = undefined
  959. }
  960. break
  961. case "text-start":
  962. currentText = {
  963. id: Identifier.ascending("part"),
  964. messageID: assistantMsg.id,
  965. sessionID: assistantMsg.sessionID,
  966. type: "text",
  967. text: "",
  968. time: {
  969. start: Date.now(),
  970. },
  971. }
  972. break
  973. case "text-delta":
  974. if (currentText) {
  975. currentText.text += value.text
  976. if (currentText.text) await Session.updatePart(currentText)
  977. }
  978. break
  979. case "text-end":
  980. if (currentText) {
  981. currentText.text = currentText.text.trimEnd()
  982. currentText.time = {
  983. start: Date.now(),
  984. end: Date.now(),
  985. }
  986. await Session.updatePart(currentText)
  987. }
  988. currentText = undefined
  989. break
  990. case "finish":
  991. assistantMsg.time.completed = Date.now()
  992. await Session.updateMessage(assistantMsg)
  993. break
  994. default:
  995. log.info("unhandled", {
  996. ...value,
  997. })
  998. continue
  999. }
  1000. }
  1001. } catch (e) {
  1002. log.error("process", {
  1003. error: e,
  1004. })
  1005. switch (true) {
  1006. case e instanceof DOMException && e.name === "AbortError":
  1007. assistantMsg.error = new MessageV2.AbortedError(
  1008. { message: e.message },
  1009. {
  1010. cause: e,
  1011. },
  1012. ).toObject()
  1013. break
  1014. case MessageV2.OutputLengthError.isInstance(e):
  1015. assistantMsg.error = e
  1016. break
  1017. case LoadAPIKeyError.isInstance(e):
  1018. assistantMsg.error = new MessageV2.AuthError(
  1019. {
  1020. providerID: input.providerID,
  1021. message: e.message,
  1022. },
  1023. { cause: e },
  1024. ).toObject()
  1025. break
  1026. case e instanceof Error:
  1027. assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  1028. break
  1029. default:
  1030. assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
  1031. }
  1032. Bus.publish(Session.Event.Error, {
  1033. sessionID: assistantMsg.sessionID,
  1034. error: assistantMsg.error,
  1035. })
  1036. }
  1037. const p = await Session.getParts(assistantMsg.id)
  1038. for (const part of p) {
  1039. if (part.type === "tool" && part.state.status !== "completed" && part.state.status !== "error") {
  1040. Session.updatePart({
  1041. ...part,
  1042. state: {
  1043. status: "error",
  1044. error: "Tool execution aborted",
  1045. time: {
  1046. start: Date.now(),
  1047. end: Date.now(),
  1048. },
  1049. input: {},
  1050. },
  1051. })
  1052. }
  1053. }
  1054. assistantMsg.time.completed = Date.now()
  1055. await Session.updateMessage(assistantMsg)
  1056. return { info: assistantMsg, parts: p, blocked }
  1057. },
  1058. }
  1059. return result
  1060. }
  1061. function isBusy(sessionID: string) {
  1062. return state().pending.has(sessionID)
  1063. }
  1064. export function abort(sessionID: string) {
  1065. const controller = state().pending.get(sessionID)
  1066. if (!controller) return false
  1067. log.info("aborting", {
  1068. sessionID,
  1069. })
  1070. controller.abort()
  1071. state().pending.delete(sessionID)
  1072. return true
  1073. }
  1074. function lock(sessionID: string) {
  1075. log.info("locking", { sessionID })
  1076. if (state().pending.has(sessionID)) throw new Error("TODO")
  1077. const controller = new AbortController()
  1078. state().pending.set(sessionID, controller)
  1079. return {
  1080. signal: controller.signal,
  1081. async [Symbol.dispose]() {
  1082. log.info("unlocking", { sessionID })
  1083. state().pending.delete(sessionID)
  1084. const session = await Session.get(sessionID)
  1085. if (session.parentID) return
  1086. Bus.publish(Event.Idle, {
  1087. sessionID,
  1088. })
  1089. },
  1090. }
  1091. }
  1092. export const ShellInput = z.object({
  1093. sessionID: Identifier.schema("session"),
  1094. agent: z.string(),
  1095. command: z.string(),
  1096. })
  1097. export type ShellInput = z.infer<typeof ShellInput>
  1098. export async function shell(input: ShellInput) {
  1099. using abort = lock(input.sessionID)
  1100. const session = await Session.get(input.sessionID)
  1101. if (session.revert) {
  1102. SessionRevert.cleanup(session)
  1103. }
  1104. const userMsg: MessageV2.User = {
  1105. id: Identifier.ascending("message"),
  1106. sessionID: input.sessionID,
  1107. time: {
  1108. created: Date.now(),
  1109. },
  1110. role: "user",
  1111. }
  1112. await Session.updateMessage(userMsg)
  1113. const userPart: MessageV2.Part = {
  1114. type: "text",
  1115. id: Identifier.ascending("part"),
  1116. messageID: userMsg.id,
  1117. sessionID: input.sessionID,
  1118. text: "The following tool was executed by the user",
  1119. synthetic: true,
  1120. }
  1121. await Session.updatePart(userPart)
  1122. const msg: MessageV2.Assistant = {
  1123. id: Identifier.ascending("message"),
  1124. sessionID: input.sessionID,
  1125. system: [],
  1126. mode: input.agent,
  1127. cost: 0,
  1128. path: {
  1129. cwd: Instance.directory,
  1130. root: Instance.worktree,
  1131. },
  1132. time: {
  1133. created: Date.now(),
  1134. },
  1135. role: "assistant",
  1136. tokens: {
  1137. input: 0,
  1138. output: 0,
  1139. reasoning: 0,
  1140. cache: { read: 0, write: 0 },
  1141. },
  1142. modelID: "",
  1143. providerID: "",
  1144. }
  1145. await Session.updateMessage(msg)
  1146. const part: MessageV2.Part = {
  1147. type: "tool",
  1148. id: Identifier.ascending("part"),
  1149. messageID: msg.id,
  1150. sessionID: input.sessionID,
  1151. tool: "bash",
  1152. callID: ulid(),
  1153. state: {
  1154. status: "running",
  1155. time: {
  1156. start: Date.now(),
  1157. },
  1158. input: {
  1159. command: input.command,
  1160. },
  1161. },
  1162. }
  1163. await Session.updatePart(part)
  1164. const shell = process.env["SHELL"] ?? "bash"
  1165. const shellName = path.basename(shell)
  1166. const scripts: Record<string, string> = {
  1167. nu: input.command,
  1168. fish: `eval "${input.command}"`,
  1169. }
  1170. const script =
  1171. scripts[shellName] ??
  1172. `[[ -f ~/.zshenv ]] && source ~/.zshenv >/dev/null 2>&1 || true
  1173. [[ -f "\${ZDOTDIR:-$HOME}/.zshrc" ]] && source "\${ZDOTDIR:-$HOME}/.zshrc" >/dev/null 2>&1 || true
  1174. [[ -f ~/.bashrc ]] && source ~/.bashrc >/dev/null 2>&1 || true
  1175. eval "${input.command}"`
  1176. const isFishOrNu = shellName === "fish" || shellName === "nu"
  1177. const args = isFishOrNu ? ["-c", script] : ["-c", "-l", script]
  1178. const proc = spawn(shell, args, {
  1179. cwd: Instance.directory,
  1180. signal: abort.signal,
  1181. detached: true,
  1182. stdio: ["ignore", "pipe", "pipe"],
  1183. env: {
  1184. ...process.env,
  1185. TERM: "dumb",
  1186. },
  1187. })
  1188. abort.signal.addEventListener("abort", () => {
  1189. if (!proc.pid) return
  1190. process.kill(-proc.pid)
  1191. })
  1192. let output = ""
  1193. proc.stdout?.on("data", (chunk) => {
  1194. output += chunk.toString()
  1195. if (part.state.status === "running") {
  1196. part.state.metadata = {
  1197. output: output,
  1198. description: "",
  1199. }
  1200. Session.updatePart(part)
  1201. }
  1202. })
  1203. proc.stderr?.on("data", (chunk) => {
  1204. output += chunk.toString()
  1205. if (part.state.status === "running") {
  1206. part.state.metadata = {
  1207. output: output,
  1208. description: "",
  1209. }
  1210. Session.updatePart(part)
  1211. }
  1212. })
  1213. await new Promise<void>((resolve) => {
  1214. proc.on("close", () => {
  1215. resolve()
  1216. })
  1217. })
  1218. msg.time.completed = Date.now()
  1219. await Session.updateMessage(msg)
  1220. if (part.state.status === "running") {
  1221. part.state = {
  1222. status: "completed",
  1223. time: {
  1224. ...part.state.time,
  1225. end: Date.now(),
  1226. },
  1227. input: part.state.input,
  1228. title: "",
  1229. metadata: {
  1230. output,
  1231. description: "",
  1232. },
  1233. output,
  1234. }
  1235. await Session.updatePart(part)
  1236. }
  1237. return { info: msg, parts: [part] }
  1238. }
  1239. export const CommandInput = z.object({
  1240. messageID: Identifier.schema("message").optional(),
  1241. sessionID: Identifier.schema("session"),
  1242. agent: z.string().optional(),
  1243. model: z.string().optional(),
  1244. arguments: z.string(),
  1245. command: z.string(),
  1246. })
  1247. export type CommandInput = z.infer<typeof CommandInput>
  1248. const bashRegex = /!`([^`]+)`/g
  1249. /**
  1250. * Regular expression to match @ file references in text
  1251. * Matches @ followed by file paths, excluding commas, periods at end of sentences, and backticks
  1252. * Does not match when preceded by word characters or backticks (to avoid email addresses and quoted references)
  1253. */
  1254. export const fileRegex = /(?<![\w`])@(\.?[^\s`,.]*(?:\.[^\s`,.]+)*)/g
  1255. export async function command(input: CommandInput) {
  1256. log.info("command", input)
  1257. const command = await Command.get(input.command)
  1258. const agentName = command.agent ?? input.agent ?? "build"
  1259. let template = command.template.replace("$ARGUMENTS", input.arguments)
  1260. const bash = Array.from(template.matchAll(bashRegex))
  1261. if (bash.length > 0) {
  1262. const results = await Promise.all(
  1263. bash.map(async ([, cmd]) => {
  1264. try {
  1265. return await $`${{ raw: cmd }}`.nothrow().text()
  1266. } catch (error) {
  1267. return `Error executing command: ${error instanceof Error ? error.message : String(error)}`
  1268. }
  1269. }),
  1270. )
  1271. let index = 0
  1272. template = template.replace(bashRegex, () => results[index++])
  1273. }
  1274. const parts = [
  1275. {
  1276. type: "text",
  1277. text: template,
  1278. },
  1279. ] as PromptInput["parts"]
  1280. const matches = Array.from(template.matchAll(fileRegex))
  1281. await Promise.all(
  1282. matches.map(async (match) => {
  1283. const name = match[1]
  1284. const filepath = name.startsWith("~/")
  1285. ? path.join(os.homedir(), name.slice(2))
  1286. : path.resolve(Instance.worktree, name)
  1287. const stats = await fs.stat(filepath).catch(() => undefined)
  1288. if (!stats) {
  1289. const agent = await Agent.get(name)
  1290. if (agent) {
  1291. parts.push({
  1292. type: "agent",
  1293. name: agent.name,
  1294. })
  1295. }
  1296. return
  1297. }
  1298. if (stats.isDirectory()) {
  1299. parts.push({
  1300. type: "file",
  1301. url: `file://${filepath}`,
  1302. filename: name,
  1303. mime: "application/x-directory",
  1304. })
  1305. return
  1306. }
  1307. parts.push({
  1308. type: "file",
  1309. url: `file://${filepath}`,
  1310. filename: name,
  1311. mime: "text/plain",
  1312. })
  1313. }),
  1314. )
  1315. const model = await (async () => {
  1316. if (command.model) {
  1317. return Provider.parseModel(command.model)
  1318. }
  1319. if (command.agent) {
  1320. const cmdAgent = await Agent.get(command.agent)
  1321. if (cmdAgent.model) {
  1322. return cmdAgent.model
  1323. }
  1324. }
  1325. if (input.model) {
  1326. return Provider.parseModel(input.model)
  1327. }
  1328. return await Provider.defaultModel()
  1329. })()
  1330. const agent = await Agent.get(agentName)
  1331. if (agent.mode === "subagent" || command.subtask) {
  1332. using abort = lock(input.sessionID)
  1333. const userMsg: MessageV2.User = {
  1334. id: Identifier.ascending("message"),
  1335. sessionID: input.sessionID,
  1336. time: {
  1337. created: Date.now(),
  1338. },
  1339. role: "user",
  1340. }
  1341. await Session.updateMessage(userMsg)
  1342. const userPart: MessageV2.Part = {
  1343. type: "text",
  1344. id: Identifier.ascending("part"),
  1345. messageID: userMsg.id,
  1346. sessionID: input.sessionID,
  1347. text: "The following tool was executed by the user",
  1348. synthetic: true,
  1349. }
  1350. await Session.updatePart(userPart)
  1351. const assistantMsg: MessageV2.Assistant = {
  1352. id: Identifier.ascending("message"),
  1353. sessionID: input.sessionID,
  1354. system: [],
  1355. mode: agentName,
  1356. cost: 0,
  1357. path: {
  1358. cwd: Instance.directory,
  1359. root: Instance.worktree,
  1360. },
  1361. time: {
  1362. created: Date.now(),
  1363. },
  1364. role: "assistant",
  1365. tokens: {
  1366. input: 0,
  1367. output: 0,
  1368. reasoning: 0,
  1369. cache: { read: 0, write: 0 },
  1370. },
  1371. modelID: model.modelID,
  1372. providerID: model.providerID,
  1373. }
  1374. await Session.updateMessage(assistantMsg)
  1375. const args = {
  1376. description: "Consulting " + agent.name,
  1377. subagent_type: agent.name,
  1378. prompt: template,
  1379. }
  1380. const toolPart: MessageV2.ToolPart = {
  1381. type: "tool",
  1382. id: Identifier.ascending("part"),
  1383. messageID: assistantMsg.id,
  1384. sessionID: input.sessionID,
  1385. tool: "task",
  1386. callID: ulid(),
  1387. state: {
  1388. status: "running",
  1389. time: {
  1390. start: Date.now(),
  1391. },
  1392. input: {
  1393. description: args.description,
  1394. subagent_type: args.subagent_type,
  1395. // truncate prompt to preserve context
  1396. prompt: args.prompt.length > 100 ? args.prompt.substring(0, 97) + "..." : args.prompt,
  1397. },
  1398. },
  1399. }
  1400. await Session.updatePart(toolPart)
  1401. const result = await TaskTool.init().then((t) =>
  1402. t.execute(args, {
  1403. sessionID: input.sessionID,
  1404. abort: abort.signal,
  1405. agent: agent.name,
  1406. messageID: assistantMsg.id,
  1407. extra: {},
  1408. metadata: async (metadata) => {
  1409. if (toolPart.state.status === "running") {
  1410. toolPart.state.metadata = metadata.metadata
  1411. toolPart.state.title = metadata.title
  1412. await Session.updatePart(toolPart)
  1413. }
  1414. },
  1415. }),
  1416. )
  1417. assistantMsg.time.completed = Date.now()
  1418. await Session.updateMessage(assistantMsg)
  1419. if (toolPart.state.status === "running") {
  1420. toolPart.state = {
  1421. status: "completed",
  1422. time: {
  1423. ...toolPart.state.time,
  1424. end: Date.now(),
  1425. },
  1426. input: toolPart.state.input,
  1427. title: "",
  1428. metadata: result.metadata,
  1429. output: result.output,
  1430. }
  1431. await Session.updatePart(toolPart)
  1432. }
  1433. return { info: assistantMsg, parts: [toolPart] }
  1434. }
  1435. return prompt({
  1436. sessionID: input.sessionID,
  1437. messageID: input.messageID,
  1438. model,
  1439. agent: agentName,
  1440. parts,
  1441. })
  1442. }
  1443. async function ensureTitle(input: {
  1444. session: Session.Info
  1445. message: MessageV2.WithParts
  1446. history: MessageV2.WithParts[]
  1447. providerID: string
  1448. modelID: string
  1449. }) {
  1450. if (input.session.parentID) return
  1451. const isFirst =
  1452. input.history.filter((m) => m.info.role === "user" && !m.parts.every((p) => "synthetic" in p && p.synthetic))
  1453. .length === 1
  1454. if (!isFirst) return
  1455. const small =
  1456. (await Provider.getSmallModel(input.providerID)) ?? (await Provider.getModel(input.providerID, input.modelID))
  1457. const options = {
  1458. ...ProviderTransform.options(small.providerID, small.modelID, input.session.id),
  1459. ...small.info.options,
  1460. }
  1461. if (small.providerID === "openai") {
  1462. options["reasoningEffort"] = "minimal"
  1463. }
  1464. if (small.providerID === "google") {
  1465. options["thinkingConfig"] = {
  1466. thinkingBudget: 0,
  1467. }
  1468. }
  1469. generateText({
  1470. maxOutputTokens: small.info.reasoning ? 1500 : 20,
  1471. providerOptions: {
  1472. [small.providerID]: options,
  1473. },
  1474. messages: [
  1475. ...SystemPrompt.title(small.providerID).map(
  1476. (x): ModelMessage => ({
  1477. role: "system",
  1478. content: x,
  1479. }),
  1480. ),
  1481. ...MessageV2.toModelMessage([
  1482. {
  1483. info: {
  1484. id: Identifier.ascending("message"),
  1485. role: "user",
  1486. sessionID: input.session.id,
  1487. time: {
  1488. created: Date.now(),
  1489. },
  1490. },
  1491. parts: input.message.parts,
  1492. },
  1493. ]),
  1494. ],
  1495. model: small.language,
  1496. })
  1497. .then((result) => {
  1498. if (result.text)
  1499. return Session.update(input.session.id, (draft) => {
  1500. const cleaned = result.text.replace(/<think>[\s\S]*?<\/think>\s*/g, "")
  1501. const title = cleaned.length > 100 ? cleaned.substring(0, 97) + "..." : cleaned
  1502. draft.title = title.trim()
  1503. })
  1504. })
  1505. .catch((error) => {
  1506. log.error("failed to generate title", { error, model: small.info.id })
  1507. })
  1508. }
  1509. }