index.ts 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985
  1. import path from "path"
  2. import { App } from "../app/app"
  3. import { Identifier } from "../id/id"
  4. import { Storage } from "../storage/storage"
  5. import { Log } from "../util/log"
  6. import {
  7. generateText,
  8. LoadAPIKeyError,
  9. convertToCoreMessages,
  10. streamText,
  11. tool,
  12. type Tool as AITool,
  13. type LanguageModelUsage,
  14. type CoreMessage,
  15. type UIMessage,
  16. type ProviderMetadata,
  17. wrapLanguageModel,
  18. } from "ai"
  19. import { z, ZodSchema } from "zod"
  20. import { Decimal } from "decimal.js"
  21. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  22. import { Share } from "../share/share"
  23. import { Message } from "./message"
  24. import { Bus } from "../bus"
  25. import { Provider } from "../provider/provider"
  26. import { MCP } from "../mcp"
  27. import { NamedError } from "../util/error"
  28. import type { Tool } from "../tool/tool"
  29. import { SystemPrompt } from "./system"
  30. import { Flag } from "../flag/flag"
  31. import type { ModelsDev } from "../provider/models"
  32. import { Installation } from "../installation"
  33. import { Config } from "../config/config"
  34. import { ProviderTransform } from "../provider/transform"
  35. export namespace Session {
  36. const log = Log.create({ service: "session" })
  37. export const Info = z
  38. .object({
  39. id: Identifier.schema("session"),
  40. parentID: Identifier.schema("session").optional(),
  41. share: z
  42. .object({
  43. url: z.string(),
  44. })
  45. .optional(),
  46. title: z.string(),
  47. version: z.string(),
  48. time: z.object({
  49. created: z.number(),
  50. updated: z.number(),
  51. }),
  52. })
  53. .openapi({
  54. ref: "session.info",
  55. })
  56. export type Info = z.output<typeof Info>
  57. export const ShareInfo = z.object({
  58. secret: z.string(),
  59. url: z.string(),
  60. })
  61. export type ShareInfo = z.output<typeof ShareInfo>
  62. export const Event = {
  63. Updated: Bus.event(
  64. "session.updated",
  65. z.object({
  66. info: Info,
  67. }),
  68. ),
  69. Deleted: Bus.event(
  70. "session.deleted",
  71. z.object({
  72. info: Info,
  73. }),
  74. ),
  75. Error: Bus.event(
  76. "session.error",
  77. z.object({
  78. error: Message.Info.shape.metadata.shape.error,
  79. }),
  80. ),
  81. }
  82. const state = App.state(
  83. "session",
  84. () => {
  85. const sessions = new Map<string, Info>()
  86. const messages = new Map<string, Message.Info[]>()
  87. const pending = new Map<string, AbortController>()
  88. return {
  89. sessions,
  90. messages,
  91. pending,
  92. }
  93. },
  94. async (state) => {
  95. for (const [_, controller] of state.pending) {
  96. controller.abort()
  97. }
  98. },
  99. )
  100. export async function create(parentID?: string) {
  101. const result: Info = {
  102. id: Identifier.descending("session"),
  103. version: Installation.VERSION,
  104. parentID,
  105. title:
  106. (parentID ? "Child session - " : "New Session - ") +
  107. new Date().toISOString(),
  108. time: {
  109. created: Date.now(),
  110. updated: Date.now(),
  111. },
  112. }
  113. log.info("created", result)
  114. state().sessions.set(result.id, result)
  115. await Storage.writeJSON("session/info/" + result.id, result)
  116. const cfg = await Config.get()
  117. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.autoshare))
  118. share(result.id).then((share) => {
  119. update(result.id, (draft) => {
  120. draft.share = share
  121. })
  122. })
  123. Bus.publish(Event.Updated, {
  124. info: result,
  125. })
  126. return result
  127. }
  128. export async function get(id: string) {
  129. const result = state().sessions.get(id)
  130. if (result) {
  131. return result
  132. }
  133. const read = await Storage.readJSON<Info>("session/info/" + id)
  134. state().sessions.set(id, read)
  135. return read as Info
  136. }
  137. export async function getShare(id: string) {
  138. return Storage.readJSON<ShareInfo>("session/share/" + id)
  139. }
  140. export async function share(id: string) {
  141. const session = await get(id)
  142. if (session.share) return session.share
  143. const share = await Share.create(id)
  144. await update(id, (draft) => {
  145. draft.share = {
  146. url: share.url,
  147. }
  148. })
  149. await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
  150. await Share.sync("session/info/" + id, session)
  151. for (const msg of await messages(id)) {
  152. await Share.sync("session/message/" + id + "/" + msg.id, msg)
  153. }
  154. return share
  155. }
  156. export async function unshare(id: string) {
  157. await Storage.remove("session/share/" + id)
  158. await update(id, (draft) => {
  159. draft.share = undefined
  160. })
  161. await Share.remove(id)
  162. }
  163. export async function update(id: string, editor: (session: Info) => void) {
  164. const { sessions } = state()
  165. const session = await get(id)
  166. if (!session) return
  167. editor(session)
  168. session.time.updated = Date.now()
  169. sessions.set(id, session)
  170. await Storage.writeJSON("session/info/" + id, session)
  171. Bus.publish(Event.Updated, {
  172. info: session,
  173. })
  174. return session
  175. }
  176. export async function messages(sessionID: string) {
  177. const result = [] as Message.Info[]
  178. const list = Storage.list("session/message/" + sessionID)
  179. for await (const p of list) {
  180. const read = await Storage.readJSON<Message.Info>(p)
  181. result.push(read)
  182. }
  183. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  184. return result
  185. }
  186. export async function getMessage(sessionID: string, messageID: string) {
  187. return Storage.readJSON<Message.Info>(
  188. "session/message/" + sessionID + "/" + messageID,
  189. )
  190. }
  191. export async function* list() {
  192. for await (const item of Storage.list("session/info")) {
  193. const sessionID = path.basename(item, ".json")
  194. yield get(sessionID)
  195. }
  196. }
  197. export async function children(parentID: string) {
  198. const result = [] as Session.Info[]
  199. for await (const item of Storage.list("session/info")) {
  200. const sessionID = path.basename(item, ".json")
  201. const session = await get(sessionID)
  202. if (session.parentID !== parentID) continue
  203. result.push(session)
  204. }
  205. return result
  206. }
  207. export function abort(sessionID: string) {
  208. const controller = state().pending.get(sessionID)
  209. if (!controller) return false
  210. controller.abort()
  211. state().pending.delete(sessionID)
  212. return true
  213. }
  214. export async function remove(sessionID: string, emitEvent = true) {
  215. try {
  216. abort(sessionID)
  217. const session = await get(sessionID)
  218. for (const child of await children(sessionID)) {
  219. await remove(child.id, false)
  220. }
  221. await unshare(sessionID).catch(() => {})
  222. await Storage.remove(`session/info/${sessionID}`).catch(() => {})
  223. await Storage.removeDir(`session/message/${sessionID}/`).catch(() => {})
  224. state().sessions.delete(sessionID)
  225. state().messages.delete(sessionID)
  226. if (emitEvent) {
  227. Bus.publish(Event.Deleted, {
  228. info: session,
  229. })
  230. }
  231. } catch (e) {
  232. log.error(e)
  233. }
  234. }
  235. async function updateMessage(msg: Message.Info) {
  236. await Storage.writeJSON(
  237. "session/message/" + msg.metadata.sessionID + "/" + msg.id,
  238. msg,
  239. )
  240. Bus.publish(Message.Event.Updated, {
  241. info: msg,
  242. })
  243. }
  244. export async function chat(input: {
  245. sessionID: string
  246. providerID: string
  247. modelID: string
  248. parts: Message.Part[]
  249. system?: string[]
  250. tools?: Tool.Info[]
  251. }) {
  252. const l = log.clone().tag("session", input.sessionID)
  253. l.info("chatting")
  254. const model = await Provider.getModel(input.providerID, input.modelID)
  255. let msgs = await messages(input.sessionID)
  256. const previous = msgs.at(-1)
  257. // auto summarize if too long
  258. if (previous?.metadata.assistant) {
  259. const tokens =
  260. previous.metadata.assistant.tokens.input +
  261. previous.metadata.assistant.tokens.cache.read +
  262. previous.metadata.assistant.tokens.cache.write +
  263. previous.metadata.assistant.tokens.output
  264. if (
  265. model.info.limit.context &&
  266. tokens >
  267. Math.max(
  268. (model.info.limit.context - (model.info.limit.output ?? 0)) * 0.9,
  269. 0,
  270. )
  271. ) {
  272. await summarize({
  273. sessionID: input.sessionID,
  274. providerID: input.providerID,
  275. modelID: input.modelID,
  276. })
  277. return chat(input)
  278. }
  279. }
  280. using abort = lock(input.sessionID)
  281. const lastSummary = msgs.findLast(
  282. (msg) => msg.metadata.assistant?.summary === true,
  283. )
  284. if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
  285. const app = App.info()
  286. const session = await get(input.sessionID)
  287. if (msgs.length === 0 && !session.parentID) {
  288. generateText({
  289. maxTokens: input.providerID === "google" ? 1024 : 20,
  290. providerOptions: model.info.options,
  291. messages: [
  292. ...SystemPrompt.title(input.providerID).map(
  293. (x): CoreMessage => ({
  294. role: "system",
  295. content: x,
  296. }),
  297. ),
  298. ...convertToCoreMessages([
  299. {
  300. role: "user",
  301. content: "",
  302. parts: toParts(input.parts),
  303. },
  304. ]),
  305. ],
  306. model: model.language,
  307. })
  308. .then((result) => {
  309. if (result.text)
  310. return Session.update(input.sessionID, (draft) => {
  311. draft.title = result.text
  312. })
  313. })
  314. .catch(() => {})
  315. }
  316. const msg: Message.Info = {
  317. role: "user",
  318. id: Identifier.ascending("message"),
  319. parts: input.parts,
  320. metadata: {
  321. time: {
  322. created: Date.now(),
  323. },
  324. sessionID: input.sessionID,
  325. tool: {},
  326. },
  327. }
  328. await updateMessage(msg)
  329. msgs.push(msg)
  330. const system = input.system ?? SystemPrompt.provider(input.providerID)
  331. system.push(...(await SystemPrompt.environment()))
  332. system.push(...(await SystemPrompt.custom()))
  333. const next: Message.Info = {
  334. id: Identifier.ascending("message"),
  335. role: "assistant",
  336. parts: [],
  337. metadata: {
  338. assistant: {
  339. system,
  340. path: {
  341. cwd: app.path.cwd,
  342. root: app.path.root,
  343. },
  344. cost: 0,
  345. tokens: {
  346. input: 0,
  347. output: 0,
  348. reasoning: 0,
  349. cache: { read: 0, write: 0 },
  350. },
  351. modelID: input.modelID,
  352. providerID: input.providerID,
  353. },
  354. time: {
  355. created: Date.now(),
  356. },
  357. sessionID: input.sessionID,
  358. tool: {},
  359. },
  360. }
  361. await updateMessage(next)
  362. const tools: Record<string, AITool> = {}
  363. for (const item of await Provider.tools(input.providerID)) {
  364. tools[item.id.replaceAll(".", "_")] = tool({
  365. id: item.id as any,
  366. description: item.description,
  367. parameters: item.parameters as ZodSchema,
  368. async execute(args, opts) {
  369. const start = Date.now()
  370. try {
  371. const result = await item.execute(args, {
  372. sessionID: input.sessionID,
  373. abort: abort.signal,
  374. messageID: next.id,
  375. metadata: async (val) => {
  376. next.metadata.tool[opts.toolCallId] = {
  377. ...val,
  378. time: {
  379. start: 0,
  380. end: 0,
  381. },
  382. }
  383. await updateMessage(next)
  384. },
  385. })
  386. next.metadata!.tool![opts.toolCallId] = {
  387. ...result.metadata,
  388. time: {
  389. start,
  390. end: Date.now(),
  391. },
  392. }
  393. await updateMessage(next)
  394. return result.output
  395. } catch (e: any) {
  396. next.metadata!.tool![opts.toolCallId] = {
  397. error: true,
  398. message: e.toString(),
  399. title: e.toString(),
  400. time: {
  401. start,
  402. end: Date.now(),
  403. },
  404. }
  405. await updateMessage(next)
  406. return e.toString()
  407. }
  408. },
  409. })
  410. }
  411. for (const [key, item] of Object.entries(await MCP.tools())) {
  412. const execute = item.execute
  413. if (!execute) continue
  414. item.execute = async (args, opts) => {
  415. const start = Date.now()
  416. try {
  417. const result = await execute(args, opts)
  418. next.metadata!.tool![opts.toolCallId] = {
  419. ...result.metadata,
  420. time: {
  421. start,
  422. end: Date.now(),
  423. },
  424. }
  425. await updateMessage(next)
  426. return result.content
  427. .filter((x: any) => x.type === "text")
  428. .map((x: any) => x.text)
  429. .join("\n\n")
  430. } catch (e: any) {
  431. next.metadata!.tool![opts.toolCallId] = {
  432. error: true,
  433. message: e.toString(),
  434. title: "mcp",
  435. time: {
  436. start,
  437. end: Date.now(),
  438. },
  439. }
  440. await updateMessage(next)
  441. return e.toString()
  442. }
  443. }
  444. tools[key] = item
  445. }
  446. let text: Message.TextPart | undefined
  447. const result = streamText({
  448. onStepFinish: async (step) => {
  449. log.info("step finish", { finishReason: step.finishReason })
  450. const assistant = next.metadata!.assistant!
  451. const usage = getUsage(model.info, step.usage, step.providerMetadata)
  452. assistant.cost += usage.cost
  453. assistant.tokens = usage.tokens
  454. await updateMessage(next)
  455. if (text) {
  456. Bus.publish(Message.Event.PartUpdated, {
  457. part: text,
  458. messageID: next.id,
  459. sessionID: next.metadata.sessionID,
  460. })
  461. }
  462. text = undefined
  463. },
  464. async onFinish(input) {
  465. log.info("message finish", {
  466. reason: input.finishReason,
  467. })
  468. const assistant = next.metadata!.assistant!
  469. const usage = getUsage(model.info, input.usage, input.providerMetadata)
  470. assistant.cost = usage.cost
  471. await updateMessage(next)
  472. },
  473. onError(err) {
  474. log.error("callback error", err)
  475. switch (true) {
  476. case LoadAPIKeyError.isInstance(err.error):
  477. next.metadata.error = new Provider.AuthError(
  478. {
  479. providerID: input.providerID,
  480. message: err.error.message,
  481. },
  482. { cause: err.error },
  483. ).toObject()
  484. break
  485. case err.error instanceof Error:
  486. next.metadata.error = new NamedError.Unknown(
  487. { message: err.error.toString() },
  488. { cause: err.error },
  489. ).toObject()
  490. break
  491. default:
  492. next.metadata.error = new NamedError.Unknown(
  493. { message: JSON.stringify(err.error) },
  494. { cause: err.error },
  495. )
  496. }
  497. Bus.publish(Event.Error, {
  498. error: next.metadata.error,
  499. })
  500. },
  501. // async prepareStep(step) {
  502. // next.parts.push({
  503. // type: "step-start",
  504. // })
  505. // await updateMessage(next)
  506. // return step
  507. // },
  508. toolCallStreaming: true,
  509. abortSignal: abort.signal,
  510. maxSteps: 1000,
  511. providerOptions: model.info.options,
  512. messages: [
  513. ...system.map(
  514. (x): CoreMessage => ({
  515. role: "system",
  516. content: x,
  517. }),
  518. ),
  519. ...convertToCoreMessages(
  520. msgs.map(toUIMessage).filter((x) => x.parts.length > 0),
  521. ),
  522. ],
  523. temperature: model.info.temperature ? 0 : undefined,
  524. tools: model.info.tool_call === false ? undefined : tools,
  525. model: wrapLanguageModel({
  526. model: model.language,
  527. middleware: [
  528. {
  529. async transformParams(args) {
  530. if (args.type === "stream") {
  531. args.params.prompt = ProviderTransform.message(
  532. args.params.prompt,
  533. input.providerID,
  534. input.modelID,
  535. )
  536. }
  537. return args.params
  538. },
  539. },
  540. ],
  541. }),
  542. })
  543. try {
  544. for await (const value of result.fullStream) {
  545. l.info("part", {
  546. type: value.type,
  547. })
  548. switch (value.type) {
  549. case "step-start":
  550. next.parts.push({
  551. type: "step-start",
  552. })
  553. break
  554. case "text-delta":
  555. if (!text) {
  556. text = {
  557. type: "text",
  558. text: value.textDelta,
  559. }
  560. next.parts.push(text)
  561. break
  562. } else text.text += value.textDelta
  563. break
  564. case "tool-call": {
  565. const [match] = next.parts.flatMap((p) =>
  566. p.type === "tool-invocation" &&
  567. p.toolInvocation.toolCallId === value.toolCallId
  568. ? [p]
  569. : [],
  570. )
  571. if (!match) break
  572. match.toolInvocation.args = value.args
  573. match.toolInvocation.state = "call"
  574. Bus.publish(Message.Event.PartUpdated, {
  575. part: match,
  576. messageID: next.id,
  577. sessionID: next.metadata.sessionID,
  578. })
  579. break
  580. }
  581. case "tool-call-streaming-start":
  582. next.parts.push({
  583. type: "tool-invocation",
  584. toolInvocation: {
  585. state: "partial-call",
  586. toolName: value.toolName,
  587. toolCallId: value.toolCallId,
  588. args: {},
  589. },
  590. })
  591. Bus.publish(Message.Event.PartUpdated, {
  592. part: next.parts[next.parts.length - 1],
  593. messageID: next.id,
  594. sessionID: next.metadata.sessionID,
  595. })
  596. break
  597. case "tool-call-delta":
  598. continue
  599. // for some reason ai sdk claims to not send this part but it does
  600. // @ts-expect-error
  601. case "tool-result":
  602. const match = next.parts.find(
  603. (p) =>
  604. p.type === "tool-invocation" &&
  605. // @ts-expect-error
  606. p.toolInvocation.toolCallId === value.toolCallId,
  607. )
  608. if (match && match.type === "tool-invocation") {
  609. match.toolInvocation = {
  610. // @ts-expect-error
  611. args: value.args,
  612. // @ts-expect-error
  613. toolCallId: value.toolCallId,
  614. // @ts-expect-error
  615. toolName: value.toolName,
  616. state: "result",
  617. // @ts-expect-error
  618. result: value.result as string,
  619. }
  620. Bus.publish(Message.Event.PartUpdated, {
  621. part: match,
  622. messageID: next.id,
  623. sessionID: next.metadata.sessionID,
  624. })
  625. }
  626. break
  627. case "finish":
  628. log.info("message finish", {
  629. reason: value.finishReason,
  630. })
  631. const assistant = next.metadata!.assistant!
  632. const usage = getUsage(
  633. model.info,
  634. value.usage,
  635. value.providerMetadata,
  636. )
  637. assistant.cost = usage.cost
  638. await updateMessage(next)
  639. if (value.finishReason === "length")
  640. throw new Message.OutputLengthError({})
  641. break
  642. default:
  643. l.info("unhandled", {
  644. type: value.type,
  645. })
  646. continue
  647. }
  648. await updateMessage(next)
  649. }
  650. } catch (e: any) {
  651. log.error("stream error", {
  652. error: e,
  653. })
  654. switch (true) {
  655. case Message.OutputLengthError.isInstance(e):
  656. next.metadata.error = e
  657. break
  658. case LoadAPIKeyError.isInstance(e):
  659. next.metadata.error = new Provider.AuthError(
  660. {
  661. providerID: input.providerID,
  662. message: e.message,
  663. },
  664. { cause: e },
  665. ).toObject()
  666. break
  667. case e instanceof Error:
  668. next.metadata.error = new NamedError.Unknown(
  669. { message: e.toString() },
  670. { cause: e },
  671. ).toObject()
  672. break
  673. default:
  674. next.metadata.error = new NamedError.Unknown(
  675. { message: JSON.stringify(e) },
  676. { cause: e },
  677. )
  678. }
  679. Bus.publish(Event.Error, {
  680. error: next.metadata.error,
  681. })
  682. }
  683. next.metadata!.time.completed = Date.now()
  684. for (const part of next.parts) {
  685. if (
  686. part.type === "tool-invocation" &&
  687. part.toolInvocation.state !== "result"
  688. ) {
  689. part.toolInvocation = {
  690. ...part.toolInvocation,
  691. state: "result",
  692. result: "request was aborted",
  693. }
  694. }
  695. }
  696. await updateMessage(next)
  697. return next
  698. }
  699. export async function summarize(input: {
  700. sessionID: string
  701. providerID: string
  702. modelID: string
  703. }) {
  704. using abort = lock(input.sessionID)
  705. const msgs = await messages(input.sessionID)
  706. const lastSummary = msgs.findLast(
  707. (msg) => msg.metadata.assistant?.summary === true,
  708. )?.id
  709. const filtered = msgs.filter((msg) => !lastSummary || msg.id >= lastSummary)
  710. const model = await Provider.getModel(input.providerID, input.modelID)
  711. const app = App.info()
  712. const system = SystemPrompt.summarize(input.providerID)
  713. const next: Message.Info = {
  714. id: Identifier.ascending("message"),
  715. role: "assistant",
  716. parts: [],
  717. metadata: {
  718. tool: {},
  719. sessionID: input.sessionID,
  720. assistant: {
  721. system,
  722. path: {
  723. cwd: app.path.cwd,
  724. root: app.path.root,
  725. },
  726. summary: true,
  727. cost: 0,
  728. modelID: input.modelID,
  729. providerID: input.providerID,
  730. tokens: {
  731. input: 0,
  732. output: 0,
  733. reasoning: 0,
  734. cache: { read: 0, write: 0 },
  735. },
  736. },
  737. time: {
  738. created: Date.now(),
  739. },
  740. },
  741. }
  742. await updateMessage(next)
  743. let text: Message.TextPart | undefined
  744. const result = streamText({
  745. abortSignal: abort.signal,
  746. model: model.language,
  747. messages: [
  748. ...system.map(
  749. (x): CoreMessage => ({
  750. role: "system",
  751. content: x,
  752. }),
  753. ),
  754. ...convertToCoreMessages(filtered.map(toUIMessage)),
  755. {
  756. role: "user",
  757. content: [
  758. {
  759. type: "text",
  760. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  761. },
  762. ],
  763. },
  764. ],
  765. onStepFinish: async (step) => {
  766. const assistant = next.metadata!.assistant!
  767. const usage = getUsage(model.info, step.usage, step.providerMetadata)
  768. assistant.cost += usage.cost
  769. assistant.tokens = usage.tokens
  770. await updateMessage(next)
  771. if (text) {
  772. Bus.publish(Message.Event.PartUpdated, {
  773. part: text,
  774. messageID: next.id,
  775. sessionID: next.metadata.sessionID,
  776. })
  777. }
  778. text = undefined
  779. },
  780. async onFinish(input) {
  781. const assistant = next.metadata!.assistant!
  782. const usage = getUsage(model.info, input.usage, input.providerMetadata)
  783. assistant.cost = usage.cost
  784. assistant.tokens = usage.tokens
  785. next.metadata!.time.completed = Date.now()
  786. await updateMessage(next)
  787. },
  788. })
  789. for await (const value of result.fullStream) {
  790. switch (value.type) {
  791. case "text-delta":
  792. if (!text) {
  793. text = {
  794. type: "text",
  795. text: value.textDelta,
  796. }
  797. next.parts.push(text)
  798. } else text.text += value.textDelta
  799. await updateMessage(next)
  800. break
  801. }
  802. }
  803. }
  804. function lock(sessionID: string) {
  805. log.info("locking", { sessionID })
  806. if (state().pending.has(sessionID)) throw new BusyError(sessionID)
  807. const controller = new AbortController()
  808. state().pending.set(sessionID, controller)
  809. return {
  810. signal: controller.signal,
  811. [Symbol.dispose]() {
  812. log.info("unlocking", { sessionID })
  813. state().pending.delete(sessionID)
  814. Config.get().then((cfg) => {
  815. if (cfg.experimental?.hook?.session_completed) {
  816. for (const item of cfg.experimental.hook.session_completed) {
  817. Bun.spawn({
  818. cmd: item.command,
  819. cwd: App.info().path.cwd,
  820. env: item.environment,
  821. })
  822. }
  823. }
  824. })
  825. },
  826. }
  827. }
  828. function getUsage(
  829. model: ModelsDev.Model,
  830. usage: LanguageModelUsage,
  831. metadata?: ProviderMetadata,
  832. ) {
  833. const tokens = {
  834. input: usage.promptTokens ?? 0,
  835. output: usage.completionTokens ?? 0,
  836. reasoning: 0,
  837. cache: {
  838. write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  839. 0) as number,
  840. read: (metadata?.["anthropic"]?.["cacheReadInputTokens"] ??
  841. 0) as number,
  842. },
  843. }
  844. return {
  845. cost: new Decimal(0)
  846. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  847. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  848. .add(
  849. new Decimal(tokens.cache.read)
  850. .mul(model.cost.cache_read ?? 0)
  851. .div(1_000_000),
  852. )
  853. .add(
  854. new Decimal(tokens.cache.write)
  855. .mul(model.cost.cache_write ?? 0)
  856. .div(1_000_000),
  857. )
  858. .toNumber(),
  859. tokens,
  860. }
  861. }
  862. export class BusyError extends Error {
  863. constructor(public readonly sessionID: string) {
  864. super(`Session ${sessionID} is busy`)
  865. }
  866. }
  867. export async function initialize(input: {
  868. sessionID: string
  869. modelID: string
  870. providerID: string
  871. }) {
  872. const app = App.info()
  873. await Session.chat({
  874. sessionID: input.sessionID,
  875. providerID: input.providerID,
  876. modelID: input.modelID,
  877. parts: [
  878. {
  879. type: "text",
  880. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  881. },
  882. ],
  883. })
  884. await App.initialize()
  885. }
  886. }
  887. function toUIMessage(msg: Message.Info): UIMessage {
  888. if (msg.role === "assistant") {
  889. return {
  890. id: msg.id,
  891. role: "assistant",
  892. content: "",
  893. parts: toParts(msg.parts),
  894. }
  895. }
  896. if (msg.role === "user") {
  897. return {
  898. id: msg.id,
  899. role: "user",
  900. content: "",
  901. parts: toParts(msg.parts),
  902. }
  903. }
  904. throw new Error("not implemented")
  905. }
  906. function toParts(parts: Message.Part[]): UIMessage["parts"] {
  907. const result: UIMessage["parts"] = []
  908. for (const part of parts) {
  909. switch (part.type) {
  910. case "text":
  911. result.push({ type: "text", text: part.text })
  912. break
  913. case "file":
  914. result.push({
  915. type: "file",
  916. data: part.url,
  917. mimeType: part.mediaType,
  918. })
  919. break
  920. case "tool-invocation":
  921. result.push({
  922. type: "tool-invocation",
  923. toolInvocation: part.toolInvocation,
  924. })
  925. break
  926. case "step-start":
  927. result.push({
  928. type: "step-start",
  929. })
  930. break
  931. default:
  932. break
  933. }
  934. }
  935. return result
  936. }