فهرست منبع

feat: add support for fast modes for claude and gpt models (that support it) (#21706)

Aiden Cline 1 هفته پیش
والد
کامیت
b0600664ab

+ 2 - 2
packages/opencode/src/plugin/codex.ts

@@ -376,9 +376,9 @@ export async function CodexAuthPlugin(input: PluginInput): Promise<Hooks> {
           "gpt-5.4",
           "gpt-5.4-mini",
         ])
-        for (const modelId of Object.keys(provider.models)) {
+        for (const [modelId, model] of Object.entries(provider.models)) {
           if (modelId.includes("codex")) continue
-          if (allowedModels.has(modelId)) continue
+          if (allowedModels.has(model.api.id)) continue
           delete provider.models[modelId]
         }
 

+ 40 - 17
packages/opencode/src/provider/models.ts

@@ -22,6 +22,27 @@ export namespace ModelsDev {
   )
   const ttl = 5 * 60 * 1000
 
+  type JsonValue = string | number | boolean | null | { [key: string]: JsonValue } | JsonValue[]
+
+  const JsonValue: z.ZodType<JsonValue> = z.lazy(() =>
+    z.union([z.string(), z.number(), z.boolean(), z.null(), z.array(JsonValue), z.record(z.string(), JsonValue)]),
+  )
+
+  const Cost = z.object({
+    input: z.number(),
+    output: z.number(),
+    cache_read: z.number().optional(),
+    cache_write: z.number().optional(),
+    context_over_200k: z
+      .object({
+        input: z.number(),
+        output: z.number(),
+        cache_read: z.number().optional(),
+        cache_write: z.number().optional(),
+      })
+      .optional(),
+  })
+
   export const Model = z.object({
     id: z.string(),
     name: z.string(),
@@ -41,22 +62,7 @@ export namespace ModelsDev {
           .strict(),
       ])
       .optional(),
-    cost: z
-      .object({
-        input: z.number(),
-        output: z.number(),
-        cache_read: z.number().optional(),
-        cache_write: z.number().optional(),
-        context_over_200k: z
-          .object({
-            input: z.number(),
-            output: z.number(),
-            cache_read: z.number().optional(),
-            cache_write: z.number().optional(),
-          })
-          .optional(),
-      })
-      .optional(),
+    cost: Cost.optional(),
     limit: z.object({
       context: z.number(),
       input: z.number().optional(),
@@ -68,7 +74,24 @@ export namespace ModelsDev {
         output: z.array(z.enum(["text", "audio", "image", "video", "pdf"])),
       })
       .optional(),
-    experimental: z.boolean().optional(),
+    experimental: z
+      .object({
+        modes: z
+          .record(
+            z.string(),
+            z.object({
+              cost: Cost.optional(),
+              provider: z
+                .object({
+                  body: z.record(z.string(), JsonValue).optional(),
+                  headers: z.record(z.string(), z.string()).optional(),
+                })
+                .optional(),
+            }),
+          )
+          .optional(),
+      })
+      .optional(),
     status: z.enum(["alpha", "beta", "deprecated"]).optional(),
     provider: z.object({ npm: z.string().optional(), api: z.string().optional() }).optional(),
   })

+ 42 - 19
packages/opencode/src/provider/provider.ts

@@ -926,6 +926,28 @@ export namespace Provider {
 
   export class Service extends ServiceMap.Service<Service, Interface>()("@opencode/Provider") {}
 
+  function cost(c: ModelsDev.Model["cost"]): Model["cost"] {
+    const result: Model["cost"] = {
+      input: c?.input ?? 0,
+      output: c?.output ?? 0,
+      cache: {
+        read: c?.cache_read ?? 0,
+        write: c?.cache_write ?? 0,
+      },
+    }
+    if (c?.context_over_200k) {
+      result.experimentalOver200K = {
+        cache: {
+          read: c.context_over_200k.cache_read ?? 0,
+          write: c.context_over_200k.cache_write ?? 0,
+        },
+        input: c.context_over_200k.input,
+        output: c.context_over_200k.output,
+      }
+    }
+    return result
+  }
+
   function fromModelsDevModel(provider: ModelsDev.Provider, model: ModelsDev.Model): Model {
     const m: Model = {
       id: ModelID.make(model.id),
@@ -940,24 +962,7 @@ export namespace Provider {
       status: model.status ?? "active",
       headers: {},
       options: {},
-      cost: {
-        input: model.cost?.input ?? 0,
-        output: model.cost?.output ?? 0,
-        cache: {
-          read: model.cost?.cache_read ?? 0,
-          write: model.cost?.cache_write ?? 0,
-        },
-        experimentalOver200K: model.cost?.context_over_200k
-          ? {
-              cache: {
-                read: model.cost.context_over_200k.cache_read ?? 0,
-                write: model.cost.context_over_200k.cache_write ?? 0,
-              },
-              input: model.cost.context_over_200k.input,
-              output: model.cost.context_over_200k.output,
-            }
-          : undefined,
-      },
+      cost: cost(model.cost),
       limit: {
         context: model.limit.context,
         input: model.limit.input,
@@ -994,13 +999,31 @@ export namespace Provider {
   }
 
   export function fromModelsDevProvider(provider: ModelsDev.Provider): Info {
+    const models: Record<string, Model> = {}
+    for (const [key, model] of Object.entries(provider.models)) {
+      models[key] = fromModelsDevModel(provider, model)
+      for (const [mode, opts] of Object.entries(model.experimental?.modes ?? {})) {
+        const id = `${model.id}-${mode}`
+        const m = fromModelsDevModel(provider, model)
+        m.id = ModelID.make(id)
+        m.name = `${model.name} ${mode[0].toUpperCase()}${mode.slice(1)}`
+        if (opts.cost) m.cost = mergeDeep(m.cost, cost(opts.cost))
+        // convert body params to camelCase for ai sdk compatibility
+        if (opts.provider?.body)
+          m.options = Object.fromEntries(
+            Object.entries(opts.provider.body).map(([k, v]) => [k.replace(/_([a-z])/g, (_, c) => c.toUpperCase()), v]),
+          )
+        if (opts.provider?.headers) m.headers = opts.provider.headers
+        models[id] = m
+      }
+    }
     return {
       id: ProviderID.make(provider.id),
       source: "custom",
       name: provider.name,
       env: provider.env ?? [],
       options: {},
-      models: mapValues(provider.models, (model) => fromModelsDevModel(provider, model)),
+      models,
     }
   }
 

+ 68 - 0
packages/opencode/test/provider/provider.test.ts

@@ -6,6 +6,7 @@ import { tmpdir } from "../fixture/fixture"
 import { Global } from "../../src/global"
 import { Instance } from "../../src/project/instance"
 import { Plugin } from "../../src/plugin/index"
+import { ModelsDev } from "../../src/provider/models"
 import { Provider } from "../../src/provider/provider"
 import { ProviderID, ModelID } from "../../src/provider/schema"
 import { Filesystem } from "../../src/util/filesystem"
@@ -1823,6 +1824,73 @@ test("custom model inherits api.url from models.dev provider", async () => {
   })
 })
 
+test("mode cost preserves over-200k pricing from base model", () => {
+  const provider = {
+    id: "openai",
+    name: "OpenAI",
+    env: [],
+    api: "https://api.openai.com/v1",
+    models: {
+      "gpt-5.4": {
+        id: "gpt-5.4",
+        name: "GPT-5.4",
+        family: "gpt",
+        release_date: "2026-03-05",
+        attachment: true,
+        reasoning: true,
+        temperature: false,
+        tool_call: true,
+        cost: {
+          input: 2.5,
+          output: 15,
+          cache_read: 0.25,
+          context_over_200k: {
+            input: 5,
+            output: 22.5,
+            cache_read: 0.5,
+          },
+        },
+        limit: {
+          context: 1_050_000,
+          input: 922_000,
+          output: 128_000,
+        },
+        experimental: {
+          modes: {
+            fast: {
+              cost: {
+                input: 5,
+                output: 30,
+                cache_read: 0.5,
+              },
+              provider: {
+                body: {
+                  service_tier: "priority",
+                },
+              },
+            },
+          },
+        },
+      },
+    },
+  } as ModelsDev.Provider
+
+  const model = Provider.fromModelsDevProvider(provider).models["gpt-5.4-fast"]
+  expect(model.cost.input).toEqual(5)
+  expect(model.cost.output).toEqual(30)
+  expect(model.cost.cache.read).toEqual(0.5)
+  expect(model.cost.cache.write).toEqual(0)
+  expect(model.options["serviceTier"]).toEqual("priority")
+  expect(model.cost.experimentalOver200K).toEqual({
+    input: 5,
+    output: 22.5,
+    cache: {
+      read: 0.5,
+      write: 0,
+    },
+  })
+})
+
 test("model variants are generated for reasoning models", async () => {
   await using tmp = await tmpdir({
     init: async (dir) => {