compaction.ts 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. import { streamText, type ModelMessage, LoadAPIKeyError } from "ai"
  2. import { Session } from "."
  3. import { Identifier } from "../id/id"
  4. import { Instance } from "../project/instance"
  5. import { Provider } from "../provider/provider"
  6. import { defer } from "../util/defer"
  7. import { MessageV2 } from "./message-v2"
  8. import { SystemPrompt } from "./system"
  9. import { Bus } from "../bus"
  10. import z from "zod/v4"
  11. import type { ModelsDev } from "../provider/models"
  12. import { SessionPrompt } from "./prompt"
  13. import { Flag } from "../flag/flag"
  14. import { Token } from "../util/token"
  15. import { Log } from "../util/log"
  16. import { SessionLock } from "./lock"
  17. import { NamedError } from "../util/error"
  18. export namespace SessionCompaction {
  19. const log = Log.create({ service: "session.compaction" })
  20. export const Event = {
  21. Compacted: Bus.event(
  22. "session.compacted",
  23. z.object({
  24. sessionID: z.string(),
  25. }),
  26. ),
  27. }
  28. export function isOverflow(input: { tokens: MessageV2.Assistant["tokens"]; model: ModelsDev.Model }) {
  29. if (Flag.OPENCODE_DISABLE_AUTOCOMPACT) return false
  30. const context = input.model.limit.context
  31. if (context === 0) return false
  32. const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
  33. const output = Math.min(input.model.limit.output, SessionPrompt.OUTPUT_TOKEN_MAX) || SessionPrompt.OUTPUT_TOKEN_MAX
  34. const usable = context - output
  35. return count > usable
  36. }
  37. export const PRUNE_MINIMUM = 20_000
  38. export const PRUNE_PROTECT = 40_000
  39. // goes backwards through parts until there are 40_000 tokens worth of tool
  40. // calls. then erases output of previous tool calls. idea is to throw away old
  41. // tool calls that are no longer relevant.
  42. export async function prune(input: { sessionID: string }) {
  43. if (Flag.OPENCODE_DISABLE_PRUNE) return
  44. log.info("pruning")
  45. const msgs = await Session.messages(input.sessionID)
  46. let total = 0
  47. let pruned = 0
  48. const toPrune = []
  49. let turns = 0
  50. loop: for (let msgIndex = msgs.length - 1; msgIndex >= 0; msgIndex--) {
  51. const msg = msgs[msgIndex]
  52. if (msg.info.role === "user") turns++
  53. if (turns < 2) continue
  54. if (msg.info.role === "assistant" && msg.info.summary) break loop
  55. for (let partIndex = msg.parts.length - 1; partIndex >= 0; partIndex--) {
  56. const part = msg.parts[partIndex]
  57. if (part.type === "tool")
  58. if (part.state.status === "completed") {
  59. if (part.state.time.compacted) break loop
  60. const estimate = Token.estimate(part.state.output)
  61. total += estimate
  62. if (total > PRUNE_PROTECT) {
  63. pruned += estimate
  64. toPrune.push(part)
  65. }
  66. }
  67. }
  68. }
  69. log.info("found", { pruned, total })
  70. if (pruned > PRUNE_MINIMUM) {
  71. for (const part of toPrune) {
  72. if (part.state.status === "completed") {
  73. part.state.time.compacted = Date.now()
  74. await Session.updatePart(part)
  75. }
  76. }
  77. log.info("pruned", { count: toPrune.length })
  78. }
  79. }
  80. export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
  81. if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
  82. await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
  83. const signal = input.signal ?? lock!.signal
  84. await Session.update(input.sessionID, (draft) => {
  85. draft.time.compacting = Date.now()
  86. })
  87. await using _ = defer(async () => {
  88. await Session.update(input.sessionID, (draft) => {
  89. draft.time.compacting = undefined
  90. })
  91. })
  92. const toSummarize = await Session.messages(input.sessionID).then(MessageV2.filterSummarized)
  93. const model = await Provider.getModel(input.providerID, input.modelID)
  94. const system = [
  95. ...SystemPrompt.summarize(model.providerID),
  96. ...(await SystemPrompt.environment()),
  97. ...(await SystemPrompt.custom()),
  98. ]
  99. const msg = (await Session.updateMessage({
  100. id: Identifier.ascending("message"),
  101. role: "assistant",
  102. sessionID: input.sessionID,
  103. system,
  104. mode: "build",
  105. path: {
  106. cwd: Instance.directory,
  107. root: Instance.worktree,
  108. },
  109. cost: 0,
  110. tokens: {
  111. output: 0,
  112. input: 0,
  113. reasoning: 0,
  114. cache: { read: 0, write: 0 },
  115. },
  116. modelID: input.modelID,
  117. providerID: model.providerID,
  118. time: {
  119. created: Date.now(),
  120. },
  121. })) as MessageV2.Assistant
  122. const part = (await Session.updatePart({
  123. type: "text",
  124. sessionID: input.sessionID,
  125. messageID: msg.id,
  126. id: Identifier.ascending("part"),
  127. text: "",
  128. time: {
  129. start: Date.now(),
  130. },
  131. })) as MessageV2.TextPart
  132. const stream = streamText({
  133. maxRetries: 10,
  134. model: model.language,
  135. providerOptions: {
  136. [model.npm === "@ai-sdk/openai" ? "openai" : model.providerID]: model.info.options,
  137. },
  138. abortSignal: signal,
  139. onError(error) {
  140. log.error("stream error", {
  141. error,
  142. })
  143. },
  144. messages: [
  145. ...system.map(
  146. (x): ModelMessage => ({
  147. role: "system",
  148. content: x,
  149. }),
  150. ),
  151. ...MessageV2.toModelMessage(toSummarize),
  152. {
  153. role: "user",
  154. content: [
  155. {
  156. type: "text",
  157. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  158. },
  159. ],
  160. },
  161. ],
  162. })
  163. try {
  164. for await (const value of stream.fullStream) {
  165. signal.throwIfAborted()
  166. switch (value.type) {
  167. case "text-delta":
  168. part.text += value.text
  169. if (value.providerMetadata) part.metadata = value.providerMetadata
  170. if (part.text) await Session.updatePart(part)
  171. continue
  172. case "text-end": {
  173. part.text = part.text.trimEnd()
  174. part.time = {
  175. start: Date.now(),
  176. end: Date.now(),
  177. }
  178. if (value.providerMetadata) part.metadata = value.providerMetadata
  179. await Session.updatePart(part)
  180. continue
  181. }
  182. case "finish-step": {
  183. const usage = Session.getUsage({
  184. model: model.info,
  185. usage: value.usage,
  186. metadata: value.providerMetadata,
  187. })
  188. msg.cost += usage.cost
  189. msg.tokens = usage.tokens
  190. await Session.updateMessage(msg)
  191. continue
  192. }
  193. case "error":
  194. throw value.error
  195. default:
  196. continue
  197. }
  198. }
  199. } catch (e) {
  200. log.error("compaction error", {
  201. error: e,
  202. })
  203. switch (true) {
  204. case e instanceof DOMException && e.name === "AbortError":
  205. msg.error = new MessageV2.AbortedError(
  206. { message: e.message },
  207. {
  208. cause: e,
  209. },
  210. ).toObject()
  211. break
  212. case MessageV2.OutputLengthError.isInstance(e):
  213. msg.error = e
  214. break
  215. case LoadAPIKeyError.isInstance(e):
  216. msg.error = new MessageV2.AuthError(
  217. {
  218. providerID: model.providerID,
  219. message: e.message,
  220. },
  221. { cause: e },
  222. ).toObject()
  223. break
  224. case e instanceof Error:
  225. msg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
  226. break
  227. default:
  228. msg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
  229. }
  230. Bus.publish(Session.Event.Error, {
  231. sessionID: input.sessionID,
  232. error: msg.error,
  233. })
  234. }
  235. msg.time.completed = Date.now()
  236. if (!msg.error || MessageV2.AbortedError.isInstance(msg.error)) {
  237. msg.summary = true
  238. Bus.publish(Event.Compacted, {
  239. sessionID: input.sessionID,
  240. })
  241. }
  242. await Session.updateMessage(msg)
  243. return {
  244. info: msg,
  245. parts: [part],
  246. }
  247. }
  248. }