summary.ts 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import { Provider } from "@/provider/provider"
  2. import { fn } from "@/util/fn"
  3. import z from "zod"
  4. import { Session } from "."
  5. import { MessageV2 } from "./message-v2"
  6. import { Identifier } from "@/id/id"
  7. import { Snapshot } from "@/snapshot"
  8. import { Log } from "@/util/log"
  9. import path from "path"
  10. import { Instance } from "@/project/instance"
  11. import { Storage } from "@/storage/storage"
  12. import { Bus } from "@/bus"
  13. import { LLM } from "./llm"
  14. import { Agent } from "@/agent/agent"
  15. export namespace SessionSummary {
  16. const log = Log.create({ service: "session.summary" })
  17. function unquoteGitPath(input: string) {
  18. if (!input.startsWith('"')) return input
  19. if (!input.endsWith('"')) return input
  20. const body = input.slice(1, -1)
  21. const bytes: number[] = []
  22. for (let i = 0; i < body.length; i++) {
  23. const char = body[i]!
  24. if (char !== "\\") {
  25. bytes.push(char.charCodeAt(0))
  26. continue
  27. }
  28. const next = body[i + 1]
  29. if (!next) {
  30. bytes.push("\\".charCodeAt(0))
  31. continue
  32. }
  33. if (next >= "0" && next <= "7") {
  34. const chunk = body.slice(i + 1, i + 4)
  35. const match = chunk.match(/^[0-7]{1,3}/)
  36. if (!match) {
  37. bytes.push(next.charCodeAt(0))
  38. i++
  39. continue
  40. }
  41. bytes.push(parseInt(match[0], 8))
  42. i += match[0].length
  43. continue
  44. }
  45. const escaped =
  46. next === "n"
  47. ? "\n"
  48. : next === "r"
  49. ? "\r"
  50. : next === "t"
  51. ? "\t"
  52. : next === "b"
  53. ? "\b"
  54. : next === "f"
  55. ? "\f"
  56. : next === "v"
  57. ? "\v"
  58. : next === "\\" || next === '"'
  59. ? next
  60. : undefined
  61. bytes.push((escaped ?? next).charCodeAt(0))
  62. i++
  63. }
  64. return Buffer.from(bytes).toString()
  65. }
  66. export const summarize = fn(
  67. z.object({
  68. sessionID: z.string(),
  69. messageID: z.string(),
  70. }),
  71. async (input) => {
  72. const all = await Session.messages({ sessionID: input.sessionID })
  73. await Promise.all([
  74. summarizeSession({ sessionID: input.sessionID, messages: all }),
  75. summarizeMessage({ messageID: input.messageID, messages: all }),
  76. ])
  77. },
  78. )
  79. async function summarizeSession(input: { sessionID: string; messages: MessageV2.WithParts[] }) {
  80. const files = new Set(
  81. input.messages
  82. .flatMap((x) => x.parts)
  83. .filter((x) => x.type === "patch")
  84. .flatMap((x) => x.files)
  85. .map((x) => path.relative(Instance.worktree, x).replaceAll("\\", "/")),
  86. )
  87. const diffs = await computeDiff({ messages: input.messages }).then((x) =>
  88. x.filter((x) => {
  89. return files.has(x.file)
  90. }),
  91. )
  92. await Session.update(input.sessionID, (draft) => {
  93. draft.summary = {
  94. additions: diffs.reduce((sum, x) => sum + x.additions, 0),
  95. deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
  96. files: diffs.length,
  97. }
  98. })
  99. await Storage.write(["session_diff", input.sessionID], diffs)
  100. Bus.publish(Session.Event.Diff, {
  101. sessionID: input.sessionID,
  102. diff: diffs,
  103. })
  104. }
  105. async function summarizeMessage(input: { messageID: string; messages: MessageV2.WithParts[] }) {
  106. const messages = input.messages.filter(
  107. (m) => m.info.id === input.messageID || (m.info.role === "assistant" && m.info.parentID === input.messageID),
  108. )
  109. const msgWithParts = messages.find((m) => m.info.id === input.messageID)!
  110. const userMsg = msgWithParts.info as MessageV2.User
  111. const diffs = await computeDiff({ messages })
  112. userMsg.summary = {
  113. ...userMsg.summary,
  114. diffs,
  115. }
  116. await Session.updateMessage(userMsg)
  117. const textPart = msgWithParts.parts.find((p) => p.type === "text" && !p.synthetic) as MessageV2.TextPart
  118. if (textPart && !userMsg.summary?.title) {
  119. const agent = await Agent.get("title")
  120. if (!agent) return
  121. const stream = await LLM.stream({
  122. agent,
  123. user: userMsg,
  124. tools: {},
  125. model: agent.model
  126. ? await Provider.getModel(agent.model.providerID, agent.model.modelID)
  127. : ((await Provider.getSmallModel(userMsg.model.providerID)) ??
  128. (await Provider.getModel(userMsg.model.providerID, userMsg.model.modelID))),
  129. small: true,
  130. messages: [
  131. {
  132. role: "user" as const,
  133. content: `
  134. The following is the text to summarize:
  135. <text>
  136. ${textPart?.text ?? ""}
  137. </text>
  138. `,
  139. },
  140. ],
  141. abort: new AbortController().signal,
  142. sessionID: userMsg.sessionID,
  143. system: [],
  144. retries: 3,
  145. })
  146. const result = await stream.text
  147. log.info("title", { title: result })
  148. userMsg.summary.title = result
  149. await Session.updateMessage(userMsg)
  150. }
  151. }
  152. export const diff = fn(
  153. z.object({
  154. sessionID: Identifier.schema("session"),
  155. messageID: Identifier.schema("message").optional(),
  156. }),
  157. async (input) => {
  158. const diffs = await Storage.read<Snapshot.FileDiff[]>(["session_diff", input.sessionID]).catch(() => [])
  159. const next = diffs.map((item) => {
  160. const file = unquoteGitPath(item.file)
  161. if (file === item.file) return item
  162. return {
  163. ...item,
  164. file,
  165. }
  166. })
  167. const changed = next.some((item, i) => item.file !== diffs[i]?.file)
  168. if (changed) Storage.write(["session_diff", input.sessionID], next).catch(() => {})
  169. return next
  170. },
  171. )
  172. export async function computeDiff(input: { messages: MessageV2.WithParts[] }) {
  173. let from: string | undefined
  174. let to: string | undefined
  175. // scan assistant messages to find earliest from and latest to
  176. // snapshot
  177. for (const item of input.messages) {
  178. if (!from) {
  179. for (const part of item.parts) {
  180. if (part.type === "step-start" && part.snapshot) {
  181. from = part.snapshot
  182. break
  183. }
  184. }
  185. }
  186. for (const part of item.parts) {
  187. if (part.type === "step-finish" && part.snapshot) {
  188. to = part.snapshot
  189. break
  190. }
  191. }
  192. }
  193. if (from && to) return Snapshot.diffFull(from, to)
  194. return []
  195. }
  196. }