Browse Source

fix(core): fix restoring earlier messages in a reverted chain (#20780)

Nate Williams 2 weeks ago
parent
commit
6359d00fb4

+ 1 - 0
packages/opencode/src/session/revert.ts

@@ -72,6 +72,7 @@ export namespace SessionRevert {
         if (!rev) return session
 
         rev.snapshot = session.revert?.snapshot ?? (yield* snap.track())
+        if (session.revert?.snapshot) yield* snap.restore(session.revert.snapshot)
         yield* snap.revert(patches)
         if (rev.snapshot) rev.diff = yield* snap.diff(rev.snapshot as string)
         const range = all.filter((msg) => msg.info.id >= rev!.messageID)

+ 184 - 0
packages/opencode/test/session/revert-compact.test.ts

@@ -1,10 +1,12 @@
 import { describe, expect, test, beforeEach, afterEach } from "bun:test"
+import fs from "fs/promises"
 import path from "path"
 import { Session } from "../../src/session"
 import { ModelID, ProviderID } from "../../src/provider/schema"
 import { SessionRevert } from "../../src/session/revert"
 import { SessionCompaction } from "../../src/session/compaction"
 import { MessageV2 } from "../../src/session/message-v2"
+import { Snapshot } from "../../src/snapshot"
 import { Log } from "../../src/util/log"
 import { Instance } from "../../src/project/instance"
 import { MessageID, PartID } from "../../src/session/schema"
@@ -70,6 +72,13 @@ function tool(sessionID: string, messageID: string) {
   })
 }
 
+const tokens = {
+  input: 0,
+  output: 0,
+  reasoning: 0,
+  cache: { read: 0, write: 0 },
+}
+
 describe("revert + compact workflow", () => {
   test("should properly handle compact command after revert", async () => {
     await using tmp = await tmpdir({ git: true })
@@ -434,4 +443,179 @@ describe("revert + compact workflow", () => {
       },
     })
   })
+
+  test("restore messages in sequential order", async () => {
+    await using tmp = await tmpdir({ git: true })
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        await fs.writeFile(path.join(tmp.path, "a.txt"), "a0")
+        await fs.writeFile(path.join(tmp.path, "b.txt"), "b0")
+        await fs.writeFile(path.join(tmp.path, "c.txt"), "c0")
+
+        const session = await Session.create({})
+        const sid = session.id
+
+        const turn = async (file: string, next: string) => {
+          const u = await user(sid)
+          await text(sid, u.id, `${file}:${next}`)
+          const a = await assistant(sid, u.id, tmp.path)
+          const before = await Snapshot.track()
+          if (!before) throw new Error("expected snapshot")
+          await fs.writeFile(path.join(tmp.path, file), next)
+          const after = await Snapshot.track()
+          if (!after) throw new Error("expected snapshot")
+          const patch = await Snapshot.patch(before)
+          await Session.updatePart({
+            id: PartID.ascending(),
+            messageID: a.id,
+            sessionID: sid,
+            type: "step-start",
+            snapshot: before,
+          })
+          await Session.updatePart({
+            id: PartID.ascending(),
+            messageID: a.id,
+            sessionID: sid,
+            type: "step-finish",
+            reason: "stop",
+            snapshot: after,
+            cost: 0,
+            tokens,
+          })
+          await Session.updatePart({
+            id: PartID.ascending(),
+            messageID: a.id,
+            sessionID: sid,
+            type: "patch",
+            hash: patch.hash,
+            files: patch.files,
+          })
+          return u.id
+        }
+
+        const first = await turn("a.txt", "a1")
+        const second = await turn("b.txt", "b2")
+        const third = await turn("c.txt", "c3")
+
+        await SessionRevert.revert({
+          sessionID: sid,
+          messageID: first,
+        })
+        expect((await Session.get(sid)).revert?.messageID).toBe(first)
+        expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0")
+        expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0")
+        expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
+
+        await SessionRevert.revert({
+          sessionID: sid,
+          messageID: second,
+        })
+        expect((await Session.get(sid)).revert?.messageID).toBe(second)
+        expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
+        expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0")
+        expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
+
+        await SessionRevert.revert({
+          sessionID: sid,
+          messageID: third,
+        })
+        expect((await Session.get(sid)).revert?.messageID).toBe(third)
+        expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
+        expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2")
+        expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
+
+        await SessionRevert.unrevert({
+          sessionID: sid,
+        })
+        expect((await Session.get(sid)).revert).toBeUndefined()
+        expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
+        expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2")
+        expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c3")
+      },
+    })
+  })
+
+  test("restore same file in sequential order", async () => {
+    await using tmp = await tmpdir({ git: true })
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        await fs.writeFile(path.join(tmp.path, "a.txt"), "a0")
+
+        const session = await Session.create({})
+        const sid = session.id
+
+        const turn = async (next: string) => {
+          const u = await user(sid)
+          await text(sid, u.id, `a.txt:${next}`)
+          const a = await assistant(sid, u.id, tmp.path)
+          const before = await Snapshot.track()
+          if (!before) throw new Error("expected snapshot")
+          await fs.writeFile(path.join(tmp.path, "a.txt"), next)
+          const after = await Snapshot.track()
+          if (!after) throw new Error("expected snapshot")
+          const patch = await Snapshot.patch(before)
+          await Session.updatePart({
+            id: PartID.ascending(),
+            messageID: a.id,
+            sessionID: sid,
+            type: "step-start",
+            snapshot: before,
+          })
+          await Session.updatePart({
+            id: PartID.ascending(),
+            messageID: a.id,
+            sessionID: sid,
+            type: "step-finish",
+            reason: "stop",
+            snapshot: after,
+            cost: 0,
+            tokens,
+          })
+          await Session.updatePart({
+            id: PartID.ascending(),
+            messageID: a.id,
+            sessionID: sid,
+            type: "patch",
+            hash: patch.hash,
+            files: patch.files,
+          })
+          return u.id
+        }
+
+        const first = await turn("a1")
+        const second = await turn("a2")
+        const third = await turn("a3")
+        expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3")
+
+        await SessionRevert.revert({
+          sessionID: sid,
+          messageID: first,
+        })
+        expect((await Session.get(sid)).revert?.messageID).toBe(first)
+        expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0")
+
+        await SessionRevert.revert({
+          sessionID: sid,
+          messageID: second,
+        })
+        expect((await Session.get(sid)).revert?.messageID).toBe(second)
+        expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
+
+        await SessionRevert.revert({
+          sessionID: sid,
+          messageID: third,
+        })
+        expect((await Session.get(sid)).revert?.messageID).toBe(third)
+        expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a2")
+
+        await SessionRevert.unrevert({
+          sessionID: sid,
+        })
+        expect((await Session.get(sid)).revert).toBeUndefined()
+        expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3")
+      },
+    })
+  })
 })