Просмотр исходного кода

chore: extract OAuth changes into #18327

Dax Raad 1 месяц назад
Родитель
Сommit
65e786258a
2 измененных файлов с 122 добавлено и 132 удалено
  1. 67 70
      packages/opencode/src/mcp/oauth-callback.ts
  2. 55 62
      packages/opencode/src/plugin/codex.ts

+ 67 - 70
packages/opencode/src/mcp/oauth-callback.ts

@@ -1,5 +1,4 @@
 import { createConnection } from "net"
 import { createConnection } from "net"
-import { createServer } from "http"
 import { Log } from "../util/log"
 import { Log } from "../util/log"
 import { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } from "./oauth-provider"
 import { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } from "./oauth-provider"
 
 
@@ -53,91 +52,89 @@ interface PendingAuth {
 }
 }
 
 
 export namespace McpOAuthCallback {
 export namespace McpOAuthCallback {
-  let server: ReturnType<typeof createServer> | undefined
+  let server: ReturnType<typeof Bun.serve> | undefined
   const pendingAuths = new Map<string, PendingAuth>()
   const pendingAuths = new Map<string, PendingAuth>()
 
 
   const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes
   const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes
 
 
-  function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) {
-    const url = new URL(req.url || "/", `http://localhost:${OAUTH_CALLBACK_PORT}`)
+  export async function ensureRunning(): Promise<void> {
+    if (server) return
 
 
-    if (url.pathname !== OAUTH_CALLBACK_PATH) {
-      res.writeHead(404)
-      res.end("Not found")
+    const running = await isPortInUse()
+    if (running) {
+      log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT })
       return
       return
     }
     }
 
 
-    const code = url.searchParams.get("code")
-    const state = url.searchParams.get("state")
-    const error = url.searchParams.get("error")
-    const errorDescription = url.searchParams.get("error_description")
+    server = Bun.serve({
+      port: OAUTH_CALLBACK_PORT,
+      fetch(req) {
+        const url = new URL(req.url)
 
 
-    log.info("received oauth callback", { hasCode: !!code, state, error })
-
-    // Enforce state parameter presence
-    if (!state) {
-      const errorMsg = "Missing required state parameter - potential CSRF attack"
-      log.error("oauth callback missing state parameter", { url: url.toString() })
-      res.writeHead(400, { "Content-Type": "text/html" })
-      res.end(HTML_ERROR(errorMsg))
-      return
-    }
-
-    if (error) {
-      const errorMsg = errorDescription || error
-      if (pendingAuths.has(state)) {
-        const pending = pendingAuths.get(state)!
-        clearTimeout(pending.timeout)
-        pendingAuths.delete(state)
-        pending.reject(new Error(errorMsg))
-      }
-      res.writeHead(200, { "Content-Type": "text/html" })
-      res.end(HTML_ERROR(errorMsg))
-      return
-    }
-
-    if (!code) {
-      res.writeHead(400, { "Content-Type": "text/html" })
-      res.end(HTML_ERROR("No authorization code provided"))
-      return
-    }
+        if (url.pathname !== OAUTH_CALLBACK_PATH) {
+          return new Response("Not found", { status: 404 })
+        }
 
 
-    // Validate state parameter
-    if (!pendingAuths.has(state)) {
-      const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
-      log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
-      res.writeHead(400, { "Content-Type": "text/html" })
-      res.end(HTML_ERROR(errorMsg))
-      return
-    }
+        const code = url.searchParams.get("code")
+        const state = url.searchParams.get("state")
+        const error = url.searchParams.get("error")
+        const errorDescription = url.searchParams.get("error_description")
+
+        log.info("received oauth callback", { hasCode: !!code, state, error })
+
+        // Enforce state parameter presence
+        if (!state) {
+          const errorMsg = "Missing required state parameter - potential CSRF attack"
+          log.error("oauth callback missing state parameter", { url: url.toString() })
+          return new Response(HTML_ERROR(errorMsg), {
+            status: 400,
+            headers: { "Content-Type": "text/html" },
+          })
+        }
 
 
-    const pending = pendingAuths.get(state)!
+        if (error) {
+          const errorMsg = errorDescription || error
+          if (pendingAuths.has(state)) {
+            const pending = pendingAuths.get(state)!
+            clearTimeout(pending.timeout)
+            pendingAuths.delete(state)
+            pending.reject(new Error(errorMsg))
+          }
+          return new Response(HTML_ERROR(errorMsg), {
+            headers: { "Content-Type": "text/html" },
+          })
+        }
 
 
-    clearTimeout(pending.timeout)
-    pendingAuths.delete(state)
-    pending.resolve(code)
+        if (!code) {
+          return new Response(HTML_ERROR("No authorization code provided"), {
+            status: 400,
+            headers: { "Content-Type": "text/html" },
+          })
+        }
 
 
-    res.writeHead(200, { "Content-Type": "text/html" })
-    res.end(HTML_SUCCESS)
-  }
+        // Validate state parameter
+        if (!pendingAuths.has(state)) {
+          const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
+          log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
+          return new Response(HTML_ERROR(errorMsg), {
+            status: 400,
+            headers: { "Content-Type": "text/html" },
+          })
+        }
 
 
-  export async function ensureRunning(): Promise<void> {
-    if (server) return
+        const pending = pendingAuths.get(state)!
 
 
-    const running = await isPortInUse()
-    if (running) {
-      log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT })
-      return
-    }
+        clearTimeout(pending.timeout)
+        pendingAuths.delete(state)
+        pending.resolve(code)
 
 
-    server = createServer(handleRequest)
-    await new Promise<void>((resolve, reject) => {
-      server!.listen(OAUTH_CALLBACK_PORT, () => {
-        log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT })
-        resolve()
-      })
-      server!.on("error", reject)
+        return new Response(HTML_SUCCESS, {
+          headers: { "Content-Type": "text/html" },
+        })
+      },
     })
     })
