index.ts 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116
  1. import path from "path"
  2. import { Decimal } from "decimal.js"
  3. import { z, ZodSchema } from "zod"
  4. import {
  5. generateText,
  6. LoadAPIKeyError,
  7. streamText,
  8. tool,
  9. wrapLanguageModel,
  10. type Tool as AITool,
  11. type LanguageModelUsage,
  12. type ProviderMetadata,
  13. type ModelMessage,
  14. stepCountIs,
  15. type StreamTextResult,
  16. } from "ai"
  17. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  18. import PROMPT_PLAN from "../session/prompt/plan.txt"
  19. import PROMPT_ANTHROPIC_SPOOF from "../session/prompt/anthropic_spoof.txt"
  20. import { App } from "../app/app"
  21. import { Bus } from "../bus"
  22. import { Config } from "../config/config"
  23. import { Flag } from "../flag/flag"
  24. import { Identifier } from "../id/id"
  25. import { Installation } from "../installation"
  26. import { MCP } from "../mcp"
  27. import { Provider } from "../provider/provider"
  28. import { ProviderTransform } from "../provider/transform"
  29. import type { ModelsDev } from "../provider/models"
  30. import { Share } from "../share/share"
  31. import { Snapshot } from "../snapshot"
  32. import { Storage } from "../storage/storage"
  33. import { Log } from "../util/log"
  34. import { NamedError } from "../util/error"
  35. import { SystemPrompt } from "./system"
  36. import { FileTime } from "../file/time"
  37. import { MessageV2 } from "./message-v2"
  38. import { Mode } from "./mode"
  39. import { LSP } from "../lsp"
  40. import { ReadTool } from "../tool/read"
  41. export namespace Session {
  42. const log = Log.create({ service: "session" })
  43. const OUTPUT_TOKEN_MAX = 32_000
  44. export const Info = z
  45. .object({
  46. id: Identifier.schema("session"),
  47. parentID: Identifier.schema("session").optional(),
  48. share: z
  49. .object({
  50. url: z.string(),
  51. })
  52. .optional(),
  53. title: z.string(),
  54. version: z.string(),
  55. time: z.object({
  56. created: z.number(),
  57. updated: z.number(),
  58. }),
  59. revert: z
  60. .object({
  61. messageID: z.string(),
  62. part: z.number(),
  63. snapshot: z.string().optional(),
  64. })
  65. .optional(),
  66. })
  67. .openapi({
  68. ref: "Session",
  69. })
  70. export type Info = z.output<typeof Info>
  71. export const ShareInfo = z
  72. .object({
  73. secret: z.string(),
  74. url: z.string(),
  75. })
  76. .openapi({
  77. ref: "SessionShare",
  78. })
  79. export type ShareInfo = z.output<typeof ShareInfo>
  80. export const Event = {
  81. Updated: Bus.event(
  82. "session.updated",
  83. z.object({
  84. info: Info,
  85. }),
  86. ),
  87. Deleted: Bus.event(
  88. "session.deleted",
  89. z.object({
  90. info: Info,
  91. }),
  92. ),
  93. Idle: Bus.event(
  94. "session.idle",
  95. z.object({
  96. sessionID: z.string(),
  97. }),
  98. ),
  99. Error: Bus.event(
  100. "session.error",
  101. z.object({
  102. sessionID: z.string().optional(),
  103. error: MessageV2.Assistant.shape.error,
  104. }),
  105. ),
  106. }
  107. const state = App.state(
  108. "session",
  109. () => {
  110. const sessions = new Map<string, Info>()
  111. const messages = new Map<string, MessageV2.Info[]>()
  112. const pending = new Map<string, AbortController>()
  113. return {
  114. sessions,
  115. messages,
  116. pending,
  117. }
  118. },
  119. async (state) => {
  120. for (const [_, controller] of state.pending) {
  121. controller.abort()
  122. }
  123. },
  124. )
  125. export async function create(parentID?: string) {
  126. const result: Info = {
  127. id: Identifier.descending("session"),
  128. version: Installation.VERSION,
  129. parentID,
  130. title: (parentID ? "Child session - " : "New Session - ") + new Date().toISOString(),
  131. time: {
  132. created: Date.now(),
  133. updated: Date.now(),
  134. },
  135. }
  136. log.info("created", result)
  137. state().sessions.set(result.id, result)
  138. await Storage.writeJSON("session/info/" + result.id, result)
  139. const cfg = await Config.get()
  140. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
  141. share(result.id)
  142. .then((share) => {
  143. update(result.id, (draft) => {
  144. draft.share = share
  145. })
  146. })
  147. .catch(() => {
  148. // Silently ignore sharing errors during session creation
  149. })
  150. Bus.publish(Event.Updated, {
  151. info: result,
  152. })
  153. return result
  154. }
  155. export async function get(id: string) {
  156. const result = state().sessions.get(id)
  157. if (result) {
  158. return result
  159. }
  160. const read = await Storage.readJSON<Info>("session/info/" + id)
  161. state().sessions.set(id, read)
  162. return read as Info
  163. }
  164. export async function getShare(id: string) {
  165. return Storage.readJSON<ShareInfo>("session/share/" + id)
  166. }
  167. export async function share(id: string) {
  168. const cfg = await Config.get()
  169. if (cfg.share === "disabled") {
  170. throw new Error("Sharing is disabled in configuration")
  171. }
  172. const session = await get(id)
  173. if (session.share) return session.share
  174. const share = await Share.create(id)
  175. await update(id, (draft) => {
  176. draft.share = {
  177. url: share.url,
  178. }
  179. })
  180. await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
  181. await Share.sync("session/info/" + id, session)
  182. for (const msg of await messages(id)) {
  183. await Share.sync("session/message/" + id + "/" + msg.info.id, msg.info)
  184. for (const part of msg.parts) {
  185. await Share.sync("session/part/" + id + "/" + msg.info.id + "/" + part.id, part)
  186. }
  187. }
  188. return share
  189. }
  190. export async function unshare(id: string) {
  191. const share = await getShare(id)
  192. if (!share) return
  193. await Storage.remove("session/share/" + id)
  194. await update(id, (draft) => {
  195. draft.share = undefined
  196. })
  197. await Share.remove(id, share.secret)
  198. }
  199. export async function update(id: string, editor: (session: Info) => void) {
  200. const { sessions } = state()
  201. const session = await get(id)
  202. if (!session) return
  203. editor(session)
  204. session.time.updated = Date.now()
  205. sessions.set(id, session)
  206. await Storage.writeJSON("session/info/" + id, session)
  207. Bus.publish(Event.Updated, {
  208. info: session,
  209. })
  210. return session
  211. }
  212. export async function messages(sessionID: string) {
  213. const result = [] as {
  214. info: MessageV2.Info
  215. parts: MessageV2.Part[]
  216. }[]
  217. for (const p of await Storage.list("session/message/" + sessionID)) {
  218. const read = await Storage.readJSON<MessageV2.Info>(p)
  219. result.push({
  220. info: read,
  221. parts: await parts(sessionID, read.id),
  222. })
  223. }
  224. result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
  225. return result
  226. }
  227. export async function getMessage(sessionID: string, messageID: string) {
  228. return Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID)
  229. }
  230. export async function parts(sessionID: string, messageID: string) {
  231. const result = [] as MessageV2.Part[]
  232. for (const item of await Storage.list("session/part/" + sessionID + "/" + messageID)) {
  233. const read = await Storage.readJSON<MessageV2.Part>(item)
  234. result.push(read)
  235. }
  236. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  237. return result
  238. }
  239. export async function* list() {
  240. for (const item of await Storage.list("session/info")) {
  241. const sessionID = path.basename(item, ".json")
  242. yield get(sessionID)
  243. }
  244. }
  245. export async function children(parentID: string) {
  246. const result = [] as Session.Info[]
  247. for (const item of await Storage.list("session/info")) {
  248. const sessionID = path.basename(item, ".json")
  249. const session = await get(sessionID)
  250. if (session.parentID !== parentID) continue
  251. result.push(session)
  252. }
  253. return result
  254. }
  255. export function abort(sessionID: string) {
  256. const controller = state().pending.get(sessionID)
  257. if (!controller) return false
  258. controller.abort()
  259. state().pending.delete(sessionID)
  260. return true
  261. }
  262. export async function remove(sessionID: string, emitEvent = true) {
  263. try {
  264. abort(sessionID)
  265. const session = await get(sessionID)
  266. for (const child of await children(sessionID)) {
  267. await remove(child.id, false)
  268. }
  269. await unshare(sessionID).catch(() => {})
  270. await Storage.remove(`session/info/${sessionID}`).catch(() => {})
  271. await Storage.removeDir(`session/message/${sessionID}/`).catch(() => {})
  272. state().sessions.delete(sessionID)
  273. state().messages.delete(sessionID)
  274. if (emitEvent) {
  275. Bus.publish(Event.Deleted, {
  276. info: session,
  277. })
  278. }
  279. } catch (e) {
  280. log.error(e)
  281. }
  282. }
  283. async function updateMessage(msg: MessageV2.Info) {
  284. await Storage.writeJSON("session/message/" + msg.sessionID + "/" + msg.id, msg)
  285. Bus.publish(MessageV2.Event.Updated, {
  286. info: msg,
  287. })
  288. }
  289. async function updatePart(part: MessageV2.Part) {
  290. await Storage.writeJSON(["session", "part", part.sessionID, part.messageID, part.id].join("/"), part)
  291. Bus.publish(MessageV2.Event.PartUpdated, {
  292. part,
  293. })
  294. return part
  295. }
  296. export async function chat(input: {
  297. sessionID: string
  298. messageID: string
  299. providerID: string
  300. modelID: string
  301. mode?: string
  302. parts: (MessageV2.TextPart | MessageV2.FilePart)[]
  303. }) {
  304. const l = log.clone().tag("session", input.sessionID)
  305. l.info("chatting")
  306. const model = await Provider.getModel(input.providerID, input.modelID)
  307. let msgs = await messages(input.sessionID)
  308. const session = await get(input.sessionID)
  309. if (session.revert) {
  310. const trimmed = []
  311. for (const msg of msgs) {
  312. if (
  313. msg.info.id > session.revert.messageID ||
  314. (msg.info.id === session.revert.messageID && session.revert.part === 0)
  315. ) {
  316. await Storage.remove("session/message/" + input.sessionID + "/" + msg.info.id)
  317. await Bus.publish(MessageV2.Event.Removed, {
  318. sessionID: input.sessionID,
  319. messageID: msg.info.id,
  320. })
  321. continue
  322. }
  323. if (msg.info.id === session.revert.messageID) {
  324. if (session.revert.part === 0) break
  325. msg.parts = msg.parts.slice(0, session.revert.part)
  326. }
  327. trimmed.push(msg)
  328. }
  329. msgs = trimmed
  330. await update(input.sessionID, (draft) => {
  331. draft.revert = undefined
  332. })
  333. }
  334. const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
  335. const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
  336. // auto summarize if too long
  337. if (previous && previous.tokens) {
  338. const tokens =
  339. previous.tokens.input + previous.tokens.cache.read + previous.tokens.cache.write + previous.tokens.output
  340. if (model.info.limit.context && tokens > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
  341. await summarize({
  342. sessionID: input.sessionID,
  343. providerID: input.providerID,
  344. modelID: input.modelID,
  345. })
  346. return chat(input)
  347. }
  348. }
  349. using abort = lock(input.sessionID)
  350. const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
  351. if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
  352. const userMsg: MessageV2.Info = {
  353. id: input.messageID,
  354. role: "user",
  355. sessionID: input.sessionID,
  356. time: {
  357. created: Date.now(),
  358. },
  359. }
  360. const app = App.info()
  361. const userParts = await Promise.all(
  362. input.parts.map(async (part): Promise<MessageV2.Part[]> => {
  363. if (part.type === "file") {
  364. const url = new URL(part.url)
  365. switch (url.protocol) {
  366. case "file:":
  367. // have to normalize, symbol search returns absolute paths
  368. // Decode the pathname since URL constructor doesn't automatically decode it
  369. const pathname = decodeURIComponent(url.pathname)
  370. const relativePath = pathname.replace(app.path.cwd, ".")
  371. const filePath = path.join(app.path.cwd, relativePath)
  372. if (part.mime === "text/plain") {
  373. let offset: number | undefined = undefined
  374. let limit: number | undefined = undefined
  375. const range = {
  376. start: url.searchParams.get("start"),
  377. end: url.searchParams.get("end"),
  378. }
  379. if (range.start != null) {
  380. const filePath = part.url.split("?")[0]
  381. let start = parseInt(range.start)
  382. let end = range.end ? parseInt(range.end) : undefined
  383. // some LSP servers (eg, gopls) don't give full range in
  384. // workspace/symbol searches, so we'll try to find the
  385. // symbol in the document to get the full range
  386. if (start === end) {
  387. const symbols = await LSP.documentSymbol(filePath)
  388. for (const symbol of symbols) {
  389. let range: LSP.Range | undefined
  390. if ("range" in symbol) {
  391. range = symbol.range
  392. } else if ("location" in symbol) {
  393. range = symbol.location.range
  394. }
  395. if (range?.start?.line && range?.start?.line === start) {
  396. start = range.start.line
  397. end = range?.end?.line ?? start
  398. break
  399. }
  400. }
  401. offset = Math.max(start - 2, 0)
  402. if (end) {
  403. limit = end - offset + 2
  404. }
  405. }
  406. }
  407. const args = { filePath, offset, limit }
  408. const result = await ReadTool.execute(args, {
  409. sessionID: input.sessionID,
  410. abort: abort.signal,
  411. messageID: userMsg.id,
  412. metadata: async () => {},
  413. })
  414. return [
  415. {
  416. id: Identifier.ascending("part"),
  417. messageID: userMsg.id,
  418. sessionID: input.sessionID,
  419. type: "text",
  420. synthetic: true,
  421. text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
  422. },
  423. {
  424. id: Identifier.ascending("part"),
  425. messageID: userMsg.id,
  426. sessionID: input.sessionID,
  427. type: "text",
  428. synthetic: true,
  429. text: result.output,
  430. },
  431. ]
  432. }
  433. let file = Bun.file(filePath)
  434. FileTime.read(input.sessionID, filePath)
  435. return [
  436. {
  437. id: Identifier.ascending("part"),
  438. messageID: userMsg.id,
  439. sessionID: input.sessionID,
  440. type: "text",
  441. text: `Called the Read tool with the following input: {\"filePath\":\"${pathname}\"}`,
  442. synthetic: true,
  443. },
  444. {
  445. id: Identifier.ascending("part"),
  446. messageID: userMsg.id,
  447. sessionID: input.sessionID,
  448. type: "file",
  449. url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
  450. mime: part.mime,
  451. filename: part.filename!,
  452. },
  453. ]
  454. }
  455. }
  456. return [part]
  457. }),
  458. ).then((x) => x.flat())
  459. if (input.mode === "plan")
  460. userParts.push({
  461. id: Identifier.ascending("part"),
  462. messageID: userMsg.id,
  463. sessionID: input.sessionID,
  464. type: "text",
  465. text: PROMPT_PLAN,
  466. synthetic: true,
  467. })
  468. if (msgs.length === 0 && !session.parentID) {
  469. const small = (await Provider.getSmallModel(input.providerID)) ?? model
  470. generateText({
  471. maxOutputTokens: input.providerID === "google" ? 1024 : 20,
  472. providerOptions: small.info.options,
  473. messages: [
  474. ...SystemPrompt.title(input.providerID).map(
  475. (x): ModelMessage => ({
  476. role: "system",
  477. content: x,
  478. }),
  479. ),
  480. ...MessageV2.toModelMessage([
  481. {
  482. info: {
  483. id: Identifier.ascending("message"),
  484. role: "user",
  485. sessionID: input.sessionID,
  486. time: {
  487. created: Date.now(),
  488. },
  489. },
  490. parts: userParts,
  491. },
  492. ]),
  493. ],
  494. model: small.language,
  495. })
  496. .then((result) => {
  497. if (result.text)
  498. return Session.update(input.sessionID, (draft) => {
  499. draft.title = result.text
  500. })
  501. })
  502. .catch(() => {})
  503. }
  504. await updateMessage(userMsg)
  505. for (const part of userParts) {
  506. await updatePart(part)
  507. }
  508. msgs.push({ info: userMsg, parts: userParts })
  509. const mode = await Mode.get(input.mode ?? "build")
  510. let system = input.providerID === "anthropic" ? [PROMPT_ANTHROPIC_SPOOF.trim()] : []
  511. system.push(...(mode.prompt ? [mode.prompt] : SystemPrompt.provider(input.modelID)))
  512. system.push(...(await SystemPrompt.environment()))
  513. system.push(...(await SystemPrompt.custom()))
  514. // max 2 system prompt messages for caching purposes
  515. const [first, ...rest] = system
  516. system = [first, rest.join("\n")]
  517. const assistantMsg: MessageV2.Info = {
  518. id: Identifier.ascending("message"),
  519. role: "assistant",
  520. system,
  521. path: {
  522. cwd: app.path.cwd,
  523. root: app.path.root,
  524. },
  525. cost: 0,
  526. tokens: {
  527. input: 0,
  528. output: 0,
  529. reasoning: 0,
  530. cache: { read: 0, write: 0 },
  531. },
  532. modelID: input.modelID,
  533. providerID: input.providerID,
  534. time: {
  535. created: Date.now(),
  536. },
  537. sessionID: input.sessionID,
  538. }
  539. await updateMessage(assistantMsg)
  540. const tools: Record<string, AITool> = {}
  541. const processor = createProcessor(assistantMsg, model.info)
  542. for (const item of await Provider.tools(input.providerID)) {
  543. if (mode.tools[item.id] === false) continue
  544. if (session.parentID && item.id === "task") continue
  545. tools[item.id] = tool({
  546. id: item.id as any,
  547. description: item.description,
  548. inputSchema: item.parameters as ZodSchema,
  549. async execute(args, options) {
  550. const result = await item.execute(args, {
  551. sessionID: input.sessionID,
  552. abort: abort.signal,
  553. messageID: assistantMsg.id,
  554. metadata: async (val) => {
  555. const match = processor.partFromToolCall(options.toolCallId)
  556. if (match && match.state.status === "running") {
  557. await updatePart({
  558. ...match,
  559. state: {
  560. title: val.title,
  561. metadata: val.metadata,
  562. status: "running",
  563. input: args,
  564. time: {
  565. start: Date.now(),
  566. },
  567. },
  568. })
  569. }
  570. },
  571. })
  572. return result
  573. },
  574. toModelOutput(result) {
  575. return {
  576. type: "text",
  577. value: result.output,
  578. }
  579. },
  580. })
  581. }
  582. for (const [key, item] of Object.entries(await MCP.tools())) {
  583. if (mode.tools[key] === false) continue
  584. const execute = item.execute
  585. if (!execute) continue
  586. item.execute = async (args, opts) => {
  587. const result = await execute(args, opts)
  588. const output = result.content
  589. .filter((x: any) => x.type === "text")
  590. .map((x: any) => x.text)
  591. .join("\n\n")
  592. return {
  593. output,
  594. }
  595. }
  596. item.toModelOutput = (result) => {
  597. return {
  598. type: "text",
  599. value: result.output,
  600. }
  601. }
  602. tools[key] = item
  603. }
  604. const stream = streamText({
  605. onError() {},
  606. maxRetries: 10,
  607. maxOutputTokens: outputLimit,
  608. abortSignal: abort.signal,
  609. stopWhen: stepCountIs(1000),
  610. providerOptions: model.info.options,
  611. messages: [
  612. ...system.map(
  613. (x): ModelMessage => ({
  614. role: "system",
  615. content: x,
  616. }),
  617. ),
  618. ...MessageV2.toModelMessage(msgs),
  619. ],
  620. temperature: model.info.temperature ? 0 : undefined,
  621. tools: model.info.tool_call === false ? undefined : tools,
  622. model: wrapLanguageModel({
  623. model: model.language,
  624. middleware: [
  625. {
  626. async transformParams(args) {
  627. if (args.type === "stream") {
  628. // @ts-expect-error
  629. args.params.prompt = ProviderTransform.message(args.params.prompt, input.providerID, input.modelID)
  630. }
  631. return args.params
  632. },
  633. },
  634. ],
  635. }),
  636. })
  637. const result = await processor.process(stream)
  638. return result
  639. }
  640. function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
  641. const toolCalls: Record<string, MessageV2.ToolPart> = {}
  642. return {
  643. partFromToolCall(toolCallID: string) {
  644. return toolCalls[toolCallID]
  645. },
  646. async process(stream: StreamTextResult<Record<string, AITool>, never>) {
  647. try {
  648. let currentText: MessageV2.TextPart | undefined
  649. for await (const value of stream.fullStream) {
  650. log.info("part", {
  651. type: value.type,
  652. })
  653. switch (value.type) {
  654. case "start":
  655. const snapshot = await Snapshot.create(assistantMsg.sessionID)
  656. if (snapshot)
  657. await updatePart({
  658. id: Identifier.ascending("part"),
  659. messageID: assistantMsg.id,
  660. sessionID: assistantMsg.sessionID,
  661. type: "snapshot",
  662. snapshot,
  663. })
  664. break
  665. case "tool-input-start":
  666. const part = await updatePart({
  667. id: Identifier.ascending("part"),
  668. messageID: assistantMsg.id,
  669. sessionID: assistantMsg.sessionID,
  670. type: "tool",
  671. tool: value.toolName,
  672. callID: value.id,
  673. state: {
  674. status: "pending",
  675. },
  676. })
  677. toolCalls[value.id] = part as MessageV2.ToolPart
  678. break
  679. case "tool-input-delta":
  680. break
  681. case "tool-call": {
  682. const match = toolCalls[value.toolCallId]
  683. if (match) {
  684. const part = await updatePart({
  685. ...match,
  686. state: {
  687. status: "running",
  688. input: value.input,
  689. time: {
  690. start: Date.now(),
  691. },
  692. },
  693. })
  694. toolCalls[value.toolCallId] = part as MessageV2.ToolPart
  695. }
  696. break
  697. }
  698. case "tool-result": {
  699. const match = toolCalls[value.toolCallId]
  700. if (match && match.state.status === "running") {
  701. await updatePart({
  702. ...match,
  703. state: {
  704. status: "completed",
  705. input: value.input,
  706. output: value.output.output,
  707. metadata: value.output.metadata,
  708. title: value.output.title,
  709. time: {
  710. start: match.state.time.start,
  711. end: Date.now(),
  712. },
  713. },
  714. })
  715. delete toolCalls[value.toolCallId]
  716. const snapshot = await Snapshot.create(assistantMsg.sessionID)
  717. if (snapshot)
  718. await updatePart({
  719. id: Identifier.ascending("part"),
  720. messageID: assistantMsg.id,
  721. sessionID: assistantMsg.sessionID,
  722. type: "snapshot",
  723. snapshot,
  724. })
  725. }
  726. break
  727. }
  728. case "tool-error": {
  729. const match = toolCalls[value.toolCallId]
  730. if (match && match.state.status === "running") {
  731. await updatePart({
  732. ...match,
  733. state: {
  734. status: "error",
  735. input: value.input,
  736. error: (value.error as any).toString(),
  737. time: {
  738. start: match.state.time.start,
  739. end: Date.now(),
  740. },
  741. },
  742. })
  743. delete toolCalls[value.toolCallId]
  744. const snapshot = await Snapshot.create(assistantMsg.sessionID)
  745. if (snapshot)
  746. await updatePart({
  747. id: Identifier.ascending("part"),
  748. messageID: assistantMsg.id,
  749. sessionID: assistantMsg.sessionID,
  750. type: "snapshot",
  751. snapshot,
  752. })
  753. }
  754. break
  755. }
  756. case "error":
  757. throw value.error
  758. case "start-step":
  759. await updatePart({
  760. id: Identifier.ascending("part"),
  761. messageID: assistantMsg.id,
  762. sessionID: assistantMsg.sessionID,
  763. type: "step-start",
  764. })
  765. break
  766. case "finish-step":
  767. const usage = getUsage(model, value.usage, value.providerMetadata)
  768. assistantMsg.cost += usage.cost
  769. assistantMsg.tokens = usage.tokens
  770. await updatePart({
  771. id: Identifier.ascending("part"),
  772. messageID: assistantMsg.id,
  773. sessionID: assistantMsg.sessionID,
  774. type: "step-finish",
  775. tokens: usage.tokens,
  776. cost: usage.cost,
  777. })
  778. await updateMessage(assistantMsg)
  779. break
  780. case "text-start":
  781. currentText = {
  782. id: Identifier.ascending("part"),
  783. messageID: assistantMsg.id,
  784. sessionID: assistantMsg.sessionID,
  785. type: "text",
  786. text: "",
  787. time: {
  788. start: Date.now(),
  789. },
  790. }
  791. break
  792. case "text":
  793. if (currentText) {
  794. currentText.text += value.text
  795. await updatePart(currentText)
  796. }
  797. break
  798. case "text-end":
  799. if (currentText && currentText.text) {
  800. currentText.time = {
  801. start: Date.now(),
  802. end: Date.now(),
  803. }
  804. await updatePart(currentText)
  805. }
  806. currentText = undefined
  807. break
  808. case "finish":
  809. assistantMsg.time.completed = Date.now()
  810. await updateMessage(assistantMsg)
  811. break
  812. default:
  813. log.info("unhandled", {
  814. ...value,
  815. })
  816. continue
  817. }
  818. }
  819. } catch (e) {
  820. log.error("", {
  821. error: e,
  822. })
  823. switch (true) {
  824. case e instanceof DOMException && e.name === "AbortError":
  825. assistantMsg.error = new MessageV2.AbortedError(
  826. { message: e.message },
  827. {
  828. cause: e,
  829. },
  830. ).toObject()
  831. break
  832. case MessageV2.OutputLengthError.isInstance(e):
  833. assistantMsg.error = e
  834. break
  835. case LoadAPIKeyError.isInstance(e):
  836. assistantMsg.error = new MessageV2.AuthError(
  837. {
  838. providerID: model.id,
  839. message: e.message,
  840. },
  841. { cause: e },
  842. ).toObject()
  843. break
  844. case e instanceof Error:
  845. assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  846. break
  847. default:
  848. assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
  849. }
  850. Bus.publish(Event.Error, {
  851. sessionID: assistantMsg.sessionID,
  852. error: assistantMsg.error,
  853. })
  854. }
  855. const p = await parts(assistantMsg.sessionID, assistantMsg.id)
  856. for (const part of p) {
  857. if (part.type === "tool" && part.state.status !== "completed") {
  858. updatePart({
  859. ...part,
  860. state: {
  861. status: "error",
  862. error: "Tool execution aborted",
  863. time: {
  864. start: Date.now(),
  865. end: Date.now(),
  866. },
  867. input: {},
  868. },
  869. })
  870. }
  871. }
  872. assistantMsg.time.completed = Date.now()
  873. await updateMessage(assistantMsg)
  874. return { info: assistantMsg, parts: p }
  875. },
  876. }
  877. }
  878. export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
  879. // TODO
  880. /*
  881. const message = await getMessage(input.sessionID, input.messageID)
  882. if (!message) return
  883. const part = message.parts[input.part]
  884. if (!part) return
  885. const session = await get(input.sessionID)
  886. const snapshot =
  887. session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
  888. const old = (() => {
  889. if (message.role === "assistant") {
  890. const lastTool = message.parts.findLast(
  891. (part, index) =>
  892. part.type === "tool-invocation" && index < input.part,
  893. )
  894. if (lastTool && lastTool.type === "tool-invocation")
  895. return message.metadata.tool[lastTool.toolInvocation.toolCallId]
  896. .snapshot
  897. }
  898. return message.metadata.snapshot
  899. })()
  900. if (old) await Snapshot.restore(input.sessionID, old)
  901. await update(input.sessionID, (draft) => {
  902. draft.revert = {
  903. messageID: input.messageID,
  904. part: input.part,
  905. snapshot,
  906. }
  907. })
  908. */
  909. }
  910. export async function unrevert(sessionID: string) {
  911. const session = await get(sessionID)
  912. if (!session) return
  913. if (!session.revert) return
  914. if (session.revert.snapshot) await Snapshot.restore(sessionID, session.revert.snapshot)
  915. update(sessionID, (draft) => {
  916. draft.revert = undefined
  917. })
  918. }
  919. export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
  920. using abort = lock(input.sessionID)
  921. const msgs = await messages(input.sessionID)
  922. const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
  923. const filtered = msgs.filter((msg) => !lastSummary || msg.info.id >= lastSummary.info.id)
  924. const model = await Provider.getModel(input.providerID, input.modelID)
  925. const app = App.info()
  926. const system = SystemPrompt.summarize(input.providerID)
  927. const next: MessageV2.Info = {
  928. id: Identifier.ascending("message"),
  929. role: "assistant",
  930. sessionID: input.sessionID,
  931. system,
  932. path: {
  933. cwd: app.path.cwd,
  934. root: app.path.root,
  935. },
  936. summary: true,
  937. cost: 0,
  938. modelID: input.modelID,
  939. providerID: input.providerID,
  940. tokens: {
  941. input: 0,
  942. output: 0,
  943. reasoning: 0,
  944. cache: { read: 0, write: 0 },
  945. },
  946. time: {
  947. created: Date.now(),
  948. },
  949. }
  950. await updateMessage(next)
  951. const processor = createProcessor(next, model.info)
  952. const stream = streamText({
  953. maxRetries: 10,
  954. abortSignal: abort.signal,
  955. model: model.language,
  956. messages: [
  957. ...system.map(
  958. (x): ModelMessage => ({
  959. role: "system",
  960. content: x,
  961. }),
  962. ),
  963. ...MessageV2.toModelMessage(filtered),
  964. {
  965. role: "user",
  966. content: [
  967. {
  968. type: "text",
  969. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  970. },
  971. ],
  972. },
  973. ],
  974. })
  975. const result = await processor.process(stream)
  976. return result
  977. }
  978. function lock(sessionID: string) {
  979. log.info("locking", { sessionID })
  980. if (state().pending.has(sessionID)) throw new BusyError(sessionID)
  981. const controller = new AbortController()
  982. state().pending.set(sessionID, controller)
  983. return {
  984. signal: controller.signal,
  985. [Symbol.dispose]() {
  986. log.info("unlocking", { sessionID })
  987. state().pending.delete(sessionID)
  988. Bus.publish(Event.Idle, {
  989. sessionID,
  990. })
  991. },
  992. }
  993. }
  994. function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
  995. const tokens = {
  996. input: usage.inputTokens ?? 0,
  997. output: usage.outputTokens ?? 0,
  998. reasoning: 0,
  999. cache: {
  1000. write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  1001. // @ts-expect-error
  1002. metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  1003. 0) as number,
  1004. read: usage.cachedInputTokens ?? 0,
  1005. },
  1006. }
  1007. return {
  1008. cost: new Decimal(0)
  1009. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  1010. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  1011. .add(new Decimal(tokens.cache.read).mul(model.cost.cache_read ?? 0).div(1_000_000))
  1012. .add(new Decimal(tokens.cache.write).mul(model.cost.cache_write ?? 0).div(1_000_000))
  1013. .toNumber(),
  1014. tokens,
  1015. }
  1016. }
  1017. export class BusyError extends Error {
  1018. constructor(public readonly sessionID: string) {
  1019. super(`Session ${sessionID} is busy`)
  1020. }
  1021. }
  1022. export async function initialize(input: {
  1023. sessionID: string
  1024. modelID: string
  1025. providerID: string
  1026. messageID: string
  1027. }) {
  1028. const app = App.info()
  1029. await Session.chat({
  1030. sessionID: input.sessionID,
  1031. messageID: input.messageID,
  1032. providerID: input.providerID,
  1033. modelID: input.modelID,
  1034. parts: [
  1035. {
  1036. id: Identifier.ascending("part"),
  1037. sessionID: input.sessionID,
  1038. messageID: input.messageID,
  1039. type: "text",
  1040. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  1041. },
  1042. ],
  1043. })
  1044. await App.initialize()
  1045. }
  1046. }