Просмотр исходного кода

refactor(session): effectify SessionRevert service (#20143)

Kit Langton 2 недель назад
Родитель
Сommit
3fc0367b93
2 измененных файлов с 287 добавлено и 105 удалено
  1. 142 104
      packages/opencode/src/session/revert.ts
  2. 145 1
      packages/opencode/test/session/revert-compact.test.ts

+ 142 - 104
packages/opencode/src/session/revert.ts

@@ -1,12 +1,14 @@
 import z from "zod"
-import { SessionID, MessageID, PartID } from "./schema"
+import { Effect, Layer, ServiceMap } from "effect"
+import { makeRuntime } from "@/effect/run-service"
+import { Bus } from "../bus"
 import { Snapshot } from "../snapshot"
-import { MessageV2 } from "./message-v2"
-import { Session } from "."
-import { Log } from "../util/log"
-import { SyncEvent } from "../sync"
 import { Storage } from "@/storage/storage"
-import { Bus } from "../bus"
+import { SyncEvent } from "../sync"
+import { Log } from "../util/log"
+import { Session } from "."
+import { MessageV2 } from "./message-v2"
+import { SessionID, MessageID, PartID } from "./schema"
 import { SessionPrompt } from "./prompt"
 import { SessionSummary } from "./summary"
 
@@ -20,116 +22,152 @@ export namespace SessionRevert {
   })
   export type RevertInput = z.infer<typeof RevertInput>
 
