compaction.ts 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import { streamText, wrapLanguageModel, type ModelMessage } from "ai"
  2. import { Session } from "."
  3. import { Identifier } from "../id/id"
  4. import { Instance } from "../project/instance"
  5. import { Provider } from "../provider/provider"
  6. import { MessageV2 } from "./message-v2"
  7. import { SystemPrompt } from "./system"
  8. import { Bus } from "../bus"
  9. import z from "zod"
  10. import type { ModelsDev } from "../provider/models"
  11. import { SessionPrompt } from "./prompt"
  12. import { Flag } from "../flag/flag"
  13. import { Token } from "../util/token"
  14. import { Log } from "../util/log"
  15. import { ProviderTransform } from "@/provider/transform"
  16. import { SessionProcessor } from "./processor"
  17. import { fn } from "@/util/fn"
  18. import { mergeDeep, pipe } from "remeda"
  19. export namespace SessionCompaction {
  20. const log = Log.create({ service: "session.compaction" })
  21. export const Event = {
  22. Compacted: Bus.event(
  23. "session.compacted",
  24. z.object({
  25. sessionID: z.string(),
  26. }),
  27. ),
  28. }
  29. export function isOverflow(input: { tokens: MessageV2.Assistant["tokens"]; model: ModelsDev.Model }) {
  30. if (Flag.OPENCODE_DISABLE_AUTOCOMPACT) return false
  31. const context = input.model.limit.context
  32. if (context === 0) return false
  33. const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
  34. const output = Math.min(input.model.limit.output, SessionPrompt.OUTPUT_TOKEN_MAX) || SessionPrompt.OUTPUT_TOKEN_MAX
  35. const usable = context - output
  36. return count > usable
  37. }
  38. export const PRUNE_MINIMUM = 20_000
  39. export const PRUNE_PROTECT = 40_000
  40. // goes backwards through parts until there are 40_000 tokens worth of tool
  41. // calls. then erases output of previous tool calls. idea is to throw away old
  42. // tool calls that are no longer relevant.
  43. export async function prune(input: { sessionID: string }) {
  44. if (Flag.OPENCODE_DISABLE_PRUNE) return
  45. log.info("pruning")
  46. const msgs = await Session.messages({ sessionID: input.sessionID })
  47. let total = 0
  48. let pruned = 0
  49. const toPrune = []
  50. let turns = 0
  51. loop: for (let msgIndex = msgs.length - 1; msgIndex >= 0; msgIndex--) {
  52. const msg = msgs[msgIndex]
  53. if (msg.info.role === "user") turns++
  54. if (turns < 2) continue
  55. if (msg.info.role === "assistant" && msg.info.summary) break loop
  56. for (let partIndex = msg.parts.length - 1; partIndex >= 0; partIndex--) {
  57. const part = msg.parts[partIndex]
  58. if (part.type === "tool")
  59. if (part.state.status === "completed") {
  60. if (part.state.time.compacted) break loop
  61. const estimate = Token.estimate(part.state.output)
  62. total += estimate
  63. if (total > PRUNE_PROTECT) {
  64. pruned += estimate
  65. toPrune.push(part)
  66. }
  67. }
  68. }
  69. }
  70. log.info("found", { pruned, total })
  71. if (pruned > PRUNE_MINIMUM) {
  72. for (const part of toPrune) {
  73. if (part.state.status === "completed") {
  74. part.state.time.compacted = Date.now()
  75. await Session.updatePart(part)
  76. }
  77. }
  78. log.info("pruned", { count: toPrune.length })
  79. }
  80. }
  81. export async function process(input: {
  82. parentID: string
  83. messages: MessageV2.WithParts[]
  84. sessionID: string
  85. model: {
  86. providerID: string
  87. modelID: string
  88. }
  89. agent: string
  90. abort: AbortSignal
  91. }) {
  92. const model = await Provider.getModel(input.model.providerID, input.model.modelID)
  93. const system = [...SystemPrompt.compaction(model.providerID)]
  94. const msg = (await Session.updateMessage({
  95. id: Identifier.ascending("message"),
  96. role: "assistant",
  97. parentID: input.parentID,
  98. sessionID: input.sessionID,
  99. mode: input.agent,
  100. summary: true,
  101. path: {
  102. cwd: Instance.directory,
  103. root: Instance.worktree,
  104. },
  105. cost: 0,
  106. tokens: {
  107. output: 0,
  108. input: 0,
  109. reasoning: 0,
  110. cache: { read: 0, write: 0 },
  111. },
  112. modelID: input.model.modelID,
  113. providerID: model.providerID,
  114. time: {
  115. created: Date.now(),
  116. },
  117. })) as MessageV2.Assistant
  118. const processor = SessionProcessor.create({
  119. assistantMessage: msg,
  120. sessionID: input.sessionID,
  121. providerID: input.model.providerID,
  122. model: model.info,
  123. abort: input.abort,
  124. })
  125. const result = await processor.process(() =>
  126. streamText({
  127. onError(error) {
  128. log.error("stream error", {
  129. error,
  130. })
  131. },
  132. // set to 0, we handle loop
  133. maxRetries: 0,
  134. providerOptions: ProviderTransform.providerOptions(
  135. model.npm,
  136. model.providerID,
  137. pipe(
  138. {},
  139. mergeDeep(ProviderTransform.options(model.providerID, model.modelID, model.npm ?? "", input.sessionID)),
  140. mergeDeep(model.info.options),
  141. ),
  142. ),
  143. headers: model.info.headers,
  144. abortSignal: input.abort,
  145. tools: model.info.tool_call ? {} : undefined,
  146. messages: [
  147. ...system.map(
  148. (x): ModelMessage => ({
  149. role: "system",
  150. content: x,
  151. }),
  152. ),
  153. ...MessageV2.toModelMessage(
  154. input.messages.filter((m) => {
  155. if (m.info.role !== "assistant" || m.info.error === undefined) {
  156. return true
  157. }
  158. if (
  159. MessageV2.AbortedError.isInstance(m.info.error) &&
  160. m.parts.some((part) => part.type !== "step-start" && part.type !== "reasoning")
  161. ) {
  162. return true
  163. }
  164. return false
  165. }),
  166. ),
  167. {
  168. role: "user",
  169. content: [
  170. {
  171. type: "text",
  172. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  173. },
  174. ],
  175. },
  176. ],
  177. model: wrapLanguageModel({
  178. model: model.language,
  179. middleware: [
  180. {
  181. async transformParams(args) {
  182. if (args.type === "stream") {
  183. // @ts-expect-error
  184. args.params.prompt = ProviderTransform.message(args.params.prompt, model.providerID, model.modelID)
  185. }
  186. return args.params
  187. },
  188. },
  189. ],
  190. }),
  191. }),
  192. )
  193. if (result === "continue") {
  194. const continueMsg = await Session.updateMessage({
  195. id: Identifier.ascending("message"),
  196. role: "user",
  197. sessionID: input.sessionID,
  198. time: {
  199. created: Date.now(),
  200. },
  201. agent: input.agent,
  202. model: input.model,
  203. })
  204. await Session.updatePart({
  205. id: Identifier.ascending("part"),
  206. messageID: continueMsg.id,
  207. sessionID: input.sessionID,
  208. type: "text",
  209. synthetic: true,
  210. text: "Continue if you have next steps",
  211. time: {
  212. start: Date.now(),
  213. end: Date.now(),
  214. },
  215. })
  216. }
  217. if (processor.message.error) return "stop"
  218. return "continue"
  219. }
  220. export const create = fn(
  221. z.object({
  222. sessionID: Identifier.schema("session"),
  223. agent: z.string(),
  224. model: z.object({
  225. providerID: z.string(),
  226. modelID: z.string(),
  227. }),
  228. }),
  229. async (input) => {
  230. const msg = await Session.updateMessage({
  231. id: Identifier.ascending("message"),
  232. role: "user",
  233. model: input.model,
  234. sessionID: input.sessionID,
  235. agent: input.agent,
  236. time: {
  237. created: Date.now(),
  238. },
  239. })
  240. await Session.updatePart({
  241. id: Identifier.ascending("part"),
  242. messageID: msg.id,
  243. sessionID: msg.sessionID,
  244. type: "compaction",
  245. })
  246. },
  247. )
  248. }