| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- import { Provider } from "@/provider/provider"
- import { fn } from "@/util/fn"
- import z from "zod"
- import { Session } from "."
- import { generateText, type ModelMessage } from "ai"
- import { MessageV2 } from "./message-v2"
- import { Identifier } from "@/id/id"
- import { Snapshot } from "@/snapshot"
- import { ProviderTransform } from "@/provider/transform"
- import { SystemPrompt } from "./system"
- import { Log } from "@/util/log"
- import path from "path"
- import { Instance } from "@/project/instance"
- import { Storage } from "@/storage/storage"
- import { Bus } from "@/bus"
- import { mergeDeep, pipe } from "remeda"
- export namespace SessionSummary {
- const log = Log.create({ service: "session.summary" })
- export const summarize = fn(
- z.object({
- sessionID: z.string(),
- messageID: z.string(),
- }),
- async (input) => {
- const all = await Session.messages({ sessionID: input.sessionID })
- await Promise.all([
- summarizeSession({ sessionID: input.sessionID, messages: all }),
- summarizeMessage({ messageID: input.messageID, messages: all }),
- ])
- },
- )
- async function summarizeSession(input: { sessionID: string; messages: MessageV2.WithParts[] }) {
- const files = new Set(
- input.messages
- .flatMap((x) => x.parts)
- .filter((x) => x.type === "patch")
- .flatMap((x) => x.files)
- .map((x) => path.relative(Instance.worktree, x)),
- )
- const diffs = await computeDiff({ messages: input.messages }).then((x) =>
- x.filter((x) => {
- return files.has(x.file)
- }),
- )
- await Session.update(input.sessionID, (draft) => {
- draft.summary = {
- additions: diffs.reduce((sum, x) => sum + x.additions, 0),
- deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
- files: diffs.length,
- }
- })
- await Storage.write(["session_diff", input.sessionID], diffs)
- Bus.publish(Session.Event.Diff, {
- sessionID: input.sessionID,
- diff: diffs,
- })
- }
- async function summarizeMessage(input: { messageID: string; messages: MessageV2.WithParts[] }) {
- const messages = input.messages.filter(
- (m) => m.info.id === input.messageID || (m.info.role === "assistant" && m.info.parentID === input.messageID),
- )
- const msgWithParts = messages.find((m) => m.info.id === input.messageID)!
- const userMsg = msgWithParts.info as MessageV2.User
- const diffs = await computeDiff({ messages })
- userMsg.summary = {
- ...userMsg.summary,
- diffs,
- }
- await Session.updateMessage(userMsg)
- const assistantMsg = messages.find((m) => m.info.role === "assistant")!.info as MessageV2.Assistant
- const small =
- (await Provider.getSmallModel(assistantMsg.providerID)) ??
- (await Provider.getModel(assistantMsg.providerID, assistantMsg.modelID))
- const options = pipe(
- {},
- mergeDeep(ProviderTransform.options(small.providerID, small.modelID, small.npm ?? "", assistantMsg.sessionID)),
- mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, modelID: small.modelID })),
- mergeDeep(small.info.options),
- )
- const textPart = msgWithParts.parts.find((p) => p.type === "text" && !p.synthetic) as MessageV2.TextPart
- if (textPart && !userMsg.summary?.title) {
- const result = await generateText({
- maxOutputTokens: small.info.reasoning ? 1500 : 20,
- providerOptions: ProviderTransform.providerOptions(small.npm, small.providerID, options),
- messages: [
- ...SystemPrompt.title(small.providerID).map(
- (x): ModelMessage => ({
- role: "system",
- content: x,
- }),
- ),
- {
- role: "user" as const,
- content: `
- The following is the text to summarize:
- <text>
- ${textPart?.text ?? ""}
- </text>
- `,
- },
- ],
- headers: small.info.headers,
- model: small.language,
- })
- log.info("title", { title: result.text })
- userMsg.summary.title = result.text
- await Session.updateMessage(userMsg)
- }
- if (
- messages.some(
- (m) =>
- m.info.role === "assistant" && m.parts.some((p) => p.type === "step-finish" && p.reason !== "tool-calls"),
- )
- ) {
- let summary = messages
- .findLast((m) => m.info.role === "assistant")
- ?.parts.findLast((p) => p.type === "text")?.text
- if (!summary || diffs.length > 0) {
- for (const msg of messages) {
- for (const part of msg.parts) {
- if (part.type === "tool" && part.state.status === "completed") {
- part.state.output = "[TOOL OUTPUT PRUNED]"
- }
- }
- }
- const result = await generateText({
- model: small.language,
- maxOutputTokens: 100,
- providerOptions: ProviderTransform.providerOptions(small.npm, small.providerID, options),
- messages: [
- ...SystemPrompt.summarize(small.providerID).map(
- (x): ModelMessage => ({
- role: "system",
- content: x,
- }),
- ),
- ...MessageV2.toModelMessage(messages),
- {
- role: "user",
- content: `Summarize the above conversation according to your system prompts.`,
- },
- ],
- headers: small.info.headers,
- }).catch(() => {})
- if (result) summary = result.text
- }
- userMsg.summary.body = summary
- log.info("body", { body: summary })
- await Session.updateMessage(userMsg)
- }
- }
- export const diff = fn(
- z.object({
- sessionID: Identifier.schema("session"),
- messageID: Identifier.schema("message").optional(),
- }),
- async (input) => {
- return Storage.read<Snapshot.FileDiff[]>(["session_diff", input.sessionID]).catch(() => [])
- },
- )
- async function computeDiff(input: { messages: MessageV2.WithParts[] }) {
- let from: string | undefined
- let to: string | undefined
- // scan assistant messages to find earliest from and latest to
- // snapshot
- for (const item of input.messages) {
- if (!from) {
- for (const part of item.parts) {
- if (part.type === "step-start" && part.snapshot) {
- from = part.snapshot
- break
- }
- }
- }
- for (const part of item.parts) {
- if (part.type === "step-finish" && part.snapshot) {
- to = part.snapshot
- break
- }
- }
- }
- if (from && to) return Snapshot.diffFull(from, to)
- return []
- }
- }
|