auth.ts 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import path from "path"
  2. import fs from "fs/promises"
  3. import z from "zod"
  4. import { Global } from "../global"
  5. export namespace McpAuth {
  6. export const Tokens = z.object({
  7. accessToken: z.string(),
  8. refreshToken: z.string().optional(),
  9. expiresAt: z.number().optional(),
  10. scope: z.string().optional(),
  11. })
  12. export type Tokens = z.infer<typeof Tokens>
  13. export const ClientInfo = z.object({
  14. clientId: z.string(),
  15. clientSecret: z.string().optional(),
  16. clientIdIssuedAt: z.number().optional(),
  17. clientSecretExpiresAt: z.number().optional(),
  18. })
  19. export type ClientInfo = z.infer<typeof ClientInfo>
  20. export const Entry = z.object({
  21. tokens: Tokens.optional(),
  22. clientInfo: ClientInfo.optional(),
  23. codeVerifier: z.string().optional(),
  24. oauthState: z.string().optional(),
  25. serverUrl: z.string().optional(), // Track the URL these credentials are for
  26. })
  27. export type Entry = z.infer<typeof Entry>
  28. const filepath = path.join(Global.Path.data, "mcp-auth.json")
  29. export async function get(mcpName: string): Promise<Entry | undefined> {
  30. const data = await all()
  31. return data[mcpName]
  32. }
  33. /**
  34. * Get auth entry and validate it's for the correct URL.
  35. * Returns undefined if URL has changed (credentials are invalid).
  36. */
  37. export async function getForUrl(mcpName: string, serverUrl: string): Promise<Entry | undefined> {
  38. const entry = await get(mcpName)
  39. if (!entry) return undefined
  40. // If no serverUrl is stored, this is from an old version - consider it invalid
  41. if (!entry.serverUrl) return undefined
  42. // If URL has changed, credentials are invalid
  43. if (entry.serverUrl !== serverUrl) return undefined
  44. return entry
  45. }
  46. export async function all(): Promise<Record<string, Entry>> {
  47. const file = Bun.file(filepath)
  48. return file.json().catch(() => ({}))
  49. }
  50. export async function set(mcpName: string, entry: Entry, serverUrl?: string): Promise<void> {
  51. const file = Bun.file(filepath)
  52. const data = await all()
  53. // Always update serverUrl if provided
  54. if (serverUrl) {
  55. entry.serverUrl = serverUrl
  56. }
  57. await Bun.write(file, JSON.stringify({ ...data, [mcpName]: entry }, null, 2))
  58. await fs.chmod(file.name!, 0o600)
  59. }
  60. export async function remove(mcpName: string): Promise<void> {
  61. const file = Bun.file(filepath)
  62. const data = await all()
  63. delete data[mcpName]
  64. await Bun.write(file, JSON.stringify(data, null, 2))
  65. await fs.chmod(file.name!, 0o600)
  66. }
  67. export async function updateTokens(mcpName: string, tokens: Tokens, serverUrl?: string): Promise<void> {
  68. const entry = (await get(mcpName)) ?? {}
  69. entry.tokens = tokens
  70. await set(mcpName, entry, serverUrl)
  71. }
  72. export async function updateClientInfo(mcpName: string, clientInfo: ClientInfo, serverUrl?: string): Promise<void> {
  73. const entry = (await get(mcpName)) ?? {}
  74. entry.clientInfo = clientInfo
  75. await set(mcpName, entry, serverUrl)
  76. }
  77. export async function updateCodeVerifier(mcpName: string, codeVerifier: string): Promise<void> {
  78. const entry = (await get(mcpName)) ?? {}
  79. entry.codeVerifier = codeVerifier
  80. await set(mcpName, entry)
  81. }
  82. export async function clearCodeVerifier(mcpName: string): Promise<void> {
  83. const entry = await get(mcpName)
  84. if (entry) {
  85. delete entry.codeVerifier
  86. await set(mcpName, entry)
  87. }
  88. }
  89. export async function updateOAuthState(mcpName: string, oauthState: string): Promise<void> {
  90. const entry = (await get(mcpName)) ?? {}
  91. entry.oauthState = oauthState
  92. await set(mcpName, entry)
  93. }
  94. export async function getOAuthState(mcpName: string): Promise<string | undefined> {
  95. const entry = await get(mcpName)
  96. return entry?.oauthState
  97. }
  98. export async function clearOAuthState(mcpName: string): Promise<void> {
  99. const entry = await get(mcpName)
  100. if (entry) {
  101. delete entry.oauthState
  102. await set(mcpName, entry)
  103. }
  104. }
  105. /**
  106. * Check if stored tokens are expired.
  107. * Returns null if no tokens exist, false if no expiry or not expired, true if expired.
  108. */
  109. export async function isTokenExpired(mcpName: string): Promise<boolean | null> {
  110. const entry = await get(mcpName)
  111. if (!entry?.tokens) return null
  112. if (!entry.tokens.expiresAt) return false
  113. return entry.tokens.expiresAt < Date.now() / 1000
  114. }
  115. }