index.ts 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077
  1. import path from "path"
  2. import { Decimal } from "decimal.js"
  3. import { z, ZodSchema } from "zod"
  4. import {
  5. generateText,
  6. LoadAPIKeyError,
  7. streamText,
  8. tool,
  9. wrapLanguageModel,
  10. type Tool as AITool,
  11. type LanguageModelUsage,
  12. type ProviderMetadata,
  13. type ModelMessage,
  14. stepCountIs,
  15. type StreamTextResult,
  16. } from "ai"
  17. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  18. import PROMPT_PLAN from "../session/prompt/plan.txt"
  19. import { App } from "../app/app"
  20. import { Bus } from "../bus"
  21. import { Config } from "../config/config"
  22. import { Flag } from "../flag/flag"
  23. import { Identifier } from "../id/id"
  24. import { Installation } from "../installation"
  25. import { MCP } from "../mcp"
  26. import { Provider } from "../provider/provider"
  27. import { ProviderTransform } from "../provider/transform"
  28. import type { ModelsDev } from "../provider/models"
  29. import { Share } from "../share/share"
  30. import { Snapshot } from "../snapshot"
  31. import { Storage } from "../storage/storage"
  32. import { Log } from "../util/log"
  33. import { NamedError } from "../util/error"
  34. import { SystemPrompt } from "./system"
  35. import { FileTime } from "../file/time"
  36. import { MessageV2 } from "./message-v2"
  37. import { Mode } from "./mode"
  38. import { LSP } from "../lsp"
  39. import { ReadTool } from "../tool/read"
  40. export namespace Session {
  41. const log = Log.create({ service: "session" })
  42. const OUTPUT_TOKEN_MAX = 32_000
  43. export const Info = z
  44. .object({
  45. id: Identifier.schema("session"),
  46. parentID: Identifier.schema("session").optional(),
  47. share: z
  48. .object({
  49. url: z.string(),
  50. })
  51. .optional(),
  52. title: z.string(),
  53. version: z.string(),
  54. time: z.object({
  55. created: z.number(),
  56. updated: z.number(),
  57. }),
  58. revert: z
  59. .object({
  60. messageID: z.string(),
  61. part: z.number(),
  62. snapshot: z.string().optional(),
  63. })
  64. .optional(),
  65. })
  66. .openapi({
  67. ref: "Session",
  68. })
  69. export type Info = z.output<typeof Info>
  70. export const ShareInfo = z
  71. .object({
  72. secret: z.string(),
  73. url: z.string(),
  74. })
  75. .openapi({
  76. ref: "SessionShare",
  77. })
  78. export type ShareInfo = z.output<typeof ShareInfo>
  79. export const Event = {
  80. Updated: Bus.event(
  81. "session.updated",
  82. z.object({
  83. info: Info,
  84. }),
  85. ),
  86. Deleted: Bus.event(
  87. "session.deleted",
  88. z.object({
  89. info: Info,
  90. }),
  91. ),
  92. Idle: Bus.event(
  93. "session.idle",
  94. z.object({
  95. sessionID: z.string(),
  96. }),
  97. ),
  98. Error: Bus.event(
  99. "session.error",
  100. z.object({
  101. sessionID: z.string().optional(),
  102. error: MessageV2.Assistant.shape.error,
  103. }),
  104. ),
  105. }
  106. const state = App.state(
  107. "session",
  108. () => {
  109. const sessions = new Map<string, Info>()
  110. const messages = new Map<string, MessageV2.Info[]>()
  111. const pending = new Map<string, AbortController>()
  112. return {
  113. sessions,
  114. messages,
  115. pending,
  116. }
  117. },
  118. async (state) => {
  119. for (const [_, controller] of state.pending) {
  120. controller.abort()
  121. }
  122. },
  123. )
  124. export async function create(parentID?: string) {
  125. const result: Info = {
  126. id: Identifier.descending("session"),
  127. version: Installation.VERSION,
  128. parentID,
  129. title: (parentID ? "Child session - " : "New Session - ") + new Date().toISOString(),
  130. time: {
  131. created: Date.now(),
  132. updated: Date.now(),
  133. },
  134. }
  135. log.info("created", result)
  136. state().sessions.set(result.id, result)
  137. await Storage.writeJSON("session/info/" + result.id, result)
  138. const cfg = await Config.get()
  139. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
  140. share(result.id).then((share) => {
  141. update(result.id, (draft) => {
  142. draft.share = share
  143. })
  144. }).catch(() => {
  145. // Silently ignore sharing errors during session creation
  146. })
  147. Bus.publish(Event.Updated, {
  148. info: result,
  149. })
  150. return result
  151. }
  152. export async function get(id: string) {
  153. const result = state().sessions.get(id)
  154. if (result) {
  155. return result
  156. }
  157. const read = await Storage.readJSON<Info>("session/info/" + id)
  158. state().sessions.set(id, read)
  159. return read as Info
  160. }
  161. export async function getShare(id: string) {
  162. return Storage.readJSON<ShareInfo>("session/share/" + id)
  163. }
  164. export async function share(id: string) {
  165. const cfg = await Config.get()
  166. if (cfg.share === "disabled") {
  167. throw new Error("Sharing is disabled in configuration")
  168. }
  169. const session = await get(id)
  170. if (session.share) return session.share
  171. const share = await Share.create(id)
  172. await update(id, (draft) => {
  173. draft.share = {
  174. url: share.url,
  175. }
  176. })
  177. await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
  178. await Share.sync("session/info/" + id, session)
  179. for (const msg of await messages(id)) {
  180. await Share.sync("session/message/" + id + "/" + msg.info.id, msg.info)
  181. for (const part of msg.parts) {
  182. await Share.sync("session/part/" + id + "/" + msg.info.id + "/" + part.id, part)
  183. }
  184. }
  185. return share
  186. }
  187. export async function unshare(id: string) {
  188. const share = await getShare(id)
  189. if (!share) return
  190. await Storage.remove("session/share/" + id)
  191. await update(id, (draft) => {
  192. draft.share = undefined
  193. })
  194. await Share.remove(id, share.secret)
  195. }
  196. export async function update(id: string, editor: (session: Info) => void) {
  197. const { sessions } = state()
  198. const session = await get(id)
  199. if (!session) return
  200. editor(session)
  201. session.time.updated = Date.now()
  202. sessions.set(id, session)
  203. await Storage.writeJSON("session/info/" + id, session)
  204. Bus.publish(Event.Updated, {
  205. info: session,
  206. })
  207. return session
  208. }
  209. export async function messages(sessionID: string) {
  210. const result = [] as {
  211. info: MessageV2.Info
  212. parts: MessageV2.Part[]
  213. }[]
  214. const list = Storage.list("session/message/" + sessionID)
  215. for await (const p of list) {
  216. const read = await Storage.readJSON<MessageV2.Info>(p)
  217. result.push({
  218. info: read,
  219. parts: await parts(sessionID, read.id),
  220. })
  221. }
  222. result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
  223. return result
  224. }
  225. export async function getMessage(sessionID: string, messageID: string) {
  226. return Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID)
  227. }
  228. export async function parts(sessionID: string, messageID: string) {
  229. const result = [] as MessageV2.Part[]
  230. for await (const item of Storage.list("session/part/" + sessionID + "/" + messageID)) {
  231. const read = await Storage.readJSON<MessageV2.Part>(item)
  232. result.push(read)
  233. }
  234. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  235. return result
  236. }
  237. export async function* list() {
  238. for await (const item of Storage.list("session/info")) {
  239. const sessionID = path.basename(item, ".json")
  240. yield get(sessionID)
  241. }
  242. }
  243. export async function children(parentID: string) {
  244. const result = [] as Session.Info[]
  245. for await (const item of Storage.list("session/info")) {
  246. const sessionID = path.basename(item, ".json")
  247. const session = await get(sessionID)
  248. if (session.parentID !== parentID) continue
  249. result.push(session)
  250. }
  251. return result
  252. }
  253. export function abort(sessionID: string) {
  254. const controller = state().pending.get(sessionID)
  255. if (!controller) return false
  256. controller.abort()
  257. state().pending.delete(sessionID)
  258. return true
  259. }
  260. export async function remove(sessionID: string, emitEvent = true) {
  261. try {
  262. abort(sessionID)
  263. const session = await get(sessionID)
  264. for (const child of await children(sessionID)) {
  265. await remove(child.id, false)
  266. }
  267. await unshare(sessionID).catch(() => {})
  268. await Storage.remove(`session/info/${sessionID}`).catch(() => {})
  269. await Storage.removeDir(`session/message/${sessionID}/`).catch(() => {})
  270. state().sessions.delete(sessionID)
  271. state().messages.delete(sessionID)
  272. if (emitEvent) {
  273. Bus.publish(Event.Deleted, {
  274. info: session,
  275. })
  276. }
  277. } catch (e) {
  278. log.error(e)
  279. }
  280. }
  281. async function updateMessage(msg: MessageV2.Info) {
  282. await Storage.writeJSON("session/message/" + msg.sessionID + "/" + msg.id, msg)
  283. Bus.publish(MessageV2.Event.Updated, {
  284. info: msg,
  285. })
  286. }
  287. async function updatePart(part: MessageV2.Part) {
  288. await Storage.writeJSON(["session", "part", part.sessionID, part.messageID, part.id].join("/"), part)
  289. Bus.publish(MessageV2.Event.PartUpdated, {
  290. part,
  291. })
  292. return part
  293. }
  294. export async function chat(input: {
  295. sessionID: string
  296. messageID: string
  297. providerID: string
  298. modelID: string
  299. mode?: string
  300. parts: (MessageV2.TextPart | MessageV2.FilePart)[]
  301. }) {
  302. const l = log.clone().tag("session", input.sessionID)
  303. l.info("chatting")
  304. const model = await Provider.getModel(input.providerID, input.modelID)
  305. let msgs = await messages(input.sessionID)
  306. const session = await get(input.sessionID)
  307. if (session.revert) {
  308. const trimmed = []
  309. for (const msg of msgs) {
  310. if (
  311. msg.info.id > session.revert.messageID ||
  312. (msg.info.id === session.revert.messageID && session.revert.part === 0)
  313. ) {
  314. await Storage.remove("session/message/" + input.sessionID + "/" + msg.info.id)
  315. await Bus.publish(MessageV2.Event.Removed, {
  316. sessionID: input.sessionID,
  317. messageID: msg.info.id,
  318. })
  319. continue
  320. }
  321. if (msg.info.id === session.revert.messageID) {
  322. if (session.revert.part === 0) break
  323. msg.parts = msg.parts.slice(0, session.revert.part)
  324. }
  325. trimmed.push(msg)
  326. }
  327. msgs = trimmed
  328. await update(input.sessionID, (draft) => {
  329. draft.revert = undefined
  330. })
  331. }
  332. const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
  333. const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
  334. // auto summarize if too long
  335. if (previous) {
  336. const tokens =
  337. previous.tokens.input + previous.tokens.cache.read + previous.tokens.cache.write + previous.tokens.output
  338. if (model.info.limit.context && tokens > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
  339. await summarize({
  340. sessionID: input.sessionID,
  341. providerID: input.providerID,
  342. modelID: input.modelID,
  343. })
  344. return chat(input)
  345. }
  346. }
  347. using abort = lock(input.sessionID)
  348. const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
  349. if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
  350. const userMsg: MessageV2.Info = {
  351. id: input.messageID,
  352. role: "user",
  353. sessionID: input.sessionID,
  354. time: {
  355. created: Date.now(),
  356. },
  357. }
  358. const app = App.info()
  359. const userParts = await Promise.all(
  360. input.parts.map(async (part): Promise<MessageV2.Part[]> => {
  361. if (part.type === "file") {
  362. const url = new URL(part.url)
  363. switch (url.protocol) {
  364. case "file:":
  365. // have to normalize, symbol search returns absolute paths
  366. // Decode the pathname since URL constructor doesn't automatically decode it
  367. const pathname = decodeURIComponent(url.pathname)
  368. const relativePath = pathname.replace(app.path.cwd, ".")
  369. const filePath = path.join(app.path.cwd, relativePath)
  370. if (part.mime === "text/plain") {
  371. let offset: number | undefined = undefined
  372. let limit: number | undefined = undefined
  373. const range = {
  374. start: url.searchParams.get("start"),
  375. end: url.searchParams.get("end"),
  376. }
  377. if (range.start != null) {
  378. const filePath = part.url.split("?")[0]
  379. let start = parseInt(range.start)
  380. let end = range.end ? parseInt(range.end) : undefined
  381. // some LSP servers (eg, gopls) don't give full range in
  382. // workspace/symbol searches, so we'll try to find the
  383. // symbol in the document to get the full range
  384. if (start === end) {
  385. const symbols = await LSP.documentSymbol(filePath)
  386. for (const symbol of symbols) {
  387. let range: LSP.Range | undefined
  388. if ("range" in symbol) {
  389. range = symbol.range
  390. } else if ("location" in symbol) {
  391. range = symbol.location.range
  392. }
  393. if (range?.start?.line && range?.start?.line === start) {
  394. start = range.start.line
  395. end = range?.end?.line ?? start
  396. break
  397. }
  398. }
  399. offset = Math.max(start - 2, 0)
  400. if (end) {
  401. limit = end - offset + 2
  402. }
  403. }
  404. }
  405. const args = { filePath, offset, limit }
  406. const result = await ReadTool.execute(args, {
  407. sessionID: input.sessionID,
  408. abort: abort.signal,
  409. messageID: "", // read tool doesn't use message ID
  410. metadata: async () => {},
  411. })
  412. return [
  413. {
  414. id: Identifier.ascending("part"),
  415. messageID: userMsg.id,
  416. sessionID: input.sessionID,
  417. type: "text",
  418. synthetic: true,
  419. text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
  420. },
  421. {
  422. id: Identifier.ascending("part"),
  423. messageID: userMsg.id,
  424. sessionID: input.sessionID,
  425. type: "text",
  426. synthetic: true,
  427. text: result.output,
  428. },
  429. ]
  430. }
  431. let file = Bun.file(filePath)
  432. FileTime.read(input.sessionID, filePath)
  433. return [
  434. {
  435. id: Identifier.ascending("part"),
  436. messageID: userMsg.id,
  437. sessionID: input.sessionID,
  438. type: "text",
  439. text: `Called the Read tool with the following input: {\"filePath\":\"${pathname}\"}`,
  440. synthetic: true,
  441. },
  442. {
  443. id: Identifier.ascending("part"),
  444. messageID: userMsg.id,
  445. sessionID: input.sessionID,
  446. type: "file",
  447. url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
  448. mime: part.mime,
  449. filename: part.filename!,
  450. },
  451. ]
  452. }
  453. }
  454. return [part]
  455. }),
  456. ).then((x) => x.flat())
  457. if (input.mode === "plan")
  458. userParts.push({
  459. id: Identifier.ascending("part"),
  460. messageID: userMsg.id,
  461. sessionID: input.sessionID,
  462. type: "text",
  463. text: PROMPT_PLAN,
  464. synthetic: true,
  465. })
  466. if (msgs.length === 0 && !session.parentID) {
  467. generateText({
  468. maxOutputTokens: input.providerID === "google" ? 1024 : 20,
  469. providerOptions: model.info.options,
  470. messages: [
  471. ...SystemPrompt.title(input.providerID).map(
  472. (x): ModelMessage => ({
  473. role: "system",
  474. content: x,
  475. }),
  476. ),
  477. ...MessageV2.toModelMessage([
  478. {
  479. info: {
  480. id: Identifier.ascending("message"),
  481. role: "user",
  482. sessionID: input.sessionID,
  483. time: {
  484. created: Date.now(),
  485. },
  486. },
  487. parts: userParts,
  488. },
  489. ]),
  490. ],
  491. model: model.language,
  492. })
  493. .then((result) => {
  494. if (result.text)
  495. return Session.update(input.sessionID, (draft) => {
  496. draft.title = result.text
  497. })
  498. })
  499. .catch(() => {})
  500. }
  501. await updateMessage(userMsg)
  502. for (const part of userParts) {
  503. await updatePart(part)
  504. }
  505. msgs.push({ info: userMsg, parts: userParts })
  506. const mode = await Mode.get(input.mode ?? "build")
  507. let system = mode.prompt ? [mode.prompt] : SystemPrompt.provider(input.providerID, input.modelID)
  508. system.push(...(await SystemPrompt.environment()))
  509. system.push(...(await SystemPrompt.custom()))
  510. // max 2 system prompt messages for caching purposes
  511. const [first, ...rest] = system
  512. system = [first, rest.join("\n")]
  513. const assistantMsg: MessageV2.Info = {
  514. id: Identifier.ascending("message"),
  515. role: "assistant",
  516. system,
  517. path: {
  518. cwd: app.path.cwd,
  519. root: app.path.root,
  520. },
  521. cost: 0,
  522. tokens: {
  523. input: 0,
  524. output: 0,
  525. reasoning: 0,
  526. cache: { read: 0, write: 0 },
  527. },
  528. modelID: input.modelID,
  529. providerID: input.providerID,
  530. time: {
  531. created: Date.now(),
  532. },
  533. sessionID: input.sessionID,
  534. }
  535. await updateMessage(assistantMsg)
  536. const tools: Record<string, AITool> = {}
  537. for (const item of await Provider.tools(input.providerID)) {
  538. if (mode.tools[item.id] === false) continue
  539. tools[item.id] = tool({
  540. id: item.id as any,
  541. description: item.description,
  542. inputSchema: item.parameters as ZodSchema,
  543. async execute(args) {
  544. const result = await item.execute(args, {
  545. sessionID: input.sessionID,
  546. abort: abort.signal,
  547. messageID: assistantMsg.id,
  548. metadata: async () => {
  549. /*
  550. const match = toolCalls[opts.toolCallId]
  551. if (match && match.state.status === "running") {
  552. await updatePart({
  553. ...match,
  554. state: {
  555. title: val.title,
  556. metadata: val.metadata,
  557. status: "running",
  558. input: args.input,
  559. time: {
  560. start: Date.now(),
  561. },
  562. },
  563. })
  564. }
  565. */
  566. },
  567. })
  568. return result
  569. },
  570. toModelOutput(result) {
  571. return {
  572. type: "text",
  573. value: result.output,
  574. }
  575. },
  576. })
  577. }
  578. for (const [key, item] of Object.entries(await MCP.tools())) {
  579. if (mode.tools[key] === false) continue
  580. const execute = item.execute
  581. if (!execute) continue
  582. item.execute = async (args, opts) => {
  583. const result = await execute(args, opts)
  584. const output = result.content
  585. .filter((x: any) => x.type === "text")
  586. .map((x: any) => x.text)
  587. .join("\n\n")
  588. return {
  589. output,
  590. }
  591. }
  592. item.toModelOutput = (result) => {
  593. return {
  594. type: "text",
  595. value: result.output,
  596. }
  597. }
  598. tools[key] = item
  599. }
  600. const result = streamText({
  601. onError() {},
  602. maxRetries: 10,
  603. maxOutputTokens: outputLimit,
  604. abortSignal: abort.signal,
  605. stopWhen: stepCountIs(1000),
  606. providerOptions: model.info.options,
  607. messages: [
  608. ...system.map(
  609. (x): ModelMessage => ({
  610. role: "system",
  611. content: x,
  612. }),
  613. ),
  614. ...MessageV2.toModelMessage(msgs),
  615. ],
  616. temperature: model.info.temperature ? 0 : undefined,
  617. tools: model.info.tool_call === false ? undefined : tools,
  618. model: wrapLanguageModel({
  619. model: model.language,
  620. middleware: [
  621. {
  622. async transformParams(args) {
  623. if (args.type === "stream") {
  624. // @ts-expect-error
  625. args.params.prompt = ProviderTransform.message(args.params.prompt, input.providerID, input.modelID)
  626. }
  627. return args.params
  628. },
  629. },
  630. ],
  631. }),
  632. })
  633. return processStream(assistantMsg, model.info, result)
  634. }
  635. async function processStream(
  636. assistantMsg: MessageV2.Assistant,
  637. model: ModelsDev.Model,
  638. stream: StreamTextResult<Record<string, AITool>, never>,
  639. ) {
  640. try {
  641. let currentText: MessageV2.TextPart | undefined
  642. const toolCalls: Record<string, MessageV2.ToolPart> = {}
  643. for await (const value of stream.fullStream) {
  644. log.info("part", {
  645. type: value.type,
  646. })
  647. switch (value.type) {
  648. case "start":
  649. break
  650. case "tool-input-start":
  651. const part = await updatePart({
  652. id: Identifier.ascending("part"),
  653. messageID: assistantMsg.id,
  654. sessionID: assistantMsg.sessionID,
  655. type: "tool",
  656. tool: value.toolName,
  657. callID: value.id,
  658. state: {
  659. status: "pending",
  660. },
  661. })
  662. toolCalls[value.id] = part as MessageV2.ToolPart
  663. break
  664. case "tool-input-delta":
  665. break
  666. case "tool-call": {
  667. const match = toolCalls[value.toolCallId]
  668. if (match) {
  669. const part = await updatePart({
  670. ...match,
  671. state: {
  672. status: "running",
  673. input: value.input,
  674. time: {
  675. start: Date.now(),
  676. },
  677. },
  678. })
  679. toolCalls[value.toolCallId] = part as MessageV2.ToolPart
  680. }
  681. break
  682. }
  683. case "tool-result": {
  684. const match = toolCalls[value.toolCallId]
  685. if (match && match.state.status === "running") {
  686. await updatePart({
  687. ...match,
  688. state: {
  689. status: "completed",
  690. input: value.input,
  691. output: value.output.output,
  692. metadata: value.output.metadata,
  693. title: value.output.title,
  694. time: {
  695. start: match.state.time.start,
  696. end: Date.now(),
  697. },
  698. },
  699. })
  700. delete toolCalls[value.toolCallId]
  701. }
  702. break
  703. }
  704. case "tool-error": {
  705. const match = toolCalls[value.toolCallId]
  706. if (match && match.state.status === "running") {
  707. await updatePart({
  708. ...match,
  709. state: {
  710. status: "error",
  711. input: value.input,
  712. error: (value.error as any).toString(),
  713. time: {
  714. start: match.state.time.start,
  715. end: Date.now(),
  716. },
  717. },
  718. })
  719. delete toolCalls[value.toolCallId]
  720. }
  721. break
  722. }
  723. case "error":
  724. throw value.error
  725. case "start-step":
  726. await updatePart({
  727. id: Identifier.ascending("part"),
  728. messageID: assistantMsg.id,
  729. sessionID: assistantMsg.sessionID,
  730. type: "step-start",
  731. })
  732. break
  733. case "finish-step":
  734. const usage = getUsage(model, value.usage, value.providerMetadata)
  735. assistantMsg.cost += usage.cost
  736. assistantMsg.tokens = usage.tokens
  737. await updatePart({
  738. id: Identifier.ascending("part"),
  739. messageID: assistantMsg.id,
  740. sessionID: assistantMsg.sessionID,
  741. type: "step-finish",
  742. tokens: usage.tokens,
  743. cost: usage.cost,
  744. })
  745. await updateMessage(assistantMsg)
  746. break
  747. case "text-start":
  748. currentText = {
  749. id: Identifier.ascending("part"),
  750. messageID: assistantMsg.id,
  751. sessionID: assistantMsg.sessionID,
  752. type: "text",
  753. text: "",
  754. time: {
  755. start: Date.now(),
  756. },
  757. }
  758. break
  759. case "text":
  760. if (currentText) {
  761. currentText.text += value.text
  762. await updatePart(currentText)
  763. }
  764. break
  765. case "text-end":
  766. if (currentText && currentText.text) {
  767. currentText.time = {
  768. start: Date.now(),
  769. end: Date.now(),
  770. }
  771. await updatePart(currentText)
  772. }
  773. currentText = undefined
  774. break
  775. case "finish":
  776. assistantMsg.time.completed = Date.now()
  777. await updateMessage(assistantMsg)
  778. break
  779. default:
  780. log.info("unhandled", {
  781. ...value,
  782. })
  783. continue
  784. }
  785. }
  786. } catch (e) {
  787. log.error("", {
  788. error: e,
  789. })
  790. switch (true) {
  791. case e instanceof DOMException && e.name === "AbortError":
  792. assistantMsg.error = new MessageV2.AbortedError(
  793. { message: e.message },
  794. {
  795. cause: e,
  796. },
  797. ).toObject()
  798. break
  799. case MessageV2.OutputLengthError.isInstance(e):
  800. assistantMsg.error = e
  801. break
  802. case LoadAPIKeyError.isInstance(e):
  803. assistantMsg.error = new Provider.AuthError(
  804. {
  805. providerID: model.id,
  806. message: e.message,
  807. },
  808. { cause: e },
  809. ).toObject()
  810. break
  811. case e instanceof Error:
  812. assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  813. break
  814. default:
  815. assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
  816. }
  817. Bus.publish(Event.Error, {
  818. sessionID: assistantMsg.sessionID,
  819. error: assistantMsg.error,
  820. })
  821. }
  822. const p = await parts(assistantMsg.sessionID, assistantMsg.id)
  823. for (const part of p) {
  824. if (part.type === "tool" && part.state.status !== "completed") {
  825. updatePart({
  826. ...part,
  827. state: {
  828. status: "error",
  829. error: "Tool execution aborted",
  830. time: {
  831. start: Date.now(),
  832. end: Date.now(),
  833. },
  834. input: {},
  835. },
  836. })
  837. }
  838. }
  839. assistantMsg.time.completed = Date.now()
  840. await updateMessage(assistantMsg)
  841. return { info: assistantMsg, parts: p }
  842. }
  843. export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
  844. // TODO
  845. /*
  846. const message = await getMessage(input.sessionID, input.messageID)
  847. if (!message) return
  848. const part = message.parts[input.part]
  849. if (!part) return
  850. const session = await get(input.sessionID)
  851. const snapshot =
  852. session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
  853. const old = (() => {
  854. if (message.role === "assistant") {
  855. const lastTool = message.parts.findLast(
  856. (part, index) =>
  857. part.type === "tool-invocation" && index < input.part,
  858. )
  859. if (lastTool && lastTool.type === "tool-invocation")
  860. return message.metadata.tool[lastTool.toolInvocation.toolCallId]
  861. .snapshot
  862. }
  863. return message.metadata.snapshot
  864. })()
  865. if (old) await Snapshot.restore(input.sessionID, old)
  866. await update(input.sessionID, (draft) => {
  867. draft.revert = {
  868. messageID: input.messageID,
  869. part: input.part,
  870. snapshot,
  871. }
  872. })
  873. */
  874. }
  875. export async function unrevert(sessionID: string) {
  876. const session = await get(sessionID)
  877. if (!session) return
  878. if (!session.revert) return
  879. if (session.revert.snapshot) await Snapshot.restore(sessionID, session.revert.snapshot)
  880. update(sessionID, (draft) => {
  881. draft.revert = undefined
  882. })
  883. }
  884. export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
  885. using abort = lock(input.sessionID)
  886. const msgs = await messages(input.sessionID)
  887. const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
  888. const filtered = msgs.filter((msg) => !lastSummary || msg.info.id >= lastSummary.info.id)
  889. const model = await Provider.getModel(input.providerID, input.modelID)
  890. const app = App.info()
  891. const system = SystemPrompt.summarize(input.providerID)
  892. const next: MessageV2.Info = {
  893. id: Identifier.ascending("message"),
  894. role: "assistant",
  895. sessionID: input.sessionID,
  896. system,
  897. path: {
  898. cwd: app.path.cwd,
  899. root: app.path.root,
  900. },
  901. summary: true,
  902. cost: 0,
  903. modelID: input.modelID,
  904. providerID: input.providerID,
  905. tokens: {
  906. input: 0,
  907. output: 0,
  908. reasoning: 0,
  909. cache: { read: 0, write: 0 },
  910. },
  911. time: {
  912. created: Date.now(),
  913. },
  914. }
  915. await updateMessage(next)
  916. const result = streamText({
  917. abortSignal: abort.signal,
  918. model: model.language,
  919. messages: [
  920. ...system.map(
  921. (x): ModelMessage => ({
  922. role: "system",
  923. content: x,
  924. }),
  925. ),
  926. ...MessageV2.toModelMessage(filtered),
  927. {
  928. role: "user",
  929. content: [
  930. {
  931. type: "text",
  932. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  933. },
  934. ],
  935. },
  936. ],
  937. })
  938. return processStream(next, model.info, result)
  939. }
  940. function lock(sessionID: string) {
  941. log.info("locking", { sessionID })
  942. if (state().pending.has(sessionID)) throw new BusyError(sessionID)
  943. const controller = new AbortController()
  944. state().pending.set(sessionID, controller)
  945. return {
  946. signal: controller.signal,
  947. [Symbol.dispose]() {
  948. log.info("unlocking", { sessionID })
  949. state().pending.delete(sessionID)
  950. Bus.publish(Event.Idle, {
  951. sessionID,
  952. })
  953. },
  954. }
  955. }
  956. function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
  957. const tokens = {
  958. input: usage.inputTokens ?? 0,
  959. output: usage.outputTokens ?? 0,
  960. reasoning: 0,
  961. cache: {
  962. write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  963. // @ts-expect-error
  964. metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  965. 0) as number,
  966. read: usage.cachedInputTokens ?? 0,
  967. },
  968. }
  969. return {
  970. cost: new Decimal(0)
  971. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  972. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  973. .add(new Decimal(tokens.cache.read).mul(model.cost.cache_read ?? 0).div(1_000_000))
  974. .add(new Decimal(tokens.cache.write).mul(model.cost.cache_write ?? 0).div(1_000_000))
  975. .toNumber(),
  976. tokens,
  977. }
  978. }
  979. export class BusyError extends Error {
  980. constructor(public readonly sessionID: string) {
  981. super(`Session ${sessionID} is busy`)
  982. }
  983. }
  984. export async function initialize(input: {
  985. sessionID: string
  986. modelID: string
  987. providerID: string
  988. messageID: string
  989. }) {
  990. const app = App.info()
  991. await Session.chat({
  992. sessionID: input.sessionID,
  993. messageID: input.messageID,
  994. providerID: input.providerID,
  995. modelID: input.modelID,
  996. parts: [
  997. {
  998. id: Identifier.ascending("part"),
  999. sessionID: input.sessionID,
  1000. messageID: input.messageID,
  1001. type: "text",
  1002. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  1003. },
  1004. ],
  1005. })
  1006. await App.initialize()
  1007. }
  1008. }