Aiden Cline 2 mesi fa
parent
commit
66d1418b27

+ 12 - 100
packages/opencode/src/provider/model-detection.ts

@@ -1,117 +1,31 @@
 import z from "zod"
 import { iife } from "@/util/iife"
 import { Log } from "@/util/log"
-import { Config } from "../config/config"
-import { ModelsDev } from "./models"
 import { Provider } from "./provider"
 
 export namespace ProviderModelDetection {
-  function mergeModel(
-    detectedModel: Partial<Provider.Model>,
-    providerModel: Provider.Model | undefined,
-    modelID: string,
-    providerID: string,
-    providerBaseURL: string,
-  ): Provider.Model {
-    return {
-      id: modelID,
-      providerID: detectedModel.providerID ?? providerModel?.providerID ?? providerID,
-      api: {
-        id: modelID,
-        url: detectedModel.api?.url ?? providerModel?.api?.url ?? providerBaseURL,
-        npm: detectedModel.api?.npm ?? providerModel?.api?.npm ?? "@ai-sdk/openai-compatible",
-      },
-      name: detectedModel.name ?? providerModel?.name ?? modelID,
-      family: detectedModel.family ?? providerModel?.family ?? "",
-      capabilities: {
-        temperature: detectedModel.capabilities?.temperature ?? providerModel?.capabilities?.temperature ?? false,
-        reasoning: detectedModel.capabilities?.reasoning ?? providerModel?.capabilities?.reasoning ?? false,
-        attachment: detectedModel.capabilities?.attachment ?? providerModel?.capabilities?.attachment ?? false,
-        toolcall: detectedModel.capabilities?.toolcall ?? providerModel?.capabilities?.toolcall ?? true,
-        input: {
-          text: detectedModel.capabilities?.input?.text ?? providerModel?.capabilities?.input?.text ?? true,
-          audio: detectedModel.capabilities?.input?.audio ?? providerModel?.capabilities?.input?.audio ?? false,
-          image: detectedModel.capabilities?.input?.image ?? providerModel?.capabilities?.input?.image ?? false,
-          video: detectedModel.capabilities?.input?.video ?? providerModel?.capabilities?.input?.video ?? false,
-          pdf: detectedModel.capabilities?.input?.pdf ?? providerModel?.capabilities?.input?.pdf ?? false,
-        },
-        output: {
-          text: detectedModel.capabilities?.output?.text ?? providerModel?.capabilities?.output?.text ?? true,
-          audio: detectedModel.capabilities?.output?.audio ?? providerModel?.capabilities?.output?.audio ?? false,
-          image: detectedModel.capabilities?.output?.image ?? providerModel?.capabilities?.output?.image ?? false,
-          video: detectedModel.capabilities?.output?.video ?? providerModel?.capabilities?.output?.video ?? false,
-          pdf: detectedModel.capabilities?.output?.pdf ?? providerModel?.capabilities?.output?.pdf ?? false,
-        },
-        interleaved: detectedModel.capabilities?.interleaved ?? providerModel?.capabilities?.interleaved ?? false,
-      },
-      cost: {
-        input: detectedModel.cost?.input ?? providerModel?.cost?.input ?? 0,
-        output: detectedModel.cost?.output ?? providerModel?.cost?.output ?? 0,
-        cache: {
-          read: detectedModel.cost?.cache?.read ?? providerModel?.cost?.cache?.read ?? 0,
-          write: detectedModel.cost?.cache?.write ?? providerModel?.cost?.cache?.write ?? 0,
-        },
-        experimentalOver200K: detectedModel.cost?.experimentalOver200K ?? providerModel?.cost?.experimentalOver200K,
-      },
-      limit: {
-        context: detectedModel.limit?.context ?? providerModel?.limit?.context ?? 0,
-        input: detectedModel.limit?.input ?? providerModel?.limit?.input ?? 0,
-        output: detectedModel.limit?.output ?? providerModel?.limit?.output ?? 0,
-      },
-      status: detectedModel.status ?? providerModel?.status ?? "active",
-      options: detectedModel.options ?? providerModel?.options ?? {},
-      headers: detectedModel.headers ?? providerModel?.headers ?? {},
-      release_date: detectedModel.release_date ?? providerModel?.release_date ?? "",
-      variants: detectedModel.variants ?? providerModel?.variants ?? {},
-    }
-  }
-
-  export async function populateModels(
-    provider: Provider.Info,
-    configProvider?: Config.Provider,
-    modelsDevProvider?: ModelsDev.Provider,
-  ): Promise<void> {
+  export async function detect(provider: Provider.Info): Promise<string[] | undefined> {
     const log = Log.create({ service: "provider.model-detection" })
 
-    const providerNPM = configProvider?.npm ?? modelsDevProvider?.npm ?? "@ai-sdk/openai-compatible"
-    const providerBaseURL = configProvider?.options?.baseURL ?? configProvider?.api ?? modelsDevProvider?.api ?? ""
+    const model = Object.values(provider.models)[0]
+    const providerNPM = model?.api?.npm ?? "@ai-sdk/openai-compatible"
+    const providerBaseURL = provider.options["baseURL"] ?? model?.api?.url ?? ""
 
     const detectedModels = await iife(async () => {
-      if (provider.id === "opencode") return
-
       try {
         if (providerNPM === "@ai-sdk/openai-compatible" && providerBaseURL) {
           log.info("using OpenAI-compatible method", { providerID: provider.id })
           return await ProviderModelDetection.OpenAICompatible.listModels(providerBaseURL, provider)
         }
       } catch (error) {
-        log.warn(`failed to populate models\n${error}`, { providerID: provider.id })
+        log.warn(`failed to detect models\n${error}`, { providerID: provider.id })
       }
     })
-    if (!detectedModels || Object.entries(detectedModels).length === 0) return
 
-    // Only keep models detected and models specified in config
-    const modelIDs = Array.from(new Set([
-      ...Object.keys(detectedModels),
-      ...Object.keys(configProvider?.models ?? {}),
-    ]))
-    // Provider models are merged from config and Models.dev, delete models only from Models.dev
-    for (const [modelID] of Object.entries(provider.models)) {
-      if (!modelIDs.includes(modelID)) delete provider.models[modelID]
-    }
-    // Add detected models, and take precedence over provider models (which are from config and Models.dev)
-    for (const modelID of modelIDs) {
-      if (!(modelID in detectedModels)) continue
-      provider.models[modelID] = mergeModel(
-        detectedModels[modelID],
-        provider.models[modelID],
-        modelID,
-        provider.id,
-        providerBaseURL,
-      )
-    }
+    if (!detectedModels || detectedModels.length === 0) return
 
-    log.info("populated models", { providerID: provider.id })
+    log.info("detected models", { providerID: provider.id, count: detectedModels.length })
+    return detectedModels
   }
 }
 
@@ -129,7 +43,7 @@ export namespace ProviderModelDetection.OpenAICompatible {
   })
   type OpenAICompatibleResponse = z.infer<typeof OpenAICompatibleResponse>
 
