Browse Source

wip: more snapshot stuff

Dax Raad 7 months ago
parent
commit
284c01018e

+ 9 - 55
packages/opencode/src/cli/cmd/debug/snapshot.ts

@@ -5,76 +5,30 @@ import { cmd } from "../cmd"
 
 export const SnapshotCommand = cmd({
   command: "snapshot",
-  builder: (yargs) =>
-    yargs.command(CreateCommand).command(RestoreCommand).command(DiffCommand).command(RevertCommand).demandCommand(),
+  builder: (yargs) => yargs.command(TrackCommand).command(PatchCommand).demandCommand(),
   async handler() {},
 })
 
-const CreateCommand = cmd({
-  command: "create",
+const TrackCommand = cmd({
+  command: "track",
   async handler() {
     await bootstrap({ cwd: process.cwd() }, async () => {
-      const result = await Snapshot.create()
-      console.log(result)
-    })
-  },
-})
-
-const RestoreCommand = cmd({
-  command: "restore <commit>",
-  builder: (yargs) =>
-    yargs.positional("commit", {
-      type: "string",
-      description: "commit",
-      demandOption: true,
-    }),
-  async handler(args) {
-    await bootstrap({ cwd: process.cwd() }, async () => {
-      await Snapshot.restore(args.commit)
-      console.log("restored")
+      console.log(await Snapshot.track())
     })
   },
 })
 
-export const DiffCommand = cmd({
-  command: "diff <commit>",
-  describe: "diff",
+const PatchCommand = cmd({
+  command: "patch <hash>",
   builder: (yargs) =>
-    yargs.positional("commit", {
+    yargs.positional("hash", {
       type: "string",
-      description: "commit",
+      description: "hash",
       demandOption: true,
     }),
   async handler(args) {
     await bootstrap({ cwd: process.cwd() }, async () => {
-      const diff = await Snapshot.diff(args.commit)
-      console.log(diff)
-    })
-  },
-})
-
-export const RevertCommand = cmd({
-  command: "revert <sessionID> <messageID>",
-  describe: "revert",
-  builder: (yargs) =>
-    yargs
-      .positional("sessionID", {
-        type: "string",
-        description: "sessionID",
-        demandOption: true,
-      })
-      .positional("messageID", {
-        type: "string",
-        description: "messageID",
-        demandOption: true,
-      }),
-  async handler(args) {
-    await bootstrap({ cwd: process.cwd() }, async () => {
-      const session = await Session.revert({
-        sessionID: args.sessionID,
-        messageID: args.messageID,
-      })
-      console.log(session?.revert)
+      console.log(await Snapshot.patch(args.hash))
     })
   },
 })

+ 59 - 38
packages/opencode/src/session/index.ts

@@ -661,6 +661,7 @@ export namespace Session {
         description: item.description,
         inputSchema: item.parameters as ZodSchema,
         async execute(args, options) {
+          await processor.track(options.toolCallId)
           const result = await item.execute(args, {
             sessionID: input.sessionID,
             abort: abort.signal,
@@ -699,6 +700,7 @@ export namespace Session {
       const execute = item.execute
       if (!execute) continue
       item.execute = async (args, opts) => {
+        await processor.track(opts.toolCallId)
         const result = await execute(args, opts)
         const output = result.content
           .filter((x: any) => x.type === "text")
@@ -814,7 +816,12 @@ export namespace Session {
 
   function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
     const toolCalls: Record<string, MessageV2.ToolPart> = {}
+    const snapshots: Record<string, string> = {}
     return {
+      async track(toolCallID: string) {
+        const hash = await Snapshot.track()
+        if (hash) snapshots[toolCallID] = hash
+      },
       partFromToolCall(toolCallID: string) {
         return toolCalls[toolCallID]
       },
@@ -828,15 +835,6 @@ export namespace Session {
             })
             switch (value.type) {
               case "start":
-                const snapshot = await Snapshot.create()
-                if (snapshot)
-                  await updatePart({
-                    id: Identifier.ascending("part"),
-                    messageID: assistantMsg.id,
-                    sessionID: assistantMsg.sessionID,
-                    type: "snapshot",
-                    snapshot,
-                  })
                 break
 
               case "tool-input-start":
@@ -857,6 +855,9 @@ export namespace Session {
               case "tool-input-delta":
                 break
 
+              case "tool-input-end":
+                break
+
               case "tool-call": {
                 const match = toolCalls[value.toolCallId]
                 if (match) {
@@ -892,15 +893,20 @@ export namespace Session {
                     },
                   })
                   delete toolCalls[value.toolCallId]
-                  const snapshot = await Snapshot.create()
-                  if (snapshot)
-                    await updatePart({
-                      id: Identifier.ascending("part"),
-                      messageID: assistantMsg.id,
-                      sessionID: assistantMsg.sessionID,
-                      type: "snapshot",
-                      snapshot,
-                    })
+                  const snapshot = snapshots[value.toolCallId]
+                  if (snapshot) {
+                    const patch = await Snapshot.patch(snapshot)
+                    if (patch.files.length) {
+                      await updatePart({
+                        id: Identifier.ascending("part"),
+                        messageID: assistantMsg.id,
+                        sessionID: assistantMsg.sessionID,
+                        type: "patch",
+                        hash: patch.hash,
+                        files: patch.files,
+                      })
+                    }
+                  }
                 }
                 break
               }
@@ -921,15 +927,18 @@ export namespace Session {
                     },
                   })
                   delete toolCalls[value.toolCallId]
-                  const snapshot = await Snapshot.create()
-                  if (snapshot)
+                  const snapshot = snapshots[value.toolCallId]
+                  if (snapshot) {
+                    const patch = await Snapshot.patch(snapshot)
                     await updatePart({
                       id: Identifier.ascending("part"),
                       messageID: assistantMsg.id,
                       sessionID: assistantMsg.sessionID,
-                      type: "snapshot",
-                      snapshot,
+                      type: "patch",
+                      hash: patch.hash,
+                      files: patch.files,
                     })
+                  }
                 }
                 break
               }
@@ -1073,33 +1082,45 @@ export namespace Session {
 
   export async function revert(input: RevertInput) {
     const all = await messages(input.sessionID)
-    const session = await get(input.sessionID)
     let lastUser: MessageV2.User | undefined
-    let lastSnapshot: MessageV2.SnapshotPart | undefined
+    const session = await get(input.sessionID)
+
+    let revert: Info["revert"]
+    const patches: Snapshot.Patch[] = []
     for (const msg of all) {
       if (msg.info.role === "user") lastUser = msg.info
       const remaining = []
       for (const part of msg.parts) {
-        if (part.type === "snapshot") lastSnapshot = part
-        if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
-          // if no useful parts left in message, same as reverting whole message
-          const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
-          const snapshot = session.revert?.snapshot ?? (await Snapshot.create())
-          log.info("revert snapshot", { snapshot })
-          if (lastSnapshot) await Snapshot.restore(lastSnapshot.snapshot)
-          const next = await update(input.sessionID, (draft) => {
-            draft.revert = {
-              // if not part id jump to the last user message
+        if (revert) {
+          if (part.type === "patch") {
+            patches.push(part)
+          }
+          continue
+        }
+
+        if (!revert) {
+          if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) {
+            // if no useful parts left in message, same as reverting whole message
+            const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined
+            revert = {
               messageID: !partID && lastUser ? lastUser.id : msg.info.id,
               partID,
-              snapshot,
             }
-          })
-          return next
+          }
+          remaining.push(part)
         }
-        remaining.push(part)
       }
     }
+
+    if (revert) {
+      const session = await get(input.sessionID)
+      revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track())
+      await Snapshot.revert(patches)
+      return update(input.sessionID, (draft) => {
+        draft.revert = revert
+      })
+    }
+    return session
   }
 
   export async function unrevert(input: { sessionID: string }) {

+ 10 - 1
packages/opencode/src/session/message-v2.ts

@@ -94,6 +94,15 @@ export namespace MessageV2 {
   })
   export type SnapshotPart = z.infer<typeof SnapshotPart>
 
+  export const PatchPart = PartBase.extend({
+    type: z.literal("patch"),
+    hash: z.string(),
+    files: z.string().array(),
+  }).openapi({
+    ref: "PatchPart",
+  })
+  export type PatchPart = z.infer<typeof PatchPart>
+
   export const TextPart = PartBase.extend({
     type: z.literal("text"),
     text: z.string(),
@@ -203,7 +212,7 @@ export namespace MessageV2 {
   export type User = z.infer<typeof User>
 
   export const Part = z
-    .discriminatedUnion("type", [TextPart, FilePart, ToolPart, StepStartPart, StepFinishPart, SnapshotPart])
+    .discriminatedUnion("type", [TextPart, FilePart, ToolPart, StepStartPart, StepFinishPart, SnapshotPart, PatchPart])
     .openapi({
       ref: "Part",
     })

+ 37 - 29
packages/opencode/src/snapshot/index.ts

@@ -6,6 +6,7 @@ import { Ripgrep } from "../file/ripgrep"
 import { Log } from "../util/log"
 import { Global } from "../global"
 import { Installation } from "../installation"
+import { z } from "zod"
 
 export namespace Snapshot {
   const log = Log.create({ service: "snapshot" })
@@ -24,21 +25,9 @@ export namespace Snapshot {
     })
   }
 
-  export async function create() {
-    log.info("creating snapshot")
+  export async function track() {
     const app = App.info()
-
-    // not a git repo, check if too big to snapshot
-    if (!app.git || !Installation.isDev()) {
-      return
-      const files = await Ripgrep.files({
-        cwd: app.path.cwd,
-        limit: 1000,
-      })
-      log.info("found files", { count: files.length })
-      if (files.length >= 1000) return
-    }
-
+    if (!app.git) return
     const git = gitdir()
     if (await fs.mkdir(git, { recursive: true })) {
       await $`git init`
@@ -51,33 +40,52 @@ export namespace Snapshot {
         .nothrow()
       log.info("initialized")
     }
-
     await $`git --git-dir ${git} add .`.quiet().cwd(app.path.cwd).nothrow()
-    log.info("added files")
+    const hash = await $`git --git-dir ${git} write-tree`.quiet().cwd(app.path.cwd).text()
+    return hash.trim()
+  }
 
-    const result =
-      await $`git --git-dir ${git} commit --allow-empty -m "snapshot" --no-gpg-sign --author="opencode <[email protected]>"`
-        .quiet()
-        .cwd(app.path.cwd)
-        .nothrow()
+  export const Patch = z.object({
+    hash: z.string(),
+    files: z.string().array(),
+  })
+  export type Patch = z.infer<typeof Patch>
 
-    const match = result.stdout.toString().match(/\[.+ ([a-f0-9]+)\]/)
-    if (!match) return
-    return match![1]
+  export async function patch(hash: string): Promise<Patch> {
+    const app = App.info()
+    const git = gitdir()
+    const files = await $`git --git-dir ${git} diff --name-only ${hash} -- .`.cwd(app.path.cwd).text()
+    return {
+      hash,
+      files: files
+        .trim()
+        .split("\n")
+        .map((x) => x.trim())
+        .filter(Boolean)
+        .map((x) => path.join(app.path.cwd, x)),
+    }
   }
 
   export async function restore(snapshot: string) {
     log.info("restore", { commit: snapshot })
     const app = App.info()
     const git = gitdir()
-    await $`git --git-dir=${git} reset --hard ${snapshot}`.quiet().cwd(app.path.root)
+    await $`git --git-dir=${git} read-tree ${snapshot} && git --git-dir=${git} checkout-index -a -f`
+      .quiet()
+      .cwd(app.path.root)
   }
 
-  export async function diff(commit: string) {
+  export async function revert(patches: Patch[]) {
+    const files = new Set<string>()
     const git = gitdir()
-    const result = await $`git --git-dir=${git} diff -R ${commit}`.quiet().cwd(App.info().path.root)
-    const text = result.stdout.toString("utf8")
-    return text
+    for (const item of patches) {
+      for (const file of item.files) {
+        if (files.has(file)) continue
+        log.info("reverting", { file, hash: item.hash })
+        await $`git --git-dir=${git} checkout ${item.hash} -- ${file}`.quiet().cwd(App.info().path.root)
+        files.add(file)
+      }
+    }
   }
 
   function gitdir() {

+ 2 - 2
packages/sdk/.stats.yml

@@ -1,4 +1,4 @@
 configured_endpoints: 26
-openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/opencode%2Fopencode-1efc45c35b58e88b0550fbb0c7a204ef66522742f87c9e29c76a18b120c0d945.yml
-openapi_spec_hash: 5e15d85e4704624f9b13bae1c71aa416
+openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/opencode%2Fopencode-5748199af356c3243a46a466e73b5d0bab7eaa0c56895e1d0f903d637f61d0bb.yml
+openapi_spec_hash: c04f6b6be54b05d9b1283c24e870163b
 config_hash: 1ae82c93499b9f0b9ba828b8919f9cb3

+ 9 - 1
packages/sdk/src/resources/config.ts

@@ -147,7 +147,7 @@ export namespace Config {
 
     npm?: string;
 
-    options?: { [key: string]: unknown };
+    options?: Provider.Options;
   }
 
   export namespace Provider {
@@ -190,6 +190,14 @@ export namespace Config {
         output: number;
       }
     }
+
+    export interface Options {
+      apiKey?: string;
+
+      baseURL?: string;
+
+      [k: string]: unknown;
+    }
   }
 }
 

+ 15 - 0
packages/sdk/src/resources/event.ts

@@ -24,6 +24,7 @@ export type EventListResponse =
   | EventListResponse.EventMessageUpdated
   | EventListResponse.EventMessageRemoved
   | EventListResponse.EventMessagePartUpdated
+  | EventListResponse.EventMessagePartRemoved
   | EventListResponse.EventStorageWrite
   | EventListResponse.EventSessionUpdated
   | EventListResponse.EventSessionDeleted
@@ -135,6 +136,20 @@ export namespace EventListResponse {
     }
   }
 
+  export interface EventMessagePartRemoved {
+    properties: EventMessagePartRemoved.Properties;
+
+    type: 'message.part.removed';
+  }
+
+  export namespace EventMessagePartRemoved {
+    export interface Properties {
+      messageID: string;
+
+      partID: string;
+    }
+  }
+
   export interface EventStorageWrite {
     properties: EventStorageWrite.Properties;
 

+ 26 - 1
packages/sdk/src/resources/session.ts

@@ -102,6 +102,8 @@ export interface AssistantMessage {
 
   cost: number;
 
+  mode: string;
+
   modelID: string;
 
   path: AssistantMessage.Path;
@@ -217,7 +219,30 @@ export interface FileSource {
 
 export type Message = UserMessage | AssistantMessage;
 
-export type Part = TextPart | FilePart | ToolPart | StepStartPart | StepFinishPart | SnapshotPart;
+export type Part =
+  | TextPart
+  | FilePart
+  | ToolPart
+  | StepStartPart
+  | StepFinishPart
+  | SnapshotPart
+  | Part.PatchPart;
+
+export namespace Part {
+  export interface PatchPart {
+    id: string;
+
+    files: Array<string>;
+
+    hash: string;
+
+    messageID: string;
+
+    sessionID: string;
+
+    type: 'patch';
+  }
+}
 
 export interface Session {
   id: string;

+ 2 - 2
packages/tui/sdk/.stats.yml

@@ -1,4 +1,4 @@
 configured_endpoints: 26
-openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/opencode%2Fopencode-1efc45c35b58e88b0550fbb0c7a204ef66522742f87c9e29c76a18b120c0d945.yml
-openapi_spec_hash: 5e15d85e4704624f9b13bae1c71aa416
+openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/opencode%2Fopencode-5748199af356c3243a46a466e73b5d0bab7eaa0c56895e1d0f903d637f61d0bb.yml
+openapi_spec_hash: c04f6b6be54b05d9b1283c24e870163b
 config_hash: 1ae82c93499b9f0b9ba828b8919f9cb3

+ 25 - 1
packages/tui/sdk/config.go

@@ -333,7 +333,7 @@ type ConfigProvider struct {
 	Env     []string                       `json:"env"`
 	Name    string                         `json:"name"`
 	Npm     string                         `json:"npm"`
-	Options map[string]interface{}         `json:"options"`
+	Options ConfigProviderOptions          `json:"options"`
 	JSON    configProviderJSON             `json:"-"`
 }
 
@@ -447,6 +447,30 @@ func (r configProviderModelsLimitJSON) RawJSON() string {
 	return r.raw
 }
 
+type ConfigProviderOptions struct {
+	APIKey      string                    `json:"apiKey"`
+	BaseURL     string                    `json:"baseURL"`
+	ExtraFields map[string]interface{}    `json:"-,extras"`
+	JSON        configProviderOptionsJSON `json:"-"`
+}
+
+// configProviderOptionsJSON contains the JSON metadata for the struct
+// [ConfigProviderOptions]
+type configProviderOptionsJSON struct {
+	APIKey      apijson.Field
+	BaseURL     apijson.Field
+	raw         string
+	ExtraFields map[string]apijson.Field
+}
+
+func (r *ConfigProviderOptions) UnmarshalJSON(data []byte) (err error) {
+	return apijson.UnmarshalRoot(data, r)
+}
+
+func (r configProviderOptionsJSON) RawJSON() string {
+	return r.raw
+}
+
 // Control sharing behavior:'manual' allows manual sharing via commands, 'auto'
 // enables automatic sharing, 'disabled' disables all sharing
 type ConfigShare string

+ 72 - 1
packages/tui/sdk/event.go

@@ -55,6 +55,7 @@ type EventListResponse struct {
 	// [EventListResponseEventMessageUpdatedProperties],
 	// [EventListResponseEventMessageRemovedProperties],
 	// [EventListResponseEventMessagePartUpdatedProperties],
+	// [EventListResponseEventMessagePartRemovedProperties],
 	// [EventListResponseEventStorageWriteProperties],
 	// [EventListResponseEventSessionUpdatedProperties],
 	// [EventListResponseEventSessionDeletedProperties],
@@ -99,6 +100,7 @@ func (r *EventListResponse) UnmarshalJSON(data []byte) (err error) {
 // [EventListResponseEventInstallationUpdated],
 // [EventListResponseEventMessageUpdated], [EventListResponseEventMessageRemoved],
 // [EventListResponseEventMessagePartUpdated],
+// [EventListResponseEventMessagePartRemoved],
 // [EventListResponseEventStorageWrite], [EventListResponseEventSessionUpdated],
 // [EventListResponseEventSessionDeleted], [EventListResponseEventSessionIdle],
 // [EventListResponseEventSessionError],
@@ -113,6 +115,7 @@ func (r EventListResponse) AsUnion() EventListResponseUnion {
 // [EventListResponseEventInstallationUpdated],
 // [EventListResponseEventMessageUpdated], [EventListResponseEventMessageRemoved],
 // [EventListResponseEventMessagePartUpdated],
+// [EventListResponseEventMessagePartRemoved],
 // [EventListResponseEventStorageWrite], [EventListResponseEventSessionUpdated],
 // [EventListResponseEventSessionDeleted], [EventListResponseEventSessionIdle],
 // [EventListResponseEventSessionError], [EventListResponseEventFileWatcherUpdated]
@@ -160,6 +163,11 @@ func init() {
 			Type:               reflect.TypeOf(EventListResponseEventMessagePartUpdated{}),
 			DiscriminatorValue: "message.part.updated",
 		},
+		apijson.UnionVariant{
+			TypeFilter:         gjson.JSON,
+			Type:               reflect.TypeOf(EventListResponseEventMessagePartRemoved{}),
+			DiscriminatorValue: "message.part.removed",
+		},
 		apijson.UnionVariant{
 			TypeFilter:         gjson.JSON,
 			Type:               reflect.TypeOf(EventListResponseEventStorageWrite{}),
@@ -651,6 +659,68 @@ func (r EventListResponseEventMessagePartUpdatedType) IsKnown() bool {
 	return false
 }
 
+type EventListResponseEventMessagePartRemoved struct {
+	Properties EventListResponseEventMessagePartRemovedProperties `json:"properties,required"`
+	Type       EventListResponseEventMessagePartRemovedType       `json:"type,required"`
+	JSON       eventListResponseEventMessagePartRemovedJSON       `json:"-"`
+}
+
+// eventListResponseEventMessagePartRemovedJSON contains the JSON metadata for the
+// struct [EventListResponseEventMessagePartRemoved]
+type eventListResponseEventMessagePartRemovedJSON struct {
+	Properties  apijson.Field
+	Type        apijson.Field
+	raw         string
+	ExtraFields map[string]apijson.Field
+}
+
+func (r *EventListResponseEventMessagePartRemoved) UnmarshalJSON(data []byte) (err error) {
+	return apijson.UnmarshalRoot(data, r)
+}
+
+func (r eventListResponseEventMessagePartRemovedJSON) RawJSON() string {
+	return r.raw
+}
+
+func (r EventListResponseEventMessagePartRemoved) implementsEventListResponse() {}
+
+type EventListResponseEventMessagePartRemovedProperties struct {
+	MessageID string                                                 `json:"messageID,required"`
+	PartID    string                                                 `json:"partID,required"`
+	JSON      eventListResponseEventMessagePartRemovedPropertiesJSON `json:"-"`
+}
+
+// eventListResponseEventMessagePartRemovedPropertiesJSON contains the JSON
+// metadata for the struct [EventListResponseEventMessagePartRemovedProperties]
+type eventListResponseEventMessagePartRemovedPropertiesJSON struct {
+	MessageID   apijson.Field
+	PartID      apijson.Field
+	raw         string
+	ExtraFields map[string]apijson.Field
+}
+
+func (r *EventListResponseEventMessagePartRemovedProperties) UnmarshalJSON(data []byte) (err error) {
+	return apijson.UnmarshalRoot(data, r)
+}
+
+func (r eventListResponseEventMessagePartRemovedPropertiesJSON) RawJSON() string {
+	return r.raw
+}
+
+type EventListResponseEventMessagePartRemovedType string
+
+const (
+	EventListResponseEventMessagePartRemovedTypeMessagePartRemoved EventListResponseEventMessagePartRemovedType = "message.part.removed"
+)
+
+func (r EventListResponseEventMessagePartRemovedType) IsKnown() bool {
+	switch r {
+	case EventListResponseEventMessagePartRemovedTypeMessagePartRemoved:
+		return true
+	}
+	return false
+}
+
 type EventListResponseEventStorageWrite struct {
 	Properties EventListResponseEventStorageWriteProperties `json:"properties,required"`
 	Type       EventListResponseEventStorageWriteType       `json:"type,required"`
@@ -1236,6 +1306,7 @@ const (
 	EventListResponseTypeMessageUpdated       EventListResponseType = "message.updated"
 	EventListResponseTypeMessageRemoved       EventListResponseType = "message.removed"
 	EventListResponseTypeMessagePartUpdated   EventListResponseType = "message.part.updated"
+	EventListResponseTypeMessagePartRemoved   EventListResponseType = "message.part.removed"
 	EventListResponseTypeStorageWrite         EventListResponseType = "storage.write"
 	EventListResponseTypeSessionUpdated       EventListResponseType = "session.updated"
 	EventListResponseTypeSessionDeleted       EventListResponseType = "session.deleted"
@@ -1247,7 +1318,7 @@ const (
 
 func (r EventListResponseType) IsKnown() bool {
 	switch r {
-	case EventListResponseTypeLspClientDiagnostics, EventListResponseTypePermissionUpdated, EventListResponseTypeFileEdited, EventListResponseTypeInstallationUpdated, EventListResponseTypeMessageUpdated, EventListResponseTypeMessageRemoved, EventListResponseTypeMessagePartUpdated, EventListResponseTypeStorageWrite, EventListResponseTypeSessionUpdated, EventListResponseTypeSessionDeleted, EventListResponseTypeSessionIdle, EventListResponseTypeSessionError, EventListResponseTypeFileWatcherUpdated, EventListResponseTypeIdeInstalled:
+	case EventListResponseTypeLspClientDiagnostics, EventListResponseTypePermissionUpdated, EventListResponseTypeFileEdited, EventListResponseTypeInstallationUpdated, EventListResponseTypeMessageUpdated, EventListResponseTypeMessageRemoved, EventListResponseTypeMessagePartUpdated, EventListResponseTypeMessagePartRemoved, EventListResponseTypeStorageWrite, EventListResponseTypeSessionUpdated, EventListResponseTypeSessionDeleted, EventListResponseTypeSessionIdle, EventListResponseTypeSessionError, EventListResponseTypeFileWatcherUpdated, EventListResponseTypeIdeInstalled:
 		return true
 	}
 	return false

+ 74 - 13
packages/tui/sdk/session.go

@@ -175,6 +175,7 @@ func (r *SessionService) Unshare(ctx context.Context, id string, opts ...option.
 type AssistantMessage struct {
 	ID         string                 `json:"id,required"`
 	Cost       float64                `json:"cost,required"`
+	Mode       string                 `json:"mode,required"`
 	ModelID    string                 `json:"modelID,required"`
 	Path       AssistantMessagePath   `json:"path,required"`
 	ProviderID string                 `json:"providerID,required"`
@@ -193,6 +194,7 @@ type AssistantMessage struct {
 type assistantMessageJSON struct {
 	ID          apijson.Field
 	Cost        apijson.Field
+	Mode        apijson.Field
 	ModelID     apijson.Field
 	Path        apijson.Field
 	ProviderID  apijson.Field
@@ -735,6 +737,7 @@ type Message struct {
 	Cost float64     `json:"cost"`
 	// This field can have the runtime type of [AssistantMessageError].
 	Error   interface{} `json:"error"`
+	Mode    string      `json:"mode"`
 	ModelID string      `json:"modelID"`
 	// This field can have the runtime type of [AssistantMessagePath].
 	Path       interface{} `json:"path"`
@@ -756,6 +759,7 @@ type messageJSON struct {
 	Time        apijson.Field
 	Cost        apijson.Field
 	Error       apijson.Field
+	Mode        apijson.Field
 	ModelID     apijson.Field
 	Path        apijson.Field
 	ProviderID  apijson.Field
@@ -825,16 +829,19 @@ func (r MessageRole) IsKnown() bool {
 }
 
 type Part struct {
-	ID        string         `json:"id,required"`
-	MessageID string         `json:"messageID,required"`
-	SessionID string         `json:"sessionID,required"`
-	Type      PartType       `json:"type,required"`
-	CallID    string         `json:"callID"`
-	Cost      float64        `json:"cost"`
-	Filename  string         `json:"filename"`
-	Mime      string         `json:"mime"`
-	Snapshot  string         `json:"snapshot"`
-	Source    FilePartSource `json:"source"`
+	ID        string   `json:"id,required"`
+	MessageID string   `json:"messageID,required"`
+	SessionID string   `json:"sessionID,required"`
+	Type      PartType `json:"type,required"`
+	CallID    string   `json:"callID"`
+	Cost      float64  `json:"cost"`
+	Filename  string   `json:"filename"`
+	// This field can have the runtime type of [[]string].
+	Files    interface{}    `json:"files"`
+	Hash     string         `json:"hash"`
+	Mime     string         `json:"mime"`
+	Snapshot string         `json:"snapshot"`
+	Source   FilePartSource `json:"source"`
 	// This field can have the runtime type of [ToolPartState].
 	State     interface{} `json:"state"`
 	Synthetic bool        `json:"synthetic"`
@@ -858,6 +865,8 @@ type partJSON struct {
 	CallID      apijson.Field
 	Cost        apijson.Field
 	Filename    apijson.Field
+	Files       apijson.Field
+	Hash        apijson.Field
 	Mime        apijson.Field
 	Snapshot    apijson.Field
 	Source      apijson.Field
@@ -889,13 +898,13 @@ func (r *Part) UnmarshalJSON(data []byte) (err error) {
 // for more type safety.
 //
 // Possible runtime types of the union are [TextPart], [FilePart], [ToolPart],
-// [StepStartPart], [StepFinishPart], [SnapshotPart].
+// [StepStartPart], [StepFinishPart], [SnapshotPart], [PartPatchPart].
 func (r Part) AsUnion() PartUnion {
 	return r.union
 }
 
 // Union satisfied by [TextPart], [FilePart], [ToolPart], [StepStartPart],
-// [StepFinishPart] or [SnapshotPart].
+// [StepFinishPart], [SnapshotPart] or [PartPatchPart].
 type PartUnion interface {
 	implementsPart()
 }
@@ -934,9 +943,60 @@ func init() {
 			Type:               reflect.TypeOf(SnapshotPart{}),
 			DiscriminatorValue: "snapshot",
 		},
+		apijson.UnionVariant{
+			TypeFilter:         gjson.JSON,
+			Type:               reflect.TypeOf(PartPatchPart{}),
+			DiscriminatorValue: "patch",
+		},
 	)
 }
 
+type PartPatchPart struct {
+	ID        string            `json:"id,required"`
+	Files     []string          `json:"files,required"`
+	Hash      string            `json:"hash,required"`
+	MessageID string            `json:"messageID,required"`
+	SessionID string            `json:"sessionID,required"`
+	Type      PartPatchPartType `json:"type,required"`
+	JSON      partPatchPartJSON `json:"-"`
+}
+
+// partPatchPartJSON contains the JSON metadata for the struct [PartPatchPart]
+type partPatchPartJSON struct {
+	ID          apijson.Field
+	Files       apijson.Field
+	Hash        apijson.Field
+	MessageID   apijson.Field
+	SessionID   apijson.Field
+	Type        apijson.Field
+	raw         string
+	ExtraFields map[string]apijson.Field
+}
+
+func (r *PartPatchPart) UnmarshalJSON(data []byte) (err error) {
+	return apijson.UnmarshalRoot(data, r)
+}
+
+func (r partPatchPartJSON) RawJSON() string {
+	return r.raw
+}
+
+func (r PartPatchPart) implementsPart() {}
+
+type PartPatchPartType string
+
+const (
+	PartPatchPartTypePatch PartPatchPartType = "patch"
+)
+
+func (r PartPatchPartType) IsKnown() bool {
+	switch r {
+	case PartPatchPartTypePatch:
+		return true
+	}
+	return false
+}
+
 type PartType string
 
 const (
@@ -946,11 +1006,12 @@ const (
 	PartTypeStepStart  PartType = "step-start"
 	PartTypeStepFinish PartType = "step-finish"
 	PartTypeSnapshot   PartType = "snapshot"
+	PartTypePatch      PartType = "patch"
 )
 
 func (r PartType) IsKnown() bool {
 	switch r {
-	case PartTypeText, PartTypeFile, PartTypeTool, PartTypeStepStart, PartTypeStepFinish, PartTypeSnapshot:
+	case PartTypeText, PartTypeFile, PartTypeTool, PartTypeStepStart, PartTypeStepFinish, PartTypeSnapshot, PartTypePatch:
 		return true
 	}
 	return false