compaction.ts 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import { BusEvent } from "@/bus/bus-event"
  2. import { Bus } from "@/bus"
  3. import { Session } from "."
  4. import { Identifier } from "../id/id"
  5. import { Instance } from "../project/instance"
  6. import { Provider } from "../provider/provider"
  7. import { MessageV2 } from "./message-v2"
  8. import z from "zod"
  9. import { SessionPrompt } from "./prompt"
  10. import { Token } from "../util/token"
  11. import { Log } from "../util/log"
  12. import { SessionProcessor } from "./processor"
  13. import { fn } from "@/util/fn"
  14. import { Agent } from "@/agent/agent"
  15. import { Plugin } from "@/plugin"
  16. import { Config } from "@/config/config"
  17. export namespace SessionCompaction {
  18. const log = Log.create({ service: "session.compaction" })
  19. export const Event = {
  20. Compacted: BusEvent.define(
  21. "session.compacted",
  22. z.object({
  23. sessionID: z.string(),
  24. }),
  25. ),
  26. }
  27. export async function isOverflow(input: { tokens: MessageV2.Assistant["tokens"]; model: Provider.Model }) {
  28. const config = await Config.get()
  29. if (config.compaction?.auto === false) return false
  30. const context = input.model.limit.context
  31. if (context === 0) return false
  32. const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
  33. const output = Math.min(input.model.limit.output, SessionPrompt.OUTPUT_TOKEN_MAX) || SessionPrompt.OUTPUT_TOKEN_MAX
  34. const usable = context - output
  35. return count > usable
  36. }
  37. export const PRUNE_MINIMUM = 20_000
  38. export const PRUNE_PROTECT = 40_000
  39. const PRUNE_PROTECTED_TOOLS = ["skill"]
  40. // goes backwards through parts until there are 40_000 tokens worth of tool
  41. // calls. then erases output of previous tool calls. idea is to throw away old
  42. // tool calls that are no longer relevant.
  43. export async function prune(input: { sessionID: string }) {
  44. const config = await Config.get()
  45. if (config.compaction?.prune === false) return
  46. log.info("pruning")
  47. const msgs = await Session.messages({ sessionID: input.sessionID })
  48. let total = 0
  49. let pruned = 0
  50. const toPrune = []
  51. let turns = 0
  52. loop: for (let msgIndex = msgs.length - 1; msgIndex >= 0; msgIndex--) {
  53. const msg = msgs[msgIndex]
  54. if (msg.info.role === "user") turns++
  55. if (turns < 2) continue
  56. if (msg.info.role === "assistant" && msg.info.summary) break loop
  57. for (let partIndex = msg.parts.length - 1; partIndex >= 0; partIndex--) {
  58. const part = msg.parts[partIndex]
  59. if (part.type === "tool")
  60. if (part.state.status === "completed") {
  61. if (PRUNE_PROTECTED_TOOLS.includes(part.tool)) continue
  62. if (part.state.time.compacted) break loop
  63. const estimate = Token.estimate(part.state.output)
  64. total += estimate
  65. if (total > PRUNE_PROTECT) {
  66. pruned += estimate
  67. toPrune.push(part)
  68. }
  69. }
  70. }
  71. }
  72. log.info("found", { pruned, total })
  73. if (pruned > PRUNE_MINIMUM) {
  74. for (const part of toPrune) {
  75. if (part.state.status === "completed") {
  76. part.state.time.compacted = Date.now()
  77. await Session.updatePart(part)
  78. }
  79. }
  80. log.info("pruned", { count: toPrune.length })
  81. }
  82. }
  83. export async function process(input: {
  84. parentID: string
  85. messages: MessageV2.WithParts[]
  86. sessionID: string
  87. abort: AbortSignal
  88. auto: boolean
  89. }) {
  90. const userMessage = input.messages.findLast((m) => m.info.id === input.parentID)!.info as MessageV2.User
  91. const agent = await Agent.get("compaction")
  92. const model = agent.model
  93. ? await Provider.getModel(agent.model.providerID, agent.model.modelID)
  94. : await Provider.getModel(userMessage.model.providerID, userMessage.model.modelID)
  95. const msg = (await Session.updateMessage({
  96. id: Identifier.ascending("message"),
  97. role: "assistant",
  98. parentID: input.parentID,
  99. sessionID: input.sessionID,
  100. mode: "compaction",
  101. agent: "compaction",
  102. summary: true,
  103. path: {
  104. cwd: Instance.directory,
  105. root: Instance.worktree,
  106. },
  107. cost: 0,
  108. tokens: {
  109. output: 0,
  110. input: 0,
  111. reasoning: 0,
  112. cache: { read: 0, write: 0 },
  113. },
  114. modelID: model.id,
  115. providerID: model.providerID,
  116. time: {
  117. created: Date.now(),
  118. },
  119. })) as MessageV2.Assistant
  120. const processor = SessionProcessor.create({
  121. assistantMessage: msg,
  122. sessionID: input.sessionID,
  123. model,
  124. abort: input.abort,
  125. })
  126. // Allow plugins to inject context or replace compaction prompt
  127. const compacting = await Plugin.trigger(
  128. "experimental.session.compacting",
  129. { sessionID: input.sessionID },
  130. { context: [], prompt: undefined },
  131. )
  132. const defaultPrompt =
  133. "Provide a detailed prompt for continuing our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next considering new session will not have access to our conversation."
  134. const promptText = compacting.prompt ?? [defaultPrompt, ...compacting.context].join("\n\n")
  135. const result = await processor.process({
  136. user: userMessage,
  137. agent,
  138. abort: input.abort,
  139. sessionID: input.sessionID,
  140. tools: {},
  141. system: [],
  142. messages: [
  143. ...MessageV2.toModelMessage(input.messages),
  144. {
  145. role: "user",
  146. content: [
  147. {
  148. type: "text",
  149. text: promptText,
  150. },
  151. ],
  152. },
  153. ],
  154. model,
  155. })
  156. if (result === "continue" && input.auto) {
  157. const continueMsg = await Session.updateMessage({
  158. id: Identifier.ascending("message"),
  159. role: "user",
  160. sessionID: input.sessionID,
  161. time: {
  162. created: Date.now(),
  163. },
  164. agent: userMessage.agent,
  165. model: userMessage.model,
  166. })
  167. await Session.updatePart({
  168. id: Identifier.ascending("part"),
  169. messageID: continueMsg.id,
  170. sessionID: input.sessionID,
  171. type: "text",
  172. synthetic: true,
  173. text: "Continue if you have next steps",
  174. time: {
  175. start: Date.now(),
  176. end: Date.now(),
  177. },
  178. })
  179. }
  180. if (processor.message.error) return "stop"
  181. Bus.publish(Event.Compacted, { sessionID: input.sessionID })
  182. return "continue"
  183. }
  184. export const create = fn(
  185. z.object({
  186. sessionID: Identifier.schema("session"),
  187. agent: z.string(),
  188. model: z.object({
  189. providerID: z.string(),
  190. modelID: z.string(),
  191. }),
  192. auto: z.boolean(),
  193. }),
  194. async (input) => {
  195. const msg = await Session.updateMessage({
  196. id: Identifier.ascending("message"),
  197. role: "user",
  198. model: input.model,
  199. sessionID: input.sessionID,
  200. agent: input.agent,
  201. time: {
  202. created: Date.now(),
  203. },
  204. })
  205. await Session.updatePart({
  206. id: Identifier.ascending("part"),
  207. messageID: msg.id,
  208. sessionID: msg.sessionID,
  209. type: "compaction",
  210. auto: input.auto,
  211. })
  212. },
  213. )
  214. }