|
|
@@ -5,6 +5,7 @@ import type {
|
|
|
OAuthClientInformation,
|
|
|
OAuthClientInformationFull,
|
|
|
} from "@modelcontextprotocol/sdk/shared/auth.js"
|
|
|
+import { Effect } from "effect"
|
|
|
import { McpAuth } from "./auth"
|
|
|
import { Log } from "../util/log"
|
|
|
|
|
|
@@ -30,6 +31,7 @@ export class McpOAuthProvider implements OAuthClientProvider {
|
|
|
private serverUrl: string,
|
|
|
private config: McpOAuthConfig,
|
|
|
private callbacks: McpOAuthCallbacks,
|
|
|
+ private auth: McpAuth.Interface,
|
|
|
) {}
|
|
|
|
|
|
get redirectUrl(): string {
|
|
|
@@ -61,7 +63,7 @@ export class McpOAuthProvider implements OAuthClientProvider {
|
|
|
|
|
|
// Check stored client info (from dynamic registration)
|
|
|
// Use getForUrl to validate credentials are for the current server URL
|
|
|
- const entry = await McpAuth.getForUrl(this.mcpName, this.serverUrl)
|
|
|
+ const entry = await Effect.runPromise(this.auth.getForUrl(this.mcpName, this.serverUrl))
|
|
|
if (entry?.clientInfo) {
|
|
|
// Check if client secret has expired
|
|
|
if (entry.clientInfo.clientSecretExpiresAt && entry.clientInfo.clientSecretExpiresAt < Date.now() / 1000) {
|
|
|
@@ -79,15 +81,17 @@ export class McpOAuthProvider implements OAuthClientProvider {
|
|
|
}
|
|
|
|
|
|
async saveClientInformation(info: OAuthClientInformationFull): Promise<void> {
|
|
|
- await McpAuth.updateClientInfo(
|
|
|
- this.mcpName,
|
|
|
- {
|
|
|
- clientId: info.client_id,
|
|
|
- clientSecret: info.client_secret,
|
|
|
- clientIdIssuedAt: info.client_id_issued_at,
|
|
|
- clientSecretExpiresAt: info.client_secret_expires_at,
|
|
|
- },
|
|
|
- this.serverUrl,
|
|
|
+ await Effect.runPromise(
|
|
|
+ this.auth.updateClientInfo(
|
|
|
+ this.mcpName,
|
|
|
+ {
|
|
|
+ clientId: info.client_id,
|
|
|
+ clientSecret: info.client_secret,
|
|
|
+ clientIdIssuedAt: info.client_id_issued_at,
|
|
|
+ clientSecretExpiresAt: info.client_secret_expires_at,
|
|
|
+ },
|
|
|
+ this.serverUrl,
|
|
|
+ ),
|
|
|
)
|
|
|
log.info("saved dynamically registered client", {
|
|
|
mcpName: this.mcpName,
|
|
|
@@ -97,7 +101,7 @@ export class McpOAuthProvider implements OAuthClientProvider {
|
|
|
|
|
|
async tokens(): Promise<OAuthTokens | undefined> {
|
|
|
// Use getForUrl to validate tokens are for the current server URL
|
|
|
- const entry = await McpAuth.getForUrl(this.mcpName, this.serverUrl)
|
|
|
+ const entry = await Effect.runPromise(this.auth.getForUrl(this.mcpName, this.serverUrl))
|
|
|
if (!entry?.tokens) return undefined
|
|
|
|
|
|
return {
|
|
|
@@ -112,15 +116,17 @@ export class McpOAuthProvider implements OAuthClientProvider {
|
|
|
}
|
|
|
|
|
|
async saveTokens(tokens: OAuthTokens): Promise<void> {
|
|
|
- await McpAuth.updateTokens(
|
|
|
- this.mcpName,
|
|
|
- {
|
|
|
- accessToken: tokens.access_token,
|
|
|
- refreshToken: tokens.refresh_token,
|
|
|
- expiresAt: tokens.expires_in ? Date.now() / 1000 + tokens.expires_in : undefined,
|
|
|
- scope: tokens.scope,
|
|
|
- },
|
|
|
- this.serverUrl,
|
|
|
+ await Effect.runPromise(
|
|
|
+ this.auth.updateTokens(
|
|
|
+ this.mcpName,
|
|
|
+ {
|
|
|
+ accessToken: tokens.access_token,
|
|
|
+ refreshToken: tokens.refresh_token,
|
|
|
+ expiresAt: tokens.expires_in ? Date.now() / 1000 + tokens.expires_in : undefined,
|
|
|
+ scope: tokens.scope,
|
|
|
+ },
|
|
|
+ this.serverUrl,
|
|
|
+ ),
|
|
|
)
|
|
|
log.info("saved oauth tokens", { mcpName: this.mcpName })
|
|
|
}
|
|
|
@@ -131,11 +137,11 @@ export class McpOAuthProvider implements OAuthClientProvider {
|
|
|
}
|
|
|
|
|
|
async saveCodeVerifier(codeVerifier: string): Promise<void> {
|
|
|
- await McpAuth.updateCodeVerifier(this.mcpName, codeVerifier)
|
|
|
+ await Effect.runPromise(this.auth.updateCodeVerifier(this.mcpName, codeVerifier))
|
|
|
}
|
|
|
|
|
|
async codeVerifier(): Promise<string> {
|
|
|
- const entry = await McpAuth.get(this.mcpName)
|
|
|
+ const entry = await Effect.runPromise(this.auth.get(this.mcpName))
|
|
|
if (!entry?.codeVerifier) {
|
|
|
throw new Error(`No code verifier saved for MCP server: ${this.mcpName}`)
|
|
|
}
|
|
|
@@ -143,11 +149,11 @@ export class McpOAuthProvider implements OAuthClientProvider {
|
|
|
}
|
|
|
|
|
|
async saveState(state: string): Promise<void> {
|
|
|
- await McpAuth.updateOAuthState(this.mcpName, state)
|
|
|
+ await Effect.runPromise(this.auth.updateOAuthState(this.mcpName, state))
|
|
|
}
|
|
|
|
|
|
async state(): Promise<string> {
|
|
|
- const entry = await McpAuth.get(this.mcpName)
|
|
|
+ const entry = await Effect.runPromise(this.auth.get(this.mcpName))
|
|
|
if (entry?.oauthState) {
|
|
|
return entry.oauthState
|
|
|
}
|
|
|
@@ -159,28 +165,28 @@ export class McpOAuthProvider implements OAuthClientProvider {
|
|
|
const newState = Array.from(crypto.getRandomValues(new Uint8Array(32)))
|
|
|
.map((b) => b.toString(16).padStart(2, "0"))
|
|
|
.join("")
|
|
|
- await McpAuth.updateOAuthState(this.mcpName, newState)
|
|
|
+ await Effect.runPromise(this.auth.updateOAuthState(this.mcpName, newState))
|
|
|
return newState
|
|
|
}
|
|
|
|
|
|
async invalidateCredentials(type: "all" | "client" | "tokens"): Promise<void> {
|
|
|
log.info("invalidating credentials", { mcpName: this.mcpName, type })
|
|
|
- const entry = await McpAuth.get(this.mcpName)
|
|
|
+ const entry = await Effect.runPromise(this.auth.get(this.mcpName))
|
|
|
if (!entry) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
switch (type) {
|
|
|
case "all":
|
|
|
- await McpAuth.remove(this.mcpName)
|
|
|
+ await Effect.runPromise(this.auth.remove(this.mcpName))
|
|
|
break
|
|
|
case "client":
|
|
|
delete entry.clientInfo
|
|
|
- await McpAuth.set(this.mcpName, entry)
|
|
|
+ await Effect.runPromise(this.auth.set(this.mcpName, entry))
|
|
|
break
|
|
|
case "tokens":
|
|
|
delete entry.tokens
|
|
|
- await McpAuth.set(this.mcpName, entry)
|
|
|
+ await Effect.runPromise(this.auth.set(this.mcpName, entry))
|
|
|
break
|
|
|
}
|
|
|
}
|