index.ts 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343
  1. import path from "path"
  2. import { Decimal } from "decimal.js"
  3. import { z, ZodSchema } from "zod"
  4. import {
  5. generateText,
  6. LoadAPIKeyError,
  7. streamText,
  8. tool,
  9. wrapLanguageModel,
  10. type Tool as AITool,
  11. type LanguageModelUsage,
  12. type ProviderMetadata,
  13. type ModelMessage,
  14. stepCountIs,
  15. type StreamTextResult,
  16. } from "ai"
  17. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  18. import PROMPT_PLAN from "../session/prompt/plan.txt"
  19. import { App } from "../app/app"
  20. import { Bus } from "../bus"
  21. import { Config } from "../config/config"
  22. import { Flag } from "../flag/flag"
  23. import { Identifier } from "../id/id"
  24. import { Installation } from "../installation"
  25. import { MCP } from "../mcp"
  26. import { Provider } from "../provider/provider"
  27. import { ProviderTransform } from "../provider/transform"
  28. import type { ModelsDev } from "../provider/models"
  29. import { Share } from "../share/share"
  30. import { Snapshot } from "../snapshot"
  31. import { Storage } from "../storage/storage"
  32. import { Log } from "../util/log"
  33. import { NamedError } from "../util/error"
  34. import { SystemPrompt } from "./system"
  35. import { FileTime } from "../file/time"
  36. import { MessageV2 } from "./message-v2"
  37. import { Mode } from "./mode"
  38. import { LSP } from "../lsp"
  39. import { ReadTool } from "../tool/read"
  40. import { mergeDeep, pipe, splitWhen } from "remeda"
  41. import { ToolRegistry } from "../tool/registry"
  42. export namespace Session {
  43. const log = Log.create({ service: "session" })
  44. const OUTPUT_TOKEN_MAX = 32_000
  45. export const Info = z
  46. .object({
  47. id: Identifier.schema("session"),
  48. parentID: Identifier.schema("session").optional(),
  49. share: z
  50. .object({
  51. url: z.string(),
  52. })
  53. .optional(),
  54. title: z.string(),
  55. version: z.string(),
  56. time: z.object({
  57. created: z.number(),
  58. updated: z.number(),
  59. }),
  60. revert: z
  61. .object({
  62. messageID: z.string(),
  63. partID: z.string().optional(),
  64. snapshot: z.string().optional(),
  65. diff: z.string().optional(),
  66. })
  67. .optional(),
  68. })
  69. .openapi({
  70. ref: "Session",
  71. })
  72. export type Info = z.output<typeof Info>
  73. export const ShareInfo = z
  74. .object({
  75. secret: z.string(),
  76. url: z.string(),
  77. })
  78. .openapi({
  79. ref: "SessionShare",
  80. })
  81. export type ShareInfo = z.output<typeof ShareInfo>
  82. export const Event = {
  83. Updated: Bus.event(
  84. "session.updated",
  85. z.object({
  86. info: Info,
  87. }),
  88. ),
  89. Deleted: Bus.event(
  90. "session.deleted",
  91. z.object({
  92. info: Info,
  93. }),
  94. ),
  95. Idle: Bus.event(
  96. "session.idle",
  97. z.object({
  98. sessionID: z.string(),
  99. }),
  100. ),
  101. Error: Bus.event(
  102. "session.error",
  103. z.object({
  104. sessionID: z.string().optional(),
  105. error: MessageV2.Assistant.shape.error,
  106. }),
  107. ),
  108. }
  109. const state = App.state(
  110. "session",
  111. () => {
  112. const sessions = new Map<string, Info>()
  113. const messages = new Map<string, MessageV2.Info[]>()
  114. const pending = new Map<string, AbortController>()
  115. const queued = new Map<
  116. string,
  117. {
  118. input: ChatInput
  119. message: MessageV2.User
  120. parts: MessageV2.Part[]
  121. processed: boolean
  122. callback: (input: { info: MessageV2.Assistant; parts: MessageV2.Part[] }) => void
  123. }[]
  124. >()
  125. return {
  126. sessions,
  127. messages,
  128. pending,
  129. queued,
  130. }
  131. },
  132. async (state) => {
  133. for (const [_, controller] of state.pending) {
  134. controller.abort()
  135. }
  136. },
  137. )
  138. export async function create(parentID?: string) {
  139. const result: Info = {
  140. id: Identifier.descending("session"),
  141. version: Installation.VERSION,
  142. parentID,
  143. title: (parentID ? "Child session - " : "New Session - ") + new Date().toISOString(),
  144. time: {
  145. created: Date.now(),
  146. updated: Date.now(),
  147. },
  148. }
  149. log.info("created", result)
  150. state().sessions.set(result.id, result)
  151. await Storage.writeJSON("session/info/" + result.id, result)
  152. const cfg = await Config.get()
  153. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
  154. share(result.id)
  155. .then((share) => {
  156. update(result.id, (draft) => {
  157. draft.share = share
  158. })
  159. })
  160. .catch(() => {
  161. // Silently ignore sharing errors during session creation
  162. })
  163. Bus.publish(Event.Updated, {
  164. info: result,
  165. })
  166. return result
  167. }
  168. export async function get(id: string) {
  169. const result = state().sessions.get(id)
  170. if (result) {
  171. return result
  172. }
  173. const read = await Storage.readJSON<Info>("session/info/" + id)
  174. state().sessions.set(id, read)
  175. return read as Info
  176. }
  177. export async function getShare(id: string) {
  178. return Storage.readJSON<ShareInfo>("session/share/" + id)
  179. }
  180. export async function share(id: string) {
  181. const cfg = await Config.get()
  182. if (cfg.share === "disabled") {
  183. throw new Error("Sharing is disabled in configuration")
  184. }
  185. const session = await get(id)
  186. if (session.share) return session.share
  187. const share = await Share.create(id)
  188. await update(id, (draft) => {
  189. draft.share = {
  190. url: share.url,
  191. }
  192. })
  193. await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
  194. await Share.sync("session/info/" + id, session)
  195. for (const msg of await messages(id)) {
  196. await Share.sync("session/message/" + id + "/" + msg.info.id, msg.info)
  197. for (const part of msg.parts) {
  198. await Share.sync("session/part/" + id + "/" + msg.info.id + "/" + part.id, part)
  199. }
  200. }
  201. return share
  202. }
  203. export async function unshare(id: string) {
  204. const share = await getShare(id)
  205. if (!share) return
  206. await Storage.remove("session/share/" + id)
  207. await update(id, (draft) => {
  208. draft.share = undefined
  209. })
  210. await Share.remove(id, share.secret)
  211. }
  212. export async function update(id: string, editor: (session: Info) => void) {
  213. const { sessions } = state()
  214. const session = await get(id)
  215. if (!session) return
  216. editor(session)
  217. session.time.updated = Date.now()
  218. sessions.set(id, session)
  219. await Storage.writeJSON("session/info/" + id, session)
  220. Bus.publish(Event.Updated, {
  221. info: session,
  222. })
  223. return session
  224. }
  225. export async function messages(sessionID: string) {
  226. const result = [] as {
  227. info: MessageV2.Info
  228. parts: MessageV2.Part[]
  229. }[]
  230. for (const p of await Storage.list("session/message/" + sessionID)) {
  231. const read = await Storage.readJSON<MessageV2.Info>(p)
  232. result.push({
  233. info: read,
  234. parts: await getParts(sessionID, read.id),
  235. })
  236. }
  237. result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
  238. return result
  239. }
  240. export async function getMessage(sessionID: string, messageID: string) {
  241. return {
  242. info: await Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID),
  243. parts: await getParts(sessionID, messageID),
  244. }
  245. }
  246. export async function getParts(sessionID: string, messageID: string) {
  247. const result = [] as MessageV2.Part[]
  248. for (const item of await Storage.list("session/part/" + sessionID + "/" + messageID)) {
  249. const read = await Storage.readJSON<MessageV2.Part>(item)
  250. result.push(read)
  251. }
  252. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  253. return result
  254. }
  255. export async function* list() {
  256. for (const item of await Storage.list("session/info")) {
  257. const sessionID = path.basename(item, ".json")
  258. yield get(sessionID)
  259. }
  260. }
  261. export async function children(parentID: string) {
  262. const result = [] as Session.Info[]
  263. for (const item of await Storage.list("session/info")) {
  264. const sessionID = path.basename(item, ".json")
  265. const session = await get(sessionID)
  266. if (session.parentID !== parentID) continue
  267. result.push(session)
  268. }
  269. return result
  270. }
  271. export function abort(sessionID: string) {
  272. const controller = state().pending.get(sessionID)
  273. if (!controller) return false
  274. log.info("aborting", {
  275. sessionID,
  276. })
  277. controller.abort()
  278. state().pending.delete(sessionID)
  279. return true
  280. }
  281. export async function remove(sessionID: string, emitEvent = true) {
  282. try {
  283. abort(sessionID)
  284. const session = await get(sessionID)
  285. for (const child of await children(sessionID)) {
  286. await remove(child.id, false)
  287. }
  288. await unshare(sessionID).catch(() => {})
  289. await Storage.remove(`session/info/${sessionID}`).catch(() => {})
  290. await Storage.removeDir(`session/message/${sessionID}/`).catch(() => {})
  291. state().sessions.delete(sessionID)
  292. state().messages.delete(sessionID)
  293. if (emitEvent) {
  294. Bus.publish(Event.Deleted, {
  295. info: session,
  296. })
  297. }
  298. } catch (e) {
  299. log.error(e)
  300. }
  301. }
  302. async function updateMessage(msg: MessageV2.Info) {
  303. await Storage.writeJSON("session/message/" + msg.sessionID + "/" + msg.id, msg)
  304. Bus.publish(MessageV2.Event.Updated, {
  305. info: msg,
  306. })
  307. }
  308. async function updatePart(part: MessageV2.Part) {
  309. await Storage.writeJSON(["session", "part", part.sessionID, part.messageID, part.id].join("/"), part)
  310. Bus.publish(MessageV2.Event.PartUpdated, {
  311. part,
  312. })
  313. return part
  314. }
  315. export const ChatInput = z.object({
  316. sessionID: Identifier.schema("session"),
  317. messageID: Identifier.schema("message").optional(),
  318. providerID: z.string(),
  319. modelID: z.string(),
  320. mode: z.string().optional(),
  321. system: z.string().optional(),
  322. tools: z.record(z.boolean()).optional(),
  323. parts: z.array(
  324. z.discriminatedUnion("type", [
  325. MessageV2.TextPart.omit({
  326. messageID: true,
  327. sessionID: true,
  328. })
  329. .partial({
  330. id: true,
  331. })
  332. .openapi({
  333. ref: "TextPartInput",
  334. }),
  335. MessageV2.FilePart.omit({
  336. messageID: true,
  337. sessionID: true,
  338. })
  339. .partial({
  340. id: true,
  341. })
  342. .openapi({
  343. ref: "FilePartInput",
  344. }),
  345. ]),
  346. ),
  347. })
  348. export type ChatInput = z.infer<typeof ChatInput>
  349. export async function chat(
  350. input: z.infer<typeof ChatInput>,
  351. ): Promise<{ info: MessageV2.Assistant; parts: MessageV2.Part[] }> {
  352. const l = log.clone().tag("session", input.sessionID)
  353. l.info("chatting")
  354. const inputMode = input.mode ?? "build"
  355. // Process revert cleanup first, before creating new messages
  356. const session = await get(input.sessionID)
  357. if (session.revert) {
  358. let msgs = await messages(input.sessionID)
  359. const messageID = session.revert.messageID
  360. const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
  361. msgs = preserve
  362. for (const msg of remove) {
  363. await Storage.remove(`session/message/${input.sessionID}/${msg.info.id}`)
  364. await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: msg.info.id })
  365. }
  366. const last = preserve.at(-1)
  367. if (session.revert.partID && last) {
  368. const partID = session.revert.partID
  369. const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
  370. last.parts = preserveParts
  371. for (const part of removeParts) {
  372. await Storage.remove(`session/part/${input.sessionID}/${last.info.id}/${part.id}`)
  373. await Bus.publish(MessageV2.Event.PartRemoved, {
  374. sessionID: input.sessionID,
  375. messageID: last.info.id,
  376. partID: part.id,
  377. })
  378. }
  379. }
  380. await update(input.sessionID, (draft) => {
  381. draft.revert = undefined
  382. })
  383. }
  384. const userMsg: MessageV2.Info = {
  385. id: input.messageID ?? Identifier.ascending("message"),
  386. role: "user",
  387. sessionID: input.sessionID,
  388. time: {
  389. created: Date.now(),
  390. },
  391. }
  392. const app = App.info()
  393. const userParts = await Promise.all(
  394. input.parts.map(async (part): Promise<MessageV2.Part[]> => {
  395. if (part.type === "file") {
  396. const url = new URL(part.url)
  397. switch (url.protocol) {
  398. case "data:":
  399. if (part.mime === "text/plain") {
  400. return [
  401. {
  402. id: Identifier.ascending("part"),
  403. messageID: userMsg.id,
  404. sessionID: input.sessionID,
  405. type: "text",
  406. synthetic: true,
  407. text: `Called the Read tool with the following input: ${JSON.stringify({ filePath: part.filename })}`,
  408. },
  409. {
  410. id: Identifier.ascending("part"),
  411. messageID: userMsg.id,
  412. sessionID: input.sessionID,
  413. type: "text",
  414. synthetic: true,
  415. text: Buffer.from(part.url, "base64url").toString(),
  416. },
  417. {
  418. ...part,
  419. id: part.id ?? Identifier.ascending("part"),
  420. messageID: userMsg.id,
  421. sessionID: input.sessionID,
  422. },
  423. ]
  424. }
  425. break
  426. case "file:":
  427. // have to normalize, symbol search returns absolute paths
  428. // Decode the pathname since URL constructor doesn't automatically decode it
  429. const filePath = decodeURIComponent(url.pathname)
  430. if (part.mime === "text/plain") {
  431. let offset: number | undefined = undefined
  432. let limit: number | undefined = undefined
  433. const range = {
  434. start: url.searchParams.get("start"),
  435. end: url.searchParams.get("end"),
  436. }
  437. if (range.start != null) {
  438. const filePath = part.url.split("?")[0]
  439. let start = parseInt(range.start)
  440. let end = range.end ? parseInt(range.end) : undefined
  441. // some LSP servers (eg, gopls) don't give full range in
  442. // workspace/symbol searches, so we'll try to find the
  443. // symbol in the document to get the full range
  444. if (start === end) {
  445. const symbols = await LSP.documentSymbol(filePath)
  446. for (const symbol of symbols) {
  447. let range: LSP.Range | undefined
  448. if ("range" in symbol) {
  449. range = symbol.range
  450. } else if ("location" in symbol) {
  451. range = symbol.location.range
  452. }
  453. if (range?.start?.line && range?.start?.line === start) {
  454. start = range.start.line
  455. end = range?.end?.line ?? start
  456. break
  457. }
  458. }
  459. offset = Math.max(start - 2, 0)
  460. if (end) {
  461. limit = end - offset + 2
  462. }
  463. }
  464. }
  465. const args = { filePath, offset, limit }
  466. const result = await ReadTool.init().then((t) =>
  467. t.execute(args, {
  468. sessionID: input.sessionID,
  469. abort: new AbortController().signal,
  470. messageID: userMsg.id,
  471. metadata: async () => {},
  472. }),
  473. )
  474. return [
  475. {
  476. id: Identifier.ascending("part"),
  477. messageID: userMsg.id,
  478. sessionID: input.sessionID,
  479. type: "text",
  480. synthetic: true,
  481. text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
  482. },
  483. {
  484. id: Identifier.ascending("part"),
  485. messageID: userMsg.id,
  486. sessionID: input.sessionID,
  487. type: "text",
  488. synthetic: true,
  489. text: result.output,
  490. },
  491. {
  492. ...part,
  493. id: part.id ?? Identifier.ascending("part"),
  494. messageID: userMsg.id,
  495. sessionID: input.sessionID,
  496. },
  497. ]
  498. }
  499. let file = Bun.file(filePath)
  500. FileTime.read(input.sessionID, filePath)
  501. return [
  502. {
  503. id: Identifier.ascending("part"),
  504. messageID: userMsg.id,
  505. sessionID: input.sessionID,
  506. type: "text",
  507. text: `Called the Read tool with the following input: {\"filePath\":\"${filePath}\"}`,
  508. synthetic: true,
  509. },
  510. {
  511. id: part.id ?? Identifier.ascending("part"),
  512. messageID: userMsg.id,
  513. sessionID: input.sessionID,
  514. type: "file",
  515. url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
  516. mime: part.mime,
  517. filename: part.filename!,
  518. source: part.source,
  519. },
  520. ]
  521. }
  522. }
  523. return [
  524. {
  525. id: Identifier.ascending("part"),
  526. ...part,
  527. messageID: userMsg.id,
  528. sessionID: input.sessionID,
  529. },
  530. ]
  531. }),
  532. ).then((x) => x.flat())
  533. if (inputMode === "plan")
  534. userParts.push({
  535. id: Identifier.ascending("part"),
  536. messageID: userMsg.id,
  537. sessionID: input.sessionID,
  538. type: "text",
  539. text: PROMPT_PLAN,
  540. synthetic: true,
  541. })
  542. await updateMessage(userMsg)
  543. for (const part of userParts) {
  544. await updatePart(part)
  545. }
  546. // mark session as updated
  547. // used for session list sorting (indicates when session was most recently interacted with)
  548. await update(input.sessionID, (_draft) => {})
  549. if (isLocked(input.sessionID)) {
  550. return new Promise((resolve) => {
  551. const queue = state().queued.get(input.sessionID) ?? []
  552. queue.push({
  553. input: input,
  554. message: userMsg,
  555. parts: userParts,
  556. processed: false,
  557. callback: resolve,
  558. })
  559. state().queued.set(input.sessionID, queue)
  560. })
  561. }
  562. const model = await Provider.getModel(input.providerID, input.modelID)
  563. let msgs = await messages(input.sessionID)
  564. const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
  565. const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
  566. // auto summarize if too long
  567. if (previous && previous.tokens) {
  568. const tokens =
  569. previous.tokens.input + previous.tokens.cache.read + previous.tokens.cache.write + previous.tokens.output
  570. if (model.info.limit.context && tokens > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
  571. await summarize({
  572. sessionID: input.sessionID,
  573. providerID: input.providerID,
  574. modelID: input.modelID,
  575. })
  576. return chat(input)
  577. }
  578. }
  579. using abort = lock(input.sessionID)
  580. const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
  581. if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
  582. if (msgs.length === 1 && !session.parentID) {
  583. const small = (await Provider.getSmallModel(input.providerID)) ?? model
  584. generateText({
  585. maxOutputTokens: small.info.reasoning ? 1024 : 20,
  586. providerOptions: {
  587. [input.providerID]: small.info.options,
  588. },
  589. messages: [
  590. ...SystemPrompt.title(input.providerID).map(
  591. (x): ModelMessage => ({
  592. role: "system",
  593. content: x,
  594. }),
  595. ),
  596. ...MessageV2.toModelMessage([
  597. {
  598. info: {
  599. id: Identifier.ascending("message"),
  600. role: "user",
  601. sessionID: input.sessionID,
  602. time: {
  603. created: Date.now(),
  604. },
  605. },
  606. parts: userParts,
  607. },
  608. ]),
  609. ],
  610. model: small.language,
  611. })
  612. .then((result) => {
  613. if (result.text)
  614. return Session.update(input.sessionID, (draft) => {
  615. const cleaned = result.text.replace(/<think>[\s\S]*?<\/think>\s*/g, "")
  616. const title = cleaned.length > 100 ? cleaned.substring(0, 97) + "..." : cleaned
  617. draft.title = title.trim()
  618. })
  619. })
  620. .catch(() => {})
  621. }
  622. const mode = await Mode.get(inputMode)
  623. let system = SystemPrompt.header(input.providerID)
  624. system.push(
  625. ...(() => {
  626. if (input.system) return [input.system]
  627. if (mode.prompt) return [mode.prompt]
  628. return SystemPrompt.provider(input.modelID)
  629. })(),
  630. )
  631. system.push(...(await SystemPrompt.environment()))
  632. system.push(...(await SystemPrompt.custom()))
  633. // max 2 system prompt messages for caching purposes
  634. const [first, ...rest] = system
  635. system = [first, rest.join("\n")]
  636. const assistantMsg: MessageV2.Info = {
  637. id: Identifier.ascending("message"),
  638. role: "assistant",
  639. system,
  640. mode: inputMode,
  641. path: {
  642. cwd: app.path.cwd,
  643. root: app.path.root,
  644. },
  645. cost: 0,
  646. tokens: {
  647. input: 0,
  648. output: 0,
  649. reasoning: 0,
  650. cache: { read: 0, write: 0 },
  651. },
  652. modelID: input.modelID,
  653. providerID: input.providerID,
  654. time: {
  655. created: Date.now(),
  656. },
  657. sessionID: input.sessionID,
  658. }
  659. await updateMessage(assistantMsg)
  660. const tools: Record<string, AITool> = {}
  661. const processor = createProcessor(assistantMsg, model.info)
  662. const enabledTools = pipe(
  663. mode.tools,
  664. mergeDeep(ToolRegistry.enabled(input.providerID, input.modelID)),
  665. mergeDeep(input.tools ?? {}),
  666. )
  667. for (const item of await ToolRegistry.tools(input.providerID, input.modelID)) {
  668. if (enabledTools[item.id] === false) continue
  669. tools[item.id] = tool({
  670. id: item.id as any,
  671. description: item.description,
  672. inputSchema: item.parameters as ZodSchema,
  673. async execute(args, options) {
  674. await processor.track(options.toolCallId)
  675. const result = await item.execute(args, {
  676. sessionID: input.sessionID,
  677. abort: abort.signal,
  678. messageID: assistantMsg.id,
  679. toolCallID: options.toolCallId,
  680. metadata: async (val) => {
  681. const match = processor.partFromToolCall(options.toolCallId)
  682. if (match && match.state.status === "running") {
  683. await updatePart({
  684. ...match,
  685. state: {
  686. title: val.title,
  687. metadata: val.metadata,
  688. status: "running",
  689. input: args,
  690. time: {
  691. start: Date.now(),
  692. },
  693. },
  694. })
  695. }
  696. },
  697. })
  698. return result
  699. },
  700. toModelOutput(result) {
  701. return {
  702. type: "text",
  703. value: result.output,
  704. }
  705. },
  706. })
  707. }
  708. for (const [key, item] of Object.entries(await MCP.tools())) {
  709. if (mode.tools[key] === false) continue
  710. const execute = item.execute
  711. if (!execute) continue
  712. item.execute = async (args, opts) => {
  713. await processor.track(opts.toolCallId)
  714. const result = await execute(args, opts)
  715. const output = result.content
  716. .filter((x: any) => x.type === "text")
  717. .map((x: any) => x.text)
  718. .join("\n\n")
  719. return {
  720. output,
  721. }
  722. }
  723. item.toModelOutput = (result) => {
  724. return {
  725. type: "text",
  726. value: result.output,
  727. }
  728. }
  729. tools[key] = item
  730. }
  731. const stream = streamText({
  732. onError(e) {
  733. log.error("streamText error", {
  734. error: e,
  735. })
  736. },
  737. async prepareStep({ messages }) {
  738. const queue = (state().queued.get(input.sessionID) ?? []).filter((x) => !x.processed)
  739. if (queue.length) {
  740. for (const item of queue) {
  741. if (item.processed) continue
  742. messages.push(
  743. ...MessageV2.toModelMessage([
  744. {
  745. info: item.message,
  746. parts: item.parts,
  747. },
  748. ]),
  749. )
  750. item.processed = true
  751. }
  752. assistantMsg.time.completed = Date.now()
  753. await updateMessage(assistantMsg)
  754. Object.assign(assistantMsg, {
  755. id: Identifier.ascending("message"),
  756. role: "assistant",
  757. system,
  758. path: {
  759. cwd: app.path.cwd,
  760. root: app.path.root,
  761. },
  762. cost: 0,
  763. tokens: {
  764. input: 0,
  765. output: 0,
  766. reasoning: 0,
  767. cache: { read: 0, write: 0 },
  768. },
  769. modelID: input.modelID,
  770. providerID: input.providerID,
  771. mode: inputMode,
  772. time: {
  773. created: Date.now(),
  774. },
  775. sessionID: input.sessionID,
  776. })
  777. await updateMessage(assistantMsg)
  778. }
  779. return {
  780. messages,
  781. }
  782. },
  783. maxRetries: 3,
  784. maxOutputTokens: outputLimit,
  785. abortSignal: abort.signal,
  786. stopWhen: stepCountIs(1000),
  787. providerOptions: {
  788. [input.providerID]: model.info.options,
  789. },
  790. messages: [
  791. ...system.map(
  792. (x): ModelMessage => ({
  793. role: "system",
  794. content: x,
  795. }),
  796. ),
  797. ...MessageV2.toModelMessage(msgs),
  798. ],
  799. temperature: model.info.temperature
  800. ? (mode.temperature ?? ProviderTransform.temperature(input.providerID, input.modelID))
  801. : undefined,
  802. tools: model.info.tool_call === false ? undefined : tools,
  803. model: wrapLanguageModel({
  804. model: model.language,
  805. middleware: [
  806. {
  807. async transformParams(args) {
  808. if (args.type === "stream") {
  809. // @ts-expect-error
  810. args.params.prompt = ProviderTransform.message(args.params.prompt, input.providerID, input.modelID)
  811. }
  812. return args.params
  813. },
  814. },
  815. ],
  816. }),
  817. })
  818. const result = await processor.process(stream)
  819. const queued = state().queued.get(input.sessionID) ?? []
  820. const unprocessed = queued.find((x) => !x.processed)
  821. if (unprocessed) {
  822. unprocessed.processed = true
  823. return chat(unprocessed.input)
  824. }
  825. for (const item of queued) {
  826. item.callback(result)
  827. }
  828. state().queued.delete(input.sessionID)
  829. return result
  830. }
  831. function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
  832. const toolCalls: Record<string, MessageV2.ToolPart> = {}
  833. const snapshots: Record<string, string> = {}
  834. return {
  835. async track(toolCallID: string) {
  836. const hash = await Snapshot.track()
  837. if (hash) snapshots[toolCallID] = hash
  838. },
  839. partFromToolCall(toolCallID: string) {
  840. return toolCalls[toolCallID]
  841. },
  842. async process(stream: StreamTextResult<Record<string, AITool>, never>) {
  843. try {
  844. let currentText: MessageV2.TextPart | undefined
  845. for await (const value of stream.fullStream) {
  846. log.info("part", {
  847. type: value.type,
  848. })
  849. switch (value.type) {
  850. case "start":
  851. break
  852. case "tool-input-start":
  853. const part = await updatePart({
  854. id: toolCalls[value.id]?.id ?? Identifier.ascending("part"),
  855. messageID: assistantMsg.id,
  856. sessionID: assistantMsg.sessionID,
  857. type: "tool",
  858. tool: value.toolName,
  859. callID: value.id,
  860. state: {
  861. status: "pending",
  862. },
  863. })
  864. toolCalls[value.id] = part as MessageV2.ToolPart
  865. break
  866. case "tool-input-delta":
  867. break
  868. case "tool-input-end":
  869. break
  870. case "tool-call": {
  871. const match = toolCalls[value.toolCallId]
  872. if (match) {
  873. const part = await updatePart({
  874. ...match,
  875. state: {
  876. status: "running",
  877. input: value.input,
  878. time: {
  879. start: Date.now(),
  880. },
  881. },
  882. })
  883. toolCalls[value.toolCallId] = part as MessageV2.ToolPart
  884. }
  885. break
  886. }
  887. case "tool-result": {
  888. const match = toolCalls[value.toolCallId]
  889. if (match && match.state.status === "running") {
  890. await updatePart({
  891. ...match,
  892. state: {
  893. status: "completed",
  894. input: value.input,
  895. output: value.output.output,
  896. metadata: value.output.metadata,
  897. title: value.output.title,
  898. time: {
  899. start: match.state.time.start,
  900. end: Date.now(),
  901. },
  902. },
  903. })
  904. delete toolCalls[value.toolCallId]
  905. const snapshot = snapshots[value.toolCallId]
  906. if (snapshot) {
  907. const patch = await Snapshot.patch(snapshot)
  908. if (patch.files.length) {
  909. await updatePart({
  910. id: Identifier.ascending("part"),
  911. messageID: assistantMsg.id,
  912. sessionID: assistantMsg.sessionID,
  913. type: "patch",
  914. hash: patch.hash,
  915. files: patch.files,
  916. })
  917. }
  918. }
  919. }
  920. break
  921. }
  922. case "tool-error": {
  923. const match = toolCalls[value.toolCallId]
  924. if (match && match.state.status === "running") {
  925. await updatePart({
  926. ...match,
  927. state: {
  928. status: "error",
  929. input: value.input,
  930. error: (value.error as any).toString(),
  931. time: {
  932. start: match.state.time.start,
  933. end: Date.now(),
  934. },
  935. },
  936. })
  937. delete toolCalls[value.toolCallId]
  938. const snapshot = snapshots[value.toolCallId]
  939. if (snapshot) {
  940. const patch = await Snapshot.patch(snapshot)
  941. await updatePart({
  942. id: Identifier.ascending("part"),
  943. messageID: assistantMsg.id,
  944. sessionID: assistantMsg.sessionID,
  945. type: "patch",
  946. hash: patch.hash,
  947. files: patch.files,
  948. })
  949. }
  950. }
  951. break
  952. }
  953. case "error":
  954. throw value.error
  955. case "start-step":
  956. await updatePart({
  957. id: Identifier.ascending("part"),
  958. messageID: assistantMsg.id,
  959. sessionID: assistantMsg.sessionID,
  960. type: "step-start",
  961. })
  962. break
  963. case "finish-step":
  964. const usage = getUsage(model, value.usage, value.providerMetadata)
  965. assistantMsg.cost += usage.cost
  966. assistantMsg.tokens = usage.tokens
  967. await updatePart({
  968. id: Identifier.ascending("part"),
  969. messageID: assistantMsg.id,
  970. sessionID: assistantMsg.sessionID,
  971. type: "step-finish",
  972. tokens: usage.tokens,
  973. cost: usage.cost,
  974. })
  975. await updateMessage(assistantMsg)
  976. break
  977. case "text-start":
  978. currentText = {
  979. id: Identifier.ascending("part"),
  980. messageID: assistantMsg.id,
  981. sessionID: assistantMsg.sessionID,
  982. type: "text",
  983. text: "",
  984. time: {
  985. start: Date.now(),
  986. },
  987. }
  988. break
  989. case "text-delta":
  990. if (currentText) {
  991. currentText.text += value.text
  992. if (currentText.text) await updatePart(currentText)
  993. }
  994. break
  995. case "text-end":
  996. if (currentText) {
  997. currentText.text = currentText.text.trimEnd()
  998. currentText.time = {
  999. start: Date.now(),
  1000. end: Date.now(),
  1001. }
  1002. await updatePart(currentText)
  1003. }
  1004. currentText = undefined
  1005. break
  1006. case "finish":
  1007. assistantMsg.time.completed = Date.now()
  1008. await updateMessage(assistantMsg)
  1009. break
  1010. default:
  1011. log.info("unhandled", {
  1012. ...value,
  1013. })
  1014. continue
  1015. }
  1016. }
  1017. } catch (e) {
  1018. log.error("", {
  1019. error: e,
  1020. })
  1021. switch (true) {
  1022. case e instanceof DOMException && e.name === "AbortError":
  1023. assistantMsg.error = new MessageV2.AbortedError(
  1024. { message: e.message },
  1025. {
  1026. cause: e,
  1027. },
  1028. ).toObject()
  1029. break
  1030. case MessageV2.OutputLengthError.isInstance(e):
  1031. assistantMsg.error = e
  1032. break
  1033. case LoadAPIKeyError.isInstance(e):
  1034. assistantMsg.error = new MessageV2.AuthError(
  1035. {
  1036. providerID: model.id,
  1037. message: e.message,
  1038. },
  1039. { cause: e },
  1040. ).toObject()
  1041. break
  1042. case e instanceof Error:
  1043. assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  1044. break
  1045. default:
  1046. assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
  1047. }
  1048. Bus.publish(Event.Error, {
  1049. sessionID: assistantMsg.sessionID,
  1050. error: assistantMsg.error,
  1051. })
  1052. }
  1053. const p = await getParts(assistantMsg.sessionID, assistantMsg.id)
  1054. for (const part of p) {
  1055. if (part.type === "tool" && part.state.status !== "completed") {
  1056. updatePart({
  1057. ...part,
  1058. state: {
  1059. status: "error",
  1060. error: "Tool execution aborted",
  1061. time: {
  1062. start: Date.now(),
  1063. end: Date.now(),
  1064. },
  1065. input: {},
  1066. },
  1067. })
  1068. }
  1069. }
  1070. assistantMsg.time.completed = Date.now()
  1071. await updateMessage(assistantMsg)
  1072. return { info: assistantMsg, parts: p }
  1073. },
  1074. }
  1075. }
  1076. export const RevertInput = z.object({
  1077. sessionID: Identifier.schema("session"),
  1078. messageID: Identifier.schema("message"),
  1079. partID: Identifier.schema("part").optional(),
  1080. })
  1081. export type RevertInput = z.infer<typeof RevertInput>
  1082. export async function revert(input: RevertInput) {
  1083. const all = await messages(input.sessionID)
  1084. let lastUser: MessageV2.User | undefined
  1085. const session = await get(input.sessionID)
  1086. let revert: Info["revert"]
  1087. const patches: Snapshot.Patch[] = []
  1088. for (const msg of all) {
  1089. if (msg.info.role === "user") lastUser = msg.info
  1090. const remaining = []
  1091. for (const part of msg.parts) {
  1092. if (revert) {
  1093. if (part.type === "patch") {
  1094. patches.push(part)
  1095. }
  1096. continue
  1097. }
  1098. if (!revert) {
  1099. if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
  1100. // if no useful parts left in message, same as reverting whole message
  1101. const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
  1102. revert = {
  1103. messageID: !partID && lastUser ? lastUser.id : msg.info.id,
  1104. partID,
  1105. }
  1106. }
  1107. remaining.push(part)
  1108. }
  1109. }
  1110. }
  1111. if (revert) {
  1112. const session = await get(input.sessionID)
  1113. revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
  1114. await Snapshot.revert(patches)
  1115. if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
  1116. return update(input.sessionID, (draft) => {
  1117. draft.revert = revert
  1118. })
  1119. }
  1120. return session
  1121. }
  1122. export async function unrevert(input: { sessionID: string }) {
  1123. log.info("unreverting", input)
  1124. const session = await get(input.sessionID)
  1125. if (!session.revert) return session
  1126. if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
  1127. const next = await update(input.sessionID, (draft) => {
  1128. draft.revert = undefined
  1129. })
  1130. return next
  1131. }
  1132. export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
  1133. using abort = lock(input.sessionID)
  1134. const msgs = await messages(input.sessionID)
  1135. const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
  1136. const filtered = msgs.filter((msg) => !lastSummary || msg.info.id >= lastSummary.info.id)
  1137. const model = await Provider.getModel(input.providerID, input.modelID)
  1138. const app = App.info()
  1139. const system = [
  1140. ...SystemPrompt.summarize(input.providerID),
  1141. ...(await SystemPrompt.environment()),
  1142. ...(await SystemPrompt.custom()),
  1143. ]
  1144. const next: MessageV2.Info = {
  1145. id: Identifier.ascending("message"),
  1146. role: "assistant",
  1147. sessionID: input.sessionID,
  1148. system,
  1149. mode: "build",
  1150. path: {
  1151. cwd: app.path.cwd,
  1152. root: app.path.root,
  1153. },
  1154. summary: true,
  1155. cost: 0,
  1156. modelID: input.modelID,
  1157. providerID: input.providerID,
  1158. tokens: {
  1159. input: 0,
  1160. output: 0,
  1161. reasoning: 0,
  1162. cache: { read: 0, write: 0 },
  1163. },
  1164. time: {
  1165. created: Date.now(),
  1166. },
  1167. }
  1168. await updateMessage(next)
  1169. const processor = createProcessor(next, model.info)
  1170. const stream = streamText({
  1171. maxRetries: 10,
  1172. abortSignal: abort.signal,
  1173. model: model.language,
  1174. messages: [
  1175. ...system.map(
  1176. (x): ModelMessage => ({
  1177. role: "system",
  1178. content: x,
  1179. }),
  1180. ),
  1181. ...MessageV2.toModelMessage(filtered),
  1182. {
  1183. role: "user",
  1184. content: [
  1185. {
  1186. type: "text",
  1187. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  1188. },
  1189. ],
  1190. },
  1191. ],
  1192. })
  1193. const result = await processor.process(stream)
  1194. return result
  1195. }
  1196. function isLocked(sessionID: string) {
  1197. return state().pending.has(sessionID)
  1198. }
  1199. function lock(sessionID: string) {
  1200. log.info("locking", { sessionID })
  1201. if (state().pending.has(sessionID)) throw new BusyError(sessionID)
  1202. const controller = new AbortController()
  1203. state().pending.set(sessionID, controller)
  1204. return {
  1205. signal: controller.signal,
  1206. [Symbol.dispose]() {
  1207. log.info("unlocking", { sessionID })
  1208. state().pending.delete(sessionID)
  1209. Bus.publish(Event.Idle, {
  1210. sessionID,
  1211. })
  1212. },
  1213. }
  1214. }
  1215. function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
  1216. const tokens = {
  1217. input: usage.inputTokens ?? 0,
  1218. output: usage.outputTokens ?? 0,
  1219. reasoning: 0,
  1220. cache: {
  1221. write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  1222. // @ts-expect-error
  1223. metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  1224. 0) as number,
  1225. read: usage.cachedInputTokens ?? 0,
  1226. },
  1227. }
  1228. return {
  1229. cost: new Decimal(0)
  1230. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  1231. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  1232. .add(new Decimal(tokens.cache.read).mul(model.cost.cache_read ?? 0).div(1_000_000))
  1233. .add(new Decimal(tokens.cache.write).mul(model.cost.cache_write ?? 0).div(1_000_000))
  1234. .toNumber(),
  1235. tokens,
  1236. }
  1237. }
  1238. export class BusyError extends Error {
  1239. constructor(public readonly sessionID: string) {
  1240. super(`Session ${sessionID} is busy`)
  1241. }
  1242. }
  1243. export async function initialize(input: {
  1244. sessionID: string
  1245. modelID: string
  1246. providerID: string
  1247. messageID: string
  1248. }) {
  1249. const app = App.info()
  1250. await Session.chat({
  1251. sessionID: input.sessionID,
  1252. messageID: input.messageID,
  1253. providerID: input.providerID,
  1254. modelID: input.modelID,
  1255. parts: [
  1256. {
  1257. id: Identifier.ascending("part"),
  1258. type: "text",
  1259. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  1260. },
  1261. ],
  1262. })
  1263. await App.initialize()
  1264. }
  1265. }