streamable.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. package mcpproxy
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. )
  11. // StreamableProxy represents a proxy for the MCP Streamable HTTP transport
  12. type StreamableProxy struct {
  13. store SessionManager
  14. backend string
  15. headers map[string]string
  16. }
  17. // NewStreamableProxy creates a new proxy for the Streamable HTTP transport
  18. func NewStreamableProxy(
  19. backend string,
  20. headers map[string]string,
  21. store SessionManager,
  22. ) *StreamableProxy {
  23. return &StreamableProxy{
  24. store: store,
  25. backend: backend,
  26. headers: headers,
  27. }
  28. }
  29. // ServeHTTP handles both GET and POST requests for the Streamable HTTP transport
  30. func (p *StreamableProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  31. // Add CORS headers
  32. w.Header().Set("Access-Control-Allow-Origin", "*")
  33. w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
  34. w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
  35. w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
  36. // Handle preflight requests
  37. if r.Method == http.MethodOptions {
  38. w.WriteHeader(http.StatusOK)
  39. return
  40. }
  41. switch r.Method {
  42. case http.MethodGet:
  43. p.handleGetRequest(w, r)
  44. case http.MethodPost:
  45. p.handlePostRequest(w, r)
  46. case http.MethodDelete:
  47. p.handleDeleteRequest(w, r)
  48. default:
  49. http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
  50. }
  51. }
  52. // handleGetRequest handles GET requests for SSE streaming
  53. func (p *StreamableProxy) handleGetRequest(w http.ResponseWriter, r *http.Request) {
  54. // Check if Accept header includes text/event-stream
  55. acceptHeader := r.Header.Get("Accept")
  56. if !strings.Contains(acceptHeader, "text/event-stream") {
  57. http.Error(w, "Accept header must include text/event-stream", http.StatusBadRequest)
  58. return
  59. }
  60. // Get proxy session ID from header
  61. proxySessionID := r.Header.Get("Mcp-Session-Id")
  62. if proxySessionID == "" {
  63. // This might be an initialization request
  64. p.proxyInitialOrNoSessionRequest(w, r)
  65. return
  66. }
  67. // Look up the backend endpoint and session ID
  68. backendInfo, ok := p.store.Get(proxySessionID)
  69. if !ok {
  70. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  71. return
  72. }
  73. // Create a request to the backend
  74. req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, backendInfo, nil)
  75. if err != nil {
  76. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  77. return
  78. }
  79. // Copy headers from original request, but replace the session ID
  80. for name, values := range r.Header {
  81. if name == "Mcp-Session-Id" {
  82. continue // Skip the proxy session ID
  83. }
  84. for _, value := range values {
  85. req.Header.Add(name, value)
  86. }
  87. }
  88. // Extract the real backend session ID from the stored URL
  89. parts := strings.Split(backendInfo, "?sessionId=")
  90. if len(parts) > 1 {
  91. req.Header.Set("Mcp-Session-Id", parts[1])
  92. }
  93. // Add any additional headers
  94. for name, value := range p.headers {
  95. req.Header.Set(name, value)
  96. }
  97. //nolint:bodyclose
  98. resp, err := http.DefaultClient.Do(req)
  99. if err != nil {
  100. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  101. return
  102. }
  103. defer resp.Body.Close()
  104. // Check if we got an SSE response
  105. if resp.StatusCode != http.StatusOK ||
  106. !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
  107. // Copy response headers, but not the backend session ID
  108. for name, values := range resp.Header {
  109. if name == "Mcp-Session-Id" {
  110. continue
  111. }
  112. for _, value := range values {
  113. w.Header().Add(name, value)
  114. }
  115. }
  116. // Add our proxy session ID
  117. w.Header().Set("Mcp-Session-Id", proxySessionID)
  118. w.WriteHeader(resp.StatusCode)
  119. _, _ = io.Copy(w, resp.Body)
  120. return
  121. }
  122. // Set SSE headers for the client response
  123. w.Header().Set("Content-Type", "text/event-stream")
  124. w.Header().Set("Cache-Control", "no-cache")
  125. w.Header().Set("Connection", "keep-alive")
  126. // Create a context that cancels when the client disconnects
  127. ctx, cancel := context.WithCancel(r.Context())
  128. defer cancel()
  129. // Monitor client disconnection
  130. go func() {
  131. <-ctx.Done()
  132. resp.Body.Close()
  133. }()
  134. // Stream the SSE events to the client
  135. reader := bufio.NewReader(resp.Body)
  136. flusher, ok := w.(http.Flusher)
  137. if !ok {
  138. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  139. return
  140. }
  141. for {
  142. line, err := reader.ReadString('\n')
  143. if err != nil {
  144. if err == io.EOF {
  145. break
  146. }
  147. return
  148. }
  149. // Write the line to the client
  150. fmt.Fprint(w, line)
  151. flusher.Flush()
  152. }
  153. }
  154. // handlePostRequest handles POST requests for JSON-RPC messages
  155. func (p *StreamableProxy) handlePostRequest(w http.ResponseWriter, r *http.Request) {
  156. // Check if this is an initialization request
  157. proxySessionID := r.Header.Get("Mcp-Session-Id")
  158. if proxySessionID == "" {
  159. p.proxyInitialOrNoSessionRequest(w, r)
  160. return
  161. }
  162. // Look up the backend endpoint and session ID
  163. backendInfo, ok := p.store.Get(proxySessionID)
  164. if !ok {
  165. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  166. return
  167. }
  168. // Create a request to the backend
  169. req, err := http.NewRequestWithContext(r.Context(), http.MethodPost, backendInfo, r.Body)
  170. if err != nil {
  171. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  172. return
  173. }
  174. // Copy headers from original request, but replace the session ID
  175. for name, values := range r.Header {
  176. if name == "Mcp-Session-Id" {
  177. continue // Skip the proxy session ID
  178. }
  179. for _, value := range values {
  180. req.Header.Add(name, value)
  181. }
  182. }
  183. // Extract the real backend session ID from the stored URL
  184. parts := strings.Split(backendInfo, "?sessionId=")
  185. if len(parts) > 1 {
  186. req.Header.Set("Mcp-Session-Id", parts[1])
  187. }
  188. // Add any additional headers
  189. for name, value := range p.headers {
  190. req.Header.Set(name, value)
  191. }
  192. //nolint:bodyclose
  193. resp, err := http.DefaultClient.Do(req)
  194. if err != nil {
  195. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  196. return
  197. }
  198. defer resp.Body.Close()
  199. // Copy response headers, but not the backend session ID
  200. for name, values := range resp.Header {
  201. if name == "Mcp-Session-Id" {
  202. continue
  203. }
  204. for _, value := range values {
  205. w.Header().Add(name, value)
  206. }
  207. }
  208. // Add our proxy session ID
  209. w.Header().Set("Mcp-Session-Id", proxySessionID)
  210. contentType := resp.Header.Get("Content-Type")
  211. w.Header().Set("Content-Type", contentType)
  212. // Set response status code
  213. w.WriteHeader(resp.StatusCode)
  214. // Check if the response is an SSE stream
  215. if strings.Contains(contentType, "text/event-stream") {
  216. // Handle SSE response
  217. reader := bufio.NewReader(resp.Body)
  218. flusher, ok := w.(http.Flusher)
  219. if !ok {
  220. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  221. return
  222. }
  223. // Create a context that cancels when the client disconnects
  224. ctx, cancel := context.WithCancel(r.Context())
  225. defer cancel()
  226. // Monitor client disconnection
  227. go func() {
  228. <-ctx.Done()
  229. resp.Body.Close()
  230. }()
  231. for {
  232. line, err := reader.ReadString('\n')
  233. if err != nil {
  234. if err == io.EOF {
  235. break
  236. }
  237. return
  238. }
  239. // Write the line to the client
  240. _, _ = fmt.Fprint(w, line)
  241. flusher.Flush()
  242. }
  243. } else {
  244. // Copy regular response body
  245. _, _ = io.Copy(w, resp.Body)
  246. }
  247. }
  248. // handleDeleteRequest handles DELETE requests for session termination
  249. func (p *StreamableProxy) handleDeleteRequest(w http.ResponseWriter, r *http.Request) {
  250. // Get proxy session ID from header
  251. proxySessionID := r.Header.Get("Mcp-Session-Id")
  252. if proxySessionID == "" {
  253. http.Error(w, "Missing session ID", http.StatusBadRequest)
  254. return
  255. }
  256. // Look up the backend endpoint and session ID
  257. backendInfo, ok := p.store.Get(proxySessionID)
  258. if !ok {
  259. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  260. return
  261. }
  262. // Create a request to the backend
  263. req, err := http.NewRequestWithContext(r.Context(), http.MethodDelete, backendInfo, nil)
  264. if err != nil {
  265. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  266. return
  267. }
  268. // Copy headers from original request, but replace the session ID
  269. for name, values := range r.Header {
  270. if name == "Mcp-Session-Id" {
  271. continue // Skip the proxy session ID
  272. }
  273. for _, value := range values {
  274. req.Header.Add(name, value)
  275. }
  276. }
  277. // Extract the real backend session ID from the stored URL
  278. parts := strings.Split(backendInfo, "?sessionId=")
  279. if len(parts) > 1 {
  280. req.Header.Set("Mcp-Session-Id", parts[1])
  281. }
  282. // Add any additional headers
  283. for name, value := range p.headers {
  284. req.Header.Set(name, value)
  285. }
  286. // Make the request to the backend
  287. client := &http.Client{
  288. Timeout: time.Second * 10,
  289. }
  290. resp, err := client.Do(req)
  291. if err != nil {
  292. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  293. return
  294. }
  295. defer resp.Body.Close()
  296. // Remove the session from our store
  297. p.store.Delete(proxySessionID)
  298. // Copy response headers, but not the backend session ID
  299. for name, values := range resp.Header {
  300. if name == "Mcp-Session-Id" {
  301. continue
  302. }
  303. for _, value := range values {
  304. w.Header().Add(name, value)
  305. }
  306. }
  307. contentType := resp.Header.Get("Content-Type")
  308. w.Header().Set("Content-Type", contentType)
  309. // Set response status code
  310. w.WriteHeader(resp.StatusCode)
  311. // Copy response body
  312. _, _ = io.Copy(w, resp.Body)
  313. }
  314. // proxyInitialOrNoSessionRequest handles the initial request that doesn't have a session ID yet
  315. func (p *StreamableProxy) proxyInitialOrNoSessionRequest(w http.ResponseWriter, r *http.Request) {
  316. // Create a request to the backend
  317. req, err := http.NewRequestWithContext(r.Context(), r.Method, p.backend, r.Body)
  318. if err != nil {
  319. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  320. return
  321. }
  322. // Add any additional headers
  323. for name, value := range p.headers {
  324. req.Header.Set(name, value)
  325. }
  326. //nolint:bodyclose
  327. resp, err := http.DefaultClient.Do(req)
  328. if err != nil {
  329. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  330. return
  331. }
  332. defer resp.Body.Close()
  333. // Check if we received a session ID from the backend
  334. backendSessionID := resp.Header.Get("Mcp-Session-Id")
  335. if backendSessionID != "" {
  336. // Generate a new proxy session ID
  337. proxySessionID := p.store.New()
  338. // Store the mapping between our proxy session ID and the backend endpoint with its session
  339. // ID
  340. backendURL := p.backend
  341. if strings.Contains(backendURL, "?") {
  342. backendURL += "&sessionId=" + backendSessionID
  343. } else {
  344. backendURL += "?sessionId=" + backendSessionID
  345. }
  346. p.store.Set(proxySessionID, backendURL)
  347. // Replace the backend session ID with our proxy session ID in the response
  348. w.Header().Set("Mcp-Session-Id", proxySessionID)
  349. }
  350. contentType := resp.Header.Get("Content-Type")
  351. w.Header().Set("Content-Type", contentType)
  352. // Set response status code
  353. w.WriteHeader(resp.StatusCode)
  354. // Check if the response is an SSE stream
  355. if strings.Contains(contentType, "text/event-stream") {
  356. // Handle SSE response
  357. reader := bufio.NewReader(resp.Body)
  358. flusher, ok := w.(http.Flusher)
  359. if !ok {
  360. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  361. return
  362. }
  363. // Create a context that cancels when the client disconnects
  364. ctx, cancel := context.WithCancel(r.Context())
  365. defer cancel()
  366. // Monitor client disconnection
  367. go func() {
  368. <-ctx.Done()
  369. resp.Body.Close()
  370. }()
  371. for {
  372. line, err := reader.ReadString('\n')
  373. if err != nil {
  374. if err == io.EOF {
  375. break
  376. }
  377. return
  378. }
  379. // Write the line to the client
  380. fmt.Fprint(w, line)
  381. flusher.Flush()
  382. }
  383. } else {
  384. // Copy regular response body
  385. _, _ = io.Copy(w, resp.Body)
  386. }
  387. }