Преглед изворни кода

fix(session): check for context overflow mid-turn in finish-step (#6480)

Saatvik Arya пре 3 месеци
родитељ
комит
7a3ff5b98f

+ 8 - 0
packages/opencode/src/session/processor.ts

@@ -13,6 +13,7 @@ import { Plugin } from "@/plugin"
 import type { Provider } from "@/provider/provider"
 import { LLM } from "./llm"
 import { Config } from "@/config/config"
+import { SessionCompaction } from "./compaction"
 
 export namespace SessionProcessor {
   const DOOM_LOOP_THRESHOLD = 3
@@ -31,6 +32,7 @@ export namespace SessionProcessor {
     let snapshot: string | undefined
     let blocked = false
     let attempt = 0
+    let needsCompaction = false
 
     const result = {
       get message() {
@@ -41,6 +43,7 @@ export namespace SessionProcessor {
       },
       async process(streamInput: LLM.StreamInput) {
         log.info("process")
+        needsCompaction = false
         const shouldBreak = (await Config.get()).experimental?.continue_loop_on_deny !== true
         while (true) {
           try {
@@ -279,6 +282,9 @@ export namespace SessionProcessor {
                     sessionID: input.sessionID,
                     messageID: input.assistantMessage.parentID,
                   })
+                  if (await SessionCompaction.isOverflow({ tokens: usage.tokens, model: input.model })) {
+                    needsCompaction = true
+                  }
                   break
 
                 case "text-start":
@@ -339,6 +345,7 @@ export namespace SessionProcessor {
                   })
                   continue
               }
+              if (needsCompaction) break
             }
           } catch (e: any) {
             log.error("process", {
@@ -398,6 +405,7 @@ export namespace SessionProcessor {
           }
           input.assistantMessage.time.completed = Date.now()
           await Session.updateMessage(input.assistantMessage)
+          if (needsCompaction) return "compact"
           if (blocked) return "stop"
           if (input.assistantMessage.error) return "stop"
           return "continue"

+ 8 - 0
packages/opencode/src/session/prompt.ts

@@ -549,6 +549,14 @@ export namespace SessionPrompt {
         model,
       })
       if (result === "stop") break
+      if (result === "compact") {
+        await SessionCompaction.create({
+          sessionID,
+          agent: lastUser.agent,
+          model: lastUser.model,
+          auto: true,
+        })
+      }
       continue
     }
     SessionCompaction.prune({ sessionID })

+ 251 - 0
packages/opencode/test/session/compaction.test.ts

@@ -0,0 +1,251 @@
+import { describe, expect, test } from "bun:test"
+import path from "path"
+import { SessionCompaction } from "../../src/session/compaction"
+import { Token } from "../../src/util/token"
+import { Instance } from "../../src/project/instance"
+import { Log } from "../../src/util/log"
+import { tmpdir } from "../fixture/fixture"
+import { Session } from "../../src/session"
+import type { Provider } from "../../src/provider/provider"
+
+Log.init({ print: false })
+
+function createModel(opts: { context: number; output: number; cost?: Provider.Model["cost"] }): Provider.Model {
+  return {
+    id: "test-model",
+    providerID: "test",
+    name: "Test",
+    limit: {
+      context: opts.context,
+      output: opts.output,
+    },
+    cost: opts.cost ?? { input: 0, output: 0, cache: { read: 0, write: 0 } },
+    capabilities: {
+      toolcall: true,
+      attachment: false,
+      reasoning: false,
+      temperature: true,
+      input: { text: true, image: false, audio: false, video: false },
+      output: { text: true, image: false, audio: false, video: false },
+    },
+    api: { npm: "@ai-sdk/anthropic" },
+    options: {},
+  } as Provider.Model
+}
+
+describe("session.compaction.isOverflow", () => {
+  test("returns true when token count exceeds usable context", async () => {
+    await using tmp = await tmpdir()
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const model = createModel({ context: 100_000, output: 32_000 })
+        const tokens = { input: 75_000, output: 5_000, reasoning: 0, cache: { read: 0, write: 0 } }
+        expect(await SessionCompaction.isOverflow({ tokens, model })).toBe(true)
+      },
+    })
+  })
+
+  test("returns false when token count within usable context", async () => {
+    await using tmp = await tmpdir()
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const model = createModel({ context: 200_000, output: 32_000 })
+        const tokens = { input: 100_000, output: 10_000, reasoning: 0, cache: { read: 0, write: 0 } }
+        expect(await SessionCompaction.isOverflow({ tokens, model })).toBe(false)
+      },
+    })
+  })
+
+  test("includes cache.read in token count", async () => {
+    await using tmp = await tmpdir()
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const model = createModel({ context: 100_000, output: 32_000 })
+        const tokens = { input: 50_000, output: 10_000, reasoning: 0, cache: { read: 10_000, write: 0 } }
+        expect(await SessionCompaction.isOverflow({ tokens, model })).toBe(true)
+      },
+    })
+  })
+
+  test("returns false when model context limit is 0", async () => {
+    await using tmp = await tmpdir()
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const model = createModel({ context: 0, output: 32_000 })
+        const tokens = { input: 100_000, output: 10_000, reasoning: 0, cache: { read: 0, write: 0 } }
+        expect(await SessionCompaction.isOverflow({ tokens, model })).toBe(false)
+      },
+    })
+  })
+
+  test("returns false when compaction.auto is disabled", async () => {
+    await using tmp = await tmpdir({
+      init: async (dir) => {
+        await Bun.write(
+          path.join(dir, "opencode.json"),
+          JSON.stringify({
+            compaction: { auto: false },
+          }),
+        )
+      },
+    })
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const model = createModel({ context: 100_000, output: 32_000 })
+        const tokens = { input: 75_000, output: 5_000, reasoning: 0, cache: { read: 0, write: 0 } }
+        expect(await SessionCompaction.isOverflow({ tokens, model })).toBe(false)
+      },
+    })
+  })
+})
+
+describe("util.token.estimate", () => {
+  test("estimates tokens from text (4 chars per token)", () => {
+    const text = "x".repeat(4000)
+    expect(Token.estimate(text)).toBe(1000)
+  })
+
+  test("estimates tokens from larger text", () => {
+    const text = "y".repeat(20_000)
+    expect(Token.estimate(text)).toBe(5000)
+  })
+
+  test("returns 0 for empty string", () => {
+    expect(Token.estimate("")).toBe(0)
+  })
+})
+
+describe("session.getUsage", () => {
+  test("normalizes standard usage to token format", () => {
+    const model = createModel({ context: 100_000, output: 32_000 })
+    const result = Session.getUsage({
+      model,
+      usage: {
+        inputTokens: 1000,
+        outputTokens: 500,
+        totalTokens: 1500,
+      },
+    })
+
+    expect(result.tokens.input).toBe(1000)
+    expect(result.tokens.output).toBe(500)
+    expect(result.tokens.reasoning).toBe(0)
+    expect(result.tokens.cache.read).toBe(0)
+    expect(result.tokens.cache.write).toBe(0)
+  })
+
+  test("extracts cached tokens to cache.read", () => {
+    const model = createModel({ context: 100_000, output: 32_000 })
+    const result = Session.getUsage({
+      model,
+      usage: {
+        inputTokens: 1000,
+        outputTokens: 500,
+        totalTokens: 1500,
+        cachedInputTokens: 200,
+      },
+    })
+
+    expect(result.tokens.input).toBe(800)
+    expect(result.tokens.cache.read).toBe(200)
+  })
+
+  test("handles anthropic cache write metadata", () => {
+    const model = createModel({ context: 100_000, output: 32_000 })
+    const result = Session.getUsage({
+      model,
+      usage: {
+        inputTokens: 1000,
+        outputTokens: 500,
+        totalTokens: 1500,
+      },
+      metadata: {
+        anthropic: {
+          cacheCreationInputTokens: 300,
+        },
+      },
+    })
+
+    expect(result.tokens.cache.write).toBe(300)
+  })
+
+  test("does not subtract cached tokens for anthropic provider", () => {
+    const model = createModel({ context: 100_000, output: 32_000 })
+    const result = Session.getUsage({
+      model,
+      usage: {
+        inputTokens: 1000,
+        outputTokens: 500,
+        totalTokens: 1500,
+        cachedInputTokens: 200,
+      },
+      metadata: {
+        anthropic: {},
+      },
+    })
+
+    expect(result.tokens.input).toBe(1000)
+    expect(result.tokens.cache.read).toBe(200)
+  })
+
+  test("handles reasoning tokens", () => {
+    const model = createModel({ context: 100_000, output: 32_000 })
+    const result = Session.getUsage({
+      model,
+      usage: {
+        inputTokens: 1000,
+        outputTokens: 500,
+        totalTokens: 1500,
+        reasoningTokens: 100,
+      },
+    })
+
+    expect(result.tokens.reasoning).toBe(100)
+  })
+
+  test("handles undefined optional values gracefully", () => {
+    const model = createModel({ context: 100_000, output: 32_000 })
+    const result = Session.getUsage({
+      model,
+      usage: {
+        inputTokens: 0,
+        outputTokens: 0,
+        totalTokens: 0,
+      },
+    })
+
+    expect(result.tokens.input).toBe(0)
+    expect(result.tokens.output).toBe(0)
+    expect(result.tokens.reasoning).toBe(0)
+    expect(result.tokens.cache.read).toBe(0)
+    expect(result.tokens.cache.write).toBe(0)
+    expect(Number.isNaN(result.cost)).toBe(false)
+  })
+
+  test("calculates cost correctly", () => {
+    const model = createModel({
+      context: 100_000,
+      output: 32_000,
+      cost: {
+        input: 3,
+        output: 15,
+        cache: { read: 0.3, write: 3.75 },
+      },
+    })
+    const result = Session.getUsage({
+      model,
+      usage: {
+        inputTokens: 1_000_000,
+        outputTokens: 100_000,
+        totalTokens: 1_100_000,
+      },
+    })
+
+    expect(result.cost).toBe(3 + 1.5)
+  })
+})