Browse Source

feat: integrate support for multi step auth flows for providers that require additional questions (#18035)

Aiden Cline 1 tháng trước cách đây
mục cha
commit
171e69c2fc

+ 6 - 2
packages/opencode/src/cli/cmd/providers.ts

@@ -46,9 +46,13 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string,
   const inputs: Record<string, string> = {}
   if (method.prompts) {
     for (const prompt of method.prompts) {
-      if (prompt.condition && !prompt.condition(inputs)) {
-        continue
+      if (prompt.when) {
+        const value = inputs[prompt.when.key]
+        if (value === undefined) continue
+        const matches = prompt.when.op === "eq" ? value === prompt.when.value : value !== prompt.when.value
+        if (!matches) continue
       }
+      if (prompt.condition && !prompt.condition(inputs)) continue
       if (prompt.type === "select") {
         const value = await prompts.select({
           message: prompt.message,

+ 71 - 1
packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx

@@ -8,7 +8,7 @@ import { DialogPrompt } from "../ui/dialog-prompt"
 import { Link } from "../ui/link"
 import { useTheme } from "../context/theme"
 import { TextAttributes } from "@opentui/core"
-import type { ProviderAuthAuthorization } from "@opencode-ai/sdk/v2"
+import type { ProviderAuthAuthorization, ProviderAuthMethod } from "@opencode-ai/sdk/v2"
 import { DialogModel } from "./dialog-model"
 import { useKeyboard } from "@opentui/solid"
 import { Clipboard } from "@tui/util/clipboard"
@@ -27,6 +27,7 @@ export function createDialogProviderOptions() {
   const sync = useSync()
   const dialog = useDialog()
   const sdk = useSDK()
+  const toast = useToast()
   const options = createMemo(() => {
     return pipe(
       sync.data.provider_next.all,
@@ -69,10 +70,29 @@ export function createDialogProviderOptions() {
           if (index == null) return
           const method = methods[index]
           if (method.type === "oauth") {
+            let inputs: Record<string, string> | undefined
+            if (method.prompts?.length) {
+              const value = await PromptsMethod({
+                dialog,
+                prompts: method.prompts,
+              })
+              if (!value) return
+              inputs = value
+            }
+
             const result = await sdk.client.provider.oauth.authorize({
               providerID: provider.id,
               method: index,
+              inputs,
             })
+            if (result.error) {
+              toast.show({
+                variant: "error",
+                message: JSON.stringify(result.error),
+              })
+              dialog.clear()
+              return
+            }
             if (result.data?.method === "code") {
               dialog.replace(() => (
                 <CodeMethod providerID={provider.id} title={method.label} index={index} authorization={result.data!} />
@@ -257,3 +277,53 @@ function ApiMethod(props: ApiMethodProps) {
     />
   )
 }
+
+interface PromptsMethodProps {
+  dialog: ReturnType<typeof useDialog>
+  prompts: NonNullable<ProviderAuthMethod["prompts"]>[number][]
+}
+async function PromptsMethod(props: PromptsMethodProps) {
+  const inputs: Record<string, string> = {}
+  for (const prompt of props.prompts) {
+    if (prompt.when) {
+      const value = inputs[prompt.when.key]
+      if (value === undefined) continue
+      const matches = prompt.when.op === "eq" ? value === prompt.when.value : value !== prompt.when.value
+      if (!matches) continue
+    }
+
+    if (prompt.type === "select") {
+      const value = await new Promise<string | null>((resolve) => {
+        props.dialog.replace(
+          () => (
+            <DialogSelect
+              title={prompt.message}
+              options={prompt.options.map((x) => ({
+                title: x.label,
+                value: x.value,
+                description: x.hint,
+              }))}
+              onSelect={(option) => resolve(option.value)}
+            />
+          ),
+          () => resolve(null),
+        )
+      })
+      if (value === null) return null
+      inputs[prompt.key] = value
+      continue
+    }
+
+    const value = await new Promise<string | null>((resolve) => {
+      props.dialog.replace(
+        () => (
+          <DialogPrompt title={prompt.message} placeholder={prompt.placeholder} onConfirm={(value) => resolve(value)} />
+        ),
+        () => resolve(null),
+      )
+    })
+    if (value === null) return null
+    inputs[prompt.key] = value
+  }
+  return inputs
+}

+ 1 - 1
packages/opencode/src/plugin/copilot.ts

@@ -168,7 +168,7 @@ export async function CopilotAuthPlugin(input: PluginInput): Promise<Hooks> {
               key: "enterpriseUrl",
               message: "Enter your GitHub Enterprise URL or domain",
               placeholder: "company.ghe.com or https://company.ghe.com",
-              condition: (inputs) => inputs.deploymentType === "enterprise",
+              when: { key: "deploymentType", op: "eq", value: "enterprise" },
               validate: (value) => {
                 if (!value) return "URL or domain is required"
                 try {

+ 90 - 3
packages/opencode/src/provider/auth-service.ts

@@ -10,6 +10,44 @@ export const Method = z
   .object({
     type: z.union([z.literal("oauth"), z.literal("api")]),
     label: z.string(),
+    prompts: z
+      .array(
+        z.union([
+          z.object({
+            type: z.literal("text"),
+            key: z.string(),
+            message: z.string(),
+            placeholder: z.string().optional(),
+            when: z
+              .object({
+                key: z.string(),
+                op: z.union([z.literal("eq"), z.literal("neq")]),
+                value: z.string(),
+              })
+              .optional(),
+          }),
+          z.object({
+            type: z.literal("select"),
+            key: z.string(),
+            message: z.string(),
+            options: z.array(
+              z.object({
+                label: z.string(),
+                value: z.string(),
+                hint: z.string().optional(),
+              }),
+            ),
+            when: z
+              .object({
+                key: z.string(),
+                op: z.union([z.literal("eq"), z.literal("neq")]),
+                value: z.string(),
+              })
+              .optional(),
+          }),
+        ]),
+      )
+      .optional(),
   })
   .meta({
     ref: "ProviderAuthMethod",
@@ -43,16 +81,29 @@ export const OauthCodeMissing = NamedError.create(
 
 export const OauthCallbackFailed = NamedError.create("ProviderAuthOauthCallbackFailed", z.object({}))
 
+export const ValidationFailed = NamedError.create(
+  "ProviderAuthValidationFailed",
+  z.object({
+    field: z.string(),
+    message: z.string(),
+  }),
+)
+
 export type ProviderAuthError =
   | Auth.AuthServiceError
   | InstanceType<typeof OauthMissing>
   | InstanceType<typeof OauthCodeMissing>
   | InstanceType<typeof OauthCallbackFailed>
+  | InstanceType<typeof ValidationFailed>
 
 export namespace ProviderAuthService {
   export interface Service {
     readonly methods: () => Effect.Effect<Record<string, Method[]>>
-    readonly authorize: (input: { providerID: ProviderID; method: number }) => Effect.Effect<Authorization | undefined>
+    readonly authorize: (input: {
+      providerID: ProviderID
+      method: number
+      inputs?: Record<string, string>
+    }) => Effect.Effect<Authorization | undefined, ProviderAuthError>
     readonly callback: (input: {
       providerID: ProviderID
       method: number
@@ -80,16 +131,52 @@ export class ProviderAuthService extends ServiceMap.Service<ProviderAuthService,
       const pending = new Map<ProviderID, AuthOuathResult>()
 
       const methods = Effect.fn("ProviderAuthService.methods")(function* () {
-        return Record.map(hooks, (item) => item.methods.map((method): Method => Struct.pick(method, ["type", "label"])))
+        return Record.map(hooks, (item) =>
+          item.methods.map(
+            (method): Method => ({
+              type: method.type,
+              label: method.label,
+              prompts: method.prompts?.map((prompt) => {
+                if (prompt.type === "select") {
+                  return {
+                    type: "select" as const,
+                    key: prompt.key,
+                    message: prompt.message,
+                    options: prompt.options,
+                    when: prompt.when,
+                  }
+                }
+                return {
+                  type: "text" as const,
+                  key: prompt.key,
+                  message: prompt.message,
+                  placeholder: prompt.placeholder,
+                  when: prompt.when,
+                }
+              }),
+            }),
+          ),
+        )
       })
 
       const authorize = Effect.fn("ProviderAuthService.authorize")(function* (input: {
         providerID: ProviderID
         method: number
+        inputs?: Record<string, string>
       }) {
         const method = hooks[input.providerID].methods[input.method]
         if (method.type !== "oauth") return
-        const result = yield* Effect.promise(() => method.authorize())
+
+        if (method.prompts && input.inputs) {
+          for (const prompt of method.prompts) {
+            if (prompt.type === "text" && prompt.validate && input.inputs[prompt.key] !== undefined) {
+              const error = prompt.validate(input.inputs[prompt.key])
+              if (error) return yield* Effect.fail(new ValidationFailed({ field: prompt.key, message: error }))
+            }
+          }
+        }
+
+        const result = yield* Effect.promise(() => method.authorize(input.inputs))
         pending.set(input.providerID, result)
         return {
           url: result.url,

+ 2 - 0
packages/opencode/src/provider/auth.ts

@@ -20,6 +20,7 @@ export namespace ProviderAuth {
     z.object({
       providerID: ProviderID.zod,
       method: z.number(),
+      inputs: z.record(z.string(), z.string()).optional(),
     }),
     async (input): Promise<Authorization | undefined> =>
       runPromiseInstance(S.ProviderAuthService.use((service) => service.authorize(input))),
@@ -37,4 +38,5 @@ export namespace ProviderAuth {
   export import OauthMissing = S.OauthMissing
   export import OauthCodeMissing = S.OauthCodeMissing
   export import OauthCallbackFailed = S.OauthCallbackFailed
+  export import ValidationFailed = S.ValidationFailed
 }

+ 3 - 1
packages/opencode/src/server/routes/provider.ts

@@ -109,14 +109,16 @@ export const ProviderRoutes = lazy(() =>
         "json",
         z.object({
           method: z.number().meta({ description: "Auth method index" }),
+          inputs: z.record(z.string(), z.string()).optional().meta({ description: "Prompt inputs" }),
         }),
       ),
       async (c) => {
         const providerID = c.req.valid("param").providerID
-        const { method } = c.req.valid("json")
+        const { method, inputs } = c.req.valid("json")
         const result = await ProviderAuth.authorize({
           providerID,
           method,
+          inputs,
         })
         return c.json(result)
       },

+ 1 - 0
packages/opencode/src/server/server.ts

@@ -66,6 +66,7 @@ export namespace Server {
           let status: ContentfulStatusCode
           if (err instanceof NotFoundError) status = 404
           else if (err instanceof Provider.ModelNotFoundError) status = 400
+          else if (err.name === "ProviderAuthValidationFailed") status = 400
           else if (err.name.startsWith("Worktree")) status = 400
           else status = 500
           return c.json(err.toObject(), { status })

+ 14 - 0
packages/plugin/src/index.ts

@@ -34,6 +34,12 @@ export type PluginInput = {
 
 export type Plugin = (input: PluginInput) => Promise<Hooks>
 
+type Rule = {
+  key: string
+  op: "eq" | "neq"
+  value: string
+}
+
 export type AuthHook = {
   provider: string
   loader?: (auth: () => Promise<Auth>, provider: Provider) => Promise<Record<string, any>>
@@ -48,7 +54,9 @@ export type AuthHook = {
               message: string
               placeholder?: string
               validate?: (value: string) => string | undefined
+              /** @deprecated Use `when` instead */
               condition?: (inputs: Record<string, string>) => boolean
+              when?: Rule
             }
           | {
               type: "select"
@@ -59,7 +67,9 @@ export type AuthHook = {
                 value: string
                 hint?: string
               }>
+              /** @deprecated Use `when` instead */
               condition?: (inputs: Record<string, string>) => boolean
+              when?: Rule
             }
         >
         authorize(inputs?: Record<string, string>): Promise<AuthOuathResult>
@@ -74,7 +84,9 @@ export type AuthHook = {
               message: string
               placeholder?: string
               validate?: (value: string) => string | undefined
+              /** @deprecated Use `when` instead */
               condition?: (inputs: Record<string, string>) => boolean
+              when?: Rule
             }
           | {
               type: "select"
@@ -85,7 +97,9 @@ export type AuthHook = {
                 value: string
                 hint?: string
               }>
+              /** @deprecated Use `when` instead */
               condition?: (inputs: Record<string, string>) => boolean
+              when?: Rule
             }
         >
         authorize?(inputs?: Record<string, string>): Promise<

+ 4 - 0
packages/sdk/js/src/v2/gen/sdk.gen.ts

@@ -2496,6 +2496,9 @@ export class Oauth extends HeyApiClient {
       directory?: string
       workspace?: string
       method?: number
+      inputs?: {
+        [key: string]: string
+      }
     },
     options?: Options<never, ThrowOnError>,
   ) {
@@ -2508,6 +2511,7 @@ export class Oauth extends HeyApiClient {
             { in: "query", key: "directory" },
             { in: "query", key: "workspace" },
             { in: "body", key: "method" },
+            { in: "body", key: "inputs" },
           ],
         },
       ],

+ 34 - 0
packages/sdk/js/src/v2/gen/types.gen.ts

@@ -1769,6 +1769,34 @@ export type SubtaskPartInput = {
 export type ProviderAuthMethod = {
   type: "oauth" | "api"
   label: string
+  prompts?: Array<
+    | {
+        type: "text"
+        key: string
+        message: string
+        placeholder?: string
+        when?: {
+          key: string
+          op: "eq" | "neq"
+          value: string
+        }
+      }
+    | {
+        type: "select"
+        key: string
+        message: string
+        options: Array<{
+          label: string
+          value: string
+          hint?: string
+        }>
+        when?: {
+          key: string
+          op: "eq" | "neq"
+          value: string
+        }
+      }
+  >
 }
 
 export type ProviderAuthAuthorization = {
@@ -3983,6 +4011,12 @@ export type ProviderOauthAuthorizeData = {
      * Auth method index
      */
     method: number
+    /**
+     * Prompt inputs
+     */
+    inputs?: {
+      [key: string]: string
+    }
   }
   path: {
     /**

+ 118 - 0
packages/sdk/openapi.json

@@ -4761,6 +4761,16 @@
                   "method": {
                     "description": "Auth method index",
                     "type": "number"
+                  },
+                  "inputs": {
+                    "description": "Prompt inputs",
+                    "type": "object",
+                    "propertyNames": {
+                      "type": "string"
+                    },
+                    "additionalProperties": {
+                      "type": "string"
+                    }
                   }
                 },
                 "required": ["method"]
@@ -11541,6 +11551,114 @@
           },
           "label": {
             "type": "string"
+          },
+          "prompts": {
+            "type": "array",
+            "items": {
+              "anyOf": [
+                {
+                  "type": "object",
+                  "properties": {
+                    "type": {
+                      "type": "string",
+                      "const": "text"
+                    },
+                    "key": {
+                      "type": "string"
+                    },
+                    "message": {
+                      "type": "string"
+                    },
+                    "placeholder": {
+                      "type": "string"
+                    },
+                    "when": {
+                      "type": "object",
+                      "properties": {
+                        "key": {
+                          "type": "string"
+                        },
+                        "op": {
+                          "anyOf": [
+                            {
+                              "type": "string",
+                              "const": "eq"
+                            },
+                            {
+                              "type": "string",
+                              "const": "neq"
+                            }
+                          ]
+                        },
+                        "value": {
+                          "type": "string"
+                        }
+                      },
+                      "required": ["key", "op", "value"]
+                    }
+                  },
+                  "required": ["type", "key", "message"]
+                },
+                {
+                  "type": "object",
+                  "properties": {
+                    "type": {
+                      "type": "string",
+                      "const": "select"
+                    },
+                    "key": {
+                      "type": "string"
+                    },
+                    "message": {
+                      "type": "string"
+                    },
+                    "options": {
+                      "type": "array",
+                      "items": {
+                        "type": "object",
+                        "properties": {
+                          "label": {
+                            "type": "string"
+                          },
+                          "value": {
+                            "type": "string"
+                          },
+                          "hint": {
+                            "type": "string"
+                          }
+                        },
+                        "required": ["label", "value"]
+                      }
+                    },
+                    "when": {
+                      "type": "object",
+                      "properties": {
+                        "key": {
+                          "type": "string"
+                        },
+                        "op": {
+                          "anyOf": [
+                            {
+                              "type": "string",
+                              "const": "eq"
+                            },
+                            {
+                              "type": "string",
+                              "const": "neq"
+                            }
+                          ]
+                        },
+                        "value": {
+                          "type": "string"
+                        }
+                      },
+                      "required": ["key", "op", "value"]
+                    }
+                  },
+                  "required": ["type", "key", "message", "options"]
+                }
+              ]
+            }
           }
         },
         "required": ["type", "label"]