summary.ts 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import { Provider } from "@/provider/provider"
  2. import { fn } from "@/util/fn"
  3. import z from "zod"
  4. import { Session } from "."
  5. import { MessageV2 } from "./message-v2"
  6. import { Identifier } from "@/id/id"
  7. import { Snapshot } from "@/snapshot"
  8. import { Log } from "@/util/log"
  9. import path from "path"
  10. import { Instance } from "@/project/instance"
  11. import { Storage } from "@/storage/storage"
  12. import { Bus } from "@/bus"
  13. import { LLM } from "./llm"
  14. import { Agent } from "@/agent/agent"
  15. export namespace SessionSummary {
  16. const log = Log.create({ service: "session.summary" })
  17. export const summarize = fn(
  18. z.object({
  19. sessionID: z.string(),
  20. messageID: z.string(),
  21. }),
  22. async (input) => {
  23. const all = await Session.messages({ sessionID: input.sessionID })
  24. await Promise.all([
  25. summarizeSession({ sessionID: input.sessionID, messages: all }),
  26. summarizeMessage({ messageID: input.messageID, messages: all }),
  27. ])
  28. },
  29. )
  30. async function summarizeSession(input: { sessionID: string; messages: MessageV2.WithParts[] }) {
  31. const files = new Set(
  32. input.messages
  33. .flatMap((x) => x.parts)
  34. .filter((x) => x.type === "patch")
  35. .flatMap((x) => x.files)
  36. .map((x) => path.relative(Instance.worktree, x).replaceAll("\\", "/")),
  37. )
  38. const diffs = await computeDiff({ messages: input.messages }).then((x) =>
  39. x.filter((x) => {
  40. return files.has(x.file)
  41. }),
  42. )
  43. await Session.update(input.sessionID, (draft) => {
  44. draft.summary = {
  45. additions: diffs.reduce((sum, x) => sum + x.additions, 0),
  46. deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
  47. files: diffs.length,
  48. }
  49. })
  50. await Storage.write(["session_diff", input.sessionID], diffs)
  51. Bus.publish(Session.Event.Diff, {
  52. sessionID: input.sessionID,
  53. diff: diffs,
  54. })
  55. }
  56. async function summarizeMessage(input: { messageID: string; messages: MessageV2.WithParts[] }) {
  57. const messages = input.messages.filter(
  58. (m) => m.info.id === input.messageID || (m.info.role === "assistant" && m.info.parentID === input.messageID),
  59. )
  60. const msgWithParts = messages.find((m) => m.info.id === input.messageID)!
  61. const userMsg = msgWithParts.info as MessageV2.User
  62. const diffs = await computeDiff({ messages })
  63. userMsg.summary = {
  64. ...userMsg.summary,
  65. diffs,
  66. }
  67. await Session.updateMessage(userMsg)
  68. const textPart = msgWithParts.parts.find((p) => p.type === "text" && !p.synthetic) as MessageV2.TextPart
  69. if (textPart && !userMsg.summary?.title) {
  70. const agent = await Agent.get("title")
  71. if (!agent) return
  72. const stream = await LLM.stream({
  73. agent,
  74. user: userMsg,
  75. tools: {},
  76. model: agent.model
  77. ? await Provider.getModel(agent.model.providerID, agent.model.modelID)
  78. : ((await Provider.getSmallModel(userMsg.model.providerID)) ??
  79. (await Provider.getModel(userMsg.model.providerID, userMsg.model.modelID))),
  80. small: true,
  81. messages: [
  82. {
  83. role: "user" as const,
  84. content: `
  85. The following is the text to summarize:
  86. <text>
  87. ${textPart?.text ?? ""}
  88. </text>
  89. `,
  90. },
  91. ],
  92. abort: new AbortController().signal,
  93. sessionID: userMsg.sessionID,
  94. system: [],
  95. retries: 3,
  96. })
  97. const result = await stream.text
  98. log.info("title", { title: result })
  99. userMsg.summary.title = result
  100. await Session.updateMessage(userMsg)
  101. }
  102. }
  103. export const diff = fn(
  104. z.object({
  105. sessionID: Identifier.schema("session"),
  106. messageID: Identifier.schema("message").optional(),
  107. }),
  108. async (input) => {
  109. return Storage.read<Snapshot.FileDiff[]>(["session_diff", input.sessionID]).catch(() => [])
  110. },
  111. )
  112. export async function computeDiff(input: { messages: MessageV2.WithParts[] }) {
  113. let from: string | undefined
  114. let to: string | undefined
  115. // scan assistant messages to find earliest from and latest to
  116. // snapshot
  117. for (const item of input.messages) {
  118. if (!from) {
  119. for (const part of item.parts) {
  120. if (part.type === "step-start" && part.snapshot) {
  121. from = part.snapshot
  122. break
  123. }
  124. }
  125. }
  126. for (const part of item.parts) {
  127. if (part.type === "step-finish" && part.snapshot) {
  128. to = part.snapshot
  129. break
  130. }
  131. }
  132. }
  133. if (from && to) return Snapshot.diffFull(from, to)
  134. return []
  135. }
  136. }