|
|
@@ -1,4 +1,4 @@
|
|
|
-import { streamText, type ModelMessage } from "ai"
|
|
|
+import { streamText, type ModelMessage, LoadAPIKeyError } from "ai"
|
|
|
import { Session } from "."
|
|
|
import { Identifier } from "../id/id"
|
|
|
import { Instance } from "../project/instance"
|
|
|
@@ -13,6 +13,8 @@ import { SessionPrompt } from "./prompt"
|
|
|
import { Flag } from "../flag/flag"
|
|
|
import { Token } from "../util/token"
|
|
|
import { Log } from "../util/log"
|
|
|
+import { SessionLock } from "./lock"
|
|
|
+import { NamedError } from "../util/error"
|
|
|
|
|
|
export namespace SessionCompaction {
|
|
|
const log = Log.create({ service: "session.compaction" })
|
|
|
@@ -82,7 +84,11 @@ export namespace SessionCompaction {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- export async function run(input: { sessionID: string; providerID: string; modelID: string }) {
|
|
|
+ export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
|
|
|
+ if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
|
|
|
+ await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
|
|
|
+ const signal = input.signal ?? lock!.signal
|
|
|
+
|
|
|
await Session.update(input.sessionID, (draft) => {
|
|
|
draft.time.compacting = Date.now()
|
|
|
})
|
|
|
@@ -122,6 +128,7 @@ export namespace SessionCompaction {
|
|
|
created: Date.now(),
|
|
|
},
|
|
|
})) as MessageV2.Assistant
|
|
|
+
|
|
|
const part = (await Session.updatePart({
|
|
|
type: "text",
|
|
|
sessionID: input.sessionID,
|
|
|
@@ -133,13 +140,18 @@ export namespace SessionCompaction {
|
|
|
},
|
|
|
})) as MessageV2.TextPart
|
|
|
|
|
|
- let summaryText = ""
|
|
|
const stream = streamText({
|
|
|
maxRetries: 10,
|
|
|
model: model.language,
|
|
|
providerOptions: {
|
|
|
[model.npm === "@ai-sdk/openai" ? "openai" : model.providerID]: model.info.options,
|
|
|
},
|
|
|
+ abortSignal: signal,
|
|
|
+ onError(error) {
|
|
|
+ log.error("stream error", {
|
|
|
+ error,
|
|
|
+ })
|
|
|
+ },
|
|
|
messages: [
|
|
|
...system.map(
|
|
|
(x): ModelMessage => ({
|
|
|
@@ -160,38 +172,88 @@ export namespace SessionCompaction {
|
|
|
],
|
|
|
})
|
|
|
|
|
|
- for await (const value of stream.fullStream) {
|
|
|
- switch (value.type) {
|
|
|
- case "text-delta":
|
|
|
- summaryText += value.text
|
|
|
- await Session.updatePart({
|
|
|
- ...part,
|
|
|
- text: summaryText,
|
|
|
- })
|
|
|
+ try {
|
|
|
+ for await (const value of stream.fullStream) {
|
|
|
+ signal.throwIfAborted()
|
|
|
+ switch (value.type) {
|
|
|
+ case "text-delta":
|
|
|
+ part.text += value.text
|
|
|
+ if (value.providerMetadata) part.metadata = value.providerMetadata
|
|
|
+ if (part.text) await Session.updatePart(part)
|
|
|
+ continue
|
|
|
+ case "text-end": {
|
|
|
+ part.text = part.text.trimEnd()
|
|
|
+ part.time = {
|
|
|
+ start: Date.now(),
|
|
|
+ end: Date.now(),
|
|
|
+ }
|
|
|
+ if (value.providerMetadata) part.metadata = value.providerMetadata
|
|
|
+ await Session.updatePart(part)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ case "finish-step": {
|
|
|
+ const usage = Session.getUsage({
|
|
|
+ model: model.info,
|
|
|
+ usage: value.usage,
|
|
|
+ metadata: value.providerMetadata,
|
|
|
+ })
|
|
|
+ msg.cost += usage.cost
|
|
|
+ msg.tokens = usage.tokens
|
|
|
+ await Session.updateMessage(msg)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ case "error":
|
|
|
+ throw value.error
|
|
|
+ default:
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (e) {
|
|
|
+ log.error("compaction error", {
|
|
|
+ error: e,
|
|
|
+ })
|
|
|
+ switch (true) {
|
|
|
+ case e instanceof DOMException && e.name === "AbortError":
|
|
|
+ msg.error = new MessageV2.AbortedError(
|
|
|
+ { message: e.message },
|
|
|
+ {
|
|
|
+ cause: e,
|
|
|
+ },
|
|
|
+ ).toObject()
|
|
|
break
|
|
|
- case "text-end":
|
|
|
- part.text = summaryText
|
|
|
- await Session.updatePart({
|
|
|
- ...part,
|
|
|
- })
|
|
|
+ case MessageV2.OutputLengthError.isInstance(e):
|
|
|
+ msg.error = e
|
|
|
break
|
|
|
- case "finish": {
|
|
|
- const usage = Session.getUsage({ model: model.info, usage: value.totalUsage, metadata: undefined })
|
|
|
- msg.cost += usage.cost
|
|
|
- msg.tokens = usage.tokens
|
|
|
- msg.summary = true
|
|
|
- msg.time.completed = Date.now()
|
|
|
- await Session.updateMessage(msg)
|
|
|
- part.time!.end = Date.now()
|
|
|
- await Session.updatePart(part)
|
|
|
+ case LoadAPIKeyError.isInstance(e):
|
|
|
+ msg.error = new MessageV2.AuthError(
|
|
|
+ {
|
|
|
+ providerID: model.providerID,
|
|
|
+ message: e.message,
|
|
|
+ },
|
|
|
+ { cause: e },
|
|
|
+ ).toObject()
|
|
|
break
|
|
|
- }
|
|
|
+ case e instanceof Error:
|
|
|
+ msg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
|
|
+ break
|
|
|
+ default:
|
|
|
+ msg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
|
|
}
|
|
|
+ Bus.publish(Session.Event.Error, {
|
|
|
+ sessionID: input.sessionID,
|
|
|
+ error: msg.error,
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
- Bus.publish(Event.Compacted, {
|
|
|
- sessionID: input.sessionID,
|
|
|
- })
|
|
|
+ msg.time.completed = Date.now()
|
|
|
+
|
|
|
+ if (!msg.error || MessageV2.AbortedError.isInstance(msg.error)) {
|
|
|
+ msg.summary = true
|
|
|
+ Bus.publish(Event.Compacted, {
|
|
|
+ sessionID: input.sessionID,
|
|
|
+ })
|
|
|
+ }
|
|
|
+ await Session.updateMessage(msg)
|
|
|
|
|
|
return {
|
|
|
info: msg,
|