compaction.ts 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import { streamText, type ModelMessage, LoadAPIKeyError, type StreamTextResult, type Tool as AITool } from "ai"
  2. import { Session } from "."
  3. import { Identifier } from "../id/id"
  4. import { Instance } from "../project/instance"
  5. import { Provider } from "../provider/provider"
  6. import { defer } from "../util/defer"
  7. import { MessageV2 } from "./message-v2"
  8. import { SystemPrompt } from "./system"
  9. import { Bus } from "../bus"
  10. import z from "zod"
  11. import type { ModelsDev } from "../provider/models"
  12. import { SessionPrompt } from "./prompt"
  13. import { Flag } from "../flag/flag"
  14. import { Token } from "../util/token"
  15. import { Log } from "../util/log"
  16. import { SessionLock } from "./lock"
  17. import { ProviderTransform } from "@/provider/transform"
  18. import { SessionRetry } from "./retry"
  19. import { Config } from "@/config/config"
  20. export namespace SessionCompaction {
  21. const log = Log.create({ service: "session.compaction" })
  22. export const Event = {
  23. Compacted: Bus.event(
  24. "session.compacted",
  25. z.object({
  26. sessionID: z.string(),
  27. }),
  28. ),
  29. }
  30. export function isOverflow(input: { tokens: MessageV2.Assistant["tokens"]; model: ModelsDev.Model }) {
  31. if (Flag.OPENCODE_DISABLE_AUTOCOMPACT) return false
  32. const context = input.model.limit.context
  33. if (context === 0) return false
  34. const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
  35. const output = Math.min(input.model.limit.output, SessionPrompt.OUTPUT_TOKEN_MAX) || SessionPrompt.OUTPUT_TOKEN_MAX
  36. const usable = context - output
  37. return count > usable
  38. }
  39. export const PRUNE_MINIMUM = 20_000
  40. export const PRUNE_PROTECT = 40_000
  41. const MAX_RETRIES = 10
  42. // goes backwards through parts until there are 40_000 tokens worth of tool
  43. // calls. then erases output of previous tool calls. idea is to throw away old
  44. // tool calls that are no longer relevant.
  45. export async function prune(input: { sessionID: string }) {
  46. if (Flag.OPENCODE_DISABLE_PRUNE) return
  47. log.info("pruning")
  48. const msgs = await Session.messages(input.sessionID)
  49. let total = 0
  50. let pruned = 0
  51. const toPrune = []
  52. let turns = 0
  53. loop: for (let msgIndex = msgs.length - 1; msgIndex >= 0; msgIndex--) {
  54. const msg = msgs[msgIndex]
  55. if (msg.info.role === "user") turns++
  56. if (turns < 2) continue
  57. if (msg.info.role === "assistant" && msg.info.summary) break loop
  58. for (let partIndex = msg.parts.length - 1; partIndex >= 0; partIndex--) {
  59. const part = msg.parts[partIndex]
  60. if (part.type === "tool")
  61. if (part.state.status === "completed") {
  62. if (part.state.time.compacted) break loop
  63. const estimate = Token.estimate(part.state.output)
  64. total += estimate
  65. if (total > PRUNE_PROTECT) {
  66. pruned += estimate
  67. toPrune.push(part)
  68. }
  69. }
  70. }
  71. }
  72. log.info("found", { pruned, total })
  73. if (pruned > PRUNE_MINIMUM) {
  74. for (const part of toPrune) {
  75. if (part.state.status === "completed") {
  76. part.state.time.compacted = Date.now()
  77. await Session.updatePart(part)
  78. }
  79. }
  80. log.info("pruned", { count: toPrune.length })
  81. }
  82. }
  83. export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
  84. if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
  85. await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
  86. const signal = input.signal ?? lock!.signal
  87. await Session.update(input.sessionID, (draft) => {
  88. draft.time.compacting = Date.now()
  89. })
  90. await using _ = defer(async () => {
  91. await Session.update(input.sessionID, (draft) => {
  92. draft.time.compacting = undefined
  93. })
  94. })
  95. const toSummarize = await Session.messages(input.sessionID).then(MessageV2.filterCompacted)
  96. const model = await Provider.getModel(input.providerID, input.modelID)
  97. const system = [
  98. ...SystemPrompt.summarize(model.providerID),
  99. ...(await SystemPrompt.environment()),
  100. ...(await SystemPrompt.custom()),
  101. ]
  102. const msg = (await Session.updateMessage({
  103. id: Identifier.ascending("message"),
  104. role: "assistant",
  105. parentID: toSummarize.findLast((m) => m.info.role === "user")?.info.id!,
  106. sessionID: input.sessionID,
  107. system,
  108. mode: "build",
  109. path: {
  110. cwd: Instance.directory,
  111. root: Instance.worktree,
  112. },
  113. summary: true,
  114. cost: 0,
  115. tokens: {
  116. output: 0,
  117. input: 0,
  118. reasoning: 0,
  119. cache: { read: 0, write: 0 },
  120. },
  121. modelID: input.modelID,
  122. providerID: model.providerID,
  123. time: {
  124. created: Date.now(),
  125. },
  126. })) as MessageV2.Assistant
  127. const part = (await Session.updatePart({
  128. type: "text",
  129. sessionID: input.sessionID,
  130. messageID: msg.id,
  131. id: Identifier.ascending("part"),
  132. text: "",
  133. time: {
  134. start: Date.now(),
  135. },
  136. })) as MessageV2.TextPart
  137. const doStream = () =>
  138. streamText({
  139. // set to 0, we handle loop
  140. maxRetries: 0,
  141. model: model.language,
  142. providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
  143. headers: model.info.headers,
  144. abortSignal: signal,
  145. onError(error) {
  146. log.error("stream error", {
  147. error,
  148. })
  149. },
  150. tools: model.info.tool_call ? {} : undefined,
  151. messages: [
  152. ...system.map(
  153. (x): ModelMessage => ({
  154. role: "system",
  155. content: x,
  156. }),
  157. ),
  158. ...MessageV2.toModelMessage(toSummarize),
  159. {
  160. role: "user",
  161. content: [
  162. {
  163. type: "text",
  164. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  165. },
  166. ],
  167. },
  168. ],
  169. })
  170. // TODO: reduce duplication between compaction.ts & prompt.ts
  171. const process = async (
  172. stream: StreamTextResult<Record<string, AITool>, never>,
  173. retries: { count: number; max: number },
  174. ) => {
  175. let shouldRetry = false
  176. try {
  177. for await (const value of stream.fullStream) {
  178. signal.throwIfAborted()
  179. switch (value.type) {
  180. case "text-delta":
  181. part.text += value.text
  182. if (value.providerMetadata) part.metadata = value.providerMetadata
  183. if (part.text)
  184. await Session.updatePart({
  185. part,
  186. delta: value.text,
  187. })
  188. continue
  189. case "text-end": {
  190. part.text = part.text.trimEnd()
  191. part.time = {
  192. start: Date.now(),
  193. end: Date.now(),
  194. }
  195. if (value.providerMetadata) part.metadata = value.providerMetadata
  196. await Session.updatePart(part)
  197. continue
  198. }
  199. case "finish-step": {
  200. const usage = Session.getUsage({
  201. model: model.info,
  202. usage: value.usage,
  203. metadata: value.providerMetadata,
  204. })
  205. msg.cost += usage.cost
  206. msg.tokens = usage.tokens
  207. await Session.updateMessage(msg)
  208. continue
  209. }
  210. case "error":
  211. throw value.error
  212. default:
  213. continue
  214. }
  215. }
  216. } catch (e) {
  217. log.error("compaction error", {
  218. error: e,
  219. })
  220. const error = MessageV2.fromError(e, { providerID: input.providerID })
  221. if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
  222. shouldRetry = true
  223. await Session.updatePart({
  224. id: Identifier.ascending("part"),
  225. messageID: msg.id,
  226. sessionID: msg.sessionID,
  227. type: "retry",
  228. attempt: retries.count + 1,
  229. time: {
  230. created: Date.now(),
  231. },
  232. error,
  233. })
  234. } else {
  235. msg.error = error
  236. Bus.publish(Session.Event.Error, {
  237. sessionID: msg.sessionID,
  238. error: msg.error,
  239. })
  240. }
  241. }
  242. const parts = await Session.getParts(msg.id)
  243. return {
  244. info: msg,
  245. parts,
  246. shouldRetry,
  247. }
  248. }
  249. let stream = doStream()
  250. const cfg = await Config.get()
  251. const maxRetries = cfg.experimental?.chatMaxRetries ?? MAX_RETRIES
  252. let result = await process(stream, {
  253. count: 0,
  254. max: maxRetries,
  255. })
  256. if (result.shouldRetry) {
  257. for (let retry = 1; retry < maxRetries; retry++) {
  258. const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
  259. if (lastRetryPart) {
  260. const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
  261. log.info("retrying with backoff", {
  262. attempt: retry,
  263. delayMs,
  264. })
  265. const stop = await SessionRetry.sleep(delayMs, signal)
  266. .then(() => false)
  267. .catch((error) => {
  268. if (error instanceof DOMException && error.name === "AbortError") {
  269. const err = new MessageV2.AbortedError(
  270. { message: error.message },
  271. {
  272. cause: error,
  273. },
  274. ).toObject()
  275. result.info.error = err
  276. Bus.publish(Session.Event.Error, {
  277. sessionID: result.info.sessionID,
  278. error: result.info.error,
  279. })
  280. return true
  281. }
  282. throw error
  283. })
  284. if (stop) break
  285. }
  286. stream = doStream()
  287. result = await process(stream, {
  288. count: retry,
  289. max: maxRetries,
  290. })
  291. if (!result.shouldRetry) {
  292. break
  293. }
  294. }
  295. }
  296. msg.time.completed = Date.now()
  297. if (
  298. !msg.error ||
  299. (MessageV2.AbortedError.isInstance(msg.error) &&
  300. result.parts.some((part) => part.type === "text" && part.text.length > 0))
  301. ) {
  302. msg.summary = true
  303. Bus.publish(Event.Compacted, {
  304. sessionID: input.sessionID,
  305. })
  306. }
  307. await Session.updateMessage(msg)
  308. return {
  309. info: msg,
  310. parts: result.parts,
  311. }
  312. }
  313. }