revert.ts 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import z from "zod/v4"
  2. import { Identifier } from "../id/id"
  3. import { Snapshot } from "../snapshot"
  4. import { MessageV2 } from "./message-v2"
  5. import { Session } from "."
  6. import { Log } from "../util/log"
  7. import { splitWhen } from "remeda"
  8. import { Storage } from "../storage/storage"
  9. import { Bus } from "../bus"
  10. export namespace SessionRevert {
  11. const log = Log.create({ service: "session.revert" })
  12. export const RevertInput = z.object({
  13. sessionID: Identifier.schema("session"),
  14. messageID: Identifier.schema("message"),
  15. partID: Identifier.schema("part").optional(),
  16. })
  17. export type RevertInput = z.infer<typeof RevertInput>
  18. export async function revert(input: RevertInput) {
  19. const all = await Session.messages(input.sessionID)
  20. let lastUser: MessageV2.User | undefined
  21. const session = await Session.get(input.sessionID)
  22. let revert: Session.Info["revert"]
  23. const patches: Snapshot.Patch[] = []
  24. for (const msg of all) {
  25. if (msg.info.role === "user") lastUser = msg.info
  26. const remaining = []
  27. for (const part of msg.parts) {
  28. if (revert) {
  29. if (part.type === "patch") {
  30. patches.push(part)
  31. }
  32. continue
  33. }
  34. if (!revert) {
  35. if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
  36. // if no useful parts left in message, same as reverting whole message
  37. const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
  38. revert = {
  39. messageID: !partID && lastUser ? lastUser.id : msg.info.id,
  40. partID,
  41. }
  42. }
  43. remaining.push(part)
  44. }
  45. }
  46. }
  47. if (revert) {
  48. const session = await Session.get(input.sessionID)
  49. revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
  50. await Snapshot.revert(patches)
  51. if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
  52. return Session.update(input.sessionID, (draft) => {
  53. draft.revert = revert
  54. })
  55. }
  56. return session
  57. }
  58. export async function unrevert(input: { sessionID: string }) {
  59. log.info("unreverting", input)
  60. const session = await Session.get(input.sessionID)
  61. if (!session.revert) return session
  62. if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
  63. const next = await Session.update(input.sessionID, (draft) => {
  64. draft.revert = undefined
  65. })
  66. return next
  67. }
  68. export async function cleanup(session: Session.Info) {
  69. if (!session.revert) return
  70. const sessionID = session.id
  71. let msgs = await Session.messages(sessionID)
  72. const messageID = session.revert.messageID
  73. const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
  74. msgs = preserve
  75. for (const msg of remove) {
  76. await Storage.remove(["message", sessionID, msg.info.id])
  77. await Bus.publish(MessageV2.Event.Removed, { sessionID: sessionID, messageID: msg.info.id })
  78. }
  79. const last = preserve.at(-1)
  80. if (session.revert.partID && last) {
  81. const partID = session.revert.partID
  82. const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
  83. last.parts = preserveParts
  84. for (const part of removeParts) {
  85. await Storage.remove(["part", last.info.id, part.id])
  86. await Bus.publish(MessageV2.Event.PartRemoved, {
  87. sessionID: sessionID,
  88. messageID: last.info.id,
  89. partID: part.id,
  90. })
  91. }
  92. }
  93. await Session.update(sessionID, (draft) => {
  94. draft.revert = undefined
  95. })
  96. }
  97. }