|
|
@@ -1,10 +1,12 @@
|
|
|
import { describe, expect, test, beforeEach, afterEach } from "bun:test"
|
|
|
+import fs from "fs/promises"
|
|
|
import path from "path"
|
|
|
import { Session } from "../../src/session"
|
|
|
import { ModelID, ProviderID } from "../../src/provider/schema"
|
|
|
import { SessionRevert } from "../../src/session/revert"
|
|
|
import { SessionCompaction } from "../../src/session/compaction"
|
|
|
import { MessageV2 } from "../../src/session/message-v2"
|
|
|
+import { Snapshot } from "../../src/snapshot"
|
|
|
import { Log } from "../../src/util/log"
|
|
|
import { Instance } from "../../src/project/instance"
|
|
|
import { MessageID, PartID } from "../../src/session/schema"
|
|
|
@@ -70,6 +72,13 @@ function tool(sessionID: string, messageID: string) {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
+const tokens = {
|
|
|
+ input: 0,
|
|
|
+ output: 0,
|
|
|
+ reasoning: 0,
|
|
|
+ cache: { read: 0, write: 0 },
|
|
|
+}
|
|
|
+
|
|
|
describe("revert + compact workflow", () => {
|
|
|
test("should properly handle compact command after revert", async () => {
|
|
|
await using tmp = await tmpdir({ git: true })
|
|
|
@@ -434,4 +443,179 @@ describe("revert + compact workflow", () => {
|
|
|
},
|
|
|
})
|
|
|
})
|
|
|
+
|
|
|
+ test("restore messages in sequential order", async () => {
|
|
|
+ await using tmp = await tmpdir({ git: true })
|
|
|
+ await Instance.provide({
|
|
|
+ directory: tmp.path,
|
|
|
+ fn: async () => {
|
|
|
+ await fs.writeFile(path.join(tmp.path, "a.txt"), "a0")
|
|
|
+ await fs.writeFile(path.join(tmp.path, "b.txt"), "b0")
|
|
|
+ await fs.writeFile(path.join(tmp.path, "c.txt"), "c0")
|
|
|
+
|
|
|
+ const session = await Session.create({})
|
|
|
+ const sid = session.id
|
|
|
+
|
|
|
+ const turn = async (file: string, next: string) => {
|
|
|
+ const u = await user(sid)
|
|
|
+ await text(sid, u.id, `${file}:${next}`)
|
|
|
+ const a = await assistant(sid, u.id, tmp.path)
|
|
|
+ const before = await Snapshot.track()
|
|
|
+ if (!before) throw new Error("expected snapshot")
|
|
|
+ await fs.writeFile(path.join(tmp.path, file), next)
|
|
|
+ const after = await Snapshot.track()
|
|
|
+ if (!after) throw new Error("expected snapshot")
|
|
|
+ const patch = await Snapshot.patch(before)
|
|
|
+ await Session.updatePart({
|
|
|
+ id: PartID.ascending(),
|
|
|
+ messageID: a.id,
|
|
|
+ sessionID: sid,
|
|
|
+ type: "step-start",
|
|
|
+ snapshot: before,
|
|
|
+ })
|
|
|
+ await Session.updatePart({
|
|
|
+ id: PartID.ascending(),
|
|
|
+ messageID: a.id,
|
|
|
+ sessionID: sid,
|
|
|
+ type: "step-finish",
|
|
|
+ reason: "stop",
|
|
|
+ snapshot: after,
|
|
|
+ cost: 0,
|
|
|
+ tokens,
|
|
|
+ })
|
|
|
+ await Session.updatePart({
|
|
|
+ id: PartID.ascending(),
|
|
|
+ messageID: a.id,
|
|
|
+ sessionID: sid,
|
|
|
+ type: "patch",
|
|
|
+ hash: patch.hash,
|
|
|
+ files: patch.files,
|
|
|
+ })
|
|
|
+ return u.id
|
|
|
+ }
|
|
|
+
|
|
|
+ const first = await turn("a.txt", "a1")
|
|
|
+ const second = await turn("b.txt", "b2")
|
|
|
+ const third = await turn("c.txt", "c3")
|
|
|
+
|
|
|
+ await SessionRevert.revert({
|
|
|
+ sessionID: sid,
|
|
|
+ messageID: first,
|
|
|
+ })
|
|
|
+ expect((await Session.get(sid)).revert?.messageID).toBe(first)
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0")
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0")
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
|
|
|
+
|
|
|
+ await SessionRevert.revert({
|
|
|
+ sessionID: sid,
|
|
|
+ messageID: second,
|
|
|
+ })
|
|
|
+ expect((await Session.get(sid)).revert?.messageID).toBe(second)
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0")
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
|
|
|
+
|
|
|
+ await SessionRevert.revert({
|
|
|
+ sessionID: sid,
|
|
|
+ messageID: third,
|
|
|
+ })
|
|
|
+ expect((await Session.get(sid)).revert?.messageID).toBe(third)
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2")
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
|
|
|
+
|
|
|
+ await SessionRevert.unrevert({
|
|
|
+ sessionID: sid,
|
|
|
+ })
|
|
|
+ expect((await Session.get(sid)).revert).toBeUndefined()
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2")
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c3")
|
|
|
+ },
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ test("restore same file in sequential order", async () => {
|
|
|
+ await using tmp = await tmpdir({ git: true })
|
|
|
+ await Instance.provide({
|
|
|
+ directory: tmp.path,
|
|
|
+ fn: async () => {
|
|
|
+ await fs.writeFile(path.join(tmp.path, "a.txt"), "a0")
|
|
|
+
|
|
|
+ const session = await Session.create({})
|
|
|
+ const sid = session.id
|
|
|
+
|
|
|
+ const turn = async (next: string) => {
|
|
|
+ const u = await user(sid)
|
|
|
+ await text(sid, u.id, `a.txt:${next}`)
|
|
|
+ const a = await assistant(sid, u.id, tmp.path)
|
|
|
+ const before = await Snapshot.track()
|
|
|
+ if (!before) throw new Error("expected snapshot")
|
|
|
+ await fs.writeFile(path.join(tmp.path, "a.txt"), next)
|
|
|
+ const after = await Snapshot.track()
|
|
|
+ if (!after) throw new Error("expected snapshot")
|
|
|
+ const patch = await Snapshot.patch(before)
|
|
|
+ await Session.updatePart({
|
|
|
+ id: PartID.ascending(),
|
|
|
+ messageID: a.id,
|
|
|
+ sessionID: sid,
|
|
|
+ type: "step-start",
|
|
|
+ snapshot: before,
|
|
|
+ })
|
|
|
+ await Session.updatePart({
|
|
|
+ id: PartID.ascending(),
|
|
|
+ messageID: a.id,
|
|
|
+ sessionID: sid,
|
|
|
+ type: "step-finish",
|
|
|
+ reason: "stop",
|
|
|
+ snapshot: after,
|
|
|
+ cost: 0,
|
|
|
+ tokens,
|
|
|
+ })
|
|
|
+ await Session.updatePart({
|
|
|
+ id: PartID.ascending(),
|
|
|
+ messageID: a.id,
|
|
|
+ sessionID: sid,
|
|
|
+ type: "patch",
|
|
|
+ hash: patch.hash,
|
|
|
+ files: patch.files,
|
|
|
+ })
|
|
|
+ return u.id
|
|
|
+ }
|
|
|
+
|
|
|
+ const first = await turn("a1")
|
|
|
+ const second = await turn("a2")
|
|
|
+ const third = await turn("a3")
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3")
|
|
|
+
|
|
|
+ await SessionRevert.revert({
|
|
|
+ sessionID: sid,
|
|
|
+ messageID: first,
|
|
|
+ })
|
|
|
+ expect((await Session.get(sid)).revert?.messageID).toBe(first)
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0")
|
|
|
+
|
|
|
+ await SessionRevert.revert({
|
|
|
+ sessionID: sid,
|
|
|
+ messageID: second,
|
|
|
+ })
|
|
|
+ expect((await Session.get(sid)).revert?.messageID).toBe(second)
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
|
|
|
+
|
|
|
+ await SessionRevert.revert({
|
|
|
+ sessionID: sid,
|
|
|
+ messageID: third,
|
|
|
+ })
|
|
|
+ expect((await Session.get(sid)).revert?.messageID).toBe(third)
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a2")
|
|
|
+
|
|
|
+ await SessionRevert.unrevert({
|
|
|
+ sessionID: sid,
|
|
|
+ })
|
|
|
+ expect((await Session.get(sid)).revert).toBeUndefined()
|
|
|
+ expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3")
|
|
|
+ },
|
|
|
+ })
|
|
|
+ })
|
|
|
})
|