|
@@ -1,4 +1,5 @@
|
|
|
import { createConnection } from "net"
|
|
import { createConnection } from "net"
|
|
|
|
|
+import { createServer } from "http"
|
|
|
import { Log } from "../util/log"
|
|
import { Log } from "../util/log"
|
|
|
import { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } from "./oauth-provider"
|
|
import { OAUTH_CALLBACK_PORT, OAUTH_CALLBACK_PATH } from "./oauth-provider"
|
|
|
|
|
|
|
@@ -52,89 +53,91 @@ interface PendingAuth {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
export namespace McpOAuthCallback {
|
|
export namespace McpOAuthCallback {
|
|
|
- let server: ReturnType<typeof Bun.serve> | undefined
|
|
|
|
|
|
|
+ let server: ReturnType<typeof createServer> | undefined
|
|
|
const pendingAuths = new Map<string, PendingAuth>()
|
|
const pendingAuths = new Map<string, PendingAuth>()
|
|
|
|
|
|
|
|
const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes
|
|
const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes
|
|
|
|
|
|
|
|
- export async function ensureRunning(): Promise<void> {
|
|
|
|
|
- if (server) return
|
|
|
|
|
|
|
+ function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) {
|
|
|
|
|
+ const url = new URL(req.url || "/", `http://localhost:${OAUTH_CALLBACK_PORT}`)
|
|
|
|
|
|
|
|
- const running = await isPortInUse()
|
|
|
|
|
- if (running) {
|
|
|
|
|
- log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT })
|
|
|
|
|
|
|
+ if (url.pathname !== OAUTH_CALLBACK_PATH) {
|
|
|
|
|
+ res.writeHead(404)
|
|
|
|
|
+ res.end("Not found")
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- server = Bun.serve({
|
|
|
|
|
- port: OAUTH_CALLBACK_PORT,
|
|
|
|
|
- fetch(req) {
|
|
|
|
|
- const url = new URL(req.url)
|
|
|
|
|
|
|
+ const code = url.searchParams.get("code")
|
|
|
|
|
+ const state = url.searchParams.get("state")
|
|
|
|
|
+ const error = url.searchParams.get("error")
|
|
|
|
|
+ const errorDescription = url.searchParams.get("error_description")
|
|
|
|
|
|
|
|
- if (url.pathname !== OAUTH_CALLBACK_PATH) {
|
|
|
|
|
- return new Response("Not found", { status: 404 })
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ log.info("received oauth callback", { hasCode: !!code, state, error })
|
|
|
|
|
|
|
|
- const code = url.searchParams.get("code")
|
|
|
|
|
- const state = url.searchParams.get("state")
|
|
|
|
|
- const error = url.searchParams.get("error")
|
|
|
|
|
- const errorDescription = url.searchParams.get("error_description")
|
|
|
|
|
-
|
|
|
|
|
- log.info("received oauth callback", { hasCode: !!code, state, error })
|
|
|
|
|
-
|
|
|
|
|
- // Enforce state parameter presence
|
|
|
|
|
- if (!state) {
|
|
|
|
|
- const errorMsg = "Missing required state parameter - potential CSRF attack"
|
|
|
|
|
- log.error("oauth callback missing state parameter", { url: url.toString() })
|
|
|
|
|
- return new Response(HTML_ERROR(errorMsg), {
|
|
|
|
|
- status: 400,
|
|
|
|
|
- headers: { "Content-Type": "text/html" },
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // Enforce state parameter presence
|
|
|
|
|
+ if (!state) {
|
|
|
|
|
+ const errorMsg = "Missing required state parameter - potential CSRF attack"
|
|
|
|
|
+ log.error("oauth callback missing state parameter", { url: url.toString() })
|
|
|
|
|
+ res.writeHead(400, { "Content-Type": "text/html" })
|
|
|
|
|
+ res.end(HTML_ERROR(errorMsg))
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- if (error) {
|
|
|
|
|
- const errorMsg = errorDescription || error
|
|
|
|
|
- if (pendingAuths.has(state)) {
|
|
|
|
|
- const pending = pendingAuths.get(state)!
|
|
|
|
|
- clearTimeout(pending.timeout)
|
|
|
|
|
- pendingAuths.delete(state)
|
|
|
|
|
- pending.reject(new Error(errorMsg))
|
|
|
|
|
- }
|
|
|
|
|
- return new Response(HTML_ERROR(errorMsg), {
|
|
|
|
|
- headers: { "Content-Type": "text/html" },
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ if (error) {
|
|
|
|
|
+ const errorMsg = errorDescription || error
|
|
|
|
|
+ if (pendingAuths.has(state)) {
|
|
|
|
|
+ const pending = pendingAuths.get(state)!
|
|
|
|
|
+ clearTimeout(pending.timeout)
|
|
|
|
|
+ pendingAuths.delete(state)
|
|
|
|
|
+ pending.reject(new Error(errorMsg))
|
|
|
|
|
+ }
|
|
|
|
|
+ res.writeHead(200, { "Content-Type": "text/html" })
|
|
|
|
|
+ res.end(HTML_ERROR(errorMsg))
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- if (!code) {
|
|
|
|
|
- return new Response(HTML_ERROR("No authorization code provided"), {
|
|
|
|
|
- status: 400,
|
|
|
|
|
- headers: { "Content-Type": "text/html" },
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ if (!code) {
|
|
|
|
|
+ res.writeHead(400, { "Content-Type": "text/html" })
|
|
|
|
|
+ res.end(HTML_ERROR("No authorization code provided"))
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Validate state parameter
|
|
|
|
|
- if (!pendingAuths.has(state)) {
|
|
|
|
|
- const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
|
|
|
|
|
- log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
|
|
|
|
|
- return new Response(HTML_ERROR(errorMsg), {
|
|
|
|
|
- status: 400,
|
|
|
|
|
- headers: { "Content-Type": "text/html" },
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ // Validate state parameter
|
|
|
|
|
+ if (!pendingAuths.has(state)) {
|
|
|
|
|
+ const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
|
|
|
|
|
+ log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
|
|
|
|
|
+ res.writeHead(400, { "Content-Type": "text/html" })
|
|
|
|
|
+ res.end(HTML_ERROR(errorMsg))
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- const pending = pendingAuths.get(state)!
|
|
|
|
|
|
|
+ const pending = pendingAuths.get(state)!
|
|
|
|
|
|
|
|
- clearTimeout(pending.timeout)
|
|
|
|
|
- pendingAuths.delete(state)
|
|
|
|
|
- pending.resolve(code)
|
|
|
|
|
|
|
+ clearTimeout(pending.timeout)
|
|
|
|
|
+ pendingAuths.delete(state)
|
|
|
|
|
+ pending.resolve(code)
|
|
|
|
|
|
|
|
- return new Response(HTML_SUCCESS, {
|
|
|
|
|
- headers: { "Content-Type": "text/html" },
|
|
|
|
|
- })
|
|
|
|
|
- },
|
|
|
|
|
- })
|
|
|
|
|
|
|
+ res.writeHead(200, { "Content-Type": "text/html" })
|
|
|
|
|
+ res.end(HTML_SUCCESS)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ export async function ensureRunning(): Promise<void> {
|
|
|
|
|
+ if (server) return
|
|
|
|
|
|
|
|
- log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT })
|
|
|
|
|
|
|
+ const running = await isPortInUse()
|
|
|
|
|
+ if (running) {
|
|
|
|
|
+ log.info("oauth callback server already running on another instance", { port: OAUTH_CALLBACK_PORT })
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ server = createServer(handleRequest)
|
|
|
|
|
+ await new Promise<void>((resolve, reject) => {
|
|
|
|
|
+ server!.listen(OAUTH_CALLBACK_PORT, () => {
|
|
|
|
|
+ log.info("oauth callback server started", { port: OAUTH_CALLBACK_PORT })
|
|
|
|
|
+ resolve()
|
|
|
|
|
+ })
|
|
|
|
|
+ server!.on("error", reject)
|
|
|
|
|
+ })
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
export function waitForCallback(oauthState: string): Promise<string> {
|
|
export function waitForCallback(oauthState: string): Promise<string> {
|
|
@@ -174,7 +177,7 @@ export namespace McpOAuthCallback {
|
|
|
|
|
|
|
|
export async function stop(): Promise<void> {
|
|
export async function stop(): Promise<void> {
|
|
|
if (server) {
|
|
if (server) {
|
|
|
- server.stop()
|
|
|
|
|
|
|
+ await new Promise<void>((resolve) => server!.close(() => resolve()))
|
|
|
server = undefined
|
|
server = undefined
|
|
|
log.info("oauth callback server stopped")
|
|
log.info("oauth callback server stopped")
|
|
|
}
|
|
}
|