|
@@ -118,11 +118,22 @@ export namespace Session {
|
|
|
const sessions = new Map<string, Info>()
|
|
const sessions = new Map<string, Info>()
|
|
|
const messages = new Map<string, MessageV2.Info[]>()
|
|
const messages = new Map<string, MessageV2.Info[]>()
|
|
|
const pending = new Map<string, AbortController>()
|
|
const pending = new Map<string, AbortController>()
|
|
|
|
|
+ const queued = new Map<
|
|
|
|
|
+ string,
|
|
|
|
|
+ {
|
|
|
|
|
+ input: ChatInput
|
|
|
|
|
+ message: MessageV2.User
|
|
|
|
|
+ parts: MessageV2.Part[]
|
|
|
|
|
+ processed: boolean
|
|
|
|
|
+ callback: (input: ReturnType<typeof chat>) => void
|
|
|
|
|
+ }[]
|
|
|
|
|
+ >()
|
|
|
|
|
|
|
|
return {
|
|
return {
|
|
|
sessions,
|
|
sessions,
|
|
|
messages,
|
|
messages,
|
|
|
pending,
|
|
pending,
|
|
|
|
|
+ queued,
|
|
|
}
|
|
}
|
|
|
},
|
|
},
|
|
|
async (state) => {
|
|
async (state) => {
|
|
@@ -351,64 +362,14 @@ export namespace Session {
|
|
|
]),
|
|
]),
|
|
|
),
|
|
),
|
|
|
})
|
|
})
|
|
|
|
|
+ export type ChatInput = z.infer<typeof ChatInput>
|
|
|
|
|
|
|
|
- export async function chat(input: z.infer<typeof ChatInput>) {
|
|
|
|
|
|
|
+ export async function chat(
|
|
|
|
|
+ input: z.infer<typeof ChatInput>,
|
|
|
|
|
+ ): Promise<{ info: MessageV2.Assistant; parts: MessageV2.Part[] }> {
|
|
|
const l = log.clone().tag("session", input.sessionID)
|
|
const l = log.clone().tag("session", input.sessionID)
|
|
|
l.info("chatting")
|
|
l.info("chatting")
|
|
|
|
|
|
|
|
- const model = await Provider.getModel(input.providerID, input.modelID)
|
|
|
|
|
- let msgs = await messages(input.sessionID)
|
|
|
|
|
- const session = await get(input.sessionID)
|
|
|
|
|
-
|
|
|
|
|
- if (session.revert) {
|
|
|
|
|
- const trimmed = []
|
|
|
|
|
- for (const msg of msgs) {
|
|
|
|
|
- if (
|
|
|
|
|
- msg.info.id > session.revert.messageID ||
|
|
|
|
|
- (msg.info.id === session.revert.messageID && session.revert.part === 0)
|
|
|
|
|
- ) {
|
|
|
|
|
- await Storage.remove("session/message/" + input.sessionID + "/" + msg.info.id)
|
|
|
|
|
- await Bus.publish(MessageV2.Event.Removed, {
|
|
|
|
|
- sessionID: input.sessionID,
|
|
|
|
|
- messageID: msg.info.id,
|
|
|
|
|
- })
|
|
|
|
|
- continue
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if (msg.info.id === session.revert.messageID) {
|
|
|
|
|
- if (session.revert.part === 0) break
|
|
|
|
|
- msg.parts = msg.parts.slice(0, session.revert.part)
|
|
|
|
|
- }
|
|
|
|
|
- trimmed.push(msg)
|
|
|
|
|
- }
|
|
|
|
|
- msgs = trimmed
|
|
|
|
|
- await update(input.sessionID, (draft) => {
|
|
|
|
|
- draft.revert = undefined
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
|
|
|
|
|
- const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
|
|
|
|
|
-
|
|
|
|
|
- // auto summarize if too long
|
|
|
|
|
- if (previous && previous.tokens) {
|
|
|
|
|
- const tokens =
|
|
|
|
|
- previous.tokens.input + previous.tokens.cache.read + previous.tokens.cache.write + previous.tokens.output
|
|
|
|
|
- if (model.info.limit.context && tokens > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
|
|
|
|
|
- await summarize({
|
|
|
|
|
- sessionID: input.sessionID,
|
|
|
|
|
- providerID: input.providerID,
|
|
|
|
|
- modelID: input.modelID,
|
|
|
|
|
- })
|
|
|
|
|
- return chat(input)
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- using abort = lock(input.sessionID)
|
|
|
|
|
-
|
|
|
|
|
- const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
|
|
|
|
|
- if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
|
|
|
|
|
-
|
|
|
|
|
const userMsg: MessageV2.Info = {
|
|
const userMsg: MessageV2.Info = {
|
|
|
id: input.messageID ?? Identifier.ascending("message"),
|
|
id: input.messageID ?? Identifier.ascending("message"),
|
|
|
role: "user",
|
|
role: "user",
|
|
@@ -533,7 +494,6 @@ export namespace Session {
|
|
|
]
|
|
]
|
|
|
}),
|
|
}),
|
|
|
).then((x) => x.flat())
|
|
).then((x) => x.flat())
|
|
|
-
|
|
|
|
|
if (input.mode === "plan")
|
|
if (input.mode === "plan")
|
|
|
userParts.push({
|
|
userParts.push({
|
|
|
id: Identifier.ascending("part"),
|
|
id: Identifier.ascending("part"),
|
|
@@ -544,6 +504,78 @@ export namespace Session {
|
|
|
synthetic: true,
|
|
synthetic: true,
|
|
|
})
|
|
})
|
|
|
|
|
|
|
|
|
|
+ await updateMessage(userMsg)
|
|
|
|
|
+ for (const part of userParts) {
|
|
|
|
|
+ await updatePart(part)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (isLocked(input.sessionID)) {
|
|
|
|
|
+ return new Promise((resolve) => {
|
|
|
|
|
+ const queue = state().queued.get(input.sessionID) ?? []
|
|
|
|
|
+ queue.push({
|
|
|
|
|
+ input: input,
|
|
|
|
|
+ message: userMsg,
|
|
|
|
|
+ parts: userParts,
|
|
|
|
|
+ processed: false,
|
|
|
|
|
+ callback: resolve,
|
|
|
|
|
+ })
|
|
|
|
|
+ state().queued.set(input.sessionID, queue)
|
|
|
|
|
+ })
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const model = await Provider.getModel(input.providerID, input.modelID)
|
|
|
|
|
+ let msgs = await messages(input.sessionID)
|
|
|
|
|
+ const session = await get(input.sessionID)
|
|
|
|
|
+
|
|
|
|
|
+ if (session.revert) {
|
|
|
|
|
+ const trimmed = []
|
|
|
|
|
+ for (const msg of msgs) {
|
|
|
|
|
+ if (
|
|
|
|
|
+ msg.info.id > session.revert.messageID ||
|
|
|
|
|
+ (msg.info.id === session.revert.messageID && session.revert.part === 0)
|
|
|
|
|
+ ) {
|
|
|
|
|
+ await Storage.remove("session/message/" + input.sessionID + "/" + msg.info.id)
|
|
|
|
|
+ await Bus.publish(MessageV2.Event.Removed, {
|
|
|
|
|
+ sessionID: input.sessionID,
|
|
|
|
|
+ messageID: msg.info.id,
|
|
|
|
|
+ })
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (msg.info.id === session.revert.messageID) {
|
|
|
|
|
+ if (session.revert.part === 0) break
|
|
|
|
|
+ msg.parts = msg.parts.slice(0, session.revert.part)
|
|
|
|
|
+ }
|
|
|
|
|
+ trimmed.push(msg)
|
|
|
|
|
+ }
|
|
|
|
|
+ msgs = trimmed
|
|
|
|
|
+ await update(input.sessionID, (draft) => {
|
|
|
|
|
+ draft.revert = undefined
|
|
|
|
|
+ })
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const previous = msgs.filter((x) => x.info.role === "assistant").at(-1)?.info as MessageV2.Assistant
|
|
|
|
|
+ const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
|
|
|
|
|
+
|
|
|
|
|
+ // auto summarize if too long
|
|
|
|
|
+ if (previous && previous.tokens) {
|
|
|
|
|
+ const tokens =
|
|
|
|
|
+ previous.tokens.input + previous.tokens.cache.read + previous.tokens.cache.write + previous.tokens.output
|
|
|
|
|
+ if (model.info.limit.context && tokens > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
|
|
|
|
|
+ await summarize({
|
|
|
|
|
+ sessionID: input.sessionID,
|
|
|
|
|
+ providerID: input.providerID,
|
|
|
|
|
+ modelID: input.modelID,
|
|
|
|
|
+ })
|
|
|
|
|
+ return chat(input)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ using abort = lock(input.sessionID)
|
|
|
|
|
+
|
|
|
|
|
+ const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
|
|
|
|
|
+ if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
|
|
|
|
|
+
|
|
|
if (msgs.length === 0 && !session.parentID) {
|
|
if (msgs.length === 0 && !session.parentID) {
|
|
|
const small = (await Provider.getSmallModel(input.providerID)) ?? model
|
|
const small = (await Provider.getSmallModel(input.providerID)) ?? model
|
|
|
generateText({
|
|
generateText({
|
|
@@ -582,10 +614,6 @@ export namespace Session {
|
|
|
})
|
|
})
|
|
|
.catch(() => {})
|
|
.catch(() => {})
|
|
|
}
|
|
}
|
|
|
- await updateMessage(userMsg)
|
|
|
|
|
- for (const part of userParts) {
|
|
|
|
|
- await updatePart(part)
|
|
|
|
|
- }
|
|
|
|
|
msgs.push({ info: userMsg, parts: userParts })
|
|
msgs.push({ info: userMsg, parts: userParts })
|
|
|
|
|
|
|
|
const mode = await Mode.get(input.mode ?? "build")
|
|
const mode = await Mode.get(input.mode ?? "build")
|
|
@@ -692,6 +720,51 @@ export namespace Session {
|
|
|
|
|
|
|
|
const stream = streamText({
|
|
const stream = streamText({
|
|
|
onError() {},
|
|
onError() {},
|
|
|
|
|
+ async prepareStep({ messages }) {
|
|
|
|
|
+ const queue = (state().queued.get(input.sessionID) ?? []).filter((x) => !x.processed)
|
|
|
|
|
+ if (queue.length) {
|
|
|
|
|
+ for (const item of queue) {
|
|
|
|
|
+ if (item.processed) continue
|
|
|
|
|
+ messages.push(
|
|
|
|
|
+ ...MessageV2.toModelMessage([
|
|
|
|
|
+ {
|
|
|
|
|
+ info: item.message,
|
|
|
|
|
+ parts: item.parts,
|
|
|
|
|
+ },
|
|
|
|
|
+ ]),
|
|
|
|
|
+ )
|
|
|
|
|
+ item.processed = true
|
|
|
|
|
+ }
|
|
|
|
|
+ assistantMsg.time.completed = Date.now()
|
|
|
|
|
+ await updateMessage(assistantMsg)
|
|
|
|
|
+ Object.assign(assistantMsg, {
|
|
|
|
|
+ id: Identifier.ascending("message"),
|
|
|
|
|
+ role: "assistant",
|
|
|
|
|
+ system,
|
|
|
|
|
+ path: {
|
|
|
|
|
+ cwd: app.path.cwd,
|
|
|
|
|
+ root: app.path.root,
|
|
|
|
|
+ },
|
|
|
|
|
+ cost: 0,
|
|
|
|
|
+ tokens: {
|
|
|
|
|
+ input: 0,
|
|
|
|
|
+ output: 0,
|
|
|
|
|
+ reasoning: 0,
|
|
|
|
|
+ cache: { read: 0, write: 0 },
|
|
|
|
|
+ },
|
|
|
|
|
+ modelID: input.modelID,
|
|
|
|
|
+ providerID: input.providerID,
|
|
|
|
|
+ time: {
|
|
|
|
|
+ created: Date.now(),
|
|
|
|
|
+ },
|
|
|
|
|
+ sessionID: input.sessionID,
|
|
|
|
|
+ })
|
|
|
|
|
+ await updateMessage(assistantMsg)
|
|
|
|
|
+ }
|
|
|
|
|
+ return {
|
|
|
|
|
+ messages,
|
|
|
|
|
+ }
|
|
|
|
|
+ },
|
|
|
maxRetries: 10,
|
|
maxRetries: 10,
|
|
|
maxOutputTokens: outputLimit,
|
|
maxOutputTokens: outputLimit,
|
|
|
abortSignal: abort.signal,
|
|
abortSignal: abort.signal,
|
|
@@ -726,6 +799,12 @@ export namespace Session {
|
|
|
}),
|
|
}),
|
|
|
})
|
|
})
|
|
|
const result = await processor.process(stream)
|
|
const result = await processor.process(stream)
|
|
|
|
|
+ const queued = (state().queued.get(input.sessionID) ?? []).find((item) => !item.processed)
|
|
|
|
|
+ if (queued) {
|
|
|
|
|
+ queued.processed = true
|
|
|
|
|
+ return chat(queued.input)
|
|
|
|
|
+ }
|
|
|
|
|
+ state().queued.delete(input.sessionID)
|
|
|
return result
|
|
return result
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -1087,6 +1166,10 @@ export namespace Session {
|
|
|
return result
|
|
return result
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ function isLocked(sessionID: string) {
|
|
|
|
|
+ return state().pending.has(sessionID)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
function lock(sessionID: string) {
|
|
function lock(sessionID: string) {
|
|
|
log.info("locking", { sessionID })
|
|
log.info("locking", { sessionID })
|
|
|
if (state().pending.has(sessionID)) throw new BusyError(sessionID)
|
|
if (state().pending.has(sessionID)) throw new BusyError(sessionID)
|