Răsfoiți Sursa

feat(provider): add GitHub Enterprise support for Copilot (#2522)

Co-authored-by: Jon-Mikkel Korsvik <[email protected]>
Co-authored-by: GitHub Action <[email protected]>
Co-authored-by: Aiden Cline <[email protected]>
Ola 3 luni în urmă
părinte
comite
fe94bb8e50

+ 1 - 0
packages/opencode/src/auth/index.ts

@@ -10,6 +10,7 @@ export namespace Auth {
       refresh: z.string(),
       access: z.string(),
       expires: z.number(),
+      enterpriseUrl: z.string().optional(),
     })
     .meta({ ref: "OAuth" })
 

+ 193 - 148
packages/opencode/src/cli/cmd/auth.ts

@@ -102,178 +102,223 @@ export const AuthLoginCommand = cmd({
           prompts.outro("Done")
           return
         }
-        await ModelsDev.refresh().catch(() => {})
-        const providers = await ModelsDev.get()
-        const priority: Record<string, number> = {
-          opencode: 0,
-          anthropic: 1,
-          "github-copilot": 2,
-          openai: 3,
-          google: 4,
-          openrouter: 5,
-          vercel: 6,
-        }
-        let provider = await prompts.autocomplete({
-          message: "Select provider",
-          maxItems: 8,
-          options: [
-            ...pipe(
-              providers,
-              values(),
-              sortBy(
-                (x) => priority[x.id] ?? 99,
-                (x) => x.name ?? x.id,
-              ),
-              map((x) => ({
-                label: x.name,
-                value: x.id,
-                hint: priority[x.id] <= 1 ? "recommended" : undefined,
-              })),
+      await ModelsDev.refresh().catch(() => {})
+      const providers = await ModelsDev.get()
+      const priority: Record<string, number> = {
+        opencode: 0,
+        anthropic: 1,
+        "github-copilot": 2,
+        openai: 3,
+        google: 4,
+        openrouter: 5,
+        vercel: 6,
+      }
+      let provider = await prompts.autocomplete({
+        message: "Select provider",
+        maxItems: 8,
+        options: [
+          ...pipe(
+            providers,
+            values(),
+            sortBy(
+              (x) => priority[x.id] ?? 99,
+              (x) => x.name ?? x.id,
             ),
-            {
-              value: "other",
-              label: "Other",
-            },
-          ],
-        })
+            map((x) => ({
+              label: x.name,
+              value: x.id,
+              hint: priority[x.id] <= 1 ? "recommended" : undefined,
+            })),
+          ),
+          {
+            value: "other",
+            label: "Other",
+          },
+        ],
+      })
 
-        if (prompts.isCancel(provider)) throw new UI.CancelledError()
+      if (prompts.isCancel(provider)) throw new UI.CancelledError()
 
-        const plugin = await Plugin.list().then((x) => x.find((x) => x.auth?.provider === provider))
-        if (plugin && plugin.auth) {
-          let index = 0
-          if (plugin.auth.methods.length > 1) {
-            const method = await prompts.select({
-              message: "Login method",
-              options: [
-                ...plugin.auth.methods.map((x, index) => ({
-                  label: x.label,
-                  value: index.toString(),
-                })),
-              ],
-            })
-            if (prompts.isCancel(method)) throw new UI.CancelledError()
-            index = parseInt(method)
-          }
-          const method = plugin.auth.methods[index]
-          if (method.type === "oauth") {
-            await new Promise((resolve) => setTimeout(resolve, 10))
-            const authorize = await method.authorize()
+      const plugin = await Plugin.list().then((x) => x.find((x) => x.auth?.provider === provider))
+      if (plugin && plugin.auth) {
+        let index = 0
+        if (plugin.auth.methods.length > 1) {
+          const method = await prompts.select({
+            message: "Login method",
+            options: [
+              ...plugin.auth.methods.map((x, index) => ({
+                label: x.label,
+                value: index.toString(),
+              })),
+            ],
+          })
+          if (prompts.isCancel(method)) throw new UI.CancelledError()
+          index = parseInt(method)
+        }
+        const method = plugin.auth.methods[index]
 
-            if (authorize.url) {
-              prompts.log.info("Go to: " + authorize.url)
+        // Handle prompts for all auth types
+        await new Promise((resolve) => setTimeout(resolve, 10))
+        const inputs: Record<string, string> = {}
+        if (method.prompts) {
+          for (const prompt of method.prompts) {
+            if (prompt.condition && !prompt.condition(inputs)) {
+              continue
             }
+            if (prompt.type === "select") {
+              const value = await prompts.select({
+                message: prompt.message,
+                options: prompt.options,
+              })
+              if (prompts.isCancel(value)) throw new UI.CancelledError()
+              inputs[prompt.key] = value
+            } else {
+              const value = await prompts.text({
+                message: prompt.message,
+                placeholder: prompt.placeholder,
+                validate: prompt.validate ? (v) => prompt.validate!(v ?? "") : undefined,
+              })
+              if (prompts.isCancel(value)) throw new UI.CancelledError()
+              inputs[prompt.key] = value
+            }
+          }
+        }
 
-            if (authorize.method === "auto") {
-              if (authorize.instructions) {
-                prompts.log.info(authorize.instructions)
-              }
-              const spinner = prompts.spinner()
-              spinner.start("Waiting for authorization...")
-              const result = await authorize.callback()
-              if (result.type === "failed") {
-                spinner.stop("Failed to authorize", 1)
+        if (method.type === "oauth") {
+          const authorize = await method.authorize(inputs)
+
+          if (authorize.url) {
+            prompts.log.info("Go to: " + authorize.url)
+          }
+
+          if (authorize.method === "auto") {
+            if (authorize.instructions) {
+              prompts.log.info(authorize.instructions)
+            }
+            const spinner = prompts.spinner()
+            spinner.start("Waiting for authorization...")
+            const result = await authorize.callback()
+            if (result.type === "failed") {
+              spinner.stop("Failed to authorize", 1)
+            }
+            if (result.type === "success") {
+              const saveProvider = result.provider ?? provider
+              if ("refresh" in result) {
+                const { type: _, provider: __, refresh, access, expires, ...extraFields } = result
+                await Auth.set(saveProvider, {
+                  type: "oauth",
+                  refresh,
+                  access,
+                  expires,
+                  ...extraFields,
+                })
               }
-              if (result.type === "success") {
-                if ("refresh" in result) {
-                  await Auth.set(provider, {
-                    type: "oauth",
-                    refresh: result.refresh,
-                    access: result.access,
-                    expires: result.expires,
-                  })
-                }
-                if ("key" in result) {
-                  await Auth.set(provider, {
-                    type: "api",
-                    key: result.key,
-                  })
-                }
-                spinner.stop("Login successful")
+              if ("key" in result) {
+                await Auth.set(saveProvider, {
+                  type: "api",
+                  key: result.key,
+                })
               }
+              spinner.stop("Login successful")
             }
+          }
 
-            if (authorize.method === "code") {
-              const code = await prompts.text({
-                message: "Paste the authorization code here: ",
-                validate: (x) => (x && x.length > 0 ? undefined : "Required"),
-              })
-              if (prompts.isCancel(code)) throw new UI.CancelledError()
-              const result = await authorize.callback(code)
-              if (result.type === "failed") {
-                prompts.log.error("Failed to authorize")
+          if (authorize.method === "code") {
+            const code = await prompts.text({
+              message: "Paste the authorization code here: ",
+              validate: (x) => (x && x.length > 0 ? undefined : "Required"),
+            })
+            if (prompts.isCancel(code)) throw new UI.CancelledError()
+            const result = await authorize.callback(code)
+            if (result.type === "failed") {
+              prompts.log.error("Failed to authorize")
+            }
+            if (result.type === "success") {
+              const saveProvider = result.provider ?? provider
+              if ("refresh" in result) {
+                const { type: _, provider: __, refresh, access, expires, ...extraFields } = result
+                await Auth.set(saveProvider, {
+                  type: "oauth",
+                  refresh,
+                  access,
+                  expires,
+                  ...extraFields,
+                })
               }
-              if (result.type === "success") {
-                if ("refresh" in result) {
-                  await Auth.set(provider, {
-                    type: "oauth",
-                    refresh: result.refresh,
-                    access: result.access,
-                    expires: result.expires,
-                  })
-                }
-                if ("key" in result) {
-                  await Auth.set(provider, {
-                    type: "api",
-                    key: result.key,
-                  })
-                }
-                prompts.log.success("Login successful")
+              if ("key" in result) {
+                await Auth.set(saveProvider, {
+                  type: "api",
+                  key: result.key,
+                })
               }
+              prompts.log.success("Login successful")
             }
-            prompts.outro("Done")
-            return
           }
-        }
 
-        if (provider === "other") {
-          provider = await prompts.text({
-            message: "Enter provider id",
-            validate: (x) => (x && x.match(/^[0-9a-z-]+$/) ? undefined : "a-z, 0-9 and hyphens only"),
-          })
-          if (prompts.isCancel(provider)) throw new UI.CancelledError()
-          provider = provider.replace(/^@ai-sdk\//, "")
-          if (prompts.isCancel(provider)) throw new UI.CancelledError()
-          prompts.log.warn(
-            `This only stores a credential for ${provider} - you will need configure it in opencode.json, check the docs for examples.`,
-          )
-        }
-
-        if (provider === "amazon-bedrock") {
-          prompts.log.info(
-            "Amazon bedrock can be configured with standard AWS environment variables like AWS_BEARER_TOKEN_BEDROCK, AWS_PROFILE or AWS_ACCESS_KEY_ID",
-          )
-          prompts.outro("Done")
-          return
-        }
-
-        if (provider === "google-vertex") {
-          prompts.log.info(
-            "Google Cloud Vertex AI uses Application Default Credentials. Set GOOGLE_APPLICATION_CREDENTIALS or run 'gcloud auth application-default login'. Optionally set GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION (or VERTEX_LOCATION)",
-          )
           prompts.outro("Done")
           return
         }
 
-        if (provider === "opencode") {
-          prompts.log.info("Create an api key at https://opencode.ai/auth")
-        }
-
-        if (provider === "vercel") {
-          prompts.log.info("You can create an api key at https://vercel.link/ai-gateway-token")
+        if (method.type === "api") {
+          if (method.authorize) {
+            const result = await method.authorize(inputs)
+            if (result.type === "failed") {
+              prompts.log.error("Failed to authorize")
+            }
+            if (result.type === "success") {
+              const saveProvider = result.provider ?? provider
+              await Auth.set(saveProvider, {
+                type: "api",
+                key: result.key,
+              })
+              prompts.log.success("Login successful")
+            }
+            prompts.outro("Done")
+            return
+          }
         }
+      }
 
-        const key = await prompts.password({
-          message: "Enter your API key",
-          validate: (x) => (x && x.length > 0 ? undefined : "Required"),
-        })
-        if (prompts.isCancel(key)) throw new UI.CancelledError()
-        await Auth.set(provider, {
-          type: "api",
-          key,
+      if (provider === "other") {
+        provider = await prompts.text({
+          message: "Enter provider id",
+          validate: (x) => (x && x.match(/^[0-9a-z-]+$/) ? undefined : "a-z, 0-9 and hyphens only"),
         })
+        if (prompts.isCancel(provider)) throw new UI.CancelledError()
+        provider = provider.replace(/^@ai-sdk\//, "")
+        if (prompts.isCancel(provider)) throw new UI.CancelledError()
+        prompts.log.warn(
+          `This only stores a credential for ${provider} - you will need configure it in opencode.json, check the docs for examples.`,
+        )
+      }
 
+      if (provider === "amazon-bedrock") {
+        prompts.log.info(
+          "Amazon bedrock can be configured with standard AWS environment variables like AWS_BEARER_TOKEN_BEDROCK, AWS_PROFILE or AWS_ACCESS_KEY_ID",
+        )
         prompts.outro("Done")
+        return
+      }
+
+      if (provider === "opencode") {
+        prompts.log.info("Create an api key at https://opencode.ai/auth")
+      }
+
+      if (provider === "vercel") {
+        prompts.log.info("You can create an api key at https://vercel.link/ai-gateway-token")
+      }
+
+      const key = await prompts.password({
+        message: "Enter your API key",
+        validate: (x) => (x && x.length > 0 ? undefined : "Required"),
+      })
+      if (prompts.isCancel(key)) throw new UI.CancelledError()
+      await Auth.set(provider, {
+        type: "api",
+        key,
+      })
+
+      prompts.outro("Done")
       },
     })
   },

+ 1 - 0
packages/opencode/src/config/config.ts

@@ -574,6 +574,7 @@ export namespace Config {
                 .object({
                   apiKey: z.string().optional(),
                   baseURL: z.string().optional(),
+                  enterpriseUrl: z.string().optional().describe("GitHub Enterprise URL for copilot authentication"),
                   timeout: z
                     .union([
                       z

+ 1 - 1
packages/opencode/src/plugin/index.ts

@@ -28,7 +28,7 @@ export namespace Plugin {
     }
     const plugins = [...(config.plugin ?? [])]
     if (!Flag.OPENCODE_DISABLE_DEFAULT_PLUGINS) {
-      plugins.push("[email protected].3")
+      plugins.push("[email protected].4")
       plugins.push("[email protected]")
     }
     for (let plugin of plugins) {

+ 52 - 8
packages/opencode/src/provider/provider.ts

@@ -283,6 +283,18 @@ export namespace Provider {
 
     const configProviders = Object.entries(config.provider ?? {})
 
+    // Add GitHub Copilot Enterprise provider that inherits from GitHub Copilot
+    if (database["github-copilot"]) {
+      const githubCopilot = database["github-copilot"]
+      database["github-copilot-enterprise"] = {
+        ...githubCopilot,
+        id: "github-copilot-enterprise",
+        name: "GitHub Copilot Enterprise",
+        // Enterprise uses a different API endpoint - will be set dynamically based on auth
+        api: undefined,
+      }
+    }
+
     for (const [providerID, provider] of configProviders) {
       const existing = database[providerID]
       const parsed: ModelsDev.Provider = {
@@ -378,14 +390,44 @@ export namespace Provider {
       if (!plugin.auth) continue
       const providerID = plugin.auth.provider
       if (disabled.has(providerID)) continue
+
+      // For github-copilot plugin, check if auth exists for either github-copilot or github-copilot-enterprise
+      let hasAuth = false
       const auth = await Auth.get(providerID)
-      if (!auth) continue
+      if (auth) hasAuth = true
+
+      // Special handling for github-copilot: also check for enterprise auth
+      if (providerID === "github-copilot" && !hasAuth) {
+        const enterpriseAuth = await Auth.get("github-copilot-enterprise")
+        if (enterpriseAuth) hasAuth = true
+      }
+
+      if (!hasAuth) continue
       if (!plugin.auth.loader) continue
-      const options = await plugin.auth.loader(
-        () => Auth.get(providerID) as any,
-        database[plugin.auth.provider],
-      )
-      mergeProvider(plugin.auth.provider, options ?? {}, "custom")
+
+      // Load for the main provider if auth exists
+      if (auth) {
+        const options = await plugin.auth.loader(
+          () => Auth.get(providerID) as any,
+          database[plugin.auth.provider],
+        )
+        mergeProvider(plugin.auth.provider, options ?? {}, "custom")
+      }
+
+      // If this is github-copilot plugin, also register for github-copilot-enterprise if auth exists
+      if (providerID === "github-copilot") {
+        const enterpriseProviderID = "github-copilot-enterprise"
+        if (!disabled.has(enterpriseProviderID)) {
+          const enterpriseAuth = await Auth.get(enterpriseProviderID)
+          if (enterpriseAuth) {
+            const enterpriseOptions = await plugin.auth.loader(
+              () => Auth.get(enterpriseProviderID) as any,
+              database[enterpriseProviderID],
+            )
+            mergeProvider(enterpriseProviderID, enterpriseOptions ?? {}, "custom")
+          }
+        }
+      }
     }
 
     // load config
@@ -458,7 +500,8 @@ export namespace Provider {
           : installedPath
       const mod = await import(modPath)
       if (options["timeout"] !== undefined && options["timeout"] !== null) {
-        // Only override fetch if user explicitly sets timeout
+        // Preserve custom fetch if it exists, wrap it with timeout logic
+        const customFetch = options["fetch"]
         options["fetch"] = async (input: any, init?: BunFetchRequestInit) => {
           const { signal, ...rest } = init ?? {}
 
@@ -468,7 +511,8 @@ export namespace Provider {
 
           const combined = signals.length > 1 ? AbortSignal.any(signals) : signals[0]
 
-          return fetch(input, {
+          const fetchFn = customFetch ?? fetch
+          return fetchFn(input, {
             ...rest,
             signal: combined,
             // @ts-ignore see here: https://github.com/oven-sh/bun/issues/16682

+ 59 - 2
packages/plugin/src/index.ts

@@ -39,13 +39,35 @@ export interface Hooks {
       | {
           type: "oauth"
           label: string
-          authorize(): Promise<
+          prompts?: Array<
+            | {
+                type: "text"
+                key: string
+                message: string
+                placeholder?: string
+                validate?: (value: string) => string | undefined
+                condition?: (inputs: Record<string, string>) => boolean
+              }
+            | {
+                type: "select"
+                key: string
+                message: string
+                options: Array<{
+                  label: string
+                  value: string
+                  hint?: string
+                }>
+                condition?: (inputs: Record<string, string>) => boolean
+              }
+          >
+          authorize(inputs?: Record<string, string>): Promise<
             { url: string; instructions: string } & (
               | {
                   method: "auto"
                   callback(): Promise<
                     | ({
                         type: "success"
+                        provider?: string
                       } & (
                         | {
                             refresh: string
@@ -64,6 +86,7 @@ export interface Hooks {
                   callback(code: string): Promise<
                     | ({
                         type: "success"
+                        provider?: string
                       } & (
                         | {
                             refresh: string
@@ -80,7 +103,41 @@ export interface Hooks {
             )
           >
         }
-      | { type: "api"; label: string }
+      | {
+          type: "api"
+          label: string
+          prompts?: Array<
+            | {
+                type: "text"
+                key: string
+                message: string
+                placeholder?: string
+                validate?: (value: string) => string | undefined
+                condition?: (inputs: Record<string, string>) => boolean
+              }
+            | {
+                type: "select"
+                key: string
+                message: string
+                options: Array<{
+                  label: string
+                  value: string
+                  hint?: string
+                }>
+                condition?: (inputs: Record<string, string>) => boolean
+              }
+          >
+          authorize?(inputs?: Record<string, string>): Promise<
+            | {
+                type: "success"
+                key: string
+                provider?: string
+              }
+            | {
+                type: "failed"
+              }
+          >
+        }
     )[]
   }
   /**

+ 5 - 0
packages/sdk/js/src/gen/types.gen.ts

@@ -405,6 +405,10 @@ export type Config = {
       options?: {
         apiKey?: string
         baseURL?: string
+        /**
+         * GitHub Enterprise URL for copilot authentication
+         */
+        enterpriseUrl?: string
         /**
          * Timeout in milliseconds for requests to this provider. Default is 300000 (5 minutes). Set to false to disable timeout.
          */
@@ -1135,6 +1139,7 @@ export type OAuth = {
   refresh: string
   access: string
   expires: number
+  enterpriseUrl?: string
 }
 
 export type ApiAuth = {