summary.ts 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import { Provider } from "@/provider/provider"
  2. import { fn } from "@/util/fn"
  3. import z from "zod"
  4. import { Session } from "."
  5. import { generateText, type ModelMessage } from "ai"
  6. import { MessageV2 } from "./message-v2"
  7. import { Identifier } from "@/id/id"
  8. import { Snapshot } from "@/snapshot"
  9. import { ProviderTransform } from "@/provider/transform"
  10. import { SystemPrompt } from "./system"
  11. import { Log } from "@/util/log"
  12. import path from "path"
  13. import { Instance } from "@/project/instance"
  14. import { Storage } from "@/storage/storage"
  15. import { Bus } from "@/bus"
  16. export namespace SessionSummary {
  17. const log = Log.create({ service: "session.summary" })
  18. export const summarize = fn(
  19. z.object({
  20. sessionID: z.string(),
  21. messageID: z.string(),
  22. }),
  23. async (input) => {
  24. const all = await Session.messages({ sessionID: input.sessionID })
  25. await Promise.all([
  26. summarizeSession({ sessionID: input.sessionID, messages: all }),
  27. summarizeMessage({ messageID: input.messageID, messages: all }),
  28. ])
  29. },
  30. )
  31. async function summarizeSession(input: { sessionID: string; messages: MessageV2.WithParts[] }) {
  32. const files = new Set(
  33. input.messages
  34. .flatMap((x) => x.parts)
  35. .filter((x) => x.type === "patch")
  36. .flatMap((x) => x.files)
  37. .map((x) => path.relative(Instance.worktree, x)),
  38. )
  39. const diffs = await computeDiff({ messages: input.messages }).then((x) =>
  40. x.filter((x) => {
  41. return files.has(x.file)
  42. }),
  43. )
  44. await Session.update(input.sessionID, (draft) => {
  45. draft.summary = {
  46. additions: diffs.reduce((sum, x) => sum + x.additions, 0),
  47. deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
  48. files: diffs.length,
  49. }
  50. })
  51. await Storage.write(["session_diff", input.sessionID], diffs)
  52. Bus.publish(Session.Event.Diff, {
  53. sessionID: input.sessionID,
  54. diff: diffs,
  55. })
  56. }
  57. async function summarizeMessage(input: { messageID: string; messages: MessageV2.WithParts[] }) {
  58. const messages = input.messages.filter(
  59. (m) =>
  60. m.info.id === input.messageID ||
  61. (m.info.role === "assistant" && m.info.parentID === input.messageID),
  62. )
  63. const msgWithParts = messages.find((m) => m.info.id === input.messageID)!
  64. const userMsg = msgWithParts.info as MessageV2.User
  65. const diffs = await computeDiff({ messages })
  66. userMsg.summary = {
  67. ...userMsg.summary,
  68. diffs,
  69. }
  70. await Session.updateMessage(userMsg)
  71. const assistantMsg = messages.find((m) => m.info.role === "assistant")!
  72. .info as MessageV2.Assistant
  73. const small = await Provider.getSmallModel(assistantMsg.providerID)
  74. if (!small) return
  75. const textPart = msgWithParts.parts.find(
  76. (p) => p.type === "text" && !p.synthetic,
  77. ) as MessageV2.TextPart
  78. if (textPart && !userMsg.summary?.title) {
  79. const result = await generateText({
  80. maxOutputTokens: small.info.reasoning ? 1500 : 20,
  81. providerOptions: ProviderTransform.providerOptions(small.npm, small.providerID, {}),
  82. messages: [
  83. ...SystemPrompt.title(small.providerID).map(
  84. (x): ModelMessage => ({
  85. role: "system",
  86. content: x,
  87. }),
  88. ),
  89. {
  90. role: "user" as const,
  91. content: `
  92. The following is the text to summarize:
  93. <text>
  94. ${textPart?.text ?? ""}
  95. </text>
  96. `,
  97. },
  98. ],
  99. headers: small.info.headers,
  100. model: small.language,
  101. })
  102. log.info("title", { title: result.text })
  103. userMsg.summary.title = result.text
  104. await Session.updateMessage(userMsg)
  105. }
  106. if (
  107. messages.some(
  108. (m) =>
  109. m.info.role === "assistant" &&
  110. m.parts.some((p) => p.type === "step-finish" && p.reason !== "tool-calls"),
  111. )
  112. ) {
  113. let summary = messages
  114. .findLast((m) => m.info.role === "assistant")
  115. ?.parts.findLast((p) => p.type === "text")?.text
  116. if (!summary || diffs.length > 0) {
  117. const result = await generateText({
  118. model: small.language,
  119. maxOutputTokens: 100,
  120. messages: [
  121. {
  122. role: "user",
  123. content: `
  124. Summarize the following conversation into 2 sentences MAX explaining what the assistant did and why. Do not explain the user's input. Do not speak in the third person about the assistant.
  125. <conversation>
  126. ${JSON.stringify(MessageV2.toModelMessage(messages))}
  127. </conversation>
  128. `,
  129. },
  130. ],
  131. headers: small.info.headers,
  132. }).catch(() => {})
  133. if (result) summary = result.text
  134. }
  135. userMsg.summary.body = summary
  136. log.info("body", { body: summary })
  137. await Session.updateMessage(userMsg)
  138. }
  139. }
  140. export const diff = fn(
  141. z.object({
  142. sessionID: Identifier.schema("session"),
  143. messageID: Identifier.schema("message").optional(),
  144. }),
  145. async (input) => {
  146. return Storage.read<Snapshot.FileDiff[]>(["session_diff", input.sessionID]) ?? []
  147. },
  148. )
  149. async function computeDiff(input: { messages: MessageV2.WithParts[] }) {
  150. let from: string | undefined
  151. let to: string | undefined
  152. // scan assistant messages to find earliest from and latest to
  153. // snapshot
  154. for (const item of input.messages) {
  155. if (!from) {
  156. for (const part of item.parts) {
  157. if (part.type === "step-start" && part.snapshot) {
  158. from = part.snapshot
  159. break
  160. }
  161. }
  162. }
  163. for (const part of item.parts) {
  164. if (part.type === "step-finish" && part.snapshot) {
  165. to = part.snapshot
  166. break
  167. }
  168. }
  169. }
  170. if (from && to) return Snapshot.diffFull(from, to)
  171. return []
  172. }
  173. }