Просмотр исходного кода

core: convert Model type to Zod schema for better type safety and validation

Dax Raad 4 месяцев назад
Родитель
Сommit
a844eb2429
1 измененных файлов с 252 добавлено и 110 удалено
  1. 252 110
      packages/opencode/src/provider/provider.ts

+ 252 - 110
packages/opencode/src/provider/provider.ts

@@ -1,7 +1,7 @@
 import z from "zod"
 import fuzzysort from "fuzzysort"
 import { Config } from "../config/config"
-import { mergeDeep, sortBy } from "remeda"
+import { entries, mapValues, mergeDeep, pipe, sortBy } from "remeda"
 import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
 import { Log } from "../util/log"
 import { BunProc } from "../bun"
@@ -43,7 +43,7 @@ export namespace Provider {
     "@ai-sdk/github-copilot": createGitHubCopilotOpenAICompatible,
   }
 
-  type CustomLoader = (provider: ModelsDev.Provider) => Promise<{
+  type CustomLoader = (provider: Info) => Promise<{
     autoload: boolean
     getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
     options?: Record<string, any>
@@ -299,18 +299,155 @@ export namespace Provider {
     },
   }
 
-  export type Model = {
-    providerID: string
-    modelID: string
+  export const Model = z
+    .object({
+      id: z.string(),
+      providerID: z.string(),
+      api: z.object({
+        id: z.string(),
+        url: z.string(),
+        npm: z.string(),
+      }),
+      name: z.string(),
+      capabilities: z.object({
+        temperature: z.boolean(),
+        reasoning: z.boolean(),
+        attachment: z.boolean(),
+        toolcall: z.boolean(),
+        input: {
+          text: z.boolean(),
+          audio: z.boolean(),
+          image: z.boolean(),
+          video: z.boolean(),
+          pdf: z.boolean(),
+        },
+        output: {
+          text: z.boolean(),
+          audio: z.boolean(),
+          image: z.boolean(),
+          video: z.boolean(),
+          pdf: z.boolean(),
+        },
+      }),
+      cost: z.object({
+        input: z.number(),
+        output: z.number(),
+        cache: z.object({
+          read: z.number(),
+          write: z.number(),
+        }),
+        experimentalOver200K: z
+          .object({
+            input: z.number(),
+            output: z.number(),
+            cache: z.object({
+              read: z.number(),
+              write: z.number(),
+            }),
+          })
+          .optional(),
+      }),
+      limit: z.object({
+        context: z.number(),
+        output: z.number(),
+      }),
+      status: z.enum(["alpha", "beta", "deprecated", "active"]),
+      options: z.record(z.string(), z.any()),
+      headers: z.record(z.string(), z.string()),
+    })
+    .meta({
+      ref: "Model",
+    })
+  export type Model = z.infer<typeof Model>
+
+  export const Info = z.object({
+    id: z.string(),
+    name: z.string(),
+    source: z.enum(["env", "config", "custom", "api"]),
+    env: z.string().array(),
+    key: z.string().optional(),
+    options: z.record(z.string(), z.any()),
+    models: z.record(z.string(), Model),
+  })
+  export type Info = z.infer<typeof Info>
+
+  function fromModelsDevModel(provider: ModelsDev.Provider, model: ModelsDev.Model): Model {
+    return {
+      id: model.id,
+      name: model.name,
+      api: {
+        id: model.id,
+        url: provider.api!,
+        npm: model.provider?.npm ?? provider.npm ?? provider.id,
+      },
+      status: model.status ?? "active",
+      headers: model.headers ?? {},
+      options: model.options ?? {},
+      cost: {
+        input: model.cost.input,
+        output: model.cost.output,
+        cache: {
+          read: model.cost.cache_read ?? 0,
+          write: model.cost.cache_write ?? 0,
+        },
+        experimentalOver200K: model.cost.context_over_200k
+          ? {
+              cache: {
+                read: model.cost.context_over_200k.cache_read ?? 0,
+                write: model.cost.context_over_200k.cache_write ?? 0,
+              },
+              input: model.cost.context_over_200k.input,
+              output: model.cost.context_over_200k.output,
+            }
+          : undefined,
+      },
+      limit: {
+        context: model.limit.context,
+        output: model.limit.output,
+      },
+      capabilities: {
+        temperature: model.temperature,
+        reasoning: model.reasoning,
+        attachment: model.attachment,
+        toolcall: model.tool_call,
+        input: {
+          text: model.modalities?.input?.includes("text") ?? false,
+          audio: model.modalities?.input?.includes("audio") ?? false,
+          image: model.modalities?.input?.includes("image") ?? false,
+          video: model.modalities?.input?.includes("video") ?? false,
+          pdf: model.modalities?.input?.includes("pdf") ?? false,
+        },
+        output: {
+          text: model.modalities?.output?.includes("text") ?? false,
+          audio: model.modalities?.output?.includes("audio") ?? false,
+          image: model.modalities?.output?.includes("image") ?? false,
+          video: model.modalities?.output?.includes("video") ?? false,
+          pdf: model.modalities?.output?.includes("pdf") ?? false,
+        },
+      },
+    }
+  }
+
+  function fromModelsDevProvider(provider: ModelsDev.Provider): Info {
+    return {
+      id: provider.id,
+      source: "custom",
+      name: provider.name,
+      env: provider.env ?? [],
+      options: {},
+      models: mapValues(provider.models, (model) => fromModelsDevModel(provider, model)),
+    }
+  }
+
+  export type ModelWithStuff = {
     language: LanguageModel
-    info: ModelsDev.Model
-    npm: string
+    info: Model
   }
 
   const state = Instance.state(async () => {
     using _ = log.time("state")
     const config = await Config.get()
-    const database = await ModelsDev.get()
+    const database = mapValues(await ModelsDev.get(), fromModelsDevProvider)
 
     const disabled = new Set(config.disabled_providers ?? [])
     const enabled = config.enabled_providers ? new Set(config.enabled_providers) : null
@@ -321,43 +458,12 @@ export namespace Provider {
       return true
     }
 
-    const providers: {
-      [providerID: string]: {
-        source: Source
-        info: ModelsDev.Provider
-        getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
-        options: Record<string, any>
-      }
-    } = {}
-    const models = new Map<string, Model>()
+    const providers: { [providerID: string]: Info } = {}
+    const models = new Map<string, ModelWithStuff>()
     const sdk = new Map<number, SDK>()
 
     log.info("init")
 
-    function mergeProvider(
-      id: string,
-      options: Record<string, any>,
-      source: Source,
-      getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>,
-    ) {
-      const provider = providers[id]
-      if (!provider) {
-        const info = database[id]
-        if (!info) return
-        if (info.api && !options["baseURL"]) options["baseURL"] = info.api
-        providers[id] = {
-          source,
-          info,
-          options,
-          getModel,
-        }
-        return
-      }
-      provider.options = mergeDeep(provider.options, options)
-      provider.source = source
-      provider.getModel = getModel ?? provider.getModel
-    }
-
     const configProviders = Object.entries(config.provider ?? {})
 
     // Add GitHub Copilot Enterprise provider that inherits from GitHub Copilot
@@ -367,11 +473,17 @@ export namespace Provider {
         ...githubCopilot,
         id: "github-copilot-enterprise",
         name: "GitHub Copilot Enterprise",
-        // Enterprise uses a different API endpoint - will be set dynamically based on auth
-        api: undefined,
       }
     }
 
+    function mergeProvider(providerID: string, provider: Partial<Info>) {
+      const match = database[providerID]
+      if (!match) return
+      // @ts-expect-error
+      providers[providerID] = mergeDeep(match, provider)
+    }
+
+    // TODO: load config
     for (const [providerID, provider] of configProviders) {
       const existing = database[providerID]
       const parsed: ModelsDev.Provider = {
@@ -390,29 +502,27 @@ export namespace Provider {
           if (model.id && model.id !== modelID) return modelID
           return existing?.name ?? modelID
         })
-        const parsedModel: ModelsDev.Model = {
+        const parsedModel: Model = {
           id: modelID,
-          target: model.target ?? existing?.target ?? modelID,
+          apiID: model.target ?? existing?.target ?? modelID,
+          status: model.status ?? existing?.status ?? "alpha",
           name,
-          release_date: model.release_date ?? existing?.release_date,
-          attachment: model.attachment ?? existing?.attachment ?? false,
-          reasoning: model.reasoning ?? existing?.reasoning ?? false,
-          temperature: model.temperature ?? existing?.temperature ?? false,
-          tool_call: model.tool_call ?? existing?.tool_call ?? true,
-          cost:
-            !model.cost && !existing?.cost
-              ? {
-                  input: 0,
-                  output: 0,
-                  cache_read: 0,
-                  cache_write: 0,
-                }
-              : {
-                  cache_read: 0,
-                  cache_write: 0,
-                  ...existing?.cost,
-                  ...model.cost,
-                },
+          providerID,
+          npm: model.provider?.npm ?? existing?.provider?.npm ?? provider.npm ?? providerID,
+          support: {
+            temperature: model.temperature ?? existing?.temperature ?? false,
+            reasoning: model.reasoning ?? existing?.reasoning ?? false,
+            attachment: model.attachment ?? existing?.attachment ?? false,
+            toolcall: model.tool_call ?? existing?.tool_call ?? true,
+          },
+          cost: {
+            input: model?.cost?.input ?? existing?.cost?.input ?? 0,
+            output: model?.cost?.output ?? existing?.cost?.output ?? 0,
+            cache: {
+              read: model?.cost?.cache_read ?? existing?.cost?.cache_read ?? 0,
+              write: model?.cost?.cache_write ?? existing?.cost?.cache_write ?? 0,
+            },
+          },
           options: {
             ...existing?.options,
             ...model.options,
@@ -427,8 +537,7 @@ export namespace Provider {
               input: ["text"],
               output: ["text"],
             },
-          headers: model.headers,
-          provider: model.provider ?? existing?.provider,
+          headers: model.headers ?? {},
         }
         parsed.models[modelID] = parsedModel
       }
@@ -442,19 +551,20 @@ export namespace Provider {
       if (disabled.has(providerID)) continue
       const apiKey = provider.env.map((item) => env[item]).find(Boolean)
       if (!apiKey) continue
-      mergeProvider(
-        providerID,
-        // only include apiKey if there's only one potential option
-        provider.env.length === 1 ? { apiKey } : {},
-        "env",
-      )
+      mergeProvider(providerID, {
+        source: "env",
+        key: provider.env.length === 1 ? apiKey : undefined,
+      })
     }
 
     // load apikeys
     for (const [providerID, provider] of Object.entries(await Auth.all())) {
       if (disabled.has(providerID)) continue
       if (provider.type === "api") {
-        mergeProvider(providerID, { apiKey: provider.key }, "api")
+        mergeProvider(providerID, {
+          source: "api",
+          key: provider.key,
+        })
       }
     }
 
@@ -480,7 +590,10 @@ export namespace Provider {
       // Load for the main provider if auth exists
       if (auth) {
         const options = await plugin.auth.loader(() => Auth.get(providerID) as any, database[plugin.auth.provider])
-        mergeProvider(plugin.auth.provider, options ?? {}, "custom")
+        mergeProvider(plugin.auth.provider, {
+          source: "custom",
+          options: options,
+        })
       }
 
       // If this is github-copilot plugin, also register for github-copilot-enterprise if auth exists
@@ -493,7 +606,10 @@ export namespace Provider {
               () => Auth.get(enterpriseProviderID) as any,
               database[enterpriseProviderID],
             )
-            mergeProvider(enterpriseProviderID, enterpriseOptions ?? {}, "custom")
+            mergeProvider(enterpriseProviderID, {
+              source: "custom",
+              options: enterpriseOptions,
+            })
           }
         }
       }
@@ -503,13 +619,22 @@ export namespace Provider {
       if (disabled.has(providerID)) continue
       const result = await fn(database[providerID])
       if (result && (result.autoload || providers[providerID])) {
-        mergeProvider(providerID, result.options ?? {}, "custom", result.getModel)
+        mergeProvider(providerID, {
+          source: "custom",
+          options: result.options,
+        })
       }
     }
 
     // load config
     for (const [providerID, provider] of configProviders) {
-      mergeProvider(providerID, provider.options ?? {}, "config")
+      mergeProvider(providerID, {
+        source: "config",
+        env: provider.env,
+        name: provider.name,
+        options: provider.options,
+        // TODO: merge models
+      })
     }
 
     for (const [providerID, provider] of Object.entries(providers)) {
@@ -519,33 +644,36 @@ export namespace Provider {
       }
 
       if (providerID === "github-copilot" || providerID === "github-copilot-enterprise") {
-        provider.info.npm = "@ai-sdk/github-copilot"
+        provider.models = mapValues(provider.models, (model) => ({
+          ...model,
+          api: {
+            ...model.api,
+            npm: "@ai-sdk/github-copilot",
+          },
+        }))
       }
 
       const configProvider = config.provider?.[providerID]
 
-      for (const [modelID, model] of Object.entries(provider.info.models)) {
-        model.target = model.target ?? model.id ?? modelID
+      for (const [modelID, model] of Object.entries(provider.models)) {
+        model.api.id = model.api.id ?? model.id ?? modelID
         if (modelID === "gpt-5-chat-latest" || (providerID === "openrouter" && modelID === "openai/gpt-5-chat"))
-          delete provider.info.models[modelID]
-        if (
-          ((model.status === "alpha" || model.experimental) && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) ||
-          model.status === "deprecated"
-        )
-          delete provider.info.models[modelID]
+          delete provider.models[modelID]
+        if ((model.status === "alpha" && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) || model.status === "deprecated")
+          delete provider.models[modelID]
         if (
           (configProvider?.blacklist && configProvider.blacklist.includes(modelID)) ||
           (configProvider?.whitelist && !configProvider.whitelist.includes(modelID))
         )
-          delete provider.info.models[modelID]
+          delete provider.models[modelID]
       }
 
-      if (Object.keys(provider.info.models).length === 0) {
+      if (Object.keys(provider.models).length === 0) {
         delete providers[providerID]
         continue
       }
 
-      log.info("found", { providerID, npm: provider.info.npm })
+      log.info("found", { providerID })
     }
 
     return {
@@ -559,18 +687,28 @@ export namespace Provider {
     return state().then((state) => state.providers)
   }
 
-  async function getSDK(npm: string, providerID: string) {
+  async function getSDK(model: Model) {
     try {
       using _ = log.time("getSDK", {
-        providerID,
+        providerID: model.providerID,
       })
       const s = await state()
-      const options = { ...s.providers[providerID]?.options }
-      if (npm.includes("@ai-sdk/openai-compatible") && options["includeUsage"] !== false) {
+      const provider = s.providers[model.providerID]
+      const options = { ...provider.options }
+
+      if (model.api.npm.includes("@ai-sdk/openai-compatible") && options["includeUsage"] !== false) {
         options["includeUsage"] = true
       }
 
-      const key = Bun.hash.xxHash32(JSON.stringify({ pkg: npm, options }))
+      if (!options["baseURL"]) options["baseURL"] = model.api.url
+      if (!options["apiKey"]) options["apiKey"] = provider.key
+      if (model.headers)
+        options["headers"] = {
+          ...options["headers"],
+          ...model.headers,
+        }
+
+      const key = Bun.hash.xxHash32(JSON.stringify({ npm: model.api.npm, options }))
       const existing = s.sdk.get(key)
       if (existing) return existing
 
@@ -599,12 +737,13 @@ export namespace Provider {
       }
 
       // Special case: google-vertex-anthropic uses a subpath import
-      const bundledKey = providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : npm
+      const bundledKey =
+        model.providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : model.api.npm
       const bundledFn = BUNDLED_PROVIDERS[bundledKey]
       if (bundledFn) {
-        log.info("using bundled provider", { providerID, pkg: bundledKey })
+        log.info("using bundled provider", { providerID: model.providerID, pkg: bundledKey })
         const loaded = bundledFn({
-          name: providerID,
+          name: model.providerID,
           ...options,
         })
         s.sdk.set(key, loaded)
@@ -612,24 +751,24 @@ export namespace Provider {
       }
 
       let installedPath: string
-      if (!npm.startsWith("file://")) {
-        installedPath = await BunProc.install(npm, "latest")
+      if (!model.api.npm.startsWith("file://")) {
+        installedPath = await BunProc.install(model.api.npm, "latest")
       } else {
-        log.info("loading local provider", { pkg: npm })
-        installedPath = npm
+        log.info("loading local provider", { pkg: model.api.npm })
+        installedPath = model.api.npm
       }
 
       const mod = await import(installedPath)
 
       const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
       const loaded = fn({
-        name: providerID,
+        name: model.providerID,
         ...options,
       })
       s.sdk.set(key, loaded)
       return loaded as SDK
     } catch (e) {
-      throw new InitError({ providerID }, { cause: e })
+      throw new InitError({ providerID: model.providerID }, { cause: e })
     }
   }
 
@@ -655,22 +794,25 @@ export namespace Provider {
       throw new ModelNotFoundError({ providerID, modelID, suggestions })
     }
 
-    const info = provider.info.models[modelID]
+    const info = provider.models[modelID]
     if (!info) {
-      const availableModels = Object.keys(provider.info.models)
+      const availableModels = Object.keys(provider.models)
       const matches = fuzzysort.go(modelID, availableModels, { limit: 3, threshold: -10000 })
       const suggestions = matches.map((m) => m.target)
       throw new ModelNotFoundError({ providerID, modelID, suggestions })
     }
 
-    const npm = info.provider?.npm ?? provider.info.npm ?? info.id
-    const sdk = await getSDK(npm, providerID)
+    const sdk = await getSDK(info)
 
     try {
       const language = provider.getModel
-        ? await provider.getModel(sdk, info.target, provider.options)
-        : sdk.languageModel(info.target)
+        ? await provider.getModel(sdk, info.api.id, provider.options)
+        : sdk.languageModel(info.api.id)
       log.info("found", { providerID, modelID })
+      const cached: ModelWithStuff = {
+        info,
+        language,
+      }
       s.models.set(key, {
         providerID,
         modelID,
@@ -755,7 +897,7 @@ export namespace Provider {
   }
 
   const priority = ["gpt-5", "claude-sonnet-4", "big-pickle", "gemini-3-pro"]
-  export function sort(models: ModelsDev.Model[]) {
+  export function sort(models: Model[]) {
     return sortBy(
       models,
       [(model) => priority.findIndex((filter) => model.id.includes(filter)), "desc"],