| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621 |
- import { describe, expect, test, beforeEach, afterEach } from "bun:test"
- import fs from "fs/promises"
- import path from "path"
- import { Session } from "../../src/session"
- import { ModelID, ProviderID } from "../../src/provider/schema"
- import { SessionRevert } from "../../src/session/revert"
- import { SessionCompaction } from "../../src/session/compaction"
- import { MessageV2 } from "../../src/session/message-v2"
- import { Snapshot } from "../../src/snapshot"
- import { Log } from "../../src/util/log"
- import { Instance } from "../../src/project/instance"
- import { MessageID, PartID } from "../../src/session/schema"
- import { tmpdir } from "../fixture/fixture"
- Log.init({ print: false })
- function user(sessionID: string, agent = "default") {
- return Session.updateMessage({
- id: MessageID.ascending(),
- role: "user" as const,
- sessionID: sessionID as any,
- agent,
- model: { providerID: ProviderID.make("openai"), modelID: ModelID.make("gpt-4") },
- time: { created: Date.now() },
- })
- }
- function assistant(sessionID: string, parentID: string, dir: string) {
- return Session.updateMessage({
- id: MessageID.ascending(),
- role: "assistant" as const,
- sessionID: sessionID as any,
- mode: "default",
- agent: "default",
- path: { cwd: dir, root: dir },
- cost: 0,
- tokens: { output: 0, input: 0, reasoning: 0, cache: { read: 0, write: 0 } },
- modelID: ModelID.make("gpt-4"),
- providerID: ProviderID.make("openai"),
- parentID: parentID as any,
- time: { created: Date.now() },
- finish: "end_turn",
- })
- }
- function text(sessionID: string, messageID: string, content: string) {
- return Session.updatePart({
- id: PartID.ascending(),
- messageID: messageID as any,
- sessionID: sessionID as any,
- type: "text" as const,
- text: content,
- })
- }
- function tool(sessionID: string, messageID: string) {
- return Session.updatePart({
- id: PartID.ascending(),
- messageID: messageID as any,
- sessionID: sessionID as any,
- type: "tool" as const,
- tool: "bash",
- callID: "call-1",
- state: {
- status: "completed" as const,
- input: {},
- output: "done",
- title: "",
- metadata: {},
- time: { start: 0, end: 1 },
- },
- })
- }
- const tokens = {
- input: 0,
- output: 0,
- reasoning: 0,
- cache: { read: 0, write: 0 },
- }
- describe("revert + compact workflow", () => {
- test("should properly handle compact command after revert", async () => {
- await using tmp = await tmpdir({ git: true })
- await Instance.provide({
- directory: tmp.path,
- fn: async () => {
- // Create a session
- const session = await Session.create({})
- const sessionID = session.id
- // Create a user message
- const userMsg1 = await Session.updateMessage({
- id: MessageID.ascending(),
- role: "user",
- sessionID,
- agent: "default",
- model: {
- providerID: ProviderID.make("openai"),
- modelID: ModelID.make("gpt-4"),
- },
- time: {
- created: Date.now(),
- },
- })
- // Add a text part to the user message
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: userMsg1.id,
- sessionID,
- type: "text",
- text: "Hello, please help me",
- })
- // Create an assistant response message
- const assistantMsg1: MessageV2.Assistant = {
- id: MessageID.ascending(),
- role: "assistant",
- sessionID,
- mode: "default",
- agent: "default",
- path: {
- cwd: tmp.path,
- root: tmp.path,
- },
- cost: 0,
- tokens: {
- output: 0,
- input: 0,
- reasoning: 0,
- cache: { read: 0, write: 0 },
- },
- modelID: ModelID.make("gpt-4"),
- providerID: ProviderID.make("openai"),
- parentID: userMsg1.id,
- time: {
- created: Date.now(),
- },
- finish: "end_turn",
- }
- await Session.updateMessage(assistantMsg1)
- // Add a text part to the assistant message
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: assistantMsg1.id,
- sessionID,
- type: "text",
- text: "Sure, I'll help you!",
- })
- // Create another user message
- const userMsg2 = await Session.updateMessage({
- id: MessageID.ascending(),
- role: "user",
- sessionID,
- agent: "default",
- model: {
- providerID: ProviderID.make("openai"),
- modelID: ModelID.make("gpt-4"),
- },
- time: {
- created: Date.now(),
- },
- })
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: userMsg2.id,
- sessionID,
- type: "text",
- text: "What's the capital of France?",
- })
- // Create another assistant response
- const assistantMsg2: MessageV2.Assistant = {
- id: MessageID.ascending(),
- role: "assistant",
- sessionID,
- mode: "default",
- agent: "default",
- path: {
- cwd: tmp.path,
- root: tmp.path,
- },
- cost: 0,
- tokens: {
- output: 0,
- input: 0,
- reasoning: 0,
- cache: { read: 0, write: 0 },
- },
- modelID: ModelID.make("gpt-4"),
- providerID: ProviderID.make("openai"),
- parentID: userMsg2.id,
- time: {
- created: Date.now(),
- },
- finish: "end_turn",
- }
- await Session.updateMessage(assistantMsg2)
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: assistantMsg2.id,
- sessionID,
- type: "text",
- text: "The capital of France is Paris.",
- })
- // Verify messages before revert
- let messages = await Session.messages({ sessionID })
- expect(messages.length).toBe(4) // 2 user + 2 assistant messages
- const messageIds = messages.map((m) => m.info.id)
- expect(messageIds).toContain(userMsg1.id)
- expect(messageIds).toContain(userMsg2.id)
- expect(messageIds).toContain(assistantMsg1.id)
- expect(messageIds).toContain(assistantMsg2.id)
- // Revert the last user message (userMsg2)
- await SessionRevert.revert({
- sessionID,
- messageID: userMsg2.id,
- })
- // Check that revert state is set
- let sessionInfo = await Session.get(sessionID)
- expect(sessionInfo.revert).toBeDefined()
- const revertMessageID = sessionInfo.revert?.messageID
- expect(revertMessageID).toBeDefined()
- // Messages should still be in the list (not removed yet, just marked for revert)
- messages = await Session.messages({ sessionID })
- expect(messages.length).toBe(4)
- // Now clean up the revert state (this is what the compact endpoint should do)
- await SessionRevert.cleanup(sessionInfo)
- // After cleanup, the reverted messages (those after the revert point) should be removed
- messages = await Session.messages({ sessionID })
- const remainingIds = messages.map((m) => m.info.id)
- // The revert point is somewhere in the message chain, so we should have fewer messages
- expect(messages.length).toBeLessThan(4)
- // userMsg2 and assistantMsg2 should be removed (they come after the revert point)
- expect(remainingIds).not.toContain(userMsg2.id)
- expect(remainingIds).not.toContain(assistantMsg2.id)
- // Revert state should be cleared
- sessionInfo = await Session.get(sessionID)
- expect(sessionInfo.revert).toBeUndefined()
- // Clean up
- await Session.remove(sessionID)
- },
- })
- })
- test("should properly clean up revert state before creating compaction message", async () => {
- await using tmp = await tmpdir({ git: true })
- await Instance.provide({
- directory: tmp.path,
- fn: async () => {
- // Create a session
- const session = await Session.create({})
- const sessionID = session.id
- // Create initial messages
- const userMsg = await Session.updateMessage({
- id: MessageID.ascending(),
- role: "user",
- sessionID,
- agent: "default",
- model: {
- providerID: ProviderID.make("openai"),
- modelID: ModelID.make("gpt-4"),
- },
- time: {
- created: Date.now(),
- },
- })
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: userMsg.id,
- sessionID,
- type: "text",
- text: "Hello",
- })
- const assistantMsg: MessageV2.Assistant = {
- id: MessageID.ascending(),
- role: "assistant",
- sessionID,
- mode: "default",
- agent: "default",
- path: {
- cwd: tmp.path,
- root: tmp.path,
- },
- cost: 0,
- tokens: {
- output: 0,
- input: 0,
- reasoning: 0,
- cache: { read: 0, write: 0 },
- },
- modelID: ModelID.make("gpt-4"),
- providerID: ProviderID.make("openai"),
- parentID: userMsg.id,
- time: {
- created: Date.now(),
- },
- finish: "end_turn",
- }
- await Session.updateMessage(assistantMsg)
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: assistantMsg.id,
- sessionID,
- type: "text",
- text: "Hi there!",
- })
- // Revert the user message
- await SessionRevert.revert({
- sessionID,
- messageID: userMsg.id,
- })
- // Check that revert state is set
- let sessionInfo = await Session.get(sessionID)
- expect(sessionInfo.revert).toBeDefined()
- // Simulate what the compact endpoint does: cleanup revert before creating compaction
- await SessionRevert.cleanup(sessionInfo)
- // Verify revert state is cleared
- sessionInfo = await Session.get(sessionID)
- expect(sessionInfo.revert).toBeUndefined()
- // Verify messages are properly cleaned up
- const messages = await Session.messages({ sessionID })
- expect(messages.length).toBe(0) // All messages should be reverted
- // Clean up
- await Session.remove(sessionID)
- },
- })
- })
- test("cleanup with partID removes parts from the revert point onward", async () => {
- await using tmp = await tmpdir({ git: true })
- await Instance.provide({
- directory: tmp.path,
- fn: async () => {
- const session = await Session.create({})
- const sid = session.id
- const u1 = await user(sid)
- const p1 = await text(sid, u1.id, "first part")
- const p2 = await tool(sid, u1.id)
- const p3 = await text(sid, u1.id, "third part")
- // Set revert state pointing at a specific part
- await Session.setRevert({
- sessionID: sid,
- revert: { messageID: u1.id, partID: p2.id },
- summary: { additions: 0, deletions: 0, files: 0 },
- })
- const info = await Session.get(sid)
- await SessionRevert.cleanup(info)
- const msgs = await Session.messages({ sessionID: sid })
- expect(msgs.length).toBe(1)
- // Only the first part should remain (before the revert partID)
- expect(msgs[0].parts.length).toBe(1)
- expect(msgs[0].parts[0].id).toBe(p1.id)
- const cleared = await Session.get(sid)
- expect(cleared.revert).toBeUndefined()
- },
- })
- })
- test("cleanup removes messages after revert point but keeps earlier ones", async () => {
- await using tmp = await tmpdir({ git: true })
- await Instance.provide({
- directory: tmp.path,
- fn: async () => {
- const session = await Session.create({})
- const sid = session.id
- const u1 = await user(sid)
- await text(sid, u1.id, "hello")
- const a1 = await assistant(sid, u1.id, tmp.path)
- await text(sid, a1.id, "hi back")
- const u2 = await user(sid)
- await text(sid, u2.id, "second question")
- const a2 = await assistant(sid, u2.id, tmp.path)
- await text(sid, a2.id, "second answer")
- // Revert from u2 onward
- await Session.setRevert({
- sessionID: sid,
- revert: { messageID: u2.id },
- summary: { additions: 0, deletions: 0, files: 0 },
- })
- const info = await Session.get(sid)
- await SessionRevert.cleanup(info)
- const msgs = await Session.messages({ sessionID: sid })
- const ids = msgs.map((m) => m.info.id)
- expect(ids).toContain(u1.id)
- expect(ids).toContain(a1.id)
- expect(ids).not.toContain(u2.id)
- expect(ids).not.toContain(a2.id)
- },
- })
- })
- test("cleanup is a no-op when session has no revert state", async () => {
- await using tmp = await tmpdir({ git: true })
- await Instance.provide({
- directory: tmp.path,
- fn: async () => {
- const session = await Session.create({})
- const sid = session.id
- const u1 = await user(sid)
- await text(sid, u1.id, "hello")
- const info = await Session.get(sid)
- expect(info.revert).toBeUndefined()
- await SessionRevert.cleanup(info)
- const msgs = await Session.messages({ sessionID: sid })
- expect(msgs.length).toBe(1)
- },
- })
- })
- test("restore messages in sequential order", async () => {
- await using tmp = await tmpdir({ git: true })
- await Instance.provide({
- directory: tmp.path,
- fn: async () => {
- await fs.writeFile(path.join(tmp.path, "a.txt"), "a0")
- await fs.writeFile(path.join(tmp.path, "b.txt"), "b0")
- await fs.writeFile(path.join(tmp.path, "c.txt"), "c0")
- const session = await Session.create({})
- const sid = session.id
- const turn = async (file: string, next: string) => {
- const u = await user(sid)
- await text(sid, u.id, `${file}:${next}`)
- const a = await assistant(sid, u.id, tmp.path)
- const before = await Snapshot.track()
- if (!before) throw new Error("expected snapshot")
- await fs.writeFile(path.join(tmp.path, file), next)
- const after = await Snapshot.track()
- if (!after) throw new Error("expected snapshot")
- const patch = await Snapshot.patch(before)
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: a.id,
- sessionID: sid,
- type: "step-start",
- snapshot: before,
- })
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: a.id,
- sessionID: sid,
- type: "step-finish",
- reason: "stop",
- snapshot: after,
- cost: 0,
- tokens,
- })
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: a.id,
- sessionID: sid,
- type: "patch",
- hash: patch.hash,
- files: patch.files,
- })
- return u.id
- }
- const first = await turn("a.txt", "a1")
- const second = await turn("b.txt", "b2")
- const third = await turn("c.txt", "c3")
- await SessionRevert.revert({
- sessionID: sid,
- messageID: first,
- })
- expect((await Session.get(sid)).revert?.messageID).toBe(first)
- expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0")
- expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0")
- expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
- await SessionRevert.revert({
- sessionID: sid,
- messageID: second,
- })
- expect((await Session.get(sid)).revert?.messageID).toBe(second)
- expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
- expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0")
- expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
- await SessionRevert.revert({
- sessionID: sid,
- messageID: third,
- })
- expect((await Session.get(sid)).revert?.messageID).toBe(third)
- expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
- expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2")
- expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
- await SessionRevert.unrevert({
- sessionID: sid,
- })
- expect((await Session.get(sid)).revert).toBeUndefined()
- expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
- expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2")
- expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c3")
- },
- })
- })
- test("restore same file in sequential order", async () => {
- await using tmp = await tmpdir({ git: true })
- await Instance.provide({
- directory: tmp.path,
- fn: async () => {
- await fs.writeFile(path.join(tmp.path, "a.txt"), "a0")
- const session = await Session.create({})
- const sid = session.id
- const turn = async (next: string) => {
- const u = await user(sid)
- await text(sid, u.id, `a.txt:${next}`)
- const a = await assistant(sid, u.id, tmp.path)
- const before = await Snapshot.track()
- if (!before) throw new Error("expected snapshot")
- await fs.writeFile(path.join(tmp.path, "a.txt"), next)
- const after = await Snapshot.track()
- if (!after) throw new Error("expected snapshot")
- const patch = await Snapshot.patch(before)
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: a.id,
- sessionID: sid,
- type: "step-start",
- snapshot: before,
- })
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: a.id,
- sessionID: sid,
- type: "step-finish",
- reason: "stop",
- snapshot: after,
- cost: 0,
- tokens,
- })
- await Session.updatePart({
- id: PartID.ascending(),
- messageID: a.id,
- sessionID: sid,
- type: "patch",
- hash: patch.hash,
- files: patch.files,
- })
- return u.id
- }
- const first = await turn("a1")
- const second = await turn("a2")
- const third = await turn("a3")
- expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3")
- await SessionRevert.revert({
- sessionID: sid,
- messageID: first,
- })
- expect((await Session.get(sid)).revert?.messageID).toBe(first)
- expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0")
- await SessionRevert.revert({
- sessionID: sid,
- messageID: second,
- })
- expect((await Session.get(sid)).revert?.messageID).toBe(second)
- expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
- await SessionRevert.revert({
- sessionID: sid,
- messageID: third,
- })
- expect((await Session.get(sid)).revert?.messageID).toBe(third)
- expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a2")
- await SessionRevert.unrevert({
- sessionID: sid,
- })
- expect((await Session.get(sid)).revert).toBeUndefined()
- expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3")
- },
- })
- })
- })
|