compaction.ts 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import { wrapLanguageModel, type ModelMessage } from "ai"
  2. import { Session } from "."
  3. import { Identifier } from "../id/id"
  4. import { Instance } from "../project/instance"
  5. import { Provider } from "../provider/provider"
  6. import { MessageV2 } from "./message-v2"
  7. import { SystemPrompt } from "./system"
  8. import { Bus } from "../bus"
  9. import z from "zod"
  10. import { SessionPrompt } from "./prompt"
  11. import { Flag } from "../flag/flag"
  12. import { Token } from "../util/token"
  13. import { Config } from "../config/config"
  14. import { Log } from "../util/log"
  15. import { ProviderTransform } from "@/provider/transform"
  16. import { SessionProcessor } from "./processor"
  17. import { fn } from "@/util/fn"
  18. import { mergeDeep, pipe } from "remeda"
  19. export namespace SessionCompaction {
  20. const log = Log.create({ service: "session.compaction" })
  21. export const Event = {
  22. Compacted: Bus.event(
  23. "session.compacted",
  24. z.object({
  25. sessionID: z.string(),
  26. }),
  27. ),
  28. }
  29. export function isOverflow(input: { tokens: MessageV2.Assistant["tokens"]; model: Provider.Model }) {
  30. if (Flag.OPENCODE_DISABLE_AUTOCOMPACT) return false
  31. const context = input.model.limit.context
  32. if (context === 0) return false
  33. const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
  34. const output = Math.min(input.model.limit.output, SessionPrompt.OUTPUT_TOKEN_MAX) || SessionPrompt.OUTPUT_TOKEN_MAX
  35. const usable = context - output
  36. return count > usable
  37. }
  38. export const PRUNE_MINIMUM = 20_000
  39. export const PRUNE_PROTECT = 40_000
  40. // goes backwards through parts until there are 40_000 tokens worth of tool
  41. // calls. then erases output of previous tool calls. idea is to throw away old
  42. // tool calls that are no longer relevant.
  43. export async function prune(input: { sessionID: string }) {
  44. if (Flag.OPENCODE_DISABLE_PRUNE) return
  45. log.info("pruning")
  46. const msgs = await Session.messages({ sessionID: input.sessionID })
  47. let total = 0
  48. let pruned = 0
  49. const toPrune = []
  50. let turns = 0
  51. loop: for (let msgIndex = msgs.length - 1; msgIndex >= 0; msgIndex--) {
  52. const msg = msgs[msgIndex]
  53. if (msg.info.role === "user") turns++
  54. if (turns < 2) continue
  55. if (msg.info.role === "assistant" && msg.info.summary) break loop
  56. for (let partIndex = msg.parts.length - 1; partIndex >= 0; partIndex--) {
  57. const part = msg.parts[partIndex]
  58. if (part.type === "tool")
  59. if (part.state.status === "completed") {
  60. if (part.state.time.compacted) break loop
  61. const estimate = Token.estimate(part.state.output)
  62. total += estimate
  63. if (total > PRUNE_PROTECT) {
  64. pruned += estimate
  65. toPrune.push(part)
  66. }
  67. }
  68. }
  69. }
  70. log.info("found", { pruned, total })
  71. if (pruned > PRUNE_MINIMUM) {
  72. for (const part of toPrune) {
  73. if (part.state.status === "completed") {
  74. part.state.time.compacted = Date.now()
  75. await Session.updatePart(part)
  76. }
  77. }
  78. log.info("pruned", { count: toPrune.length })
  79. }
  80. }
  81. export async function process(input: {
  82. parentID: string
  83. messages: MessageV2.WithParts[]
  84. sessionID: string
  85. model: {
  86. providerID: string
  87. modelID: string
  88. }
  89. agent: string
  90. abort: AbortSignal
  91. auto: boolean
  92. }) {
  93. const cfg = await Config.get()
  94. const model = await Provider.getModel(input.model.providerID, input.model.modelID)
  95. const language = await Provider.getLanguage(model)
  96. const system = [...SystemPrompt.compaction(model.providerID)]
  97. const msg = (await Session.updateMessage({
  98. id: Identifier.ascending("message"),
  99. role: "assistant",
  100. parentID: input.parentID,
  101. sessionID: input.sessionID,
  102. mode: input.agent,
  103. summary: true,
  104. path: {
  105. cwd: Instance.directory,
  106. root: Instance.worktree,
  107. },
  108. cost: 0,
  109. tokens: {
  110. output: 0,
  111. input: 0,
  112. reasoning: 0,
  113. cache: { read: 0, write: 0 },
  114. },
  115. modelID: input.model.modelID,
  116. providerID: model.providerID,
  117. time: {
  118. created: Date.now(),
  119. },
  120. })) as MessageV2.Assistant
  121. const processor = SessionProcessor.create({
  122. assistantMessage: msg,
  123. sessionID: input.sessionID,
  124. model: model,
  125. abort: input.abort,
  126. })
  127. const result = await processor.process({
  128. onError(error) {
  129. log.error("stream error", {
  130. error,
  131. })
  132. },
  133. // set to 0, we handle loop
  134. maxRetries: 0,
  135. providerOptions: ProviderTransform.providerOptions(
  136. model,
  137. pipe({}, mergeDeep(ProviderTransform.options(model, input.sessionID)), mergeDeep(model.options)),
  138. ),
  139. headers: model.headers,
  140. abortSignal: input.abort,
  141. tools: model.capabilities.toolcall ? {} : undefined,
  142. messages: [
  143. ...system.map(
  144. (x): ModelMessage => ({
  145. role: "system",
  146. content: x,
  147. }),
  148. ),
  149. ...MessageV2.toModelMessage(
  150. input.messages.filter((m) => {
  151. if (m.info.role !== "assistant" || m.info.error === undefined) {
  152. return true
  153. }
  154. if (
  155. MessageV2.AbortedError.isInstance(m.info.error) &&
  156. m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
  157. ) {
  158. return true
  159. }
  160. return false
  161. }),
  162. ),
  163. {
  164. role: "user",
  165. content: [
  166. {
  167. type: "text",
  168. text: "Summarize our conversation above. This summary will be the only context available when the conversation continues, so preserve critical information including: what was accomplished, current work in progress, files involved, next steps, and any key user requests or constraints. Be concise but detailed enough that work can continue seamlessly.",
  169. },
  170. ],
  171. },
  172. ],
  173. model: wrapLanguageModel({
  174. model: language,
  175. middleware: [
  176. {
  177. async transformParams(args) {
  178. if (args.type === "stream") {
  179. // @ts-expect-error
  180. args.params.prompt = ProviderTransform.message(args.params.prompt, model)
  181. }
  182. return args.params
  183. },
  184. },
  185. ],
  186. }),
  187. experimental_telemetry: { isEnabled: cfg.experimental?.openTelemetry },
  188. })
  189. if (result === "continue" && input.auto) {
  190. const continueMsg = await Session.updateMessage({
  191. id: Identifier.ascending("message"),
  192. role: "user",
  193. sessionID: input.sessionID,
  194. time: {
  195. created: Date.now(),
  196. },
  197. agent: input.agent,
  198. model: input.model,
  199. })
  200. await Session.updatePart({
  201. id: Identifier.ascending("part"),
  202. messageID: continueMsg.id,
  203. sessionID: input.sessionID,
  204. type: "text",
  205. synthetic: true,
  206. text: "Continue if you have next steps",
  207. time: {
  208. start: Date.now(),
  209. end: Date.now(),
  210. },
  211. })
  212. }
  213. if (processor.message.error) return "stop"
  214. Bus.publish(Event.Compacted, { sessionID: input.sessionID })
  215. return "continue"
  216. }
  217. export const create = fn(
  218. z.object({
  219. sessionID: Identifier.schema("session"),
  220. agent: z.string(),
  221. model: z.object({
  222. providerID: z.string(),
  223. modelID: z.string(),
  224. }),
  225. auto: z.boolean(),
  226. }),
  227. async (input) => {
  228. const msg = await Session.updateMessage({
  229. id: Identifier.ascending("message"),
  230. role: "user",
  231. model: input.model,
  232. sessionID: input.sessionID,
  233. agent: input.agent,
  234. time: {
  235. created: Date.now(),
  236. },
  237. })
  238. await Session.updatePart({
  239. id: Identifier.ascending("part"),
  240. messageID: msg.id,
  241. sessionID: msg.sessionID,
  242. type: "compaction",
  243. auto: input.auto,
  244. })
  245. },
  246. )
  247. }