compaction.ts 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. import { streamText, type ModelMessage, LoadAPIKeyError, type StreamTextResult, type Tool as AITool } from "ai"
  2. import { Session } from "."
  3. import { Identifier } from "../id/id"
  4. import { Instance } from "../project/instance"
  5. import { Provider } from "../provider/provider"
  6. import { defer } from "../util/defer"
  7. import { MessageV2 } from "./message-v2"
  8. import { SystemPrompt } from "./system"
  9. import { Bus } from "../bus"
  10. import z from "zod"
  11. import type { ModelsDev } from "../provider/models"
  12. import { SessionPrompt } from "./prompt"
  13. import { Flag } from "../flag/flag"
  14. import { Token } from "../util/token"
  15. import { Log } from "../util/log"
  16. import { SessionLock } from "./lock"
  17. import { ProviderTransform } from "@/provider/transform"
  18. import { SessionRetry } from "./retry"
  19. import { Config } from "@/config/config"
  20. export namespace SessionCompaction {
  21. const log = Log.create({ service: "session.compaction" })
  22. export const Event = {
  23. Compacted: Bus.event(
  24. "session.compacted",
  25. z.object({
  26. sessionID: z.string(),
  27. }),
  28. ),
  29. }
  30. export function isOverflow(input: { tokens: MessageV2.Assistant["tokens"]; model: ModelsDev.Model }) {
  31. if (Flag.OPENCODE_DISABLE_AUTOCOMPACT) return false
  32. const context = input.model.limit.context
  33. if (context === 0) return false
  34. const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
  35. const output = Math.min(input.model.limit.output, SessionPrompt.OUTPUT_TOKEN_MAX) || SessionPrompt.OUTPUT_TOKEN_MAX
  36. const usable = context - output
  37. return count > usable
  38. }
  39. export const PRUNE_MINIMUM = 20_000
  40. export const PRUNE_PROTECT = 40_000
  41. const MAX_RETRIES = 10
  42. // goes backwards through parts until there are 40_000 tokens worth of tool
  43. // calls. then erases output of previous tool calls. idea is to throw away old
  44. // tool calls that are no longer relevant.
  45. export async function prune(input: { sessionID: string }) {
  46. if (Flag.OPENCODE_DISABLE_PRUNE) return
  47. log.info("pruning")
  48. const msgs = await Session.messages(input.sessionID)
  49. let total = 0
  50. let pruned = 0
  51. const toPrune = []
  52. let turns = 0
  53. loop: for (let msgIndex = msgs.length - 1; msgIndex >= 0; msgIndex--) {
  54. const msg = msgs[msgIndex]
  55. if (msg.info.role === "user") turns++
  56. if (turns < 2) continue
  57. if (msg.info.role === "assistant" && msg.info.summary) break loop
  58. for (let partIndex = msg.parts.length - 1; partIndex >= 0; partIndex--) {
  59. const part = msg.parts[partIndex]
  60. if (part.type === "tool")
  61. if (part.state.status === "completed") {
  62. if (part.state.time.compacted) break loop
  63. const estimate = Token.estimate(part.state.output)
  64. total += estimate
  65. if (total > PRUNE_PROTECT) {
  66. pruned += estimate
  67. toPrune.push(part)
  68. }
  69. }
  70. }
  71. }
  72. log.info("found", { pruned, total })
  73. if (pruned > PRUNE_MINIMUM) {
  74. for (const part of toPrune) {
  75. if (part.state.status === "completed") {
  76. part.state.time.compacted = Date.now()
  77. await Session.updatePart(part)
  78. }
  79. }
  80. log.info("pruned", { count: toPrune.length })
  81. }
  82. }
  83. export async function run(input: { sessionID: string; providerID: string; modelID: string; signal?: AbortSignal }) {
  84. if (!input.signal) SessionLock.assertUnlocked(input.sessionID)
  85. await using lock = input.signal === undefined ? SessionLock.acquire({ sessionID: input.sessionID }) : undefined
  86. const signal = input.signal ?? lock!.signal
  87. await Session.update(input.sessionID, (draft) => {
  88. draft.time.compacting = Date.now()
  89. })
  90. await using _ = defer(async () => {
  91. await Session.update(input.sessionID, (draft) => {
  92. draft.time.compacting = undefined
  93. })
  94. })
  95. const toSummarize = await Session.messages(input.sessionID).then(MessageV2.filterCompacted)
  96. const model = await Provider.getModel(input.providerID, input.modelID)
  97. const system = [
  98. ...SystemPrompt.summarize(model.providerID),
  99. ...(await SystemPrompt.environment()),
  100. ...(await SystemPrompt.custom()),
  101. ]
  102. const msg = (await Session.updateMessage({
  103. id: Identifier.ascending("message"),
  104. role: "assistant",
  105. parentID: toSummarize.findLast((m) => m.info.role === "user")?.info.id!,
  106. sessionID: input.sessionID,
  107. system,
  108. mode: "build",
  109. path: {
  110. cwd: Instance.directory,
  111. root: Instance.worktree,
  112. },
  113. cost: 0,
  114. tokens: {
  115. output: 0,
  116. input: 0,
  117. reasoning: 0,
  118. cache: { read: 0, write: 0 },
  119. },
  120. modelID: input.modelID,
  121. providerID: model.providerID,
  122. time: {
  123. created: Date.now(),
  124. },
  125. })) as MessageV2.Assistant
  126. const part = (await Session.updatePart({
  127. type: "text",
  128. sessionID: input.sessionID,
  129. messageID: msg.id,
  130. id: Identifier.ascending("part"),
  131. text: "",
  132. time: {
  133. start: Date.now(),
  134. },
  135. })) as MessageV2.TextPart
  136. const doStream = () =>
  137. streamText({
  138. // set to 0, we handle loop
  139. maxRetries: 0,
  140. model: model.language,
  141. providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
  142. headers: model.info.headers,
  143. abortSignal: signal,
  144. onError(error) {
  145. log.error("stream error", {
  146. error,
  147. })
  148. },
  149. tools: model.info.tool_call ? {} : undefined,
  150. messages: [
  151. ...system.map(
  152. (x): ModelMessage => ({
  153. role: "system",
  154. content: x,
  155. }),
  156. ),
  157. ...MessageV2.toModelMessage(toSummarize),
  158. {
  159. role: "user",
  160. content: [
  161. {
  162. type: "text",
  163. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  164. },
  165. ],
  166. },
  167. ],
  168. })
  169. // TODO: reduce duplication between compaction.ts & prompt.ts
  170. const process = async (
  171. stream: StreamTextResult<Record<string, AITool>, never>,
  172. retries: { count: number; max: number },
  173. ) => {
  174. let shouldRetry = false
  175. try {
  176. for await (const value of stream.fullStream) {
  177. signal.throwIfAborted()
  178. switch (value.type) {
  179. case "text-delta":
  180. part.text += value.text
  181. if (value.providerMetadata) part.metadata = value.providerMetadata
  182. if (part.text)
  183. await Session.updatePart({
  184. part,
  185. delta: value.text,
  186. })
  187. continue
  188. case "text-end": {
  189. part.text = part.text.trimEnd()
  190. part.time = {
  191. start: Date.now(),
  192. end: Date.now(),
  193. }
  194. if (value.providerMetadata) part.metadata = value.providerMetadata
  195. await Session.updatePart(part)
  196. continue
  197. }
  198. case "finish-step": {
  199. const usage = Session.getUsage({
  200. model: model.info,
  201. usage: value.usage,
  202. metadata: value.providerMetadata,
  203. })
  204. msg.cost += usage.cost
  205. msg.tokens = usage.tokens
  206. await Session.updateMessage(msg)
  207. continue
  208. }
  209. case "error":
  210. throw value.error
  211. default:
  212. continue
  213. }
  214. }
  215. } catch (e) {
  216. log.error("compaction error", {
  217. error: e,
  218. })
  219. const error = MessageV2.fromError(e, { providerID: input.providerID })
  220. if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
  221. shouldRetry = true
  222. await Session.updatePart({
  223. id: Identifier.ascending("part"),
  224. messageID: msg.id,
  225. sessionID: msg.sessionID,
  226. type: "retry",
  227. attempt: retries.count + 1,
  228. time: {
  229. created: Date.now(),
  230. },
  231. error,
  232. })
  233. } else {
  234. msg.error = error
  235. Bus.publish(Session.Event.Error, {
  236. sessionID: msg.sessionID,
  237. error: msg.error,
  238. })
  239. }
  240. }
  241. const parts = await Session.getParts(msg.id)
  242. return {
  243. info: msg,
  244. parts,
  245. shouldRetry,
  246. }
  247. }
  248. let stream = doStream()
  249. const cfg = await Config.get()
  250. const maxRetries = cfg.experimental?.chatMaxRetries ?? MAX_RETRIES
  251. let result = await process(stream, {
  252. count: 0,
  253. max: maxRetries,
  254. })
  255. if (result.shouldRetry) {
  256. for (let retry = 1; retry < maxRetries; retry++) {
  257. const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
  258. if (lastRetryPart) {
  259. const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
  260. log.info("retrying with backoff", {
  261. attempt: retry,
  262. delayMs,
  263. })
  264. const stop = await SessionRetry.sleep(delayMs, signal)
  265. .then(() => false)
  266. .catch((error) => {
  267. if (error instanceof DOMException && error.name === "AbortError") {
  268. const err = new MessageV2.AbortedError(
  269. { message: error.message },
  270. {
  271. cause: error,
  272. },
  273. ).toObject()
  274. result.info.error = err
  275. Bus.publish(Session.Event.Error, {
  276. sessionID: result.info.sessionID,
  277. error: result.info.error,
  278. })
  279. return true
  280. }
  281. throw error
  282. })
  283. if (stop) break
  284. }
  285. stream = doStream()
  286. result = await process(stream, {
  287. count: retry,
  288. max: maxRetries,
  289. })
  290. if (!result.shouldRetry) {
  291. break
  292. }
  293. }
  294. }
  295. msg.time.completed = Date.now()
  296. if (
  297. !msg.error ||
  298. (MessageV2.AbortedError.isInstance(msg.error) &&
  299. result.parts.some((part) => part.type === "text" && part.text.length > 0))
  300. ) {
  301. msg.summary = true
  302. Bus.publish(Event.Compacted, {
  303. sessionID: input.sessionID,
  304. })
  305. }
  306. await Session.updateMessage(msg)
  307. return {
  308. info: msg,
  309. parts: result.parts,
  310. }
  311. }
  312. }