Просмотр исходного кода

more efficient snapshots in parallel toolcalls

Dax Raad 6 месяцев назад
Родитель
Сommit
06830327e7
1 измененных файлов с 26 добавлено и 43 удалено
  1. 26 43
      packages/opencode/src/session/index.ts

+ 26 - 43
packages/opencode/src/session/index.ts

@@ -735,7 +735,6 @@ export namespace Session {
               args,
             },
           )
-          await processor.track(options.toolCallId)
           const result = await item.execute(args, {
             sessionID: input.sessionID,
             abort: abort.signal,
@@ -784,7 +783,6 @@ export namespace Session {
       const execute = item.execute
       if (!execute) continue
       item.execute = async (args, opts) => {
-        await processor.track(opts.toolCallId)
         const result = await execute(args, opts)
         const output = result.content
           .filter((x: any) => x.type === "text")
@@ -920,15 +918,11 @@ export namespace Session {
   }
 
   function createProcessor(assistantMsg: MessageV2.Assistant, model: ModelsDev.Model) {
-    const toolCalls: Record<string, MessageV2.ToolPart> = {}
-    const snapshots: Record<string, string> = {}
+    const toolcalls: Record<string, MessageV2.ToolPart> = {}
+    let snapshot: string | undefined
     return {
-      async track(toolCallID: string) {
-        const hash = await Snapshot.track()
-        if (hash) snapshots[toolCallID] = hash
-      },
       partFromToolCall(toolCallID: string) {
-        return toolCalls[toolCallID]
+        return toolcalls[toolCallID]
       },
       async process(stream: StreamTextResult<Record<string, AITool>, never>) {
         try {
@@ -944,7 +938,7 @@ export namespace Session {
 
               case "tool-input-start":
                 const part = await updatePart({
-                  id: toolCalls[value.id]?.id ?? Identifier.ascending("part"),
+                  id: toolcalls[value.id]?.id ?? Identifier.ascending("part"),
                   messageID: assistantMsg.id,
                   sessionID: assistantMsg.sessionID,
                   type: "tool",
@@ -954,7 +948,7 @@ export namespace Session {
                     status: "pending",
                   },
                 })
-                toolCalls[value.id] = part as MessageV2.ToolPart
+                toolcalls[value.id] = part as MessageV2.ToolPart
                 break
 
               case "tool-input-delta":
@@ -964,7 +958,7 @@ export namespace Session {
                 break
 
               case "tool-call": {
-                const match = toolCalls[value.toolCallId]
+                const match = toolcalls[value.toolCallId]
                 if (match) {
                   const part = await updatePart({
                     ...match,
@@ -976,12 +970,12 @@ export namespace Session {
                       },
                     },
                   })
-                  toolCalls[value.toolCallId] = part as MessageV2.ToolPart
+                  toolcalls[value.toolCallId] = part as MessageV2.ToolPart
                 }
                 break
               }
               case "tool-result": {
-                const match = toolCalls[value.toolCallId]
+                const match = toolcalls[value.toolCallId]
                 if (match && match.state.status === "running") {
                   await updatePart({
                     ...match,
@@ -997,27 +991,13 @@ export namespace Session {
                       },
                     },
                   })
-                  delete toolCalls[value.toolCallId]
-                  const snapshot = snapshots[value.toolCallId]
-                  if (snapshot) {
-                    const patch = await Snapshot.patch(snapshot)
-                    if (patch.files.length) {
-                      await updatePart({
-                        id: Identifier.ascending("part"),
-                        messageID: assistantMsg.id,
-                        sessionID: assistantMsg.sessionID,
-                        type: "patch",
-                        hash: patch.hash,
-                        files: patch.files,
-                      })
-                    }
-                  }
+                  delete toolcalls[value.toolCallId]
                 }
                 break
               }
 
               case "tool-error": {
-                const match = toolCalls[value.toolCallId]
+                const match = toolcalls[value.toolCallId]
                 if (match && match.state.status === "running") {
                   await updatePart({
                     ...match,
@@ -1031,19 +1011,7 @@ export namespace Session {
                       },
                     },
                   })
-                  delete toolCalls[value.toolCallId]
-                  const snapshot = snapshots[value.toolCallId]
-                  if (snapshot) {
-                    const patch = await Snapshot.patch(snapshot)
-                    await updatePart({
-                      id: Identifier.ascending("part"),
-                      messageID: assistantMsg.id,
-                      sessionID: assistantMsg.sessionID,
-                      type: "patch",
-                      hash: patch.hash,
-                      files: patch.files,
-                    })
-                  }
+                  delete toolcalls[value.toolCallId]
                 }
                 break
               }
@@ -1058,6 +1026,7 @@ export namespace Session {
                   sessionID: assistantMsg.sessionID,
                   type: "step-start",
                 })
+                snapshot = await Snapshot.track()
                 break
 
               case "finish-step":
@@ -1073,6 +1042,20 @@ export namespace Session {
                   cost: usage.cost,
                 })
                 await updateMessage(assistantMsg)
+                if (snapshot) {
+                  const patch = await Snapshot.patch(snapshot)
+                  if (patch.files.length) {
+                    await updatePart({
+                      id: Identifier.ascending("part"),
+                      messageID: assistantMsg.id,
+                      sessionID: assistantMsg.sessionID,
+                      type: "patch",
+                      hash: patch.hash,
+                      files: patch.files,
+                    })
+                  }
+                  snapshot = undefined
+                }
                 break
 
               case "text-start":