Browse Source

fix: azure issue where azure sdk was being used instead of anthropic one for anthropic models

Aiden Cline 3 tháng trước cách đây
mục cha
commit
84c4fe971a
1 tập tin đã thay đổi với 37 bổ sung33 xóa
  1. 37 33
      packages/opencode/src/provider/provider.ts

+ 37 - 33
packages/opencode/src/provider/provider.ts

@@ -66,7 +66,7 @@ export namespace Provider {
     "@ai-sdk/github-copilot": createGitHubCopilotOpenAICompatible,
   }
 
-  type CustomModelLoader = (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
+  type CustomModelLoader = (sdk: any, model: Model, options?: Record<string, any>) => Promise<any>
   type CustomLoader = (provider: Info) => Promise<{
     autoload: boolean
     getModel?: CustomModelLoader
@@ -110,8 +110,8 @@ export namespace Provider {
     openai: async () => {
       return {
         autoload: false,
-        async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
-          return sdk.responses(modelID)
+        async getModel(sdk: any, model: Model, _options?: Record<string, any>) {
+          return sdk.responses(model.api.id)
         },
         options: {},
       }
@@ -119,11 +119,11 @@ export namespace Provider {
     "github-copilot": async () => {
       return {
         autoload: false,
-        async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
-          if (modelID.includes("codex")) {
-            return sdk.responses(modelID)
+        async getModel(sdk: any, model: Model, _options?: Record<string, any>) {
+          if (model.api.id.includes("codex")) {
+            return sdk.responses(model.api.id)
           }
-          return sdk.chat(modelID)
+          return sdk.chat(model.api.id)
         },
         options: {},
       }
@@ -131,11 +131,11 @@ export namespace Provider {
     "github-copilot-enterprise": async () => {
       return {
         autoload: false,
-        async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
-          if (modelID.includes("codex")) {
-            return sdk.responses(modelID)
+        async getModel(sdk: any, model: Model, _options?: Record<string, any>) {
+          if (model.api.id.includes("codex")) {
+            return sdk.responses(model.api.id)
           }
-          return sdk.chat(modelID)
+          return sdk.chat(model.api.id)
         },
         options: {},
       }
@@ -143,12 +143,14 @@ export namespace Provider {
     azure: async () => {
       return {
         autoload: false,
-        async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
+        async getModel(sdk: any, model: Model, options?: Record<string, any>) {
+          if (model && model.api.npm !== "@ai-sdk/azure") {
+            return sdk.languageModel(model.api.id)
+          }
           if (options?.["useCompletionUrls"]) {
-            return sdk.chat(modelID)
-          } else {
-            return sdk.responses(modelID)
+            return sdk.chat(model.api.id)
           }
+          return sdk.responses(model.api.id)
         },
         options: {},
       }
@@ -157,12 +159,14 @@ export namespace Provider {
       const resourceName = Env.get("AZURE_COGNITIVE_SERVICES_RESOURCE_NAME")
       return {
         autoload: false,
-        async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
+        async getModel(sdk: any, model: Model, options?: Record<string, any>) {
+          if (model && model.api.npm !== "@ai-sdk/azure") {
+            return sdk.languageModel(model.api.id)
+          }
           if (options?.["useCompletionUrls"]) {
-            return sdk.chat(modelID)
-          } else {
-            return sdk.responses(modelID)
+            return sdk.chat(model.api.id)
           }
+          return sdk.responses(model.api.id)
         },
         options: {
           baseURL: resourceName ? `https://${resourceName}.cognitiveservices.azure.com/openai` : undefined,
@@ -225,7 +229,8 @@ export namespace Provider {
       return {
         autoload: true,
         options: providerOptions,
-        async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
+        async getModel(sdk: any, model: Model, options?: Record<string, any>) {
+          let modelID = model.api.id
           // Skip region prefixing if model already has a cross-region inference profile prefix
           if (modelID.startsWith("global.") || modelID.startsWith("jp.")) {
             return sdk.languageModel(modelID)
@@ -343,8 +348,8 @@ export namespace Provider {
           project,
           location,
         },
-        async getModel(sdk: any, modelID: string) {
-          const id = String(modelID).trim()
+        async getModel(sdk: any, model: Model) {
+          const id = String(model.api.id).trim()
           return sdk.languageModel(id)
         },
       }
@@ -360,8 +365,8 @@ export namespace Provider {
           project,
           location,
         },
-        async getModel(sdk: any, modelID) {
-          const id = String(modelID).trim()
+        async getModel(sdk: any, model: Model) {
+          const id = String(model.api.id).trim()
           return sdk.languageModel(id)
         },
       }
@@ -383,8 +388,8 @@ export namespace Provider {
       return {
         autoload: !!envServiceKey,
         options: envServiceKey ? { deploymentId, resourceGroup } : {},
-        async getModel(sdk: any, modelID: string) {
-          return sdk(modelID)
+        async getModel(sdk: any, model: Model) {
+          return sdk(model.api.id)
         },
       }
     },
@@ -423,8 +428,8 @@ export namespace Provider {
             ...(providerConfig?.options?.featureFlags || {}),
           },
         },
-        async getModel(sdk: ReturnType<typeof createGitLab>, modelID: string) {
-          return sdk.agenticChat(modelID, {
+        async getModel(sdk: ReturnType<typeof createGitLab>, model: Model) {
+          return sdk.agenticChat(model.api.id, {
             featureFlags: {
               duo_agent_platform_agentic_chat: true,
               duo_agent_platform: true,
@@ -451,8 +456,8 @@ export namespace Provider {
 
       return {
         autoload: true,
-        async getModel(sdk: any, modelID: string, _options?: Record<string, any>) {
-          return sdk.languageModel(modelID)
+        async getModel(sdk: any, model: Model, _options?: Record<string, any>) {
+          return sdk.languageModel(model.api.id)
         },
         options: {
           baseURL: `https://gateway.ai.cloudflare.com/v1/${accountId}/${gateway}/compat`,
@@ -1093,9 +1098,8 @@ export namespace Provider {
     const sdk = await getSDK(model)
 
     try {
-      const language = s.modelLoaders[model.providerID]
-        ? await s.modelLoaders[model.providerID](sdk, model.api.id, provider.options)
-        : sdk.languageModel(model.api.id)
+      const loader = s.modelLoaders[model.providerID]
+      const language = loader ? await loader(sdk, model, provider.options) : sdk.languageModel(model.api.id)
       s.models.set(key, language)
       return language
     } catch (e) {