index.ts 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253
  1. import path from "path"
  2. import { Decimal } from "decimal.js"
  3. import { z, ZodSchema } from "zod"
  4. import {
  5. generateText,
  6. LoadAPIKeyError,
  7. streamText,
  8. tool,
  9. wrapLanguageModel,
  10. type Tool as AITool,
  11. type LanguageModelUsage,
  12. type ProviderMetadata,
  13. type ModelMessage,
  14. stepCountIs,
  15. type StreamTextResult,
  16. } from "ai"
  17. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  18. import PROMPT_PLAN from "../session/prompt/plan.txt"
  19. import PROMPT_ANTHROPIC_SPOOF from "../session/prompt/anthropic_spoof.txt"
  20. import { App } from "../app/app"
  21. import { Bus } from "../bus"
  22. import { Config } from "../config/config"
  23. import { Flag } from "../flag/flag"
  24. import { Identifier } from "../id/id"
  25. import { Installation } from "../installation"
  26. import { MCP } from "../mcp"
  27. import { Provider } from "../provider/provider"
  28. import { ProviderTransform } from "../provider/transform"
  29. import type { ModelsDev } from "../provider/models"
  30. import { Share } from "../share/share"
  31. import { Snapshot } from "../snapshot"
  32. import { Storage } from "../storage/storage"
  33. import { Log } from "../util/log"
  34. import { NamedError } from "../util/error"
  35. import { SystemPrompt } from "./system"
  36. import { FileTime } from "../file/time"
  37. import { MessageV2 } from "./message-v2"
  38. import { Mode } from "./mode"
  39. import { LSP } from "../lsp"
  40. import { ReadTool } from "../tool/read"
  41. import { splitWhen } from "remeda"
  42. export namespace Session {
  43. const log = Log.create({ service: "session" })
  44. const OUTPUT_TOKEN_MAX = 32_000
  45. export const Info = z
  46. .object({
  47. id: Identifier.schema("session"),
  48. parentID: Identifier.schema("session").optional(),
  49. share: z
  50. .object({
  51. url: z.string(),
  52. })
  53. .optional(),
  54. title: z.string(),
  55. version: z.string(),
  56. time: z.object({
  57. created: z.number(),
  58. updated: z.number(),
  59. }),
  60. revert: z
  61. .object({
  62. messageID: z.string(),
  63. partID: z.string().optional(),
  64. snapshot: z.string().optional(),
  65. })
  66. .optional(),
  67. })
  68. .openapi({
  69. ref: "Session",
  70. })
  71. export type Info = z.output<typeof Info>
  72. export const ShareInfo = z
  73. .object({
  74. secret: z.string(),
  75. url: z.string(),
  76. })
  77. .openapi({
  78. ref: "SessionShare",
  79. })
  80. export type ShareInfo = z.output<typeof ShareInfo>
  81. export const Event = {
  82. Updated: Bus.event(
  83. "session.updated",
  84. z.object({
  85. info: Info,
  86. }),
  87. ),
  88. Deleted: Bus.event(
  89. "session.deleted",
  90. z.object({
  91. info: Info,
  92. }),
  93. ),
  94. Idle: Bus.event(
  95. "session.idle",
  96. z.object({
  97. sessionID: z.string(),
  98. }),
  99. ),
  100. Error: Bus.event(
  101. "session.error",
  102. z.object({
  103. sessionID: z.string().optional(),
  104. error: MessageV2.Assistant.shape.error,
  105. }),
  106. ),
  107. }
  108. const state = App.state(
  109. "session",
  110. () => {
  111. const sessions = new Map<string, Info>()
  112. const messages = new Map<string, MessageV2.Info[]>()
  113. const pending = new Map<string, AbortController>()
  114. const queued = new Map<
  115. string,
  116. {
  117. input: ChatInput
  118. message: MessageV2.User
  119. parts: MessageV2.Part[]
  120. processed: boolean
  121. callback: (input: { info: MessageV2.Assistant; parts: MessageV2.Part[] }) => void
  122. }[]
  123. >()
  124. return {
  125. sessions,
  126. messages,
  127. pending,
  128. queued,
  129. }
  130. },
  131. async (state) => {
  132. for (const [_, controller] of state.pending) {
  133. controller.abort()
  134. }
  135. },
  136. )
  137. export async function create(parentID?: string) {
  138. const result: Info = {
  139. id: Identifier.descending("session"),
  140. version: Installation.VERSION,
  141. parentID,
  142. title: (parentID ? "Child session - " : "New Session - ") + new Date().toISOString(),
  143. time: {
  144. created: Date.now(),
  145. updated: Date.now(),
  146. },
  147. }
  148. log.info("created", result)
  149. state().sessions.set(result.id, result)
  150. await Storage.writeJSON("session/info/" + result.id, result)
  151. const cfg = await Config.get()
  152. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
  153. share(result.id)
  154. .then((share) => {
  155. update(result.id, (draft) => {
  156. draft.share = share
  157. })
  158. })
  159. .catch(() => {
  160. // Silently ignore sharing errors during session creation
  161. })
  162. Bus.publish(Event.Updated, {
  163. info: result,
  164. })
  165. return result
  166. }
  167. export async function get(id: string) {
  168. const result = state().sessions.get(id)
  169. if (result) {
  170. return result
  171. }
  172. const read = await Storage.readJSON<Info>("session/info/" + id)
  173. state().sessions.set(id, read)
  174. return read as Info
  175. }
  176. export async function getShare(id: string) {
  177. return Storage.readJSON<ShareInfo>("session/share/" + id)
  178. }
  179. export async function share(id: string) {
  180. const cfg = await Config.get()
  181. if (cfg.share === "disabled") {
  182. throw new Error("Sharing is disabled in configuration")
  183. }
  184. const session = await get(id)
  185. if (session.share) return session.share
  186. const share = await Share.create(id)
  187. await update(id, (draft) => {
  188. draft.share = {
  189. url: share.url,
  190. }
  191. })
  192. await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
  193. await Share.sync("session/info/" + id, session)
  194. for (const msg of await messages(id)) {
  195. await Share.sync("session/message/" + id + "/" + msg.info.id, msg.info)
  196. for (const part of msg.parts) {
  197. await Share.sync("session/part/" + id + "/" + msg.info.id + "/" + part.id, part)
  198. }
  199. }
  200. return share
  201. }
  202. export async function unshare(id: string) {
  203. const share = await getShare(id)
  204. if (!share) return
  205. await Storage.remove("session/share/" + id)
  206. await update(id, (draft) => {
  207. draft.share = undefined
  208. })
  209. await Share.remove(id, share.secret)
  210. }
  211. export async function update(id: string, editor: (session: Info) => void) {
  212. const { sessions } = state()
  213. const session = await get(id)
  214. if (!session) return
  215. editor(session)
  216. session.time.updated = Date.now()
  217. sessions.set(id, session)
  218. await Storage.writeJSON("session/info/" + id, session)
  219. Bus.publish(Event.Updated, {
  220. info: session,
  221. })
  222. return session
  223. }
  224. export async function messages(sessionID: string) {
  225. const result = [] as {
  226. info: MessageV2.Info
  227. parts: MessageV2.Part[]
  228. }[]
  229. for (const p of await Storage.list("session/message/" + sessionID)) {
  230. const read = await Storage.readJSON<MessageV2.Info>(p)
  231. result.push({
  232. info: read,
  233. parts: await getParts(sessionID, read.id),
  234. })
  235. }
  236. result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
  237. return result
  238. }
  239. export async function getMessage(sessionID: string, messageID: string) {
  240. return Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID)
  241. }
  242. export async function getParts(sessionID: string, messageID: string) {
  243. const result = [] as MessageV2.Part[]
  244. for (const item of await Storage.list("session/part/" + sessionID + "/" + messageID)) {
  245. const read = await Storage.readJSON<MessageV2.Part>(item)
  246. result.push(read)
  247. }
  248. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  249. return result
  250. }
  251. export async function* list() {
  252. for (const item of await Storage.list("session/info")) {
  253. const sessionID = path.basename(item, ".json")
  254. yield get(sessionID)
  255. }
  256. }
  257. export async function children(parentID: string) {
  258. const result = [] as Session.Info[]
  259. for (const item of await Storage.list("session/info")) {
  260. const sessionID = path.basename(item, ".json")
  261. const session = await get(sessionID)
  262. if (session.parentID !== parentID) continue
  263. result.push(session)
  264. }
  265. return result
  266. }
  267. export function abort(sessionID: string) {
  268. const controller = state().pending.get(sessionID)
  269. if (!controller) return false
  270. controller.abort()
  271. state().pending.delete(sessionID)
  272. return true
  273. }
  274. export async function remove(sessionID: string, emitEvent = true) {
  275. try {
  276. abort(sessionID)
  277. const session = await get(sessionID)
  278. for (const child of await children(sessionID)) {
  279. await remove(child.id, false)
  280. }
  281. await unshare(sessionID).catch(() => {})
  282. await Storage.remove(`session/info/${sessionID}`).catch(() => {})
  283. await Storage.removeDir(`session/message/${sessionID}/`).catch(() => {})
  284. state().sessions.delete(sessionID)
  285. state().messages.delete(sessionID)
  286. if (emitEvent) {
  287. Bus.publish(Event.Deleted, {
  288. info: session,
  289. })
  290. }
  291. } catch (e) {
  292. log.error(e)
  293. }
  294. }
  295. async function updateMessage(msg: MessageV2.Info) {
  296. await Storage.writeJSON("session/message/" + msg.sessionID + "/" + msg.id, msg)
  297. Bus.publish(MessageV2.Event.Updated, {
  298. info: msg,
  299. })
  300. }
  301. async function updatePart(part: MessageV2.Part) {
  302. await Storage.writeJSON(["session", "part", part.sessionID, part.messageID, part.id].join("/"), part)
  303. Bus.publish(MessageV2.Event.PartUpdated, {
  304. part,
  305. })
  306. return part
  307. }
  308. export const ChatInput = z.object({
  309. sessionID: Identifier.schema("session"),
  310. messageID: Identifier.schema("message").optional(),
  311. providerID: z.string(),
  312. modelID: z.string(),
  313. mode: z.string().optional(),
  314. tools: z.record(z.boolean()).optional(),
  315. parts: z.array(
  316. z.discriminatedUnion("type", [
  317. MessageV2.TextPart.omit({
  318. messageID: true,
  319. sessionID: true,
  320. })
  321. .partial({
  322. id: true,
  323. })
  324. .openapi({
  325. ref: "TextPartInput",
  326. }),
  327. MessageV2.FilePart.omit({
  328. messageID: true,
  329. sessionID: true,
  330. })
  331. .partial({
  332. id: true,
  333. })
  334. .openapi({
  335. ref: "FilePartInput",
  336. }),
  337. ]),
  338. ),
  339. })
  340. export type ChatInput = z.infer<typeof ChatInput>
  341. export async function chat(
  342. input: z.infer<typeof ChatInput>,
  343. ): Promise<{ info: MessageV2.Assistant; parts: MessageV2.Part[] }> {
  344. const l = log.clone().tag("session", input.sessionID)
  345. l.info("chatting")
  346. const inputMode = input.mode ?? "build"
  347. const userMsg: MessageV2.Info = {
  348. id: input.messageID ?? Identifier.ascending("message"),
  349. role: "user",
  350. sessionID: input.sessionID,
  351. time: {
  352. created: Date.now(),
  353. },
  354. }
  355. const app = App.info()
  356. const userParts = await Promise.all(
  357. input.parts.map(async (part): Promise<MessageV2.Part[]> => {
  358. if (part.type === "file") {
  359. const url = new URL(part.url)
  360. switch (url.protocol) {
  361. case "file:":
  362. // have to normalize, symbol search returns absolute paths
  363. // Decode the pathname since URL constructor doesn't automatically decode it
  364. const pathname = decodeURIComponent(url.pathname)
  365. const relativePath = pathname.replace(app.path.cwd, ".")
  366. const filePath = path.join(app.path.cwd, relativePath)
  367. if (part.mime === "text/plain") {
  368. let offset: number | undefined = undefined
  369. let limit: number | undefined = undefined
  370. const range = {
  371. start: url.searchParams.get("start"),
  372. end: url.searchParams.get("end"),
  373. }
  374. if (range.start != null) {
  375. const filePath = part.url.split("?")[0]
  376. let start = parseInt(range.start)
  377. let end = range.end ? parseInt(range.end) : undefined
  378. // some LSP servers (eg, gopls) don't give full range in
  379. // workspace/symbol searches, so we'll try to find the
  380. // symbol in the document to get the full range
  381. if (start === end) {
  382. const symbols = await LSP.documentSymbol(filePath)
  383. for (const symbol of symbols) {
  384. let range: LSP.Range | undefined
  385. if ("range" in symbol) {
  386. range = symbol.range
  387. } else if ("location" in symbol) {
  388. range = symbol.location.range
  389. }
  390. if (range?.start?.line && range?.start?.line === start) {
  391. start = range.start.line
  392. end = range?.end?.line ?? start
  393. break
  394. }
  395. }
  396. offset = Math.max(start - 2, 0)
  397. if (end) {
  398. limit = end - offset + 2
  399. }
  400. }
  401. }
  402. const args = { filePath, offset, limit }
  403. const result = await ReadTool.execute(args, {
  404. sessionID: input.sessionID,
  405. abort: new AbortController().signal,
  406. messageID: userMsg.id,
  407. metadata: async () => {},
  408. })
  409. return [
  410. {
  411. id: Identifier.ascending("part"),
  412. messageID: userMsg.id,
  413. sessionID: input.sessionID,
  414. type: "text",
  415. synthetic: true,
  416. text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
  417. },
  418. {
  419. id: Identifier.ascending("part"),
  420. messageID: userMsg.id,
  421. sessionID: input.sessionID,
  422. type: "text",
  423. synthetic: true,
  424. text: result.output,
  425. },
  426. {
  427. ...part,
  428. id: part.id ?? Identifier.ascending("part"),
  429. messageID: userMsg.id,
  430. sessionID: input.sessionID,
  431. },
  432. ]
  433. }
  434. let file = Bun.file(filePath)
  435. FileTime.read(input.sessionID, filePath)
  436. return [
  437. {
  438. id: Identifier.ascending("part"),
  439. messageID: userMsg.id,
  440. sessionID: input.sessionID,
  441. type: "text",
  442. text: `Called the Read tool with the following input: {\"filePath\":\"${pathname}\"}`,
  443. synthetic: true,
  444. },
  445. {
  446. id: part.id ?? Identifier.ascending("part"),
  447. messageID: userMsg.id,
  448. sessionID: input.sessionID,
  449. type: "file",
  450. url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
  451. mime: part.mime,
  452. filename: part.filename!,
  453. source: part.source,
  454. },
  455. ]
  456. }
  457. }
  458. return [
  459. {
  460. id: Identifier.ascending("part"),
  461. ...part,
  462. messageID: userMsg.id,
  463. sessionID: input.sessionID,
  464. },
  465. ]
  466. }),
  467. ).then((x) => x.flat())
  468. if (inputMode === "plan")
  469. userParts.push({
  470. id: Identifier.ascending("part"),
  471. messageID: userMsg.id,
  472. sessionID: input.sessionID,
  473. type: "text",
  474. text: PROMPT_PLAN,
  475. synthetic: true,
  476. })
  477. await updateMessage(userMsg)
  478. for (const part of userParts) {
  479. await updatePart(part)
  480. }
  481. // mark session as updated since a message has been added to it
  482. await update(input.sessionID, (_draft) => {})
  483. if (isLocked(input.sessionID)) {
  484. return new Promise((resolve) => {
  485. const queue = state().queued.get(input.sessionID) ?? []
  486. queue.push({
  487. input: input,
  488. message: userMsg,
  489. parts: userParts,
  490. processed: false,
  491. callback: resolve,
  492. })
  493. state().queued.set(input.sessionID, queue)
  494. })
  495. }
  496. const model = await Provider.getModel(input.providerID, input.modelID)
  497. let msgs = await messages(input.sessionID)
  498. const session = await get(input.sessionID)
  499. if (session.revert) {
  500. const messageID = session.revert.messageID
  501. const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
  502. msgs = preserve
  503. for (const msg of remove) {
  504. await Storage.remove(`session/message/${input.sessionID}/${msg.info.id}`)
  505. await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: msg.info.id })
  506. }
  507. const last = preserve.at(-1)
  508. if (session.revert.partID && last) {
  509. const partID = session.revert.partID
  510. const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
  511. last.parts = preserveParts
  512. for (const part of removeParts) {
  513. await Storage.remove(`session/part/${input.sessionID}/${last.info.id}/${part.id}`)
  514. await Bus.publish(MessageV2.Event.PartRemoved, {
  515. messageID: last.info.id,
  516. partID: part.id,
  517. })
  518. }
  519. }
  520. }
  521. const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
  522. const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
  523. // auto summarize if too long
  524. if (previous && previous.tokens) {
  525. const tokens =
  526. previous.tokens.input + previous.tokens.cache.read + previous.tokens.cache.write + previous.tokens.output
  527. if (model.info.limit.context && tokens > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
  528. await summarize({
  529. sessionID: input.sessionID,
  530. providerID: input.providerID,
  531. modelID: input.modelID,
  532. })
  533. return chat(input)
  534. }
  535. }
  536. using abort = lock(input.sessionID)
  537. const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
  538. if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
  539. if (msgs.length === 1 && !session.parentID) {
  540. const small = (await Provider.getSmallModel(input.providerID)) ?? model
  541. generateText({
  542. maxOutputTokens: small.info.reasoning ? 1024 : 20,
  543. providerOptions: {
  544. [input.providerID]: small.info.options,
  545. },
  546. messages: [
  547. ...SystemPrompt.title(input.providerID).map(
  548. (x): ModelMessage => ({
  549. role: "system",
  550. content: x,
  551. }),
  552. ),
  553. ...MessageV2.toModelMessage([
  554. {
  555. info: {
  556. id: Identifier.ascending("message"),
  557. role: "user",
  558. sessionID: input.sessionID,
  559. time: {
  560. created: Date.now(),
  561. },
  562. },
  563. parts: userParts,
  564. },
  565. ]),
  566. ],
  567. model: small.language,
  568. })
  569. .then((result) => {
  570. if (result.text)
  571. return Session.update(input.sessionID, (draft) => {
  572. draft.title = result.text
  573. })
  574. })
  575. .catch(() => {})
  576. }
  577. const mode = await Mode.get(inputMode)
  578. let system = input.providerID === "anthropic" ? [PROMPT_ANTHROPIC_SPOOF.trim()] : []
  579. system.push(...(mode.prompt ? [mode.prompt] : SystemPrompt.provider(input.modelID)))
  580. system.push(...(await SystemPrompt.environment()))
  581. system.push(...(await SystemPrompt.custom()))
  582. // max 2 system prompt messages for caching purposes
  583. const [first, ...rest] = system
  584. system = [first, rest.join("\n")]
  585. const assistantMsg: MessageV2.Info = {
  586. id: Identifier.ascending("message"),
  587. role: "assistant",
  588. system,
  589. mode: inputMode,
  590. path: {
  591. cwd: app.path.cwd,
  592. root: app.path.root,
  593. },
  594. cost: 0,
  595. tokens: {
  596. input: 0,
  597. output: 0,
  598. reasoning: 0,
  599. cache: { read: 0, write: 0 },
  600. },
  601. modelID: input.modelID,
  602. providerID: input.providerID,
  603. time: {
  604. created: Date.now(),
  605. },
  606. sessionID: input.sessionID,
  607. }
  608. await updateMessage(assistantMsg)
  609. const tools: Record<string, AITool> = {}
  610. const processor = createProcessor(assistantMsg, model.info)
  611. for (const item of await Provider.tools(input.providerID)) {
  612. if (mode.tools[item.id] === false) continue
  613. if (input.tools?.[item.id] === false) continue
  614. if (session.parentID && item.id === "task") continue
  615. tools[item.id] = tool({
  616. id: item.id as any,
  617. description: item.description,
  618. inputSchema: item.parameters as ZodSchema,
  619. async execute(args, options) {
  620. const result = await item.execute(args, {
  621. sessionID: input.sessionID,
  622. abort: abort.signal,
  623. messageID: assistantMsg.id,
  624. metadata: async (val) => {
  625. const match = processor.partFromToolCall(options.toolCallId)
  626. if (match && match.state.status === "running") {
  627. await updatePart({
  628. ...match,
  629. state: {
  630. title: val.title,
  631. metadata: val.metadata,
  632. status: "running",
  633. input: args,
  634. time: {
  635. start: Date.now(),
  636. },
  637. },
  638. })
  639. }
  640. },
  641. })
  642. return result
  643. },
  644. toModelOutput(result) {
  645. return {
  646. type: "text",
  647. value: result.output,
  648. }
  649. },
  650. })
  651. }
  652. for (const [key, item] of Object.entries(await MCP.tools())) {
  653. if (mode.tools[key] === false) continue
  654. const execute = item.execute
  655. if (!execute) continue
  656. item.execute = async (args, opts) => {
  657. const result = await execute(args, opts)
  658. const output = result.content
  659. .filter((x: any) => x.type === "text")
  660. .map((x: any) => x.text)
  661. .join("\n\n")
  662. return {
  663. output,
  664. }
  665. }
  666. item.toModelOutput = (result) => {
  667. return {
  668. type: "text",
  669. value: result.output,
  670. }
  671. }
  672. tools[key] = item
  673. }
  674. const stream = streamText({
  675. onError() {},
  676. async prepareStep({ messages }) {
  677. const queue = (state().queued.get(input.sessionID) ?? []).filter((x) => !x.processed)
  678. if (queue.length) {
  679. for (const item of queue) {
  680. if (item.processed) continue
  681. messages.push(
  682. ...MessageV2.toModelMessage([
  683. {
  684. info: item.message,
  685. parts: item.parts,
  686. },
  687. ]),
  688. )
  689. item.processed = true
  690. }
  691. assistantMsg.time.completed = Date.now()
  692. await updateMessage(assistantMsg)
  693. Object.assign(assistantMsg, {
  694. id: Identifier.ascending("message"),
  695. role: "assistant",
  696. system,
  697. path: {
  698. cwd: app.path.cwd,
  699. root: app.path.root,
  700. },
  701. cost: 0,
  702. tokens: {
  703. input: 0,
  704. output: 0,
  705. reasoning: 0,
  706. cache: { read: 0, write: 0 },
  707. },
  708. modelID: input.modelID,
  709. providerID: input.providerID,
  710. time: {
  711. created: Date.now(),
  712. },
  713. sessionID: input.sessionID,
  714. })
  715. await updateMessage(assistantMsg)
  716. }
  717. return {
  718. messages,
  719. }
  720. },
  721. maxRetries: 10,
  722. maxOutputTokens: outputLimit,
  723. abortSignal: abort.signal,
  724. stopWhen: stepCountIs(1000),
  725. providerOptions: {
  726. [input.providerID]: model.info.options,
  727. },
  728. messages: [
  729. ...system.map(
  730. (x): ModelMessage => ({
  731. role: "system",
  732. content: x,
  733. }),
  734. ),
  735. ...MessageV2.toModelMessage(msgs),
  736. ],
  737. temperature: model.info.temperature ? 0 : undefined,
  738. tools: model.info.tool_call === false ? undefined : tools,
  739. model: wrapLanguageModel({
  740. model: model.language,
  741. middleware: [
  742. {
  743. async transformParams(args) {
  744. if (args.type === "stream") {
  745. // @ts-expect-error
  746. args.params.prompt = ProviderTransform.message(args.params.prompt, input.providerID, input.modelID)
  747. }
  748. return args.params
  749. },
  750. },
  751. ],
  752. }),
  753. })
  754. const result = await processor.process(stream)
  755. const queued = state().queued.get(input.sessionID) ?? []
  756. const unprocessed = queued.find((x) => !x.processed)
  757. if (unprocessed) {
  758. unprocessed.processed = true
  759. return chat(unprocessed.input)
  760. }
  761. for (const item of queued) {
  762. item.callback(result)
  763. }
  764. state().queued.delete(input.sessionID)
  765. return result
  766. }
  767. function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
  768. const toolCalls: Record<string, MessageV2.ToolPart> = {}
  769. return {
  770. partFromToolCall(toolCallID: string) {
  771. return toolCalls[toolCallID]
  772. },
  773. async process(stream: StreamTextResult<Record<string, AITool>, never>) {
  774. try {
  775. let currentText: MessageV2.TextPart | undefined
  776. for await (const value of stream.fullStream) {
  777. log.info("part", {
  778. type: value.type,
  779. })
  780. switch (value.type) {
  781. case "start":
  782. const snapshot = await Snapshot.create()
  783. if (snapshot)
  784. await updatePart({
  785. id: Identifier.ascending("part"),
  786. messageID: assistantMsg.id,
  787. sessionID: assistantMsg.sessionID,
  788. type: "snapshot",
  789. snapshot,
  790. })
  791. break
  792. case "tool-input-start":
  793. const part = await updatePart({
  794. id: Identifier.ascending("part"),
  795. messageID: assistantMsg.id,
  796. sessionID: assistantMsg.sessionID,
  797. type: "tool",
  798. tool: value.toolName,
  799. callID: value.id,
  800. state: {
  801. status: "pending",
  802. },
  803. })
  804. toolCalls[value.id] = part as MessageV2.ToolPart
  805. break
  806. case "tool-input-delta":
  807. break
  808. case "tool-call": {
  809. const match = toolCalls[value.toolCallId]
  810. if (match) {
  811. const part = await updatePart({
  812. ...match,
  813. state: {
  814. status: "running",
  815. input: value.input,
  816. time: {
  817. start: Date.now(),
  818. },
  819. },
  820. })
  821. toolCalls[value.toolCallId] = part as MessageV2.ToolPart
  822. }
  823. break
  824. }
  825. case "tool-result": {
  826. const match = toolCalls[value.toolCallId]
  827. if (match && match.state.status === "running") {
  828. await updatePart({
  829. ...match,
  830. state: {
  831. status: "completed",
  832. input: value.input,
  833. output: value.output.output,
  834. metadata: value.output.metadata,
  835. title: value.output.title,
  836. time: {
  837. start: match.state.time.start,
  838. end: Date.now(),
  839. },
  840. },
  841. })
  842. delete toolCalls[value.toolCallId]
  843. const snapshot = await Snapshot.create()
  844. if (snapshot)
  845. await updatePart({
  846. id: Identifier.ascending("part"),
  847. messageID: assistantMsg.id,
  848. sessionID: assistantMsg.sessionID,
  849. type: "snapshot",
  850. snapshot,
  851. })
  852. }
  853. break
  854. }
  855. case "tool-error": {
  856. const match = toolCalls[value.toolCallId]
  857. if (match && match.state.status === "running") {
  858. await updatePart({
  859. ...match,
  860. state: {
  861. status: "error",
  862. input: value.input,
  863. error: (value.error as any).toString(),
  864. time: {
  865. start: match.state.time.start,
  866. end: Date.now(),
  867. },
  868. },
  869. })
  870. delete toolCalls[value.toolCallId]
  871. const snapshot = await Snapshot.create()
  872. if (snapshot)
  873. await updatePart({
  874. id: Identifier.ascending("part"),
  875. messageID: assistantMsg.id,
  876. sessionID: assistantMsg.sessionID,
  877. type: "snapshot",
  878. snapshot,
  879. })
  880. }
  881. break
  882. }
  883. case "error":
  884. throw value.error
  885. case "start-step":
  886. await updatePart({
  887. id: Identifier.ascending("part"),
  888. messageID: assistantMsg.id,
  889. sessionID: assistantMsg.sessionID,
  890. type: "step-start",
  891. })
  892. break
  893. case "finish-step":
  894. const usage = getUsage(model, value.usage, value.providerMetadata)
  895. assistantMsg.cost += usage.cost
  896. assistantMsg.tokens = usage.tokens
  897. await updatePart({
  898. id: Identifier.ascending("part"),
  899. messageID: assistantMsg.id,
  900. sessionID: assistantMsg.sessionID,
  901. type: "step-finish",
  902. tokens: usage.tokens,
  903. cost: usage.cost,
  904. })
  905. await updateMessage(assistantMsg)
  906. break
  907. case "text-start":
  908. currentText = {
  909. id: Identifier.ascending("part"),
  910. messageID: assistantMsg.id,
  911. sessionID: assistantMsg.sessionID,
  912. type: "text",
  913. text: "",
  914. time: {
  915. start: Date.now(),
  916. },
  917. }
  918. break
  919. case "text":
  920. if (currentText) {
  921. currentText.text += value.text
  922. await updatePart(currentText)
  923. }
  924. break
  925. case "text-end":
  926. if (currentText && currentText.text) {
  927. currentText.time = {
  928. start: Date.now(),
  929. end: Date.now(),
  930. }
  931. await updatePart(currentText)
  932. }
  933. currentText = undefined
  934. break
  935. case "finish":
  936. assistantMsg.time.completed = Date.now()
  937. await updateMessage(assistantMsg)
  938. break
  939. default:
  940. log.info("unhandled", {
  941. ...value,
  942. })
  943. continue
  944. }
  945. }
  946. } catch (e) {
  947. log.error("", {
  948. error: e,
  949. })
  950. switch (true) {
  951. case e instanceof DOMException && e.name === "AbortError":
  952. assistantMsg.error = new MessageV2.AbortedError(
  953. { message: e.message },
  954. {
  955. cause: e,
  956. },
  957. ).toObject()
  958. break
  959. case MessageV2.OutputLengthError.isInstance(e):
  960. assistantMsg.error = e
  961. break
  962. case LoadAPIKeyError.isInstance(e):
  963. assistantMsg.error = new MessageV2.AuthError(
  964. {
  965. providerID: model.id,
  966. message: e.message,
  967. },
  968. { cause: e },
  969. ).toObject()
  970. break
  971. case e instanceof Error:
  972. assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  973. break
  974. default:
  975. assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
  976. }
  977. Bus.publish(Event.Error, {
  978. sessionID: assistantMsg.sessionID,
  979. error: assistantMsg.error,
  980. })
  981. }
  982. const p = await getParts(assistantMsg.sessionID, assistantMsg.id)
  983. for (const part of p) {
  984. if (part.type === "tool" && part.state.status !== "completed") {
  985. updatePart({
  986. ...part,
  987. state: {
  988. status: "error",
  989. error: "Tool execution aborted",
  990. time: {
  991. start: Date.now(),
  992. end: Date.now(),
  993. },
  994. input: {},
  995. },
  996. })
  997. }
  998. }
  999. assistantMsg.time.completed = Date.now()
  1000. await updateMessage(assistantMsg)
  1001. return { info: assistantMsg, parts: p }
  1002. },
  1003. }
  1004. }
  1005. export const RevertInput = z.object({
  1006. sessionID: Identifier.schema("session"),
  1007. messageID: Identifier.schema("message"),
  1008. partID: Identifier.schema("part").optional(),
  1009. })
  1010. export type RevertInput = z.infer<typeof RevertInput>
  1011. export async function revert(input: RevertInput) {
  1012. const all = await messages(input.sessionID)
  1013. const session = await get(input.sessionID)
  1014. let lastUser: MessageV2.User | undefined
  1015. let lastSnapshot: MessageV2.SnapshotPart | undefined
  1016. for (const msg of all) {
  1017. if (msg.info.role === "user") lastUser = msg.info
  1018. const remaining = []
  1019. for (const part of msg.parts) {
  1020. if (part.type === "snapshot") lastSnapshot = part
  1021. if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
  1022. // if no useful parts left in message, same as reverting whole message
  1023. const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
  1024. const snapshot = session.revert?.snapshot ?? (await Snapshot.create())
  1025. log.info("revert snapshot", { snapshot })
  1026. if (lastSnapshot) await Snapshot.restore(lastSnapshot.snapshot)
  1027. const next = await update(input.sessionID, (draft) => {
  1028. draft.revert = {
  1029. // if not part id jump to the last user message
  1030. messageID: !partID && lastUser ? lastUser.id : msg.info.id,
  1031. partID,
  1032. snapshot,
  1033. }
  1034. })
  1035. return next
  1036. }
  1037. remaining.push(part)
  1038. }
  1039. }
  1040. }
  1041. export async function unrevert(input: { sessionID: string }) {
  1042. log.info("unreverting", input)
  1043. const session = await get(input.sessionID)
  1044. if (!session.revert) return session
  1045. if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
  1046. const next = await update(input.sessionID, (draft) => {
  1047. draft.revert = undefined
  1048. })
  1049. return next
  1050. }
  1051. export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
  1052. using abort = lock(input.sessionID)
  1053. const msgs = await messages(input.sessionID)
  1054. const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
  1055. const filtered = msgs.filter((msg) => !lastSummary || msg.info.id >= lastSummary.info.id)
  1056. const model = await Provider.getModel(input.providerID, input.modelID)
  1057. const app = App.info()
  1058. const system = SystemPrompt.summarize(input.providerID)
  1059. const next: MessageV2.Info = {
  1060. id: Identifier.ascending("message"),
  1061. role: "assistant",
  1062. sessionID: input.sessionID,
  1063. system,
  1064. mode: "build",
  1065. path: {
  1066. cwd: app.path.cwd,
  1067. root: app.path.root,
  1068. },
  1069. summary: true,
  1070. cost: 0,
  1071. modelID: input.modelID,
  1072. providerID: input.providerID,
  1073. tokens: {
  1074. input: 0,
  1075. output: 0,
  1076. reasoning: 0,
  1077. cache: { read: 0, write: 0 },
  1078. },
  1079. time: {
  1080. created: Date.now(),
  1081. },
  1082. }
  1083. await updateMessage(next)
  1084. const processor = createProcessor(next, model.info)
  1085. const stream = streamText({
  1086. maxRetries: 10,
  1087. abortSignal: abort.signal,
  1088. model: model.language,
  1089. messages: [
  1090. ...system.map(
  1091. (x): ModelMessage => ({
  1092. role: "system",
  1093. content: x,
  1094. }),
  1095. ),
  1096. ...MessageV2.toModelMessage(filtered),
  1097. {
  1098. role: "user",
  1099. content: [
  1100. {
  1101. type: "text",
  1102. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  1103. },
  1104. ],
  1105. },
  1106. ],
  1107. })
  1108. const result = await processor.process(stream)
  1109. return result
  1110. }
  1111. function isLocked(sessionID: string) {
  1112. return state().pending.has(sessionID)
  1113. }
  1114. function lock(sessionID: string) {
  1115. log.info("locking", { sessionID })
  1116. if (state().pending.has(sessionID)) throw new BusyError(sessionID)
  1117. const controller = new AbortController()
  1118. state().pending.set(sessionID, controller)
  1119. return {
  1120. signal: controller.signal,
  1121. [Symbol.dispose]() {
  1122. log.info("unlocking", { sessionID })
  1123. state().pending.delete(sessionID)
  1124. Bus.publish(Event.Idle, {
  1125. sessionID,
  1126. })
  1127. },
  1128. }
  1129. }
  1130. function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
  1131. const tokens = {
  1132. input: usage.inputTokens ?? 0,
  1133. output: usage.outputTokens ?? 0,
  1134. reasoning: 0,
  1135. cache: {
  1136. write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  1137. // @ts-expect-error
  1138. metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  1139. 0) as number,
  1140. read: usage.cachedInputTokens ?? 0,
  1141. },
  1142. }
  1143. return {
  1144. cost: new Decimal(0)
  1145. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  1146. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  1147. .add(new Decimal(tokens.cache.read).mul(model.cost.cache_read ?? 0).div(1_000_000))
  1148. .add(new Decimal(tokens.cache.write).mul(model.cost.cache_write ?? 0).div(1_000_000))
  1149. .toNumber(),
  1150. tokens,
  1151. }
  1152. }
  1153. export class BusyError extends Error {
  1154. constructor(public readonly sessionID: string) {
  1155. super(`Session ${sessionID} is busy`)
  1156. }
  1157. }
  1158. export async function initialize(input: {
  1159. sessionID: string
  1160. modelID: string
  1161. providerID: string
  1162. messageID: string
  1163. }) {
  1164. const app = App.info()
  1165. await Session.chat({
  1166. sessionID: input.sessionID,
  1167. messageID: input.messageID,
  1168. providerID: input.providerID,
  1169. modelID: input.modelID,
  1170. parts: [
  1171. {
  1172. id: Identifier.ascending("part"),
  1173. type: "text",
  1174. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  1175. },
  1176. ],
  1177. })
  1178. await App.initialize()
  1179. }
  1180. }