-  export async function listModels(baseURL: string, provider: Provider.Info): Promise<Record<string, Partial<Provider.Model>>> {
+  export async function listModels(baseURL: string, provider: Provider.Info): Promise<string[]> {
     const fetchFn = provider.options["fetch"] ?? fetch
     const apiKey = provider.options["apiKey"] ?? provider.key ?? ""
     const headers = new Headers()
@@ -142,10 +56,8 @@ export namespace ProviderModelDetection.OpenAICompatible {
     if (!res.ok) throw new Error(`bad http status ${res.status}`)
     const parsed = OpenAICompatibleResponse.parse(await res.json())
 
-    return Object.fromEntries(
-      parsed.data
-        .filter((model) => model.id && !model.id.includes("embedding") && !model.id.includes("embed"))
-        .map((model) => [model.id, {}])
-    )
+    return parsed.data
+      .filter((model) => model.id && !model.id.includes("embedding") && !model.id.includes("embed"))
+      .map((model) => model.id)
   }
 }

+ 50 - 4
packages/opencode/src/provider/provider.ts

@@ -673,6 +673,45 @@ export namespace Provider {
     }
   }
 
+  const ModelsList = z.object({
+    object: z.string(),
+    data: z.array(
+      z
+        .object({
+          id: z.string(),
+          object: z.string().optional(),
+          created: z.number().optional(),
+          owned_by: z.string().optional(),
+        })
+        .catchall(z.any()),
+    ),
+  })
+  type ModelsList = z.infer<typeof ModelsList>
+
+  async function listModels(provider: Info) {
+    const baseURL = provider.options["baseURL"]
+    const fetchFn = (provider.options["fetch"] as typeof fetch) ?? fetch
+    const apiKey = provider.options["apiKey"] ?? provider.key ?? ""
+    const headers = new Headers()
+    if (apiKey) headers.append("Authorization", `Bearer ${apiKey}`)
+    const models = await fetchFn(`${baseURL}/models`, {
+      headers,
+      signal: AbortSignal.timeout(3 * 1000),
+    })
+      .then(async (resp) => {
+        if (!resp.ok) return
+        return ModelsList.parse(await resp.json())
+      })
+      .catch((err) => {
+        log.error(`Failed to fetch models from: ${baseURL}/models`, { error: err })
+      })
+    if (!models) return
+
+    return models.data
+      .filter((model) => model.id && !model.id.includes("embedding") && !model.id.includes("embed"))
+      .map((model) => model.id)
+  }
+
   const state = Instance.state(async () => {
     using _ = log.time("state")
     const config = await Config.get()
@@ -904,11 +943,18 @@ export namespace Provider {
       mergeProvider(providerID, partial)
     }
 
-    // detect and populate models
+    // detect models and prune invalid ones
     await Promise.all(
-      Object.entries(providers).map(async ([providerID, provider]) => {
-        await ProviderModelDetection.populateModels(provider, config.provider?.[providerID], modelsDev[providerID])
-      })
+      Object.values(providers).map(async (provider) => {
+        const detected = await listModels(provider)
+        if (!detected) return
+        const detectedSet = new Set(detected)
+        for (const modelID of Object.keys(provider.models)) {
+          if (!detectedSet.has(modelID)) delete provider.models[modelID]
+        }
+        // TODO: add detected models not present in config/models.dev
+        // for (const modelID of detected) {}
+      }),
     )
 
     for (const [providerID, provider] of Object.entries(providers)) {