index.ts 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159
  1. import path from "path"
  2. import { Decimal } from "decimal.js"
  3. import { z, ZodSchema } from "zod"
  4. import {
  5. generateText,
  6. LoadAPIKeyError,
  7. streamText,
  8. tool,
  9. wrapLanguageModel,
  10. type Tool as AITool,
  11. type LanguageModelUsage,
  12. type ProviderMetadata,
  13. type ModelMessage,
  14. stepCountIs,
  15. type StreamTextResult,
  16. } from "ai"
  17. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  18. import PROMPT_PLAN from "../session/prompt/plan.txt"
  19. import PROMPT_ANTHROPIC_SPOOF from "../session/prompt/anthropic_spoof.txt"
  20. import { App } from "../app/app"
  21. import { Bus } from "../bus"
  22. import { Config } from "../config/config"
  23. import { Flag } from "../flag/flag"
  24. import { Identifier } from "../id/id"
  25. import { Installation } from "../installation"
  26. import { MCP } from "../mcp"
  27. import { Provider } from "../provider/provider"
  28. import { ProviderTransform } from "../provider/transform"
  29. import type { ModelsDev } from "../provider/models"
  30. import { Share } from "../share/share"
  31. import { Snapshot } from "../snapshot"
  32. import { Storage } from "../storage/storage"
  33. import { Log } from "../util/log"
  34. import { NamedError } from "../util/error"
  35. import { SystemPrompt } from "./system"
  36. import { FileTime } from "../file/time"
  37. import { MessageV2 } from "./message-v2"
  38. import { Mode } from "./mode"
  39. import { LSP } from "../lsp"
  40. import { ReadTool } from "../tool/read"
  41. export namespace Session {
  42. const log = Log.create({ service: "session" })
  43. const OUTPUT_TOKEN_MAX = 32_000
  44. export const Info = z
  45. .object({
  46. id: Identifier.schema("session"),
  47. parentID: Identifier.schema("session").optional(),
  48. share: z
  49. .object({
  50. url: z.string(),
  51. })
  52. .optional(),
  53. title: z.string(),
  54. version: z.string(),
  55. time: z.object({
  56. created: z.number(),
  57. updated: z.number(),
  58. }),
  59. revert: z
  60. .object({
  61. messageID: z.string(),
  62. part: z.number(),
  63. snapshot: z.string().optional(),
  64. })
  65. .optional(),
  66. })
  67. .openapi({
  68. ref: "Session",
  69. })
  70. export type Info = z.output<typeof Info>
  71. export const ShareInfo = z
  72. .object({
  73. secret: z.string(),
  74. url: z.string(),
  75. })
  76. .openapi({
  77. ref: "SessionShare",
  78. })
  79. export type ShareInfo = z.output<typeof ShareInfo>
  80. export const Event = {
  81. Updated: Bus.event(
  82. "session.updated",
  83. z.object({
  84. info: Info,
  85. }),
  86. ),
  87. Deleted: Bus.event(
  88. "session.deleted",
  89. z.object({
  90. info: Info,
  91. }),
  92. ),
  93. Idle: Bus.event(
  94. "session.idle",
  95. z.object({
  96. sessionID: z.string(),
  97. }),
  98. ),
  99. Error: Bus.event(
  100. "session.error",
  101. z.object({
  102. sessionID: z.string().optional(),
  103. error: MessageV2.Assistant.shape.error,
  104. }),
  105. ),
  106. }
  107. const state = App.state(
  108. "session",
  109. () => {
  110. const sessions = new Map<string, Info>()
  111. const messages = new Map<string, MessageV2.Info[]>()
  112. const pending = new Map<string, AbortController>()
  113. return {
  114. sessions,
  115. messages,
  116. pending,
  117. }
  118. },
  119. async (state) => {
  120. for (const [_, controller] of state.pending) {
  121. controller.abort()
  122. }
  123. },
  124. )
  125. export async function create(parentID?: string) {
  126. const result: Info = {
  127. id: Identifier.descending("session"),
  128. version: Installation.VERSION,
  129. parentID,
  130. title: (parentID ? "Child session - " : "New Session - ") + new Date().toISOString(),
  131. time: {
  132. created: Date.now(),
  133. updated: Date.now(),
  134. },
  135. }
  136. log.info("created", result)
  137. state().sessions.set(result.id, result)
  138. await Storage.writeJSON("session/info/" + result.id, result)
  139. const cfg = await Config.get()
  140. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
  141. share(result.id)
  142. .then((share) => {
  143. update(result.id, (draft) => {
  144. draft.share = share
  145. })
  146. })
  147. .catch(() => {
  148. // Silently ignore sharing errors during session creation
  149. })
  150. Bus.publish(Event.Updated, {
  151. info: result,
  152. })
  153. return result
  154. }
  155. export async function get(id: string) {
  156. const result = state().sessions.get(id)
  157. if (result) {
  158. return result
  159. }
  160. const read = await Storage.readJSON<Info>("session/info/" + id)
  161. state().sessions.set(id, read)
  162. return read as Info
  163. }
  164. export async function getShare(id: string) {
  165. return Storage.readJSON<ShareInfo>("session/share/" + id)
  166. }
  167. export async function share(id: string) {
  168. const cfg = await Config.get()
  169. if (cfg.share === "disabled") {
  170. throw new Error("Sharing is disabled in configuration")
  171. }
  172. const session = await get(id)
  173. if (session.share) return session.share
  174. const share = await Share.create(id)
  175. await update(id, (draft) => {
  176. draft.share = {
  177. url: share.url,
  178. }
  179. })
  180. await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
  181. await Share.sync("session/info/" + id, session)
  182. for (const msg of await messages(id)) {
  183. await Share.sync("session/message/" + id + "/" + msg.info.id, msg.info)
  184. for (const part of msg.parts) {
  185. await Share.sync("session/part/" + id + "/" + msg.info.id + "/" + part.id, part)
  186. }
  187. }
  188. return share
  189. }
  190. export async function unshare(id: string) {
  191. const share = await getShare(id)
  192. if (!share) return
  193. await Storage.remove("session/share/" + id)
  194. await update(id, (draft) => {
  195. draft.share = undefined
  196. })
  197. await Share.remove(id, share.secret)
  198. }
  199. export async function update(id: string, editor: (session: Info) => void) {
  200. const { sessions } = state()
  201. const session = await get(id)
  202. if (!session) return
  203. editor(session)
  204. session.time.updated = Date.now()
  205. sessions.set(id, session)
  206. await Storage.writeJSON("session/info/" + id, session)
  207. Bus.publish(Event.Updated, {
  208. info: session,
  209. })
  210. return session
  211. }
  212. export async function messages(sessionID: string) {
  213. const result = [] as {
  214. info: MessageV2.Info
  215. parts: MessageV2.Part[]
  216. }[]
  217. for (const p of await Storage.list("session/message/" + sessionID)) {
  218. const read = await Storage.readJSON<MessageV2.Info>(p)
  219. result.push({
  220. info: read,
  221. parts: await parts(sessionID, read.id),
  222. })
  223. }
  224. result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
  225. return result
  226. }
  227. export async function getMessage(sessionID: string, messageID: string) {
  228. return Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID)
  229. }
  230. export async function parts(sessionID: string, messageID: string) {
  231. const result = [] as MessageV2.Part[]
  232. for (const item of await Storage.list("session/part/" + sessionID + "/" + messageID)) {
  233. const read = await Storage.readJSON<MessageV2.Part>(item)
  234. result.push(read)
  235. }
  236. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  237. return result
  238. }
  239. export async function* list() {
  240. for (const item of await Storage.list("session/info")) {
  241. const sessionID = path.basename(item, ".json")
  242. yield get(sessionID)
  243. }
  244. }
  245. export async function children(parentID: string) {
  246. const result = [] as Session.Info[]
  247. for (const item of await Storage.list("session/info")) {
  248. const sessionID = path.basename(item, ".json")
  249. const session = await get(sessionID)
  250. if (session.parentID !== parentID) continue
  251. result.push(session)
  252. }
  253. return result
  254. }
  255. export function abort(sessionID: string) {
  256. const controller = state().pending.get(sessionID)
  257. if (!controller) return false
  258. controller.abort()
  259. state().pending.delete(sessionID)
  260. return true
  261. }
  262. export async function remove(sessionID: string, emitEvent = true) {
  263. try {
  264. abort(sessionID)
  265. const session = await get(sessionID)
  266. for (const child of await children(sessionID)) {
  267. await remove(child.id, false)
  268. }
  269. await unshare(sessionID).catch(() => {})
  270. await Storage.remove(`session/info/${sessionID}`).catch(() => {})
  271. await Storage.removeDir(`session/message/${sessionID}/`).catch(() => {})
  272. state().sessions.delete(sessionID)
  273. state().messages.delete(sessionID)
  274. if (emitEvent) {
  275. Bus.publish(Event.Deleted, {
  276. info: session,
  277. })
  278. }
  279. } catch (e) {
  280. log.error(e)
  281. }
  282. }
  283. async function updateMessage(msg: MessageV2.Info) {
  284. await Storage.writeJSON("session/message/" + msg.sessionID + "/" + msg.id, msg)
  285. Bus.publish(MessageV2.Event.Updated, {
  286. info: msg,
  287. })
  288. }
  289. async function updatePart(part: MessageV2.Part) {
  290. await Storage.writeJSON(["session", "part", part.sessionID, part.messageID, part.id].join("/"), part)
  291. Bus.publish(MessageV2.Event.PartUpdated, {
  292. part,
  293. })
  294. return part
  295. }
  296. export const ChatInput = z.object({
  297. sessionID: Identifier.schema("session"),
  298. messageID: Identifier.schema("message").optional(),
  299. providerID: z.string(),
  300. modelID: z.string(),
  301. mode: z.string().optional(),
  302. tools: z.record(z.boolean()).optional(),
  303. parts: z.array(
  304. z.discriminatedUnion("type", [
  305. MessageV2.TextPart.omit({
  306. messageID: true,
  307. sessionID: true,
  308. })
  309. .partial({
  310. id: true,
  311. })
  312. .openapi({
  313. ref: "TextPartInput",
  314. }),
  315. MessageV2.FilePart.omit({
  316. messageID: true,
  317. sessionID: true,
  318. })
  319. .partial({
  320. id: true,
  321. })
  322. .openapi({
  323. ref: "FilePartInput",
  324. }),
  325. ]),
  326. ),
  327. })
  328. export async function chat(input: z.infer<typeof ChatInput>) {
  329. const l = log.clone().tag("session", input.sessionID)
  330. l.info("chatting")
  331. const model = await Provider.getModel(input.providerID, input.modelID)
  332. let msgs = await messages(input.sessionID)
  333. const session = await get(input.sessionID)
  334. if (session.revert) {
  335. const trimmed = []
  336. for (const msg of msgs) {
  337. if (
  338. msg.info.id > session.revert.messageID ||
  339. (msg.info.id === session.revert.messageID && session.revert.part === 0)
  340. ) {
  341. await Storage.remove("session/message/" + input.sessionID + "/" + msg.info.id)
  342. await Bus.publish(MessageV2.Event.Removed, {
  343. sessionID: input.sessionID,
  344. messageID: msg.info.id,
  345. })
  346. continue
  347. }
  348. if (msg.info.id === session.revert.messageID) {
  349. if (session.revert.part === 0) break
  350. msg.parts = msg.parts.slice(0, session.revert.part)
  351. }
  352. trimmed.push(msg)
  353. }
  354. msgs = trimmed
  355. await update(input.sessionID, (draft) => {
  356. draft.revert = undefined
  357. })
  358. }
  359. const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
  360. const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
  361. // auto summarize if too long
  362. if (previous && previous.tokens) {
  363. const tokens =
  364. previous.tokens.input + previous.tokens.cache.read + previous.tokens.cache.write + previous.tokens.output
  365. if (model.info.limit.context && tokens > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
  366. await summarize({
  367. sessionID: input.sessionID,
  368. providerID: input.providerID,
  369. modelID: input.modelID,
  370. })
  371. return chat(input)
  372. }
  373. }
  374. using abort = lock(input.sessionID)
  375. const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
  376. if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
  377. const userMsg: MessageV2.Info = {
  378. id: input.messageID ?? Identifier.ascending("message"),
  379. role: "user",
  380. sessionID: input.sessionID,
  381. time: {
  382. created: Date.now(),
  383. },
  384. }
  385. const app = App.info()
  386. const userParts = await Promise.all(
  387. input.parts.map(async (part): Promise<MessageV2.Part[]> => {
  388. if (part.type === "file") {
  389. const url = new URL(part.url)
  390. switch (url.protocol) {
  391. case "file:":
  392. // have to normalize, symbol search returns absolute paths
  393. // Decode the pathname since URL constructor doesn't automatically decode it
  394. const pathname = decodeURIComponent(url.pathname)
  395. const relativePath = pathname.replace(app.path.cwd, ".")
  396. const filePath = path.join(app.path.cwd, relativePath)
  397. if (part.mime === "text/plain") {
  398. let offset: number | undefined = undefined
  399. let limit: number | undefined = undefined
  400. const range = {
  401. start: url.searchParams.get("start"),
  402. end: url.searchParams.get("end"),
  403. }
  404. if (range.start != null) {
  405. const filePath = part.url.split("?")[0]
  406. let start = parseInt(range.start)
  407. let end = range.end ? parseInt(range.end) : undefined
  408. // some LSP servers (eg, gopls) don't give full range in
  409. // workspace/symbol searches, so we'll try to find the
  410. // symbol in the document to get the full range
  411. if (start === end) {
  412. const symbols = await LSP.documentSymbol(filePath)
  413. for (const symbol of symbols) {
  414. let range: LSP.Range | undefined
  415. if ("range" in symbol) {
  416. range = symbol.range
  417. } else if ("location" in symbol) {
  418. range = symbol.location.range
  419. }
  420. if (range?.start?.line && range?.start?.line === start) {
  421. start = range.start.line
  422. end = range?.end?.line ?? start
  423. break
  424. }
  425. }
  426. offset = Math.max(start - 2, 0)
  427. if (end) {
  428. limit = end - offset + 2
  429. }
  430. }
  431. }
  432. const args = { filePath, offset, limit }
  433. const result = await ReadTool.execute(args, {
  434. sessionID: input.sessionID,
  435. abort: abort.signal,
  436. messageID: userMsg.id,
  437. metadata: async () => {},
  438. })
  439. return [
  440. {
  441. id: Identifier.ascending("part"),
  442. messageID: userMsg.id,
  443. sessionID: input.sessionID,
  444. type: "text",
  445. synthetic: true,
  446. text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
  447. },
  448. {
  449. id: Identifier.ascending("part"),
  450. messageID: userMsg.id,
  451. sessionID: input.sessionID,
  452. type: "text",
  453. synthetic: true,
  454. text: result.output,
  455. },
  456. {
  457. ...part,
  458. id: part.id ?? Identifier.ascending("part"),
  459. messageID: userMsg.id,
  460. sessionID: input.sessionID,
  461. },
  462. ]
  463. }
  464. let file = Bun.file(filePath)
  465. FileTime.read(input.sessionID, filePath)
  466. return [
  467. {
  468. id: Identifier.ascending("part"),
  469. messageID: userMsg.id,
  470. sessionID: input.sessionID,
  471. type: "text",
  472. text: `Called the Read tool with the following input: {\"filePath\":\"${pathname}\"}`,
  473. synthetic: true,
  474. },
  475. {
  476. id: part.id ?? Identifier.ascending("part"),
  477. messageID: userMsg.id,
  478. sessionID: input.sessionID,
  479. type: "file",
  480. url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
  481. mime: part.mime,
  482. filename: part.filename!,
  483. source: part.source,
  484. },
  485. ]
  486. }
  487. }
  488. return [
  489. {
  490. id: Identifier.ascending("part"),
  491. ...part,
  492. messageID: userMsg.id,
  493. sessionID: input.sessionID,
  494. },
  495. ]
  496. }),
  497. ).then((x) => x.flat())
  498. if (input.mode === "plan")
  499. userParts.push({
  500. id: Identifier.ascending("part"),
  501. messageID: userMsg.id,
  502. sessionID: input.sessionID,
  503. type: "text",
  504. text: PROMPT_PLAN,
  505. synthetic: true,
  506. })
  507. if (msgs.length === 0 && !session.parentID) {
  508. const small = (await Provider.getSmallModel(input.providerID)) ?? model
  509. generateText({
  510. maxOutputTokens: small.info.reasoning ? 1024 : 20,
  511. providerOptions: {
  512. [input.providerID]: small.info.options,
  513. },
  514. messages: [
  515. ...SystemPrompt.title(input.providerID).map(
  516. (x): ModelMessage => ({
  517. role: "system",
  518. content: x,
  519. }),
  520. ),
  521. ...MessageV2.toModelMessage([
  522. {
  523. info: {
  524. id: Identifier.ascending("message"),
  525. role: "user",
  526. sessionID: input.sessionID,
  527. time: {
  528. created: Date.now(),
  529. },
  530. },
  531. parts: userParts,
  532. },
  533. ]),
  534. ],
  535. model: small.language,
  536. })
  537. .then((result) => {
  538. if (result.text)
  539. return Session.update(input.sessionID, (draft) => {
  540. draft.title = result.text
  541. })
  542. })
  543. .catch(() => {})
  544. }
  545. await updateMessage(userMsg)
  546. for (const part of userParts) {
  547. await updatePart(part)
  548. }
  549. msgs.push({ info: userMsg, parts: userParts })
  550. const mode = await Mode.get(input.mode ?? "build")
  551. let system = input.providerID === "anthropic" ? [PROMPT_ANTHROPIC_SPOOF.trim()] : []
  552. system.push(...(mode.prompt ? [mode.prompt] : SystemPrompt.provider(input.modelID)))
  553. system.push(...(await SystemPrompt.environment()))
  554. system.push(...(await SystemPrompt.custom()))
  555. // max 2 system prompt messages for caching purposes
  556. const [first, ...rest] = system
  557. system = [first, rest.join("\n")]
  558. const assistantMsg: MessageV2.Info = {
  559. id: Identifier.ascending("message"),
  560. role: "assistant",
  561. system,
  562. path: {
  563. cwd: app.path.cwd,
  564. root: app.path.root,
  565. },
  566. cost: 0,
  567. tokens: {
  568. input: 0,
  569. output: 0,
  570. reasoning: 0,
  571. cache: { read: 0, write: 0 },
  572. },
  573. modelID: input.modelID,
  574. providerID: input.providerID,
  575. time: {
  576. created: Date.now(),
  577. },
  578. sessionID: input.sessionID,
  579. }
  580. await updateMessage(assistantMsg)
  581. const tools: Record<string, AITool> = {}
  582. const processor = createProcessor(assistantMsg, model.info)
  583. for (const item of await Provider.tools(input.providerID)) {
  584. if (mode.tools[item.id] === false) continue
  585. if (input.tools?.[item.id] === false) continue
  586. if (session.parentID && item.id === "task") continue
  587. tools[item.id] = tool({
  588. id: item.id as any,
  589. description: item.description,
  590. inputSchema: item.parameters as ZodSchema,
  591. async execute(args, options) {
  592. const result = await item.execute(args, {
  593. sessionID: input.sessionID,
  594. abort: abort.signal,
  595. messageID: assistantMsg.id,
  596. metadata: async (val) => {
  597. const match = processor.partFromToolCall(options.toolCallId)
  598. if (match && match.state.status === "running") {
  599. await updatePart({
  600. ...match,
  601. state: {
  602. title: val.title,
  603. metadata: val.metadata,
  604. status: "running",
  605. input: args,
  606. time: {
  607. start: Date.now(),
  608. },
  609. },
  610. })
  611. }
  612. },
  613. })
  614. return result
  615. },
  616. toModelOutput(result) {
  617. return {
  618. type: "text",
  619. value: result.output,
  620. }
  621. },
  622. })
  623. }
  624. for (const [key, item] of Object.entries(await MCP.tools())) {
  625. if (mode.tools[key] === false) continue
  626. const execute = item.execute
  627. if (!execute) continue
  628. item.execute = async (args, opts) => {
  629. const result = await execute(args, opts)
  630. const output = result.content
  631. .filter((x: any) => x.type === "text")
  632. .map((x: any) => x.text)
  633. .join("\n\n")
  634. return {
  635. output,
  636. }
  637. }
  638. item.toModelOutput = (result) => {
  639. return {
  640. type: "text",
  641. value: result.output,
  642. }
  643. }
  644. tools[key] = item
  645. }
  646. const stream = streamText({
  647. onError() {},
  648. maxRetries: 10,
  649. maxOutputTokens: outputLimit,
  650. abortSignal: abort.signal,
  651. stopWhen: stepCountIs(1000),
  652. providerOptions: {
  653. [input.providerID]: model.info.options,
  654. },
  655. messages: [
  656. ...system.map(
  657. (x): ModelMessage => ({
  658. role: "system",
  659. content: x,
  660. }),
  661. ),
  662. ...MessageV2.toModelMessage(msgs),
  663. ],
  664. temperature: model.info.temperature ? 0 : undefined,
  665. tools: model.info.tool_call === false ? undefined : tools,
  666. model: wrapLanguageModel({
  667. model: model.language,
  668. middleware: [
  669. {
  670. async transformParams(args) {
  671. if (args.type === "stream") {
  672. // @ts-expect-error
  673. args.params.prompt = ProviderTransform.message(args.params.prompt, input.providerID, input.modelID)
  674. }
  675. return args.params
  676. },
  677. },
  678. ],
  679. }),
  680. })
  681. const result = await processor.process(stream)
  682. return result
  683. }
  684. function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
  685. const toolCalls: Record<string, MessageV2.ToolPart> = {}
  686. return {
  687. partFromToolCall(toolCallID: string) {
  688. return toolCalls[toolCallID]
  689. },
  690. async process(stream: StreamTextResult<Record<string, AITool>, never>) {
  691. try {
  692. let currentText: MessageV2.TextPart | undefined
  693. for await (const value of stream.fullStream) {
  694. log.info("part", {
  695. type: value.type,
  696. })
  697. switch (value.type) {
  698. case "start":
  699. const snapshot = await Snapshot.create(assistantMsg.sessionID)
  700. if (snapshot)
  701. await updatePart({
  702. id: Identifier.ascending("part"),
  703. messageID: assistantMsg.id,
  704. sessionID: assistantMsg.sessionID,
  705. type: "snapshot",
  706. snapshot,
  707. })
  708. break
  709. case "tool-input-start":
  710. const part = await updatePart({
  711. id: Identifier.ascending("part"),
  712. messageID: assistantMsg.id,
  713. sessionID: assistantMsg.sessionID,
  714. type: "tool",
  715. tool: value.toolName,
  716. callID: value.id,
  717. state: {
  718. status: "pending",
  719. },
  720. })
  721. toolCalls[value.id] = part as MessageV2.ToolPart
  722. break
  723. case "tool-input-delta":
  724. break
  725. case "tool-call": {
  726. const match = toolCalls[value.toolCallId]
  727. if (match) {
  728. const part = await updatePart({
  729. ...match,
  730. state: {
  731. status: "running",
  732. input: value.input,
  733. time: {
  734. start: Date.now(),
  735. },
  736. },
  737. })
  738. toolCalls[value.toolCallId] = part as MessageV2.ToolPart
  739. }
  740. break
  741. }
  742. case "tool-result": {
  743. const match = toolCalls[value.toolCallId]
  744. if (match && match.state.status === "running") {
  745. await updatePart({
  746. ...match,
  747. state: {
  748. status: "completed",
  749. input: value.input,
  750. output: value.output.output,
  751. metadata: value.output.metadata,
  752. title: value.output.title,
  753. time: {
  754. start: match.state.time.start,
  755. end: Date.now(),
  756. },
  757. },
  758. })
  759. delete toolCalls[value.toolCallId]
  760. const snapshot = await Snapshot.create(assistantMsg.sessionID)
  761. if (snapshot)
  762. await updatePart({
  763. id: Identifier.ascending("part"),
  764. messageID: assistantMsg.id,
  765. sessionID: assistantMsg.sessionID,
  766. type: "snapshot",
  767. snapshot,
  768. })
  769. }
  770. break
  771. }
  772. case "tool-error": {
  773. const match = toolCalls[value.toolCallId]
  774. if (match && match.state.status === "running") {
  775. await updatePart({
  776. ...match,
  777. state: {
  778. status: "error",
  779. input: value.input,
  780. error: (value.error as any).toString(),
  781. time: {
  782. start: match.state.time.start,
  783. end: Date.now(),
  784. },
  785. },
  786. })
  787. delete toolCalls[value.toolCallId]
  788. const snapshot = await Snapshot.create(assistantMsg.sessionID)
  789. if (snapshot)
  790. await updatePart({
  791. id: Identifier.ascending("part"),
  792. messageID: assistantMsg.id,
  793. sessionID: assistantMsg.sessionID,
  794. type: "snapshot",
  795. snapshot,
  796. })
  797. }
  798. break
  799. }
  800. case "error":
  801. throw value.error
  802. case "start-step":
  803. await updatePart({
  804. id: Identifier.ascending("part"),
  805. messageID: assistantMsg.id,
  806. sessionID: assistantMsg.sessionID,
  807. type: "step-start",
  808. })
  809. break
  810. case "finish-step":
  811. const usage = getUsage(model, value.usage, value.providerMetadata)
  812. assistantMsg.cost += usage.cost
  813. assistantMsg.tokens = usage.tokens
  814. await updatePart({
  815. id: Identifier.ascending("part"),
  816. messageID: assistantMsg.id,
  817. sessionID: assistantMsg.sessionID,
  818. type: "step-finish",
  819. tokens: usage.tokens,
  820. cost: usage.cost,
  821. })
  822. await updateMessage(assistantMsg)
  823. break
  824. case "text-start":
  825. currentText = {
  826. id: Identifier.ascending("part"),
  827. messageID: assistantMsg.id,
  828. sessionID: assistantMsg.sessionID,
  829. type: "text",
  830. text: "",
  831. time: {
  832. start: Date.now(),
  833. },
  834. }
  835. break
  836. case "text":
  837. if (currentText) {
  838. currentText.text += value.text
  839. await updatePart(currentText)
  840. }
  841. break
  842. case "text-end":
  843. if (currentText && currentText.text) {
  844. currentText.time = {
  845. start: Date.now(),
  846. end: Date.now(),
  847. }
  848. await updatePart(currentText)
  849. }
  850. currentText = undefined
  851. break
  852. case "finish":
  853. assistantMsg.time.completed = Date.now()
  854. await updateMessage(assistantMsg)
  855. break
  856. default:
  857. log.info("unhandled", {
  858. ...value,
  859. })
  860. continue
  861. }
  862. }
  863. } catch (e) {
  864. log.error("", {
  865. error: e,
  866. })
  867. switch (true) {
  868. case e instanceof DOMException && e.name === "AbortError":
  869. assistantMsg.error = new MessageV2.AbortedError(
  870. { message: e.message },
  871. {
  872. cause: e,
  873. },
  874. ).toObject()
  875. break
  876. case MessageV2.OutputLengthError.isInstance(e):
  877. assistantMsg.error = e
  878. break
  879. case LoadAPIKeyError.isInstance(e):
  880. assistantMsg.error = new MessageV2.AuthError(
  881. {
  882. providerID: model.id,
  883. message: e.message,
  884. },
  885. { cause: e },
  886. ).toObject()
  887. break
  888. case e instanceof Error:
  889. assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  890. break
  891. default:
  892. assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
  893. }
  894. Bus.publish(Event.Error, {
  895. sessionID: assistantMsg.sessionID,
  896. error: assistantMsg.error,
  897. })
  898. }
  899. const p = await parts(assistantMsg.sessionID, assistantMsg.id)
  900. for (const part of p) {
  901. if (part.type === "tool" && part.state.status !== "completed") {
  902. updatePart({
  903. ...part,
  904. state: {
  905. status: "error",
  906. error: "Tool execution aborted",
  907. time: {
  908. start: Date.now(),
  909. end: Date.now(),
  910. },
  911. input: {},
  912. },
  913. })
  914. }
  915. }
  916. assistantMsg.time.completed = Date.now()
  917. await updateMessage(assistantMsg)
  918. return { info: assistantMsg, parts: p }
  919. },
  920. }
  921. }
  922. export async function revert(_input: { sessionID: string; messageID: string; part: number }) {
  923. // TODO
  924. /*
  925. const message = await getMessage(input.sessionID, input.messageID)
  926. if (!message) return
  927. const part = message.parts[input.part]
  928. if (!part) return
  929. const session = await get(input.sessionID)
  930. const snapshot =
  931. session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
  932. const old = (() => {
  933. if (message.role === "assistant") {
  934. const lastTool = message.parts.findLast(
  935. (part, index) =>
  936. part.type === "tool-invocation" && index < input.part,
  937. )
  938. if (lastTool && lastTool.type === "tool-invocation")
  939. return message.metadata.tool[lastTool.toolInvocation.toolCallId]
  940. .snapshot
  941. }
  942. return message.metadata.snapshot
  943. })()
  944. if (old) await Snapshot.restore(input.sessionID, old)
  945. await update(input.sessionID, (draft) => {
  946. draft.revert = {
  947. messageID: input.messageID,
  948. part: input.part,
  949. snapshot,
  950. }
  951. })
  952. */
  953. }
  954. export async function unrevert(sessionID: string) {
  955. const session = await get(sessionID)
  956. if (!session) return
  957. if (!session.revert) return
  958. if (session.revert.snapshot) await Snapshot.restore(sessionID, session.revert.snapshot)
  959. update(sessionID, (draft) => {
  960. draft.revert = undefined
  961. })
  962. }
  963. export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
  964. using abort = lock(input.sessionID)
  965. const msgs = await messages(input.sessionID)
  966. const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
  967. const filtered = msgs.filter((msg) => !lastSummary || msg.info.id >= lastSummary.info.id)
  968. const model = await Provider.getModel(input.providerID, input.modelID)
  969. const app = App.info()
  970. const system = SystemPrompt.summarize(input.providerID)
  971. const next: MessageV2.Info = {
  972. id: Identifier.ascending("message"),
  973. role: "assistant",
  974. sessionID: input.sessionID,
  975. system,
  976. path: {
  977. cwd: app.path.cwd,
  978. root: app.path.root,
  979. },
  980. summary: true,
  981. cost: 0,
  982. modelID: input.modelID,
  983. providerID: input.providerID,
  984. tokens: {
  985. input: 0,
  986. output: 0,
  987. reasoning: 0,
  988. cache: { read: 0, write: 0 },
  989. },
  990. time: {
  991. created: Date.now(),
  992. },
  993. }
  994. await updateMessage(next)
  995. const processor = createProcessor(next, model.info)
  996. const stream = streamText({
  997. maxRetries: 10,
  998. abortSignal: abort.signal,
  999. model: model.language,
  1000. messages: [
  1001. ...system.map(
  1002. (x): ModelMessage => ({
  1003. role: "system",
  1004. content: x,
  1005. }),
  1006. ),
  1007. ...MessageV2.toModelMessage(filtered),
  1008. {
  1009. role: "user",
  1010. content: [
  1011. {
  1012. type: "text",
  1013. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  1014. },
  1015. ],
  1016. },
  1017. ],
  1018. })
  1019. const result = await processor.process(stream)
  1020. return result
  1021. }
  1022. function lock(sessionID: string) {
  1023. log.info("locking", { sessionID })
  1024. if (state().pending.has(sessionID)) throw new BusyError(sessionID)
  1025. const controller = new AbortController()
  1026. state().pending.set(sessionID, controller)
  1027. return {
  1028. signal: controller.signal,
  1029. [Symbol.dispose]() {
  1030. log.info("unlocking", { sessionID })
  1031. state().pending.delete(sessionID)
  1032. Bus.publish(Event.Idle, {
  1033. sessionID,
  1034. })
  1035. },
  1036. }
  1037. }
  1038. function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
  1039. const tokens = {
  1040. input: usage.inputTokens ?? 0,
  1041. output: usage.outputTokens ?? 0,
  1042. reasoning: 0,
  1043. cache: {
  1044. write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  1045. // @ts-expect-error
  1046. metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  1047. 0) as number,
  1048. read: usage.cachedInputTokens ?? 0,
  1049. },
  1050. }
  1051. return {
  1052. cost: new Decimal(0)
  1053. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  1054. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  1055. .add(new Decimal(tokens.cache.read).mul(model.cost.cache_read ?? 0).div(1_000_000))
  1056. .add(new Decimal(tokens.cache.write).mul(model.cost.cache_write ?? 0).div(1_000_000))
  1057. .toNumber(),
  1058. tokens,
  1059. }
  1060. }
  1061. export class BusyError extends Error {
  1062. constructor(public readonly sessionID: string) {
  1063. super(`Session ${sessionID} is busy`)
  1064. }
  1065. }
  1066. export async function initialize(input: {
  1067. sessionID: string
  1068. modelID: string
  1069. providerID: string
  1070. messageID: string
  1071. }) {
  1072. const app = App.info()
  1073. await Session.chat({
  1074. sessionID: input.sessionID,
  1075. messageID: input.messageID,
  1076. providerID: input.providerID,
  1077. modelID: input.modelID,
  1078. parts: [
  1079. {
  1080. id: Identifier.ascending("part"),
  1081. type: "text",
  1082. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  1083. },
  1084. ],
  1085. })
  1086. await App.initialize()
  1087. }
  1088. }