Explorar el Código

Update provider configuration and server handling

🤖 Generated with [OpenCode](https://opencode.ai)

Co-Authored-By: OpenCode <[email protected]>
Dax Raad hace 8 meses
padre
commit
442e1b52ad

+ 9 - 1
packages/opencode/src/config/config.ts

@@ -2,6 +2,7 @@ import { Log } from "../util/log"
 import { z } from "zod"
 import { App } from "../app/app"
 import { Filesystem } from "../util/filesystem"
+import { ModelsDev } from "../provider/models"
 
 export namespace Config {
   const log = Log.create({ service: "config" })
@@ -49,7 +50,14 @@ export namespace Config {
 
   export const Info = z
     .object({
-      provider: z.record(z.string(), z.record(z.string(), z.any())).optional(),
+      provider: z
+        .record(
+          ModelsDev.Provider.partial().extend({
+            models: z.record(ModelsDev.Model.partial()),
+            options: z.record(z.any()).optional(),
+          }),
+        )
+        .optional(),
       tool: z
         .object({
           provider: z.record(z.string(), z.string().array()).optional(),

+ 29 - 1
packages/opencode/src/provider/models.ts

@@ -1,17 +1,45 @@
 import { Global } from "../global"
 import { Log } from "../util/log"
 import path from "path"
+import { z } from "zod"
 
 export namespace ModelsDev {
   const log = Log.create({ service: "models.dev" })
   const filepath = path.join(Global.Path.cache, "models.json")
 
+  export const Model = z.object({
+    name: z.string(),
+    attachment: z.boolean(),
+    reasoning: z.boolean(),
+    temperature: z.boolean(),
+    cost: z.object({
+      input: z.number(),
+      output: z.number(),
+      inputCached: z.number(),
+      outputCached: z.number(),
+    }),
+    limit: z.object({
+      context: z.number(),
+      output: z.number(),
+    }),
+    id: z.string(),
+  })
+  export type Model = z.infer<typeof Model>
+
+  export const Provider = z.object({
+    name: z.string(),
+    env: z.array(z.string()),
+    id: z.string(),
+    models: z.record(Model),
+  })
+  export type Provider = z.infer<typeof Provider>
+
   export async function get() {
     const file = Bun.file(filepath)
     const result = await file.json().catch(() => {})
     if (result) {
       refresh()
-      return result
+      return result as Record<string, Provider>
     }
     await refresh()
     return get()

+ 82 - 99
packages/opencode/src/provider/provider.ts

@@ -1,7 +1,7 @@
 import z from "zod"
 import { App } from "../app/app"
 import { Config } from "../config/config"
-import { mergeDeep, sortBy } from "remeda"
+import { mergeDeep, pipe, sortBy } from "remeda"
 import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
 import { Log } from "../util/log"
 import path from "path"
@@ -29,106 +29,52 @@ import { TaskTool } from "../tool/task"
 export namespace Provider {
   const log = Log.create({ service: "provider" })
 
-  export const Model = z
-    .object({
-      id: z.string(),
-      name: z.string().optional(),
-      attachment: z.boolean(),
-      reasoning: z.boolean().optional(),
-      cost: z.object({
-        input: z.number(),
-        inputCached: z.number(),
-        output: z.number(),
-        outputCached: z.number(),
-      }),
-      limit: z.object({
-        context: z.number(),
-        output: z.number(),
-      }),
-    })
-    .openapi({
-      ref: "Provider.Model",
-    })
-  export type Model = z.output<typeof Model>
-
-  export const Info = z
-    .object({
-      id: z.string(),
-      name: z.string(),
-      models: z.record(z.string(), Model),
-    })
-    .openapi({
-      ref: "Provider.Info",
-    })
-  export type Info = z.output<typeof Info>
-
-  type Autodetector = (provider: Info) => Promise<
-    | {
-        source: Source
-        options: Record<string, any>
-      }
-    | false
-  >
-
-  function env(...keys: string[]) {
-    const result: Autodetector = async () => {
-      for (const key of keys) {
-        if (process.env[key])
-          return {
-            source: "env",
-            options: {},
-          }
-      }
-      return false
-    }
-
-    return result
-  }
+  type CustomLoader = (
+    provider: ModelsDev.Provider,
+  ) => Promise<Record<string, any> | false>
 
-  type Source = "oauth" | "env" | "config" | "api"
+  type Source = "env" | "config" | "custom"
 
-  const AUTODETECT: Record<string, Autodetector> = {
+  const CUSTOM_LOADERS: Record<string, CustomLoader> = {
     async anthropic(provider) {
       const access = await AuthAnthropic.access()
-      if (access) {
-        // claude sub doesn't have usage cost
-        for (const model of Object.values(provider.models)) {
-          model.cost = {
-            input: 0,
-            inputCached: 0,
-            output: 0,
-            outputCached: 0,
-          }
+      if (!access) return false
+      for (const model of Object.values(provider.models)) {
+        model.cost = {
+          input: 0,
+          inputCached: 0,
+          output: 0,
+          outputCached: 0,
         }
-        return {
-          source: "oauth",
-          options: {
-            apiKey: "",
-            headers: {
-              authorization: `Bearer ${access}`,
-              "anthropic-beta": "oauth-2025-04-20",
-            },
+      }
+      return {
+        source: "oauth",
+        options: {
+          apiKey: "",
+          headers: {
+            authorization: `Bearer ${access}`,
+            "anthropic-beta": "oauth-2025-04-20",
           },
-        }
+        },
       }
-      return env("ANTHROPIC_API_KEY")(provider)
     },
-    google: env("GOOGLE_GENERATIVE_AI_API_KEY"),
-    openai: env("OPENAI_API_KEY"),
   }
 
   const state = App.state("provider", async () => {
     const config = await Config.get()
-    const database: Record<string, Provider.Info> = await ModelsDev.get()
+    const database = await ModelsDev.get()
 
     const providers: {
       [providerID: string]: {
         source: Source
-        info: Provider.Info
+        info: ModelsDev.Provider
         options: Record<string, any>
       }
     } = {}
-    const models = new Map<string, { info: Model; language: LanguageModel }>()
+    const models = new Map<
+      string,
+      { info: ModelsDev.Model; language: LanguageModel }
+    >()
     const sdk = new Map<string, SDK>()
 
     log.info("loading")
@@ -142,11 +88,7 @@ export namespace Provider {
       if (!provider) {
         providers[id] = {
           source,
-          info: database[id] ?? {
-            id,
-            name: id,
-            models: [],
-          },
+          info: database[id],
           options,
         }
         return
@@ -155,22 +97,63 @@ export namespace Provider {
       provider.source = source
     }
 
-    for (const [providerID, fn] of Object.entries(AUTODETECT)) {
-      const provider = database[providerID]
-      if (!provider) continue
-      const result = await fn(provider)
-      if (!result) continue
-      mergeProvider(providerID, result.options, result.source)
+    for (const [providerID, provider] of Object.entries(
+      config.provider ?? {},
+    )) {
+      const existing = database[providerID]
+      const parsed: ModelsDev.Provider = {
+        id: providerID,
+        name: provider.name ?? existing?.name ?? providerID,
+        env: provider.env ?? existing?.env ?? [],
+        models: existing?.models ?? {},
+      }
+
+      for (const [modelID, model] of Object.entries(provider.models ?? {})) {
+        const existing = parsed.models[modelID]
+        const parsedModel: ModelsDev.Model = {
+          id: modelID,
+          name: model.name ?? existing?.name ?? modelID,
+          attachment: model.attachment ?? existing?.attachment ?? false,
+          reasoning: model.reasoning ?? existing?.reasoning ?? false,
+          temperature: model.temperature ?? existing?.temperature ?? false,
+          cost: model.cost ??
+            existing?.cost ?? {
+              input: 0,
+              output: 0,
+              inputCached: 0,
+              outputCached: 0,
+            },
+          limit: model.limit ??
+            existing?.limit ?? {
+              context: 0,
+              output: 0,
+            },
+        }
+        parsed.models[modelID] = parsedModel
+      }
+      database[providerID] = parsed
+    }
+
+    // load env
+    for (const [providerID, provider] of Object.entries(database)) {
+      if (provider.env.some((item) => process.env[item])) {
+        mergeProvider(providerID, {}, "env")
+      }
     }
 
-    for (const [providerID, info] of Object.entries(await Auth.all())) {
-      if (info.type === "api") {
-        mergeProvider(providerID, { apiKey: info.key }, "api")
+    // load custom
+    for (const [providerID, fn] of Object.entries(CUSTOM_LOADERS)) {
+      const result = await fn(database[providerID])
+      if (result) {
+        mergeProvider(providerID, result, "custom")
       }
     }
 
-    for (const [providerID, options] of Object.entries(config.provider ?? {})) {
-      mergeProvider(providerID, options, "config")
+    // load config
+    for (const [providerID, provider] of Object.entries(
+      config.provider ?? {},
+    )) {
+      mergeProvider(providerID, provider.options ?? {}, "config")
     }
 
     for (const providerID of Object.keys(providers)) {
@@ -261,7 +244,7 @@ export namespace Provider {
   }
 
   const priority = ["gemini-2.5-pro-preview", "codex-mini", "claude-sonnet-4"]
-  export function sort(models: Model[]) {
+  export function sort(models: ModelsDev.Model[]) {
     return sortBy(
       models,
       [

+ 2 - 1
packages/opencode/src/server/server.ts

@@ -13,6 +13,7 @@ import { Global } from "../global"
 import { mapValues } from "remeda"
 import { NamedError } from "../util/error"
 import { Fzf } from "../external/fzf"
+import { ModelsDev } from "../provider/models"
 
 const ERRORS = {
   400: {
@@ -406,7 +407,7 @@ export namespace Server {
                 "application/json": {
                   schema: resolver(
                     z.object({
-                      providers: Provider.Info.array(),
+                      providers: ModelsDev.Provider.array(),
                       default: z.record(z.string(), z.string()),
                     }),
                   ),