Explorar o código

wip: session revert/unrevert

Dax Raad hai 7 meses
pai
achega
35d6273fb3

+ 91 - 1
packages/opencode/src/session/index.ts

@@ -34,6 +34,7 @@ import type { ModelsDev } from "../provider/models"
 import { Installation } from "../installation"
 import { Installation } from "../installation"
 import { Config } from "../config/config"
 import { Config } from "../config/config"
 import { ProviderTransform } from "../provider/transform"
 import { ProviderTransform } from "../provider/transform"
+import { Snapshot } from "../snapshot"
 
 
 export namespace Session {
 export namespace Session {
   const log = Log.create({ service: "session" })
   const log = Log.create({ service: "session" })
@@ -53,6 +54,13 @@ export namespace Session {
         created: z.number(),
         created: z.number(),
         updated: z.number(),
         updated: z.number(),
       }),
       }),
+      revert: z
+        .object({
+          messageID: z.string(),
+          part: z.number(),
+          snapshot: z.string().optional(),
+        })
+        .optional(),
     })
     })
     .openapi({
     .openapi({
       ref: "Session",
       ref: "Session",
@@ -285,6 +293,37 @@ export namespace Session {
     l.info("chatting")
     l.info("chatting")
     const model = await Provider.getModel(input.providerID, input.modelID)
     const model = await Provider.getModel(input.providerID, input.modelID)
     let msgs = await messages(input.sessionID)
     let msgs = await messages(input.sessionID)
+    const session = await get(input.sessionID)
+
+    if (session.revert) {
+      const trimmed = []
+      for (const msg of msgs) {
+        if (
+          msg.id > session.revert.messageID ||
+          (msg.id === session.revert.messageID && session.revert.part === 0)
+        ) {
+          await Storage.remove(
+            "session/message/" + input.sessionID + "/" + msg.id,
+          )
+          await Bus.publish(Message.Event.Removed, {
+            sessionID: input.sessionID,
+            messageID: msg.id,
+          })
+          continue
+        }
+
+        if (msg.id === session.revert.messageID) {
+          if (session.revert.part === 0) break
+          msg.parts = msg.parts.slice(0, session.revert.part)
+        }
+        trimmed.push(msg)
+      }
+      msgs = trimmed
+      await update(input.sessionID, (draft) => {
+        draft.revert = undefined
+      })
+    }
+
     const previous = msgs.at(-1)
     const previous = msgs.at(-1)
 
 
     // auto summarize if too long
     // auto summarize if too long
@@ -319,7 +358,6 @@ export namespace Session {
     if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
     if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
 
 
     const app = App.info()
     const app = App.info()
-    const session = await get(input.sessionID)
     if (msgs.length === 0 && !session.parentID) {
     if (msgs.length === 0 && !session.parentID) {
       generateText({
       generateText({
         maxTokens: input.providerID === "google" ? 1024 : 20,
         maxTokens: input.providerID === "google" ? 1024 : 20,
@@ -349,6 +387,7 @@ export namespace Session {
         })
         })
         .catch(() => {})
         .catch(() => {})
     }
     }
+    const snapshot = await Snapshot.create(input.sessionID)
     const msg: Message.Info = {
     const msg: Message.Info = {
       role: "user",
       role: "user",
       id: Identifier.ascending("message"),
       id: Identifier.ascending("message"),
@@ -359,6 +398,7 @@ export namespace Session {
         },
         },
         sessionID: input.sessionID,
         sessionID: input.sessionID,
         tool: {},
         tool: {},
+        snapshot,
       },
       },
     }
     }
     await updateMessage(msg)
     await updateMessage(msg)
@@ -373,6 +413,7 @@ export namespace Session {
       role: "assistant",
       role: "assistant",
       parts: [],
       parts: [],
       metadata: {
       metadata: {
+        snapshot,
         assistant: {
         assistant: {
           system,
           system,
           path: {
           path: {
@@ -424,6 +465,7 @@ export namespace Session {
             })
             })
             next.metadata!.tool![opts.toolCallId] = {
             next.metadata!.tool![opts.toolCallId] = {
               ...result.metadata,
               ...result.metadata,
+              snapshot: await Snapshot.create(input.sessionID),
               time: {
               time: {
                 start,
                 start,
                 end: Date.now(),
                 end: Date.now(),
@@ -436,6 +478,7 @@ export namespace Session {
               error: true,
               error: true,
               message: e.toString(),
               message: e.toString(),
               title: e.toString(),
               title: e.toString(),
+              snapshot: await Snapshot.create(input.sessionID),
               time: {
               time: {
                 start,
                 start,
                 end: Date.now(),
                 end: Date.now(),
@@ -457,6 +500,7 @@ export namespace Session {
           const result = await execute(args, opts)
           const result = await execute(args, opts)
           next.metadata!.tool![opts.toolCallId] = {
           next.metadata!.tool![opts.toolCallId] = {
             ...result.metadata,
             ...result.metadata,
+            snapshot: await Snapshot.create(input.sessionID),
             time: {
             time: {
               start,
               start,
               end: Date.now(),
               end: Date.now(),
@@ -471,6 +515,7 @@ export namespace Session {
           next.metadata!.tool![opts.toolCallId] = {
           next.metadata!.tool![opts.toolCallId] = {
             error: true,
             error: true,
             message: e.toString(),
             message: e.toString(),
+            snapshot: await Snapshot.create(input.sessionID),
             title: "mcp",
             title: "mcp",
             time: {
             time: {
               start,
               start,
@@ -735,6 +780,51 @@ export namespace Session {
     return next
     return next
   }
   }
 
 
+  export async function revert(input: {
+    sessionID: string
+    messageID: string
+    part: number
+  }) {
+    const message = await getMessage(input.sessionID, input.messageID)
+    if (!message) return
+    const part = message.parts[input.part]
+    if (!part) return
+    const session = await get(input.sessionID)
+    const snapshot =
+      session.revert?.snapshot ?? (await Snapshot.create(input.sessionID))
+    const old = (() => {
+      if (message.role === "assistant") {
+        const lastTool = message.parts.findLast(
+          (part, index) =>
+            part.type === "tool-invocation" && index < input.part,
+        )
+        if (lastTool && lastTool.type === "tool-invocation")
+          return message.metadata.tool[lastTool.toolInvocation.toolCallId]
+            .snapshot
+      }
+      return message.metadata.snapshot
+    })()
+    if (old) await Snapshot.restore(input.sessionID, old)
+    await update(input.sessionID, (draft) => {
+      draft.revert = {
+        messageID: input.messageID,
+        part: input.part,
+        snapshot,
+      }
+    })
+  }
+
+  export async function unrevert(sessionID: string) {
+    const session = await get(sessionID)
+    if (!session) return
+    if (!session.revert) return
+    if (session.revert.snapshot)
+      await Snapshot.restore(sessionID, session.revert.snapshot)
+    update(sessionID, (draft) => {
+      draft.revert = undefined
+    })
+  }
+
   export async function summarize(input: {
   export async function summarize(input: {
     sessionID: string
     sessionID: string
     providerID: string
     providerID: string

+ 9 - 5
packages/opencode/src/session/message.ts

@@ -159,6 +159,7 @@ export namespace Message {
             z
             z
               .object({
               .object({
                 title: z.string(),
                 title: z.string(),
+                snapshot: z.string().optional(),
                 time: z.object({
                 time: z.object({
                   start: z.number(),
                   start: z.number(),
                   end: z.number(),
                   end: z.number(),
@@ -188,11 +189,7 @@ export namespace Message {
               }),
               }),
             })
             })
             .optional(),
             .optional(),
-          user: z
-            .object({
-              snapshot: z.string().optional(),
-            })
-            .optional(),
+          snapshot: z.string().optional(),
         })
         })
         .openapi({ ref: "MessageMetadata" }),
         .openapi({ ref: "MessageMetadata" }),
     })
     })
@@ -208,6 +205,13 @@ export namespace Message {
         info: Info,
         info: Info,
       }),
       }),
     ),
     ),
+    Removed: Bus.event(
+      "message.removed",
+      z.object({
+        sessionID: z.string(),
+        messageID: z.string(),
+      }),
+    ),
     PartUpdated: Bus.event(
     PartUpdated: Bus.event(
       "message.part.updated",
       "message.part.updated",
       z.object({
       z.object({