|
|
@@ -56,177 +56,175 @@ interface PendingAuth {
|
|
|
timeout: ReturnType<typeof setTimeout>
|
|
|
}
|
|
|
|
|
|
-export namespace McpOAuthCallback {
|
|
|
- let server: ReturnType<typeof createServer> | undefined
|
|
|
- const pendingAuths = new Map<string, PendingAuth>()
|
|
|
- // Reverse index: mcpName → oauthState, so cancelPending(mcpName) can
|
|
|
- // find the right entry in pendingAuths (which is keyed by oauthState).
|
|
|
- const mcpNameToState = new Map<string, string>()
|
|
|
-
|
|
|
- const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes
|
|
|
-
|
|
|
- function cleanupStateIndex(oauthState: string) {
|
|
|
- for (const [name, state] of mcpNameToState) {
|
|
|
- if (state === oauthState) {
|
|
|
- mcpNameToState.delete(name)
|
|
|
- break
|
|
|
- }
|
|
|
+let server: ReturnType<typeof createServer> | undefined
|
|
|
+const pendingAuths = new Map<string, PendingAuth>()
|
|
|
+// Reverse index: mcpName → oauthState, so cancelPending(mcpName) can
|
|
|
+// find the right entry in pendingAuths (which is keyed by oauthState).
|
|
|
+const mcpNameToState = new Map<string, string>()
|
|
|
+
|
|
|
+const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes
|
|
|
+
|
|
|
+function cleanupStateIndex(oauthState: string) {
|
|
|
+ for (const [name, state] of mcpNameToState) {
|
|
|
+ if (state === oauthState) {
|
|
|
+ mcpNameToState.delete(name)
|
|
|
+ break
|
|
|
}
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) {
|
|
|
- const url = new URL(req.url || "/", `http://localhost:${currentPort}`)
|
|
|
+function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) {
|
|
|
+ const url = new URL(req.url || "/", `http://localhost:${currentPort}`)
|
|
|
|
|
|
- if (url.pathname !== currentPath) {
|
|
|
- res.writeHead(404)
|
|
|
- res.end("Not found")
|
|
|
- return
|
|
|
- }
|
|
|
+ if (url.pathname !== currentPath) {
|
|
|
+ res.writeHead(404)
|
|
|
+ res.end("Not found")
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- const code = url.searchParams.get("code")
|
|
|
- const state = url.searchParams.get("state")
|
|
|
- const error = url.searchParams.get("error")
|
|
|
- const errorDescription = url.searchParams.get("error_description")
|
|
|
+ const code = url.searchParams.get("code")
|
|
|
+ const state = url.searchParams.get("state")
|
|
|
+ const error = url.searchParams.get("error")
|
|
|
+ const errorDescription = url.searchParams.get("error_description")
|
|
|
|
|
|
- log.info("received oauth callback", { hasCode: !!code, state, error })
|
|
|
+ log.info("received oauth callback", { hasCode: !!code, state, error })
|
|
|
|
|
|
- // Enforce state parameter presence
|
|
|
- if (!state) {
|
|
|
- const errorMsg = "Missing required state parameter - potential CSRF attack"
|
|
|
- log.error("oauth callback missing state parameter", { url: url.toString() })
|
|
|
- res.writeHead(400, { "Content-Type": "text/html" })
|
|
|
- res.end(HTML_ERROR(errorMsg))
|
|
|
- return
|
|
|
- }
|
|
|
+ // Enforce state parameter presence
|
|
|
+ if (!state) {
|
|
|
+ const errorMsg = "Missing required state parameter - potential CSRF attack"
|
|
|
+ log.error("oauth callback missing state parameter", { url: url.toString() })
|
|
|
+ res.writeHead(400, { "Content-Type": "text/html" })
|
|
|
+ res.end(HTML_ERROR(errorMsg))
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- if (error) {
|
|
|
- const errorMsg = errorDescription || error
|
|
|
- if (pendingAuths.has(state)) {
|
|
|
- const pending = pendingAuths.get(state)!
|
|
|
- clearTimeout(pending.timeout)
|
|
|
- pendingAuths.delete(state)
|
|
|
- cleanupStateIndex(state)
|
|
|
- pending.reject(new Error(errorMsg))
|
|
|
- }
|
|
|
- res.writeHead(200, { "Content-Type": "text/html" })
|
|
|
- res.end(HTML_ERROR(errorMsg))
|
|
|
- return
|
|
|
+ if (error) {
|
|
|
+ const errorMsg = errorDescription || error
|
|
|
+ if (pendingAuths.has(state)) {
|
|
|
+ const pending = pendingAuths.get(state)!
|
|
|
+ clearTimeout(pending.timeout)
|
|
|
+ pendingAuths.delete(state)
|
|
|
+ cleanupStateIndex(state)
|
|
|
+ pending.reject(new Error(errorMsg))
|
|
|
}
|
|
|
+ res.writeHead(200, { "Content-Type": "text/html" })
|
|
|
+ res.end(HTML_ERROR(errorMsg))
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- if (!code) {
|
|
|
- res.writeHead(400, { "Content-Type": "text/html" })
|
|
|
- res.end(HTML_ERROR("No authorization code provided"))
|
|
|
- return
|
|
|
- }
|
|
|
+ if (!code) {
|
|
|
+ res.writeHead(400, { "Content-Type": "text/html" })
|
|
|
+ res.end(HTML_ERROR("No authorization code provided"))
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- // Validate state parameter
|
|
|
- if (!pendingAuths.has(state)) {
|
|
|
- const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
|
|
|
- log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
|
|
|
- res.writeHead(400, { "Content-Type": "text/html" })
|
|
|
- res.end(HTML_ERROR(errorMsg))
|
|
|
- return
|
|
|
- }
|
|
|
+ // Validate state parameter
|
|
|
+ if (!pendingAuths.has(state)) {
|
|
|
+ const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
|
|
|
+ log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
|
|
|
+ res.writeHead(400, { "Content-Type": "text/html" })
|
|
|
+ res.end(HTML_ERROR(errorMsg))
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- const pending = pendingAuths.get(state)!
|
|
|
+ const pending = pendingAuths.get(state)!
|
|
|
|
|
|
- clearTimeout(pending.timeout)
|
|
|
- pendingAuths.delete(state)
|
|
|
- cleanupStateIndex(state)
|
|
|
- pending.resolve(code)
|
|
|
+ clearTimeout(pending.timeout)
|
|
|
+ pendingAuths.delete(state)
|
|
|
+ cleanupStateIndex(state)
|
|
|
+ pending.resolve(code)
|
|
|
|
|
|
- res.writeHead(200, { "Content-Type": "text/html" })
|
|
|
- res.end(HTML_SUCCESS)
|
|
|
- }
|
|
|
+ res.writeHead(200, { "Content-Type": "text/html" })
|
|
|
+ res.end(HTML_SUCCESS)
|
|
|
+}
|
|
|
|
|
|
- export async function ensureRunning(redirectUri?: string): Promise<void> {
|
|
|
- // Parse the redirect URI to get port and path (uses defaults if not provided)
|
|
|
- const { port, path } = parseRedirectUri(redirectUri)
|
|
|
+export async function ensureRunning(redirectUri?: string): Promise<void> {
|
|
|
+ // Parse the redirect URI to get port and path (uses defaults if not provided)
|
|
|
+ const { port, path } = parseRedirectUri(redirectUri)
|
|
|
|
|
|
- // If server is running on a different port/path, stop it first
|
|
|
- if (server && (currentPort !== port || currentPath !== path)) {
|
|
|
- log.info("stopping oauth callback server to reconfigure", { oldPort: currentPort, newPort: port })
|
|
|
- await stop()
|
|
|
- }
|
|
|
+ // If server is running on a different port/path, stop it first
|
|
|
+ if (server && (currentPort !== port || currentPath !== path)) {
|
|
|
+ log.info("stopping oauth callback server to reconfigure", { oldPort: currentPort, newPort: port })
|
|
|
+ await stop()
|
|
|
+ }
|
|
|
|
|
|
- if (server) return
|
|
|
+ if (server) return
|
|
|
|
|
|
- const running = await isPortInUse(port)
|
|
|
- if (running) {
|
|
|
- log.info("oauth callback server already running on another instance", { port })
|
|
|
- return
|
|
|
- }
|
|
|
+ const running = await isPortInUse(port)
|
|
|
+ if (running) {
|
|
|
+ log.info("oauth callback server already running on another instance", { port })
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- currentPort = port
|
|
|
- currentPath = path
|
|
|
+ currentPort = port
|
|
|
+ currentPath = path
|
|
|
|
|
|
- server = createServer(handleRequest)
|
|
|
- await new Promise<void>((resolve, reject) => {
|
|
|
- server!.listen(currentPort, () => {
|
|
|
- log.info("oauth callback server started", { port: currentPort, path: currentPath })
|
|
|
- resolve()
|
|
|
- })
|
|
|
- server!.on("error", reject)
|
|
|
+ server = createServer(handleRequest)
|
|
|
+ await new Promise<void>((resolve, reject) => {
|
|
|
+ server!.listen(currentPort, () => {
|
|
|
+ log.info("oauth callback server started", { port: currentPort, path: currentPath })
|
|
|
+ resolve()
|
|
|
})
|
|
|
- }
|
|
|
+ server!.on("error", reject)
|
|
|
+ })
|
|
|
+}
|
|
|
|
|
|
- export function waitForCallback(oauthState: string, mcpName?: string): Promise<string> {
|
|
|
- if (mcpName) mcpNameToState.set(mcpName, oauthState)
|
|
|
- return new Promise((resolve, reject) => {
|
|
|
- const timeout = setTimeout(() => {
|
|
|
- if (pendingAuths.has(oauthState)) {
|
|
|
- pendingAuths.delete(oauthState)
|
|
|
- if (mcpName) mcpNameToState.delete(mcpName)
|
|
|
- reject(new Error("OAuth callback timeout - authorization took too long"))
|
|
|
- }
|
|
|
- }, CALLBACK_TIMEOUT_MS)
|
|
|
-
|
|
|
- pendingAuths.set(oauthState, { resolve, reject, timeout })
|
|
|
- })
|
|
|
- }
|
|
|
+export function waitForCallback(oauthState: string, mcpName?: string): Promise<string> {
|
|
|
+ if (mcpName) mcpNameToState.set(mcpName, oauthState)
|
|
|
+ return new Promise((resolve, reject) => {
|
|
|
+ const timeout = setTimeout(() => {
|
|
|
+ if (pendingAuths.has(oauthState)) {
|
|
|
+ pendingAuths.delete(oauthState)
|
|
|
+ if (mcpName) mcpNameToState.delete(mcpName)
|
|
|
+ reject(new Error("OAuth callback timeout - authorization took too long"))
|
|
|
+ }
|
|
|
+ }, CALLBACK_TIMEOUT_MS)
|
|
|
|
|
|
- export function cancelPending(mcpName: string): void {
|
|
|
- // Look up the oauthState for this mcpName via the reverse index
|
|
|
- const oauthState = mcpNameToState.get(mcpName)
|
|
|
- const key = oauthState ?? mcpName
|
|
|
- const pending = pendingAuths.get(key)
|
|
|
- if (pending) {
|
|
|
- clearTimeout(pending.timeout)
|
|
|
- pendingAuths.delete(key)
|
|
|
- mcpNameToState.delete(mcpName)
|
|
|
- pending.reject(new Error("Authorization cancelled"))
|
|
|
- }
|
|
|
- }
|
|
|
+ pendingAuths.set(oauthState, { resolve, reject, timeout })
|
|
|
+ })
|
|
|
+}
|
|
|
|
|
|
- export async function isPortInUse(port: number = OAUTH_CALLBACK_PORT): Promise<boolean> {
|
|
|
- return new Promise((resolve) => {
|
|
|
- const socket = createConnection(port, "127.0.0.1")
|
|
|
- socket.on("connect", () => {
|
|
|
- socket.destroy()
|
|
|
- resolve(true)
|
|
|
- })
|
|
|
- socket.on("error", () => {
|
|
|
- resolve(false)
|
|
|
- })
|
|
|
- })
|
|
|
+export function cancelPending(mcpName: string): void {
|
|
|
+ // Look up the oauthState for this mcpName via the reverse index
|
|
|
+ const oauthState = mcpNameToState.get(mcpName)
|
|
|
+ const key = oauthState ?? mcpName
|
|
|
+ const pending = pendingAuths.get(key)
|
|
|
+ if (pending) {
|
|
|
+ clearTimeout(pending.timeout)
|
|
|
+ pendingAuths.delete(key)
|
|
|
+ mcpNameToState.delete(mcpName)
|
|
|
+ pending.reject(new Error("Authorization cancelled"))
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- export async function stop(): Promise<void> {
|
|
|
- if (server) {
|
|
|
- await new Promise<void>((resolve) => server!.close(() => resolve()))
|
|
|
- server = undefined
|
|
|
- log.info("oauth callback server stopped")
|
|
|
- }
|
|
|
+export async function isPortInUse(port: number = OAUTH_CALLBACK_PORT): Promise<boolean> {
|
|
|
+ return new Promise((resolve) => {
|
|
|
+ const socket = createConnection(port, "127.0.0.1")
|
|
|
+ socket.on("connect", () => {
|
|
|
+ socket.destroy()
|
|
|
+ resolve(true)
|
|
|
+ })
|
|
|
+ socket.on("error", () => {
|
|
|
+ resolve(false)
|
|
|
+ })
|
|
|
+ })
|
|
|
+}
|
|
|
|
|
|
- for (const [_name, pending] of pendingAuths) {
|
|
|
- clearTimeout(pending.timeout)
|
|
|
- pending.reject(new Error("OAuth callback server stopped"))
|
|
|
- }
|
|
|
- pendingAuths.clear()
|
|
|
- mcpNameToState.clear()
|
|
|
+export async function stop(): Promise<void> {
|
|
|
+ if (server) {
|
|
|
+ await new Promise<void>((resolve) => server!.close(() => resolve()))
|
|
|
+ server = undefined
|
|
|
+ log.info("oauth callback server stopped")
|
|
|
}
|
|
|
|
|
|
- export function isRunning(): boolean {
|
|
|
- return server !== undefined
|
|
|
+ for (const [_name, pending] of pendingAuths) {
|
|
|
+ clearTimeout(pending.timeout)
|
|
|
+ pending.reject(new Error("OAuth callback server stopped"))
|
|
|
}
|
|
|
+ pendingAuths.clear()
|
|
|
+ mcpNameToState.clear()
|
|
|
+}
|
|
|
+
|
|
|
+export function isRunning(): boolean {
|
|
|
+ return server !== undefined
|
|
|
}
|