auth.ts 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import path from "path"
  2. import z from "zod"
  3. import { Global } from "../global"
  4. export namespace McpAuth {
  5. export const Tokens = z.object({
  6. accessToken: z.string(),
  7. refreshToken: z.string().optional(),
  8. expiresAt: z.number().optional(),
  9. scope: z.string().optional(),
  10. })
  11. export type Tokens = z.infer<typeof Tokens>
  12. export const ClientInfo = z.object({
  13. clientId: z.string(),
  14. clientSecret: z.string().optional(),
  15. clientIdIssuedAt: z.number().optional(),
  16. clientSecretExpiresAt: z.number().optional(),
  17. })
  18. export type ClientInfo = z.infer<typeof ClientInfo>
  19. export const Entry = z.object({
  20. tokens: Tokens.optional(),
  21. clientInfo: ClientInfo.optional(),
  22. codeVerifier: z.string().optional(),
  23. oauthState: z.string().optional(),
  24. serverUrl: z.string().optional(), // Track the URL these credentials are for
  25. })
  26. export type Entry = z.infer<typeof Entry>
  27. const filepath = path.join(Global.Path.data, "mcp-auth.json")
  28. export async function get(mcpName: string): Promise<Entry | undefined> {
  29. const data = await all()
  30. return data[mcpName]
  31. }
  32. /**
  33. * Get auth entry and validate it's for the correct URL.
  34. * Returns undefined if URL has changed (credentials are invalid).
  35. */
  36. export async function getForUrl(mcpName: string, serverUrl: string): Promise<Entry | undefined> {
  37. const entry = await get(mcpName)
  38. if (!entry) return undefined
  39. // If no serverUrl is stored, this is from an old version - consider it invalid
  40. if (!entry.serverUrl) return undefined
  41. // If URL has changed, credentials are invalid
  42. if (entry.serverUrl !== serverUrl) return undefined
  43. return entry
  44. }
  45. export async function all(): Promise<Record<string, Entry>> {
  46. const file = Bun.file(filepath)
  47. return file.json().catch(() => ({}))
  48. }
  49. export async function set(mcpName: string, entry: Entry, serverUrl?: string): Promise<void> {
  50. const file = Bun.file(filepath)
  51. const data = await all()
  52. // Always update serverUrl if provided
  53. if (serverUrl) {
  54. entry.serverUrl = serverUrl
  55. }
  56. await Bun.write(file, JSON.stringify({ ...data, [mcpName]: entry }, null, 2), { mode: 0o600 })
  57. }
  58. export async function remove(mcpName: string): Promise<void> {
  59. const file = Bun.file(filepath)
  60. const data = await all()
  61. delete data[mcpName]
  62. await Bun.write(file, JSON.stringify(data, null, 2), { mode: 0o600 })
  63. }
  64. export async function updateTokens(mcpName: string, tokens: Tokens, serverUrl?: string): Promise<void> {
  65. const entry = (await get(mcpName)) ?? {}
  66. entry.tokens = tokens
  67. await set(mcpName, entry, serverUrl)
  68. }
  69. export async function updateClientInfo(mcpName: string, clientInfo: ClientInfo, serverUrl?: string): Promise<void> {
  70. const entry = (await get(mcpName)) ?? {}
  71. entry.clientInfo = clientInfo
  72. await set(mcpName, entry, serverUrl)
  73. }
  74. export async function updateCodeVerifier(mcpName: string, codeVerifier: string): Promise<void> {
  75. const entry = (await get(mcpName)) ?? {}
  76. entry.codeVerifier = codeVerifier
  77. await set(mcpName, entry)
  78. }
  79. export async function clearCodeVerifier(mcpName: string): Promise<void> {
  80. const entry = await get(mcpName)
  81. if (entry) {
  82. delete entry.codeVerifier
  83. await set(mcpName, entry)
  84. }
  85. }
  86. export async function updateOAuthState(mcpName: string, oauthState: string): Promise<void> {
  87. const entry = (await get(mcpName)) ?? {}
  88. entry.oauthState = oauthState
  89. await set(mcpName, entry)
  90. }
  91. export async function getOAuthState(mcpName: string): Promise<string | undefined> {
  92. const entry = await get(mcpName)
  93. return entry?.oauthState
  94. }
  95. export async function clearOAuthState(mcpName: string): Promise<void> {
  96. const entry = await get(mcpName)
  97. if (entry) {
  98. delete entry.oauthState
  99. await set(mcpName, entry)
  100. }
  101. }
  102. /**
  103. * Check if stored tokens are expired.
  104. * Returns null if no tokens exist, false if no expiry or not expired, true if expired.
  105. */
  106. export async function isTokenExpired(mcpName: string): Promise<boolean | null> {
  107. const entry = await get(mcpName)
  108. if (!entry?.tokens) return null
  109. if (!entry.tokens.expiresAt) return false
  110. return entry.tokens.expiresAt < Date.now() / 1000
  111. }
  112. }