index.ts 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891
  1. import path from "path"
  2. import { App } from "../app/app"
  3. import { Identifier } from "../id/id"
  4. import { Storage } from "../storage/storage"
  5. import { Log } from "../util/log"
  6. import {
  7. generateText,
  8. LoadAPIKeyError,
  9. convertToCoreMessages,
  10. streamText,
  11. tool,
  12. type Tool as AITool,
  13. type LanguageModelUsage,
  14. type CoreMessage,
  15. type UIMessage,
  16. type ProviderMetadata,
  17. } from "ai"
  18. import { z, ZodSchema } from "zod"
  19. import { Decimal } from "decimal.js"
  20. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  21. import { Share } from "../share/share"
  22. import { Message } from "./message"
  23. import { Bus } from "../bus"
  24. import { Provider } from "../provider/provider"
  25. import { MCP } from "../mcp"
  26. import { NamedError } from "../util/error"
  27. import type { Tool } from "../tool/tool"
  28. import { SystemPrompt } from "./system"
  29. import { Flag } from "../flag/flag"
  30. import type { ModelsDev } from "../provider/models"
  31. import { Installation } from "../installation"
  32. import { Config } from "../config/config"
  33. export namespace Session {
  34. const log = Log.create({ service: "session" })
  35. export const Info = z
  36. .object({
  37. id: Identifier.schema("session"),
  38. parentID: Identifier.schema("session").optional(),
  39. share: z
  40. .object({
  41. url: z.string(),
  42. })
  43. .optional(),
  44. title: z.string(),
  45. version: z.string(),
  46. time: z.object({
  47. created: z.number(),
  48. updated: z.number(),
  49. }),
  50. })
  51. .openapi({
  52. ref: "session.info",
  53. })
  54. export type Info = z.output<typeof Info>
  55. export const ShareInfo = z.object({
  56. secret: z.string(),
  57. url: z.string(),
  58. })
  59. export type ShareInfo = z.output<typeof ShareInfo>
  60. export const Event = {
  61. Updated: Bus.event(
  62. "session.updated",
  63. z.object({
  64. info: Info,
  65. }),
  66. ),
  67. Error: Bus.event(
  68. "session.error",
  69. z.object({
  70. error: Message.Info.shape.metadata.shape.error,
  71. }),
  72. ),
  73. }
  74. const state = App.state(
  75. "session",
  76. () => {
  77. const sessions = new Map<string, Info>()
  78. const messages = new Map<string, Message.Info[]>()
  79. const pending = new Map<string, AbortController>()
  80. return {
  81. sessions,
  82. messages,
  83. pending,
  84. }
  85. },
  86. async (state) => {
  87. for (const [_, controller] of state.pending) {
  88. controller.abort()
  89. }
  90. },
  91. )
  92. export async function create(parentID?: string) {
  93. const result: Info = {
  94. id: Identifier.descending("session"),
  95. version: Installation.VERSION,
  96. parentID,
  97. title:
  98. (parentID ? "Child session - " : "New Session - ") +
  99. new Date().toISOString(),
  100. time: {
  101. created: Date.now(),
  102. updated: Date.now(),
  103. },
  104. }
  105. log.info("created", result)
  106. state().sessions.set(result.id, result)
  107. await Storage.writeJSON("session/info/" + result.id, result)
  108. const cfg = await Config.get()
  109. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.autoshare))
  110. share(result.id).then((share) => {
  111. update(result.id, (draft) => {
  112. draft.share = share
  113. })
  114. })
  115. Bus.publish(Event.Updated, {
  116. info: result,
  117. })
  118. return result
  119. }
  120. export async function get(id: string) {
  121. const result = state().sessions.get(id)
  122. if (result) {
  123. return result
  124. }
  125. const read = await Storage.readJSON<Info>("session/info/" + id)
  126. state().sessions.set(id, read)
  127. return read as Info
  128. }
  129. export async function getShare(id: string) {
  130. return Storage.readJSON<ShareInfo>("session/share/" + id)
  131. }
  132. export async function share(id: string) {
  133. const session = await get(id)
  134. if (session.share) return session.share
  135. const share = await Share.create(id)
  136. await update(id, (draft) => {
  137. draft.share = {
  138. url: share.url,
  139. }
  140. })
  141. await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
  142. await Share.sync("session/info/" + id, session)
  143. for (const msg of await messages(id)) {
  144. await Share.sync("session/message/" + id + "/" + msg.id, msg)
  145. }
  146. return share
  147. }
  148. export async function update(id: string, editor: (session: Info) => void) {
  149. const { sessions } = state()
  150. const session = await get(id)
  151. if (!session) return
  152. editor(session)
  153. session.time.updated = Date.now()
  154. sessions.set(id, session)
  155. await Storage.writeJSON("session/info/" + id, session)
  156. Bus.publish(Event.Updated, {
  157. info: session,
  158. })
  159. return session
  160. }
  161. export async function messages(sessionID: string) {
  162. const result = [] as Message.Info[]
  163. const list = Storage.list("session/message/" + sessionID)
  164. for await (const p of list) {
  165. const read = await Storage.readJSON<Message.Info>(p)
  166. result.push(read)
  167. }
  168. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  169. return result
  170. }
  171. export async function getMessage(sessionID: string, messageID: string) {
  172. return Storage.readJSON<Message.Info>(
  173. "session/message/" + sessionID + "/" + messageID,
  174. )
  175. }
  176. export async function* list() {
  177. for await (const item of Storage.list("session/info")) {
  178. const sessionID = path.basename(item, ".json")
  179. yield get(sessionID)
  180. }
  181. }
  182. export function abort(sessionID: string) {
  183. const controller = state().pending.get(sessionID)
  184. if (!controller) return false
  185. controller.abort()
  186. state().pending.delete(sessionID)
  187. return true
  188. }
  189. async function updateMessage(msg: Message.Info) {
  190. await Storage.writeJSON(
  191. "session/message/" + msg.metadata.sessionID + "/" + msg.id,
  192. msg,
  193. )
  194. Bus.publish(Message.Event.Updated, {
  195. info: msg,
  196. })
  197. }
  198. export async function chat(input: {
  199. sessionID: string
  200. providerID: string
  201. modelID: string
  202. parts: Message.Part[]
  203. system?: string[]
  204. tools?: Tool.Info[]
  205. }) {
  206. const l = log.clone().tag("session", input.sessionID)
  207. l.info("chatting")
  208. const model = await Provider.getModel(input.providerID, input.modelID)
  209. let msgs = await messages(input.sessionID)
  210. const previous = msgs.at(-1)
  211. // auto summarize if too long
  212. if (previous?.metadata.assistant) {
  213. const tokens =
  214. previous.metadata.assistant.tokens.input +
  215. previous.metadata.assistant.tokens.cache.read +
  216. previous.metadata.assistant.tokens.cache.write +
  217. previous.metadata.assistant.tokens.output
  218. if (
  219. model.info.limit.context &&
  220. tokens >
  221. (model.info.limit.context - (model.info.limit.output ?? 0)) * 0.9
  222. ) {
  223. await summarize({
  224. sessionID: input.sessionID,
  225. providerID: input.providerID,
  226. modelID: input.modelID,
  227. })
  228. return chat(input)
  229. }
  230. }
  231. using abort = lock(input.sessionID)
  232. const lastSummary = msgs.findLast(
  233. (msg) => msg.metadata.assistant?.summary === true,
  234. )
  235. if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
  236. const app = App.info()
  237. const session = await get(input.sessionID)
  238. if (msgs.length === 0 && !session.parentID) {
  239. generateText({
  240. maxTokens: input.providerID === "google" ? 1024 : 20,
  241. messages: [
  242. ...SystemPrompt.title(input.providerID).map(
  243. (x): CoreMessage => ({
  244. role: "system",
  245. content: x,
  246. providerOptions: {
  247. ...(input.providerID === "anthropic"
  248. ? {
  249. anthropic: {
  250. cacheControl: { type: "ephemeral" },
  251. },
  252. }
  253. : {}),
  254. },
  255. }),
  256. ),
  257. ...convertToCoreMessages([
  258. {
  259. role: "user",
  260. content: "",
  261. parts: toParts(input.parts),
  262. },
  263. ]),
  264. ],
  265. model: model.language,
  266. })
  267. .then((result) => {
  268. if (result.text)
  269. return Session.update(input.sessionID, (draft) => {
  270. draft.title = result.text
  271. })
  272. })
  273. .catch(() => {})
  274. }
  275. const msg: Message.Info = {
  276. role: "user",
  277. id: Identifier.ascending("message"),
  278. parts: input.parts,
  279. metadata: {
  280. time: {
  281. created: Date.now(),
  282. },
  283. sessionID: input.sessionID,
  284. tool: {},
  285. },
  286. }
  287. await updateMessage(msg)
  288. msgs.push(msg)
  289. const system = input.system ?? SystemPrompt.provider(input.providerID)
  290. system.push(...(await SystemPrompt.environment()))
  291. system.push(...(await SystemPrompt.custom()))
  292. const next: Message.Info = {
  293. id: Identifier.ascending("message"),
  294. role: "assistant",
  295. parts: [],
  296. metadata: {
  297. assistant: {
  298. system,
  299. path: {
  300. cwd: app.path.cwd,
  301. root: app.path.root,
  302. },
  303. cost: 0,
  304. tokens: {
  305. input: 0,
  306. output: 0,
  307. reasoning: 0,
  308. cache: { read: 0, write: 0 },
  309. },
  310. modelID: input.modelID,
  311. providerID: input.providerID,
  312. },
  313. time: {
  314. created: Date.now(),
  315. },
  316. sessionID: input.sessionID,
  317. tool: {},
  318. },
  319. }
  320. await updateMessage(next)
  321. const tools: Record<string, AITool> = {}
  322. for (const item of await Provider.tools(input.providerID)) {
  323. tools[item.id.replaceAll(".", "_")] = tool({
  324. id: item.id as any,
  325. description: item.description,
  326. parameters: item.parameters as ZodSchema,
  327. async execute(args, opts) {
  328. const start = Date.now()
  329. try {
  330. const result = await item.execute(args, {
  331. sessionID: input.sessionID,
  332. abort: abort.signal,
  333. messageID: next.id,
  334. metadata: async (val) => {
  335. next.metadata.tool[opts.toolCallId] = {
  336. ...val,
  337. time: {
  338. start: 0,
  339. end: 0,
  340. },
  341. }
  342. await updateMessage(next)
  343. },
  344. })
  345. next.metadata!.tool![opts.toolCallId] = {
  346. ...result.metadata,
  347. time: {
  348. start,
  349. end: Date.now(),
  350. },
  351. }
  352. await updateMessage(next)
  353. return result.output
  354. } catch (e: any) {
  355. next.metadata!.tool![opts.toolCallId] = {
  356. error: true,
  357. message: e.toString(),
  358. title: e.toString(),
  359. time: {
  360. start,
  361. end: Date.now(),
  362. },
  363. }
  364. await updateMessage(next)
  365. return e.toString()
  366. }
  367. },
  368. })
  369. }
  370. for (const [key, item] of Object.entries(await MCP.tools())) {
  371. const execute = item.execute
  372. if (!execute) continue
  373. item.execute = async (args, opts) => {
  374. const start = Date.now()
  375. try {
  376. const result = await execute(args, opts)
  377. next.metadata!.tool![opts.toolCallId] = {
  378. ...result.metadata,
  379. time: {
  380. start,
  381. end: Date.now(),
  382. },
  383. }
  384. await updateMessage(next)
  385. return result.content
  386. .filter((x: any) => x.type === "text")
  387. .map((x: any) => x.text)
  388. .join("\n\n")
  389. } catch (e: any) {
  390. next.metadata!.tool![opts.toolCallId] = {
  391. error: true,
  392. message: e.toString(),
  393. title: "mcp",
  394. time: {
  395. start,
  396. end: Date.now(),
  397. },
  398. }
  399. await updateMessage(next)
  400. return e.toString()
  401. }
  402. }
  403. tools[key] = item
  404. }
  405. let text: Message.TextPart | undefined
  406. await Bun.write(
  407. "/tmp/message.json",
  408. JSON.stringify(
  409. [
  410. ...system.map(
  411. (x): CoreMessage => ({
  412. role: "system",
  413. content: x,
  414. }),
  415. ),
  416. ...convertToCoreMessages(
  417. msgs.map(toUIMessage).filter((x) => x.parts.length > 0),
  418. ),
  419. ],
  420. null,
  421. 2,
  422. ),
  423. )
  424. const result = streamText({
  425. onStepFinish: async (step) => {
  426. log.info("step finish", { finishReason: step.finishReason })
  427. const assistant = next.metadata!.assistant!
  428. const usage = getUsage(model.info, step.usage, step.providerMetadata)
  429. assistant.cost += usage.cost
  430. assistant.tokens = usage.tokens
  431. await updateMessage(next)
  432. if (text) {
  433. Bus.publish(Message.Event.PartUpdated, {
  434. part: text,
  435. messageID: next.id,
  436. sessionID: next.metadata.sessionID,
  437. })
  438. }
  439. text = undefined
  440. },
  441. async onFinish(input) {
  442. log.info("message finish", {
  443. reason: input.finishReason,
  444. })
  445. const assistant = next.metadata!.assistant!
  446. const usage = getUsage(model.info, input.usage, input.providerMetadata)
  447. assistant.cost = usage.cost
  448. await updateMessage(next)
  449. },
  450. onError(err) {
  451. log.error("callback error", err)
  452. switch (true) {
  453. case LoadAPIKeyError.isInstance(err.error):
  454. next.metadata.error = new Provider.AuthError(
  455. {
  456. providerID: input.providerID,
  457. message: err.error.message,
  458. },
  459. { cause: err.error },
  460. ).toObject()
  461. break
  462. case err.error instanceof Error:
  463. next.metadata.error = new NamedError.Unknown(
  464. { message: err.error.toString() },
  465. { cause: err.error },
  466. ).toObject()
  467. break
  468. default:
  469. next.metadata.error = new NamedError.Unknown(
  470. { message: JSON.stringify(err.error) },
  471. { cause: err.error },
  472. )
  473. }
  474. Bus.publish(Event.Error, {
  475. error: next.metadata.error,
  476. })
  477. },
  478. // async prepareStep(step) {
  479. // next.parts.push({
  480. // type: "step-start",
  481. // })
  482. // await updateMessage(next)
  483. // return step
  484. // },
  485. toolCallStreaming: true,
  486. abortSignal: abort.signal,
  487. maxSteps: 1000,
  488. messages: [
  489. ...system.map(
  490. (x, index): CoreMessage => ({
  491. role: "system",
  492. content: x,
  493. providerOptions: {
  494. ...(input.providerID === "anthropic" && index < 4
  495. ? {
  496. anthropic: {
  497. cacheControl: { type: "ephemeral" },
  498. },
  499. }
  500. : {}),
  501. },
  502. }),
  503. ),
  504. ...convertToCoreMessages(
  505. msgs.map(toUIMessage).filter((x) => x.parts.length > 0),
  506. ),
  507. ],
  508. temperature: model.info.temperature ? 0 : undefined,
  509. tools: {
  510. ...tools,
  511. },
  512. model: model.language,
  513. })
  514. try {
  515. for await (const value of result.fullStream) {
  516. l.info("part", {
  517. type: value.type,
  518. })
  519. switch (value.type) {
  520. case "step-start":
  521. next.parts.push({
  522. type: "step-start",
  523. })
  524. break
  525. case "text-delta":
  526. if (!text) {
  527. text = {
  528. type: "text",
  529. text: value.textDelta,
  530. }
  531. next.parts.push(text)
  532. break
  533. } else text.text += value.textDelta
  534. break
  535. case "tool-call": {
  536. const [match] = next.parts.flatMap((p) =>
  537. p.type === "tool-invocation" &&
  538. p.toolInvocation.toolCallId === value.toolCallId
  539. ? [p]
  540. : [],
  541. )
  542. if (!match) break
  543. match.toolInvocation.args = value.args
  544. match.toolInvocation.state = "call"
  545. Bus.publish(Message.Event.PartUpdated, {
  546. part: match,
  547. messageID: next.id,
  548. sessionID: next.metadata.sessionID,
  549. })
  550. break
  551. }
  552. case "tool-call-streaming-start":
  553. next.parts.push({
  554. type: "tool-invocation",
  555. toolInvocation: {
  556. state: "partial-call",
  557. toolName: value.toolName,
  558. toolCallId: value.toolCallId,
  559. args: {},
  560. },
  561. })
  562. Bus.publish(Message.Event.PartUpdated, {
  563. part: next.parts[next.parts.length - 1],
  564. messageID: next.id,
  565. sessionID: next.metadata.sessionID,
  566. })
  567. break
  568. case "tool-call-delta":
  569. break
  570. // for some reason ai sdk claims to not send this part but it does
  571. // @ts-expect-error
  572. case "tool-result":
  573. const match = next.parts.find(
  574. (p) =>
  575. p.type === "tool-invocation" &&
  576. // @ts-expect-error
  577. p.toolInvocation.toolCallId === value.toolCallId,
  578. )
  579. if (match && match.type === "tool-invocation") {
  580. match.toolInvocation = {
  581. // @ts-expect-error
  582. args: value.args,
  583. // @ts-expect-error
  584. toolCallId: value.toolCallId,
  585. // @ts-expect-error
  586. toolName: value.toolName,
  587. state: "result",
  588. // @ts-expect-error
  589. result: value.result as string,
  590. }
  591. Bus.publish(Message.Event.PartUpdated, {
  592. part: match,
  593. messageID: next.id,
  594. sessionID: next.metadata.sessionID,
  595. })
  596. }
  597. break
  598. default:
  599. l.info("unhandled", {
  600. type: value.type,
  601. })
  602. }
  603. await updateMessage(next)
  604. }
  605. } catch (e: any) {
  606. log.error("stream error", {
  607. error: e,
  608. })
  609. switch (true) {
  610. case LoadAPIKeyError.isInstance(e):
  611. next.metadata.error = new Provider.AuthError(
  612. {
  613. providerID: input.providerID,
  614. message: e.message,
  615. },
  616. { cause: e },
  617. ).toObject()
  618. break
  619. case e instanceof Error:
  620. next.metadata.error = new NamedError.Unknown(
  621. { message: e.toString() },
  622. { cause: e },
  623. ).toObject()
  624. break
  625. default:
  626. next.metadata.error = new NamedError.Unknown(
  627. { message: JSON.stringify(e) },
  628. { cause: e },
  629. )
  630. }
  631. Bus.publish(Event.Error, {
  632. error: next.metadata.error,
  633. })
  634. }
  635. next.metadata!.time.completed = Date.now()
  636. for (const part of next.parts) {
  637. if (
  638. part.type === "tool-invocation" &&
  639. part.toolInvocation.state !== "result"
  640. ) {
  641. part.toolInvocation = {
  642. ...part.toolInvocation,
  643. state: "result",
  644. result: "request was aborted",
  645. }
  646. }
  647. }
  648. await updateMessage(next)
  649. return next
  650. }
  651. export async function summarize(input: {
  652. sessionID: string
  653. providerID: string
  654. modelID: string
  655. }) {
  656. using abort = lock(input.sessionID)
  657. const msgs = await messages(input.sessionID)
  658. const lastSummary = msgs.findLast(
  659. (msg) => msg.metadata.assistant?.summary === true,
  660. )?.id
  661. const filtered = msgs.filter((msg) => !lastSummary || msg.id >= lastSummary)
  662. const model = await Provider.getModel(input.providerID, input.modelID)
  663. const app = App.info()
  664. const system = SystemPrompt.summarize(input.providerID)
  665. const next: Message.Info = {
  666. id: Identifier.ascending("message"),
  667. role: "assistant",
  668. parts: [],
  669. metadata: {
  670. tool: {},
  671. sessionID: input.sessionID,
  672. assistant: {
  673. system,
  674. path: {
  675. cwd: app.path.cwd,
  676. root: app.path.root,
  677. },
  678. summary: true,
  679. cost: 0,
  680. modelID: input.modelID,
  681. providerID: input.providerID,
  682. tokens: {
  683. input: 0,
  684. output: 0,
  685. reasoning: 0,
  686. cache: { read: 0, write: 0 },
  687. },
  688. },
  689. time: {
  690. created: Date.now(),
  691. },
  692. },
  693. }
  694. await updateMessage(next)
  695. const result = await generateText({
  696. abortSignal: abort.signal,
  697. model: model.language,
  698. messages: [
  699. ...system.map(
  700. (x): CoreMessage => ({
  701. role: "system",
  702. content: x,
  703. }),
  704. ),
  705. ...convertToCoreMessages(filtered.map(toUIMessage)),
  706. {
  707. role: "user",
  708. content: [
  709. {
  710. type: "text",
  711. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  712. },
  713. ],
  714. },
  715. ],
  716. })
  717. next.parts.push({
  718. type: "text",
  719. text: result.text,
  720. })
  721. const assistant = next.metadata!.assistant!
  722. const usage = getUsage(model.info, result.usage, result.providerMetadata)
  723. assistant.cost = usage.cost
  724. assistant.tokens = usage.tokens
  725. await updateMessage(next)
  726. }
  727. function lock(sessionID: string) {
  728. log.info("locking", { sessionID })
  729. if (state().pending.has(sessionID)) throw new BusyError(sessionID)
  730. const controller = new AbortController()
  731. state().pending.set(sessionID, controller)
  732. return {
  733. signal: controller.signal,
  734. [Symbol.dispose]() {
  735. log.info("unlocking", { sessionID })
  736. state().pending.delete(sessionID)
  737. },
  738. }
  739. }
  740. function getUsage(
  741. model: ModelsDev.Model,
  742. usage: LanguageModelUsage,
  743. metadata?: ProviderMetadata,
  744. ) {
  745. const tokens = {
  746. input: usage.promptTokens ?? 0,
  747. output: usage.completionTokens ?? 0,
  748. reasoning: 0,
  749. cache: {
  750. write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  751. 0) as number,
  752. read: (metadata?.["anthropic"]?.["cacheReadInputTokens"] ??
  753. 0) as number,
  754. },
  755. }
  756. return {
  757. cost: new Decimal(0)
  758. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  759. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  760. .add(
  761. new Decimal(tokens.cache.read)
  762. .mul(model.cost.cache_read ?? 0)
  763. .div(1_000_000),
  764. )
  765. .add(
  766. new Decimal(tokens.cache.write)
  767. .mul(model.cost.cache_write ?? 0)
  768. .div(1_000_000),
  769. )
  770. .toNumber(),
  771. tokens,
  772. }
  773. }
  774. export class BusyError extends Error {
  775. constructor(public readonly sessionID: string) {
  776. super(`Session ${sessionID} is busy`)
  777. }
  778. }
  779. export async function initialize(input: {
  780. sessionID: string
  781. modelID: string
  782. providerID: string
  783. }) {
  784. const app = App.info()
  785. await Session.chat({
  786. sessionID: input.sessionID,
  787. providerID: input.providerID,
  788. modelID: input.modelID,
  789. parts: [
  790. {
  791. type: "text",
  792. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  793. },
  794. ],
  795. })
  796. await App.initialize()
  797. }
  798. }
  799. function toUIMessage(msg: Message.Info): UIMessage {
  800. if (msg.role === "assistant") {
  801. return {
  802. id: msg.id,
  803. role: "assistant",
  804. content: "",
  805. parts: toParts(msg.parts),
  806. }
  807. }
  808. if (msg.role === "user") {
  809. return {
  810. id: msg.id,
  811. role: "user",
  812. content: "",
  813. parts: toParts(msg.parts),
  814. }
  815. }
  816. throw new Error("not implemented")
  817. }
  818. function toParts(parts: Message.Part[]): UIMessage["parts"] {
  819. const result: UIMessage["parts"] = []
  820. for (const part of parts) {
  821. switch (part.type) {
  822. case "text":
  823. result.push({ type: "text", text: part.text })
  824. break
  825. case "file":
  826. result.push({
  827. type: "file",
  828. data: part.url,
  829. mimeType: part.mediaType,
  830. })
  831. break
  832. case "tool-invocation":
  833. result.push({
  834. type: "tool-invocation",
  835. toolInvocation: part.toolInvocation,
  836. })
  837. break
  838. case "step-start":
  839. result.push({
  840. type: "step-start",
  841. })
  842. break
  843. default:
  844. break
  845. }
  846. }
  847. return result
  848. }