index.ts 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322
  1. import path from "path"
  2. import { Decimal } from "decimal.js"
  3. import { z, ZodSchema } from "zod"
  4. import {
  5. generateText,
  6. LoadAPIKeyError,
  7. streamText,
  8. tool,
  9. wrapLanguageModel,
  10. type Tool as AITool,
  11. type LanguageModelUsage,
  12. type ProviderMetadata,
  13. type ModelMessage,
  14. stepCountIs,
  15. type StreamTextResult,
  16. } from "ai"
  17. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  18. import PROMPT_PLAN from "../session/prompt/plan.txt"
  19. import { App } from "../app/app"
  20. import { Bus } from "../bus"
  21. import { Config } from "../config/config"
  22. import { Flag } from "../flag/flag"
  23. import { Identifier } from "../id/id"
  24. import { Installation } from "../installation"
  25. import { MCP } from "../mcp"
  26. import { Provider } from "../provider/provider"
  27. import { ProviderTransform } from "../provider/transform"
  28. import type { ModelsDev } from "../provider/models"
  29. import { Share } from "../share/share"
  30. import { Snapshot } from "../snapshot"
  31. import { Storage } from "../storage/storage"
  32. import { Log } from "../util/log"
  33. import { NamedError } from "../util/error"
  34. import { SystemPrompt } from "./system"
  35. import { FileTime } from "../file/time"
  36. import { MessageV2 } from "./message-v2"
  37. import { Mode } from "./mode"
  38. import { LSP } from "../lsp"
  39. import { ReadTool } from "../tool/read"
  40. import { mergeDeep, pipe, splitWhen } from "remeda"
  41. import { ToolRegistry } from "../tool/registry"
  42. export namespace Session {
  43. const log = Log.create({ service: "session" })
  44. const OUTPUT_TOKEN_MAX = 32_000
  45. export const Info = z
  46. .object({
  47. id: Identifier.schema("session"),
  48. parentID: Identifier.schema("session").optional(),
  49. share: z
  50. .object({
  51. url: z.string(),
  52. })
  53. .optional(),
  54. title: z.string(),
  55. version: z.string(),
  56. time: z.object({
  57. created: z.number(),
  58. updated: z.number(),
  59. }),
  60. revert: z
  61. .object({
  62. messageID: z.string(),
  63. partID: z.string().optional(),
  64. snapshot: z.string().optional(),
  65. })
  66. .optional(),
  67. })
  68. .openapi({
  69. ref: "Session",
  70. })
  71. export type Info = z.output<typeof Info>
  72. export const ShareInfo = z
  73. .object({
  74. secret: z.string(),
  75. url: z.string(),
  76. })
  77. .openapi({
  78. ref: "SessionShare",
  79. })
  80. export type ShareInfo = z.output<typeof ShareInfo>
  81. export const Event = {
  82. Updated: Bus.event(
  83. "session.updated",
  84. z.object({
  85. info: Info,
  86. }),
  87. ),
  88. Deleted: Bus.event(
  89. "session.deleted",
  90. z.object({
  91. info: Info,
  92. }),
  93. ),
  94. Idle: Bus.event(
  95. "session.idle",
  96. z.object({
  97. sessionID: z.string(),
  98. }),
  99. ),
  100. Error: Bus.event(
  101. "session.error",
  102. z.object({
  103. sessionID: z.string().optional(),
  104. error: MessageV2.Assistant.shape.error,
  105. }),
  106. ),
  107. }
  108. const state = App.state(
  109. "session",
  110. () => {
  111. const sessions = new Map<string, Info>()
  112. const messages = new Map<string, MessageV2.Info[]>()
  113. const pending = new Map<string, AbortController>()
  114. const queued = new Map<
  115. string,
  116. {
  117. input: ChatInput
  118. message: MessageV2.User
  119. parts: MessageV2.Part[]
  120. processed: boolean
  121. callback: (input: { info: MessageV2.Assistant; parts: MessageV2.Part[] }) => void
  122. }[]
  123. >()
  124. return {
  125. sessions,
  126. messages,
  127. pending,
  128. queued,
  129. }
  130. },
  131. async (state) => {
  132. for (const [_, controller] of state.pending) {
  133. controller.abort()
  134. }
  135. },
  136. )
  137. export async function create(parentID?: string) {
  138. const result: Info = {
  139. id: Identifier.descending("session"),
  140. version: Installation.VERSION,
  141. parentID,
  142. title: (parentID ? "Child session - " : "New Session - ") + new Date().toISOString(),
  143. time: {
  144. created: Date.now(),
  145. updated: Date.now(),
  146. },
  147. }
  148. log.info("created", result)
  149. state().sessions.set(result.id, result)
  150. await Storage.writeJSON("session/info/" + result.id, result)
  151. const cfg = await Config.get()
  152. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
  153. share(result.id)
  154. .then((share) => {
  155. update(result.id, (draft) => {
  156. draft.share = share
  157. })
  158. })
  159. .catch(() => {
  160. // Silently ignore sharing errors during session creation
  161. })
  162. Bus.publish(Event.Updated, {
  163. info: result,
  164. })
  165. return result
  166. }
  167. export async function get(id: string) {
  168. const result = state().sessions.get(id)
  169. if (result) {
  170. return result
  171. }
  172. const read = await Storage.readJSON<Info>("session/info/" + id)
  173. state().sessions.set(id, read)
  174. return read as Info
  175. }
  176. export async function getShare(id: string) {
  177. return Storage.readJSON<ShareInfo>("session/share/" + id)
  178. }
  179. export async function share(id: string) {
  180. const cfg = await Config.get()
  181. if (cfg.share === "disabled") {
  182. throw new Error("Sharing is disabled in configuration")
  183. }
  184. const session = await get(id)
  185. if (session.share) return session.share
  186. const share = await Share.create(id)
  187. await update(id, (draft) => {
  188. draft.share = {
  189. url: share.url,
  190. }
  191. })
  192. await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
  193. await Share.sync("session/info/" + id, session)
  194. for (const msg of await messages(id)) {
  195. await Share.sync("session/message/" + id + "/" + msg.info.id, msg.info)
  196. for (const part of msg.parts) {
  197. await Share.sync("session/part/" + id + "/" + msg.info.id + "/" + part.id, part)
  198. }
  199. }
  200. return share
  201. }
  202. export async function unshare(id: string) {
  203. const share = await getShare(id)
  204. if (!share) return
  205. await Storage.remove("session/share/" + id)
  206. await update(id, (draft) => {
  207. draft.share = undefined
  208. })
  209. await Share.remove(id, share.secret)
  210. }
  211. export async function update(id: string, editor: (session: Info) => void) {
  212. const { sessions } = state()
  213. const session = await get(id)
  214. if (!session) return
  215. editor(session)
  216. session.time.updated = Date.now()
  217. sessions.set(id, session)
  218. await Storage.writeJSON("session/info/" + id, session)
  219. Bus.publish(Event.Updated, {
  220. info: session,
  221. })
  222. return session
  223. }
  224. export async function messages(sessionID: string) {
  225. const result = [] as {
  226. info: MessageV2.Info
  227. parts: MessageV2.Part[]
  228. }[]
  229. for (const p of await Storage.list("session/message/" + sessionID)) {
  230. const read = await Storage.readJSON<MessageV2.Info>(p)
  231. result.push({
  232. info: read,
  233. parts: await getParts(sessionID, read.id),
  234. })
  235. }
  236. result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
  237. return result
  238. }
  239. export async function getMessage(sessionID: string, messageID: string) {
  240. return Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID)
  241. }
  242. export async function getParts(sessionID: string, messageID: string) {
  243. const result = [] as MessageV2.Part[]
  244. for (const item of await Storage.list("session/part/" + sessionID + "/" + messageID)) {
  245. const read = await Storage.readJSON<MessageV2.Part>(item)
  246. result.push(read)
  247. }
  248. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  249. return result
  250. }
  251. export async function* list() {
  252. for (const item of await Storage.list("session/info")) {
  253. const sessionID = path.basename(item, ".json")
  254. yield get(sessionID)
  255. }
  256. }
  257. export async function children(parentID: string) {
  258. const result = [] as Session.Info[]
  259. for (const item of await Storage.list("session/info")) {
  260. const sessionID = path.basename(item, ".json")
  261. const session = await get(sessionID)
  262. if (session.parentID !== parentID) continue
  263. result.push(session)
  264. }
  265. return result
  266. }
  267. export function abort(sessionID: string) {
  268. const controller = state().pending.get(sessionID)
  269. if (!controller) return false
  270. controller.abort()
  271. state().pending.delete(sessionID)
  272. return true
  273. }
  274. export async function remove(sessionID: string, emitEvent = true) {
  275. try {
  276. abort(sessionID)
  277. const session = await get(sessionID)
  278. for (const child of await children(sessionID)) {
  279. await remove(child.id, false)
  280. }
  281. await unshare(sessionID).catch(() => {})
  282. await Storage.remove(`session/info/${sessionID}`).catch(() => {})
  283. await Storage.removeDir(`session/message/${sessionID}/`).catch(() => {})
  284. state().sessions.delete(sessionID)
  285. state().messages.delete(sessionID)
  286. if (emitEvent) {
  287. Bus.publish(Event.Deleted, {
  288. info: session,
  289. })
  290. }
  291. } catch (e) {
  292. log.error(e)
  293. }
  294. }
  295. async function updateMessage(msg: MessageV2.Info) {
  296. await Storage.writeJSON("session/message/" + msg.sessionID + "/" + msg.id, msg)
  297. Bus.publish(MessageV2.Event.Updated, {
  298. info: msg,
  299. })
  300. }
  301. async function updatePart(part: MessageV2.Part) {
  302. await Storage.writeJSON(["session", "part", part.sessionID, part.messageID, part.id].join("/"), part)
  303. Bus.publish(MessageV2.Event.PartUpdated, {
  304. part,
  305. })
  306. return part
  307. }
  308. export const ChatInput = z.object({
  309. sessionID: Identifier.schema("session"),
  310. messageID: Identifier.schema("message").optional(),
  311. providerID: z.string(),
  312. modelID: z.string(),
  313. mode: z.string().optional(),
  314. system: z.string().optional(),
  315. tools: z.record(z.boolean()).optional(),
  316. parts: z.array(
  317. z.discriminatedUnion("type", [
  318. MessageV2.TextPart.omit({
  319. messageID: true,
  320. sessionID: true,
  321. })
  322. .partial({
  323. id: true,
  324. })
  325. .openapi({
  326. ref: "TextPartInput",
  327. }),
  328. MessageV2.FilePart.omit({
  329. messageID: true,
  330. sessionID: true,
  331. })
  332. .partial({
  333. id: true,
  334. })
  335. .openapi({
  336. ref: "FilePartInput",
  337. }),
  338. ]),
  339. ),
  340. })
  341. export type ChatInput = z.infer<typeof ChatInput>
  342. export async function chat(
  343. input: z.infer<typeof ChatInput>,
  344. ): Promise<{ info: MessageV2.Assistant; parts: MessageV2.Part[] }> {
  345. const l = log.clone().tag("session", input.sessionID)
  346. l.info("chatting")
  347. const inputMode = input.mode ?? "build"
  348. const userMsg: MessageV2.Info = {
  349. id: input.messageID ?? Identifier.ascending("message"),
  350. role: "user",
  351. sessionID: input.sessionID,
  352. time: {
  353. created: Date.now(),
  354. },
  355. }
  356. const app = App.info()
  357. const userParts = await Promise.all(
  358. input.parts.map(async (part): Promise<MessageV2.Part[]> => {
  359. if (part.type === "file") {
  360. const url = new URL(part.url)
  361. switch (url.protocol) {
  362. case "data:":
  363. if (part.mime === "text/plain") {
  364. return [
  365. {
  366. id: Identifier.ascending("part"),
  367. messageID: userMsg.id,
  368. sessionID: input.sessionID,
  369. type: "text",
  370. synthetic: true,
  371. text: `Called the Read tool with the following input: ${JSON.stringify({ filePath: part.filename })}`,
  372. },
  373. {
  374. id: Identifier.ascending("part"),
  375. messageID: userMsg.id,
  376. sessionID: input.sessionID,
  377. type: "text",
  378. synthetic: true,
  379. text: Buffer.from(part.url, "base64url").toString(),
  380. },
  381. {
  382. ...part,
  383. id: part.id ?? Identifier.ascending("part"),
  384. messageID: userMsg.id,
  385. sessionID: input.sessionID,
  386. },
  387. ]
  388. }
  389. break
  390. case "file:":
  391. // have to normalize, symbol search returns absolute paths
  392. // Decode the pathname since URL constructor doesn't automatically decode it
  393. const pathname = decodeURIComponent(url.pathname)
  394. const relativePath = pathname.replace(app.path.cwd, ".")
  395. const filePath = path.join(app.path.cwd, relativePath)
  396. if (part.mime === "text/plain") {
  397. let offset: number | undefined = undefined
  398. let limit: number | undefined = undefined
  399. const range = {
  400. start: url.searchParams.get("start"),
  401. end: url.searchParams.get("end"),
  402. }
  403. if (range.start != null) {
  404. const filePath = part.url.split("?")[0]
  405. let start = parseInt(range.start)
  406. let end = range.end ? parseInt(range.end) : undefined
  407. // some LSP servers (eg, gopls) don't give full range in
  408. // workspace/symbol searches, so we'll try to find the
  409. // symbol in the document to get the full range
  410. if (start === end) {
  411. const symbols = await LSP.documentSymbol(filePath)
  412. for (const symbol of symbols) {
  413. let range: LSP.Range | undefined
  414. if ("range" in symbol) {
  415. range = symbol.range
  416. } else if ("location" in symbol) {
  417. range = symbol.location.range
  418. }
  419. if (range?.start?.line && range?.start?.line === start) {
  420. start = range.start.line
  421. end = range?.end?.line ?? start
  422. break
  423. }
  424. }
  425. offset = Math.max(start - 2, 0)
  426. if (end) {
  427. limit = end - offset + 2
  428. }
  429. }
  430. }
  431. const args = { filePath, offset, limit }
  432. const result = await ReadTool.init().then((t) =>
  433. t.execute(args, {
  434. sessionID: input.sessionID,
  435. abort: new AbortController().signal,
  436. messageID: userMsg.id,
  437. metadata: async () => {},
  438. }),
  439. )
  440. return [
  441. {
  442. id: Identifier.ascending("part"),
  443. messageID: userMsg.id,
  444. sessionID: input.sessionID,
  445. type: "text",
  446. synthetic: true,
  447. text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
  448. },
  449. {
  450. id: Identifier.ascending("part"),
  451. messageID: userMsg.id,
  452. sessionID: input.sessionID,
  453. type: "text",
  454. synthetic: true,
  455. text: result.output,
  456. },
  457. {
  458. ...part,
  459. id: part.id ?? Identifier.ascending("part"),
  460. messageID: userMsg.id,
  461. sessionID: input.sessionID,
  462. },
  463. ]
  464. }
  465. let file = Bun.file(filePath)
  466. FileTime.read(input.sessionID, filePath)
  467. return [
  468. {
  469. id: Identifier.ascending("part"),
  470. messageID: userMsg.id,
  471. sessionID: input.sessionID,
  472. type: "text",
  473. text: `Called the Read tool with the following input: {\"filePath\":\"${pathname}\"}`,
  474. synthetic: true,
  475. },
  476. {
  477. id: part.id ?? Identifier.ascending("part"),
  478. messageID: userMsg.id,
  479. sessionID: input.sessionID,
  480. type: "file",
  481. url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
  482. mime: part.mime,
  483. filename: part.filename!,
  484. source: part.source,
  485. },
  486. ]
  487. }
  488. }
  489. return [
  490. {
  491. id: Identifier.ascending("part"),
  492. ...part,
  493. messageID: userMsg.id,
  494. sessionID: input.sessionID,
  495. },
  496. ]
  497. }),
  498. ).then((x) => x.flat())
  499. if (inputMode === "plan")
  500. userParts.push({
  501. id: Identifier.ascending("part"),
  502. messageID: userMsg.id,
  503. sessionID: input.sessionID,
  504. type: "text",
  505. text: PROMPT_PLAN,
  506. synthetic: true,
  507. })
  508. await updateMessage(userMsg)
  509. for (const part of userParts) {
  510. await updatePart(part)
  511. }
  512. // mark session as updated since a message has been added to it
  513. await update(input.sessionID, (_draft) => {})
  514. if (isLocked(input.sessionID)) {
  515. return new Promise((resolve) => {
  516. const queue = state().queued.get(input.sessionID) ?? []
  517. queue.push({
  518. input: input,
  519. message: userMsg,
  520. parts: userParts,
  521. processed: false,
  522. callback: resolve,
  523. })
  524. state().queued.set(input.sessionID, queue)
  525. })
  526. }
  527. const model = await Provider.getModel(input.providerID, input.modelID)
  528. let msgs = await messages(input.sessionID)
  529. const session = await get(input.sessionID)
  530. if (session.revert) {
  531. const messageID = session.revert.messageID
  532. const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
  533. msgs = preserve
  534. for (const msg of remove) {
  535. await Storage.remove(`session/message/${input.sessionID}/${msg.info.id}`)
  536. await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: msg.info.id })
  537. }
  538. const last = preserve.at(-1)
  539. if (session.revert.partID && last) {
  540. const partID = session.revert.partID
  541. const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
  542. last.parts = preserveParts
  543. for (const part of removeParts) {
  544. await Storage.remove(`session/part/${input.sessionID}/${last.info.id}/${part.id}`)
  545. await Bus.publish(MessageV2.Event.PartRemoved, {
  546. messageID: last.info.id,
  547. partID: part.id,
  548. })
  549. }
  550. }
  551. }
  552. const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
  553. const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
  554. // auto summarize if too long
  555. if (previous && previous.tokens) {
  556. const tokens =
  557. previous.tokens.input + previous.tokens.cache.read + previous.tokens.cache.write + previous.tokens.output
  558. if (model.info.limit.context && tokens > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
  559. await summarize({
  560. sessionID: input.sessionID,
  561. providerID: input.providerID,
  562. modelID: input.modelID,
  563. })
  564. return chat(input)
  565. }
  566. }
  567. using abort = lock(input.sessionID)
  568. const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
  569. if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
  570. if (msgs.length === 1 && !session.parentID) {
  571. const small = (await Provider.getSmallModel(input.providerID)) ?? model
  572. generateText({
  573. maxOutputTokens: small.info.reasoning ? 1024 : 20,
  574. providerOptions: {
  575. [input.providerID]: small.info.options,
  576. },
  577. messages: [
  578. ...SystemPrompt.title(input.providerID).map(
  579. (x): ModelMessage => ({
  580. role: "system",
  581. content: x,
  582. }),
  583. ),
  584. ...MessageV2.toModelMessage([
  585. {
  586. info: {
  587. id: Identifier.ascending("message"),
  588. role: "user",
  589. sessionID: input.sessionID,
  590. time: {
  591. created: Date.now(),
  592. },
  593. },
  594. parts: userParts,
  595. },
  596. ]),
  597. ],
  598. model: small.language,
  599. })
  600. .then((result) => {
  601. if (result.text)
  602. return Session.update(input.sessionID, (draft) => {
  603. draft.title = result.text
  604. })
  605. })
  606. .catch(() => {})
  607. }
  608. const mode = await Mode.get(inputMode)
  609. let system = SystemPrompt.header(input.providerID)
  610. system.push(
  611. ...(() => {
  612. if (input.system) return [input.system]
  613. if (mode.prompt) return [mode.prompt]
  614. return SystemPrompt.provider(input.modelID)
  615. })(),
  616. )
  617. system.push(...(await SystemPrompt.environment()))
  618. system.push(...(await SystemPrompt.custom()))
  619. // max 2 system prompt messages for caching purposes
  620. const [first, ...rest] = system
  621. system = [first, rest.join("\n")]
  622. const assistantMsg: MessageV2.Info = {
  623. id: Identifier.ascending("message"),
  624. role: "assistant",
  625. system,
  626. mode: inputMode,
  627. path: {
  628. cwd: app.path.cwd,
  629. root: app.path.root,
  630. },
  631. cost: 0,
  632. tokens: {
  633. input: 0,
  634. output: 0,
  635. reasoning: 0,
  636. cache: { read: 0, write: 0 },
  637. },
  638. modelID: input.modelID,
  639. providerID: input.providerID,
  640. time: {
  641. created: Date.now(),
  642. },
  643. sessionID: input.sessionID,
  644. }
  645. await updateMessage(assistantMsg)
  646. const tools: Record<string, AITool> = {}
  647. const processor = createProcessor(assistantMsg, model.info)
  648. const enabledTools = pipe(
  649. mode.tools,
  650. mergeDeep(ToolRegistry.enabled(input.providerID, input.modelID)),
  651. mergeDeep(input.tools ?? {}),
  652. )
  653. for (const item of await ToolRegistry.tools(input.providerID, input.modelID)) {
  654. if (enabledTools[item.id] === false) continue
  655. tools[item.id] = tool({
  656. id: item.id as any,
  657. description: item.description,
  658. inputSchema: item.parameters as ZodSchema,
  659. async execute(args, options) {
  660. await processor.track(options.toolCallId)
  661. const result = await item.execute(args, {
  662. sessionID: input.sessionID,
  663. abort: abort.signal,
  664. messageID: assistantMsg.id,
  665. metadata: async (val) => {
  666. const match = processor.partFromToolCall(options.toolCallId)
  667. if (match && match.state.status === "running") {
  668. await updatePart({
  669. ...match,
  670. state: {
  671. title: val.title,
  672. metadata: val.metadata,
  673. status: "running",
  674. input: args,
  675. time: {
  676. start: Date.now(),
  677. },
  678. },
  679. })
  680. }
  681. },
  682. })
  683. return result
  684. },
  685. toModelOutput(result) {
  686. return {
  687. type: "text",
  688. value: result.output,
  689. }
  690. },
  691. })
  692. }
  693. for (const [key, item] of Object.entries(await MCP.tools())) {
  694. if (mode.tools[key] === false) continue
  695. const execute = item.execute
  696. if (!execute) continue
  697. item.execute = async (args, opts) => {
  698. await processor.track(opts.toolCallId)
  699. const result = await execute(args, opts)
  700. const output = result.content
  701. .filter((x: any) => x.type === "text")
  702. .map((x: any) => x.text)
  703. .join("\n\n")
  704. return {
  705. output,
  706. }
  707. }
  708. item.toModelOutput = (result) => {
  709. return {
  710. type: "text",
  711. value: result.output,
  712. }
  713. }
  714. tools[key] = item
  715. }
  716. const stream = streamText({
  717. onError() {},
  718. async prepareStep({ messages }) {
  719. const queue = (state().queued.get(input.sessionID) ?? []).filter((x) => !x.processed)
  720. if (queue.length) {
  721. for (const item of queue) {
  722. if (item.processed) continue
  723. messages.push(
  724. ...MessageV2.toModelMessage([
  725. {
  726. info: item.message,
  727. parts: item.parts,
  728. },
  729. ]),
  730. )
  731. item.processed = true
  732. }
  733. assistantMsg.time.completed = Date.now()
  734. await updateMessage(assistantMsg)
  735. Object.assign(assistantMsg, {
  736. id: Identifier.ascending("message"),
  737. role: "assistant",
  738. system,
  739. path: {
  740. cwd: app.path.cwd,
  741. root: app.path.root,
  742. },
  743. cost: 0,
  744. tokens: {
  745. input: 0,
  746. output: 0,
  747. reasoning: 0,
  748. cache: { read: 0, write: 0 },
  749. },
  750. modelID: input.modelID,
  751. providerID: input.providerID,
  752. mode: inputMode,
  753. time: {
  754. created: Date.now(),
  755. },
  756. sessionID: input.sessionID,
  757. })
  758. await updateMessage(assistantMsg)
  759. }
  760. return {
  761. messages,
  762. }
  763. },
  764. maxRetries: 10,
  765. maxOutputTokens: outputLimit,
  766. abortSignal: abort.signal,
  767. stopWhen: stepCountIs(1000),
  768. providerOptions: {
  769. [input.providerID]: model.info.options,
  770. },
  771. messages: [
  772. ...system.map(
  773. (x): ModelMessage => ({
  774. role: "system",
  775. content: x,
  776. }),
  777. ),
  778. ...MessageV2.toModelMessage(msgs),
  779. ],
  780. temperature: model.info.temperature
  781. ? (mode.temperature ?? ProviderTransform.temperature(input.providerID, input.modelID))
  782. : undefined,
  783. tools: model.info.tool_call === false ? undefined : tools,
  784. model: wrapLanguageModel({
  785. model: model.language,
  786. middleware: [
  787. {
  788. async transformParams(args) {
  789. if (args.type === "stream") {
  790. // @ts-expect-error
  791. args.params.prompt = ProviderTransform.message(args.params.prompt, input.providerID, input.modelID)
  792. }
  793. return args.params
  794. },
  795. },
  796. ],
  797. }),
  798. })
  799. const result = await processor.process(stream)
  800. const queued = state().queued.get(input.sessionID) ?? []
  801. const unprocessed = queued.find((x) => !x.processed)
  802. if (unprocessed) {
  803. unprocessed.processed = true
  804. return chat(unprocessed.input)
  805. }
  806. for (const item of queued) {
  807. item.callback(result)
  808. }
  809. state().queued.delete(input.sessionID)
  810. return result
  811. }
  812. function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
  813. const toolCalls: Record<string, MessageV2.ToolPart> = {}
  814. const snapshots: Record<string, string> = {}
  815. return {
  816. async track(toolCallID: string) {
  817. const hash = await Snapshot.track()
  818. if (hash) snapshots[toolCallID] = hash
  819. },
  820. partFromToolCall(toolCallID: string) {
  821. return toolCalls[toolCallID]
  822. },
  823. async process(stream: StreamTextResult<Record<string, AITool>, never>) {
  824. try {
  825. let currentText: MessageV2.TextPart | undefined
  826. for await (const value of stream.fullStream) {
  827. log.info("part", {
  828. type: value.type,
  829. })
  830. switch (value.type) {
  831. case "start":
  832. break
  833. case "tool-input-start":
  834. const part = await updatePart({
  835. id: Identifier.ascending("part"),
  836. messageID: assistantMsg.id,
  837. sessionID: assistantMsg.sessionID,
  838. type: "tool",
  839. tool: value.toolName,
  840. callID: value.id,
  841. state: {
  842. status: "pending",
  843. },
  844. })
  845. toolCalls[value.id] = part as MessageV2.ToolPart
  846. break
  847. case "tool-input-delta":
  848. break
  849. case "tool-input-end":
  850. break
  851. case "tool-call": {
  852. const match = toolCalls[value.toolCallId]
  853. if (match) {
  854. const part = await updatePart({
  855. ...match,
  856. state: {
  857. status: "running",
  858. input: value.input,
  859. time: {
  860. start: Date.now(),
  861. },
  862. },
  863. })
  864. toolCalls[value.toolCallId] = part as MessageV2.ToolPart
  865. }
  866. break
  867. }
  868. case "tool-result": {
  869. const match = toolCalls[value.toolCallId]
  870. if (match && match.state.status === "running") {
  871. await updatePart({
  872. ...match,
  873. state: {
  874. status: "completed",
  875. input: value.input,
  876. output: value.output.output,
  877. metadata: value.output.metadata,
  878. title: value.output.title,
  879. time: {
  880. start: match.state.time.start,
  881. end: Date.now(),
  882. },
  883. },
  884. })
  885. delete toolCalls[value.toolCallId]
  886. const snapshot = snapshots[value.toolCallId]
  887. if (snapshot) {
  888. const patch = await Snapshot.patch(snapshot)
  889. if (patch.files.length) {
  890. await updatePart({
  891. id: Identifier.ascending("part"),
  892. messageID: assistantMsg.id,
  893. sessionID: assistantMsg.sessionID,
  894. type: "patch",
  895. hash: patch.hash,
  896. files: patch.files,
  897. })
  898. }
  899. }
  900. }
  901. break
  902. }
  903. case "tool-error": {
  904. const match = toolCalls[value.toolCallId]
  905. if (match && match.state.status === "running") {
  906. await updatePart({
  907. ...match,
  908. state: {
  909. status: "error",
  910. input: value.input,
  911. error: (value.error as any).toString(),
  912. time: {
  913. start: match.state.time.start,
  914. end: Date.now(),
  915. },
  916. },
  917. })
  918. delete toolCalls[value.toolCallId]
  919. const snapshot = snapshots[value.toolCallId]
  920. if (snapshot) {
  921. const patch = await Snapshot.patch(snapshot)
  922. await updatePart({
  923. id: Identifier.ascending("part"),
  924. messageID: assistantMsg.id,
  925. sessionID: assistantMsg.sessionID,
  926. type: "patch",
  927. hash: patch.hash,
  928. files: patch.files,
  929. })
  930. }
  931. }
  932. break
  933. }
  934. case "error":
  935. throw value.error
  936. case "start-step":
  937. await updatePart({
  938. id: Identifier.ascending("part"),
  939. messageID: assistantMsg.id,
  940. sessionID: assistantMsg.sessionID,
  941. type: "step-start",
  942. })
  943. break
  944. case "finish-step":
  945. const usage = getUsage(model, value.usage, value.providerMetadata)
  946. assistantMsg.cost += usage.cost
  947. assistantMsg.tokens = usage.tokens
  948. await updatePart({
  949. id: Identifier.ascending("part"),
  950. messageID: assistantMsg.id,
  951. sessionID: assistantMsg.sessionID,
  952. type: "step-finish",
  953. tokens: usage.tokens,
  954. cost: usage.cost,
  955. })
  956. await updateMessage(assistantMsg)
  957. break
  958. case "text-start":
  959. currentText = {
  960. id: Identifier.ascending("part"),
  961. messageID: assistantMsg.id,
  962. sessionID: assistantMsg.sessionID,
  963. type: "text",
  964. text: "",
  965. time: {
  966. start: Date.now(),
  967. },
  968. }
  969. break
  970. case "text":
  971. if (currentText) {
  972. currentText.text += value.text
  973. await updatePart(currentText)
  974. }
  975. break
  976. case "text-end":
  977. if (currentText && currentText.text) {
  978. currentText.time = {
  979. start: Date.now(),
  980. end: Date.now(),
  981. }
  982. currentText.text = currentText.text.trimEnd()
  983. await updatePart(currentText)
  984. }
  985. currentText = undefined
  986. break
  987. case "finish":
  988. assistantMsg.time.completed = Date.now()
  989. await updateMessage(assistantMsg)
  990. break
  991. default:
  992. log.info("unhandled", {
  993. ...value,
  994. })
  995. continue
  996. }
  997. }
  998. } catch (e) {
  999. log.error("", {
  1000. error: e,
  1001. })
  1002. switch (true) {
  1003. case e instanceof DOMException && e.name === "AbortError":
  1004. assistantMsg.error = new MessageV2.AbortedError(
  1005. { message: e.message },
  1006. {
  1007. cause: e,
  1008. },
  1009. ).toObject()
  1010. break
  1011. case MessageV2.OutputLengthError.isInstance(e):
  1012. assistantMsg.error = e
  1013. break
  1014. case LoadAPIKeyError.isInstance(e):
  1015. assistantMsg.error = new MessageV2.AuthError(
  1016. {
  1017. providerID: model.id,
  1018. message: e.message,
  1019. },
  1020. { cause: e },
  1021. ).toObject()
  1022. break
  1023. case e instanceof Error:
  1024. assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  1025. break
  1026. default:
  1027. assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
  1028. }
  1029. Bus.publish(Event.Error, {
  1030. sessionID: assistantMsg.sessionID,
  1031. error: assistantMsg.error,
  1032. })
  1033. }
  1034. const p = await getParts(assistantMsg.sessionID, assistantMsg.id)
  1035. for (const part of p) {
  1036. if (part.type === "tool" && part.state.status !== "completed") {
  1037. updatePart({
  1038. ...part,
  1039. state: {
  1040. status: "error",
  1041. error: "Tool execution aborted",
  1042. time: {
  1043. start: Date.now(),
  1044. end: Date.now(),
  1045. },
  1046. input: {},
  1047. },
  1048. })
  1049. }
  1050. }
  1051. assistantMsg.time.completed = Date.now()
  1052. await updateMessage(assistantMsg)
  1053. return { info: assistantMsg, parts: p }
  1054. },
  1055. }
  1056. }
  1057. export const RevertInput = z.object({
  1058. sessionID: Identifier.schema("session"),
  1059. messageID: Identifier.schema("message"),
  1060. partID: Identifier.schema("part").optional(),
  1061. })
  1062. export type RevertInput = z.infer<typeof RevertInput>
  1063. export async function revert(input: RevertInput) {
  1064. const all = await messages(input.sessionID)
  1065. let lastUser: MessageV2.User | undefined
  1066. const session = await get(input.sessionID)
  1067. let revert: Info["revert"]
  1068. const patches: Snapshot.Patch[] = []
  1069. for (const msg of all) {
  1070. if (msg.info.role === "user") lastUser = msg.info
  1071. const remaining = []
  1072. for (const part of msg.parts) {
  1073. if (revert) {
  1074. if (part.type === "patch") {
  1075. patches.push(part)
  1076. }
  1077. continue
  1078. }
  1079. if (!revert) {
  1080. if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
  1081. // if no useful parts left in message, same as reverting whole message
  1082. const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
  1083. revert = {
  1084. messageID: !partID && lastUser ? lastUser.id : msg.info.id,
  1085. partID,
  1086. }
  1087. }
  1088. remaining.push(part)
  1089. }
  1090. }
  1091. }
  1092. if (revert) {
  1093. const session = await get(input.sessionID)
  1094. revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
  1095. await Snapshot.revert(patches)
  1096. return update(input.sessionID, (draft) => {
  1097. draft.revert = revert
  1098. })
  1099. }
  1100. return session
  1101. }
  1102. export async function unrevert(input: { sessionID: string }) {
  1103. log.info("unreverting", input)
  1104. const session = await get(input.sessionID)
  1105. if (!session.revert) return session
  1106. if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
  1107. const next = await update(input.sessionID, (draft) => {
  1108. draft.revert = undefined
  1109. })
  1110. return next
  1111. }
  1112. export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
  1113. using abort = lock(input.sessionID)
  1114. const msgs = await messages(input.sessionID)
  1115. const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
  1116. const filtered = msgs.filter((msg) => !lastSummary || msg.info.id >= lastSummary.info.id)
  1117. const model = await Provider.getModel(input.providerID, input.modelID)
  1118. const app = App.info()
  1119. const system = [
  1120. ...SystemPrompt.summarize(input.providerID),
  1121. ...(await SystemPrompt.environment()),
  1122. ...(await SystemPrompt.custom()),
  1123. ]
  1124. const next: MessageV2.Info = {
  1125. id: Identifier.ascending("message"),
  1126. role: "assistant",
  1127. sessionID: input.sessionID,
  1128. system,
  1129. mode: "build",
  1130. path: {
  1131. cwd: app.path.cwd,
  1132. root: app.path.root,
  1133. },
  1134. summary: true,
  1135. cost: 0,
  1136. modelID: input.modelID,
  1137. providerID: input.providerID,
  1138. tokens: {
  1139. input: 0,
  1140. output: 0,
  1141. reasoning: 0,
  1142. cache: { read: 0, write: 0 },
  1143. },
  1144. time: {
  1145. created: Date.now(),
  1146. },
  1147. }
  1148. await updateMessage(next)
  1149. const processor = createProcessor(next, model.info)
  1150. const stream = streamText({
  1151. maxRetries: 10,
  1152. abortSignal: abort.signal,
  1153. model: model.language,
  1154. messages: [
  1155. ...system.map(
  1156. (x): ModelMessage => ({
  1157. role: "system",
  1158. content: x,
  1159. }),
  1160. ),
  1161. ...MessageV2.toModelMessage(filtered),
  1162. {
  1163. role: "user",
  1164. content: [
  1165. {
  1166. type: "text",
  1167. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  1168. },
  1169. ],
  1170. },
  1171. ],
  1172. })
  1173. const result = await processor.process(stream)
  1174. return result
  1175. }
  1176. function isLocked(sessionID: string) {
  1177. return state().pending.has(sessionID)
  1178. }
  1179. function lock(sessionID: string) {
  1180. log.info("locking", { sessionID })
  1181. if (state().pending.has(sessionID)) throw new BusyError(sessionID)
  1182. const controller = new AbortController()
  1183. state().pending.set(sessionID, controller)
  1184. return {
  1185. signal: controller.signal,
  1186. [Symbol.dispose]() {
  1187. log.info("unlocking", { sessionID })
  1188. state().pending.delete(sessionID)
  1189. Bus.publish(Event.Idle, {
  1190. sessionID,
  1191. })
  1192. },
  1193. }
  1194. }
  1195. function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
  1196. const tokens = {
  1197. input: usage.inputTokens ?? 0,
  1198. output: usage.outputTokens ?? 0,
  1199. reasoning: 0,
  1200. cache: {
  1201. write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  1202. // @ts-expect-error
  1203. metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  1204. 0) as number,
  1205. read: usage.cachedInputTokens ?? 0,
  1206. },
  1207. }
  1208. return {
  1209. cost: new Decimal(0)
  1210. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  1211. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  1212. .add(new Decimal(tokens.cache.read).mul(model.cost.cache_read ?? 0).div(1_000_000))
  1213. .add(new Decimal(tokens.cache.write).mul(model.cost.cache_write ?? 0).div(1_000_000))
  1214. .toNumber(),
  1215. tokens,
  1216. }
  1217. }
  1218. export class BusyError extends Error {
  1219. constructor(public readonly sessionID: string) {
  1220. super(`Session ${sessionID} is busy`)
  1221. }
  1222. }
  1223. export async function initialize(input: {
  1224. sessionID: string
  1225. modelID: string
  1226. providerID: string
  1227. messageID: string
  1228. }) {
  1229. const app = App.info()
  1230. await Session.chat({
  1231. sessionID: input.sessionID,
  1232. messageID: input.messageID,
  1233. providerID: input.providerID,
  1234. modelID: input.modelID,
  1235. parts: [
  1236. {
  1237. id: Identifier.ascending("part"),
  1238. type: "text",
  1239. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  1240. },
  1241. ],
  1242. })
  1243. await App.initialize()
  1244. }
  1245. }