index.ts 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063
  1. import path from "path"
  2. import { Decimal } from "decimal.js"
  3. import { z, ZodSchema } from "zod"
  4. import {
  5. generateText,
  6. LoadAPIKeyError,
  7. streamText,
  8. tool,
  9. wrapLanguageModel,
  10. type Tool as AITool,
  11. type LanguageModelUsage,
  12. type ProviderMetadata,
  13. type ModelMessage,
  14. stepCountIs,
  15. } from "ai"
  16. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  17. import PROMPT_PLAN from "../session/prompt/plan.txt"
  18. import { App } from "../app/app"
  19. import { Bus } from "../bus"
  20. import { Config } from "../config/config"
  21. import { Flag } from "../flag/flag"
  22. import { Identifier } from "../id/id"
  23. import { Installation } from "../installation"
  24. import { MCP } from "../mcp"
  25. import { Provider } from "../provider/provider"
  26. import { ProviderTransform } from "../provider/transform"
  27. import type { ModelsDev } from "../provider/models"
  28. import { Share } from "../share/share"
  29. import { Snapshot } from "../snapshot"
  30. import { Storage } from "../storage/storage"
  31. import { Log } from "../util/log"
  32. import { NamedError } from "../util/error"
  33. import { SystemPrompt } from "./system"
  34. import { FileTime } from "../file/time"
  35. import { MessageV2 } from "./message-v2"
  36. import { Mode } from "./mode"
  37. import { LSP } from "../lsp"
  38. import { ReadTool } from "../tool/read"
  39. export namespace Session {
  40. const log = Log.create({ service: "session" })
  41. const OUTPUT_TOKEN_MAX = 32_000
  42. export const Info = z
  43. .object({
  44. id: Identifier.schema("session"),
  45. parentID: Identifier.schema("session").optional(),
  46. share: z
  47. .object({
  48. url: z.string(),
  49. })
  50. .optional(),
  51. title: z.string(),
  52. version: z.string(),
  53. time: z.object({
  54. created: z.number(),
  55. updated: z.number(),
  56. }),
  57. revert: z
  58. .object({
  59. messageID: z.string(),
  60. part: z.number(),
  61. snapshot: z.string().optional(),
  62. })
  63. .optional(),
  64. })
  65. .openapi({
  66. ref: "Session",
  67. })
  68. export type Info = z.output<typeof Info>
  69. export const ShareInfo = z
  70. .object({
  71. secret: z.string(),
  72. url: z.string(),
  73. })
  74. .openapi({
  75. ref: "SessionShare",
  76. })
  77. export type ShareInfo = z.output<typeof ShareInfo>
  78. export const Event = {
  79. Updated: Bus.event(
  80. "session.updated",
  81. z.object({
  82. info: Info,
  83. }),
  84. ),
  85. Deleted: Bus.event(
  86. "session.deleted",
  87. z.object({
  88. info: Info,
  89. }),
  90. ),
  91. Idle: Bus.event(
  92. "session.idle",
  93. z.object({
  94. sessionID: z.string(),
  95. }),
  96. ),
  97. Error: Bus.event(
  98. "session.error",
  99. z.object({
  100. sessionID: z.string().optional(),
  101. error: MessageV2.Assistant.shape.error,
  102. }),
  103. ),
  104. }
  105. const state = App.state(
  106. "session",
  107. () => {
  108. const sessions = new Map<string, Info>()
  109. const messages = new Map<string, MessageV2.Info[]>()
  110. const pending = new Map<string, AbortController>()
  111. return {
  112. sessions,
  113. messages,
  114. pending,
  115. }
  116. },
  117. async (state) => {
  118. for (const [_, controller] of state.pending) {
  119. controller.abort()
  120. }
  121. },
  122. )
  123. export async function create(parentID?: string) {
  124. const result: Info = {
  125. id: Identifier.descending("session"),
  126. version: Installation.VERSION,
  127. parentID,
  128. title: (parentID ? "Child session - " : "New Session - ") + new Date().toISOString(),
  129. time: {
  130. created: Date.now(),
  131. updated: Date.now(),
  132. },
  133. }
  134. log.info("created", result)
  135. state().sessions.set(result.id, result)
  136. await Storage.writeJSON("session/info/" + result.id, result)
  137. const cfg = await Config.get()
  138. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
  139. share(result.id).then((share) => {
  140. update(result.id, (draft) => {
  141. draft.share = share
  142. })
  143. }).catch(() => {
  144. // Silently ignore sharing errors during session creation
  145. })
  146. Bus.publish(Event.Updated, {
  147. info: result,
  148. })
  149. return result
  150. }
  151. export async function get(id: string) {
  152. const result = state().sessions.get(id)
  153. if (result) {
  154. return result
  155. }
  156. const read = await Storage.readJSON<Info>("session/info/" + id)
  157. state().sessions.set(id, read)
  158. return read as Info
  159. }
  160. export async function getShare(id: string) {
  161. return Storage.readJSON<ShareInfo>("session/share/" + id)
  162. }
  163. export async function share(id: string) {
  164. const cfg = await Config.get()
  165. if (cfg.share === "disabled") {
  166. throw new Error("Sharing is disabled in configuration")
  167. }
  168. const session = await get(id)
  169. if (session.share) return session.share
  170. const share = await Share.create(id)
  171. await update(id, (draft) => {
  172. draft.share = {
  173. url: share.url,
  174. }
  175. })
  176. await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
  177. await Share.sync("session/info/" + id, session)
  178. for (const msg of await messages(id)) {
  179. await Share.sync("session/message/" + id + "/" + msg.id, msg)
  180. }
  181. return share
  182. }
  183. export async function unshare(id: string) {
  184. const share = await getShare(id)
  185. if (!share) return
  186. await Storage.remove("session/share/" + id)
  187. await update(id, (draft) => {
  188. draft.share = undefined
  189. })
  190. await Share.remove(id, share.secret)
  191. }
  192. export async function update(id: string, editor: (session: Info) => void) {
  193. const { sessions } = state()
  194. const session = await get(id)
  195. if (!session) return
  196. editor(session)
  197. session.time.updated = Date.now()
  198. sessions.set(id, session)
  199. await Storage.writeJSON("session/info/" + id, session)
  200. Bus.publish(Event.Updated, {
  201. info: session,
  202. })
  203. return session
  204. }
  205. export async function messages(sessionID: string) {
  206. const result = [] as MessageV2.Info[]
  207. const list = Storage.list("session/message/" + sessionID)
  208. for await (const p of list) {
  209. const read = await Storage.readJSON<MessageV2.Info>(p)
  210. result.push(read)
  211. }
  212. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  213. return result
  214. }
  215. export async function getMessage(sessionID: string, messageID: string) {
  216. return Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID)
  217. }
  218. export async function* list() {
  219. for await (const item of Storage.list("session/info")) {
  220. const sessionID = path.basename(item, ".json")
  221. yield get(sessionID)
  222. }
  223. }
  224. export async function children(parentID: string) {
  225. const result = [] as Session.Info[]
  226. for await (const item of Storage.list("session/info")) {
  227. const sessionID = path.basename(item, ".json")
  228. const session = await get(sessionID)
  229. if (session.parentID !== parentID) continue
  230. result.push(session)
  231. }
  232. return result
  233. }
  234. export function abort(sessionID: string) {
  235. const controller = state().pending.get(sessionID)
  236. if (!controller) return false
  237. controller.abort()
  238. state().pending.delete(sessionID)
  239. return true
  240. }
  241. export async function remove(sessionID: string, emitEvent = true) {
  242. try {
  243. abort(sessionID)
  244. const session = await get(sessionID)
  245. for (const child of await children(sessionID)) {
  246. await remove(child.id, false)
  247. }
  248. await unshare(sessionID).catch(() => {})
  249. await Storage.remove(`session/info/${sessionID}`).catch(() => {})
  250. await Storage.removeDir(`session/message/${sessionID}/`).catch(() => {})
  251. state().sessions.delete(sessionID)
  252. state().messages.delete(sessionID)
  253. if (emitEvent) {
  254. Bus.publish(Event.Deleted, {
  255. info: session,
  256. })
  257. }
  258. } catch (e) {
  259. log.error(e)
  260. }
  261. }
  262. async function updateMessage(msg: MessageV2.Info) {
  263. await Storage.writeJSON("session/message/" + msg.sessionID + "/" + msg.id, msg)
  264. Bus.publish(MessageV2.Event.Updated, {
  265. info: msg,
  266. })
  267. }
  268. export async function chat(input: {
  269. sessionID: string
  270. providerID: string
  271. modelID: string
  272. mode?: string
  273. parts: MessageV2.UserPart[]
  274. }) {
  275. const l = log.clone().tag("session", input.sessionID)
  276. l.info("chatting")
  277. const model = await Provider.getModel(input.providerID, input.modelID)
  278. let msgs = await messages(input.sessionID)
  279. const session = await get(input.sessionID)
  280. if (session.revert) {
  281. const trimmed = []
  282. for (const msg of msgs) {
  283. if (msg.id > session.revert.messageID || (msg.id === session.revert.messageID && session.revert.part === 0)) {
  284. await Storage.remove("session/message/" + input.sessionID + "/" + msg.id)
  285. await Bus.publish(MessageV2.Event.Removed, {
  286. sessionID: input.sessionID,
  287. messageID: msg.id,
  288. })
  289. continue
  290. }
  291. if (msg.id === session.revert.messageID) {
  292. if (session.revert.part === 0) break
  293. msg.parts = msg.parts.slice(0, session.revert.part)
  294. }
  295. trimmed.push(msg)
  296. }
  297. msgs = trimmed
  298. await update(input.sessionID, (draft) => {
  299. draft.revert = undefined
  300. })
  301. }
  302. const previous = msgs.at(-1) as MessageV2.Assistant
  303. const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
  304. // auto summarize if too long
  305. if (previous) {
  306. const tokens =
  307. previous.tokens.input + previous.tokens.cache.read + previous.tokens.cache.write + previous.tokens.output
  308. if (model.info.limit.context && tokens > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
  309. await summarize({
  310. sessionID: input.sessionID,
  311. providerID: input.providerID,
  312. modelID: input.modelID,
  313. })
  314. return chat(input)
  315. }
  316. }
  317. using abort = lock(input.sessionID)
  318. const lastSummary = msgs.findLast((msg) => msg.role === "assistant" && msg.summary === true)
  319. if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
  320. const app = App.info()
  321. input.parts = await Promise.all(
  322. input.parts.map(async (part): Promise<MessageV2.UserPart[]> => {
  323. if (part.type === "file") {
  324. const url = new URL(part.url)
  325. switch (url.protocol) {
  326. case "file:":
  327. // have to normalize, symbol search returns absolute paths
  328. // Decode the pathname since URL constructor doesn't automatically decode it
  329. const pathname = decodeURIComponent(url.pathname)
  330. const relativePath = pathname.replace(app.path.cwd, ".")
  331. const filePath = path.join(app.path.cwd, relativePath)
  332. if (part.mime === "text/plain") {
  333. let offset: number | undefined = undefined
  334. let limit: number | undefined = undefined
  335. const range = {
  336. start: url.searchParams.get("start"),
  337. end: url.searchParams.get("end"),
  338. }
  339. if (range.start != null) {
  340. const filePath = part.url.split("?")[0]
  341. let start = parseInt(range.start)
  342. let end = range.end ? parseInt(range.end) : undefined
  343. // some LSP servers (eg, gopls) don't give full range in
  344. // workspace/symbol searches, so we'll try to find the
  345. // symbol in the document to get the full range
  346. if (start === end) {
  347. const symbols = await LSP.documentSymbol(filePath)
  348. for (const symbol of symbols) {
  349. let range: LSP.Range | undefined
  350. if ("range" in symbol) {
  351. range = symbol.range
  352. } else if ("location" in symbol) {
  353. range = symbol.location.range
  354. }
  355. if (range?.start?.line && range?.start?.line === start) {
  356. start = range.start.line
  357. end = range?.end?.line ?? start
  358. break
  359. }
  360. }
  361. offset = Math.max(start - 2, 0)
  362. if (end) {
  363. limit = end - offset + 2
  364. }
  365. }
  366. }
  367. const args = { filePath, offset, limit }
  368. const result = await ReadTool.execute(args, {
  369. sessionID: input.sessionID,
  370. abort: abort.signal,
  371. messageID: "", // read tool doesn't use message ID
  372. metadata: async () => {},
  373. })
  374. return [
  375. {
  376. type: "text",
  377. synthetic: true,
  378. text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
  379. },
  380. {
  381. type: "text",
  382. synthetic: true,
  383. text: result.output,
  384. },
  385. ]
  386. }
  387. let file = Bun.file(filePath)
  388. FileTime.read(input.sessionID, filePath)
  389. return [
  390. {
  391. type: "text",
  392. text: `Called the Read tool with the following input: {\"filePath\":\"${pathname}\"}`,
  393. synthetic: true,
  394. },
  395. {
  396. type: "file",
  397. url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
  398. mime: part.mime,
  399. filename: part.filename!,
  400. },
  401. ]
  402. }
  403. }
  404. return [part]
  405. }),
  406. ).then((x) => x.flat())
  407. if (input.mode === "plan")
  408. input.parts.push({
  409. type: "text",
  410. text: PROMPT_PLAN,
  411. synthetic: true,
  412. })
  413. if (msgs.length === 0 && !session.parentID) {
  414. generateText({
  415. maxOutputTokens: input.providerID === "google" ? 1024 : 20,
  416. providerOptions: model.info.options,
  417. messages: [
  418. ...SystemPrompt.title(input.providerID).map(
  419. (x): ModelMessage => ({
  420. role: "system",
  421. content: x,
  422. }),
  423. ),
  424. ...MessageV2.toModelMessage([
  425. {
  426. id: Identifier.ascending("message"),
  427. role: "user",
  428. sessionID: input.sessionID,
  429. parts: input.parts,
  430. time: {
  431. created: Date.now(),
  432. },
  433. },
  434. ]),
  435. ],
  436. model: model.language,
  437. })
  438. .then((result) => {
  439. if (result.text)
  440. return Session.update(input.sessionID, (draft) => {
  441. draft.title = result.text
  442. })
  443. })
  444. .catch(() => {})
  445. }
  446. const msg: MessageV2.Info = {
  447. id: Identifier.ascending("message"),
  448. role: "user",
  449. sessionID: input.sessionID,
  450. parts: input.parts,
  451. time: {
  452. created: Date.now(),
  453. },
  454. }
  455. await updateMessage(msg)
  456. msgs.push(msg)
  457. const mode = await Mode.get(input.mode ?? "build")
  458. let system = mode.prompt ? [mode.prompt] : SystemPrompt.provider(input.providerID, input.modelID)
  459. system.push(...(await SystemPrompt.environment()))
  460. system.push(...(await SystemPrompt.custom()))
  461. // max 2 system prompt messages for caching purposes
  462. const [first, ...rest] = system
  463. system = [first, rest.join("\n")]
  464. const next: MessageV2.Info = {
  465. id: Identifier.ascending("message"),
  466. role: "assistant",
  467. parts: [],
  468. system,
  469. path: {
  470. cwd: app.path.cwd,
  471. root: app.path.root,
  472. },
  473. cost: 0,
  474. tokens: {
  475. input: 0,
  476. output: 0,
  477. reasoning: 0,
  478. cache: { read: 0, write: 0 },
  479. },
  480. modelID: input.modelID,
  481. providerID: input.providerID,
  482. time: {
  483. created: Date.now(),
  484. },
  485. sessionID: input.sessionID,
  486. }
  487. await updateMessage(next)
  488. const tools: Record<string, AITool> = {}
  489. for (const item of await Provider.tools(input.providerID)) {
  490. if (mode.tools[item.id] === false) continue
  491. tools[item.id] = tool({
  492. id: item.id as any,
  493. description: item.description,
  494. inputSchema: item.parameters as ZodSchema,
  495. async execute(args, opts) {
  496. const result = await item.execute(args, {
  497. sessionID: input.sessionID,
  498. abort: abort.signal,
  499. messageID: next.id,
  500. metadata: async (val) => {
  501. const match = next.parts.find(
  502. (p): p is MessageV2.ToolPart => p.type === "tool" && p.id === opts.toolCallId,
  503. )
  504. if (match && match.state.status === "running") {
  505. match.state.title = val.title
  506. match.state.metadata = val.metadata
  507. }
  508. await updateMessage(next)
  509. },
  510. })
  511. return result
  512. },
  513. toModelOutput(result) {
  514. return {
  515. type: "text",
  516. value: result.output,
  517. }
  518. },
  519. })
  520. }
  521. for (const [key, item] of Object.entries(await MCP.tools())) {
  522. if (mode.tools[key] === false) continue
  523. const execute = item.execute
  524. if (!execute) continue
  525. item.execute = async (args, opts) => {
  526. const result = await execute(args, opts)
  527. const output = result.content
  528. .filter((x: any) => x.type === "text")
  529. .map((x: any) => x.text)
  530. .join("\n\n")
  531. return {
  532. output,
  533. }
  534. }
  535. item.toModelOutput = (result) => {
  536. return {
  537. type: "text",
  538. value: result.output,
  539. }
  540. }
  541. tools[key] = item
  542. }
  543. let text: MessageV2.TextPart = {
  544. type: "text",
  545. text: "",
  546. }
  547. const result = streamText({
  548. onError() {},
  549. maxRetries: 10,
  550. maxOutputTokens: outputLimit,
  551. abortSignal: abort.signal,
  552. stopWhen: stepCountIs(1000),
  553. providerOptions: model.info.options,
  554. messages: [
  555. ...system.map(
  556. (x): ModelMessage => ({
  557. role: "system",
  558. content: x,
  559. }),
  560. ),
  561. ...MessageV2.toModelMessage(msgs),
  562. ],
  563. temperature: model.info.temperature ? 0 : undefined,
  564. tools: model.info.tool_call === false ? undefined : tools,
  565. model: wrapLanguageModel({
  566. model: model.language,
  567. middleware: [
  568. {
  569. async transformParams(args) {
  570. if (args.type === "stream") {
  571. // @ts-expect-error
  572. args.params.prompt = ProviderTransform.message(args.params.prompt, input.providerID, input.modelID)
  573. }
  574. return args.params
  575. },
  576. },
  577. ],
  578. }),
  579. })
  580. try {
  581. for await (const value of result.fullStream) {
  582. l.info("part", {
  583. type: value.type,
  584. })
  585. switch (value.type) {
  586. case "start":
  587. break
  588. case "tool-input-start":
  589. next.parts.push({
  590. type: "tool",
  591. tool: value.toolName,
  592. id: value.id,
  593. state: {
  594. status: "pending",
  595. },
  596. })
  597. Bus.publish(MessageV2.Event.PartUpdated, {
  598. part: next.parts[next.parts.length - 1],
  599. sessionID: next.sessionID,
  600. messageID: next.id,
  601. })
  602. break
  603. case "tool-input-delta":
  604. break
  605. case "tool-call": {
  606. const match = next.parts.find(
  607. (p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
  608. )
  609. if (match) {
  610. match.state = {
  611. status: "running",
  612. input: value.input,
  613. time: {
  614. start: Date.now(),
  615. },
  616. }
  617. Bus.publish(MessageV2.Event.PartUpdated, {
  618. part: match,
  619. sessionID: next.sessionID,
  620. messageID: next.id,
  621. })
  622. }
  623. break
  624. }
  625. case "tool-result": {
  626. const match = next.parts.find(
  627. (p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
  628. )
  629. if (match && match.state.status === "running") {
  630. match.state = {
  631. status: "completed",
  632. input: value.input,
  633. output: value.output.output,
  634. metadata: value.output.metadata,
  635. title: value.output.title,
  636. time: {
  637. start: match.state.time.start,
  638. end: Date.now(),
  639. },
  640. }
  641. Bus.publish(MessageV2.Event.PartUpdated, {
  642. part: match,
  643. sessionID: next.sessionID,
  644. messageID: next.id,
  645. })
  646. }
  647. break
  648. }
  649. case "tool-error": {
  650. const match = next.parts.find(
  651. (p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
  652. )
  653. if (match && match.state.status === "running") {
  654. match.state = {
  655. status: "error",
  656. input: value.input,
  657. error: (value.error as any).toString(),
  658. time: {
  659. start: match.state.time.start,
  660. end: Date.now(),
  661. },
  662. }
  663. Bus.publish(MessageV2.Event.PartUpdated, {
  664. part: match,
  665. sessionID: next.sessionID,
  666. messageID: next.id,
  667. })
  668. }
  669. break
  670. }
  671. case "error":
  672. throw value.error
  673. case "start-step":
  674. next.parts.push({
  675. type: "step-start",
  676. })
  677. break
  678. case "finish-step":
  679. const usage = getUsage(model.info, value.usage, value.providerMetadata)
  680. next.cost += usage.cost
  681. next.tokens = usage.tokens
  682. next.parts.push({
  683. type: "step-finish",
  684. tokens: usage.tokens,
  685. cost: usage.cost,
  686. })
  687. break
  688. case "text-start":
  689. text = {
  690. type: "text",
  691. text: "",
  692. }
  693. break
  694. case "text":
  695. if (text.text === "") next.parts.push(text)
  696. text.text += value.text
  697. break
  698. case "text-end":
  699. Bus.publish(MessageV2.Event.PartUpdated, {
  700. part: text,
  701. sessionID: next.sessionID,
  702. messageID: next.id,
  703. })
  704. break
  705. case "finish":
  706. next.time.completed = Date.now()
  707. break
  708. default:
  709. l.info("unhandled", {
  710. ...value,
  711. })
  712. continue
  713. }
  714. await updateMessage(next)
  715. }
  716. } catch (e) {
  717. log.error("", {
  718. error: e,
  719. })
  720. switch (true) {
  721. case e instanceof DOMException && e.name === "AbortError":
  722. next.error = new MessageV2.AbortedError(
  723. { message: e.message },
  724. {
  725. cause: e,
  726. },
  727. ).toObject()
  728. break
  729. case MessageV2.OutputLengthError.isInstance(e):
  730. next.error = e
  731. break
  732. case LoadAPIKeyError.isInstance(e):
  733. next.error = new Provider.AuthError(
  734. {
  735. providerID: input.providerID,
  736. message: e.message,
  737. },
  738. { cause: e },
  739. ).toObject()
  740. break
  741. case e instanceof Error:
  742. next.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  743. break
  744. default:
  745. next.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
  746. }
  747. Bus.publish(Event.Error, {
  748. sessionID: next.sessionID,
  749. error: next.error,
  750. })
  751. }
  752. for (const part of next.parts) {
  753. if (part.type === "tool" && part.state.status !== "completed") {
  754. part.state = {
  755. status: "error",
  756. error: "Tool execution aborted",
  757. time: {
  758. start: Date.now(),
  759. end: Date.now(),
  760. },
  761. input: {},
  762. }
  763. }
  764. }
  765. next.time.completed = Date.now()
  766. await updateMessage(next)
  767. return next
  768. }
  769. export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
  770. // TODO
  771. /*
  772. const message = await getMessage(input.sessionID, input.messageID)
  773. if (!message) return
  774. const part = message.parts[input.part]
  775. if (!part) return
  776. const session = await get(input.sessionID)
  777. const snapshot =
  778. session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
  779. const old = (() => {
  780. if (message.role === "assistant") {
  781. const lastTool = message.parts.findLast(
  782. (part, index) =>
  783. part.type === "tool-invocation" && index < input.part,
  784. )
  785. if (lastTool && lastTool.type === "tool-invocation")
  786. return message.metadata.tool[lastTool.toolInvocation.toolCallId]
  787. .snapshot
  788. }
  789. return message.metadata.snapshot
  790. })()
  791. if (old) await Snapshot.restore(input.sessionID, old)
  792. await update(input.sessionID, (draft) => {
  793. draft.revert = {
  794. messageID: input.messageID,
  795. part: input.part,
  796. snapshot,
  797. }
  798. })
  799. */
  800. }
  801. export async function unrevert(sessionID: string) {
  802. const session = await get(sessionID)
  803. if (!session) return
  804. if (!session.revert) return
  805. if (session.revert.snapshot) await Snapshot.restore(sessionID, session.revert.snapshot)
  806. update(sessionID, (draft) => {
  807. draft.revert = undefined
  808. })
  809. }
  810. export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
  811. using abort = lock(input.sessionID)
  812. const msgs = await messages(input.sessionID)
  813. const lastSummary = msgs.findLast((msg) => msg.role === "assistant" && msg.summary === true)?.id
  814. const filtered = msgs.filter((msg) => !lastSummary || msg.id >= lastSummary)
  815. const model = await Provider.getModel(input.providerID, input.modelID)
  816. const app = App.info()
  817. const system = SystemPrompt.summarize(input.providerID)
  818. const next: MessageV2.Info = {
  819. id: Identifier.ascending("message"),
  820. role: "assistant",
  821. parts: [],
  822. sessionID: input.sessionID,
  823. system,
  824. path: {
  825. cwd: app.path.cwd,
  826. root: app.path.root,
  827. },
  828. summary: true,
  829. cost: 0,
  830. modelID: input.modelID,
  831. providerID: input.providerID,
  832. tokens: {
  833. input: 0,
  834. output: 0,
  835. reasoning: 0,
  836. cache: { read: 0, write: 0 },
  837. },
  838. time: {
  839. created: Date.now(),
  840. },
  841. }
  842. await updateMessage(next)
  843. let text: MessageV2.TextPart | undefined
  844. const result = streamText({
  845. abortSignal: abort.signal,
  846. model: model.language,
  847. messages: [
  848. ...system.map(
  849. (x): ModelMessage => ({
  850. role: "system",
  851. content: x,
  852. }),
  853. ),
  854. ...MessageV2.toModelMessage(filtered),
  855. {
  856. role: "user",
  857. content: [
  858. {
  859. type: "text",
  860. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  861. },
  862. ],
  863. },
  864. ],
  865. onStepFinish: async (step) => {
  866. const usage = getUsage(model.info, step.usage, step.providerMetadata)
  867. next.cost += usage.cost
  868. next.tokens = usage.tokens
  869. await updateMessage(next)
  870. if (text) {
  871. Bus.publish(MessageV2.Event.PartUpdated, {
  872. part: text,
  873. messageID: next.id,
  874. sessionID: next.sessionID,
  875. })
  876. }
  877. text = undefined
  878. },
  879. async onFinish(input) {
  880. const usage = getUsage(model.info, input.usage, input.providerMetadata)
  881. next.cost += usage.cost
  882. next.tokens = usage.tokens
  883. next.time.completed = Date.now()
  884. await updateMessage(next)
  885. },
  886. })
  887. try {
  888. for await (const value of result.fullStream) {
  889. switch (value.type) {
  890. case "text":
  891. if (!text) {
  892. text = {
  893. type: "text",
  894. text: value.text,
  895. }
  896. next.parts.push(text)
  897. } else text.text += value.text
  898. await updateMessage(next)
  899. break
  900. }
  901. }
  902. } catch (e: any) {
  903. log.error("summarize stream error", {
  904. error: e,
  905. })
  906. switch (true) {
  907. case e instanceof DOMException && e.name === "AbortError":
  908. next.error = new MessageV2.AbortedError(
  909. { message: e.message },
  910. {
  911. cause: e,
  912. },
  913. ).toObject()
  914. break
  915. case MessageV2.OutputLengthError.isInstance(e):
  916. next.error = e
  917. break
  918. case LoadAPIKeyError.isInstance(e):
  919. next.error = new Provider.AuthError(
  920. {
  921. providerID: input.providerID,
  922. message: e.message,
  923. },
  924. { cause: e },
  925. ).toObject()
  926. break
  927. case e instanceof Error:
  928. next.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  929. break
  930. default:
  931. next.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e }).toObject()
  932. }
  933. Bus.publish(Event.Error, {
  934. error: next.error,
  935. })
  936. }
  937. next.time.completed = Date.now()
  938. await updateMessage(next)
  939. }
  940. function lock(sessionID: string) {
  941. log.info("locking", { sessionID })
  942. if (state().pending.has(sessionID)) throw new BusyError(sessionID)
  943. const controller = new AbortController()
  944. state().pending.set(sessionID, controller)
  945. return {
  946. signal: controller.signal,
  947. [Symbol.dispose]() {
  948. log.info("unlocking", { sessionID })
  949. state().pending.delete(sessionID)
  950. Bus.publish(Event.Idle, {
  951. sessionID,
  952. })
  953. },
  954. }
  955. }
  956. function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
  957. const tokens = {
  958. input: usage.inputTokens ?? 0,
  959. output: usage.outputTokens ?? 0,
  960. reasoning: 0,
  961. cache: {
  962. write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  963. // @ts-expect-error
  964. metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  965. 0) as number,
  966. read: usage.cachedInputTokens ?? 0,
  967. },
  968. }
  969. return {
  970. cost: new Decimal(0)
  971. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  972. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  973. .add(new Decimal(tokens.cache.read).mul(model.cost.cache_read ?? 0).div(1_000_000))
  974. .add(new Decimal(tokens.cache.write).mul(model.cost.cache_write ?? 0).div(1_000_000))
  975. .toNumber(),
  976. tokens,
  977. }
  978. }
  979. export class BusyError extends Error {
  980. constructor(public readonly sessionID: string) {
  981. super(`Session ${sessionID} is busy`)
  982. }
  983. }
  984. export async function initialize(input: { sessionID: string; modelID: string; providerID: string }) {
  985. const app = App.info()
  986. await Session.chat({
  987. sessionID: input.sessionID,
  988. providerID: input.providerID,
  989. modelID: input.modelID,
  990. parts: [
  991. {
  992. type: "text",
  993. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  994. },
  995. ],
  996. })
  997. await App.initialize()
  998. }
  999. }