revert.ts 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import z from "zod"
  2. import { Identifier } from "../id/id"
  3. import { Snapshot } from "../snapshot"
  4. import { MessageV2 } from "./message-v2"
  5. import { Session } from "."
  6. import { Log } from "../util/log"
  7. import { splitWhen } from "remeda"
  8. import { Storage } from "../storage/storage"
  9. import { Bus } from "../bus"
  10. import { SessionLock } from "./lock"
  11. export namespace SessionRevert {
  12. const log = Log.create({ service: "session.revert" })
  13. export const RevertInput = z.object({
  14. sessionID: Identifier.schema("session"),
  15. messageID: Identifier.schema("message"),
  16. partID: Identifier.schema("part").optional(),
  17. })
  18. export type RevertInput = z.infer<typeof RevertInput>
  19. export async function revert(input: RevertInput) {
  20. SessionLock.assertUnlocked(input.sessionID)
  21. using _ = SessionLock.acquire({
  22. sessionID: input.sessionID,
  23. })
  24. const all = await Session.messages({ sessionID: input.sessionID })
  25. let lastUser: MessageV2.User | undefined
  26. const session = await Session.get(input.sessionID)
  27. let revert: Session.Info["revert"]
  28. const patches: Snapshot.Patch[] = []
  29. for (const msg of all) {
  30. if (msg.info.role === "user") lastUser = msg.info
  31. const remaining = []
  32. for (const part of msg.parts) {
  33. if (revert) {
  34. if (part.type === "patch") {
  35. patches.push(part)
  36. }
  37. continue
  38. }
  39. if (!revert) {
  40. if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
  41. // if no useful parts left in message, same as reverting whole message
  42. const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
  43. revert = {
  44. messageID: !partID && lastUser ? lastUser.id : msg.info.id,
  45. partID,
  46. }
  47. }
  48. remaining.push(part)
  49. }
  50. }
  51. }
  52. if (revert) {
  53. const session = await Session.get(input.sessionID)
  54. revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
  55. await Snapshot.revert(patches)
  56. if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
  57. return Session.update(input.sessionID, (draft) => {
  58. draft.revert = revert
  59. })
  60. }
  61. return session
  62. }
  63. export async function unrevert(input: { sessionID: string }) {
  64. log.info("unreverting", input)
  65. SessionLock.assertUnlocked(input.sessionID)
  66. using _ = SessionLock.acquire({
  67. sessionID: input.sessionID,
  68. })
  69. const session = await Session.get(input.sessionID)
  70. if (!session.revert) return session
  71. if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
  72. const next = await Session.update(input.sessionID, (draft) => {
  73. draft.revert = undefined
  74. })
  75. return next
  76. }
  77. export async function cleanup(session: Session.Info) {
  78. if (!session.revert) return
  79. const sessionID = session.id
  80. let msgs = await Session.messages({ sessionID })
  81. const messageID = session.revert.messageID
  82. const [preserve, remove] = splitWhen(msgs, (x) => x.info.id === messageID)
  83. msgs = preserve
  84. for (const msg of remove) {
  85. await Storage.remove(["message", sessionID, msg.info.id])
  86. await Bus.publish(MessageV2.Event.Removed, { sessionID: sessionID, messageID: msg.info.id })
  87. }
  88. const last = preserve.at(-1)
  89. if (session.revert.partID && last) {
  90. const partID = session.revert.partID
  91. const [preserveParts, removeParts] = splitWhen(last.parts, (x) => x.id === partID)
  92. last.parts = preserveParts
  93. for (const part of removeParts) {
  94. await Storage.remove(["part", last.info.id, part.id])
  95. await Bus.publish(MessageV2.Event.PartRemoved, {
  96. sessionID: sessionID,
  97. messageID: last.info.id,
  98. partID: part.id,
  99. })
  100. }
  101. }
  102. await Session.update(sessionID, (draft) => {
  103. draft.revert = undefined
  104. })
  105. }
  106. }