-  export async function revert(input: RevertInput) {
-    await SessionPrompt.assertNotBusy(input.sessionID)
-    const all = await Session.messages({ sessionID: input.sessionID })
-    let lastUser: MessageV2.User | undefined
-    const session = await Session.get(input.sessionID)
-
-    let revert: Session.Info["revert"]
-    const patches: Snapshot.Patch[] = []
-    for (const msg of all) {
-      if (msg.info.role === "user") lastUser = msg.info
-      const remaining = []
-      for (const part of msg.parts) {
-        if (revert) {
-          if (part.type === "patch") {
-            patches.push(part)
+  export interface Interface {
+    readonly revert: (input: RevertInput) => Effect.Effect<Session.Info>
+    readonly unrevert: (input: { sessionID: SessionID }) => Effect.Effect<Session.Info>
+    readonly cleanup: (session: Session.Info) => Effect.Effect<void>
+  }
+
+  export class Service extends ServiceMap.Service<Service, Interface>()("@opencode/SessionRevert") {}
+
+  export const layer = Layer.effect(
+    Service,
+    Effect.gen(function* () {
+      const sessions = yield* Session.Service
+      const snap = yield* Snapshot.Service
+      const storage = yield* Storage.Service
+      const bus = yield* Bus.Service
+
+      const revert = Effect.fn("SessionRevert.revert")(function* (input: RevertInput) {
+        yield* Effect.promise(() => SessionPrompt.assertNotBusy(input.sessionID))
+        const all = yield* sessions.messages({ sessionID: input.sessionID })
+        let lastUser: MessageV2.User | undefined
+        const session = yield* sessions.get(input.sessionID)
+
+        let rev: Session.Info["revert"]
+        const patches: Snapshot.Patch[] = []
+        for (const msg of all) {
+          if (msg.info.role === "user") lastUser = msg.info
+          const remaining = []
+          for (const part of msg.parts) {
+            if (rev) {
+              if (part.type === "patch") patches.push(part)
+              continue
+            }
+
+            if (!rev) {
+              if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
+                const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
+                rev = {
+                  messageID: !partID && lastUser ? lastUser.id : msg.info.id,
+                  partID,
+                }
+              }
+              remaining.push(part)
+            }
           }
-          continue
         }
 
-        if (!revert) {
-          if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
-            // if no useful parts left in message, same as reverting whole message
-            const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
-            revert = {
-              messageID: !partID && lastUser ? lastUser.id : msg.info.id,
-              partID,
+        if (!rev) return session
+
+        rev.snapshot = session.revert?.snapshot ?? (yield* snap.track())
+        yield* snap.revert(patches)
+        if (rev.snapshot) rev.diff = yield* snap.diff(rev.snapshot as string)
+        const range = all.filter((msg) => msg.info.id >= rev!.messageID)
+        const diffs = yield* Effect.promise(() => SessionSummary.computeDiff({ messages: range }))
+        yield* storage.write(["session_diff", input.sessionID], diffs).pipe(Effect.ignore)
+        yield* bus.publish(Session.Event.Diff, { sessionID: input.sessionID, diff: diffs })
+        yield* sessions.setRevert({
+          sessionID: input.sessionID,
+          revert: rev,
+          summary: {
+            additions: diffs.reduce((sum, x) => sum + x.additions, 0),
+            deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
+            files: diffs.length,
+          },
+        })
+        return yield* sessions.get(input.sessionID)
+      })
+
+      const unrevert = Effect.fn("SessionRevert.unrevert")(function* (input: { sessionID: SessionID }) {
+        log.info("unreverting", input)
+        yield* Effect.promise(() => SessionPrompt.assertNotBusy(input.sessionID))
+        const session = yield* sessions.get(input.sessionID)
+        if (!session.revert) return session
+        if (session.revert.snapshot) yield* snap.restore(session.revert!.snapshot!)
+        yield* sessions.clearRevert(input.sessionID)
+        return yield* sessions.get(input.sessionID)
+      })
+
+      const cleanup = Effect.fn("SessionRevert.cleanup")(function* (session: Session.Info) {
+        if (!session.revert) return
+        const sessionID = session.id
+        const msgs = yield* sessions.messages({ sessionID })
+        const messageID = session.revert.messageID
+        const remove = [] as MessageV2.WithParts[]
+        let target: MessageV2.WithParts | undefined
+        for (const msg of msgs) {
+          if (msg.info.id < messageID) continue
+          if (msg.info.id > messageID) {
+            remove.push(msg)
+            continue
+          }
+          if (session.revert.partID) {
+            target = msg
+            continue
+          }
+          remove.push(msg)
+        }
+        for (const msg of remove) {
+          SyncEvent.run(MessageV2.Event.Removed, {
+            sessionID,
+            messageID: msg.info.id,
+          })
+        }
+        if (session.revert.partID && target) {
+          const partID = session.revert.partID
+          const idx = target.parts.findIndex((part) => part.id === partID)
+          if (idx >= 0) {
+            const removeParts = target.parts.slice(idx)
+            target.parts = target.parts.slice(0, idx)
+            for (const part of removeParts) {
+              SyncEvent.run(MessageV2.Event.PartRemoved, {
+                sessionID,
+                messageID: target.info.id,
+                partID: part.id,
+              })
             }
           }
-          remaining.push(part)
         }
-      }
-    }
-
-    if (revert) {
-      const session = await Session.get(input.sessionID)
-      revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
-      await Snapshot.revert(patches)
-      if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot)
-      const rangeMessages = all.filter((msg) => msg.info.id >= revert!.messageID)
-      const diffs = await SessionSummary.computeDiff({ messages: rangeMessages })
-      await Storage.write(["session_diff", input.sessionID], diffs)
-      Bus.publish(Session.Event.Diff, {
-        sessionID: input.sessionID,
-        diff: diffs,
-      })
-      return Session.setRevert({
-        sessionID: input.sessionID,
-        revert,
-        summary: {
-          additions: diffs.reduce((sum, x) => sum + x.additions, 0),
-          deletions: diffs.reduce((sum, x) => sum + x.deletions, 0),
-          files: diffs.length,
-        },
+        yield* sessions.clearRevert(sessionID)
       })
-    }
-    return session
+
+      return Service.of({ revert, unrevert, cleanup })
+    }),
+  )
+
+  export const defaultLayer = Layer.unwrap(
+    Effect.sync(() =>
+      layer.pipe(
+        Layer.provide(Session.defaultLayer),
+        Layer.provide(Snapshot.defaultLayer),
+        Layer.provide(Storage.defaultLayer),
+        Layer.provide(Bus.layer),
+      ),
+    ),
+  )
+
+  const { runPromise } = makeRuntime(Service, defaultLayer)
+
+  export async function revert(input: RevertInput) {
+    return runPromise((svc) => svc.revert(input))
   }
 
   export async function unrevert(input: { sessionID: SessionID }) {
-    log.info("unreverting", input)
-    await SessionPrompt.assertNotBusy(input.sessionID)
-    const session = await Session.get(input.sessionID)
-    if (!session.revert) return session
-    if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
-    return Session.clearRevert(input.sessionID)
+    return runPromise((svc) => svc.unrevert(input))
   }
 
   export async function cleanup(session: Session.Info) {
-    if (!session.revert) return
-    const sessionID = session.id
-    const msgs = await Session.messages({ sessionID })
-    const messageID = session.revert.messageID
-    const remove = [] as MessageV2.WithParts[]
-    let target: MessageV2.WithParts | undefined
-    for (const msg of msgs) {
-      if (msg.info.id < messageID) {
-        continue
-      }
-      if (msg.info.id > messageID) {
-        remove.push(msg)
-        continue
-      }
-      if (session.revert.partID) {
-        target = msg
-        continue
-      }
-      remove.push(msg)
-    }
-    for (const msg of remove) {
-      SyncEvent.run(MessageV2.Event.Removed, {
-        sessionID: sessionID,
-        messageID: msg.info.id,
-      })
-    }
-    if (session.revert.partID && target) {
-      const partID = session.revert.partID
-      const removeStart = target.parts.findIndex((part) => part.id === partID)
-      if (removeStart >= 0) {
-        const preserveParts = target.parts.slice(0, removeStart)
-        const removeParts = target.parts.slice(removeStart)
-        target.parts = preserveParts
-        for (const part of removeParts) {
-          SyncEvent.run(MessageV2.Event.PartRemoved, {
-            sessionID: sessionID,
-            messageID: target.info.id,
-            partID: part.id,
-          })
-        }
-      }
-    }
-    await Session.clearRevert(sessionID)
+    return runPromise((svc) => svc.cleanup(session))
   }
 }

+ 145 - 1
packages/opencode/test/session/revert-compact.test.ts

@@ -10,9 +10,59 @@ import { Instance } from "../../src/project/instance"
 import { MessageID, PartID } from "../../src/session/schema"
 import { tmpdir } from "../fixture/fixture"
 
-const projectRoot = path.join(__dirname, "../..")
 Log.init({ print: false })
 
+function user(sessionID: string, agent = "default") {
+  return Session.updateMessage({
+    id: MessageID.ascending(),
+    role: "user" as const,
+    sessionID: sessionID as any,
+    agent,
+    model: { providerID: ProviderID.make("openai"), modelID: ModelID.make("gpt-4") },
+    time: { created: Date.now() },
+  })
+}
+
+function assistant(sessionID: string, parentID: string, dir: string) {
+  return Session.updateMessage({
+    id: MessageID.ascending(),
+    role: "assistant" as const,
+    sessionID: sessionID as any,
+    mode: "default",
+    agent: "default",
+    path: { cwd: dir, root: dir },
+    cost: 0,
+    tokens: { output: 0, input: 0, reasoning: 0, cache: { read: 0, write: 0 } },
+    modelID: ModelID.make("gpt-4"),
+    providerID: ProviderID.make("openai"),
+    parentID: parentID as any,
+    time: { created: Date.now() },
+    finish: "end_turn",
+  })
+}
+
+function text(sessionID: string, messageID: string, content: string) {
+  return Session.updatePart({
+    id: PartID.ascending(),
+    messageID: messageID as any,
+    sessionID: sessionID as any,
+    type: "text" as const,
+    text: content,
+  })
+}
+
+function tool(sessionID: string, messageID: string) {
+  return Session.updatePart({
+    id: PartID.ascending(),
+    messageID: messageID as any,
+    sessionID: sessionID as any,
+    type: "tool" as const,
+    tool: "bash",
+    callID: "call-1",
+    state: { status: "completed" as const, input: {}, output: "done", title: "", metadata: {}, time: { start: 0, end: 1 } },
+  })
+}
+
 describe("revert + compact workflow", () => {
   test("should properly handle compact command after revert", async () => {
     await using tmp = await tmpdir({ git: true })
@@ -283,4 +333,98 @@ describe("revert + compact workflow", () => {
       },
     })
   })
+
+  test("cleanup with partID removes parts from the revert point onward", async () => {
+    await using tmp = await tmpdir({ git: true })
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const session = await Session.create({})
+        const sid = session.id
+
+        const u1 = await user(sid)
+        const p1 = await text(sid, u1.id, "first part")
+        const p2 = await tool(sid, u1.id)
+        const p3 = await text(sid, u1.id, "third part")
+
+        // Set revert state pointing at a specific part
+        await Session.setRevert({
+          sessionID: sid,
+          revert: { messageID: u1.id, partID: p2.id },
+          summary: { additions: 0, deletions: 0, files: 0 },
+        })
+
+        const info = await Session.get(sid)
+        await SessionRevert.cleanup(info)
+
+        const msgs = await Session.messages({ sessionID: sid })
+        expect(msgs.length).toBe(1)
+        // Only the first part should remain (before the revert partID)
+        expect(msgs[0].parts.length).toBe(1)
+        expect(msgs[0].parts[0].id).toBe(p1.id)
+
+        const cleared = await Session.get(sid)
+        expect(cleared.revert).toBeUndefined()
+      },
+    })
+  })
+
+  test("cleanup removes messages after revert point but keeps earlier ones", async () => {
+    await using tmp = await tmpdir({ git: true })
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const session = await Session.create({})
+        const sid = session.id
+
+        const u1 = await user(sid)
+        await text(sid, u1.id, "hello")
+        const a1 = await assistant(sid, u1.id, tmp.path)
+        await text(sid, a1.id, "hi back")
+
+        const u2 = await user(sid)
+        await text(sid, u2.id, "second question")
+        const a2 = await assistant(sid, u2.id, tmp.path)
+        await text(sid, a2.id, "second answer")
+
+        // Revert from u2 onward
+        await Session.setRevert({
+          sessionID: sid,
+          revert: { messageID: u2.id },
+          summary: { additions: 0, deletions: 0, files: 0 },
+        })
+
+        const info = await Session.get(sid)
+        await SessionRevert.cleanup(info)
+
+        const msgs = await Session.messages({ sessionID: sid })
+        const ids = msgs.map((m) => m.info.id)
+        expect(ids).toContain(u1.id)
+        expect(ids).toContain(a1.id)
+        expect(ids).not.toContain(u2.id)
+        expect(ids).not.toContain(a2.id)
+      },
+    })
+  })
+
+  test("cleanup is a no-op when session has no revert state", async () => {
+    await using tmp = await tmpdir({ git: true })
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const session = await Session.create({})
+        const sid = session.id
+
+        const u1 = await user(sid)
+        await text(sid, u1.id, "hello")
+
+        const info = await Session.get(sid)
+        expect(info.revert).toBeUndefined()
+        await SessionRevert.cleanup(info)
+
+        const msgs = await Session.messages({ sessionID: sid })
+        expect(msgs.length).toBe(1)
+      },
+    })
+  })
 })