index.ts 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048
  1. import path from "path"
  2. import { Decimal } from "decimal.js"
  3. import { z, ZodSchema } from "zod"
  4. import {
  5. generateText,
  6. LoadAPIKeyError,
  7. streamText,
  8. tool,
  9. wrapLanguageModel,
  10. type Tool as AITool,
  11. type LanguageModelUsage,
  12. type ProviderMetadata,
  13. type ModelMessage,
  14. stepCountIs,
  15. } from "ai"
  16. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  17. import PROMPT_PLAN from "../session/prompt/plan.txt"
  18. import { App } from "../app/app"
  19. import { Bus } from "../bus"
  20. import { Config } from "../config/config"
  21. import { Flag } from "../flag/flag"
  22. import { Identifier } from "../id/id"
  23. import { Installation } from "../installation"
  24. import { MCP } from "../mcp"
  25. import { Provider } from "../provider/provider"
  26. import { ProviderTransform } from "../provider/transform"
  27. import type { ModelsDev } from "../provider/models"
  28. import { Share } from "../share/share"
  29. import { Snapshot } from "../snapshot"
  30. import { Storage } from "../storage/storage"
  31. import { Log } from "../util/log"
  32. import { NamedError } from "../util/error"
  33. import { SystemPrompt } from "./system"
  34. import { FileTime } from "../file/time"
  35. import { MessageV2 } from "./message-v2"
  36. import { Mode } from "./mode"
  37. import { LSP } from "../lsp"
  38. import { ReadTool } from "../tool/read"
  39. export namespace Session {
  40. const log = Log.create({ service: "session" })
  41. export const Info = z
  42. .object({
  43. id: Identifier.schema("session"),
  44. parentID: Identifier.schema("session").optional(),
  45. share: z
  46. .object({
  47. url: z.string(),
  48. })
  49. .optional(),
  50. title: z.string(),
  51. version: z.string(),
  52. time: z.object({
  53. created: z.number(),
  54. updated: z.number(),
  55. }),
  56. revert: z
  57. .object({
  58. messageID: z.string(),
  59. part: z.number(),
  60. snapshot: z.string().optional(),
  61. })
  62. .optional(),
  63. })
  64. .openapi({
  65. ref: "Session",
  66. })
  67. export type Info = z.output<typeof Info>
  68. export const ShareInfo = z
  69. .object({
  70. secret: z.string(),
  71. url: z.string(),
  72. })
  73. .openapi({
  74. ref: "SessionShare",
  75. })
  76. export type ShareInfo = z.output<typeof ShareInfo>
  77. export const Event = {
  78. Updated: Bus.event(
  79. "session.updated",
  80. z.object({
  81. info: Info,
  82. }),
  83. ),
  84. Deleted: Bus.event(
  85. "session.deleted",
  86. z.object({
  87. info: Info,
  88. }),
  89. ),
  90. Idle: Bus.event(
  91. "session.idle",
  92. z.object({
  93. sessionID: z.string(),
  94. }),
  95. ),
  96. Error: Bus.event(
  97. "session.error",
  98. z.object({
  99. sessionID: z.string().optional(),
  100. error: MessageV2.Assistant.shape.error,
  101. }),
  102. ),
  103. }
  104. const state = App.state(
  105. "session",
  106. () => {
  107. const sessions = new Map<string, Info>()
  108. const messages = new Map<string, MessageV2.Info[]>()
  109. const pending = new Map<string, AbortController>()
  110. return {
  111. sessions,
  112. messages,
  113. pending,
  114. }
  115. },
  116. async (state) => {
  117. for (const [_, controller] of state.pending) {
  118. controller.abort()
  119. }
  120. },
  121. )
  122. export async function create(parentID?: string) {
  123. const result: Info = {
  124. id: Identifier.descending("session"),
  125. version: Installation.VERSION,
  126. parentID,
  127. title: (parentID ? "Child session - " : "New Session - ") + new Date().toISOString(),
  128. time: {
  129. created: Date.now(),
  130. updated: Date.now(),
  131. },
  132. }
  133. log.info("created", result)
  134. state().sessions.set(result.id, result)
  135. await Storage.writeJSON("session/info/" + result.id, result)
  136. const cfg = await Config.get()
  137. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.autoshare))
  138. share(result.id).then((share) => {
  139. update(result.id, (draft) => {
  140. draft.share = share
  141. })
  142. })
  143. Bus.publish(Event.Updated, {
  144. info: result,
  145. })
  146. return result
  147. }
  148. export async function get(id: string) {
  149. const result = state().sessions.get(id)
  150. if (result) {
  151. return result
  152. }
  153. const read = await Storage.readJSON<Info>("session/info/" + id)
  154. state().sessions.set(id, read)
  155. return read as Info
  156. }
  157. export async function getShare(id: string) {
  158. return Storage.readJSON<ShareInfo>("session/share/" + id)
  159. }
  160. export async function share(id: string) {
  161. const session = await get(id)
  162. if (session.share) return session.share
  163. const share = await Share.create(id)
  164. await update(id, (draft) => {
  165. draft.share = {
  166. url: share.url,
  167. }
  168. })
  169. await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
  170. await Share.sync("session/info/" + id, session)
  171. for (const msg of await messages(id)) {
  172. await Share.sync("session/message/" + id + "/" + msg.id, msg)
  173. }
  174. return share
  175. }
  176. export async function unshare(id: string) {
  177. const share = await getShare(id)
  178. if (!share) return
  179. await Storage.remove("session/share/" + id)
  180. await update(id, (draft) => {
  181. draft.share = undefined
  182. })
  183. await Share.remove(id, share.secret)
  184. }
  185. export async function update(id: string, editor: (session: Info) => void) {
  186. const { sessions } = state()
  187. const session = await get(id)
  188. if (!session) return
  189. editor(session)
  190. session.time.updated = Date.now()
  191. sessions.set(id, session)
  192. await Storage.writeJSON("session/info/" + id, session)
  193. Bus.publish(Event.Updated, {
  194. info: session,
  195. })
  196. return session
  197. }
  198. export async function messages(sessionID: string) {
  199. const result = [] as MessageV2.Info[]
  200. const list = Storage.list("session/message/" + sessionID)
  201. for await (const p of list) {
  202. const read = await Storage.readJSON<MessageV2.Info>(p)
  203. result.push(read)
  204. }
  205. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  206. return result
  207. }
  208. export async function getMessage(sessionID: string, messageID: string) {
  209. return Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID)
  210. }
  211. export async function* list() {
  212. for await (const item of Storage.list("session/info")) {
  213. const sessionID = path.basename(item, ".json")
  214. yield get(sessionID)
  215. }
  216. }
  217. export async function children(parentID: string) {
  218. const result = [] as Session.Info[]
  219. for await (const item of Storage.list("session/info")) {
  220. const sessionID = path.basename(item, ".json")
  221. const session = await get(sessionID)
  222. if (session.parentID !== parentID) continue
  223. result.push(session)
  224. }
  225. return result
  226. }
  227. export function abort(sessionID: string) {
  228. const controller = state().pending.get(sessionID)
  229. if (!controller) return false
  230. controller.abort()
  231. state().pending.delete(sessionID)
  232. return true
  233. }
  234. export async function remove(sessionID: string, emitEvent = true) {
  235. try {
  236. abort(sessionID)
  237. const session = await get(sessionID)
  238. for (const child of await children(sessionID)) {
  239. await remove(child.id, false)
  240. }
  241. await unshare(sessionID).catch(() => {})
  242. await Storage.remove(`session/info/${sessionID}`).catch(() => {})
  243. await Storage.removeDir(`session/message/${sessionID}/`).catch(() => {})
  244. state().sessions.delete(sessionID)
  245. state().messages.delete(sessionID)
  246. if (emitEvent) {
  247. Bus.publish(Event.Deleted, {
  248. info: session,
  249. })
  250. }
  251. } catch (e) {
  252. log.error(e)
  253. }
  254. }
  255. async function updateMessage(msg: MessageV2.Info) {
  256. await Storage.writeJSON("session/message/" + msg.sessionID + "/" + msg.id, msg)
  257. Bus.publish(MessageV2.Event.Updated, {
  258. info: msg,
  259. })
  260. }
  261. export async function chat(input: {
  262. sessionID: string
  263. providerID: string
  264. modelID: string
  265. mode?: string
  266. parts: MessageV2.UserPart[]
  267. }) {
  268. using abort = lock(input.sessionID)
  269. const l = log.clone().tag("session", input.sessionID)
  270. l.info("chatting")
  271. const model = await Provider.getModel(input.providerID, input.modelID)
  272. let msgs = await messages(input.sessionID)
  273. const session = await get(input.sessionID)
  274. if (session.revert) {
  275. const trimmed = []
  276. for (const msg of msgs) {
  277. if (msg.id > session.revert.messageID || (msg.id === session.revert.messageID && session.revert.part === 0)) {
  278. await Storage.remove("session/message/" + input.sessionID + "/" + msg.id)
  279. await Bus.publish(MessageV2.Event.Removed, {
  280. sessionID: input.sessionID,
  281. messageID: msg.id,
  282. })
  283. continue
  284. }
  285. if (msg.id === session.revert.messageID) {
  286. if (session.revert.part === 0) break
  287. msg.parts = msg.parts.slice(0, session.revert.part)
  288. }
  289. trimmed.push(msg)
  290. }
  291. msgs = trimmed
  292. await update(input.sessionID, (draft) => {
  293. draft.revert = undefined
  294. })
  295. }
  296. const previous = msgs.at(-1) as MessageV2.Assistant
  297. // auto summarize if too long
  298. if (previous) {
  299. const tokens =
  300. previous.tokens.input + previous.tokens.cache.read + previous.tokens.cache.write + previous.tokens.output
  301. if (
  302. model.info.limit.context &&
  303. tokens > Math.max((model.info.limit.context - (model.info.limit.output ?? 0)) * 0.9, 0)
  304. ) {
  305. await summarize({
  306. sessionID: input.sessionID,
  307. providerID: input.providerID,
  308. modelID: input.modelID,
  309. })
  310. return chat(input)
  311. }
  312. }
  313. const lastSummary = msgs.findLast((msg) => msg.role === "assistant" && msg.summary === true)
  314. if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
  315. const app = App.info()
  316. input.parts = await Promise.all(
  317. input.parts.map(async (part): Promise<MessageV2.UserPart[]> => {
  318. if (part.type === "file") {
  319. const url = new URL(part.url)
  320. switch (url.protocol) {
  321. case "file:":
  322. // have to normalize, symbol search returns absolute paths
  323. const relativePath = url.pathname.replace(app.path.cwd, ".")
  324. const filePath = path.join(app.path.cwd, relativePath)
  325. if (part.mime === "text/plain") {
  326. let offset: number | undefined = undefined
  327. let limit: number | undefined = undefined
  328. const range = {
  329. start: url.searchParams.get("start"),
  330. end: url.searchParams.get("end"),
  331. }
  332. if (range.start != null) {
  333. const filePath = part.url.split("?")[0]
  334. let start = parseInt(range.start)
  335. let end = range.end ? parseInt(range.end) : undefined
  336. // some LSP servers (eg, gopls) don't give full range in
  337. // workspace/symbol searches, so we'll try to find the
  338. // symbol in the document to get the full range
  339. if (start === end) {
  340. const symbols = await LSP.documentSymbol(filePath)
  341. for (const symbol of symbols) {
  342. let range: LSP.Range | undefined
  343. if ("range" in symbol) {
  344. range = symbol.range
  345. } else if ("location" in symbol) {
  346. range = symbol.location.range
  347. }
  348. if (range?.start?.line && range?.start?.line === start) {
  349. start = range.start.line
  350. end = range?.end?.line ?? start
  351. break
  352. }
  353. }
  354. offset = Math.max(start - 2, 0)
  355. if (end) {
  356. limit = end - offset + 2
  357. }
  358. }
  359. }
  360. const args = { filePath, offset, limit }
  361. const result = await ReadTool.execute(args, {
  362. sessionID: input.sessionID,
  363. abort: abort.signal,
  364. messageID: "", // read tool doesn't use message ID
  365. metadata: async () => {},
  366. })
  367. return [
  368. {
  369. type: "text",
  370. synthetic: true,
  371. text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
  372. },
  373. {
  374. type: "text",
  375. synthetic: true,
  376. text: result.output,
  377. },
  378. ]
  379. }
  380. let file = Bun.file(filePath)
  381. FileTime.read(input.sessionID, filePath)
  382. return [
  383. {
  384. type: "text",
  385. text: `Called the Read tool with the following input: {\"filePath\":\"${url.pathname}\"}`,
  386. synthetic: true,
  387. },
  388. {
  389. type: "file",
  390. url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
  391. mime: part.mime,
  392. filename: part.filename!,
  393. },
  394. ]
  395. }
  396. }
  397. return [part]
  398. }),
  399. ).then((x) => x.flat())
  400. if (input.mode === "plan")
  401. input.parts.push({
  402. type: "text",
  403. text: PROMPT_PLAN,
  404. synthetic: true,
  405. })
  406. if (msgs.length === 0 && !session.parentID) {
  407. generateText({
  408. maxOutputTokens: input.providerID === "google" ? 1024 : 20,
  409. providerOptions: model.info.options,
  410. messages: [
  411. ...SystemPrompt.title(input.providerID).map(
  412. (x): ModelMessage => ({
  413. role: "system",
  414. content: x,
  415. }),
  416. ),
  417. ...MessageV2.toModelMessage([
  418. {
  419. id: Identifier.ascending("message"),
  420. role: "user",
  421. sessionID: input.sessionID,
  422. parts: input.parts,
  423. time: {
  424. created: Date.now(),
  425. },
  426. },
  427. ]),
  428. ],
  429. model: model.language,
  430. })
  431. .then((result) => {
  432. if (result.text)
  433. return Session.update(input.sessionID, (draft) => {
  434. draft.title = result.text
  435. })
  436. })
  437. .catch(() => {})
  438. }
  439. const msg: MessageV2.Info = {
  440. id: Identifier.ascending("message"),
  441. role: "user",
  442. sessionID: input.sessionID,
  443. parts: input.parts,
  444. time: {
  445. created: Date.now(),
  446. },
  447. }
  448. await updateMessage(msg)
  449. msgs.push(msg)
  450. const mode = await Mode.get(input.mode ?? "build")
  451. let system = mode.prompt ? [mode.prompt] : SystemPrompt.provider(input.providerID, input.modelID)
  452. system.push(...(await SystemPrompt.environment()))
  453. system.push(...(await SystemPrompt.custom()))
  454. // max 2 system prompt messages for caching purposes
  455. const [first, ...rest] = system
  456. system = [first, rest.join("\n")]
  457. const next: MessageV2.Info = {
  458. id: Identifier.ascending("message"),
  459. role: "assistant",
  460. parts: [],
  461. system,
  462. path: {
  463. cwd: app.path.cwd,
  464. root: app.path.root,
  465. },
  466. cost: 0,
  467. tokens: {
  468. input: 0,
  469. output: 0,
  470. reasoning: 0,
  471. cache: { read: 0, write: 0 },
  472. },
  473. modelID: input.modelID,
  474. providerID: input.providerID,
  475. time: {
  476. created: Date.now(),
  477. },
  478. sessionID: input.sessionID,
  479. }
  480. await updateMessage(next)
  481. const tools: Record<string, AITool> = {}
  482. for (const item of await Provider.tools(input.providerID)) {
  483. if (mode.tools[item.id] === false) continue
  484. tools[item.id] = tool({
  485. id: item.id as any,
  486. description: item.description,
  487. inputSchema: item.parameters as ZodSchema,
  488. async execute(args, opts) {
  489. const result = await item.execute(args, {
  490. sessionID: input.sessionID,
  491. abort: abort.signal,
  492. messageID: next.id,
  493. metadata: async (val) => {
  494. const match = next.parts.find(
  495. (p): p is MessageV2.ToolPart => p.type === "tool" && p.id === opts.toolCallId,
  496. )
  497. if (match && match.state.status === "running") {
  498. match.state.title = val.title
  499. match.state.metadata = val.metadata
  500. }
  501. await updateMessage(next)
  502. },
  503. })
  504. return result
  505. },
  506. toModelOutput(result) {
  507. return {
  508. type: "text",
  509. value: result.output,
  510. }
  511. },
  512. })
  513. }
  514. for (const [key, item] of Object.entries(await MCP.tools())) {
  515. if (mode.tools[key] === false) continue
  516. const execute = item.execute
  517. if (!execute) continue
  518. item.execute = async (args, opts) => {
  519. const result = await execute(args, opts)
  520. const output = result.content
  521. .filter((x: any) => x.type === "text")
  522. .map((x: any) => x.text)
  523. .join("\n\n")
  524. return {
  525. output,
  526. }
  527. }
  528. item.toModelOutput = (result) => {
  529. return {
  530. type: "text",
  531. value: result.output,
  532. }
  533. }
  534. tools[key] = item
  535. }
  536. let text: MessageV2.TextPart = {
  537. type: "text",
  538. text: "",
  539. }
  540. const result = streamText({
  541. onError() {},
  542. maxRetries: 10,
  543. maxOutputTokens: Math.max(0, model.info.limit.output) || undefined,
  544. abortSignal: abort.signal,
  545. stopWhen: stepCountIs(1000),
  546. providerOptions: model.info.options,
  547. messages: [
  548. ...system.map(
  549. (x): ModelMessage => ({
  550. role: "system",
  551. content: x,
  552. }),
  553. ),
  554. ...MessageV2.toModelMessage(msgs),
  555. ],
  556. temperature: model.info.temperature ? 0 : undefined,
  557. tools: model.info.tool_call === false ? undefined : tools,
  558. model: wrapLanguageModel({
  559. model: model.language,
  560. middleware: [
  561. {
  562. async transformParams(args) {
  563. if (args.type === "stream") {
  564. // @ts-expect-error
  565. args.params.prompt = ProviderTransform.message(args.params.prompt, input.providerID, input.modelID)
  566. }
  567. return args.params
  568. },
  569. },
  570. ],
  571. }),
  572. })
  573. try {
  574. for await (const value of result.fullStream) {
  575. l.info("part", {
  576. type: value.type,
  577. })
  578. switch (value.type) {
  579. case "start":
  580. break
  581. case "tool-input-start":
  582. next.parts.push({
  583. type: "tool",
  584. tool: value.toolName,
  585. id: value.id,
  586. state: {
  587. status: "pending",
  588. },
  589. })
  590. Bus.publish(MessageV2.Event.PartUpdated, {
  591. part: next.parts[next.parts.length - 1],
  592. sessionID: next.sessionID,
  593. messageID: next.id,
  594. })
  595. break
  596. case "tool-input-delta":
  597. break
  598. case "tool-call": {
  599. const match = next.parts.find(
  600. (p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
  601. )
  602. if (match) {
  603. match.state = {
  604. status: "running",
  605. input: value.input,
  606. time: {
  607. start: Date.now(),
  608. },
  609. }
  610. Bus.publish(MessageV2.Event.PartUpdated, {
  611. part: match,
  612. sessionID: next.sessionID,
  613. messageID: next.id,
  614. })
  615. }
  616. break
  617. }
  618. case "tool-result": {
  619. const match = next.parts.find(
  620. (p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
  621. )
  622. if (match && match.state.status === "running") {
  623. match.state = {
  624. status: "completed",
  625. input: value.input,
  626. output: value.output.output,
  627. metadata: value.output.metadata,
  628. title: value.output.title,
  629. time: {
  630. start: match.state.time.start,
  631. end: Date.now(),
  632. },
  633. }
  634. Bus.publish(MessageV2.Event.PartUpdated, {
  635. part: match,
  636. sessionID: next.sessionID,
  637. messageID: next.id,
  638. })
  639. }
  640. break
  641. }
  642. case "tool-error": {
  643. const match = next.parts.find(
  644. (p): p is MessageV2.ToolPart => p.type === "tool" && p.id === value.toolCallId,
  645. )
  646. if (match && match.state.status === "running") {
  647. match.state = {
  648. status: "error",
  649. input: value.input,
  650. error: (value.error as any).toString(),
  651. time: {
  652. start: match.state.time.start,
  653. end: Date.now(),
  654. },
  655. }
  656. Bus.publish(MessageV2.Event.PartUpdated, {
  657. part: match,
  658. sessionID: next.sessionID,
  659. messageID: next.id,
  660. })
  661. }
  662. break
  663. }
  664. case "error":
  665. throw value.error
  666. case "start-step":
  667. next.parts.push({
  668. type: "step-start",
  669. })
  670. break
  671. case "finish-step":
  672. const usage = getUsage(model.info, value.usage, value.providerMetadata)
  673. next.cost += usage.cost
  674. next.tokens = usage.tokens
  675. break
  676. case "text-start":
  677. text = {
  678. type: "text",
  679. text: "",
  680. }
  681. break
  682. case "text":
  683. if (text.text === "") next.parts.push(text)
  684. text.text += value.text
  685. break
  686. case "text-end":
  687. Bus.publish(MessageV2.Event.PartUpdated, {
  688. part: text,
  689. sessionID: next.sessionID,
  690. messageID: next.id,
  691. })
  692. break
  693. case "finish":
  694. next.time.completed = Date.now()
  695. break
  696. default:
  697. l.info("unhandled", {
  698. ...value,
  699. })
  700. continue
  701. }
  702. await updateMessage(next)
  703. }
  704. } catch (e) {
  705. log.error("", {
  706. error: e,
  707. })
  708. switch (true) {
  709. case e instanceof DOMException && e.name === "AbortError":
  710. next.error = new MessageV2.AbortedError(
  711. { message: e.message },
  712. {
  713. cause: e,
  714. },
  715. ).toObject()
  716. break
  717. case MessageV2.OutputLengthError.isInstance(e):
  718. next.error = e
  719. break
  720. case LoadAPIKeyError.isInstance(e):
  721. next.error = new Provider.AuthError(
  722. {
  723. providerID: input.providerID,
  724. message: e.message,
  725. },
  726. { cause: e },
  727. ).toObject()
  728. break
  729. case e instanceof Error:
  730. next.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  731. break
  732. default:
  733. next.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
  734. }
  735. Bus.publish(Event.Error, {
  736. sessionID: next.sessionID,
  737. error: next.error,
  738. })
  739. }
  740. for (const part of next.parts) {
  741. if (part.type === "tool" && part.state.status !== "completed") {
  742. part.state = {
  743. status: "error",
  744. error: "Tool execution aborted",
  745. time: {
  746. start: Date.now(),
  747. end: Date.now(),
  748. },
  749. input: {},
  750. }
  751. }
  752. }
  753. next.time.completed = Date.now()
  754. await updateMessage(next)
  755. return next
  756. }
  757. export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
  758. // TODO
  759. /*
  760. const message = await getMessage(input.sessionID, input.messageID)
  761. if (!message) return
  762. const part = message.parts[input.part]
  763. if (!part) return
  764. const session = await get(input.sessionID)
  765. const snapshot =
  766. session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
  767. const old = (() => {
  768. if (message.role === "assistant") {
  769. const lastTool = message.parts.findLast(
  770. (part, index) =>
  771. part.type === "tool-invocation" && index < input.part,
  772. )
  773. if (lastTool && lastTool.type === "tool-invocation")
  774. return message.metadata.tool[lastTool.toolInvocation.toolCallId]
  775. .snapshot
  776. }
  777. return message.metadata.snapshot
  778. })()
  779. if (old) await Snapshot.restore(input.sessionID, old)
  780. await update(input.sessionID, (draft) => {
  781. draft.revert = {
  782. messageID: input.messageID,
  783. part: input.part,
  784. snapshot,
  785. }
  786. })
  787. */
  788. }
  789. export async function unrevert(sessionID: string) {
  790. const session = await get(sessionID)
  791. if (!session) return
  792. if (!session.revert) return
  793. if (session.revert.snapshot) await Snapshot.restore(sessionID, session.revert.snapshot)
  794. update(sessionID, (draft) => {
  795. draft.revert = undefined
  796. })
  797. }
  798. export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
  799. using abort = lock(input.sessionID)
  800. const msgs = await messages(input.sessionID)
  801. const lastSummary = msgs.findLast((msg) => msg.role === "assistant" && msg.summary === true)?.id
  802. const filtered = msgs.filter((msg) => !lastSummary || msg.id >= lastSummary)
  803. const model = await Provider.getModel(input.providerID, input.modelID)
  804. const app = App.info()
  805. const system = SystemPrompt.summarize(input.providerID)
  806. const next: MessageV2.Info = {
  807. id: Identifier.ascending("message"),
  808. role: "assistant",
  809. parts: [],
  810. sessionID: input.sessionID,
  811. system,
  812. path: {
  813. cwd: app.path.cwd,
  814. root: app.path.root,
  815. },
  816. summary: true,
  817. cost: 0,
  818. modelID: input.modelID,
  819. providerID: input.providerID,
  820. tokens: {
  821. input: 0,
  822. output: 0,
  823. reasoning: 0,
  824. cache: { read: 0, write: 0 },
  825. },
  826. time: {
  827. created: Date.now(),
  828. },
  829. }
  830. await updateMessage(next)
  831. let text: MessageV2.TextPart | undefined
  832. const result = streamText({
  833. abortSignal: abort.signal,
  834. model: model.language,
  835. messages: [
  836. ...system.map(
  837. (x): ModelMessage => ({
  838. role: "system",
  839. content: x,
  840. }),
  841. ),
  842. ...MessageV2.toModelMessage(filtered),
  843. {
  844. role: "user",
  845. content: [
  846. {
  847. type: "text",
  848. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  849. },
  850. ],
  851. },
  852. ],
  853. onStepFinish: async (step) => {
  854. const usage = getUsage(model.info, step.usage, step.providerMetadata)
  855. next.cost += usage.cost
  856. next.tokens = usage.tokens
  857. await updateMessage(next)
  858. if (text) {
  859. Bus.publish(MessageV2.Event.PartUpdated, {
  860. part: text,
  861. messageID: next.id,
  862. sessionID: next.sessionID,
  863. })
  864. }
  865. text = undefined
  866. },
  867. async onFinish(input) {
  868. const usage = getUsage(model.info, input.usage, input.providerMetadata)
  869. next.cost += usage.cost
  870. next.tokens = usage.tokens
  871. next.time.completed = Date.now()
  872. await updateMessage(next)
  873. },
  874. })
  875. try {
  876. for await (const value of result.fullStream) {
  877. switch (value.type) {
  878. case "text":
  879. if (!text) {
  880. text = {
  881. type: "text",
  882. text: value.text,
  883. }
  884. next.parts.push(text)
  885. } else text.text += value.text
  886. await updateMessage(next)
  887. break
  888. }
  889. }
  890. } catch (e: any) {
  891. log.error("summarize stream error", {
  892. error: e,
  893. })
  894. switch (true) {
  895. case e instanceof DOMException && e.name === "AbortError":
  896. next.error = new MessageV2.AbortedError(
  897. { message: e.message },
  898. {
  899. cause: e,
  900. },
  901. ).toObject()
  902. break
  903. case MessageV2.OutputLengthError.isInstance(e):
  904. next.error = e
  905. break
  906. case LoadAPIKeyError.isInstance(e):
  907. next.error = new Provider.AuthError(
  908. {
  909. providerID: input.providerID,
  910. message: e.message,
  911. },
  912. { cause: e },
  913. ).toObject()
  914. break
  915. case e instanceof Error:
  916. next.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  917. break
  918. default:
  919. next.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e }).toObject()
  920. }
  921. Bus.publish(Event.Error, {
  922. error: next.error,
  923. })
  924. }
  925. next.time.completed = Date.now()
  926. await updateMessage(next)
  927. }
  928. function lock(sessionID: string) {
  929. log.info("locking", { sessionID })
  930. if (state().pending.has(sessionID)) throw new BusyError(sessionID)
  931. const controller = new AbortController()
  932. state().pending.set(sessionID, controller)
  933. return {
  934. signal: controller.signal,
  935. [Symbol.dispose]() {
  936. log.info("unlocking", { sessionID })
  937. state().pending.delete(sessionID)
  938. Bus.publish(Event.Idle, {
  939. sessionID,
  940. })
  941. },
  942. }
  943. }
  944. function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
  945. const tokens = {
  946. input: usage.inputTokens ?? 0,
  947. output: usage.outputTokens ?? 0,
  948. reasoning: 0,
  949. cache: {
  950. write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  951. // @ts-expect-error
  952. metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  953. 0) as number,
  954. read: usage.cachedInputTokens ?? 0,
  955. },
  956. }
  957. return {
  958. cost: new Decimal(0)
  959. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  960. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  961. .add(new Decimal(tokens.cache.read).mul(model.cost.cache_read ?? 0).div(1_000_000))
  962. .add(new Decimal(tokens.cache.write).mul(model.cost.cache_write ?? 0).div(1_000_000))
  963. .toNumber(),
  964. tokens,
  965. }
  966. }
  967. export class BusyError extends Error {
  968. constructor(public readonly sessionID: string) {
  969. super(`Session ${sessionID} is busy`)
  970. }
  971. }
  972. export async function initialize(input: { sessionID: string; modelID: string; providerID: string }) {
  973. const app = App.info()
  974. await Session.chat({
  975. sessionID: input.sessionID,
  976. providerID: input.providerID,
  977. modelID: input.modelID,
  978. parts: [
  979. {
  980. type: "text",
  981. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  982. },
  983. ],
  984. })
  985. await App.initialize()
  986. }
  987. }