Просмотр исходного кода

core: add session diff API to show file changes between snapshots

Dax Raad 4 месяцев назад
Родитель
Сommit
a0a09f421c

+ 87 - 14
packages/opencode/src/server/server.ts

@@ -1,6 +1,12 @@
 import { Log } from "../util/log"
 import { Bus } from "../bus"
-import { describeRoute, generateSpecs, validator, resolver, openAPIRouteHandler } from "hono-openapi"
+import {
+  describeRoute,
+  generateSpecs,
+  validator,
+  resolver,
+  openAPIRouteHandler,
+} from "hono-openapi"
 import { Hono } from "hono"
 import { cors } from "hono/cors"
 import { streamSSE } from "hono/streaming"
@@ -35,6 +41,7 @@ import { InstanceBootstrap } from "../project/bootstrap"
 import { MCP } from "../mcp"
 import { Storage } from "../storage/storage"
 import type { ContentfulStatusCode } from "hono/utils/http-status"
+import { Snapshot } from "@/snapshot"
 
 const ERRORS = {
   400: {
@@ -66,7 +73,9 @@ const ERRORS = {
 } as const
 
 function errors(...codes: number[]) {
-  return Object.fromEntries(codes.map((code) => [code, ERRORS[code as keyof typeof ERRORS]]))
+  return Object.fromEntries(
+    codes.map((code) => [code, ERRORS[code as keyof typeof ERRORS]]),
+  )
 }
 
 export namespace Server {
@@ -90,7 +99,8 @@ export namespace Server {
           else status = 500
           return c.json(err.toObject(), { status })
         }
-        const message = err instanceof Error && err.stack ? err.stack : err.toString()
+        const message =
+          err instanceof Error && err.stack ? err.stack : err.toString()
         return c.json(new NamedError.Unknown({ message }).toObject(), {
           status: 500,
         })
@@ -184,14 +194,17 @@ export namespace Server {
       .get(
         "/experimental/tool/ids",
         describeRoute({
-          description: "List all tool IDs (including built-in and dynamically registered)",
+          description:
+            "List all tool IDs (including built-in and dynamically registered)",
           operationId: "tool.ids",
           responses: {
             200: {
               description: "Tool IDs",
               content: {
                 "application/json": {
-                  schema: resolver(z.array(z.string()).meta({ ref: "ToolIDs" })),
+                  schema: resolver(
+                    z.array(z.string()).meta({ ref: "ToolIDs" }),
+                  ),
                 },
               },
             },
@@ -205,7 +218,8 @@ export namespace Server {
       .get(
         "/experimental/tool",
         describeRoute({
-          description: "List tools with JSON schema parameters for a provider/model",
+          description:
+            "List tools with JSON schema parameters for a provider/model",
           operationId: "tool.list",
           responses: {
             200: {
@@ -246,7 +260,9 @@ export namespace Server {
               id: t.id,
               description: t.description,
               // Handle both Zod schemas and plain JSON schemas
-              parameters: (t.parameters as any)?._def ? zodToJsonSchema(t.parameters as any) : t.parameters,
+              parameters: (t.parameters as any)?._def
+                ? zodToJsonSchema(t.parameters as any)
+                : t.parameters,
             })),
           )
         },
@@ -608,6 +624,44 @@ export namespace Server {
           return c.json(session)
         },
       )
+      .get(
+        "/session/:id/diff",
+        describeRoute({
+          description: "Get the diff that resulted from this user message",
+          operationId: "session.diff",
+          responses: {
+            200: {
+              description: "Successfully retrieved diff",
+              content: {
+                "application/json": {
+                  schema: resolver(Snapshot.FileDiff.array()),
+                },
+              },
+            },
+          },
+        }),
+        validator(
+          "param",
+          z.object({
+            id: Session.diff.schema.shape.sessionID,
+          }),
+        ),
+        validator(
+          "query",
+          z.object({
+            messageID: Session.diff.schema.shape.messageID,
+          }),
+        ),
+        async (c) => {
+          const query = c.req.valid("query")
+          const params = c.req.valid("param")
+          const result = await Session.diff({
+            sessionID: params.id,
+            messageID: query.messageID,
+          })
+          return c.json(result)
+        },
+      )
       .delete(
         "/session/:id/share",
         describeRoute({
@@ -734,7 +788,10 @@ export namespace Server {
         ),
         async (c) => {
           const params = c.req.valid("param")
-          const message = await Session.getMessage({ sessionID: params.id, messageID: params.messageID })
+          const message = await Session.getMessage({
+            sessionID: params.id,
+            messageID: params.messageID,
+          })
           return c.json(message)
         },
       )
@@ -868,7 +925,10 @@ export namespace Server {
         async (c) => {
           const id = c.req.valid("param").id
           log.info("revert", c.req.valid("json"))
-          const session = await SessionRevert.revert({ sessionID: id, ...c.req.valid("json") })
+          const session = await SessionRevert.revert({
+            sessionID: id,
+            ...c.req.valid("json"),
+          })
           return c.json(session)
         },
       )
@@ -929,7 +989,11 @@ export namespace Server {
           const params = c.req.valid("param")
           const id = params.id
           const permissionID = params.permissionID
-          Permission.respond({ sessionID: id, permissionID, response: c.req.valid("json").response })
+          Permission.respond({
+            sessionID: id,
+            permissionID,
+            response: c.req.valid("json").response,
+          })
           return c.json(true)
         },
       )
@@ -976,10 +1040,15 @@ export namespace Server {
           },
         }),
         async (c) => {
-          const providers = await Provider.list().then((x) => mapValues(x, (item) => item.info))
+          const providers = await Provider.list().then((x) =>
+            mapValues(x, (item) => item.info),
+          )
           return c.json({
             providers: Object.values(providers),
-            default: mapValues(providers, (item) => Provider.sort(Object.values(item.models))[0].id),
+            default: mapValues(
+              providers,
+              (item) => Provider.sort(Object.values(item.models))[0].id,
+            ),
           })
         },
       )
@@ -1174,8 +1243,12 @@ export namespace Server {
         validator(
           "json",
           z.object({
-            service: z.string().meta({ description: "Service name for the log entry" }),
-            level: z.enum(["debug", "info", "error", "warn"]).meta({ description: "Log level" }),
+            service: z
+              .string()
+              .meta({ description: "Service name for the log entry" }),
+            level: z
+              .enum(["debug", "info", "error", "warn"])
+              .meta({ description: "Log level" }),
             message: z.string().meta({ description: "Log message" }),
             extra: z
               .record(z.string(), z.any())

+ 53 - 2
packages/opencode/src/session/index.ts

@@ -18,6 +18,7 @@ import { Project } from "../project/project"
 import { Instance } from "../project/instance"
 import { SessionPrompt } from "./prompt"
 import { fn } from "@/util/fn"
+import { Snapshot } from "@/snapshot"
 
 export namespace Session {
   const log = Log.create({ service: "session" })
@@ -146,7 +147,12 @@ export namespace Session {
     })
   })
 
-  export async function createNext(input: { id?: string; title?: string; parentID?: string; directory: string }) {
+  export async function createNext(input: {
+    id?: string
+    title?: string
+    parentID?: string
+    directory: string
+  }) {
     const result: Info = {
       id: Identifier.descending("session", input.id),
       version: Installation.VERSION,
@@ -366,7 +372,9 @@ export namespace Session {
           .add(new Decimal(tokens.input).mul(input.model.cost?.input ?? 0).div(1_000_000))
           .add(new Decimal(tokens.output).mul(input.model.cost?.output ?? 0).div(1_000_000))
           .add(new Decimal(tokens.cache.read).mul(input.model.cost?.cache_read ?? 0).div(1_000_000))
-          .add(new Decimal(tokens.cache.write).mul(input.model.cost?.cache_write ?? 0).div(1_000_000))
+          .add(
+            new Decimal(tokens.cache.write).mul(input.model.cost?.cache_write ?? 0).div(1_000_000),
+          )
           .toNumber(),
         tokens,
       }
@@ -405,4 +413,47 @@ export namespace Session {
       await Project.setInitialized(Instance.project.id)
     },
   )
+
+  export const diff = fn(
+    z.object({
+      sessionID: Identifier.schema("session"),
+      messageID: Identifier.schema("message").optional(),
+    }),
+    async (input) => {
+      const all = await messages(input.sessionID)
+      const index = !input.messageID ? 0 : all.findIndex((x) => x.info.id === input.messageID)
+      if (index === -1) return []
+
+      let from: string | undefined
+      let to: string | undefined
+
+      // scan assistant messages to find earliest from and latest to
+      // snapshot
+      for (let i = index + 1; i < all.length; i++) {
+        const item = all[i]
+
+        // if messageID is provided, stop at the next user message
+        if (input.messageID && item.info.role === "user") break
+
+        if (!from) {
+          for (const part of item.parts) {
+            if (part.type === "step-start" && part.snapshot) {
+              from = part.snapshot
+              break
+            }
+          }
+        }
+
+        for (const part of item.parts) {
+          if (part.type === "step-finish" && part.snapshot) {
+            to = part.snapshot
+            break
+          }
+        }
+      }
+
+      if (from && to) return Snapshot.diffFull(from, to)
+      return []
+    },
+  )
 }

+ 2 - 0
packages/opencode/src/session/message-v2.ts

@@ -130,6 +130,7 @@ export namespace MessageV2 {
 
   export const StepStartPart = PartBase.extend({
     type: z.literal("step-start"),
+    snapshot: z.string().optional(),
   }).meta({
     ref: "StepStartPart",
   })
@@ -137,6 +138,7 @@ export namespace MessageV2 {
 
   export const StepFinishPart = PartBase.extend({
     type: z.literal("step-finish"),
+    snapshot: z.string().optional(),
     cost: z.number(),
     tokens: z.object({
       input: z.number(),

+ 3 - 1
packages/opencode/src/session/prompt.ts

@@ -1195,13 +1195,14 @@ export namespace SessionPrompt {
                 throw value.error
 
               case "start-step":
+                snapshot = await Snapshot.track()
                 await Session.updatePart({
                   id: Identifier.ascending("part"),
                   messageID: assistantMsg.id,
                   sessionID: assistantMsg.sessionID,
+                  snapshot,
                   type: "step-start",
                 })
-                snapshot = await Snapshot.track()
                 break
 
               case "finish-step":
@@ -1214,6 +1215,7 @@ export namespace SessionPrompt {
                 assistantMsg.tokens = usage.tokens
                 await Session.updatePart({
                   id: Identifier.ascending("part"),
+                  snapshot: await Snapshot.track(),
                   messageID: assistantMsg.id,
                   sessionID: assistantMsg.sessionID,
                   type: "step-finish",

+ 81 - 18
packages/opencode/src/snapshot/index.ts

@@ -26,8 +26,15 @@ export namespace Snapshot {
         .nothrow()
       log.info("initialized")
     }
-    await $`git --git-dir ${git} add .`.quiet().cwd(Instance.directory).nothrow()
-    const hash = await $`git --git-dir ${git} write-tree`.quiet().cwd(Instance.directory).nothrow().text()
+    await $`git --git-dir ${git} add .`
+      .quiet()
+      .cwd(Instance.directory)
+      .nothrow()
+    const hash = await $`git --git-dir ${git} write-tree`
+      .quiet()
+      .cwd(Instance.directory)
+      .nothrow()
+      .text()
     log.info("tracking", { hash, cwd: Instance.directory, git })
     return hash.trim()
   }
@@ -40,8 +47,14 @@ export namespace Snapshot {
 
   export async function patch(hash: string): Promise<Patch> {
     const git = gitdir()
-    await $`git --git-dir ${git} add .`.quiet().cwd(Instance.directory).nothrow()
-    const result = await $`git --git-dir ${git} diff --name-only ${hash} -- .`.quiet().cwd(Instance.directory).nothrow()
+    await $`git --git-dir ${git} add .`
+      .quiet()
+      .cwd(Instance.directory)
+      .nothrow()
+    const result = await $`git --git-dir ${git} diff --name-only ${hash} -- .`
+      .quiet()
+      .cwd(Instance.directory)
+      .nothrow()
 
     // If git diff fails, return empty patch
     if (result.exitCode !== 0) {
@@ -64,10 +77,11 @@ export namespace Snapshot {
   export async function restore(snapshot: string) {
     log.info("restore", { commit: snapshot })
     const git = gitdir()
-    const result = await $`git --git-dir=${git} read-tree ${snapshot} && git --git-dir=${git} checkout-index -a -f`
-      .quiet()
-      .cwd(Instance.worktree)
-      .nothrow()
+    const result =
+      await $`git --git-dir=${git} read-tree ${snapshot} && git --git-dir=${git} checkout-index -a -f`
+        .quiet()
+        .cwd(Instance.worktree)
+        .nothrow()
 
     if (result.exitCode !== 0) {
       log.error("failed to restore snapshot", {
@@ -86,18 +100,22 @@ export namespace Snapshot {
       for (const file of item.files) {
         if (files.has(file)) continue
         log.info("reverting", { file, hash: item.hash })
-        const result = await $`git --git-dir=${git} checkout ${item.hash} -- ${file}`
-          .quiet()
-          .cwd(Instance.worktree)
-          .nothrow()
-        if (result.exitCode !== 0) {
-          const relativePath = path.relative(Instance.worktree, file)
-          const checkTree = await $`git --git-dir=${git} ls-tree ${item.hash} -- ${relativePath}`
+        const result =
+          await $`git --git-dir=${git} checkout ${item.hash} -- ${file}`
             .quiet()
             .cwd(Instance.worktree)
             .nothrow()
+        if (result.exitCode !== 0) {
+          const relativePath = path.relative(Instance.worktree, file)
+          const checkTree =
+            await $`git --git-dir=${git} ls-tree ${item.hash} -- ${relativePath}`
+              .quiet()
+              .cwd(Instance.worktree)
+              .nothrow()
           if (checkTree.exitCode === 0 && checkTree.text().trim()) {
-            log.info("file existed in snapshot but checkout failed, keeping", { file })
+            log.info("file existed in snapshot but checkout failed, keeping", {
+              file,
+            })
           } else {
             log.info("file did not exist in snapshot, deleting", { file })
             await fs.unlink(file).catch(() => {})
@@ -110,8 +128,14 @@ export namespace Snapshot {
 
   export async function diff(hash: string) {
     const git = gitdir()
-    await $`git --git-dir ${git} add .`.quiet().cwd(Instance.directory).nothrow()
-    const result = await $`git --git-dir=${git} diff ${hash} -- .`.quiet().cwd(Instance.worktree).nothrow()
+    await $`git --git-dir ${git} add .`
+      .quiet()
+      .cwd(Instance.directory)
+      .nothrow()
+    const result = await $`git --git-dir=${git} diff ${hash} -- .`
+      .quiet()
+      .cwd(Instance.worktree)
+      .nothrow()
 
     if (result.exitCode !== 0) {
       log.warn("failed to get diff", {
@@ -126,6 +150,45 @@ export namespace Snapshot {
     return result.text().trim()
   }
 
+  export const FileDiff = z
+    .object({
+      file: z.string(),
+      left: z.string(),
+      right: z.string(),
+    })
+    .meta({
+      ref: "FileDiff",
+    })
+  export type FileDiff = z.infer<typeof FileDiff>
+  export async function diffFull(
+    from: string,
+    to: string,
+  ): Promise<FileDiff[]> {
+    const git = gitdir()
+    const result: FileDiff[] = []
+    for await (const line of $`git --git-dir=${git} diff --name-only ${from} ${to} -- .`
+      .quiet()
+      .cwd(Instance.directory)
+      .nothrow()
+      .lines()) {
+      if (!line) continue
+      const left = await $`git --git-dir=${git} show ${from}:${line}`
+        .quiet()
+        .nothrow()
+        .text()
+      const right = await $`git --git-dir=${git} show ${to}:${line}`
+        .quiet()
+        .nothrow()
+        .text()
+      result.push({
+        file: line,
+        left,
+        right,
+      })
+    }
+    return result
+  }
+
   function gitdir() {
     const project = Instance.project
     return path.join(Global.Path.data, "snapshot", project.id)

+ 91 - 11
packages/opencode/test/snapshot/snapshot.test.ts

@@ -33,7 +33,9 @@ test("tracks deleted files correctly", async () => {
 
       await $`rm ${tmp.path}/a.txt`.quiet()
 
-      expect((await Snapshot.patch(before!)).files).toContain(`${tmp.path}/a.txt`)
+      expect((await Snapshot.patch(before!)).files).toContain(
+        `${tmp.path}/a.txt`,
+      )
     },
   })
 })
@@ -91,11 +93,15 @@ test("multiple file operations", async () => {
 
       await Snapshot.revert([await Snapshot.patch(before!)])
 
-      expect(await Bun.file(`${tmp.path}/a.txt`).text()).toBe(tmp.extra.aContent)
+      expect(await Bun.file(`${tmp.path}/a.txt`).text()).toBe(
+        tmp.extra.aContent,
+      )
       expect(await Bun.file(`${tmp.path}/c.txt`).exists()).toBe(false)
       // Note: revert currently only removes files, not directories
       // The empty directory will remain
-      expect(await Bun.file(`${tmp.path}/b.txt`).text()).toBe(tmp.extra.bContent)
+      expect(await Bun.file(`${tmp.path}/b.txt`).text()).toBe(
+        tmp.extra.bContent,
+      )
     },
   })
 })
@@ -123,7 +129,10 @@ test("binary file handling", async () => {
       const before = await Snapshot.track()
       expect(before).toBeTruthy()
 
-      await Bun.write(`${tmp.path}/image.png`, new Uint8Array([0x89, 0x50, 0x4e, 0x47]))
+      await Bun.write(
+        `${tmp.path}/image.png`,
+        new Uint8Array([0x89, 0x50, 0x4e, 0x47]),
+      )
 
       const patch = await Snapshot.patch(before!)
       expect(patch.files).toContain(`${tmp.path}/image.png`)
@@ -144,7 +153,9 @@ test("symlink handling", async () => {
 
       await $`ln -s ${tmp.path}/a.txt ${tmp.path}/link.txt`.quiet()
 
-      expect((await Snapshot.patch(before!)).files).toContain(`${tmp.path}/link.txt`)
+      expect((await Snapshot.patch(before!)).files).toContain(
+        `${tmp.path}/link.txt`,
+      )
     },
   })
 })
@@ -159,7 +170,9 @@ test("large file handling", async () => {
 
       await Bun.write(`${tmp.path}/large.txt`, "x".repeat(1024 * 1024))
 
-      expect((await Snapshot.patch(before!)).files).toContain(`${tmp.path}/large.txt`)
+      expect((await Snapshot.patch(before!)).files).toContain(
+        `${tmp.path}/large.txt`,
+      )
     },
   })
 })
@@ -177,7 +190,9 @@ test("nested directory revert", async () => {
 
       await Snapshot.revert([await Snapshot.patch(before!)])
 
-      expect(await Bun.file(`${tmp.path}/level1/level2/level3/deep.txt`).exists()).toBe(false)
+      expect(
+        await Bun.file(`${tmp.path}/level1/level2/level3/deep.txt`).exists(),
+      ).toBe(false)
     },
   })
 })
@@ -211,7 +226,9 @@ test("revert with empty patches", async () => {
       expect(Snapshot.revert([])).resolves.toBeUndefined()
 
       // Should not crash with patches that have empty file lists
-      expect(Snapshot.revert([{ hash: "dummy", files: [] }])).resolves.toBeUndefined()
+      expect(
+        Snapshot.revert([{ hash: "dummy", files: [] }]),
+      ).resolves.toBeUndefined()
     },
   })
 })
@@ -526,9 +543,13 @@ test("restore function", async () => {
       await Snapshot.restore(before!)
 
       expect(await Bun.file(`${tmp.path}/a.txt`).exists()).toBe(true)
-      expect(await Bun.file(`${tmp.path}/a.txt`).text()).toBe(tmp.extra.aContent)
+      expect(await Bun.file(`${tmp.path}/a.txt`).text()).toBe(
+        tmp.extra.aContent,
+      )
       expect(await Bun.file(`${tmp.path}/new.txt`).exists()).toBe(true) // New files should remain
-      expect(await Bun.file(`${tmp.path}/b.txt`).text()).toBe(tmp.extra.bContent)
+      expect(await Bun.file(`${tmp.path}/b.txt`).text()).toBe(
+        tmp.extra.bContent,
+      )
     },
   })
 })
@@ -580,7 +601,66 @@ test("revert preserves file that existed in snapshot when deleted then recreated
 
       expect(await Bun.file(`${tmp.path}/newfile.txt`).exists()).toBe(false)
       expect(await Bun.file(`${tmp.path}/existing.txt`).exists()).toBe(true)
-      expect(await Bun.file(`${tmp.path}/existing.txt`).text()).toBe("original content")
+      expect(await Bun.file(`${tmp.path}/existing.txt`).text()).toBe(
+        "original content",
+      )
+    },
+  })
+})
+
+test("diffFull function", async () => {
+  await using tmp = await bootstrap()
+  await Instance.provide({
+    directory: tmp.path,
+    fn: async () => {
+      const before = await Snapshot.track()
+      expect(before).toBeTruthy()
+
+      await Bun.write(`${tmp.path}/new.txt`, "new content")
+      await Bun.write(`${tmp.path}/b.txt`, "modified content")
+
+      const after = await Snapshot.track()
+      expect(after).toBeTruthy()
+
+      const diffs = await Snapshot.diffFull(before!, after!)
+      expect(diffs.length).toBe(2)
+
+      const newFileDiff = diffs.find((d) => d.file === "new.txt")
+      expect(newFileDiff).toBeDefined()
+      expect(newFileDiff!.left).toBe("")
+      expect(newFileDiff!.right).toBe("new content")
+
+      const modifiedFileDiff = diffs.find((d) => d.file === "b.txt")
+      expect(modifiedFileDiff).toBeDefined()
+      expect(modifiedFileDiff!.left).toBe(tmp.extra.bContent)
+      expect(modifiedFileDiff!.right).toBe("modified content")
+    },
+  })
+
+  await Instance.provide({
+    directory: tmp.path,
+    fn: async () => {
+      const before = await Snapshot.track()
+      expect(before).toBeTruthy()
+
+      await Bun.write(`${tmp.path}/added.txt`, "added content")
+      await $`rm ${tmp.path}/a.txt`.quiet()
+
+      const after = await Snapshot.track()
+      expect(after).toBeTruthy()
+
+      const diffs = await Snapshot.diffFull(before!, after!)
+      expect(diffs.length).toBe(2)
+
+      const addedFileDiff = diffs.find((d) => d.file === "added.txt")
+      expect(addedFileDiff).toBeDefined()
+      expect(addedFileDiff!.left).toBe("")
+      expect(addedFileDiff!.right).toBe("added content")
+
+      const removedFileDiff = diffs.find((d) => d.file === "a.txt")
+      expect(removedFileDiff).toBeDefined()
+      expect(removedFileDiff!.left).toBe(tmp.extra.aContent)
+      expect(removedFileDiff!.right).toBe("")
     },
   })
 })