Browse Source

fix compaction issues

Dax Raad 5 months ago
parent
commit
983e3b2ee3
2 changed files with 185 additions and 115 deletions
  1. 178 115
      packages/opencode/src/session/index.ts
  2. 7 0
      packages/opencode/src/util/token.ts

+ 178 - 115
packages/opencode/src/session/index.ts

@@ -53,6 +53,7 @@ import { defer } from "../util/defer"
 import { Command } from "../command"
 import { $ } from "bun"
 import { ListTool } from "../tool/ls"
+import { Token } from "../util/token"
 
 export namespace Session {
   const log = Log.create({ service: "session" })
@@ -83,6 +84,12 @@ export namespace Session {
         .optional(),
       title: z.string(),
       version: z.string(),
+      compaction: z
+        .object({
+          full: z.string().optional(),
+          micro: z.string().optional(),
+        })
+        .optional(),
       time: z.object({
         created: z.number(),
         updated: z.number(),
@@ -361,6 +368,7 @@ export namespace Session {
     Bus.publish(MessageV2.Event.Updated, {
       info: msg,
     })
+    return msg
   }
 
   async function updatePart(part: MessageV2.Part) {
@@ -717,14 +725,34 @@ export namespace Session {
       }
       return Provider.defaultModel()
     })().then((x) => Provider.getModel(x.providerID, x.modelID))
+
     let msgs = await messages(input.sessionID)
+    const lastSummary = Math.max(
+      0,
+      msgs.findLastIndex((msg) => msg.info.role === "assistant" && msg.info.summary === true),
+    )
+    msgs = msgs.slice(lastSummary)
+
+    const lastAssistant = msgs.findLast((msg) => msg.info.role === "assistant")
+    if (
+      lastAssistant?.info.role === "assistant" &&
+      needsCompaction({
+        tokens: lastAssistant.info.tokens,
+        model: model.info,
+      })
+    ) {
+      const msg = await summarize({
+        sessionID: input.sessionID,
+        providerID: model.providerID,
+        modelID: model.info.id,
+      })
+      msgs = [msg]
+    }
 
     const outputLimit = Math.min(model.info.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
 
     using abort = lock(input.sessionID)
 
-    const lastSummary = msgs.findLast((msg) => msg.info.role === "assistant" && msg.info.summary === true)
-    if (lastSummary) msgs = msgs.filter((msg) => msg.info.id >= lastSummary.info.id)
     const numRealUserMsgs = msgs.filter(
       (m) => m.info.role === "user" && !m.parts.every((p) => "synthetic" in p && p.synthetic),
     ).length
@@ -819,39 +847,21 @@ export namespace Session {
     const [first, ...rest] = system
     system = [first, rest.join("\n")]
 
-    const assistantMsg: MessageV2.Info = {
-      id: Identifier.ascending("message"),
-      role: "assistant",
-      system,
-      mode: inputAgent,
-      path: {
-        cwd: Instance.directory,
-        root: Instance.worktree,
-      },
-      cost: 0,
-      tokens: {
-        input: 0,
-        output: 0,
-        reasoning: 0,
-        cache: { read: 0, write: 0 },
-      },
-      modelID: model.modelID,
-      providerID: model.providerID,
-      time: {
-        created: Date.now(),
-      },
+    const processor = await createProcessor({
       sessionID: input.sessionID,
-    }
-    await updateMessage(assistantMsg)
+      model: model.info,
+      providerID: model.providerID,
+      agent: inputAgent,
+      system,
+    })
+
     await using _ = defer(async () => {
-      if (assistantMsg.time.completed) return
-      await Storage.remove(["session", "message", input.sessionID, assistantMsg.id])
-      await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: assistantMsg.id })
+      if (processor.message.time.completed) return
+      await Storage.remove(["session", "message", input.sessionID, processor.message.id])
+      await Bus.publish(MessageV2.Event.Removed, { sessionID: input.sessionID, messageID: processor.message.id })
     })
     const tools: Record<string, AITool> = {}
 
-    const processor = createProcessor(assistantMsg, model.info)
-
     const enabledTools = pipe(
       agent.tools,
       mergeDeep(await ToolRegistry.enabled(model.providerID, model.modelID, agent)),
@@ -878,7 +888,7 @@ export namespace Session {
           const result = await item.execute(args, {
             sessionID: input.sessionID,
             abort: options.abortSignal!,
-            messageID: assistantMsg.id,
+            messageID: processor.message.id,
             callID: options.toolCallId,
             agent: agent.name,
             metadata: async (val) => {
@@ -982,6 +992,8 @@ export namespace Session {
         },
       },
     )
+
+    let pointer = 0
     const stream = streamText({
       onError(e) {
         log.error("streamText error", {
@@ -989,39 +1001,32 @@ export namespace Session {
         })
       },
       async prepareStep({ messages, steps }) {
-        // Auto compact if too long
-        const tokens = (() => {
-          if (steps.length) {
-            const previous = steps.at(-1)
-            if (previous) return getUsage(model.info, previous.usage, previous.providerMetadata).tokens
-          }
-          const msg = msgs.findLast((x) => x.info.role === "assistant")?.info as MessageV2.Assistant
-          if (msg && msg.tokens) {
-            return msg.tokens
-          }
-        })()
-        if (tokens) {
-          log.info("compact check", tokens)
-          const count = tokens.input + tokens.cache.read + tokens.cache.write + tokens.output
-          if (model.info.limit.context && count > Math.max((model.info.limit.context - outputLimit) * 0.9, 0)) {
-            log.info("compacting in prepareStep")
-            const summarized = await summarize({
-              sessionID: input.sessionID,
-              providerID: model.providerID,
-              modelID: model.info.id,
-            })
-            const msgs = await Session.messages(input.sessionID).then((x) =>
-              x.filter((x) => x.info.id >= summarized.id),
-            )
-            return {
-              messages: MessageV2.toModelMessage(msgs),
-            }
-          }
+        log.info("search", {
+          length: messages.length,
+        })
+        const step = steps.at(-1)
+        if (
+          step &&
+          needsCompaction({
+            tokens: getUsage(model.info, step.usage, step.providerMetadata).tokens,
+            model: model.info,
+          })
+        ) {
+          await processor.end()
+          const msg = await Session.summarize({
+            sessionID: input.sessionID,
+            providerID: model.providerID,
+            modelID: model.info.id,
+          })
+          await processor.next()
+          pointer = messages.length - 1
+          messages.push(...MessageV2.toModelMessage([msg]))
         }
 
         // Add queued messages to the stream
         const queue = (state().queued.get(input.sessionID) ?? []).filter((x) => !x.processed)
         if (queue.length) {
+          await processor.end()
           for (const item of queue) {
             if (item.processed) continue
             messages.push(
@@ -1034,35 +1039,10 @@ export namespace Session {
             )
             item.processed = true
           }
-          assistantMsg.time.completed = Date.now()
-          await updateMessage(assistantMsg)
-          Object.assign(assistantMsg, {
-            id: Identifier.ascending("message"),
-            role: "assistant",
-            system,
-            path: {
-              cwd: Instance.directory,
-              root: Instance.worktree,
-            },
-            cost: 0,
-            tokens: {
-              input: 0,
-              output: 0,
-              reasoning: 0,
-              cache: { read: 0, write: 0 },
-            },
-            modelID: model.modelID,
-            providerID: model.providerID,
-            mode: inputAgent,
-            time: {
-              created: Date.now(),
-            },
-            sessionID: input.sessionID,
-          })
-          await updateMessage(assistantMsg)
+          await processor.next()
         }
         return {
-          messages,
+          messages: messages.slice(pointer),
         }
       },
       async experimental_repairToolCall(input) {
@@ -1421,11 +1401,60 @@ export namespace Session {
     })
   }
 
-  function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
+  async function createProcessor(input: {
+    sessionID: string
+    providerID: string
+    model: ModelsDev.Model
+    system: string[]
+    agent: string
+  }) {
     const toolcalls: Record<string, MessageV2.ToolPart> = {}
     let snapshot: string | undefined
     let shouldStop = false
-    return {
+
+    async function createMessage() {
+      const msg: MessageV2.Info = {
+        id: Identifier.ascending("message"),
+        role: "assistant",
+        system: input.system,
+        mode: input.agent,
+        path: {
+          cwd: Instance.directory,
+          root: Instance.worktree,
+        },
+        cost: 0,
+        tokens: {
+          input: 0,
+          output: 0,
+          reasoning: 0,
+          cache: { read: 0, write: 0 },
+        },
+        modelID: input.model.id,
+        providerID: input.providerID,
+        time: {
+          created: Date.now(),
+        },
+        sessionID: input.sessionID,
+      }
+      await updateMessage(msg)
+      return msg
+    }
+
+    let assistantMsg = await createMessage()
+
+    const result = {
+      async end() {
+        if (assistantMsg) {
+          assistantMsg.time.completed = Date.now()
+          await updateMessage(assistantMsg)
+        }
+      },
+      async next() {
+        assistantMsg = await createMessage()
+      },
+      get message() {
+        return assistantMsg
+      },
       partFromToolCall(toolCallID: string) {
         return toolcalls[toolCallID]
       },
@@ -1581,7 +1610,7 @@ export namespace Session {
                 break
 
               case "finish-step":
-                const usage = getUsage(model, value.usage, value.providerMetadata)
+                const usage = getUsage(input.model, value.usage, value.providerMetadata)
                 assistantMsg.cost += usage.cost
                 assistantMsg.tokens = usage.tokens
                 await updatePart({
@@ -1672,7 +1701,7 @@ export namespace Session {
             case LoadAPIKeyError.isInstance(e):
               assistantMsg.error = new MessageV2.AuthError(
                 {
-                  providerID: model.id,
+                  providerID: input.providerID,
                   message: e.message,
                 },
                 { cause: e },
@@ -1711,6 +1740,7 @@ export namespace Session {
         return { info: assistantMsg, parts: p }
       },
     }
+    return result
   }
 
   export const RevertInput = z.object({
@@ -1789,9 +1819,8 @@ export namespace Session {
       0,
       msgs.findLastIndex((msg) => msg.info.role === "assistant" && msg.info.summary === true),
     )
-    const split = start + Math.floor((msgs.length - start) / 2)
-    log.info("summarizing", { start, split })
-    const toSummarize = msgs.slice(start, split)
+    log.info("summarizing", { start })
+    const toSummarize = msgs.slice(start)
     const model = await Provider.getModel(input.providerID, input.modelID)
     const system = [
       ...SystemPrompt.summarize(model.providerID),
@@ -1799,6 +1828,29 @@ export namespace Session {
       ...(await SystemPrompt.custom()),
     ]
 
+    const msg = (await updateMessage({
+      id: Identifier.ascending("message"),
+      role: "assistant",
+      sessionID: input.sessionID,
+      system,
+      mode: "build",
+      path: {
+        cwd: Instance.directory,
+        root: Instance.worktree,
+      },
+      cost: 0,
+      tokens: {
+        output: 0,
+        input: 0,
+        reasoning: 0,
+        cache: { read: 0, write: 0 },
+      },
+      modelID: input.modelID,
+      providerID: model.providerID,
+      time: {
+        created: Date.now(),
+      },
+    })) as MessageV2.Assistant
     const generated = await generateText({
       maxRetries: 10,
       model: model.language,
@@ -1822,28 +1874,12 @@ export namespace Session {
       ],
     })
     const usage = getUsage(model.info, generated.usage, generated.providerMetadata)
-    const msg: MessageV2.Info = {
-      id: Identifier.create("message", false, toSummarize.at(-1)!.info.time.created + 1),
-      role: "assistant",
-      sessionID: input.sessionID,
-      system,
-      mode: "build",
-      path: {
-        cwd: Instance.directory,
-        root: Instance.worktree,
-      },
-      summary: true,
-      cost: usage.cost,
-      tokens: usage.tokens,
-      modelID: input.modelID,
-      providerID: model.providerID,
-      time: {
-        created: Date.now(),
-        completed: Date.now(),
-      },
-    }
+    msg.cost += usage.cost
+    msg.tokens = usage.tokens
+    msg.summary = true
+    msg.time.completed = Date.now()
     await updateMessage(msg)
-    await updatePart({
+    const part = await updatePart({
       type: "text",
       sessionID: input.sessionID,
       messageID: msg.id,
@@ -1859,7 +1895,34 @@ export namespace Session {
       sessionID: input.sessionID,
     })
 
-    return msg
+    return {
+      info: msg,
+      parts: [part],
+    }
+  }
+
+  function needsCompaction(input: { tokens: MessageV2.Assistant["tokens"]; model: ModelsDev.Model }) {
+    const count = input.tokens.input + input.tokens.cache.read + input.tokens.output
+    const output = Math.min(input.model.limit.output, OUTPUT_TOKEN_MAX) || OUTPUT_TOKEN_MAX
+    const usable = input.model.limit.context - output
+    return count > usable / 2
+  }
+
+  export async function microcompact(input: { sessionID: string }) {
+    const msgs = await messages(input.sessionID)
+    let sum = 0
+    for (let msgIndex = msgs.length - 1; msgIndex >= 0; msgIndex--) {
+      const msg = msgs[msgIndex]
+      for (let partIndex = msg.parts.length - 1; partIndex >= 0; partIndex--) {
+        const part = msg.parts[partIndex]
+        if (part.type === "tool")
+          if (part.state.status === "completed") {
+            sum += Token.estimate(part.state.output)
+            if (sum > 40_000) {
+            }
+          }
+      }
+    }
   }
 
   function isLocked(sessionID: string) {

+ 7 - 0
packages/opencode/src/util/token.ts

@@ -0,0 +1,7 @@
+export namespace Token {
+  const CHARS_PER_TOKEN = 4
+
+  export function estimate(input: string) {
+    return Math.max(0, Math.round((input || "").length / CHARS_PER_TOKEN))
+  }
+}