Forráskód Böngészése

fix: non openai azure models that use completions endpoints (#17128)

Aiden Cline 1 hónapja
szülő
commit
a776a3ee12
1 módosított fájl, 73 hozzáadás és 47 törlés
  1. 73 47
      packages/opencode/src/provider/provider.ts

+ 73 - 47
packages/opencode/src/provider/provider.ts

@@ -52,44 +52,10 @@ const DEFAULT_CHUNK_TIMEOUT = 120_000
 export namespace Provider {
   const log = Log.create({ service: "provider" })
 
-  function isGpt5OrLater(modelID: string): boolean {
-    const match = /^gpt-(\d+)/.exec(modelID)
-    if (!match) {
-      return false
-    }
-    return Number(match[1]) >= 5
-  }
-
   function shouldUseCopilotResponsesApi(modelID: string): boolean {
-    return isGpt5OrLater(modelID) && !modelID.startsWith("gpt-5-mini")
-  }
-
-  function googleVertexVars(options: Record<string, any>) {
-    const project =
-      options["project"] ?? Env.get("GOOGLE_CLOUD_PROJECT") ?? Env.get("GCP_PROJECT") ?? Env.get("GCLOUD_PROJECT")
-    const location =
-      options["location"] ??
-      Env.get("GOOGLE_VERTEX_LOCATION") ??
-      Env.get("GOOGLE_CLOUD_LOCATION") ??
-      Env.get("VERTEX_LOCATION") ??
-      "us-central1"
-    const endpoint = location === "global" ? "aiplatform.googleapis.com" : `${location}-aiplatform.googleapis.com`
-
-    return {
-      GOOGLE_VERTEX_PROJECT: project,
-      GOOGLE_VERTEX_LOCATION: location,
-      GOOGLE_VERTEX_ENDPOINT: endpoint,
-    }
-  }
-
-  function loadBaseURL(model: Model, options: Record<string, any>) {
-    const raw = options["baseURL"] ?? model.api.url
-    if (typeof raw !== "string") return raw
-    const vars = model.providerID === "google-vertex" ? googleVertexVars(options) : undefined
-    return raw.replace(/\$\{([^}]+)\}/g, (match, key) => {
-      const val = Env.get(String(key)) ?? vars?.[String(key) as keyof typeof vars]
-      return val ?? match
-    })
+    const match = /^gpt-(\d+)/.exec(modelID)
+    if (!match) return false
+    return Number(match[1]) >= 5 && !modelID.startsWith("gpt-5-mini")
   }
 
   function wrapSSE(res: Response, ms: number, ctl: AbortController) {
@@ -166,12 +132,18 @@ export namespace Provider {
   }
 
   type CustomModelLoader = (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
+  type CustomVarsLoader = (options: Record<string, any>) => Record<string, string>
   type CustomLoader = (provider: Info) => Promise<{
     autoload: boolean
     getModel?: CustomModelLoader
+    vars?: CustomVarsLoader
     options?: Record<string, any>
   }>
 
+  function useLanguageModel(sdk: any) {
+    return sdk.responses === undefined && sdk.chat === undefined
+  }
+
   const CUSTOM_LOADERS: Record<string, CustomLoader> = {
     async anthropic() {
       return {
@@ -219,7 +191,7 @@ export namespace Provider {
       return {
         autoload: false,
         async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
-          if (sdk.responses === undefined && sdk.chat === undefined) return sdk.languageModel(modelID)
+          if (useLanguageModel(sdk)) return sdk.languageModel(modelID)
           return shouldUseCopilotResponsesApi(modelID) ? sdk.responses(modelID) : sdk.chat(modelID)
         },
         options: {},
@@ -229,16 +201,23 @@ export namespace Provider {
       return {
         autoload: false,
         async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
-          if (sdk.responses === undefined && sdk.chat === undefined) return sdk.languageModel(modelID)
+          if (useLanguageModel(sdk)) return sdk.languageModel(modelID)
           return shouldUseCopilotResponsesApi(modelID) ? sdk.responses(modelID) : sdk.chat(modelID)
         },
         options: {},
       }
     },
-    azure: async () => {
+    azure: async (provider) => {
+      const resource = iife(() => {
+        const name = provider.options?.resourceName
+        if (typeof name === "string" && name.trim() !== "") return name
+        return Env.get("AZURE_RESOURCE_NAME")
+      })
+
       return {
         autoload: false,
         async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
+          if (useLanguageModel(sdk)) return sdk.languageModel(modelID)
           if (options?.["useCompletionUrls"]) {
             return sdk.chat(modelID)
           } else {
@@ -246,6 +225,11 @@ export namespace Provider {
           }
         },
         options: {},
+        vars(_options) {
+          return {
+            ...(resource && { AZURE_RESOURCE_NAME: resource }),
+          }
+        },
       }
     },
     "azure-cognitive-services": async () => {
@@ -253,6 +237,7 @@ export namespace Provider {
       return {
         autoload: false,
         async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
+          if (useLanguageModel(sdk)) return sdk.languageModel(modelID)
           if (options?.["useCompletionUrls"]) {
             return sdk.chat(modelID)
           } else {
@@ -441,17 +426,26 @@ export namespace Provider {
         Env.get("GCP_PROJECT") ??
         Env.get("GCLOUD_PROJECT")
 
-      const location =
+      const location = String(
         provider.options?.location ??
-        Env.get("GOOGLE_VERTEX_LOCATION") ??
-        Env.get("GOOGLE_CLOUD_LOCATION") ??
-        Env.get("VERTEX_LOCATION") ??
-        "us-central1"
+          Env.get("GOOGLE_VERTEX_LOCATION") ??
+          Env.get("GOOGLE_CLOUD_LOCATION") ??
+          Env.get("VERTEX_LOCATION") ??
+          "us-central1",
+      )
 
       const autoload = Boolean(project)
       if (!autoload) return { autoload: false }
       return {
         autoload: true,
+        vars(_options: Record<string, any>) {
+          const endpoint = location === "global" ? "aiplatform.googleapis.com" : `${location}-aiplatform.googleapis.com`
+          return {
+            ...(project && { GOOGLE_VERTEX_PROJECT: project }),
+            GOOGLE_VERTEX_LOCATION: location,
+            GOOGLE_VERTEX_ENDPOINT: endpoint,
+          }
+        },
         options: {
           project,
           location,
@@ -583,11 +577,15 @@ export namespace Provider {
         autoload: !!apiKey,
         options: {
           apiKey,
-          baseURL: `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/v1`,
         },
         async getModel(sdk: any, modelID: string) {
           return sdk.languageModel(modelID)
         },
+        vars(_options) {
+          return {
+            CLOUDFLARE_ACCOUNT_ID: accountId,
+          }
+        },
       }
     },
     "cloudflare-ai-gateway": async (input) => {
@@ -856,6 +854,9 @@ export namespace Provider {
     const modelLoaders: {
       [providerID: string]: CustomModelLoader
     } = {}
+    const varsLoaders: {
+      [providerID: string]: CustomVarsLoader
+    } = {}
     const sdk = new Map<string, SDK>()
 
     log.info("init")
@@ -1055,6 +1056,7 @@ export namespace Provider {
       const result = await fn(data)
       if (result && (result.autoload || providers[providerID])) {
         if (result.getModel) modelLoaders[providerID] = result.getModel
+        if (result.vars) varsLoaders[providerID] = result.vars
         const opts = result.options ?? {}
         const patch: Partial<Info> = providers[providerID] ? { options: opts } : { source: "custom", options: opts }
         mergeProvider(providerID, patch)
@@ -1121,6 +1123,7 @@ export namespace Provider {
       providers,
       sdk,
       modelLoaders,
+      varsLoaders,
     }
   })
 
@@ -1145,7 +1148,30 @@ export namespace Provider {
         options["includeUsage"] = true
       }
 
-      const baseURL = loadBaseURL(model, options)
+      const baseURL = iife(() => {
+        let url =
+          typeof options["baseURL"] === "string" && options["baseURL"] !== "" ? options["baseURL"] : model.api.url
+        if (!url) return
+
+        // some models/providers have variable urls, ex: "https://${AZURE_RESOURCE_NAME}.services.ai.azure.com/anthropic/v1"
+        // We track this in models.dev, and then when we are resolving the baseURL
+        // we need to string replace that literal: "${AZURE_RESOURCE_NAME}"
+        const loader = s.varsLoaders[model.providerID]
+        if (loader) {
+          const vars = loader(options)
+          for (const [key, value] of Object.entries(vars)) {
+            const field = "${" + key + "}"
+            url = url.replaceAll(field, value)
+          }
+        }
+
+        url = url.replace(/\$\{([^}]+)\}/g, (item, key) => {
+          const val = Env.get(String(key))
+          return val ?? item
+        })
+        return url
+      })
+
       if (baseURL !== undefined) options["baseURL"] = baseURL
       if (options["apiKey"] === undefined && provider.key) options["apiKey"] = provider.key
       if (model.headers)