revert-compact.test.ts 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. import { describe, expect, test, beforeEach, afterEach } from "bun:test"
  2. import fs from "fs/promises"
  3. import path from "path"
  4. import { Session } from "../../src/session"
  5. import { ModelID, ProviderID } from "../../src/provider/schema"
  6. import { SessionRevert } from "../../src/session/revert"
  7. import { SessionCompaction } from "../../src/session/compaction"
  8. import { MessageV2 } from "../../src/session/message-v2"
  9. import { Snapshot } from "../../src/snapshot"
  10. import { Log } from "../../src/util/log"
  11. import { Instance } from "../../src/project/instance"
  12. import { MessageID, PartID } from "../../src/session/schema"
  13. import { tmpdir } from "../fixture/fixture"
  14. Log.init({ print: false })
  15. function user(sessionID: string, agent = "default") {
  16. return Session.updateMessage({
  17. id: MessageID.ascending(),
  18. role: "user" as const,
  19. sessionID: sessionID as any,
  20. agent,
  21. model: { providerID: ProviderID.make("openai"), modelID: ModelID.make("gpt-4") },
  22. time: { created: Date.now() },
  23. })
  24. }
  25. function assistant(sessionID: string, parentID: string, dir: string) {
  26. return Session.updateMessage({
  27. id: MessageID.ascending(),
  28. role: "assistant" as const,
  29. sessionID: sessionID as any,
  30. mode: "default",
  31. agent: "default",
  32. path: { cwd: dir, root: dir },
  33. cost: 0,
  34. tokens: { output: 0, input: 0, reasoning: 0, cache: { read: 0, write: 0 } },
  35. modelID: ModelID.make("gpt-4"),
  36. providerID: ProviderID.make("openai"),
  37. parentID: parentID as any,
  38. time: { created: Date.now() },
  39. finish: "end_turn",
  40. })
  41. }
  42. function text(sessionID: string, messageID: string, content: string) {
  43. return Session.updatePart({
  44. id: PartID.ascending(),
  45. messageID: messageID as any,
  46. sessionID: sessionID as any,
  47. type: "text" as const,
  48. text: content,
  49. })
  50. }
  51. function tool(sessionID: string, messageID: string) {
  52. return Session.updatePart({
  53. id: PartID.ascending(),
  54. messageID: messageID as any,
  55. sessionID: sessionID as any,
  56. type: "tool" as const,
  57. tool: "bash",
  58. callID: "call-1",
  59. state: {
  60. status: "completed" as const,
  61. input: {},
  62. output: "done",
  63. title: "",
  64. metadata: {},
  65. time: { start: 0, end: 1 },
  66. },
  67. })
  68. }
  69. const tokens = {
  70. input: 0,
  71. output: 0,
  72. reasoning: 0,
  73. cache: { read: 0, write: 0 },
  74. }
  75. describe("revert + compact workflow", () => {
  76. test("should properly handle compact command after revert", async () => {
  77. await using tmp = await tmpdir({ git: true })
  78. await Instance.provide({
  79. directory: tmp.path,
  80. fn: async () => {
  81. // Create a session
  82. const session = await Session.create({})
  83. const sessionID = session.id
  84. // Create a user message
  85. const userMsg1 = await Session.updateMessage({
  86. id: MessageID.ascending(),
  87. role: "user",
  88. sessionID,
  89. agent: "default",
  90. model: {
  91. providerID: ProviderID.make("openai"),
  92. modelID: ModelID.make("gpt-4"),
  93. },
  94. time: {
  95. created: Date.now(),
  96. },
  97. })
  98. // Add a text part to the user message
  99. await Session.updatePart({
  100. id: PartID.ascending(),
  101. messageID: userMsg1.id,
  102. sessionID,
  103. type: "text",
  104. text: "Hello, please help me",
  105. })
  106. // Create an assistant response message
  107. const assistantMsg1: MessageV2.Assistant = {
  108. id: MessageID.ascending(),
  109. role: "assistant",
  110. sessionID,
  111. mode: "default",
  112. agent: "default",
  113. path: {
  114. cwd: tmp.path,
  115. root: tmp.path,
  116. },
  117. cost: 0,
  118. tokens: {
  119. output: 0,
  120. input: 0,
  121. reasoning: 0,
  122. cache: { read: 0, write: 0 },
  123. },
  124. modelID: ModelID.make("gpt-4"),
  125. providerID: ProviderID.make("openai"),
  126. parentID: userMsg1.id,
  127. time: {
  128. created: Date.now(),
  129. },
  130. finish: "end_turn",
  131. }
  132. await Session.updateMessage(assistantMsg1)
  133. // Add a text part to the assistant message
  134. await Session.updatePart({
  135. id: PartID.ascending(),
  136. messageID: assistantMsg1.id,
  137. sessionID,
  138. type: "text",
  139. text: "Sure, I'll help you!",
  140. })
  141. // Create another user message
  142. const userMsg2 = await Session.updateMessage({
  143. id: MessageID.ascending(),
  144. role: "user",
  145. sessionID,
  146. agent: "default",
  147. model: {
  148. providerID: ProviderID.make("openai"),
  149. modelID: ModelID.make("gpt-4"),
  150. },
  151. time: {
  152. created: Date.now(),
  153. },
  154. })
  155. await Session.updatePart({
  156. id: PartID.ascending(),
  157. messageID: userMsg2.id,
  158. sessionID,
  159. type: "text",
  160. text: "What's the capital of France?",
  161. })
  162. // Create another assistant response
  163. const assistantMsg2: MessageV2.Assistant = {
  164. id: MessageID.ascending(),
  165. role: "assistant",
  166. sessionID,
  167. mode: "default",
  168. agent: "default",
  169. path: {
  170. cwd: tmp.path,
  171. root: tmp.path,
  172. },
  173. cost: 0,
  174. tokens: {
  175. output: 0,
  176. input: 0,
  177. reasoning: 0,
  178. cache: { read: 0, write: 0 },
  179. },
  180. modelID: ModelID.make("gpt-4"),
  181. providerID: ProviderID.make("openai"),
  182. parentID: userMsg2.id,
  183. time: {
  184. created: Date.now(),
  185. },
  186. finish: "end_turn",
  187. }
  188. await Session.updateMessage(assistantMsg2)
  189. await Session.updatePart({
  190. id: PartID.ascending(),
  191. messageID: assistantMsg2.id,
  192. sessionID,
  193. type: "text",
  194. text: "The capital of France is Paris.",
  195. })
  196. // Verify messages before revert
  197. let messages = await Session.messages({ sessionID })
  198. expect(messages.length).toBe(4) // 2 user + 2 assistant messages
  199. const messageIds = messages.map((m) => m.info.id)
  200. expect(messageIds).toContain(userMsg1.id)
  201. expect(messageIds).toContain(userMsg2.id)
  202. expect(messageIds).toContain(assistantMsg1.id)
  203. expect(messageIds).toContain(assistantMsg2.id)
  204. // Revert the last user message (userMsg2)
  205. await SessionRevert.revert({
  206. sessionID,
  207. messageID: userMsg2.id,
  208. })
  209. // Check that revert state is set
  210. let sessionInfo = await Session.get(sessionID)
  211. expect(sessionInfo.revert).toBeDefined()
  212. const revertMessageID = sessionInfo.revert?.messageID
  213. expect(revertMessageID).toBeDefined()
  214. // Messages should still be in the list (not removed yet, just marked for revert)
  215. messages = await Session.messages({ sessionID })
  216. expect(messages.length).toBe(4)
  217. // Now clean up the revert state (this is what the compact endpoint should do)
  218. await SessionRevert.cleanup(sessionInfo)
  219. // After cleanup, the reverted messages (those after the revert point) should be removed
  220. messages = await Session.messages({ sessionID })
  221. const remainingIds = messages.map((m) => m.info.id)
  222. // The revert point is somewhere in the message chain, so we should have fewer messages
  223. expect(messages.length).toBeLessThan(4)
  224. // userMsg2 and assistantMsg2 should be removed (they come after the revert point)
  225. expect(remainingIds).not.toContain(userMsg2.id)
  226. expect(remainingIds).not.toContain(assistantMsg2.id)
  227. // Revert state should be cleared
  228. sessionInfo = await Session.get(sessionID)
  229. expect(sessionInfo.revert).toBeUndefined()
  230. // Clean up
  231. await Session.remove(sessionID)
  232. },
  233. })
  234. })
  235. test("should properly clean up revert state before creating compaction message", async () => {
  236. await using tmp = await tmpdir({ git: true })
  237. await Instance.provide({
  238. directory: tmp.path,
  239. fn: async () => {
  240. // Create a session
  241. const session = await Session.create({})
  242. const sessionID = session.id
  243. // Create initial messages
  244. const userMsg = await Session.updateMessage({
  245. id: MessageID.ascending(),
  246. role: "user",
  247. sessionID,
  248. agent: "default",
  249. model: {
  250. providerID: ProviderID.make("openai"),
  251. modelID: ModelID.make("gpt-4"),
  252. },
  253. time: {
  254. created: Date.now(),
  255. },
  256. })
  257. await Session.updatePart({
  258. id: PartID.ascending(),
  259. messageID: userMsg.id,
  260. sessionID,
  261. type: "text",
  262. text: "Hello",
  263. })
  264. const assistantMsg: MessageV2.Assistant = {
  265. id: MessageID.ascending(),
  266. role: "assistant",
  267. sessionID,
  268. mode: "default",
  269. agent: "default",
  270. path: {
  271. cwd: tmp.path,
  272. root: tmp.path,
  273. },
  274. cost: 0,
  275. tokens: {
  276. output: 0,
  277. input: 0,
  278. reasoning: 0,
  279. cache: { read: 0, write: 0 },
  280. },
  281. modelID: ModelID.make("gpt-4"),
  282. providerID: ProviderID.make("openai"),
  283. parentID: userMsg.id,
  284. time: {
  285. created: Date.now(),
  286. },
  287. finish: "end_turn",
  288. }
  289. await Session.updateMessage(assistantMsg)
  290. await Session.updatePart({
  291. id: PartID.ascending(),
  292. messageID: assistantMsg.id,
  293. sessionID,
  294. type: "text",
  295. text: "Hi there!",
  296. })
  297. // Revert the user message
  298. await SessionRevert.revert({
  299. sessionID,
  300. messageID: userMsg.id,
  301. })
  302. // Check that revert state is set
  303. let sessionInfo = await Session.get(sessionID)
  304. expect(sessionInfo.revert).toBeDefined()
  305. // Simulate what the compact endpoint does: cleanup revert before creating compaction
  306. await SessionRevert.cleanup(sessionInfo)
  307. // Verify revert state is cleared
  308. sessionInfo = await Session.get(sessionID)
  309. expect(sessionInfo.revert).toBeUndefined()
  310. // Verify messages are properly cleaned up
  311. const messages = await Session.messages({ sessionID })
  312. expect(messages.length).toBe(0) // All messages should be reverted
  313. // Clean up
  314. await Session.remove(sessionID)
  315. },
  316. })
  317. })
  318. test("cleanup with partID removes parts from the revert point onward", async () => {
  319. await using tmp = await tmpdir({ git: true })
  320. await Instance.provide({
  321. directory: tmp.path,
  322. fn: async () => {
  323. const session = await Session.create({})
  324. const sid = session.id
  325. const u1 = await user(sid)
  326. const p1 = await text(sid, u1.id, "first part")
  327. const p2 = await tool(sid, u1.id)
  328. const p3 = await text(sid, u1.id, "third part")
  329. // Set revert state pointing at a specific part
  330. await Session.setRevert({
  331. sessionID: sid,
  332. revert: { messageID: u1.id, partID: p2.id },
  333. summary: { additions: 0, deletions: 0, files: 0 },
  334. })
  335. const info = await Session.get(sid)
  336. await SessionRevert.cleanup(info)
  337. const msgs = await Session.messages({ sessionID: sid })
  338. expect(msgs.length).toBe(1)
  339. // Only the first part should remain (before the revert partID)
  340. expect(msgs[0].parts.length).toBe(1)
  341. expect(msgs[0].parts[0].id).toBe(p1.id)
  342. const cleared = await Session.get(sid)
  343. expect(cleared.revert).toBeUndefined()
  344. },
  345. })
  346. })
  347. test("cleanup removes messages after revert point but keeps earlier ones", async () => {
  348. await using tmp = await tmpdir({ git: true })
  349. await Instance.provide({
  350. directory: tmp.path,
  351. fn: async () => {
  352. const session = await Session.create({})
  353. const sid = session.id
  354. const u1 = await user(sid)
  355. await text(sid, u1.id, "hello")
  356. const a1 = await assistant(sid, u1.id, tmp.path)
  357. await text(sid, a1.id, "hi back")
  358. const u2 = await user(sid)
  359. await text(sid, u2.id, "second question")
  360. const a2 = await assistant(sid, u2.id, tmp.path)
  361. await text(sid, a2.id, "second answer")
  362. // Revert from u2 onward
  363. await Session.setRevert({
  364. sessionID: sid,
  365. revert: { messageID: u2.id },
  366. summary: { additions: 0, deletions: 0, files: 0 },
  367. })
  368. const info = await Session.get(sid)
  369. await SessionRevert.cleanup(info)
  370. const msgs = await Session.messages({ sessionID: sid })
  371. const ids = msgs.map((m) => m.info.id)
  372. expect(ids).toContain(u1.id)
  373. expect(ids).toContain(a1.id)
  374. expect(ids).not.toContain(u2.id)
  375. expect(ids).not.toContain(a2.id)
  376. },
  377. })
  378. })
  379. test("cleanup is a no-op when session has no revert state", async () => {
  380. await using tmp = await tmpdir({ git: true })
  381. await Instance.provide({
  382. directory: tmp.path,
  383. fn: async () => {
  384. const session = await Session.create({})
  385. const sid = session.id
  386. const u1 = await user(sid)
  387. await text(sid, u1.id, "hello")
  388. const info = await Session.get(sid)
  389. expect(info.revert).toBeUndefined()
  390. await SessionRevert.cleanup(info)
  391. const msgs = await Session.messages({ sessionID: sid })
  392. expect(msgs.length).toBe(1)
  393. },
  394. })
  395. })
  396. test("restore messages in sequential order", async () => {
  397. await using tmp = await tmpdir({ git: true })
  398. await Instance.provide({
  399. directory: tmp.path,
  400. fn: async () => {
  401. await fs.writeFile(path.join(tmp.path, "a.txt"), "a0")
  402. await fs.writeFile(path.join(tmp.path, "b.txt"), "b0")
  403. await fs.writeFile(path.join(tmp.path, "c.txt"), "c0")
  404. const session = await Session.create({})
  405. const sid = session.id
  406. const turn = async (file: string, next: string) => {
  407. const u = await user(sid)
  408. await text(sid, u.id, `${file}:${next}`)
  409. const a = await assistant(sid, u.id, tmp.path)
  410. const before = await Snapshot.track()
  411. if (!before) throw new Error("expected snapshot")
  412. await fs.writeFile(path.join(tmp.path, file), next)
  413. const after = await Snapshot.track()
  414. if (!after) throw new Error("expected snapshot")
  415. const patch = await Snapshot.patch(before)
  416. await Session.updatePart({
  417. id: PartID.ascending(),
  418. messageID: a.id,
  419. sessionID: sid,
  420. type: "step-start",
  421. snapshot: before,
  422. })
  423. await Session.updatePart({
  424. id: PartID.ascending(),
  425. messageID: a.id,
  426. sessionID: sid,
  427. type: "step-finish",
  428. reason: "stop",
  429. snapshot: after,
  430. cost: 0,
  431. tokens,
  432. })
  433. await Session.updatePart({
  434. id: PartID.ascending(),
  435. messageID: a.id,
  436. sessionID: sid,
  437. type: "patch",
  438. hash: patch.hash,
  439. files: patch.files,
  440. })
  441. return u.id
  442. }
  443. const first = await turn("a.txt", "a1")
  444. const second = await turn("b.txt", "b2")
  445. const third = await turn("c.txt", "c3")
  446. await SessionRevert.revert({
  447. sessionID: sid,
  448. messageID: first,
  449. })
  450. expect((await Session.get(sid)).revert?.messageID).toBe(first)
  451. expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0")
  452. expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0")
  453. expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
  454. await SessionRevert.revert({
  455. sessionID: sid,
  456. messageID: second,
  457. })
  458. expect((await Session.get(sid)).revert?.messageID).toBe(second)
  459. expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
  460. expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b0")
  461. expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
  462. await SessionRevert.revert({
  463. sessionID: sid,
  464. messageID: third,
  465. })
  466. expect((await Session.get(sid)).revert?.messageID).toBe(third)
  467. expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
  468. expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2")
  469. expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c0")
  470. await SessionRevert.unrevert({
  471. sessionID: sid,
  472. })
  473. expect((await Session.get(sid)).revert).toBeUndefined()
  474. expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
  475. expect(await fs.readFile(path.join(tmp.path, "b.txt"), "utf-8")).toBe("b2")
  476. expect(await fs.readFile(path.join(tmp.path, "c.txt"), "utf-8")).toBe("c3")
  477. },
  478. })
  479. })
  480. test("restore same file in sequential order", async () => {
  481. await using tmp = await tmpdir({ git: true })
  482. await Instance.provide({
  483. directory: tmp.path,
  484. fn: async () => {
  485. await fs.writeFile(path.join(tmp.path, "a.txt"), "a0")
  486. const session = await Session.create({})
  487. const sid = session.id
  488. const turn = async (next: string) => {
  489. const u = await user(sid)
  490. await text(sid, u.id, `a.txt:${next}`)
  491. const a = await assistant(sid, u.id, tmp.path)
  492. const before = await Snapshot.track()
  493. if (!before) throw new Error("expected snapshot")
  494. await fs.writeFile(path.join(tmp.path, "a.txt"), next)
  495. const after = await Snapshot.track()
  496. if (!after) throw new Error("expected snapshot")
  497. const patch = await Snapshot.patch(before)
  498. await Session.updatePart({
  499. id: PartID.ascending(),
  500. messageID: a.id,
  501. sessionID: sid,
  502. type: "step-start",
  503. snapshot: before,
  504. })
  505. await Session.updatePart({
  506. id: PartID.ascending(),
  507. messageID: a.id,
  508. sessionID: sid,
  509. type: "step-finish",
  510. reason: "stop",
  511. snapshot: after,
  512. cost: 0,
  513. tokens,
  514. })
  515. await Session.updatePart({
  516. id: PartID.ascending(),
  517. messageID: a.id,
  518. sessionID: sid,
  519. type: "patch",
  520. hash: patch.hash,
  521. files: patch.files,
  522. })
  523. return u.id
  524. }
  525. const first = await turn("a1")
  526. const second = await turn("a2")
  527. const third = await turn("a3")
  528. expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3")
  529. await SessionRevert.revert({
  530. sessionID: sid,
  531. messageID: first,
  532. })
  533. expect((await Session.get(sid)).revert?.messageID).toBe(first)
  534. expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a0")
  535. await SessionRevert.revert({
  536. sessionID: sid,
  537. messageID: second,
  538. })
  539. expect((await Session.get(sid)).revert?.messageID).toBe(second)
  540. expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a1")
  541. await SessionRevert.revert({
  542. sessionID: sid,
  543. messageID: third,
  544. })
  545. expect((await Session.get(sid)).revert?.messageID).toBe(third)
  546. expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a2")
  547. await SessionRevert.unrevert({
  548. sessionID: sid,
  549. })
  550. expect((await Session.get(sid)).revert).toBeUndefined()
  551. expect(await fs.readFile(path.join(tmp.path, "a.txt"), "utf-8")).toBe("a3")
  552. },
  553. })
  554. })
  555. })