Răsfoiți Sursa

feat(models): enable Kimi k2 ⇄ Claude trajectory handoff (#1525)

Ricardo Gonzalez 6 luni în urmă
părinte
comite
8f45a0e227
1 a modificat fișierele cu 61 adăugiri și 38 ștergeri
  1. 61 38
      packages/opencode/src/provider/transform.ts

+ 61 - 38
packages/opencode/src/provider/transform.ts

@@ -1,48 +1,76 @@
 import type { ModelMessage } from "ai"
 import { unique } from "remeda"
 
+
 export namespace ProviderTransform {
-  export function message(msgs: ModelMessage[], providerID: string, modelID: string) {
-    if (providerID === "anthropic" || modelID.includes("anthropic") || modelID.includes("claude")) {
-      const system = msgs.filter((msg) => msg.role === "system").slice(0, 2)
-      const final = msgs.filter((msg) => msg.role !== "system").slice(-2)
-
-      const providerOptions = {
-        anthropic: {
-          cacheControl: { type: "ephemeral" },
-        },
-        openrouter: {
-          cache_control: { type: "ephemeral" },
-        },
-        bedrock: {
-          cachePoint: { type: "ephemeral" },
-        },
-        openaiCompatible: {
-          cache_control: { type: "ephemeral" },
-        },
+  function normalizeToolCallIds(msgs: ModelMessage[]): ModelMessage[] {
+    return msgs.map((msg) => {
+      if ((msg.role === "assistant" || msg.role === "tool") && Array.isArray(msg.content)) {
+        msg.content = msg.content.map((part) => {
+          if ((part.type === "tool-call" || part.type === "tool-result") && "toolCallId" in part) {
+            return {
+              ...part,
+              toolCallId: part.toolCallId.replace(/[^a-zA-Z0-9_-]/g, '_')
+            }
+          }
+          return part
+        })
       }
+      return msg
+    })
+  }
 
-      for (const msg of unique([...system, ...final])) {
-        const shouldUseContentOptions =
-          providerID !== "anthropic" && Array.isArray(msg.content) && msg.content.length > 0
+  function applyCaching(msgs: ModelMessage[], providerID: string): ModelMessage[] {
+    const system = msgs.filter((msg) => msg.role === "system").slice(0, 2)
+    const final = msgs.filter((msg) => msg.role !== "system").slice(-2)
 
-        if (shouldUseContentOptions) {
-          const lastContent = msg.content[msg.content.length - 1]
-          if (lastContent && typeof lastContent === "object") {
-            lastContent.providerOptions = {
-              ...lastContent.providerOptions,
-              ...providerOptions,
-            }
-            continue
+    const providerOptions = {
+      anthropic: {
+        cacheControl: { type: "ephemeral" },
+      },
+      openrouter: {
+        cache_control: { type: "ephemeral" },
+      },
+      bedrock: {
+        cachePoint: { type: "ephemeral" },
+      },
+      openaiCompatible: {
+        cache_control: { type: "ephemeral" },
+      },
+    }
+
+    for (const msg of unique([...system, ...final])) {
+      const shouldUseContentOptions =
+        providerID !== "anthropic" && Array.isArray(msg.content) && msg.content.length > 0
+
+      if (shouldUseContentOptions) {
+        const lastContent = msg.content[msg.content.length - 1]
+        if (lastContent && typeof lastContent === "object") {
+          lastContent.providerOptions = {
+            ...lastContent.providerOptions,
+            ...providerOptions,
           }
+          continue
         }
+      }
 
-        msg.providerOptions = {
-          ...msg.providerOptions,
-          ...providerOptions,
-        }
+      msg.providerOptions = {
+        ...msg.providerOptions,
+        ...providerOptions,
       }
     }
+
+    return msgs
+  }
+
+  export function message(msgs: ModelMessage[], providerID: string, modelID: string) {
+    if (modelID.includes("claude")) {
+      msgs = normalizeToolCallIds(msgs)
+    }
+    if (providerID === "anthropic" || modelID.includes("anthropic") || modelID.includes("claude")) {
+      msgs = applyCaching(msgs, providerID)
+    }
+    
     return msgs
   }
 
@@ -50,9 +78,4 @@ export namespace ProviderTransform {
     if (modelID.toLowerCase().includes("qwen")) return 0.55
     return 0
   }
-
-  export function topP(_providerID: string, modelID: string) {
-    if (modelID.toLowerCase().includes("qwen")) return 1
-    return undefined
-  }
 }