index.ts 32 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135
  1. import path from "node:path"
  2. import { App } from "../app/app"
  3. import { Identifier } from "../id/id"
  4. import { Storage } from "../storage/storage"
  5. import { Log } from "../util/log"
  6. import {
  7. generateText,
  8. LoadAPIKeyError,
  9. convertToCoreMessages,
  10. streamText,
  11. tool,
  12. type Tool as AITool,
  13. type LanguageModelUsage,
  14. type CoreMessage,
  15. type UIMessage,
  16. type ProviderMetadata,
  17. wrapLanguageModel,
  18. type Attachment,
  19. } from "ai"
  20. import { z, ZodSchema } from "zod"
  21. import { Decimal } from "decimal.js"
  22. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  23. import { Share } from "../share/share"
  24. import { Message } from "./message"
  25. import { Bus } from "../bus"
  26. import { Provider } from "../provider/provider"
  27. import { MCP } from "../mcp"
  28. import { NamedError } from "../util/error"
  29. import type { Tool } from "../tool/tool"
  30. import { SystemPrompt } from "./system"
  31. import { Flag } from "../flag/flag"
  32. import type { ModelsDev } from "../provider/models"
  33. import { Installation } from "../installation"
  34. import { Config } from "../config/config"
  35. import { ProviderTransform } from "../provider/transform"
  36. import { Snapshot } from "../snapshot"
  37. export namespace Session {
  38. const log = Log.create({ service: "session" })
  39. export const Info = z
  40. .object({
  41. id: Identifier.schema("session"),
  42. parentID: Identifier.schema("session").optional(),
  43. share: z
  44. .object({
  45. url: z.string(),
  46. })
  47. .optional(),
  48. title: z.string(),
  49. version: z.string(),
  50. time: z.object({
  51. created: z.number(),
  52. updated: z.number(),
  53. }),
  54. revert: z
  55. .object({
  56. messageID: z.string(),
  57. part: z.number(),
  58. snapshot: z.string().optional(),
  59. })
  60. .optional(),
  61. })
  62. .openapi({
  63. ref: "Session",
  64. })
  65. export type Info = z.output<typeof Info>
  66. export const ShareInfo = z
  67. .object({
  68. secret: z.string(),
  69. url: z.string(),
  70. })
  71. .openapi({
  72. ref: "SessionShare",
  73. })
  74. export type ShareInfo = z.output<typeof ShareInfo>
  75. export const Event = {
  76. Updated: Bus.event(
  77. "session.updated",
  78. z.object({
  79. info: Info,
  80. }),
  81. ),
  82. Deleted: Bus.event(
  83. "session.deleted",
  84. z.object({
  85. info: Info,
  86. }),
  87. ),
  88. Idle: Bus.event(
  89. "session.idle",
  90. z.object({
  91. sessionID: z.string(),
  92. }),
  93. ),
  94. Error: Bus.event(
  95. "session.error",
  96. z.object({
  97. error: Message.Info.shape.metadata.shape.error,
  98. }),
  99. ),
  100. }
  101. const state = App.state(
  102. "session",
  103. () => {
  104. const sessions = new Map<string, Info>()
  105. const messages = new Map<string, Message.Info[]>()
  106. const pending = new Map<string, AbortController>()
  107. return {
  108. sessions,
  109. messages,
  110. pending,
  111. }
  112. },
  113. async (state) => {
  114. for (const [_, controller] of state.pending) {
  115. controller.abort()
  116. }
  117. },
  118. )
  119. export async function create(parentID?: string) {
  120. const result: Info = {
  121. id: Identifier.descending("session"),
  122. version: Installation.VERSION,
  123. parentID,
  124. title:
  125. (parentID ? "Child session - " : "New Session - ") +
  126. new Date().toISOString(),
  127. time: {
  128. created: Date.now(),
  129. updated: Date.now(),
  130. },
  131. }
  132. log.info("created", result)
  133. state().sessions.set(result.id, result)
  134. await Storage.writeJSON("session/info/" + result.id, result)
  135. const cfg = await Config.get()
  136. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.autoshare))
  137. share(result.id).then((share) => {
  138. update(result.id, (draft) => {
  139. draft.share = share
  140. })
  141. })
  142. Bus.publish(Event.Updated, {
  143. info: result,
  144. })
  145. return result
  146. }
  147. export async function get(id: string) {
  148. const result = state().sessions.get(id)
  149. if (result) {
  150. return result
  151. }
  152. const read = await Storage.readJSON<Info>("session/info/" + id)
  153. state().sessions.set(id, read)
  154. return read as Info
  155. }
  156. export async function getShare(id: string) {
  157. return Storage.readJSON<ShareInfo>("session/share/" + id)
  158. }
  159. export async function share(id: string) {
  160. const session = await get(id)
  161. if (session.share) return session.share
  162. const share = await Share.create(id)
  163. await update(id, (draft) => {
  164. draft.share = {
  165. url: share.url,
  166. }
  167. })
  168. await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
  169. await Share.sync("session/info/" + id, session)
  170. for (const msg of await messages(id)) {
  171. await Share.sync("session/message/" + id + "/" + msg.id, msg)
  172. }
  173. return share
  174. }
  175. export async function unshare(id: string) {
  176. const share = await getShare(id)
  177. if (!share) return
  178. await Storage.remove("session/share/" + id)
  179. await update(id, (draft) => {
  180. draft.share = undefined
  181. })
  182. await Share.remove(id, share.secret)
  183. }
  184. export async function update(id: string, editor: (session: Info) => void) {
  185. const { sessions } = state()
  186. const session = await get(id)
  187. if (!session) return
  188. editor(session)
  189. session.time.updated = Date.now()
  190. sessions.set(id, session)
  191. await Storage.writeJSON("session/info/" + id, session)
  192. Bus.publish(Event.Updated, {
  193. info: session,
  194. })
  195. return session
  196. }
  197. export async function messages(sessionID: string) {
  198. const result = [] as Message.Info[]
  199. const list = Storage.list("session/message/" + sessionID)
  200. for await (const p of list) {
  201. const read = await Storage.readJSON<Message.Info>(p)
  202. result.push(read)
  203. }
  204. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  205. return result
  206. }
  207. export async function getMessage(sessionID: string, messageID: string) {
  208. return Storage.readJSON<Message.Info>(
  209. "session/message/" + sessionID + "/" + messageID,
  210. )
  211. }
  212. export async function* list() {
  213. for await (const item of Storage.list("session/info")) {
  214. const sessionID = path.basename(item, ".json")
  215. yield get(sessionID)
  216. }
  217. }
  218. export async function children(parentID: string) {
  219. const result = [] as Session.Info[]
  220. for await (const item of Storage.list("session/info")) {
  221. const sessionID = path.basename(item, ".json")
  222. const session = await get(sessionID)
  223. if (session.parentID !== parentID) continue
  224. result.push(session)
  225. }
  226. return result
  227. }
  228. export function abort(sessionID: string) {
  229. const controller = state().pending.get(sessionID)
  230. if (!controller) return false
  231. controller.abort()
  232. state().pending.delete(sessionID)
  233. return true
  234. }
  235. export async function remove(sessionID: string, emitEvent = true) {
  236. try {
  237. abort(sessionID)
  238. const session = await get(sessionID)
  239. for (const child of await children(sessionID)) {
  240. await remove(child.id, false)
  241. }
  242. await unshare(sessionID).catch(() => {})
  243. await Storage.remove(`session/info/${sessionID}`).catch(() => {})
  244. await Storage.removeDir(`session/message/${sessionID}/`).catch(() => {})
  245. state().sessions.delete(sessionID)
  246. state().messages.delete(sessionID)
  247. if (emitEvent) {
  248. Bus.publish(Event.Deleted, {
  249. info: session,
  250. })
  251. }
  252. } catch (e) {
  253. log.error(e)
  254. }
  255. }
  256. async function updateMessage(msg: Message.Info) {
  257. await Storage.writeJSON(
  258. "session/message/" + msg.metadata.sessionID + "/" + msg.id,
  259. msg,
  260. )
  261. Bus.publish(Message.Event.Updated, {
  262. info: msg,
  263. })
  264. }
  265. export async function chat(input: {
  266. sessionID: string
  267. providerID: string
  268. modelID: string
  269. parts: Message.MessagePart[]
  270. system?: string[]
  271. tools?: Tool.Info[]
  272. }) {
  273. const l = log.clone().tag("session", input.sessionID)
  274. l.info("chatting")
  275. const model = await Provider.getModel(input.providerID, input.modelID)
  276. let msgs = await messages(input.sessionID)
  277. const session = await get(input.sessionID)
  278. if (session.revert) {
  279. const trimmed = []
  280. for (const msg of msgs) {
  281. if (
  282. msg.id > session.revert.messageID ||
  283. (msg.id === session.revert.messageID && session.revert.part === 0)
  284. ) {
  285. await Storage.remove(
  286. "session/message/" + input.sessionID + "/" + msg.id,
  287. )
  288. await Bus.publish(Message.Event.Removed, {
  289. sessionID: input.sessionID,
  290. messageID: msg.id,
  291. })
  292. continue
  293. }
  294. if (msg.id === session.revert.messageID) {
  295. if (session.revert.part === 0) break
  296. msg.parts = msg.parts.slice(0, session.revert.part)
  297. }
  298. trimmed.push(msg)
  299. }
  300. msgs = trimmed
  301. await update(input.sessionID, (draft) => {
  302. draft.revert = undefined
  303. })
  304. }
  305. const previous = msgs.at(-1)
  306. // auto summarize if too long
  307. if (previous?.metadata.assistant) {
  308. const tokens =
  309. previous.metadata.assistant.tokens.input +
  310. previous.metadata.assistant.tokens.cache.read +
  311. previous.metadata.assistant.tokens.cache.write +
  312. previous.metadata.assistant.tokens.output
  313. if (
  314. model.info.limit.context &&
  315. tokens >
  316. Math.max(
  317. (model.info.limit.context - (model.info.limit.output ?? 0)) * 0.9,
  318. 0,
  319. )
  320. ) {
  321. await summarize({
  322. sessionID: input.sessionID,
  323. providerID: input.providerID,
  324. modelID: input.modelID,
  325. })
  326. return chat(input)
  327. }
  328. }
  329. using abort = lock(input.sessionID)
  330. const lastSummary = msgs.findLast(
  331. (msg) => msg.metadata.assistant?.summary === true,
  332. )
  333. if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
  334. const app = App.info()
  335. input.parts = await Promise.all(
  336. input.parts.map(async (part): Promise<Message.MessagePart[]> => {
  337. if (part.type === "file") {
  338. const url = new URL(part.url)
  339. switch (url.protocol) {
  340. case "file:":
  341. let content = Bun.file(path.join(app.path.cwd, url.pathname))
  342. if (part.mediaType === "text/plain") {
  343. let text = await content.text()
  344. const range = {
  345. start: url.searchParams.get("start"),
  346. end: url.searchParams.get("end"),
  347. }
  348. if (range.start != null && part.mediaType === "text/plain") {
  349. const lines = text.split("\n")
  350. const start = parseInt(range.start)
  351. const end = range.end ? parseInt(range.end) : lines.length
  352. text = lines.slice(start, end).join("\n")
  353. }
  354. return [
  355. {
  356. type: "text",
  357. text: [
  358. "Called the Read tool on " + url.pathname,
  359. "<results>",
  360. text,
  361. "</results>",
  362. ].join("\n"),
  363. },
  364. ]
  365. }
  366. return [
  367. {
  368. type: "text",
  369. text: ["Called the Read tool on " + url.pathname].join("\n"),
  370. },
  371. {
  372. type: "file",
  373. url:
  374. `data:${part.mediaType};base64,` +
  375. Buffer.from(await content.bytes()).toString("base64url"),
  376. mediaType: part.mediaType,
  377. filename: path.basename(part.filename!),
  378. },
  379. ]
  380. }
  381. }
  382. return [part]
  383. }),
  384. ).then((x) => x.flat())
  385. if (msgs.length === 0 && !session.parentID) {
  386. generateText({
  387. maxTokens: input.providerID === "google" ? 1024 : 20,
  388. providerOptions: model.info.options,
  389. messages: [
  390. ...SystemPrompt.title(input.providerID).map(
  391. (x): CoreMessage => ({
  392. role: "system",
  393. content: x,
  394. }),
  395. ),
  396. ...convertToCoreMessages([
  397. {
  398. role: "user",
  399. content: "",
  400. parts: toParts(input.parts).parts,
  401. },
  402. ]),
  403. ],
  404. model: model.language,
  405. })
  406. .then((result) => {
  407. if (result.text)
  408. return Session.update(input.sessionID, (draft) => {
  409. draft.title = result.text
  410. })
  411. })
  412. .catch(() => {})
  413. }
  414. const snapshot = await Snapshot.create(input.sessionID)
  415. const msg: Message.Info = {
  416. role: "user",
  417. id: Identifier.ascending("message"),
  418. parts: input.parts,
  419. metadata: {
  420. time: {
  421. created: Date.now(),
  422. },
  423. sessionID: input.sessionID,
  424. tool: {},
  425. snapshot,
  426. },
  427. }
  428. await updateMessage(msg)
  429. msgs.push(msg)
  430. const system = input.system ?? SystemPrompt.provider(input.providerID)
  431. system.push(...(await SystemPrompt.environment()))
  432. system.push(...(await SystemPrompt.custom()))
  433. const next: Message.Info = {
  434. id: Identifier.ascending("message"),
  435. role: "assistant",
  436. parts: [],
  437. metadata: {
  438. snapshot,
  439. assistant: {
  440. system,
  441. path: {
  442. cwd: app.path.cwd,
  443. root: app.path.root,
  444. },
  445. cost: 0,
  446. tokens: {
  447. input: 0,
  448. output: 0,
  449. reasoning: 0,
  450. cache: { read: 0, write: 0 },
  451. },
  452. modelID: input.modelID,
  453. providerID: input.providerID,
  454. },
  455. time: {
  456. created: Date.now(),
  457. },
  458. sessionID: input.sessionID,
  459. tool: {},
  460. },
  461. }
  462. await updateMessage(next)
  463. const tools: Record<string, AITool> = {}
  464. for (const item of await Provider.tools(input.providerID)) {
  465. tools[item.id.replaceAll(".", "_")] = tool({
  466. id: item.id as any,
  467. description: item.description,
  468. parameters: item.parameters as ZodSchema,
  469. async execute(args, opts) {
  470. const start = Date.now()
  471. try {
  472. const result = await item.execute(args, {
  473. sessionID: input.sessionID,
  474. abort: abort.signal,
  475. messageID: next.id,
  476. metadata: async (val) => {
  477. next.metadata.tool[opts.toolCallId] = {
  478. ...val,
  479. time: {
  480. start: 0,
  481. end: 0,
  482. },
  483. }
  484. await updateMessage(next)
  485. },
  486. })
  487. next.metadata!.tool![opts.toolCallId] = {
  488. ...result.metadata,
  489. snapshot: await Snapshot.create(input.sessionID),
  490. time: {
  491. start,
  492. end: Date.now(),
  493. },
  494. }
  495. await updateMessage(next)
  496. return result.output
  497. } catch (e: any) {
  498. next.metadata!.tool![opts.toolCallId] = {
  499. error: true,
  500. message: e.toString(),
  501. title: e.toString(),
  502. snapshot: await Snapshot.create(input.sessionID),
  503. time: {
  504. start,
  505. end: Date.now(),
  506. },
  507. }
  508. await updateMessage(next)
  509. return e.toString()
  510. }
  511. },
  512. })
  513. }
  514. for (const [key, item] of Object.entries(await MCP.tools())) {
  515. const execute = item.execute
  516. if (!execute) continue
  517. item.execute = async (args, opts) => {
  518. const start = Date.now()
  519. try {
  520. const result = await execute(args, opts)
  521. next.metadata!.tool![opts.toolCallId] = {
  522. ...result.metadata,
  523. snapshot: await Snapshot.create(input.sessionID),
  524. time: {
  525. start,
  526. end: Date.now(),
  527. },
  528. }
  529. await updateMessage(next)
  530. return result.content
  531. .filter((x: any) => x.type === "text")
  532. .map((x: any) => x.text)
  533. .join("\n\n")
  534. } catch (e: any) {
  535. next.metadata!.tool![opts.toolCallId] = {
  536. error: true,
  537. message: e.toString(),
  538. snapshot: await Snapshot.create(input.sessionID),
  539. title: "mcp",
  540. time: {
  541. start,
  542. end: Date.now(),
  543. },
  544. }
  545. await updateMessage(next)
  546. return e.toString()
  547. }
  548. }
  549. tools[key] = item
  550. }
  551. let text: Message.TextPart | undefined
  552. const result = streamText({
  553. onStepFinish: async (step) => {
  554. log.info("step finish", { finishReason: step.finishReason })
  555. const assistant = next.metadata!.assistant!
  556. const usage = getUsage(model.info, step.usage, step.providerMetadata)
  557. assistant.cost += usage.cost
  558. assistant.tokens = usage.tokens
  559. await updateMessage(next)
  560. if (text) {
  561. Bus.publish(Message.Event.PartUpdated, {
  562. part: text,
  563. messageID: next.id,
  564. sessionID: next.metadata.sessionID,
  565. })
  566. }
  567. text = undefined
  568. },
  569. onError(err) {
  570. log.error("callback error", err)
  571. switch (true) {
  572. case LoadAPIKeyError.isInstance(err.error):
  573. next.metadata.error = new Provider.AuthError(
  574. {
  575. providerID: input.providerID,
  576. message: err.error.message,
  577. },
  578. { cause: err.error },
  579. ).toObject()
  580. break
  581. case err.error instanceof Error:
  582. next.metadata.error = new NamedError.Unknown(
  583. { message: err.error.toString() },
  584. { cause: err.error },
  585. ).toObject()
  586. break
  587. default:
  588. next.metadata.error = new NamedError.Unknown(
  589. { message: JSON.stringify(err.error) },
  590. { cause: err.error },
  591. )
  592. }
  593. Bus.publish(Event.Error, {
  594. error: next.metadata.error,
  595. })
  596. },
  597. // async prepareStep(step) {
  598. // next.parts.push({
  599. // type: "step-start",
  600. // })
  601. // await updateMessage(next)
  602. // return step
  603. // },
  604. toolCallStreaming: true,
  605. maxRetries: 10,
  606. maxTokens: Math.max(0, model.info.limit.output) || undefined,
  607. abortSignal: abort.signal,
  608. maxSteps: 1000,
  609. providerOptions: model.info.options,
  610. messages: [
  611. ...system.map(
  612. (x): CoreMessage => ({
  613. role: "system",
  614. content: x,
  615. }),
  616. ),
  617. ...convertToCoreMessages(
  618. msgs.map(toUIMessage).filter((x) => x.parts.length > 0),
  619. ),
  620. ],
  621. temperature: model.info.temperature ? 0 : undefined,
  622. tools: model.info.tool_call === false ? undefined : tools,
  623. model: wrapLanguageModel({
  624. model: model.language,
  625. middleware: [
  626. {
  627. async transformParams(args) {
  628. if (args.type === "stream") {
  629. args.params.prompt = ProviderTransform.message(
  630. args.params.prompt,
  631. input.providerID,
  632. input.modelID,
  633. )
  634. }
  635. return args.params
  636. },
  637. },
  638. ],
  639. }),
  640. })
  641. try {
  642. for await (const value of result.fullStream) {
  643. l.info("part", {
  644. type: value.type,
  645. })
  646. switch (value.type) {
  647. case "step-start":
  648. next.parts.push({
  649. type: "step-start",
  650. })
  651. break
  652. case "text-delta":
  653. if (!text) {
  654. text = {
  655. type: "text",
  656. text: value.textDelta,
  657. }
  658. next.parts.push(text)
  659. break
  660. } else text.text += value.textDelta
  661. break
  662. case "tool-call": {
  663. const [match] = next.parts.flatMap((p) =>
  664. p.type === "tool-invocation" &&
  665. p.toolInvocation.toolCallId === value.toolCallId
  666. ? [p]
  667. : [],
  668. )
  669. if (!match) break
  670. match.toolInvocation.args = value.args
  671. match.toolInvocation.state = "call"
  672. Bus.publish(Message.Event.PartUpdated, {
  673. part: match,
  674. messageID: next.id,
  675. sessionID: next.metadata.sessionID,
  676. })
  677. break
  678. }
  679. case "tool-call-streaming-start":
  680. next.parts.push({
  681. type: "tool-invocation",
  682. toolInvocation: {
  683. state: "partial-call",
  684. toolName: value.toolName,
  685. toolCallId: value.toolCallId,
  686. args: {},
  687. },
  688. })
  689. Bus.publish(Message.Event.PartUpdated, {
  690. part: next.parts[next.parts.length - 1],
  691. messageID: next.id,
  692. sessionID: next.metadata.sessionID,
  693. })
  694. break
  695. case "tool-call-delta":
  696. continue
  697. // for some reason ai sdk claims to not send this part but it does
  698. // @ts-expect-error
  699. case "tool-result":
  700. const match = next.parts.find(
  701. (p) =>
  702. p.type === "tool-invocation" &&
  703. // @ts-expect-error
  704. p.toolInvocation.toolCallId === value.toolCallId,
  705. )
  706. if (match && match.type === "tool-invocation") {
  707. match.toolInvocation = {
  708. // @ts-expect-error
  709. args: value.args,
  710. // @ts-expect-error
  711. toolCallId: value.toolCallId,
  712. // @ts-expect-error
  713. toolName: value.toolName,
  714. state: "result",
  715. // @ts-expect-error
  716. result: value.result as string,
  717. }
  718. Bus.publish(Message.Event.PartUpdated, {
  719. part: match,
  720. messageID: next.id,
  721. sessionID: next.metadata.sessionID,
  722. })
  723. }
  724. break
  725. case "finish":
  726. log.info("message finish", {
  727. reason: value.finishReason,
  728. })
  729. const assistant = next.metadata!.assistant!
  730. const usage = getUsage(
  731. model.info,
  732. value.usage,
  733. value.providerMetadata,
  734. )
  735. assistant.cost += usage.cost
  736. await updateMessage(next)
  737. if (value.finishReason === "length")
  738. throw new Message.OutputLengthError({})
  739. break
  740. default:
  741. l.info("unhandled", {
  742. type: value.type,
  743. })
  744. continue
  745. }
  746. await updateMessage(next)
  747. }
  748. } catch (e: any) {
  749. log.error("stream error", {
  750. error: e,
  751. })
  752. switch (true) {
  753. case Message.OutputLengthError.isInstance(e):
  754. next.metadata.error = e
  755. break
  756. case LoadAPIKeyError.isInstance(e):
  757. next.metadata.error = new Provider.AuthError(
  758. {
  759. providerID: input.providerID,
  760. message: e.message,
  761. },
  762. { cause: e },
  763. ).toObject()
  764. break
  765. case e instanceof Error:
  766. next.metadata.error = new NamedError.Unknown(
  767. { message: e.toString() },
  768. { cause: e },
  769. ).toObject()
  770. break
  771. default:
  772. next.metadata.error = new NamedError.Unknown(
  773. { message: JSON.stringify(e) },
  774. { cause: e },
  775. )
  776. }
  777. Bus.publish(Event.Error, {
  778. error: next.metadata.error,
  779. })
  780. }
  781. next.metadata!.time.completed = Date.now()
  782. for (const part of next.parts) {
  783. if (
  784. part.type === "tool-invocation" &&
  785. part.toolInvocation.state !== "result"
  786. ) {
  787. part.toolInvocation = {
  788. ...part.toolInvocation,
  789. state: "result",
  790. result: "request was aborted",
  791. }
  792. }
  793. }
  794. await updateMessage(next)
  795. return next
  796. }
  797. export async function revert(input: {
  798. sessionID: string
  799. messageID: string
  800. part: number
  801. }) {
  802. const message = await getMessage(input.sessionID, input.messageID)
  803. if (!message) return
  804. const part = message.parts[input.part]
  805. if (!part) return
  806. const session = await get(input.sessionID)
  807. const snapshot =
  808. session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
  809. const old = (() => {
  810. if (message.role === "assistant") {
  811. const lastTool = message.parts.findLast(
  812. (part, index) =>
  813. part.type === "tool-invocation" && index < input.part,
  814. )
  815. if (lastTool && lastTool.type === "tool-invocation")
  816. return message.metadata.tool[lastTool.toolInvocation.toolCallId]
  817. .snapshot
  818. }
  819. return message.metadata.snapshot
  820. })()
  821. if (old) await Snapshot.restore(input.sessionID, old)
  822. await update(input.sessionID, (draft) => {
  823. draft.revert = {
  824. messageID: input.messageID,
  825. part: input.part,
  826. snapshot,
  827. }
  828. })
  829. }
  830. export async function unrevert(sessionID: string) {
  831. const session = await get(sessionID)
  832. if (!session) return
  833. if (!session.revert) return
  834. if (session.revert.snapshot)
  835. await Snapshot.restore(sessionID, session.revert.snapshot)
  836. update(sessionID, (draft) => {
  837. draft.revert = undefined
  838. })
  839. }
  840. export async function summarize(input: {
  841. sessionID: string
  842. providerID: string
  843. modelID: string
  844. }) {
  845. using abort = lock(input.sessionID)
  846. const msgs = await messages(input.sessionID)
  847. const lastSummary = msgs.findLast(
  848. (msg) => msg.metadata.assistant?.summary === true,
  849. )?.id
  850. const filtered = msgs.filter((msg) => !lastSummary || msg.id >= lastSummary)
  851. const model = await Provider.getModel(input.providerID, input.modelID)
  852. const app = App.info()
  853. const system = SystemPrompt.summarize(input.providerID)
  854. const next: Message.Info = {
  855. id: Identifier.ascending("message"),
  856. role: "assistant",
  857. parts: [],
  858. metadata: {
  859. tool: {},
  860. sessionID: input.sessionID,
  861. assistant: {
  862. system,
  863. path: {
  864. cwd: app.path.cwd,
  865. root: app.path.root,
  866. },
  867. summary: true,
  868. cost: 0,
  869. modelID: input.modelID,
  870. providerID: input.providerID,
  871. tokens: {
  872. input: 0,
  873. output: 0,
  874. reasoning: 0,
  875. cache: { read: 0, write: 0 },
  876. },
  877. },
  878. time: {
  879. created: Date.now(),
  880. },
  881. },
  882. }
  883. await updateMessage(next)
  884. let text: Message.TextPart | undefined
  885. const result = streamText({
  886. abortSignal: abort.signal,
  887. model: model.language,
  888. messages: [
  889. ...system.map(
  890. (x): CoreMessage => ({
  891. role: "system",
  892. content: x,
  893. }),
  894. ),
  895. ...convertToCoreMessages(filtered.map(toUIMessage)),
  896. {
  897. role: "user",
  898. content: [
  899. {
  900. type: "text",
  901. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  902. },
  903. ],
  904. },
  905. ],
  906. onStepFinish: async (step) => {
  907. const assistant = next.metadata!.assistant!
  908. const usage = getUsage(model.info, step.usage, step.providerMetadata)
  909. assistant.cost += usage.cost
  910. assistant.tokens = usage.tokens
  911. await updateMessage(next)
  912. if (text) {
  913. Bus.publish(Message.Event.PartUpdated, {
  914. part: text,
  915. messageID: next.id,
  916. sessionID: next.metadata.sessionID,
  917. })
  918. }
  919. text = undefined
  920. },
  921. async onFinish(input) {
  922. const assistant = next.metadata!.assistant!
  923. const usage = getUsage(model.info, input.usage, input.providerMetadata)
  924. assistant.cost += usage.cost
  925. assistant.tokens = usage.tokens
  926. next.metadata!.time.completed = Date.now()
  927. await updateMessage(next)
  928. },
  929. })
  930. for await (const value of result.fullStream) {
  931. switch (value.type) {
  932. case "text-delta":
  933. if (!text) {
  934. text = {
  935. type: "text",
  936. text: value.textDelta,
  937. }
  938. next.parts.push(text)
  939. } else text.text += value.textDelta
  940. await updateMessage(next)
  941. break
  942. }
  943. }
  944. }
  945. function lock(sessionID: string) {
  946. log.info("locking", { sessionID })
  947. if (state().pending.has(sessionID)) throw new BusyError(sessionID)
  948. const controller = new AbortController()
  949. state().pending.set(sessionID, controller)
  950. return {
  951. signal: controller.signal,
  952. [Symbol.dispose]() {
  953. log.info("unlocking", { sessionID })
  954. state().pending.delete(sessionID)
  955. Bus.publish(Event.Idle, {
  956. sessionID,
  957. })
  958. },
  959. }
  960. }
  961. function getUsage(
  962. model: ModelsDev.Model,
  963. usage: LanguageModelUsage,
  964. metadata?: ProviderMetadata,
  965. ) {
  966. const tokens = {
  967. input: usage.promptTokens ?? 0,
  968. output: usage.completionTokens ?? 0,
  969. reasoning: 0,
  970. cache: {
  971. write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  972. // @ts-expect-error
  973. metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  974. 0) as number,
  975. read: (metadata?.["anthropic"]?.["cacheReadInputTokens"] ??
  976. // @ts-expect-error
  977. metadata?.["bedrock"]?.["usage"]?.["cacheReadInputTokens"] ??
  978. 0) as number,
  979. },
  980. }
  981. return {
  982. cost: new Decimal(0)
  983. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  984. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  985. .add(
  986. new Decimal(tokens.cache.read)
  987. .mul(model.cost.cache_read ?? 0)
  988. .div(1_000_000),
  989. )
  990. .add(
  991. new Decimal(tokens.cache.write)
  992. .mul(model.cost.cache_write ?? 0)
  993. .div(1_000_000),
  994. )
  995. .toNumber(),
  996. tokens,
  997. }
  998. }
  999. export class BusyError extends Error {
  1000. constructor(public readonly sessionID: string) {
  1001. super(`Session ${sessionID} is busy`)
  1002. }
  1003. }
  1004. export async function initialize(input: {
  1005. sessionID: string
  1006. modelID: string
  1007. providerID: string
  1008. }) {
  1009. const app = App.info()
  1010. await Session.chat({
  1011. sessionID: input.sessionID,
  1012. providerID: input.providerID,
  1013. modelID: input.modelID,
  1014. parts: [
  1015. {
  1016. type: "text",
  1017. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  1018. },
  1019. ],
  1020. })
  1021. await App.initialize()
  1022. }
  1023. }
  1024. function toUIMessage(msg: Message.Info): UIMessage {
  1025. if (msg.role === "assistant") {
  1026. return {
  1027. id: msg.id,
  1028. role: "assistant",
  1029. content: "",
  1030. ...toParts(msg.parts),
  1031. }
  1032. }
  1033. if (msg.role === "user") {
  1034. return {
  1035. id: msg.id,
  1036. role: "user",
  1037. content: "",
  1038. ...toParts(msg.parts),
  1039. }
  1040. }
  1041. throw new Error("not implemented")
  1042. }
  1043. function toParts(parts: Message.MessagePart[]) {
  1044. const result: {
  1045. parts: UIMessage["parts"]
  1046. experimental_attachments: Attachment[]
  1047. } = {
  1048. parts: [],
  1049. experimental_attachments: [],
  1050. }
  1051. for (const part of parts) {
  1052. switch (part.type) {
  1053. case "text":
  1054. result.parts.push({ type: "text", text: part.text })
  1055. break
  1056. case "file":
  1057. result.experimental_attachments.push({
  1058. url: part.url,
  1059. contentType: part.mediaType,
  1060. name: part.filename,
  1061. })
  1062. break
  1063. case "tool-invocation":
  1064. result.parts.push({
  1065. type: "tool-invocation",
  1066. toolInvocation: part.toolInvocation,
  1067. })
  1068. break
  1069. case "step-start":
  1070. result.parts.push({
  1071. type: "step-start",
  1072. })
  1073. break
  1074. default:
  1075. break
  1076. }
  1077. }
  1078. return result
  1079. }