| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322 |
- import path from "path"
- import { Decimal } from "decimal.js"
- import { z, ZodSchema } from "zod"
- import {
- generateText,
- LoadAPIKeyError,
- streamText,
- tool,
- wrapLanguageModel,
- type Tool as AITool,
- type LanguageModelUsage,
- type ProviderMetadata,
- type ModelMessage,
- stepCountIs,
- type StreamTextResult,
- } from "ai"
- import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
- import PROMPT_PLAN from "../session/prompt/plan.txt"
- import { App } from "../app/app"
- import { Bus } from "../bus"
- import { Config } from "../config/config"
- import { Flag } from "../flag/flag"
- import { Identifier } from "../id/id"
- import { Installation } from "../installation"
- import { MCP } from "../mcp"
- import { Provider } from "../provider/provider"
- import { ProviderTransform } from "../provider/transform"
- import type { ModelsDev } from "../provider/models"
- import { Share } from "../share/share"
- import { Snapshot } from "../snapshot"
- import { Storage } from "../storage/storage"
- import { Log } from "../util/log"
- import { NamedError } from "../util/error"
- import { SystemPrompt } from "./system"
- import { FileTime } from "../file/time"
- import { MessageV2 } from "./message-v2"
- import { Mode } from "./mode"
- import { LSP } from "../lsp"
- import { ReadTool } from "../tool/read"
- import { mergeDeep, pipe, splitWhen } from "remeda"
- import { ToolRegistry } from "../tool/registry"
- export namespace Session {
- const log = Log.create({ service: "session" })
- const OUTPUT_TOKEN_MAX = 32_000
- export const Info = z
- .object({
- id: Identifier.schema("session"),
- parentID: Identifier.schema("session").optional(),
- share: z
- .object({
- url: z.string(),
- })
- .optional(),
- title: z.string(),
- version: z.string(),
- time: z.object({
- created: z.number(),
- updated: z.number(),
- }),
- revert: z
- .object({
- messageID: z.string(),
- partID: z.string().optional(),
- snapshot: z.string().optional(),
- })
- .optional(),
- })
- .openapi({
- ref: "Session",
- })
- export type Info = z.output<typeof Info>
- export const ShareInfo = z
- .object({
- secret: z.string(),
- url: z.string(),
- })
- .openapi({
- ref: "SessionShare",
- })
- export type ShareInfo = z.output<typeof ShareInfo>
- export const Event = {
- Updated: Bus.event(
- "session.updated",
- z.object({
- info: Info,
- }),
- ),
- Deleted: Bus.event(
- "session.deleted",
- z.object({
- info: Info,
- }),
- ),
- Idle: Bus.event(
- "session.idle",
- z.object({
- sessionID: z.string(),
- }),
- ),
- Error: Bus.event(
- "session.error",
- z.object({
- sessionID: z.string().optional(),
- error: MessageV2.Assistant.shape.error,
- }),
- ),
- }
- const state = App.state(
- "session",
- () => {
- const sessions = new Map<string, Info>()
- const messages = new Map<string, MessageV2.Info[]>()
- const pending = new Map<string, AbortController>()
- const queued = new Map<
- string,
- {
- input: ChatInput
- message: MessageV2.User
- parts: MessageV2.Part[]
- processed: boolean
- callback: (input: { info: MessageV2.Assistant; parts: MessageV2.Part[] }) => void
- }[]
- >()
- return {
- sessions,
- messages,
- pending,
- queued,
- }
- },
- async (state) => {
- for (const [_, controller] of state.pending) {
- controller.abort()
- }
- },
- )
- export async function create(parentID?: string) {
- const result: Info = {
- id: Identifier.descending("session"),
- version: Installation.VERSION,
- parentID,
- title: (parentID ? "Child session - " : "New Session - ") + new Date().toISOString(),
- time: {
- created: Date.now(),
- updated: Date.now(),
- },
- }
- log.info("created", result)
- state().sessions.set(result.id, result)
- await Storage.writeJSON("session/info/" + result.id, result)
- const cfg = await Config.get()
- if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
- share(result.id)
- .then((share) => {
- update(result.id, (draft) => {
- draft.share = share
- })
- })
- .catch(() => {
- // Silently ignore sharing errors during session creation
- })
- Bus.publish(Event.Updated, {
- info: result,
- })
- return result
- }
- export async function get(id: string) {
- const result = state().sessions.get(id)
- if (result) {
- return result
- }
- const read = await Storage.readJSON<Info>("session/info/" + id)
- state().sessions.set(id, read)
- return read as Info
- }
- export async function getShare(id: string) {
- return Storage.readJSON<ShareInfo>("session/share/" + id)
- }
- export async function share(id: string) {
- const cfg = await Config.get()
- if (cfg.share === "disabled") {
- throw new Error("Sharing is disabled in configuration")
- }
- const session = await get(id)
- if (session.share) return session.share
- const share = await Share.create(id)
- await update(id, (draft) => {
- draft.share = {
- url: share.url,
- }
- })
- await Storage.writeJSON<ShareInfo>("session/share/" + id, share)
- await Share.sync("session/info/" + id, session)
- for (const msg of await messages(id)) {
- await Share.sync("session/message/" + id + "/" + msg.info.id, msg.info)
- for (const part of msg.parts) {
- await Share.sync("session/part/" + id + "/" + msg.info.id + "/" + part.id, part)
- }
- }
- return share
- }
- export async function unshare(id: string) {
- const share = await getShare(id)
- if (!share) return
- await Storage.remove("session/share/" + id)
- await update(id, (draft) => {
- draft.share = undefined
- })
- await Share.remove(id, share.secret)
- }
- export async function update(id: string, editor: (session: Info) => void) {
- const { sessions } = state()
- const session = await get(id)
- if (!session) return
- editor(session)
- session.time.updated = Date.now()
- sessions.set(id, session)
- await Storage.writeJSON("session/info/" + id, session)
- Bus.publish(Event.Updated, {
- info: session,
- })
- return session
- }
- export async function messages(sessionID: string) {
- const result = [] as {
- info: MessageV2.Info
- parts: MessageV2.Part[]
- }[]
- for (const p of await Storage.list("session/message/" + sessionID)) {
- const read = await Storage.readJSON<MessageV2.Info>(p)
- result.push({
- info: read,
- parts: await getParts(sessionID, read.id),
- })
- }
- result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
- return result
- }
- export async function getMessage(sessionID: string, messageID: string) {
- return Storage.readJSON<MessageV2.Info>("session/message/" + sessionID + "/" + messageID)
- }
- export async function getParts(sessionID: string, messageID: string) {
- const result = [] as MessageV2.Part[]
- for (const item of await Storage.list("session/part/" + sessionID + "/" + messageID)) {
- const read = await Storage.readJSON<MessageV2.Part>(item)
- result.push(read)
- }
- result.sort((a, b) => (a.id > b.id ? 1 : -1))
- return result
- }
- export async function* list() {
- for (const item of await Storage.list("session/info")) {
- const sessionID = path.basename(item, ".json")
- yield get(sessionID)
- }
- }
- export async function children(parentID: string) {
- const result = [] as Session.Info[]
- for (const item of await Storage.list("session/info")) {
- const sessionID = path.basename(item, ".json")
- const session = await get(sessionID)
- if (session.parentID !== parentID) continue
- result.push(session)
- }
- return result
- }
- export function abort(sessionID: string) {
- const controller = state().pending.get(sessionID)
- if (!controller) return false
- controller.abort()
- state().pending.delete(sessionID)
- return true
- }
- export async function remove(sessionID: string, emitEvent = true) {
- try {
- abort(sessionID)
- const session = await get(sessionID)
- for (const child of await children(sessionID)) {
- await remove(child.id, false)
- }
- await unshare(sessionID).catch(() => {})
- await Storage.remove(`session/info/${sessionID}`).catch(() => {})
- await Storage.removeDir(`session/message/${sessionID}/`).catch(() => {})
- state().sessions.delete(sessionID)
- state().messages.delete(sessionID)
- if (emitEvent) {
- Bus.publish(Event.Deleted, {
- info: session,
- })
- }
- } catch (e) {
- log.error(e)
- }
- }
- async function updateMessage(msg: MessageV2.Info) {
- await Storage.writeJSON("session/message/" + msg.sessionID + "/" + msg.id, msg)
- Bus.publish(MessageV2.Event.Updated, {
- info: msg,
- })
- }
- async function updatePart(part: MessageV2.Part) {
- await Storage.writeJSON(["session", "part", part.sessionID, part.messageID, part.id].join("/"), part)
- Bus.publish(MessageV2.Event.PartUpdated, {
- part,
- })
- return part
- }
- export const ChatInput = z.object({
- sessionID: Identifier.schema("session"),
- messageID: Identifier.schema("message").optional(),
- providerID: z.string(),
- modelID: z.string(),
- mode: z.string().optional(),
- system: z.string().optional(),
- tools: z.record(z.boolean()).optional(),
- parts: z.array(
- z.discriminatedUnion("type", [
- MessageV2.TextPart.omit({
- messageID: true,
- sessionID: true,
- })
- .partial({
- id: true,
- })
- .openapi({
- ref: "TextPartInput",
- }),
- MessageV2.FilePart.omit({
- messageID: true,
- sessionID: true,
- })
- .partial({
- id: true,
- })
- .openapi({
- ref: "FilePartInput",
- }),
- ]),
- ),
- })
- export type ChatInput = z.infer<typeof ChatInput>
- export async function chat(
- input: z.infer<typeof ChatInput>,
- ): Promise<{ info: MessageV2.Assistant; parts: MessageV2.Part[] }> {
- const l = log.clone().tag("session", input.sessionID)
- l.info("chatting")
- const inputMode = input.mode ?? "build"
- const userMsg: MessageV2.Info = {
- id: input.messageID ?? Identifier.ascending("message"),
- role: "user",
- sessionID: input.sessionID,
- time: {
- created: Date.now(),
- },
- }
- const app = App.info()
- const userParts = await Promise.all(
- input.parts.map(async (part): Promise<MessageV2.Part[]> => {
- if (part.type === "file") {
- const url = new URL(part.url)
- switch (url.protocol) {
- case "data:":
- if (part.mime === "text/plain") {
- return [
- {
- id: Identifier.ascending("part"),
- messageID: userMsg.id,
- sessionID: input.sessionID,
- type: "text",
- synthetic: true,
- text: `Called the Read tool with the following input: ${JSON.stringify({ filePath: part.filename })}`,
- },
- {
- id: Identifier.ascending("part"),
- messageID: userMsg.id,
- sessionID: input.sessionID,
- type: "text",
- synthetic: true,
- text: Buffer.from(part.url, "base64url").toString(),
- },
- {
- ...part,
- id: part.id ?? Identifier.ascending("part"),
- messageID: userMsg.id,
- sessionID: input.sessionID,
- },
- ]
- }
- break
- case "file:":
- // have to normalize, symbol search returns absolute paths
- // Decode the pathname since URL constructor doesn't automatically decode it
- const pathname = decodeURIComponent(url.pathname)
- const relativePath = pathname.replace(app.path.cwd, ".")
- const filePath = path.join(app.path.cwd, relativePath)
- if (part.mime === "text/plain") {
- let offset: number | undefined = undefined
- let limit: number | undefined = undefined
- const range = {
- start: url.searchParams.get("start"),
- end: url.searchParams.get("end"),
- }
- if (range.start != null) {
- const filePath = part.url.split("?")[0]
- let start = parseInt(range.start)
- let end = range.end ? parseInt(range.end) : undefined
- // some LSP servers (eg, gopls) don't give full range in
- // workspace/symbol searches, so we'll try to find the
- // symbol in the document to get the full range
- if (start === end) {
- const symbols = await LSP.documentSymbol(filePath)
- for (const symbol of symbols) {
- let range: LSP.Range | undefined
- if ("range" in symbol) {
- range = symbol.range
- } else if ("location" in symbol) {
- range = symbol.location.range
- }
- if (range?.start?.line && range?.start?.line === start) {
- start = range.start.line
- end = range?.end?.line ?? start
- break
- }
- }
- offset = Math.max(start - 2, 0)
- if (end) {
- limit = end - offset + 2
- }
- }
- }
- const args = { filePath, offset, limit }
- const result = await ReadTool.init().then((t) =>
- t.execute(args, {
- sessionID: input.sessionID,
- abort: new AbortController().signal,
- messageID: userMsg.id,
- metadata: async () => {},
- }),
- )
- return [
- {
- id: Identifier.ascending("part"),
- messageID: userMsg.id,
- sessionID: input.sessionID,
- type: "text",
- synthetic: true,
- text: `Called the Read tool with the following input: ${JSON.stringify(args)}`,
- },
- {
- id: Identifier.ascending("part"),
- messageID: userMsg.id,
- sessionID: input.sessionID,
- type: "text",
- synthetic: true,
- text: result.output,
- },
- {
- ...part,
- id: part.id ?? Identifier.ascending("part"),
- messageID: userMsg.id,
- sessionID: input.sessionID,
- },
- ]
- }
- let file = Bun.file(filePath)
- FileTime.read(input.sessionID, filePath)
- return [
- {
- id: Identifier.ascending("part"),
- messageID: userMsg.id,
- sessionID: input.sessionID,
- type: "text",
- text: `Called the Read tool with the following input: {\"filePath\":\"${pathname}\"}`,
- synthetic: true,
- },
- {
- id: part.id ?? Identifier.ascending("part"),
- messageID: userMsg.id,
- sessionID: input.sessionID,
- type: "file",
- url: `data:${part.mime};base64,` + Buffer.from(await file.bytes()).toString("base64"),
- mime: part.mime,
- filename: part.filename!,
- source: part.source,
- },
- ]
- }
- }
- return [
- {
- id: Identifier.ascending("part"),
- ...part,
- messageID: userMsg.id,
- sessionID: input.sessionID,
- },
- ]
- }),
- ).then((x) => x.flat())
- if (inputMode === "plan")
- userParts.push({
- id: Identifier.ascending("part"),
- messageID: userMsg.id,
- sessionID: input.sessionID,
- type: "text",
- text: PROMPT_PLAN,
- synthetic: true,
- })
- await updateMessage(userMsg)
- for (const part of userParts) {
- await updatePart(part)
- }
- // mark session as updated since a message has been added to it
- await update(input.sessionID, (_draft) => {})
- if (isLocked(input.sessionID)) {
- return new Promise((resolve) => {
- const queue = state().queued.get(input.sessionID) ?? []
- queue.push({
- input: input,
- message: userMsg,
- parts: userParts,
- processed: false,
- callback: resolve,
- })
- state().queued.set(input.sessionID, queue)
- })
- }
- const model = await Provider.getModel(input.providerID, input.modelID)
- let msgs = await messages(input.sessionID)
- const session = await get(input.sessionID)
- if (session.revert) {
- const messageID = session.revert.messageID
- const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
- msgs = preserve
- for (const msg of remove) {
- await Storage.remove(`session/message/${input.sessionID}/${msg.info.id}`)
- await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: msg.info.id })
- }
- const last = preserve.at(-1)
- if (session.revert.partID && last) {
- const partID = session.revert.partID
- const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
- last.parts = preserveParts
- for (const part of removeParts) {
- await Storage.remove(`session/part/${input.sessionID}/${last.info.id}/${part.id}`)
- await Bus.publish(MessageV2.Event.PartRemoved, {
- messageID: last.info.id,
- partID: part.id,
- })
- }
- }
- }
- const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
- const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
- // auto summarize if too long
- if (previous && previous.tokens) {
- const tokens =
- previous.tokens.input + previous.tokens.cache.read + previous.tokens.cache.write + previous.tokens.output
- if (model.info.limit.context && tokens > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
- await summarize({
- sessionID: input.sessionID,
- providerID: input.providerID,
- modelID: input.modelID,
- })
- return chat(input)
- }
- }
- using abort = lock(input.sessionID)
- const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
- if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
- if (msgs.length === 1 && !session.parentID) {
- const small = (await Provider.getSmallModel(input.providerID)) ?? model
- generateText({
- maxOutputTokens: small.info.reasoning ? 1024 : 20,
- providerOptions: {
- [input.providerID]: small.info.options,
- },
- messages: [
- ...SystemPrompt.title(input.providerID).map(
- (x): ModelMessage => ({
- role: "system",
- content: x,
- }),
- ),
- ...MessageV2.toModelMessage([
- {
- info: {
- id: Identifier.ascending("message"),
- role: "user",
- sessionID: input.sessionID,
- time: {
- created: Date.now(),
- },
- },
- parts: userParts,
- },
- ]),
- ],
- model: small.language,
- })
- .then((result) => {
- if (result.text)
- return Session.update(input.sessionID, (draft) => {
- draft.title = result.text
- })
- })
- .catch(() => {})
- }
- const mode = await Mode.get(inputMode)
- let system = SystemPrompt.header(input.providerID)
- system.push(
- ...(() => {
- if (input.system) return [input.system]
- if (mode.prompt) return [mode.prompt]
- return SystemPrompt.provider(input.modelID)
- })(),
- )
- system.push(...(await SystemPrompt.environment()))
- system.push(...(await SystemPrompt.custom()))
- // max 2 system prompt messages for caching purposes
- const [first, ...rest] = system
- system = [first, rest.join("\n")]
- const assistantMsg: MessageV2.Info = {
- id: Identifier.ascending("message"),
- role: "assistant",
- system,
- mode: inputMode,
- path: {
- cwd: app.path.cwd,
- root: app.path.root,
- },
- cost: 0,
- tokens: {
- input: 0,
- output: 0,
- reasoning: 0,
- cache: { read: 0, write: 0 },
- },
- modelID: input.modelID,
- providerID: input.providerID,
- time: {
- created: Date.now(),
- },
- sessionID: input.sessionID,
- }
- await updateMessage(assistantMsg)
- const tools: Record<string, AITool> = {}
- const processor = createProcessor(assistantMsg, model.info)
- const enabledTools = pipe(
- mode.tools,
- mergeDeep(ToolRegistry.enabled(input.providerID, input.modelID)),
- mergeDeep(input.tools ?? {}),
- )
- for (const item of await ToolRegistry.tools(input.providerID, input.modelID)) {
- if (enabledTools[item.id] === false) continue
- tools[item.id] = tool({
- id: item.id as any,
- description: item.description,
- inputSchema: item.parameters as ZodSchema,
- async execute(args, options) {
- await processor.track(options.toolCallId)
- const result = await item.execute(args, {
- sessionID: input.sessionID,
- abort: abort.signal,
- messageID: assistantMsg.id,
- metadata: async (val) => {
- const match = processor.partFromToolCall(options.toolCallId)
- if (match && match.state.status === "running") {
- await updatePart({
- ...match,
- state: {
- title: val.title,
- metadata: val.metadata,
- status: "running",
- input: args,
- time: {
- start: Date.now(),
- },
- },
- })
- }
- },
- })
- return result
- },
- toModelOutput(result) {
- return {
- type: "text",
- value: result.output,
- }
- },
- })
- }
- for (const [key, item] of Object.entries(await MCP.tools())) {
- if (mode.tools[key] === false) continue
- const execute = item.execute
- if (!execute) continue
- item.execute = async (args, opts) => {
- await processor.track(opts.toolCallId)
- const result = await execute(args, opts)
- const output = result.content
- .filter((x: any) => x.type === "text")
- .map((x: any) => x.text)
- .join("\n\n")
- return {
- output,
- }
- }
- item.toModelOutput = (result) => {
- return {
- type: "text",
- value: result.output,
- }
- }
- tools[key] = item
- }
- const stream = streamText({
- onError() {},
- async prepareStep({ messages }) {
- const queue = (state().queued.get(input.sessionID) ?? []).filter((x) => !x.processed)
- if (queue.length) {
- for (const item of queue) {
- if (item.processed) continue
- messages.push(
- ...MessageV2.toModelMessage([
- {
- info: item.message,
- parts: item.parts,
- },
- ]),
- )
- item.processed = true
- }
- assistantMsg.time.completed = Date.now()
- await updateMessage(assistantMsg)
- Object.assign(assistantMsg, {
- id: Identifier.ascending("message"),
- role: "assistant",
- system,
- path: {
- cwd: app.path.cwd,
- root: app.path.root,
- },
- cost: 0,
- tokens: {
- input: 0,
- output: 0,
- reasoning: 0,
- cache: { read: 0, write: 0 },
- },
- modelID: input.modelID,
- providerID: input.providerID,
- mode: inputMode,
- time: {
- created: Date.now(),
- },
- sessionID: input.sessionID,
- })
- await updateMessage(assistantMsg)
- }
- return {
- messages,
- }
- },
- maxRetries: 10,
- maxOutputTokens: outputLimit,
- abortSignal: abort.signal,
- stopWhen: stepCountIs(1000),
- providerOptions: {
- [input.providerID]: model.info.options,
- },
- messages: [
- ...system.map(
- (x): ModelMessage => ({
- role: "system",
- content: x,
- }),
- ),
- ...MessageV2.toModelMessage(msgs),
- ],
- temperature: model.info.temperature
- ? (mode.temperature ?? ProviderTransform.temperature(input.providerID, input.modelID))
- : undefined,
- tools: model.info.tool_call === false ? undefined : tools,
- model: wrapLanguageModel({
- model: model.language,
- middleware: [
- {
- async transformParams(args) {
- if (args.type === "stream") {
- // @ts-expect-error
- args.params.prompt = ProviderTransform.message(args.params.prompt, input.providerID, input.modelID)
- }
- return args.params
- },
- },
- ],
- }),
- })
- const result = await processor.process(stream)
- const queued = state().queued.get(input.sessionID) ?? []
- const unprocessed = queued.find((x) => !x.processed)
- if (unprocessed) {
- unprocessed.processed = true
- return chat(unprocessed.input)
- }
- for (const item of queued) {
- item.callback(result)
- }
- state().queued.delete(input.sessionID)
- return result
- }
- function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
- const toolCalls: Record<string, MessageV2.ToolPart> = {}
- const snapshots: Record<string, string> = {}
- return {
- async track(toolCallID: string) {
- const hash = await Snapshot.track()
- if (hash) snapshots[toolCallID] = hash
- },
- partFromToolCall(toolCallID: string) {
- return toolCalls[toolCallID]
- },
- async process(stream: StreamTextResult<Record<string, AITool>, never>) {
- try {
- let currentText: MessageV2.TextPart | undefined
- for await (const value of stream.fullStream) {
- log.info("part", {
- type: value.type,
- })
- switch (value.type) {
- case "start":
- break
- case "tool-input-start":
- const part = await updatePart({
- id: Identifier.ascending("part"),
- messageID: assistantMsg.id,
- sessionID: assistantMsg.sessionID,
- type: "tool",
- tool: value.toolName,
- callID: value.id,
- state: {
- status: "pending",
- },
- })
- toolCalls[value.id] = part as MessageV2.ToolPart
- break
- case "tool-input-delta":
- break
- case "tool-input-end":
- break
- case "tool-call": {
- const match = toolCalls[value.toolCallId]
- if (match) {
- const part = await updatePart({
- ...match,
- state: {
- status: "running",
- input: value.input,
- time: {
- start: Date.now(),
- },
- },
- })
- toolCalls[value.toolCallId] = part as MessageV2.ToolPart
- }
- break
- }
- case "tool-result": {
- const match = toolCalls[value.toolCallId]
- if (match && match.state.status === "running") {
- await updatePart({
- ...match,
- state: {
- status: "completed",
- input: value.input,
- output: value.output.output,
- metadata: value.output.metadata,
- title: value.output.title,
- time: {
- start: match.state.time.start,
- end: Date.now(),
- },
- },
- })
- delete toolCalls[value.toolCallId]
- const snapshot = snapshots[value.toolCallId]
- if (snapshot) {
- const patch = await Snapshot.patch(snapshot)
- if (patch.files.length) {
- await updatePart({
- id: Identifier.ascending("part"),
- messageID: assistantMsg.id,
- sessionID: assistantMsg.sessionID,
- type: "patch",
- hash: patch.hash,
- files: patch.files,
- })
- }
- }
- }
- break
- }
- case "tool-error": {
- const match = toolCalls[value.toolCallId]
- if (match && match.state.status === "running") {
- await updatePart({
- ...match,
- state: {
- status: "error",
- input: value.input,
- error: (value.error as any).toString(),
- time: {
- start: match.state.time.start,
- end: Date.now(),
- },
- },
- })
- delete toolCalls[value.toolCallId]
- const snapshot = snapshots[value.toolCallId]
- if (snapshot) {
- const patch = await Snapshot.patch(snapshot)
- await updatePart({
- id: Identifier.ascending("part"),
- messageID: assistantMsg.id,
- sessionID: assistantMsg.sessionID,
- type: "patch",
- hash: patch.hash,
- files: patch.files,
- })
- }
- }
- break
- }
- case "error":
- throw value.error
- case "start-step":
- await updatePart({
- id: Identifier.ascending("part"),
- messageID: assistantMsg.id,
- sessionID: assistantMsg.sessionID,
- type: "step-start",
- })
- break
- case "finish-step":
- const usage = getUsage(model, value.usage, value.providerMetadata)
- assistantMsg.cost += usage.cost
- assistantMsg.tokens = usage.tokens
- await updatePart({
- id: Identifier.ascending("part"),
- messageID: assistantMsg.id,
- sessionID: assistantMsg.sessionID,
- type: "step-finish",
- tokens: usage.tokens,
- cost: usage.cost,
- })
- await updateMessage(assistantMsg)
- break
- case "text-start":
- currentText = {
- id: Identifier.ascending("part"),
- messageID: assistantMsg.id,
- sessionID: assistantMsg.sessionID,
- type: "text",
- text: "",
- time: {
- start: Date.now(),
- },
- }
- break
- case "text":
- if (currentText) {
- currentText.text += value.text
- await updatePart(currentText)
- }
- break
- case "text-end":
- if (currentText && currentText.text) {
- currentText.time = {
- start: Date.now(),
- end: Date.now(),
- }
- currentText.text = currentText.text.trimEnd()
- await updatePart(currentText)
- }
- currentText = undefined
- break
- case "finish":
- assistantMsg.time.completed = Date.now()
- await updateMessage(assistantMsg)
- break
- default:
- log.info("unhandled", {
- ...value,
- })
- continue
- }
- }
- } catch (e) {
- log.error("", {
- error: e,
- })
- switch (true) {
- case e instanceof DOMException && e.name === "AbortError":
- assistantMsg.error = new MessageV2.AbortedError(
- { message: e.message },
- {
- cause: e,
- },
- ).toObject()
- break
- case MessageV2.OutputLengthError.isInstance(e):
- assistantMsg.error = e
- break
- case LoadAPIKeyError.isInstance(e):
- assistantMsg.error = new MessageV2.AuthError(
- {
- providerID: model.id,
- message: e.message,
- },
- { cause: e },
- ).toObject()
- break
- case e instanceof Error:
- assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
- break
- default:
- assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
- }
- Bus.publish(Event.Error, {
- sessionID: assistantMsg.sessionID,
- error: assistantMsg.error,
- })
- }
- const p = await getParts(assistantMsg.sessionID, assistantMsg.id)
- for (const part of p) {
- if (part.type === "tool" && part.state.status !== "completed") {
- updatePart({
- ...part,
- state: {
- status: "error",
- error: "Tool execution aborted",
- time: {
- start: Date.now(),
- end: Date.now(),
- },
- input: {},
- },
- })
- }
- }
- assistantMsg.time.completed = Date.now()
- await updateMessage(assistantMsg)
- return { info: assistantMsg, parts: p }
- },
- }
- }
- export const RevertInput = z.object({
- sessionID: Identifier.schema("session"),
- messageID: Identifier.schema("message"),
- partID: Identifier.schema("part").optional(),
- })
- export type RevertInput = z.infer<typeof RevertInput>
- export async function revert(input: RevertInput) {
- const all = await messages(input.sessionID)
- let lastUser: MessageV2.User | undefined
- const session = await get(input.sessionID)
- let revert: Info["revert"]
- const patches: Snapshot.Patch[] = []
- for (const msg of all) {
- if (msg.info.role === "user") lastUser = msg.info
- const remaining = []
- for (const part of msg.parts) {
- if (revert) {
- if (part.type === "patch") {
- patches.push(part)
- }
- continue
- }
- if (!revert) {
- if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
- // if no useful parts left in message, same as reverting whole message
- const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
- revert = {
- messageID: !partID && lastUser ? lastUser.id : msg.info.id,
- partID,
- }
- }
- remaining.push(part)
- }
- }
- }
- if (revert) {
- const session = await get(input.sessionID)
- revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
- await Snapshot.revert(patches)
- return update(input.sessionID, (draft) => {
- draft.revert = revert
- })
- }
- return session
- }
- export async function unrevert(input: { sessionID: string }) {
- log.info("unreverting", input)
- const session = await get(input.sessionID)
- if (!session.revert) return session
- if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
- const next = await update(input.sessionID, (draft) => {
- draft.revert = undefined
- })
- return next
- }
- export async function summarize(input: { sessionID: string; providerID: string; modelID: string }) {
- using abort = lock(input.sessionID)
- const msgs = await messages(input.sessionID)
- const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
- const filtered = msgs.filter((msg) => !lastSummary || msg.info.id >= lastSummary.info.id)
- const model = await Provider.getModel(input.providerID, input.modelID)
- const app = App.info()
- const system = [
- ...SystemPrompt.summarize(input.providerID),
- ...(await SystemPrompt.environment()),
- ...(await SystemPrompt.custom()),
- ]
- const next: MessageV2.Info = {
- id: Identifier.ascending("message"),
- role: "assistant",
- sessionID: input.sessionID,
- system,
- mode: "build",
- path: {
- cwd: app.path.cwd,
- root: app.path.root,
- },
- summary: true,
- cost: 0,
- modelID: input.modelID,
- providerID: input.providerID,
- tokens: {
- input: 0,
- output: 0,
- reasoning: 0,
- cache: { read: 0, write: 0 },
- },
- time: {
- created: Date.now(),
- },
- }
- await updateMessage(next)
- const processor = createProcessor(next, model.info)
- const stream = streamText({
- maxRetries: 10,
- abortSignal: abort.signal,
- model: model.language,
- messages: [
- ...system.map(
- (x): ModelMessage => ({
- role: "system",
- content: x,
- }),
- ),
- ...MessageV2.toModelMessage(filtered),
- {
- role: "user",
- content: [
- {
- type: "text",
- text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
- },
- ],
- },
- ],
- })
- const result = await processor.process(stream)
- return result
- }
- function isLocked(sessionID: string) {
- return state().pending.has(sessionID)
- }
- function lock(sessionID: string) {
- log.info("locking", { sessionID })
- if (state().pending.has(sessionID)) throw new BusyError(sessionID)
- const controller = new AbortController()
- state().pending.set(sessionID, controller)
- return {
- signal: controller.signal,
- [Symbol.dispose]() {
- log.info("unlocking", { sessionID })
- state().pending.delete(sessionID)
- Bus.publish(Event.Idle, {
- sessionID,
- })
- },
- }
- }
- function getUsage(model: ModelsDev.Model, usage: LanguageModelUsage, metadata?: ProviderMetadata) {
- const tokens = {
- input: usage.inputTokens ?? 0,
- output: usage.outputTokens ?? 0,
- reasoning: 0,
- cache: {
- write: (metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
- // @ts-expect-error
- metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
- 0) as number,
- read: usage.cachedInputTokens ?? 0,
- },
- }
- return {
- cost: new Decimal(0)
- .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
- .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
- .add(new Decimal(tokens.cache.read).mul(model.cost.cache_read ?? 0).div(1_000_000))
- .add(new Decimal(tokens.cache.write).mul(model.cost.cache_write ?? 0).div(1_000_000))
- .toNumber(),
- tokens,
- }
- }
- export class BusyError extends Error {
- constructor(public readonly sessionID: string) {
- super(`Session ${sessionID} is busy`)
- }
- }
- export async function initialize(input: {
- sessionID: string
- modelID: string
- providerID: string
- messageID: string
- }) {
- const app = App.info()
- await Session.chat({
- sessionID: input.sessionID,
- messageID: input.messageID,
- providerID: input.providerID,
- modelID: input.modelID,
- parts: [
- {
- id: Identifier.ascending("part"),
- type: "text",
- text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
- },
- ],
- })
- await App.initialize()
- }
- }
|