Browse Source

add provider whitelist

Aiden Cline 3 months ago
parent
commit
fffe20cbe5
2 changed files with 18 additions and 1 deletions
  1. 4 0
      packages/opencode/src/config/config.ts
  2. 14 1
      packages/opencode/src/provider/provider.ts

+ 4 - 0
packages/opencode/src/config/config.ts

@@ -480,6 +480,10 @@ export namespace Config {
         .describe("@deprecated Use 'share' field instead. Share newly created sessions automatically"),
       autoupdate: z.boolean().optional().describe("Automatically update to the latest version"),
       disabled_providers: z.array(z.string()).optional().describe("Disable providers that are loaded automatically"),
+      enabled_providers: z
+        .array(z.string())
+        .optional()
+        .describe("When set, ONLY these providers will be enabled. All other providers will be ignored"),
       model: z.string().describe("Model to use in the format of provider/model, eg anthropic/claude-2").optional(),
       small_model: z
         .string()

+ 14 - 1
packages/opencode/src/provider/provider.ts

@@ -241,6 +241,15 @@ export namespace Provider {
     const config = await Config.get()
     const database = await ModelsDev.get()
 
+    const disabled = new Set(config.disabled_providers ?? [])
+    const enabled = config.enabled_providers ? new Set(config.enabled_providers) : null
+
+    function isProviderAllowed(providerID: string): boolean {
+      if (enabled && !enabled.has(providerID)) return false
+      if (disabled.has(providerID)) return false
+      return true
+    }
+
     const providers: {
       [providerID: string]: {
         source: Source
@@ -369,7 +378,6 @@ export namespace Provider {
       database[providerID] = parsed
     }
 
-    const disabled = await Config.get().then((cfg) => new Set(cfg.disabled_providers ?? []))
     // load env
     for (const [providerID, provider] of Object.entries(database)) {
       if (disabled.has(providerID)) continue
@@ -447,6 +455,11 @@ export namespace Provider {
     }
 
     for (const [providerID, provider] of Object.entries(providers)) {
+      if (!isProviderAllowed(providerID)) {
+        delete providers[providerID]
+        continue
+      }
+
       const configProvider = config.provider?.[providerID]
       const filteredModels = Object.fromEntries(
         Object.entries(provider.info.models)