summary.ts 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import { Provider } from "@/provider/provider"
  2. import { fn } from "@/util/fn"
  3. import z from "zod"
  4. import { Session } from "."
  5. import { generateText, type ModelMessage } from "ai"
  6. import { MessageV2 } from "./message-v2"
  7. import { Identifier } from "@/id/id"
  8. import { Snapshot } from "@/snapshot"
  9. import { ProviderTransform } from "@/provider/transform"
  10. import { SystemPrompt } from "./system"
  11. import { Log } from "@/util/log"
  12. import path from "path"
  13. import { Instance } from "@/project/instance"
  14. export namespace SessionSummary {
  15. const log = Log.create({ service: "session.summary" })
  16. export const summarize = fn(
  17. z.object({
  18. sessionID: z.string(),
  19. messageID: z.string(),
  20. }),
  21. async (input) => {
  22. const all = await Session.messages(input.sessionID)
  23. await Promise.all([
  24. summarizeSession({ sessionID: input.sessionID, messages: all }),
  25. summarizeMessage({ messageID: input.messageID, messages: all }),
  26. ])
  27. },
  28. )
  29. async function summarizeSession(input: { sessionID: string; messages: MessageV2.WithParts[] }) {
  30. const files = new Set(
  31. input.messages
  32. .flatMap((x) => x.parts)
  33. .filter((x) => x.type === "patch")
  34. .flatMap((x) => x.files)
  35. .map((x) => path.relative(Instance.worktree, x)),
  36. )
  37. const diffs = await computeDiff({ messages: input.messages }).then((x) =>
  38. x.filter((x) => {
  39. return files.has(x.file)
  40. }),
  41. )
  42. await Session.update(input.sessionID, (draft) => {
  43. draft.summary = {
  44. diffs,
  45. }
  46. })
  47. }
  48. async function summarizeMessage(input: { messageID: string; messages: MessageV2.WithParts[] }) {
  49. const messages = input.messages.filter(
  50. (m) =>
  51. m.info.id === input.messageID ||
  52. (m.info.role === "assistant" && m.info.parentID === input.messageID),
  53. )
  54. const msgWithParts = messages.find((m) => m.info.id === input.messageID)!
  55. const userMsg = msgWithParts.info as MessageV2.User
  56. const diffs = await computeDiff({ messages })
  57. userMsg.summary = {
  58. ...userMsg.summary,
  59. diffs,
  60. }
  61. await Session.updateMessage(userMsg)
  62. const assistantMsg = messages.find((m) => m.info.role === "assistant")!
  63. .info as MessageV2.Assistant
  64. const small = await Provider.getSmallModel(assistantMsg.providerID)
  65. if (!small) return
  66. const textPart = msgWithParts.parts.find(
  67. (p) => p.type === "text" && !p.synthetic,
  68. ) as MessageV2.TextPart
  69. if (textPart && !userMsg.summary?.title) {
  70. const result = await generateText({
  71. maxOutputTokens: small.info.reasoning ? 1500 : 20,
  72. providerOptions: ProviderTransform.providerOptions(small.npm, small.providerID, {}),
  73. messages: [
  74. ...SystemPrompt.title(small.providerID).map(
  75. (x): ModelMessage => ({
  76. role: "system",
  77. content: x,
  78. }),
  79. ),
  80. {
  81. role: "user" as const,
  82. content: `
  83. The following is the text to summarize:
  84. <text>
  85. ${textPart?.text ?? ""}
  86. </text>
  87. `,
  88. },
  89. ],
  90. headers: small.info.headers,
  91. model: small.language,
  92. })
  93. log.info("title", { title: result.text })
  94. userMsg.summary.title = result.text
  95. await Session.updateMessage(userMsg)
  96. }
  97. if (
  98. messages.some(
  99. (m) =>
  100. m.info.role === "assistant" &&
  101. m.parts.some((p) => p.type === "step-finish" && p.reason !== "tool-calls"),
  102. )
  103. ) {
  104. let summary = messages
  105. .findLast((m) => m.info.role === "assistant")
  106. ?.parts.findLast((p) => p.type === "text")?.text
  107. if (!summary || diffs.length > 0) {
  108. const result = await generateText({
  109. model: small.language,
  110. maxOutputTokens: 100,
  111. messages: [
  112. {
  113. role: "user",
  114. content: `
  115. Summarize the following conversation into 2 sentences MAX explaining what the assistant did and why. Do not explain the user's input. Do not speak in the third person about the assistant.
  116. <conversation>
  117. ${JSON.stringify(MessageV2.toModelMessage(messages))}
  118. </conversation>
  119. `,
  120. },
  121. ],
  122. headers: small.info.headers,
  123. }).catch(() => {})
  124. if (result) summary = result.text
  125. }
  126. userMsg.summary.body = summary
  127. log.info("body", { body: summary })
  128. await Session.updateMessage(userMsg)
  129. }
  130. }
  131. export const diff = fn(
  132. z.object({
  133. sessionID: Identifier.schema("session"),
  134. messageID: Identifier.schema("message").optional(),
  135. }),
  136. async (input) => {
  137. let all = await Session.messages(input.sessionID)
  138. if (input.messageID)
  139. all = all.filter(
  140. (x) =>
  141. x.info.id === input.messageID ||
  142. (x.info.role === "assistant" && x.info.parentID === input.messageID),
  143. )
  144. return computeDiff({
  145. messages: all,
  146. })
  147. },
  148. )
  149. async function computeDiff(input: { messages: MessageV2.WithParts[] }) {
  150. let from: string | undefined
  151. let to: string | undefined
  152. // scan assistant messages to find earliest from and latest to
  153. // snapshot
  154. for (const item of input.messages) {
  155. if (!from) {
  156. for (const part of item.parts) {
  157. if (part.type === "step-start" && part.snapshot) {
  158. from = part.snapshot
  159. break
  160. }
  161. }
  162. }
  163. for (const part of item.parts) {
  164. if (part.type === "step-finish" && part.snapshot) {
  165. to = part.snapshot
  166. break
  167. }
  168. }
  169. }
  170. if (from && to) return Snapshot.diffFull(from, to)
  171. return []
  172. }
  173. }