Jelajahi Sumber

huge optimization for token usage with anthropic

Dax Raad 8 bulan lalu
induk
melakukan
1684042fb6

+ 15 - 14
packages/opencode/src/provider/transform.ts

@@ -1,24 +1,25 @@
-import type { CoreMessage } from "ai"
+import type { CoreMessage, LanguageModelV1Prompt } from "ai"
+import { unique } from "remeda"
 
 export namespace ProviderTransform {
   export function message(
-    msg: CoreMessage,
-    index: number,
+    msgs: LanguageModelV1Prompt,
     providerID: string,
     modelID: string,
   ) {
-    if (
-      (providerID === "anthropic" || modelID.includes("anthropic")) &&
-      index < 4
-    ) {
-      msg.providerOptions = {
-        ...msg.providerOptions,
-        anthropic: {
-          cacheControl: { type: "ephemeral" },
-        },
+    if (providerID === "anthropic" || modelID.includes("anthropic")) {
+      const system = msgs.filter((msg) => msg.role === "system").slice(0, 2)
+      const final = msgs.filter((msg) => msg.role !== "system").slice(-2)
+
+      for (const msg of unique([...system, ...final])) {
+        msg.providerMetadata = {
+          ...msg.providerMetadata,
+          anthropic: {
+            cacheControl: { type: "ephemeral" },
+          },
+        }
       }
     }
-
-    return msg
+    return msgs
   }
 }

+ 20 - 7
packages/opencode/src/session/index.ts

@@ -14,6 +14,7 @@ import {
   type CoreMessage,
   type UIMessage,
   type ProviderMetadata,
+  wrapLanguageModel,
 } from "ai"
 import { z, ZodSchema } from "zod"
 import { Decimal } from "decimal.js"
@@ -285,9 +286,7 @@ export namespace Session {
               parts: toParts(input.parts),
             },
           ]),
-        ].map((msg, i) =>
-          ProviderTransform.message(msg, i, input.providerID, input.modelID),
-        ),
+        ],
         model: model.language,
       })
         .then((result) => {
@@ -527,12 +526,26 @@ export namespace Session {
         ...convertToCoreMessages(
           msgs.map(toUIMessage).filter((x) => x.parts.length > 0),
         ),
-      ].map((msg, i) =>
-        ProviderTransform.message(msg, i, input.providerID, input.modelID),
-      ),
+      ],
       temperature: model.info.temperature ? 0 : undefined,
       tools: model.info.tool_call === false ? undefined : tools,
-      model: model.language,
+      model: wrapLanguageModel({
+        model: model.language,
+        middleware: [
+          {
+            async transformParams(args) {
+              if (args.type === "stream") {
+                args.params.prompt = ProviderTransform.message(
+                  args.params.prompt,
+                  input.providerID,
+                  input.modelID,
+                )
+              }
+              return args.params
+            },
+          },
+        ],
+      }),
     })
     try {
       for await (const value of result.fullStream) {