oauth-browser.test.ts 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import { test, expect, mock, beforeEach } from "bun:test"
  2. import { EventEmitter } from "events"
  3. // Track open() calls and control failure behavior
  4. let openShouldFail = false
  5. let openCalledWith: string | undefined
  6. mock.module("open", () => ({
  7. default: async (url: string) => {
  8. openCalledWith = url
  9. // Return a mock subprocess that emits an error if openShouldFail is true
  10. const subprocess = new EventEmitter()
  11. if (openShouldFail) {
  12. // Emit error asynchronously like a real subprocess would
  13. setTimeout(() => {
  14. subprocess.emit("error", new Error("spawn xdg-open ENOENT"))
  15. }, 10)
  16. }
  17. return subprocess
  18. },
  19. }))
  20. // Mock UnauthorizedError
  21. class MockUnauthorizedError extends Error {
  22. constructor() {
  23. super("Unauthorized")
  24. this.name = "UnauthorizedError"
  25. }
  26. }
  27. // Track what options were passed to each transport constructor
  28. const transportCalls: Array<{
  29. type: "streamable" | "sse"
  30. url: string
  31. options: { authProvider?: unknown }
  32. }> = []
  33. // Mock the transport constructors
  34. mock.module("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({
  35. StreamableHTTPClientTransport: class MockStreamableHTTP {
  36. url: string
  37. authProvider: { redirectToAuthorization?: (url: URL) => Promise<void> } | undefined
  38. constructor(url: URL, options?: { authProvider?: { redirectToAuthorization?: (url: URL) => Promise<void> } }) {
  39. this.url = url.toString()
  40. this.authProvider = options?.authProvider
  41. transportCalls.push({
  42. type: "streamable",
  43. url: url.toString(),
  44. options: options ?? {},
  45. })
  46. }
  47. async start() {
  48. // Simulate OAuth redirect by calling the authProvider's redirectToAuthorization
  49. if (this.authProvider?.redirectToAuthorization) {
  50. await this.authProvider.redirectToAuthorization(new URL("https://auth.example.com/authorize?client_id=test"))
  51. }
  52. throw new MockUnauthorizedError()
  53. }
  54. async finishAuth(_code: string) {
  55. // Mock successful auth completion
  56. }
  57. },
  58. }))
  59. mock.module("@modelcontextprotocol/sdk/client/sse.js", () => ({
  60. SSEClientTransport: class MockSSE {
  61. constructor(url: URL) {
  62. transportCalls.push({
  63. type: "sse",
  64. url: url.toString(),
  65. options: {},
  66. })
  67. }
  68. async start() {
  69. throw new Error("Mock SSE transport cannot connect")
  70. }
  71. },
  72. }))
  73. // Mock the MCP SDK Client to trigger OAuth flow
  74. mock.module("@modelcontextprotocol/sdk/client/index.js", () => ({
  75. Client: class MockClient {
  76. async connect(transport: { start: () => Promise<void> }) {
  77. await transport.start()
  78. }
  79. },
  80. }))
  81. // Mock UnauthorizedError in the auth module
  82. mock.module("@modelcontextprotocol/sdk/client/auth.js", () => ({
  83. UnauthorizedError: MockUnauthorizedError,
  84. }))
  85. beforeEach(() => {
  86. openShouldFail = false
  87. openCalledWith = undefined
  88. transportCalls.length = 0
  89. })
  90. // Import modules after mocking
  91. const { MCP } = await import("../../src/mcp/index")
  92. const { Bus } = await import("../../src/bus")
  93. const { McpOAuthCallback } = await import("../../src/mcp/oauth-callback")
  94. const { Instance } = await import("../../src/project/instance")
  95. const { tmpdir } = await import("../fixture/fixture")
  96. test("BrowserOpenFailed event is published when open() throws", async () => {
  97. await using tmp = await tmpdir({
  98. init: async (dir) => {
  99. await Bun.write(
  100. `${dir}/opencode.json`,
  101. JSON.stringify({
  102. $schema: "https://opencode.ai/config.json",
  103. mcp: {
  104. "test-oauth-server": {
  105. type: "remote",
  106. url: "https://example.com/mcp",
  107. },
  108. },
  109. }),
  110. )
  111. },
  112. })
  113. await Instance.provide({
  114. directory: tmp.path,
  115. fn: async () => {
  116. openShouldFail = true
  117. const events: Array<{ mcpName: string; url: string }> = []
  118. const unsubscribe = Bus.subscribe(MCP.BrowserOpenFailed, (evt) => {
  119. events.push(evt.properties)
  120. })
  121. // Run authenticate with a timeout to avoid waiting forever for the callback
  122. // Attach a handler immediately so callback shutdown rejections
  123. // don't show up as unhandled between tests.
  124. const authPromise = MCP.authenticate("test-oauth-server").catch(() => undefined)
  125. // Config.get() can be slow in tests, so give it plenty of time.
  126. await new Promise((resolve) => setTimeout(resolve, 2_000))
  127. // Stop the callback server and cancel any pending auth
  128. await McpOAuthCallback.stop()
  129. await authPromise
  130. unsubscribe()
  131. // Verify the BrowserOpenFailed event was published
  132. expect(events.length).toBe(1)
  133. expect(events[0].mcpName).toBe("test-oauth-server")
  134. expect(events[0].url).toContain("https://")
  135. },
  136. })
  137. })
  138. test("BrowserOpenFailed event is NOT published when open() succeeds", async () => {
  139. await using tmp = await tmpdir({
  140. init: async (dir) => {
  141. await Bun.write(
  142. `${dir}/opencode.json`,
  143. JSON.stringify({
  144. $schema: "https://opencode.ai/config.json",
  145. mcp: {
  146. "test-oauth-server-2": {
  147. type: "remote",
  148. url: "https://example.com/mcp",
  149. },
  150. },
  151. }),
  152. )
  153. },
  154. })
  155. await Instance.provide({
  156. directory: tmp.path,
  157. fn: async () => {
  158. openShouldFail = false
  159. const events: Array<{ mcpName: string; url: string }> = []
  160. const unsubscribe = Bus.subscribe(MCP.BrowserOpenFailed, (evt) => {
  161. events.push(evt.properties)
  162. })
  163. // Run authenticate with a timeout to avoid waiting forever for the callback
  164. const authPromise = MCP.authenticate("test-oauth-server-2").catch(() => undefined)
  165. // Config.get() can be slow in tests; also covers the ~500ms open() error-detection window.
  166. await new Promise((resolve) => setTimeout(resolve, 2_000))
  167. // Stop the callback server and cancel any pending auth
  168. await McpOAuthCallback.stop()
  169. await authPromise
  170. unsubscribe()
  171. // Verify NO BrowserOpenFailed event was published
  172. expect(events.length).toBe(0)
  173. // Verify open() was still called
  174. expect(openCalledWith).toBeDefined()
  175. },
  176. })
  177. })
  178. test("open() is called with the authorization URL", async () => {
  179. await using tmp = await tmpdir({
  180. init: async (dir) => {
  181. await Bun.write(
  182. `${dir}/opencode.json`,
  183. JSON.stringify({
  184. $schema: "https://opencode.ai/config.json",
  185. mcp: {
  186. "test-oauth-server-3": {
  187. type: "remote",
  188. url: "https://example.com/mcp",
  189. },
  190. },
  191. }),
  192. )
  193. },
  194. })
  195. await Instance.provide({
  196. directory: tmp.path,
  197. fn: async () => {
  198. openShouldFail = false
  199. openCalledWith = undefined
  200. // Run authenticate with a timeout to avoid waiting forever for the callback
  201. const authPromise = MCP.authenticate("test-oauth-server-3").catch(() => undefined)
  202. // Config.get() can be slow in tests; also covers the ~500ms open() error-detection window.
  203. await new Promise((resolve) => setTimeout(resolve, 2_000))
  204. // Stop the callback server and cancel any pending auth
  205. await McpOAuthCallback.stop()
  206. await authPromise
  207. // Verify open was called with a URL
  208. expect(openCalledWith).toBeDefined()
  209. expect(typeof openCalledWith).toBe("string")
  210. expect(openCalledWith!).toContain("https://")
  211. },
  212. })
  213. })