+
+    log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT })
   }
   }
 
 
   export function waitForCallback(oauthState: string): Promise<string> {
   export function waitForCallback(oauthState: string): Promise<string> {
@@ -177,7 +174,7 @@ export namespace McpOAuthCallback {
 
 
   export async function stop(): Promise<void> {
   export async function stop(): Promise<void> {
     if (server) {
     if (server) {
-      await new Promise<void>((resolve) => server!.close(() => resolve()))
+      server.stop()
       server = undefined
       server = undefined
       log.info("oauth callback server stopped")
       log.info("oauth callback server stopped")
     }
     }

+ 55 - 62
packages/opencode/src/plugin/codex.ts

@@ -6,7 +6,6 @@ import os from "os"
 import { ProviderTransform } from "@/provider/transform"
 import { ProviderTransform } from "@/provider/transform"
 import { ModelID, ProviderID } from "@/provider/schema"
 import { ModelID, ProviderID } from "@/provider/schema"
 import { setTimeout as sleep } from "node:timers/promises"
 import { setTimeout as sleep } from "node:timers/promises"
-import { createServer } from "http"
 
 
 const log = Log.create({ service: "plugin.codex" })
 const log = Log.create({ service: "plugin.codex" })
 
 
@@ -242,7 +241,7 @@ interface PendingOAuth {
   reject: (error: Error) => void
   reject: (error: Error) => void
 }
 }
 
 
-let oauthServer: ReturnType<typeof createServer> | undefined
+let oauthServer: ReturnType<typeof Bun.serve> | undefined
 let pendingOAuth: PendingOAuth | undefined
 let pendingOAuth: PendingOAuth | undefined
 
 
 async function startOAuthServer(): Promise<{ port: number; redirectUri: string }> {
 async function startOAuthServer(): Promise<{ port: number; redirectUri: string }> {
@@ -250,83 +249,77 @@ async function startOAuthServer(): Promise<{ port: number; redirectUri: string }
     return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` }
     return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` }
   }
   }
 
 
-  oauthServer = createServer((req, res) => {
-    const url = new URL(req.url || "/", `http://localhost:${OAUTH_PORT}`)
+  oauthServer = Bun.serve({
+    port: OAUTH_PORT,
+    fetch(req) {
+      const url = new URL(req.url)
 
 
-    if (url.pathname === "/auth/callback") {
-      const code = url.searchParams.get("code")
-      const state = url.searchParams.get("state")
-      const error = url.searchParams.get("error")
-      const errorDescription = url.searchParams.get("error_description")
+      if (url.pathname === "/auth/callback") {
+        const code = url.searchParams.get("code")
+        const state = url.searchParams.get("state")
+        const error = url.searchParams.get("error")
+        const errorDescription = url.searchParams.get("error_description")
 
 
-      if (error) {
-        const errorMsg = errorDescription || error
-        pendingOAuth?.reject(new Error(errorMsg))
-        pendingOAuth = undefined
-        res.writeHead(200, { "Content-Type": "text/html" })
-        res.end(HTML_ERROR(errorMsg))
-        return
-      }
-
-      if (!code) {
-        const errorMsg = "Missing authorization code"
-        pendingOAuth?.reject(new Error(errorMsg))
-        pendingOAuth = undefined
-        res.writeHead(400, { "Content-Type": "text/html" })
-        res.end(HTML_ERROR(errorMsg))
-        return
-      }
+        if (error) {
+          const errorMsg = errorDescription || error
+          pendingOAuth?.reject(new Error(errorMsg))
+          pendingOAuth = undefined
+          return new Response(HTML_ERROR(errorMsg), {
+            headers: { "Content-Type": "text/html" },
+          })
+        }
 
 
-      if (!pendingOAuth || state !== pendingOAuth.state) {
-        const errorMsg = "Invalid state - potential CSRF attack"
-        pendingOAuth?.reject(new Error(errorMsg))
-        pendingOAuth = undefined
-        res.writeHead(400, { "Content-Type": "text/html" })
-        res.end(HTML_ERROR(errorMsg))
-        return
-      }
+        if (!code) {
+          const errorMsg = "Missing authorization code"
+          pendingOAuth?.reject(new Error(errorMsg))
+          pendingOAuth = undefined
+          return new Response(HTML_ERROR(errorMsg), {
+            status: 400,
+            headers: { "Content-Type": "text/html" },
+          })
+        }
 
 
-      const current = pendingOAuth
-      pendingOAuth = undefined
+        if (!pendingOAuth || state !== pendingOAuth.state) {
+          const errorMsg = "Invalid state - potential CSRF attack"
+          pendingOAuth?.reject(new Error(errorMsg))
+          pendingOAuth = undefined
+          return new Response(HTML_ERROR(errorMsg), {
+            status: 400,
+            headers: { "Content-Type": "text/html" },
+          })
+        }
 
 
-      exchangeCodeForTokens(code, `http://localhost:${OAUTH_PORT}/auth/callback`, current.pkce)
-        .then((tokens) => current.resolve(tokens))
-        .catch((err) => current.reject(err))
+        const current = pendingOAuth
+        pendingOAuth = undefined
 
 
-      res.writeHead(200, { "Content-Type": "text/html" })
-      res.end(HTML_SUCCESS)
-      return
-    }
+        exchangeCodeForTokens(code, `http://localhost:${OAUTH_PORT}/auth/callback`, current.pkce)
+          .then((tokens) => current.resolve(tokens))
+          .catch((err) => current.reject(err))
 
 
-    if (url.pathname === "/cancel") {
-      pendingOAuth?.reject(new Error("Login cancelled"))
-      pendingOAuth = undefined
-      res.writeHead(200)
-      res.end("Login cancelled")
-      return
-    }
+        return new Response(HTML_SUCCESS, {
+          headers: { "Content-Type": "text/html" },
+        })
+      }
 
 
-    res.writeHead(404)
-    res.end("Not found")
-  })
+      if (url.pathname === "/cancel") {
+        pendingOAuth?.reject(new Error("Login cancelled"))
+        pendingOAuth = undefined
+        return new Response("Login cancelled", { status: 200 })
+      }
 
 
-  await new Promise<void>((resolve, reject) => {
-    oauthServer!.listen(OAUTH_PORT, () => {
-      log.info("codex oauth server started", { port: OAUTH_PORT })
-      resolve()
-    })
-    oauthServer!.on("error", reject)
+      return new Response("Not found", { status: 404 })
+    },
   })
   })
 
 
+  log.info("codex oauth server started", { port: OAUTH_PORT })
   return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` }
   return { port: OAUTH_PORT, redirectUri: `http://localhost:${OAUTH_PORT}/auth/callback` }
 }
 }
 
 
 function stopOAuthServer() {
 function stopOAuthServer() {
   if (oauthServer) {
   if (oauthServer) {
-    oauthServer.close(() => {
-      log.info("codex oauth server stopped")
-    })
+    oauthServer.stop()
     oauthServer = undefined
     oauthServer = undefined
+    log.info("codex oauth server stopped")
   }
   }
 }
 }