|
@@ -1,9 +1,8 @@
|
|
|
-import { streamText, type ModelMessage, type StreamTextResult, type Tool as AITool } from "ai"
|
|
|
|
|
|
|
+import { streamText, type ModelMessage } from "ai"
|
|
|
import { Session } from "."
|
|
import { Session } from "."
|
|
|
import { Identifier } from "../id/id"
|
|
import { Identifier } from "../id/id"
|
|
|
import { Instance } from "../project/instance"
|
|
import { Instance } from "../project/instance"
|
|
|
import { Provider } from "../provider/provider"
|
|
import { Provider } from "../provider/provider"
|
|
|
-import { defer } from "../util/defer"
|
|
|
|
|
import { MessageV2 } from "./message-v2"
|
|
import { MessageV2 } from "./message-v2"
|
|
|
import { SystemPrompt } from "./system"
|
|
import { SystemPrompt } from "./system"
|
|
|
import { Bus } from "../bus"
|
|
import { Bus } from "../bus"
|
|
@@ -13,10 +12,9 @@ import { SessionPrompt } from "./prompt"
|
|
|
import { Flag } from "../flag/flag"
|
|
import { Flag } from "../flag/flag"
|
|
|
import { Token } from "../util/token"
|
|
import { Token } from "../util/token"
|
|
|
import { Log } from "../util/log"
|
|
import { Log } from "../util/log"
|
|
|
-import { SessionLock } from "./lock"
|
|
|
|
|
import { ProviderTransform } from "@/provider/transform"
|
|
import { ProviderTransform } from "@/provider/transform"
|
|
|
-import { SessionRetry } from "./retry"
|
|
|
|
|
-import { Config } from "@/config/config"
|
|
|
|
|
|
|
+import { SessionProcessor } from "./processor"
|
|
|
|
|
+import { fn } from "@/util/fn"
|
|
|
|
|
|
|
|
export namespace SessionCompaction {
|
|
export namespace SessionCompaction {
|
|
|
const log = Log.create({ service: "session.compaction" })
|
|
const log = Log.create({ service: "session.compaction" })
|
|
@@ -42,7 +40,6 @@ export namespace SessionCompaction {
|
|
|
|
|
|
|
|
export const PRUNE_MINIMUM = 20_000
|
|
export const PRUNE_MINIMUM = 20_000
|
|
|
export const PRUNE_PROTECT = 40_000
|
|
export const PRUNE_PROTECT = 40_000
|
|
|
- const MAX_RETRIES = 10
|
|
|
|
|
|
|
|
|
|
// goes backwards through parts until there are 40_000 tokens worth of tool
|
|
// goes backwards through parts until there are 40_000 tokens worth of tool
|
|
|
// calls. then erases output of previous tool calls. idea is to throw away old
|
|
// calls. then erases output of previous tool calls. idea is to throw away old
|
|
@@ -87,38 +84,29 @@ export namespace SessionCompaction {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
|
|
|
|
|
- if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
|
|
|
|
|
- await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
|
|
|
|
|
- const signal = input.signal ?? lock!.signal
|
|
|
|
|
-
|
|
|
|
|
- await Session.update(input.sessionID, (draft) => {
|
|
|
|
|
- draft.time.compacting = Date.now()
|
|
|
|
|
- })
|
|
|
|
|
- await using _ = defer(async () => {
|
|
|
|
|
- await Session.update(input.sessionID, (draft) => {
|
|
|
|
|
- draft.time.compacting = undefined
|
|
|
|
|
- })
|
|
|
|
|
- })
|
|
|
|
|
- const toSummarize = await MessageV2.filterCompacted(MessageV2.stream(input.sessionID))
|
|
|
|
|
- const model = await Provider.getModel(input.providerID, input.modelID)
|
|
|
|
|
- const system = [
|
|
|
|
|
- ...SystemPrompt.summarize(model.providerID),
|
|
|
|
|
- ...(await SystemPrompt.environment()),
|
|
|
|
|
- ...(await SystemPrompt.custom()),
|
|
|
|
|
- ]
|
|
|
|
|
-
|
|
|
|
|
|
|
+ export async function process(input: {
|
|
|
|
|
+ parentID: string
|
|
|
|
|
+ messages: MessageV2.WithParts[]
|
|
|
|
|
+ sessionID: string
|
|
|
|
|
+ model: {
|
|
|
|
|
+ providerID: string
|
|
|
|
|
+ modelID: string
|
|
|
|
|
+ }
|
|
|
|
|
+ abort: AbortSignal
|
|
|
|
|
+ }) {
|
|
|
|
|
+ const model = await Provider.getModel(input.model.providerID, input.model.modelID)
|
|
|
|
|
+ const system = [...SystemPrompt.summarize(model.providerID)]
|
|
|
const msg = (await Session.updateMessage({
|
|
const msg = (await Session.updateMessage({
|
|
|
id: Identifier.ascending("message"),
|
|
id: Identifier.ascending("message"),
|
|
|
role: "assistant",
|
|
role: "assistant",
|
|
|
- parentID: toSummarize.findLast((m) => m.info.role === "user")?.info.id!,
|
|
|
|
|
|
|
+ parentID: input.parentID,
|
|
|
sessionID: input.sessionID,
|
|
sessionID: input.sessionID,
|
|
|
mode: "build",
|
|
mode: "build",
|
|
|
|
|
+ summary: true,
|
|
|
path: {
|
|
path: {
|
|
|
cwd: Instance.directory,
|
|
cwd: Instance.directory,
|
|
|
root: Instance.worktree,
|
|
root: Instance.worktree,
|
|
|
},
|
|
},
|
|
|
- summary: true,
|
|
|
|
|
cost: 0,
|
|
cost: 0,
|
|
|
tokens: {
|
|
tokens: {
|
|
|
output: 0,
|
|
output: 0,
|
|
@@ -126,37 +114,27 @@ export namespace SessionCompaction {
|
|
|
reasoning: 0,
|
|
reasoning: 0,
|
|
|
cache: { read: 0, write: 0 },
|
|
cache: { read: 0, write: 0 },
|
|
|
},
|
|
},
|
|
|
- modelID: input.modelID,
|
|
|
|
|
|
|
+ modelID: input.model.modelID,
|
|
|
providerID: model.providerID,
|
|
providerID: model.providerID,
|
|
|
time: {
|
|
time: {
|
|
|
created: Date.now(),
|
|
created: Date.now(),
|
|
|
},
|
|
},
|
|
|
})) as MessageV2.Assistant
|
|
})) as MessageV2.Assistant
|
|
|
-
|
|
|
|
|
- const part = (await Session.updatePart({
|
|
|
|
|
- type: "text",
|
|
|
|
|
|
|
+ const processor = SessionProcessor.create({
|
|
|
|
|
+ assistantMessage: msg,
|
|
|
sessionID: input.sessionID,
|
|
sessionID: input.sessionID,
|
|
|
- messageID: msg.id,
|
|
|
|
|
- id: Identifier.ascending("part"),
|
|
|
|
|
- text: "",
|
|
|
|
|
- time: {
|
|
|
|
|
- start: Date.now(),
|
|
|
|
|
- },
|
|
|
|
|
- })) as MessageV2.TextPart
|
|
|
|
|
-
|
|
|
|
|
- const doStream = () =>
|
|
|
|
|
|
|
+ providerID: input.model.providerID,
|
|
|
|
|
+ model: model.info,
|
|
|
|
|
+ abort: input.abort,
|
|
|
|
|
+ })
|
|
|
|
|
+ const result = await processor.process(() =>
|
|
|
streamText({
|
|
streamText({
|
|
|
// set to 0, we handle loop
|
|
// set to 0, we handle loop
|
|
|
maxRetries: 0,
|
|
maxRetries: 0,
|
|
|
model: model.language,
|
|
model: model.language,
|
|
|
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
|
|
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
|
|
|
headers: model.info.headers,
|
|
headers: model.info.headers,
|
|
|
- abortSignal: signal,
|
|
|
|
|
- onError(error) {
|
|
|
|
|
- log.error("stream error", {
|
|
|
|
|
- error,
|
|
|
|
|
- })
|
|
|
|
|
- },
|
|
|
|
|
|
|
+ abortSignal: input.abort,
|
|
|
tools: model.info.tool_call ? {} : undefined,
|
|
tools: model.info.tool_call ? {} : undefined,
|
|
|
messages: [
|
|
messages: [
|
|
|
...system.map(
|
|
...system.map(
|
|
@@ -165,7 +143,7 @@ export namespace SessionCompaction {
|
|
|
content: x,
|
|
content: x,
|
|
|
}),
|
|
}),
|
|
|
),
|
|
),
|
|
|
- ...MessageV2.toModelMessage(toSummarize),
|
|
|
|
|
|
|
+ ...MessageV2.toModelMessage(input.messages),
|
|
|
{
|
|
{
|
|
|
role: "user",
|
|
role: "user",
|
|
|
content: [
|
|
content: [
|
|
@@ -176,168 +154,60 @@ export namespace SessionCompaction {
|
|
|
],
|
|
],
|
|
|
},
|
|
},
|
|
|
],
|
|
],
|
|
|
|
|
+ }),
|
|
|
|
|
+ )
|
|
|
|
|
+ if (result === "continue") {
|
|
|
|
|
+ const continueMsg = await Session.updateMessage({
|
|
|
|
|
+ id: Identifier.ascending("message"),
|
|
|
|
|
+ role: "user",
|
|
|
|
|
+ sessionID: input.sessionID,
|
|
|
|
|
+ time: {
|
|
|
|
|
+ created: Date.now(),
|
|
|
|
|
+ },
|
|
|
|
|
+ agent: "build",
|
|
|
|
|
+ model: input.model,
|
|
|
})
|
|
})
|
|
|
-
|
|
|
|
|
- // TODO: reduce duplication between compaction.ts & prompt.ts
|
|
|
|
|
- const process = async (
|
|
|
|
|
- stream: StreamTextResult<Record<string, AITool>, never>,
|
|
|
|
|
- retries: { count: number; max: number },
|
|
|
|
|
- ) => {
|
|
|
|
|
- let shouldRetry = false
|
|
|
|
|
- try {
|
|
|
|
|
- for await (const value of stream.fullStream) {
|
|
|
|
|
- signal.throwIfAborted()
|
|
|
|
|
- switch (value.type) {
|
|
|
|
|
- case "text-delta":
|
|
|
|
|
- part.text += value.text
|
|
|
|
|
- if (value.providerMetadata) part.metadata = value.providerMetadata
|
|
|
|
|
- if (part.text)
|
|
|
|
|
- await Session.updatePart({
|
|
|
|
|
- part,
|
|
|
|
|
- delta: value.text,
|
|
|
|
|
- })
|
|
|
|
|
- continue
|
|
|
|
|
- case "text-end": {
|
|
|
|
|
- part.text = part.text.trimEnd()
|
|
|
|
|
- part.time = {
|
|
|
|
|
- start: Date.now(),
|
|
|
|
|
- end: Date.now(),
|
|
|
|
|
- }
|
|
|
|
|
- if (value.providerMetadata) part.metadata = value.providerMetadata
|
|
|
|
|
- await Session.updatePart(part)
|
|
|
|
|
- continue
|
|
|
|
|
- }
|
|
|
|
|
- case "finish-step": {
|
|
|
|
|
- const usage = Session.getUsage({
|
|
|
|
|
- model: model.info,
|
|
|
|
|
- usage: value.usage,
|
|
|
|
|
- metadata: value.providerMetadata,
|
|
|
|
|
- })
|
|
|
|
|
- msg.cost += usage.cost
|
|
|
|
|
- msg.tokens = usage.tokens
|
|
|
|
|
- await Session.updateMessage(msg)
|
|
|
|
|
- continue
|
|
|
|
|
- }
|
|
|
|
|
- case "error":
|
|
|
|
|
- throw value.error
|
|
|
|
|
- default:
|
|
|
|
|
- continue
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- } catch (e) {
|
|
|
|
|
- log.error("compaction error", {
|
|
|
|
|
- error: e,
|
|
|
|
|
- })
|
|
|
|
|
- const error = MessageV2.fromError(e, { providerID: input.providerID })
|
|
|
|
|
- if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
|
|
|
|
|
- shouldRetry = true
|
|
|
|
|
- await Session.updatePart({
|
|
|
|
|
- id: Identifier.ascending("part"),
|
|
|
|
|
- messageID: msg.id,
|
|
|
|
|
- sessionID: msg.sessionID,
|
|
|
|
|
- type: "retry",
|
|
|
|
|
- attempt: retries.count + 1,
|
|
|
|
|
- time: {
|
|
|
|
|
- created: Date.now(),
|
|
|
|
|
- },
|
|
|
|
|
- error,
|
|
|
|
|
- })
|
|
|
|
|
- } else {
|
|
|
|
|
- msg.error = error
|
|
|
|
|
- Bus.publish(Session.Event.Error, {
|
|
|
|
|
- sessionID: msg.sessionID,
|
|
|
|
|
- error: msg.error,
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- const parts = await MessageV2.parts(msg.id)
|
|
|
|
|
- return {
|
|
|
|
|
- info: msg,
|
|
|
|
|
- parts,
|
|
|
|
|
- shouldRetry,
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- let stream = doStream()
|
|
|
|
|
- const cfg = await Config.get()
|
|
|
|
|
- const maxRetries = cfg.experimental?.chatMaxRetries ?? MAX_RETRIES
|
|
|
|
|
- let result = await process(stream, {
|
|
|
|
|
- count: 0,
|
|
|
|
|
- max: maxRetries,
|
|
|
|
|
- })
|
|
|
|
|
- if (result.shouldRetry) {
|
|
|
|
|
- const start = Date.now()
|
|
|
|
|
- for (let retry = 1; retry < maxRetries; retry++) {
|
|
|
|
|
- const lastRetryPart = result.parts.findLast((p): p is MessageV2.RetryPart => p.type === "retry")
|
|
|
|
|
-
|
|
|
|
|
- if (lastRetryPart) {
|
|
|
|
|
- const delayMs = SessionRetry.getBoundedDelay({
|
|
|
|
|
- error: lastRetryPart.error,
|
|
|
|
|
- attempt: retry,
|
|
|
|
|
- startTime: start,
|
|
|
|
|
- })
|
|
|
|
|
- if (!delayMs) {
|
|
|
|
|
- break
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- log.info("retrying with backoff", {
|
|
|
|
|
- attempt: retry,
|
|
|
|
|
- delayMs,
|
|
|
|
|
- elapsed: Date.now() - start,
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- const stop = await SessionRetry.sleep(delayMs, signal)
|
|
|
|
|
- .then(() => false)
|
|
|
|
|
- .catch((error) => {
|
|
|
|
|
- if (error instanceof DOMException && error.name === "AbortError") {
|
|
|
|
|
- const err = new MessageV2.AbortedError(
|
|
|
|
|
- { message: error.message },
|
|
|
|
|
- {
|
|
|
|
|
- cause: error,
|
|
|
|
|
- },
|
|
|
|
|
- ).toObject()
|
|
|
|
|
- result.info.error = err
|
|
|
|
|
- Bus.publish(Session.Event.Error, {
|
|
|
|
|
- sessionID: result.info.sessionID,
|
|
|
|
|
- error: result.info.error,
|
|
|
|
|
- })
|
|
|
|
|
- return true
|
|
|
|
|
- }
|
|
|
|
|
- throw error
|
|
|
|
|
- })
|
|
|
|
|
-
|
|
|
|
|
- if (stop) break
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- stream = doStream()
|
|
|
|
|
- result = await process(stream, {
|
|
|
|
|
- count: retry,
|
|
|
|
|
- max: maxRetries,
|
|
|
|
|
- })
|
|
|
|
|
- if (!result.shouldRetry) {
|
|
|
|
|
- break
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- msg.time.completed = Date.now()
|
|
|
|
|
-
|
|
|
|
|
- if (
|
|
|
|
|
- !msg.error ||
|
|
|
|
|
- (MessageV2.AbortedError.isInstance(msg.error) &&
|
|
|
|
|
- result.parts.some((part): part is MessageV2.TextPart => part.type === "text" && part.text.length > 0))
|
|
|
|
|
- ) {
|
|
|
|
|
- msg.summary = true
|
|
|
|
|
- Bus.publish(Event.Compacted, {
|
|
|
|
|
|
|
+ await Session.updatePart({
|
|
|
|
|
+ id: Identifier.ascending("part"),
|
|
|
|
|
+ messageID: continueMsg.id,
|
|
|
sessionID: input.sessionID,
|
|
sessionID: input.sessionID,
|
|
|
|
|
+ type: "text",
|
|
|
|
|
+ synthetic: true,
|
|
|
|
|
+ text: "Continue if you have next steps",
|
|
|
|
|
+ time: {
|
|
|
|
|
+ start: Date.now(),
|
|
|
|
|
+ end: Date.now(),
|
|
|
|
|
+ },
|
|
|
})
|
|
})
|
|
|
}
|
|
}
|
|
|
- await Session.updateMessage(msg)
|
|
|
|
|
-
|
|
|
|
|
- return {
|
|
|
|
|
- info: msg,
|
|
|
|
|
- parts: result.parts,
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ return "continue"
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ export const create = fn(
|
|
|
|
|
+ z.object({
|
|
|
|
|
+ sessionID: Identifier.schema("session"),
|
|
|
|
|
+ model: z.object({
|
|
|
|
|
+ providerID: z.string(),
|
|
|
|
|
+ modelID: z.string(),
|
|
|
|
|
+ }),
|
|
|
|
|
+ }),
|
|
|
|
|
+ async (input) => {
|
|
|
|
|
+ const msg = await Session.updateMessage({
|
|
|
|
|
+ id: Identifier.ascending("message"),
|
|
|
|
|
+ role: "user",
|
|
|
|
|
+ model: input.model,
|
|
|
|
|
+ sessionID: input.sessionID,
|
|
|
|
|
+ agent: "build",
|
|
|
|
|
+ time: {
|
|
|
|
|
+ created: Date.now(),
|
|
|
|
|
+ },
|
|
|
|
|
+ })
|
|
|
|
|
+ await Session.updatePart({
|
|
|
|
|
+ id: Identifier.ascending("part"),
|
|
|
|
|
+ messageID: msg.id,
|
|
|
|
|
+ sessionID: msg.sessionID,
|
|
|
|
|
+ type: "compaction",
|
|
|
|
|
+ })
|
|
|
|
|
+ },
|
|
|
|
|
+ )
|
|
|
}
|
|
}
|