Explorar o código

fix: add useCompletionUrls option to fix certain azure setups (#2528)

Co-authored-by: andreas.blass <[email protected]>
Co-authored-by: GitHub Action <[email protected]>
AB hai 4 meses
pai
achega
d0519be0d0
Modificáronse 1 ficheiros con 14 adicións e 8 borrados
  1. 14 8
      packages/opencode/src/provider/provider.ts

+ 14 - 8
packages/opencode/src/provider/provider.ts

@@ -18,7 +18,7 @@ export namespace Provider {
 
 
   type CustomLoader = (provider: ModelsDev.Provider) => Promise<{
   type CustomLoader = (provider: ModelsDev.Provider) => Promise<{
     autoload: boolean
     autoload: boolean
-    getModel?: (sdk: any, modelID: string) => Promise<any>
+    getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
     options?: Record<string, any>
     options?: Record<string, any>
   }>
   }>
 
 
@@ -58,7 +58,7 @@ export namespace Provider {
     openai: async () => {
     openai: async () => {
       return {
       return {
         autoload: false,
         autoload: false,
-        async getModel(sdk: any, modelID: string) {
+        async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
           return sdk.responses(modelID)
           return sdk.responses(modelID)
         },
         },
         options: {},
         options: {},
@@ -67,8 +67,12 @@ export namespace Provider {
     azure: async () => {
     azure: async () => {
       return {
       return {
         autoload: false,
         autoload: false,
-        async getModel(sdk: any, modelID: string) {
-          return sdk.responses(modelID)
+        async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
+          if (options?.["useCompletionUrls"]) {
+            return sdk.completion(modelID)
+          } else {
+            return sdk.responses(modelID)
+          }
         },
         },
         options: {},
         options: {},
       }
       }
@@ -86,7 +90,7 @@ export namespace Provider {
           region,
           region,
           credentialProvider: fromNodeProviderChain(),
           credentialProvider: fromNodeProviderChain(),
         },
         },
-        async getModel(sdk: any, modelID: string) {
+        async getModel(sdk: any, modelID: string, options?: Record<string, any>) {
           let regionPrefix = region.split("-")[0]
           let regionPrefix = region.split("-")[0]
 
 
           switch (regionPrefix) {
           switch (regionPrefix) {
@@ -197,7 +201,7 @@ export namespace Provider {
       [providerID: string]: {
       [providerID: string]: {
         source: Source
         source: Source
         info: ModelsDev.Provider
         info: ModelsDev.Provider
-        getModel?: (sdk: any, modelID: string) => Promise<any>
+        getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
         options: Record<string, any>
         options: Record<string, any>
       }
       }
     } = {}
     } = {}
@@ -213,7 +217,7 @@ export namespace Provider {
       id: string,
       id: string,
       options: Record<string, any>,
       options: Record<string, any>,
       source: Source,
       source: Source,
-      getModel?: (sdk: any, modelID: string) => Promise<any>,
+      getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>,
     ) {
     ) {
       const provider = providers[id]
       const provider = providers[id]
       if (!provider) {
       if (!provider) {
@@ -432,7 +436,9 @@ export namespace Provider {
     const sdk = await getSDK(provider.info, info)
     const sdk = await getSDK(provider.info, info)
 
 
     try {
     try {
-      const language = provider.getModel ? await provider.getModel(sdk, modelID) : sdk.languageModel(modelID)
+      const language = provider.getModel
+        ? await provider.getModel(sdk, modelID, provider.options)
+        : sdk.languageModel(modelID)
       log.info("found", { providerID, modelID })
       log.info("found", { providerID, modelID })
       s.models.set(key, {
       s.models.set(key, {
         providerID,
         providerID,