summary.ts 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import { Provider } from "@/provider/provider"
  2. import { fn } from "@/util/fn"
  3. import z from "zod"
  4. import { Session } from "."
  5. import { generateText, type ModelMessage } from "ai"
  6. import { MessageV2 } from "./message-v2"
  7. import { Identifier } from "@/id/id"
  8. import { Snapshot } from "@/snapshot"
  9. import { ProviderTransform } from "@/provider/transform"
  10. import { SystemPrompt } from "./system"
  11. import { Log } from "@/util/log"
  12. import path from "path"
  13. import { Instance } from "@/project/instance"
  14. import { Storage } from "@/storage/storage"
  15. import { Bus } from "@/bus"
  16. import { mergeDeep, pipe } from "remeda"
  17. export namespace SessionSummary {
  18. const log = Log.create({ service: "session.summary" })
  19. export const summarize = fn(
  20. z.object({
  21. sessionID: z.string(),
  22. messageID: z.string(),
  23. }),
  24. async (input) => {
  25. const all = await Session.messages({ sessionID: input.sessionID })
  26. await Promise.all([
  27. summarizeSession({ sessionID: input.sessionID, messages: all }),
  28. summarizeMessage({ messageID: input.messageID, messages: all }),
  29. ])
  30. },
  31. )
  32. async function summarizeSession(input: { sessionID: string; messages: MessageV2.WithParts[] }) {
  33. const files = new Set(
  34. input.messages
  35. .flatMap((x) => x.parts)
  36. .filter((x) => x.type === "patch")
  37. .flatMap((x) => x.files)
  38. .map((x) => path.relative(Instance.worktree, x)),
  39. )
  40. const diffs = await computeDiff({ messages: input.messages }).then((x) =>
  41. x.filter((x) => {
  42. return files.has(x.file)
  43. }),
  44. )
  45. await Session.update(input.sessionID, (draft) => {
  46. draft.summary = {
  47. additions: diffs.reduce((sum, x) => sum + x.additions, 0),
  48. deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
  49. files: diffs.length,
  50. }
  51. })
  52. await Storage.write(["session_diff", input.sessionID], diffs)
  53. Bus.publish(Session.Event.Diff, {
  54. sessionID: input.sessionID,
  55. diff: diffs,
  56. })
  57. }
  58. async function summarizeMessage(input: { messageID: string; messages: MessageV2.WithParts[] }) {
  59. const messages = input.messages.filter(
  60. (m) => m.info.id === input.messageID || (m.info.role === "assistant" && m.info.parentID === input.messageID),
  61. )
  62. const msgWithParts = messages.find((m) => m.info.id === input.messageID)!
  63. const userMsg = msgWithParts.info as MessageV2.User
  64. const diffs = await computeDiff({ messages })
  65. userMsg.summary = {
  66. ...userMsg.summary,
  67. diffs,
  68. }
  69. await Session.updateMessage(userMsg)
  70. const assistantMsg = messages.find((m) => m.info.role === "assistant")!.info as MessageV2.Assistant
  71. const small =
  72. (await Provider.getSmallModel(assistantMsg.providerID)) ??
  73. (await Provider.getModel(assistantMsg.providerID, assistantMsg.modelID))
  74. const options = pipe(
  75. {},
  76. mergeDeep(ProviderTransform.options(small.providerID, small.modelID, small.npm ?? "", assistantMsg.sessionID)),
  77. mergeDeep(ProviderTransform.smallOptions({ providerID: small.providerID, modelID: small.modelID })),
  78. mergeDeep(small.info.options),
  79. )
  80. const textPart = msgWithParts.parts.find((p) => p.type === "text" && !p.synthetic) as MessageV2.TextPart
  81. if (textPart && !userMsg.summary?.title) {
  82. const result = await generateText({
  83. maxOutputTokens: small.info.reasoning ? 1500 : 20,
  84. providerOptions: ProviderTransform.providerOptions(small.npm, small.providerID, options),
  85. messages: [
  86. ...SystemPrompt.title(small.providerID).map(
  87. (x): ModelMessage => ({
  88. role: "system",
  89. content: x,
  90. }),
  91. ),
  92. {
  93. role: "user" as const,
  94. content: `
  95. The following is the text to summarize:
  96. <text>
  97. ${textPart?.text ?? ""}
  98. </text>
  99. `,
  100. },
  101. ],
  102. headers: small.info.headers,
  103. model: small.language,
  104. })
  105. log.info("title", { title: result.text })
  106. userMsg.summary.title = result.text
  107. await Session.updateMessage(userMsg)
  108. }
  109. if (
  110. messages.some(
  111. (m) =>
  112. m.info.role === "assistant" && m.parts.some((p) => p.type === "step-finish" && p.reason !== "tool-calls"),
  113. )
  114. ) {
  115. let summary = messages
  116. .findLast((m) => m.info.role === "assistant")
  117. ?.parts.findLast((p) => p.type === "text")?.text
  118. if (!summary || diffs.length > 0) {
  119. for (const msg of messages) {
  120. for (const part of msg.parts) {
  121. if (part.type === "tool" && part.state.status === "completed") {
  122. part.state.output = "[TOOL OUTPUT PRUNED]"
  123. }
  124. }
  125. }
  126. const result = await generateText({
  127. model: small.language,
  128. maxOutputTokens: 100,
  129. providerOptions: ProviderTransform.providerOptions(small.npm, small.providerID, options),
  130. messages: [
  131. ...SystemPrompt.summarize(small.providerID).map(
  132. (x): ModelMessage => ({
  133. role: "system",
  134. content: x,
  135. }),
  136. ),
  137. ...MessageV2.toModelMessage(messages),
  138. {
  139. role: "user",
  140. content: `Summarize the above conversation according to your system prompts.`,
  141. },
  142. ],
  143. headers: small.info.headers,
  144. }).catch(() => {})
  145. if (result) summary = result.text
  146. }
  147. userMsg.summary.body = summary
  148. log.info("body", { body: summary })
  149. await Session.updateMessage(userMsg)
  150. }
  151. }
  152. export const diff = fn(
  153. z.object({
  154. sessionID: Identifier.schema("session"),
  155. messageID: Identifier.schema("message").optional(),
  156. }),
  157. async (input) => {
  158. return Storage.read<Snapshot.FileDiff[]>(["session_diff", input.sessionID]).catch(() => [])
  159. },
  160. )
  161. async function computeDiff(input: { messages: MessageV2.WithParts[] }) {
  162. let from: string | undefined
  163. let to: string | undefined
  164. // scan assistant messages to find earliest from and latest to
  165. // snapshot
  166. for (const item of input.messages) {
  167. if (!from) {
  168. for (const part of item.parts) {
  169. if (part.type === "step-start" && part.snapshot) {
  170. from = part.snapshot
  171. break
  172. }
  173. }
  174. }
  175. for (const part of item.parts) {
  176. if (part.type === "step-finish" && part.snapshot) {
  177. to = part.snapshot
  178. break
  179. }
  180. }
  181. }
  182. if (from && to) return Snapshot.diffFull(from, to)
  183. return []
  184. }
  185. }