index.ts 60 KB


  1. import os from "os"
  2. import path from "path"
  3. import fs from "fs/promises"
  4. import { spawn } from "child_process"
  5. import { Decimal } from "decimal.js"
  6. import { z, ZodSchema } from "zod"
  7. import {
  8. generateText,
  9. LoadAPIKeyError,
  10. streamText,
  11. tool,
  12. wrapLanguageModel,
  13. type Tool as AITool,
  14. type LanguageModelUsage,
  15. type ProviderMetadata,
  16. type ModelMessage,
  17. type StreamTextResult,
  18. } from "ai"
  19. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  20. import PROMPT_PLAN from "../session/prompt/plan.txt"
  21. import BUILD_SWITCH from "../session/prompt/build-switch.txt"
  22. import { Bus } from "../bus"
  23. import { Config } from "../config/config"
  24. import { Flag } from "../flag/flag"
  25. import { Identifier } from "../id/id"
  26. import { Installation } from "../installation"
  27. import { MCP } from "../mcp"
  28. import { Provider } from "../provider/provider"
  29. import { ProviderTransform } from "../provider/transform"
  30. import type { ModelsDev } from "../provider/models"
  31. import { Share } from "../share/share"
  32. import { Snapshot } from "../snapshot"
  33. import { Storage } from "../storage/storage"
  34. import { Log } from "../util/log"
  35. import { NamedError } from "../util/error"
  36. import { SystemPrompt } from "./system"
  37. import { FileTime } from "../file/time"
  38. import { MessageV2 } from "./message-v2"
  39. import { LSP } from "../lsp"
  40. import { ReadTool } from "../tool/read"
  41. import { mergeDeep, pipe, splitWhen } from "remeda"
  42. import { ToolRegistry } from "../tool/registry"
  43. import { Plugin } from "../plugin"
  44. import { Project } from "../project/project"
  45. import { Instance } from "../project/instance"
  46. import { Agent } from "../agent/agent"
  47. import { Permission } from "../permission"
  48. import { Wildcard } from "../util/wildcard"
  49. import { ulid } from "ulid"
  50. import { defer } from "../util/defer"
  51. import { Command } from "../command"
  52. import { $ } from "bun"
  53. import { ListTool } from "../tool/ls"
  54. import { Token } from "../util/token"
  55. export namespace Session {
  56. const log = Log.create({ service: "session" })
  57. const OUTPUT_TOKEN_MAX = 32_000
  58. const parentSessionTitlePrefix = "New session - "
  59. const childSessionTitlePrefix = "Child session - "
  60. function createDefaultTitle(isChild = false) {
  61. return (isChild ? childSessionTitlePrefix : parentSessionTitlePrefix) + new Date().toISOString()
  62. }
  63. function isDefaultTitle(title: string) {
  64. return title.startsWith(parentSessionTitlePrefix)
  65. }
  66. export const Info = z
  67. .object({
  68. id: Identifier.schema("session"),
  69. projectID: z.string(),
  70. directory: z.string(),
  71. parentID: Identifier.schema("session").optional(),
  72. share: z
  73. .object({
  74. url: z.string(),
  75. })
  76. .optional(),
  77. title: z.string(),
  78. version: z.string(),
  79. compaction: z
  80. .object({
  81. full: z.string().optional(),
  82. micro: z.string().optional(),
  83. })
  84. .optional(),
  85. time: z.object({
  86. created: z.number(),
  87. updated: z.number(),
  88. compacting: z.number().optional(),
  89. }),
  90. revert: z
  91. .object({
  92. messageID: z.string(),
  93. partID: z.string().optional(),
  94. snapshot: z.string().optional(),
  95. diff: z.string().optional(),
  96. })
  97. .optional(),
  98. })
  99. .openapi({
  100. ref: "Session",
  101. })
  102. export type Info = z.output<typeof Info>
  103. export const ShareInfo = z
  104. .object({
  105. secret: z.string(),
  106. url: z.string(),
  107. })
  108. .openapi({
  109. ref: "SessionShare",
  110. })
  111. export type ShareInfo = z.output<typeof ShareInfo>
  112. export const Event = {
  113. Updated: Bus.event(
  114. "session.updated",
  115. z.object({
  116. info: Info,
  117. }),
  118. ),
  119. Deleted: Bus.event(
  120. "session.deleted",
  121. z.object({
  122. info: Info,
  123. }),
  124. ),
  125. Idle: Bus.event(
  126. "session.idle",
  127. z.object({
  128. sessionID: z.string(),
  129. }),
  130. ),
  131. Error: Bus.event(
  132. "session.error",
  133. z.object({
  134. sessionID: z.string().optional(),
  135. error: MessageV2.Assistant.shape.error,
  136. }),
  137. ),
  138. Compacted: Bus.event(
  139. "session.compacted",
  140. z.object({
  141. sessionID: z.string(),
  142. }),
  143. ),
  144. }
  145. const state = Instance.state(
  146. () => {
  147. const pending = new Map<string, AbortController>()
  148. const queued = new Map<
  149. string,
  150. {
  151. input: ChatInput
  152. message: MessageV2.User
  153. parts: MessageV2.Part[]
  154. processed: boolean
  155. callback: (input: { info: MessageV2.Assistant; parts: MessageV2.Part[] }) => void
  156. }[]
  157. >()
  158. return {
  159. pending,
  160. queued,
  161. }
  162. },
  163. async (state) => {
  164. for (const [_, controller] of state.pending) {
  165. controller.abort()
  166. }
  167. },
  168. )
  169. export async function create(parentID?: string, title?: string) {
  170. return createNext({
  171. parentID,
  172. directory: Instance.directory,
  173. title,
  174. })
  175. }
  176. export async function createNext(input: { id?: string; title?: string; parentID?: string; directory: string }) {
  177. const result: Info = {
  178. id: Identifier.descending("session", input.id),
  179. version: Installation.VERSION,
  180. projectID: Instance.project.id,
  181. directory: input.directory,
  182. parentID: input.parentID,
  183. title: input.title ?? createDefaultTitle(!!input.parentID),
  184. time: {
  185. created: Date.now(),
  186. updated: Date.now(),
  187. },
  188. }
  189. log.info("created", result)
  190. await Storage.write(["session", Instance.project.id, result.id], result)
  191. const cfg = await Config.get()
  192. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
  193. share(result.id)
  194. .then((share) => {
  195. update(result.id, (draft) => {
  196. draft.share = share
  197. })
  198. })
  199. .catch(() => {
  200. // Silently ignore sharing errors during session creation
  201. })
  202. Bus.publish(Event.Updated, {
  203. info: result,
  204. })
  205. return result
  206. }
  207. export async function get(id: string) {
  208. const read = await Storage.read<Info>(["session", Instance.project.id, id])
  209. return read as Info
  210. }
  211. export async function getShare(id: string) {
  212. return Storage.read<ShareInfo>(["share", id])
  213. }
  214. export async function share(id: string) {
  215. const cfg = await Config.get()
  216. if (cfg.share === "disabled") {
  217. throw new Error("Sharing is disabled in configuration")
  218. }
  219. const session = await get(id)
  220. if (session.share) return session.share
  221. const share = await Share.create(id)
  222. await update(id, (draft) => {
  223. draft.share = {
  224. url: share.url,
  225. }
  226. })
  227. await Storage.write(["share", id], share)
  228. await Share.sync("session/info/" + id, session)
  229. for (const msg of await messages(id)) {
  230. await Share.sync("session/message/" + id + "/" + msg.info.id, msg.info)
  231. for (const part of msg.parts) {
  232. await Share.sync("session/part/" + id + "/" + msg.info.id + "/" + part.id, part)
  233. }
  234. }
  235. return share
  236. }
  237. export async function unshare(id: string) {
  238. const share = await getShare(id)
  239. if (!share) return
  240. await Storage.remove(["share", id])
  241. await update(id, (draft) => {
  242. draft.share = undefined
  243. })
  244. await Share.remove(id, share.secret)
  245. }
  246. export async function update(id: string, editor: (session: Info) => void) {
  247. const project = Instance.project
  248. const result = await Storage.update<Info>(["session", project.id, id], (draft) => {
  249. editor(draft)
  250. draft.time.updated = Date.now()
  251. })
  252. Bus.publish(Event.Updated, {
  253. info: result,
  254. })
  255. return result
  256. }
  257. export async function messages(sessionID: string) {
  258. const result = [] as {
  259. info: MessageV2.Info
  260. parts: MessageV2.Part[]
  261. }[]
  262. for (const p of await Storage.list(["message", sessionID])) {
  263. const read = await Storage.read<MessageV2.Info>(p)
  264. result.push({
  265. info: read,
  266. parts: await getParts(read.id),
  267. })
  268. }
  269. result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
  270. return result
  271. }
  272. export async function getMessage(sessionID: string, messageID: string) {
  273. return {
  274. info: await Storage.read<MessageV2.Info>(["message", sessionID, messageID]),
  275. parts: await getParts(messageID),
  276. }
  277. }
  278. export async function getParts(messageID: string) {
  279. const result = [] as MessageV2.Part[]
  280. for (const item of await Storage.list(["part", messageID])) {
  281. const read = await Storage.read<MessageV2.Part>(item)
  282. result.push(read)
  283. }
  284. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  285. return result
  286. }
  287. export async function* list() {
  288. const project = Instance.project
  289. for (const item of await Storage.list(["session", project.id])) {
  290. yield Storage.read<Info>(item)
  291. }
  292. }
  293. export async function children(parentID: string) {
  294. const project = Instance.project
  295. const result = [] as Session.Info[]
  296. for (const item of await Storage.list(["session", project.id])) {
  297. const session = await Storage.read<Info>(item)
  298. if (session.parentID !== parentID) continue
  299. result.push(session)
  300. }
  301. return result
  302. }
  303. export function abort(sessionID: string) {
  304. const controller = state().pending.get(sessionID)
  305. if (!controller) return false
  306. log.info("aborting", {
  307. sessionID,
  308. })
  309. controller.abort()
  310. state().pending.delete(sessionID)
  311. return true
  312. }
  313. export async function remove(sessionID: string, emitEvent = true) {
  314. const project = Instance.project
  315. try {
  316. abort(sessionID)
  317. const session = await get(sessionID)
  318. for (const child of await children(sessionID)) {
  319. await remove(child.id, false)
  320. }
  321. await unshare(sessionID).catch(() => {})
  322. for (const msg of await Storage.list(["message", sessionID])) {
  323. for (const part of await Storage.list(["part", msg.at(-1)!])) {
  324. await Storage.remove(part)
  325. }
  326. await Storage.remove(msg)
  327. }
  328. await Storage.remove(["session", project.id, sessionID])
  329. if (emitEvent) {
  330. Bus.publish(Event.Deleted, {
  331. info: session,
  332. })
  333. }
  334. } catch (e) {
  335. log.error(e)
  336. }
  337. }
  338. async function updateMessage(msg: MessageV2.Info) {
  339. await Storage.write(["message", msg.sessionID, msg.id], msg)
  340. Bus.publish(MessageV2.Event.Updated, {
  341. info: msg,
  342. })
  343. return msg
  344. }
  345. async function updatePart(part: MessageV2.Part) {
  346. await Storage.write(["part", part.messageID, part.id], part)
  347. Bus.publish(MessageV2.Event.PartUpdated, {
  348. part,
  349. })
  350. return part
  351. }
  352. async function cleanupRevert(session: Info) {
  353. if (!session.revert) return
  354. const sessionID = session.id
  355. let msgs = await messages(sessionID)
  356. const messageID = session.revert.messageID
  357. const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
  358. msgs = preserve
  359. for (const msg of remove) {
  360. await Storage.remove(["message", sessionID, msg.info.id])
  361. await Bus.publish(MessageV2.Event.Removed, { sessionID: sessionID, messageID: msg.info.id })
  362. }
  363. const last = preserve.at(-1)
  364. if (session.revert.partID && last) {
  365. const partID = session.revert.partID
  366. const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
  367. last.parts = preserveParts
  368. for (const part of removeParts) {
  369. await Storage.remove(["part", last.info.id, part.id])
  370. await Bus.publish(MessageV2.Event.PartRemoved, {
  371. sessionID: sessionID,
  372. messageID: last.info.id,
  373. partID: part.id,
  374. })
  375. }
  376. }
  377. await update(sessionID, (draft) => {
  378. draft.revert = undefined
  379. })
  380. }
  381. export const PromptInput = z.object({
  382. sessionID: Identifier.schema("session"),
  383. messageID: Identifier.schema("message").optional(),
  384. model: z
  385. .object({
  386. providerID: z.string(),
  387. modelID: z.string(),
  388. })
  389. .optional(),
  390. agent: z.string().optional(),
  391. system: z.string().optional(),
  392. tools: z.record(z.boolean()).optional(),
  393. parts: z.array(
  394. z.discriminatedUnion("type", [
  395. MessageV2.TextPart.omit({
  396. messageID: true,
  397. sessionID: true,
  398. })
  399. .partial({
  400. id: true,
  401. })
  402. .openapi({
  403. ref: "TextPartInput",
  404. }),
  405. MessageV2.FilePart.omit({
  406. messageID: true,
  407. sessionID: true,
  408. })
  409. .partial({
  410. id: true,
  411. })
  412. .openapi({
  413. ref: "FilePartInput",
  414. }),
  415. MessageV2.AgentPart.omit({
  416. messageID: true,
  417. sessionID: true,
  418. })
  419. .partial({
  420. id: true,
  421. })
  422. .openapi({
  423. ref: "AgentPartInput",
  424. }),
  425. ]),
  426. ),
  427. })
  428. export type ChatInput = z.infer<typeof PromptInput>
  429. export async function prompt(
  430. input: z.infer<typeof PromptInput>,
  431. ): Promise<{ info: MessageV2.Assistant; parts: MessageV2.Part[] }> {
  432. const l = log.clone().tag("session", input.sessionID)
  433. l.info("chatting")
  434. const inputAgent = input.agent ?? "build"
  435. // Process revert cleanup first, before creating new messages
  436. const session = await get(input.sessionID)
  437. if (session.revert) {
  438. cleanupRevert(session)
  439. }
  440. const userMsg: MessageV2.Info = {
  441. id: input.messageID ?? Identifier.ascending("message"),
  442. role: "user",
  443. sessionID: input.sessionID,
  444. time: {
  445. created: Date.now(),
  446. },
  447. }
  448. const userParts = await Promise.all(
  449. input.parts.map(async (part): Promise<MessageV2.Part[]> => {
  450. if (part.type === "file") {
  451. const url = new URL(part.url)
  452. switch (url.protocol) {
  453. case "data:":
  454. if (part.mime === "text/plain") {
  455. return [
  456. {
  457. id: Identifier.ascending("part"),
  458. messageID: userMsg.id,
  459. sessionID: input.sessionID,
  460. type: "text",
  461. synthetic: true,
  462. text: `Called the Read tool with the following input: ${JSON.stringify({ filePath: part.filename })}`,
  463. },
  464. {
  465. id: Identifier.ascending("part"),
  466. messageID: userMsg.id,
  467. sessionID: input.sessionID,
  468. type: "text",
  469. synthetic: true,
  470. text: Buffer.from(part.url, "base64url").toString(),
  471. },
  472. {
  473. ...part,
  474. id: part.id ?? Identifier.ascending("part"),
  475. messageID: userMsg.id,
  476. sessionID: input.sessionID,
  477. },
  478. ]
  479. }
  480. break
  481. case "file:":
  482. // have to normalize, symbol search returns absolute paths
  483. // Decode the pathname since URL constructor doesn't automatically decode it
  484. const filePath = decodeURIComponent(url.pathname)
  485. if (part.mime === "text/plain") {
  486. let offset: number | undefined = undefined
  487. let limit: number | undefined = undefined
  488. const range = {
  489. start: url.searchParams.get("start"),
  490. end: url.searchParams.get("end"),
  491. }
  492. if (range.start != null) {
  493. const filePath = part.url.split("?")[0]
  494. let start = parseInt(range.start)
  495. let end = range.end ? parseInt(range.end) : undefined
  496. // some LSP servers (eg, gopls) don't give full range in
  497. // workspace/symbol searches, so we'll try to find the
  498. // symbol in the document to get the full range
  499. if (start === end) {
  500. const symbols = await LSP.documentSymbol(filePath)
  501. for (const symbol of symbols) {
  502. let range: LSP.Range | undefined
  503. if ("range" in symbol) {
  504. range = symbol.range
  505. } else if ("location" in symbol) {
  506. range = symbol.location.range
  507. }
  508. if (range?.start?.line && range?.start?.line === start) {
  509. start = range.start.line
  510. end = range?.end?.line ?? start
  511. break
  512. }
  513. }
  514. }
  515. offset = Math.max(start - 1, 0)
  516. if (end) {
  517. limit = end - offset
  518. }
  519. }
  520. const args = { filePath, offset, limit }
  521. const result = await ReadTool.init().then((t) =>
  522. t.execute(args, {
  523. sessionID: input.sessionID,
  524. abort: new AbortController().signal,
  525. agent: input.agent!,
  526. messageID: userMsg.id,
  527. extra: { bypassCwdCheck: true },
  528. metadata: async () => {},
  529. }),
  530. )
  531. return [
  532. {
  533. id: Identifier.ascending("part"),
  534. messageID: userMsg.id,
  535. sessionID: input.sessionID,
  536. type: "text",
  537. synthetic: true,
  538. text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
  539. },
  540. {
  541. id: Identifier.ascending("part"),
  542. messageID: userMsg.id,
  543. sessionID: input.sessionID,
  544. type: "text",
  545. synthetic: true,
  546. text: result.output,
  547. },
  548. {
  549. ...part,
  550. id: part.id ?? Identifier.ascending("part"),
  551. messageID: userMsg.id,
  552. sessionID: input.sessionID,
  553. },
  554. ]
  555. }
  556. if (part.mime === "application/x-directory") {
  557. const args = { path: filePath }
  558. const result = await ListTool.init().then((t) =>
  559. t.execute(args, {
  560. sessionID: input.sessionID,
  561. abort: new AbortController().signal,
  562. agent: input.agent!,
  563. messageID: userMsg.id,
  564. extra: { bypassCwdCheck: true },
  565. metadata: async () => {},
  566. }),
  567. )
  568. return [
  569. {
  570. id: Identifier.ascending("part"),
  571. messageID: userMsg.id,
  572. sessionID: input.sessionID,
  573. type: "text",
  574. synthetic: true,
  575. text: `Called the list tool with the following input: ${JSON.stringify(args)}`,
  576. },
  577. {
  578. id: Identifier.ascending("part"),
  579. messageID: userMsg.id,
  580. sessionID: input.sessionID,
  581. type: "text",
  582. synthetic: true,
  583. text: result.output,
  584. },
  585. {
  586. ...part,
  587. id: part.id ?? Identifier.ascending("part"),
  588. messageID: userMsg.id,
  589. sessionID: input.sessionID,
  590. },
  591. ]
  592. }
  593. const file = Bun.file(filePath)
  594. FileTime.read(input.sessionID, filePath)
  595. return [
  596. {
  597. id: Identifier.ascending("part"),
  598. messageID: userMsg.id,
  599. sessionID: input.sessionID,
  600. type: "text",
  601. text: `Called the Read tool with the following input: {\"filePath\":\"${filePath}\"}`,
  602. synthetic: true,
  603. },
  604. {
  605. id: part.id ?? Identifier.ascending("part"),
  606. messageID: userMsg.id,
  607. sessionID: input.sessionID,
  608. type: "file",
  609. url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
  610. mime: part.mime,
  611. filename: part.filename!,
  612. source: part.source,
  613. },
  614. ]
  615. }
  616. }
  617. if (part.type === "agent") {
  618. return [
  619. {
  620. id: Identifier.ascending("part"),
  621. ...part,
  622. messageID: userMsg.id,
  623. sessionID: input.sessionID,
  624. },
  625. {
  626. id: Identifier.ascending("part"),
  627. messageID: userMsg.id,
  628. sessionID: input.sessionID,
  629. type: "text",
  630. synthetic: true,
  631. text:
  632. "Use the above message and context to generate a prompt and call the task tool with subagent: " +
  633. part.name,
  634. },
  635. ]
  636. }
  637. return [
  638. {
  639. id: Identifier.ascending("part"),
  640. ...part,
  641. messageID: userMsg.id,
  642. sessionID: input.sessionID,
  643. },
  644. ]
  645. }),
  646. ).then((x) => x.flat())
  647. await Plugin.trigger(
  648. "chat.message",
  649. {},
  650. {
  651. message: userMsg,
  652. parts: userParts,
  653. },
  654. )
  655. await updateMessage(userMsg)
  656. for (const part of userParts) {
  657. await updatePart(part)
  658. }
  659. // mark session as updated
  660. // used for session list sorting (indicates when session was most recently interacted with)
  661. await update(input.sessionID, (_draft) => {})
  662. if (isLocked(input.sessionID)) {
  663. return new Promise((resolve) => {
  664. const queue = state().queued.get(input.sessionID) ?? []
  665. queue.push({
  666. input: input,
  667. message: userMsg,
  668. parts: userParts,
  669. processed: false,
  670. callback: resolve,
  671. })
  672. state().queued.set(input.sessionID, queue)
  673. })
  674. }
  675. const agent = await Agent.get(inputAgent)
  676. const model = await (async () => {
  677. if (input.model) {
  678. return input.model
  679. }
  680. if (agent.model) {
  681. return agent.model
  682. }
  683. return Provider.defaultModel()
  684. })().then((x) => Provider.getModel(x.providerID, x.modelID))
  685. let msgs = await messages(input.sessionID)
  686. const lastSummary = Math.max(
  687. 0,
  688. msgs.findLastIndex((msg) => msg.info.role === "assistant" && msg.info.summary === true),
  689. )
  690. msgs = msgs.slice(lastSummary)
  691. const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant")
  692. if (
  693. lastAssistant?.info.role === "assistant" &&
  694. needsCompaction({
  695. tokens: lastAssistant.info.tokens,
  696. model: model.info,
  697. })
  698. ) {
  699. const msg = await summarize({
  700. sessionID: input.sessionID,
  701. providerID: model.providerID,
  702. modelID: model.info.id,
  703. })
  704. msgs = [msg]
  705. }
  706. const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
  707. using abort = lock(input.sessionID)
  708. const numRealUserMsgs = msgs.filter(
  709. (m) => m.info.role === "user" && !m.parts.every((p) => "synthetic" in p && p.synthetic),
  710. ).length
  711. if (numRealUserMsgs === 1 && !session.parentID && isDefaultTitle(session.title)) {
  712. const small = (await Provider.getSmallModel(model.providerID)) ?? model
  713. const options = {
  714. ...ProviderTransform.options(small.providerID, small.modelID, input.sessionID),
  715. ...small.info.options,
  716. }
  717. if (small.providerID === "openai") {
  718. options["reasoningEffort"] = "minimal"
  719. }
  720. if (small.providerID === "google") {
  721. options["thinkingConfig"] = {
  722. thinkingBudget: 0,
  723. }
  724. }
  725. generateText({
  726. maxOutputTokens: small.info.reasoning ? 1500 : 20,
  727. providerOptions: {
  728. [model.providerID]: options,
  729. },
  730. messages: [
  731. ...SystemPrompt.title(model.providerID).map(
  732. (x): ModelMessage => ({
  733. role: "system",
  734. content: x,
  735. }),
  736. ),
  737. ...MessageV2.toModelMessage([
  738. {
  739. info: {
  740. id: Identifier.ascending("message"),
  741. role: "user",
  742. sessionID: input.sessionID,
  743. time: {
  744. created: Date.now(),
  745. },
  746. },
  747. parts: userParts,
  748. },
  749. ]),
  750. ],
  751. model: small.language,
  752. })
  753. .then((result) => {
  754. if (result.text)
  755. return Session.update(input.sessionID, (draft) => {
  756. const cleaned = result.text.replace(/<think>[\s\S]*?<\/think>\s*/g, "")
  757. const title = cleaned.length > 100 ? cleaned.substring(0, 97) + "..." : cleaned
  758. draft.title = title.trim()
  759. })
  760. })
  761. .catch((error) => {
  762. log.error("failed to generate title", { error, model: small.info.id })
  763. })
  764. }
  765. if (agent.name === "plan") {
  766. msgs.at(-1)?.parts.push({
  767. id: Identifier.ascending("part"),
  768. messageID: userMsg.id,
  769. sessionID: input.sessionID,
  770. type: "text",
  771. text: PROMPT_PLAN,
  772. synthetic: true,
  773. })
  774. }
  775. const wasPlan = msgs.some((msg) => msg.info.role === "assistant" && msg.info.mode === "plan")
  776. if (wasPlan && agent.name === "build") {
  777. msgs.at(-1)?.parts.push({
  778. id: Identifier.ascending("part"),
  779. messageID: userMsg.id,
  780. sessionID: input.sessionID,
  781. type: "text",
  782. text: BUILD_SWITCH,
  783. synthetic: true,
  784. })
  785. }
  786. let system = SystemPrompt.header(model.providerID)
  787. system.push(
  788. ...(() => {
  789. if (input.system) return [input.system]
  790. if (agent.prompt) return [agent.prompt]
  791. return SystemPrompt.provider(model.modelID)
  792. })(),
  793. )
  794. system.push(...(await SystemPrompt.environment()))
  795. system.push(...(await SystemPrompt.custom()))
  796. // max 2 system prompt messages for caching purposes
  797. const [first, ...rest] = system
  798. system = [first, rest.join("\n")]
  799. const processor = await createProcessor({
  800. sessionID: input.sessionID,
  801. model: model.info,
  802. providerID: model.providerID,
  803. agent: inputAgent,
  804. system,
  805. })
  806. await using _ = defer(async () => {
  807. if (processor.message.time.completed) return
  808. await Storage.remove(["session", "message", input.sessionID, processor.message.id])
  809. await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: processor.message.id })
  810. })
  811. const tools: Record<string, AITool> = {}
  812. const enabledTools = pipe(
  813. agent.tools,
  814. mergeDeep(await ToolRegistry.enabled(model.providerID, model.modelID, agent)),
  815. mergeDeep(input.tools ?? {}),
  816. )
  817. for (const item of await ToolRegistry.tools(model.providerID, model.modelID)) {
  818. if (Wildcard.all(item.id, enabledTools) === false) continue
  819. tools[item.id] = tool({
  820. id: item.id as any,
  821. description: item.description,
  822. inputSchema: item.parameters as ZodSchema,
  823. async execute(args, options) {
  824. await Plugin.trigger(
  825. "tool.execute.before",
  826. {
  827. tool: item.id,
  828. sessionID: input.sessionID,
  829. callID: options.toolCallId,
  830. },
  831. {
  832. args,
  833. },
  834. )
  835. const result = await item.execute(args, {
  836. sessionID: input.sessionID,
  837. abort: options.abortSignal!,
  838. messageID: processor.message.id,
  839. callID: options.toolCallId,
  840. agent: agent.name,
  841. metadata: async (val) => {
  842. const match = processor.partFromToolCall(options.toolCallId)
  843. if (match && match.state.status === "running") {
  844. await updatePart({
  845. ...match,
  846. state: {
  847. title: val.title,
  848. metadata: val.metadata,
  849. status: "running",
  850. input: args,
  851. time: {
  852. start: Date.now(),
  853. },
  854. },
  855. })
  856. }
  857. },
  858. })
  859. await Plugin.trigger(
  860. "tool.execute.after",
  861. {
  862. tool: item.id,
  863. sessionID: input.sessionID,
  864. callID: options.toolCallId,
  865. },
  866. result,
  867. )
  868. return result
  869. },
  870. toModelOutput(result) {
  871. return {
  872. type: "text",
  873. value: result.output,
  874. }
  875. },
  876. })
  877. }
  878. for (const [key, item] of Object.entries(await MCP.tools())) {
  879. if (Wildcard.all(key, enabledTools) === false) continue
  880. const execute = item.execute
  881. if (!execute) continue
  882. item.execute = async (args, opts) => {
  883. await Plugin.trigger(
  884. "tool.execute.before",
  885. {
  886. tool: key,
  887. sessionID: input.sessionID,
  888. callID: opts.toolCallId,
  889. },
  890. {
  891. args,
  892. },
  893. )
  894. const result = await execute(args, opts)
  895. const output = result.content
  896. .filter((x: any) => x.type === "text")
  897. .map((x: any) => x.text)
  898. .join("\n\n")
  899. await Plugin.trigger(
  900. "tool.execute.after",
  901. {
  902. tool: key,
  903. sessionID: input.sessionID,
  904. callID: opts.toolCallId,
  905. },
  906. result,
  907. )
  908. return {
  909. output,
  910. }
  911. }
  912. item.toModelOutput = (result) => {
  913. return {
  914. type: "text",
  915. value: result.output,
  916. }
  917. }
  918. tools[key] = item
  919. }
  920. const params = await Plugin.trigger(
  921. "chat.params",
  922. {
  923. model: model.info,
  924. provider: await Provider.getProvider(model.providerID),
  925. message: userMsg,
  926. },
  927. {
  928. temperature: model.info.temperature
  929. ? (agent.temperature ?? ProviderTransform.temperature(model.providerID, model.modelID))
  930. : undefined,
  931. topP: agent.topP ?? ProviderTransform.topP(model.providerID, model.modelID),
  932. options: {
  933. ...ProviderTransform.options(model.providerID, model.modelID, input.sessionID),
  934. ...model.info.options,
  935. ...agent.options,
  936. },
  937. },
  938. )
  939. let pointer = 0
  940. const stream = streamText({
  941. onError(e) {
  942. log.error("streamText error", {
  943. error: e,
  944. })
  945. },
  946. async prepareStep({ messages, steps }) {
  947. log.info("search", {
  948. length: messages.length,
  949. })
  950. const step = steps.at(-1)
  951. if (
  952. step &&
  953. needsCompaction({
  954. tokens: getUsage(model.info, step.usage, step.providerMetadata).tokens,
  955. model: model.info,
  956. })
  957. ) {
  958. await processor.end()
  959. const msg = await Session.summarize({
  960. sessionID: input.sessionID,
  961. providerID: model.providerID,
  962. modelID: model.info.id,
  963. })
  964. await processor.next()
  965. pointer = messages.length - 1
  966. messages.push(...MessageV2.toModelMessage([msg]))
  967. }
  968. // Add queued messages to the stream
  969. const queue = (state().queued.get(input.sessionID) ?? []).filter((x) => !x.processed)
  970. if (queue.length) {
  971. await processor.end()
  972. for (const item of queue) {
  973. if (item.processed) continue
  974. messages.push(
  975. ...MessageV2.toModelMessage([
  976. {
  977. info: item.message,
  978. parts: item.parts,
  979. },
  980. ]),
  981. )
  982. item.processed = true
  983. }
  984. await processor.next()
  985. }
  986. return {
  987. messages: messages.slice(pointer),
  988. }
  989. },
  990. async experimental_repairToolCall(input) {
  991. const lower = input.toolCall.toolName.toLowerCase()
  992. if (lower !== input.toolCall.toolName && tools[lower]) {
  993. log.info("repairing tool call", {
  994. tool: input.toolCall.toolName,
  995. repaired: lower,
  996. })
  997. return {
  998. ...input.toolCall,
  999. toolName: lower,
  1000. }
  1001. }
  1002. return {
  1003. ...input.toolCall,
  1004. input: JSON.stringify({
  1005. tool: input.toolCall.toolName,
  1006. error: input.error.message,
  1007. }),
  1008. toolName: "invalid",
  1009. }
  1010. },
  1011. headers:
  1012. model.providerID === "opencode"
  1013. ? {
  1014. "x-opencode-session": input.sessionID,
  1015. "x-opencode-request": userMsg.id,
  1016. }
  1017. : undefined,
  1018. maxRetries: 3,
  1019. activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
  1020. maxOutputTokens: ProviderTransform.maxOutputTokens(model.providerID, outputLimit, params.options),
  1021. abortSignal: abort.signal,
  1022. stopWhen: async ({ steps }) => {
  1023. if (steps.length >= 1000) {
  1024. return true
  1025. }
  1026. // Check if processor flagged that we should stop
  1027. if (processor.getShouldStop()) {
  1028. return true
  1029. }
  1030. return false
  1031. },
  1032. providerOptions: {
  1033. [model.providerID]: params.options,
  1034. },
  1035. temperature: params.temperature,
  1036. topP: params.topP,
  1037. messages: [
  1038. ...system.map(
  1039. (x): ModelMessage => ({
  1040. role: "system",
  1041. content: x,
  1042. }),
  1043. ),
  1044. ...MessageV2.toModelMessage(msgs.filter((m) => !(m.info.role === "assistant" && m.info.error))),
  1045. ],
  1046. tools: model.info.tool_call === false ? undefined : tools,
  1047. model: wrapLanguageModel({
  1048. model: model.language,
  1049. middleware: [
  1050. {
  1051. async transformParams(args) {
  1052. if (args.type === "stream") {
  1053. // @ts-expect-error
  1054. args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
  1055. }
  1056. return args.params
  1057. },
  1058. },
  1059. ],
  1060. }),
  1061. })
  1062. const result = await processor.process(stream)
  1063. const queued = state().queued.get(input.sessionID) ?? []
  1064. const unprocessed = queued.find((x) => !x.processed)
  1065. if (unprocessed) {
  1066. unprocessed.processed = true
  1067. return prompt(unprocessed.input)
  1068. }
  1069. for (const item of queued) {
  1070. item.callback(result)
  1071. }
  1072. state().queued.delete(input.sessionID)
  1073. return result
  1074. }
  1075. export const ShellInput = z.object({
  1076. sessionID: Identifier.schema("session"),
  1077. agent: z.string(),
  1078. command: z.string(),
  1079. })
  1080. export type ShellInput = z.infer<typeof ShellInput>
  1081. export async function shell(input: ShellInput) {
  1082. using abort = lock(input.sessionID)
  1083. const session = await get(input.sessionID)
  1084. if (session.revert) {
  1085. cleanupRevert(session)
  1086. }
  1087. const userMsg: MessageV2.User = {
  1088. id: Identifier.ascending("message"),
  1089. sessionID: input.sessionID,
  1090. time: {
  1091. created: Date.now(),
  1092. },
  1093. role: "user",
  1094. }
  1095. await updateMessage(userMsg)
  1096. const userPart: MessageV2.Part = {
  1097. type: "text",
  1098. id: Identifier.ascending("part"),
  1099. messageID: userMsg.id,
  1100. sessionID: input.sessionID,
  1101. text: "The following tool was executed by the user",
  1102. synthetic: true,
  1103. }
  1104. await updatePart(userPart)
  1105. const msg: MessageV2.Assistant = {
  1106. id: Identifier.ascending("message"),
  1107. sessionID: input.sessionID,
  1108. system: [],
  1109. mode: input.agent,
  1110. cost: 0,
  1111. path: {
  1112. cwd: Instance.directory,
  1113. root: Instance.worktree,
  1114. },
  1115. time: {
  1116. created: Date.now(),
  1117. },
  1118. role: "assistant",
  1119. tokens: {
  1120. input: 0,
  1121. output: 0,
  1122. reasoning: 0,
  1123. cache: { read: 0, write: 0 },
  1124. },
  1125. modelID: "",
  1126. providerID: "",
  1127. }
  1128. await updateMessage(msg)
  1129. const part: MessageV2.Part = {
  1130. type: "tool",
  1131. id: Identifier.ascending("part"),
  1132. messageID: msg.id,
  1133. sessionID: input.sessionID,
  1134. tool: "bash",
  1135. callID: ulid(),
  1136. state: {
  1137. status: "running",
  1138. time: {
  1139. start: Date.now(),
  1140. },
  1141. input: {
  1142. command: input.command,
  1143. },
  1144. },
  1145. }
  1146. await updatePart(part)
  1147. const shell = process.env["SHELL"] ?? "bash"
  1148. const shellName = path.basename(shell)
  1149. const scripts: Record<string, string> = {
  1150. nu: input.command,
  1151. fish: `eval "${input.command}"`,
  1152. }
  1153. const script =
  1154. scripts[shellName] ??
  1155. `[[ -f ~/.zshenv ]] && source ~/.zshenv >/dev/null 2>&1 || true
  1156. [[ -f "\${ZDOTDIR:-$HOME}/.zshrc" ]] && source "\${ZDOTDIR:-$HOME}/.zshrc" >/dev/null 2>&1 || true
  1157. [[ -f ~/.bashrc ]] && source ~/.bashrc >/dev/null 2>&1 || true
  1158. eval "${input.command}"`
  1159. const isFishOrNu = shellName === "fish" || shellName === "nu"
  1160. const args = isFishOrNu ? ["-c", script] : ["-c", "-l", script]
  1161. const proc = spawn(shell, args, {
  1162. cwd: Instance.directory,
  1163. signal: abort.signal,
  1164. detached: true,
  1165. stdio: ["ignore", "pipe", "pipe"],
  1166. env: {
  1167. ...process.env,
  1168. TERM: "dumb",
  1169. },
  1170. })
  1171. abort.signal.addEventListener("abort", () => {
  1172. if (!proc.pid) return
  1173. process.kill(-proc.pid)
  1174. })
  1175. let output = ""
  1176. proc.stdout?.on("data", (chunk) => {
  1177. output += chunk.toString()
  1178. if (part.state.status === "running") {
  1179. part.state.metadata = {
  1180. output: output,
  1181. description: "",
  1182. }
  1183. updatePart(part)
  1184. }
  1185. })
  1186. proc.stderr?.on("data", (chunk) => {
  1187. output += chunk.toString()
  1188. if (part.state.status === "running") {
  1189. part.state.metadata = {
  1190. output: output,
  1191. description: "",
  1192. }
  1193. updatePart(part)
  1194. }
  1195. })
  1196. await new Promise<void>((resolve) => {
  1197. proc.on("close", () => {
  1198. resolve()
  1199. })
  1200. })
  1201. msg.time.completed = Date.now()
  1202. await updateMessage(msg)
  1203. if (part.state.status === "running") {
  1204. part.state = {
  1205. status: "completed",
  1206. time: {
  1207. ...part.state.time,
  1208. end: Date.now(),
  1209. },
  1210. input: part.state.input,
  1211. title: "",
  1212. metadata: {
  1213. output,
  1214. description: "",
  1215. },
  1216. output,
  1217. }
  1218. await updatePart(part)
  1219. }
  1220. return { info: msg, parts: [part] }
  1221. }
  1222. export const CommandInput = z.object({
  1223. messageID: Identifier.schema("message").optional(),
  1224. sessionID: Identifier.schema("session"),
  1225. agent: z.string().optional(),
  1226. model: z.string().optional(),
  1227. arguments: z.string(),
  1228. command: z.string(),
  1229. })
  1230. export type CommandInput = z.infer<typeof CommandInput>
  1231. const bashRegex = /!`([^`]+)`/g
  1232. /**
  1233. * Regular expression to match @ file references in text
  1234. * Matches @ followed by file paths, excluding commas, periods at end of sentences, and backticks
  1235. * Does not match when preceded by word characters or backticks (to avoid email addresses and quoted references)
  1236. */
  1237. export const fileRegex = /(?<![\w`])@(\.?[^\s`,.]*(?:\.[^\s`,.]+)*)/g
  1238. export async function command(input: CommandInput) {
  1239. log.info("command", input)
  1240. const command = await Command.get(input.command)
  1241. const agent = command.agent ?? input.agent ?? "build"
  1242. let template = command.template.replace("$ARGUMENTS", input.arguments)
  1243. const bash = Array.from(template.matchAll(bashRegex))
  1244. if (bash.length > 0) {
  1245. const results = await Promise.all(
  1246. bash.map(async ([, cmd]) => {
  1247. try {
  1248. return await $`${{ raw: cmd }}`.nothrow().text()
  1249. } catch (error) {
  1250. return `Error executing command: ${error instanceof Error ? error.message : String(error)}`
  1251. }
  1252. }),
  1253. )
  1254. let index = 0
  1255. template = template.replace(bashRegex, () => results[index++])
  1256. }
  1257. const parts = [
  1258. {
  1259. type: "text",
  1260. text: template,
  1261. },
  1262. ] as ChatInput["parts"]
  1263. const matches = Array.from(template.matchAll(fileRegex))
  1264. await Promise.all(
  1265. matches.map(async (match) => {
  1266. const name = match[1]
  1267. const filepath = name.startsWith("~/")
  1268. ? path.join(os.homedir(), name.slice(2))
  1269. : path.resolve(Instance.worktree, name)
  1270. const stats = await fs.stat(filepath).catch(() => undefined)
  1271. if (!stats) {
  1272. const agent = await Agent.get(name)
  1273. if (agent) {
  1274. parts.push({
  1275. type: "agent",
  1276. name: agent.name,
  1277. })
  1278. }
  1279. return
  1280. }
  1281. if (stats.isDirectory()) {
  1282. parts.push({
  1283. type: "file",
  1284. url: `file://${filepath}`,
  1285. filename: name,
  1286. mime: "application/x-directory",
  1287. })
  1288. return
  1289. }
  1290. parts.push({
  1291. type: "file",
  1292. url: `file://${filepath}`,
  1293. filename: name,
  1294. mime: "text/plain",
  1295. })
  1296. }),
  1297. )
  1298. const model = await (async () => {
  1299. if (command.model) {
  1300. return Provider.parseModel(command.model)
  1301. }
  1302. if (command.agent) {
  1303. const agent = await Agent.get(command.agent)
  1304. if (agent.model) {
  1305. return agent.model
  1306. }
  1307. }
  1308. if (input.model) {
  1309. return Provider.parseModel(input.model)
  1310. }
  1311. return undefined
  1312. })()
  1313. return prompt({
  1314. sessionID: input.sessionID,
  1315. messageID: input.messageID,
  1316. model,
  1317. agent,
  1318. parts,
  1319. })
  1320. }
  1321. async function createProcessor(input: {
  1322. sessionID: string
  1323. providerID: string
  1324. model: ModelsDev.Model
  1325. system: string[]
  1326. agent: string
  1327. }) {
  1328. const toolcalls: Record<string, MessageV2.ToolPart> = {}
  1329. let snapshot: string | undefined
  1330. let shouldStop = false
  1331. async function createMessage() {
  1332. const msg: MessageV2.Info = {
  1333. id: Identifier.ascending("message"),
  1334. role: "assistant",
  1335. system: input.system,
  1336. mode: input.agent,
  1337. path: {
  1338. cwd: Instance.directory,
  1339. root: Instance.worktree,
  1340. },
  1341. cost: 0,
  1342. tokens: {
  1343. input: 0,
  1344. output: 0,
  1345. reasoning: 0,
  1346. cache: { read: 0, write: 0 },
  1347. },
  1348. modelID: input.model.id,
  1349. providerID: input.providerID,
  1350. time: {
  1351. created: Date.now(),
  1352. },
  1353. sessionID: input.sessionID,
  1354. }
  1355. await updateMessage(msg)
  1356. return msg
  1357. }
  1358. let assistantMsg = await createMessage()
  1359. const result = {
  1360. async end() {
  1361. if (assistantMsg) {
  1362. assistantMsg.time.completed = Date.now()
  1363. await updateMessage(assistantMsg)
  1364. }
  1365. },
  1366. async next() {
  1367. assistantMsg = await createMessage()
  1368. },
  1369. get message() {
  1370. return assistantMsg
  1371. },
  1372. partFromToolCall(toolCallID: string) {
  1373. return toolcalls[toolCallID]
  1374. },
  1375. getShouldStop() {
  1376. return shouldStop
  1377. },
  1378. async process(stream: StreamTextResult<Record<string, AITool>, never>) {
  1379. try {
  1380. let currentText: MessageV2.TextPart | undefined
  1381. let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
  1382. for await (const value of stream.fullStream) {
  1383. log.info("part", {
  1384. type: value.type,
  1385. })
  1386. switch (value.type) {
  1387. case "start":
  1388. break
  1389. case "reasoning-start":
  1390. if (value.id in reasoningMap) {
  1391. continue
  1392. }
  1393. reasoningMap[value.id] = {
  1394. id: Identifier.ascending("part"),
  1395. messageID: assistantMsg.id,
  1396. sessionID: assistantMsg.sessionID,
  1397. type: "reasoning",
  1398. text: "",
  1399. time: {
  1400. start: Date.now(),
  1401. },
  1402. }
  1403. break
  1404. case "reasoning-delta":
  1405. if (value.id in reasoningMap) {
  1406. const part = reasoningMap[value.id]
  1407. part.text += value.text
  1408. if (part.text) await updatePart(part)
  1409. }
  1410. break
  1411. case "reasoning-end":
  1412. if (value.id in reasoningMap) {
  1413. const part = reasoningMap[value.id]
  1414. part.text = part.text.trimEnd()
  1415. part.metadata = value.providerMetadata
  1416. part.time = {
  1417. ...part.time,
  1418. end: Date.now(),
  1419. }
  1420. await updatePart(part)
  1421. delete reasoningMap[value.id]
  1422. }
  1423. break
  1424. case "tool-input-start":
  1425. const part = await updatePart({
  1426. id: toolcalls[value.id]?.id ?? Identifier.ascending("part"),
  1427. messageID: assistantMsg.id,
  1428. sessionID: assistantMsg.sessionID,
  1429. type: "tool",
  1430. tool: value.toolName,
  1431. callID: value.id,
  1432. state: {
  1433. status: "pending",
  1434. },
  1435. })
  1436. toolcalls[value.id] = part as MessageV2.ToolPart
  1437. break
  1438. case "tool-input-delta":
  1439. break
  1440. case "tool-input-end":
  1441. break
  1442. case "tool-call": {
  1443. const match = toolcalls[value.toolCallId]
  1444. if (match) {
  1445. const part = await updatePart({
  1446. ...match,
  1447. tool: value.toolName,
  1448. state: {
  1449. status: "running",
  1450. input: value.input,
  1451. time: {
  1452. start: Date.now(),
  1453. },
  1454. },
  1455. })
  1456. toolcalls[value.toolCallId] = part as MessageV2.ToolPart
  1457. }
  1458. break
  1459. }
  1460. case "tool-result": {
  1461. const match = toolcalls[value.toolCallId]
  1462. if (match && match.state.status === "running") {
  1463. await updatePart({
  1464. ...match,
  1465. state: {
  1466. status: "completed",
  1467. input: value.input,
  1468. output: value.output.output,
  1469. metadata: value.output.metadata,
  1470. title: value.output.title,
  1471. time: {
  1472. start: match.state.time.start,
  1473. end: Date.now(),
  1474. },
  1475. },
  1476. })
  1477. delete toolcalls[value.toolCallId]
  1478. }
  1479. break
  1480. }
  1481. case "tool-error": {
  1482. const match = toolcalls[value.toolCallId]
  1483. if (match && match.state.status === "running") {
  1484. if (value.error instanceof Permission.RejectedError) {
  1485. shouldStop = true
  1486. }
  1487. await updatePart({
  1488. ...match,
  1489. state: {
  1490. status: "error",
  1491. input: value.input,
  1492. error: (value.error as any).toString(),
  1493. metadata: value.error instanceof Permission.RejectedError ? value.error.metadata : undefined,
  1494. time: {
  1495. start: match.state.time.start,
  1496. end: Date.now(),
  1497. },
  1498. },
  1499. })
  1500. delete toolcalls[value.toolCallId]
  1501. }
  1502. break
  1503. }
  1504. case "error":
  1505. throw value.error
  1506. case "start-step":
  1507. await updatePart({
  1508. id: Identifier.ascending("part"),
  1509. messageID: assistantMsg.id,
  1510. sessionID: assistantMsg.sessionID,
  1511. type: "step-start",
  1512. })
  1513. snapshot = await Snapshot.track()
  1514. break
  1515. case "finish-step":
  1516. const usage = getUsage(input.model, value.usage, value.providerMetadata)
  1517. assistantMsg.cost += usage.cost
  1518. assistantMsg.tokens = usage.tokens
  1519. await updatePart({
  1520. id: Identifier.ascending("part"),
  1521. messageID: assistantMsg.id,
  1522. sessionID: assistantMsg.sessionID,
  1523. type: "step-finish",
  1524. tokens: usage.tokens,
  1525. cost: usage.cost,
  1526. })
  1527. await updateMessage(assistantMsg)
  1528. if (snapshot) {
  1529. const patch = await Snapshot.patch(snapshot)
  1530. if (patch.files.length) {
  1531. await updatePart({
  1532. id: Identifier.ascending("part"),
  1533. messageID: assistantMsg.id,
  1534. sessionID: assistantMsg.sessionID,
  1535. type: "patch",
  1536. hash: patch.hash,
  1537. files: patch.files,
  1538. })
  1539. }
  1540. snapshot = undefined
  1541. }
  1542. break
  1543. case "text-start":
  1544. currentText = {
  1545. id: Identifier.ascending("part"),
  1546. messageID: assistantMsg.id,
  1547. sessionID: assistantMsg.sessionID,
  1548. type: "text",
  1549. text: "",
  1550. time: {
  1551. start: Date.now(),
  1552. },
  1553. }
  1554. break
  1555. case "text-delta":
  1556. if (currentText) {
  1557. currentText.text += value.text
  1558. if (currentText.text) await updatePart(currentText)
  1559. }
  1560. break
  1561. case "text-end":
  1562. if (currentText) {
  1563. currentText.text = currentText.text.trimEnd()
  1564. currentText.time = {
  1565. start: Date.now(),
  1566. end: Date.now(),
  1567. }
  1568. await updatePart(currentText)
  1569. }
  1570. currentText = undefined
  1571. break
  1572. case "finish":
  1573. assistantMsg.time.completed = Date.now()
  1574. await updateMessage(assistantMsg)
  1575. break
  1576. default:
  1577. log.info("unhandled", {
  1578. ...value,
  1579. })
  1580. continue
  1581. }
  1582. }
  1583. } catch (e) {
  1584. log.error("", {
  1585. error: e,
  1586. })
  1587. switch (true) {
  1588. case e instanceof DOMException && e.name === "AbortError":
  1589. assistantMsg.error = new MessageV2.AbortedError(
  1590. { message: e.message },
  1591. {
  1592. cause: e,
  1593. },
  1594. ).toObject()
  1595. break
  1596. case MessageV2.OutputLengthError.isInstance(e):
  1597. assistantMsg.error = e
  1598. break
  1599. case LoadAPIKeyError.isInstance(e):
  1600. assistantMsg.error = new MessageV2.AuthError(
  1601. {
  1602. providerID: input.providerID,
  1603. message: e.message,
  1604. },
  1605. { cause: e },
  1606. ).toObject()
  1607. break
  1608. case e instanceof Error:
  1609. assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  1610. break
  1611. default:
  1612. assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
  1613. }
  1614. Bus.publish(Event.Error, {
  1615. sessionID: assistantMsg.sessionID,
  1616. error: assistantMsg.error,
  1617. })
  1618. }
  1619. const p = await getParts(assistantMsg.id)
  1620. for (const part of p) {
  1621. if (part.type === "tool" && part.state.status !== "completed" && part.state.status !== "error") {
  1622. updatePart({
  1623. ...part,
  1624. state: {
  1625. status: "error",
  1626. error: "Tool execution aborted",
  1627. time: {
  1628. start: Date.now(),
  1629. end: Date.now(),
  1630. },
  1631. input: {},
  1632. },
  1633. })
  1634. }
  1635. }
  1636. assistantMsg.time.completed = Date.now()
  1637. await updateMessage(assistantMsg)
  1638. return { info: assistantMsg, parts: p }
  1639. },
  1640. }
  1641. return result
  1642. }
  1643. export const RevertInput = z.object({
  1644. sessionID: Identifier.schema("session"),
  1645. messageID: Identifier.schema("message"),
  1646. partID: Identifier.schema("part").optional(),
  1647. })
  1648. export type RevertInput = z.infer<typeof RevertInput>
  1649. export async function revert(input: RevertInput) {
  1650. const all = await messages(input.sessionID)
  1651. let lastUser: MessageV2.User | undefined
  1652. const session = await get(input.sessionID)
  1653. let revert: Info["revert"]
  1654. const patches: Snapshot.Patch[] = []
  1655. for (const msg of all) {
  1656. if (msg.info.role === "user") lastUser = msg.info
  1657. const remaining = []
  1658. for (const part of msg.parts) {
  1659. if (revert) {
  1660. if (part.type === "patch") {
  1661. patches.push(part)
  1662. }
  1663. continue
  1664. }
  1665. if (!revert) {
  1666. if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
  1667. // if no useful parts left in message, same as reverting whole message
  1668. const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
  1669. revert = {
  1670. messageID: !partID && lastUser ? lastUser.id : msg.info.id,
  1671. partID,
  1672. }
  1673. }
  1674. remaining.push(part)
  1675. }
  1676. }
  1677. }
  1678. if (revert) {
  1679. const session = await get(input.sessionID)
  1680. revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
  1681. await Snapshot.revert(patches)
  1682. if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
  1683. return update(input.sessionID, (draft) => {
  1684. draft.revert = revert
  1685. })
  1686. }
  1687. return session
  1688. }
  1689. export async function unrevert(input: { sessionID: string }) {
  1690. log.info("unreverting", input)
  1691. const session = await get(input.sessionID)
  1692. if (!session.revert) return session
  1693. if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
  1694. const next = await update(input.sessionID, (draft) => {
  1695. draft.revert = undefined
  1696. })
  1697. return next
  1698. }
  1699. export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
  1700. await update(input.sessionID, (draft) => {
  1701. draft.time.compacting = Date.now()
  1702. })
  1703. await using _ = defer(async () => {
  1704. await update(input.sessionID, (draft) => {
  1705. draft.time.compacting = undefined
  1706. })
  1707. })
  1708. const msgs = await messages(input.sessionID)
  1709. const start = Math.max(
  1710. 0,
  1711. msgs.findLastIndex((msg) => msg.info.role === "assistant" && msg.info.summary === true),
  1712. )
  1713. log.info("summarizing", { start })
  1714. const toSummarize = msgs.slice(start)
  1715. const model = await Provider.getModel(input.providerID, input.modelID)
  1716. const system = [
  1717. ...SystemPrompt.summarize(model.providerID),
  1718. ...(await SystemPrompt.environment()),
  1719. ...(await SystemPrompt.custom()),
  1720. ]
  1721. const msg = (await updateMessage({
  1722. id: Identifier.ascending("message"),
  1723. role: "assistant",
  1724. sessionID: input.sessionID,
  1725. system,
  1726. mode: "build",
  1727. path: {
  1728. cwd: Instance.directory,
  1729. root: Instance.worktree,
  1730. },
  1731. cost: 0,
  1732. tokens: {
  1733. output: 0,
  1734. input: 0,
  1735. reasoning: 0,
  1736. cache: { read: 0, write: 0 },
  1737. },
  1738. modelID: input.modelID,
  1739. providerID: model.providerID,
  1740. time: {
  1741. created: Date.now(),
  1742. },
  1743. })) as MessageV2.Assistant
  1744. const generated = await generateText({
  1745. maxRetries: 10,
  1746. model: model.language,
  1747. messages: [
  1748. ...system.map(
  1749. (x): ModelMessage => ({
  1750. role: "system",
  1751. content: x,
  1752. }),
  1753. ),
  1754. ...MessageV2.toModelMessage(toSummarize),
  1755. {
  1756. role: "user",
  1757. content: [
  1758. {
  1759. type: "text",
  1760. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  1761. },
  1762. ],
  1763. },
  1764. ],
  1765. })
  1766. const usage = getUsage(model.info, generated.usage, generated.providerMetadata)
  1767. msg.cost += usage.cost
  1768. msg.tokens = usage.tokens
  1769. msg.summary = true
  1770. msg.time.completed = Date.now()
  1771. await updateMessage(msg)
  1772. const part = await updatePart({
  1773. type: "text",
  1774. sessionID: input.sessionID,
  1775. messageID: msg.id,
  1776. id: Identifier.ascending("part"),
  1777. text: generated.text,
  1778. time: {
  1779. start: Date.now(),
  1780. end: Date.now(),
  1781. },
  1782. })
  1783. Bus.publish(Event.Compacted, {
  1784. sessionID: input.sessionID,
  1785. })
  1786. return {
  1787. info: msg,
  1788. parts: [part],
  1789. }
  1790. }
  1791. function needsCompaction(input: { tokens: MessageV2.Assistant["tokens"]; model: ModelsDev.Model }) {
  1792. const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
  1793. const output = Math.min(input.model.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
  1794. const usable = input.model.limit.context - output
  1795. return count > usable / 2
  1796. }
  1797. export async function microcompact(input: { sessionID: string }) {
  1798. const msgs = await messages(input.sessionID)
  1799. let sum = 0
  1800. for (let msgIndex = msgs.length - 1; msgIndex >= 0; msgIndex--) {
  1801. const msg = msgs[msgIndex]
  1802. for (let partIndex = msg.parts.length - 1; partIndex >= 0; partIndex--) {
  1803. const part = msg.parts[partIndex]
  1804. if (part.type === "tool")
  1805. if (part.state.status === "completed") {
  1806. sum += Token.estimate(part.state.output)
  1807. if (sum > 40_000) {
  1808. }
  1809. }
  1810. }
  1811. }
  1812. }
  1813. function isLocked(sessionID: string) {
  1814. return state().pending.has(sessionID)
  1815. }
  1816. function lock(sessionID: string) {
  1817. log.info("locking", { sessionID })
  1818. if (state().pending.has(sessionID)) throw new BusyError(sessionID)
  1819. const controller = new AbortController()
  1820. state().pending.set(sessionID, controller)
  1821. return {
  1822. signal: controller.signal,
  1823. async [Symbol.dispose]() {
  1824. log.info("unlocking", { sessionID })
  1825. state().pending.delete(sessionID)
  1826. const session = await get(sessionID)
  1827. if (session.parentID) return
  1828. Bus.publish(Event.Idle, {
  1829. sessionID,
  1830. })
  1831. },
  1832. }
  1833. }
  1834. function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
  1835. const tokens = {
  1836. input: usage.inputTokens ?? 0,
  1837. output: usage.outputTokens ?? 0,
  1838. reasoning: usage?.reasoningTokens ?? 0,
  1839. cache: {
  1840. write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  1841. // @ts-expect-error
  1842. metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  1843. 0) as number,
  1844. read: usage.cachedInputTokens ?? 0,
  1845. },
  1846. }
  1847. return {
  1848. cost: new Decimal(0)
  1849. .add(new Decimal(tokens.input).mul(model.cost?.input ?? 0).div(1_000_000))
  1850. .add(new Decimal(tokens.output).mul(model.cost?.output ?? 0).div(1_000_000))
  1851. .add(new Decimal(tokens.cache.read).mul(model.cost?.cache_read ?? 0).div(1_000_000))
  1852. .add(new Decimal(tokens.cache.write).mul(model.cost?.cache_write ?? 0).div(1_000_000))
  1853. .toNumber(),
  1854. tokens,
  1855. }
  1856. }
  1857. export class BusyError extends Error {
  1858. constructor(public readonly sessionID: string) {
  1859. super(`Session ${sessionID} is busy`)
  1860. }
  1861. }
  1862. export async function initialize(input: {
  1863. sessionID: string
  1864. modelID: string
  1865. providerID: string
  1866. messageID: string
  1867. }) {
  1868. await Session.prompt({
  1869. sessionID: input.sessionID,
  1870. messageID: input.messageID,
  1871. model: {
  1872. providerID: input.providerID,
  1873. modelID: input.modelID,
  1874. },
  1875. parts: [
  1876. {
  1877. id: Identifier.ascending("part"),
  1878. type: "text",
  1879. text: PROMPT_INITIALIZE.replace("${path}", Instance.worktree),
  1880. },
  1881. ],
  1882. })
  1883. await Project.setInitialized(Instance.project.id)
  1884. }
  1885. }