streamable.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. package mcpproxy
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. )
  11. // StreamableProxy represents a proxy for the MCP Streamable HTTP transport
  12. type StreamableProxy struct {
  13. store SessionManager
  14. backend string
  15. headers map[string]string
  16. }
  17. // NewStreamableProxy creates a new proxy for the Streamable HTTP transport
  18. func NewStreamableProxy(
  19. backend string,
  20. headers map[string]string,
  21. store SessionManager,
  22. ) *StreamableProxy {
  23. return &StreamableProxy{
  24. store: store,
  25. backend: backend,
  26. headers: headers,
  27. }
  28. }
  29. // ServeHTTP handles both GET and POST requests for the Streamable HTTP transport
  30. func (p *StreamableProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  31. // Add CORS headers
  32. w.Header().Set("Access-Control-Allow-Origin", "*")
  33. w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
  34. w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
  35. w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
  36. // Handle preflight requests
  37. if r.Method == http.MethodOptions {
  38. w.WriteHeader(http.StatusOK)
  39. return
  40. }
  41. switch r.Method {
  42. case http.MethodGet:
  43. p.handleGetRequest(w, r)
  44. case http.MethodPost:
  45. p.handlePostRequest(w, r)
  46. case http.MethodDelete:
  47. p.handleDeleteRequest(w, r)
  48. default:
  49. http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
  50. }
  51. }
  52. // handleGetRequest handles GET requests for SSE streaming
  53. func (p *StreamableProxy) handleGetRequest(w http.ResponseWriter, r *http.Request) {
  54. // Check if Accept header includes text/event-stream
  55. acceptHeader := r.Header.Get("Accept")
  56. if !strings.Contains(acceptHeader, "text/event-stream") {
  57. http.Error(w, "Accept header must include text/event-stream", http.StatusBadRequest)
  58. return
  59. }
  60. // Get proxy session ID from header
  61. proxySessionID := r.Header.Get("Mcp-Session-Id")
  62. if proxySessionID == "" {
  63. // This might be an initialization request
  64. p.proxyInitialOrNoSessionRequest(w, r)
  65. return
  66. }
  67. // Look up the backend endpoint and session ID
  68. backendInfo, ok := p.store.Get(proxySessionID)
  69. if !ok {
  70. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  71. return
  72. }
  73. // Create a request to the backend
  74. req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, backendInfo, nil)
  75. if err != nil {
  76. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  77. return
  78. }
  79. // Copy headers from original request, but replace the session ID
  80. for name, values := range r.Header {
  81. if name == "Mcp-Session-Id" {
  82. continue // Skip the proxy session ID
  83. }
  84. for _, value := range values {
  85. req.Header.Add(name, value)
  86. }
  87. }
  88. // Extract the real backend session ID from the stored URL
  89. parts := strings.Split(backendInfo, "?sessionId=")
  90. if len(parts) > 1 {
  91. req.Header.Set("Mcp-Session-Id", parts[1])
  92. }
  93. // Add any additional headers
  94. for name, value := range p.headers {
  95. req.Header.Set(name, value)
  96. }
  97. //nolint:bodyclose
  98. resp, err := http.DefaultClient.Do(req)
  99. if err != nil {
  100. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  101. return
  102. }
  103. defer resp.Body.Close()
  104. // Check if we got an SSE response
  105. if resp.StatusCode != http.StatusOK ||
  106. !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
  107. // Copy response headers, but not the backend session ID
  108. for name, values := range resp.Header {
  109. if name == "Mcp-Session-Id" {
  110. continue
  111. }
  112. for _, value := range values {
  113. w.Header().Add(name, value)
  114. }
  115. }
  116. // Add our proxy session ID
  117. w.Header().Set("Mcp-Session-Id", proxySessionID)
  118. w.WriteHeader(resp.StatusCode)
  119. _, _ = io.Copy(w, resp.Body)
  120. return
  121. }
  122. // Set SSE headers for the client response
  123. w.Header().Set("Content-Type", "text/event-stream")
  124. w.Header().Set("Cache-Control", "no-cache")
  125. w.Header().Set("Connection", "keep-alive")
  126. // Create a context that cancels when the client disconnects
  127. ctx, cancel := context.WithCancel(r.Context())
  128. defer cancel()
  129. // Monitor client disconnection
  130. go func() {
  131. <-ctx.Done()
  132. resp.Body.Close()
  133. }()
  134. // Stream the SSE events to the client
  135. reader := bufio.NewReader(resp.Body)
  136. flusher, ok := w.(http.Flusher)
  137. if !ok {
  138. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  139. return
  140. }
  141. for {
  142. line, err := reader.ReadString('\n')
  143. if err != nil {
  144. if err == io.EOF {
  145. break
  146. }
  147. return
  148. }
  149. // Write the line to the client
  150. fmt.Fprint(w, line)
  151. flusher.Flush()
  152. }
  153. }
  154. // handlePostRequest handles POST requests for JSON-RPC messages
  155. func (p *StreamableProxy) handlePostRequest(w http.ResponseWriter, r *http.Request) {
  156. // Check if this is an initialization request
  157. proxySessionID := r.Header.Get("Mcp-Session-Id")
  158. if proxySessionID == "" {
  159. p.proxyInitialOrNoSessionRequest(w, r)
  160. return
  161. }
  162. // Look up the backend endpoint and session ID
  163. backendInfo, ok := p.store.Get(proxySessionID)
  164. if !ok {
  165. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  166. return
  167. }
  168. // Create a request to the backend
  169. req, err := http.NewRequestWithContext(r.Context(), http.MethodPost, backendInfo, r.Body)
  170. if err != nil {
  171. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  172. return
  173. }
  174. // Copy headers from original request, but replace the session ID
  175. for name, values := range r.Header {
  176. if name == "Mcp-Session-Id" {
  177. continue // Skip the proxy session ID
  178. }
  179. for _, value := range values {
  180. req.Header.Add(name, value)
  181. }
  182. }
  183. // Extract the real backend session ID from the stored URL
  184. parts := strings.Split(backendInfo, "?sessionId=")
  185. if len(parts) > 1 {
  186. req.Header.Set("Mcp-Session-Id", parts[1])
  187. }
  188. // Add any additional headers
  189. for name, value := range p.headers {
  190. req.Header.Set(name, value)
  191. }
  192. //nolint:bodyclose
  193. resp, err := http.DefaultClient.Do(req)
  194. if err != nil {
  195. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  196. return
  197. }
  198. defer resp.Body.Close()
  199. // Copy response headers, but not the backend session ID
  200. for name, values := range resp.Header {
  201. if name == "Mcp-Session-Id" {
  202. continue
  203. }
  204. for _, value := range values {
  205. w.Header().Add(name, value)
  206. }
  207. }
  208. // Add our proxy session ID
  209. w.Header().Set("Mcp-Session-Id", proxySessionID)
  210. // Set response status code
  211. w.WriteHeader(resp.StatusCode)
  212. // Check if the response is an SSE stream
  213. if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
  214. // Handle SSE response
  215. reader := bufio.NewReader(resp.Body)
  216. flusher, ok := w.(http.Flusher)
  217. if !ok {
  218. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  219. return
  220. }
  221. // Create a context that cancels when the client disconnects
  222. ctx, cancel := context.WithCancel(r.Context())
  223. defer cancel()
  224. // Monitor client disconnection
  225. go func() {
  226. <-ctx.Done()
  227. resp.Body.Close()
  228. }()
  229. for {
  230. line, err := reader.ReadString('\n')
  231. if err != nil {
  232. if err == io.EOF {
  233. break
  234. }
  235. return
  236. }
  237. // Write the line to the client
  238. _, _ = fmt.Fprint(w, line)
  239. flusher.Flush()
  240. }
  241. } else {
  242. // Copy regular response body
  243. _, _ = io.Copy(w, resp.Body)
  244. }
  245. }
  246. // handleDeleteRequest handles DELETE requests for session termination
  247. func (p *StreamableProxy) handleDeleteRequest(w http.ResponseWriter, r *http.Request) {
  248. // Get proxy session ID from header
  249. proxySessionID := r.Header.Get("Mcp-Session-Id")
  250. if proxySessionID == "" {
  251. http.Error(w, "Missing session ID", http.StatusBadRequest)
  252. return
  253. }
  254. // Look up the backend endpoint and session ID
  255. backendInfo, ok := p.store.Get(proxySessionID)
  256. if !ok {
  257. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  258. return
  259. }
  260. // Create a request to the backend
  261. req, err := http.NewRequestWithContext(r.Context(), http.MethodDelete, backendInfo, nil)
  262. if err != nil {
  263. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  264. return
  265. }
  266. // Copy headers from original request, but replace the session ID
  267. for name, values := range r.Header {
  268. if name == "Mcp-Session-Id" {
  269. continue // Skip the proxy session ID
  270. }
  271. for _, value := range values {
  272. req.Header.Add(name, value)
  273. }
  274. }
  275. // Extract the real backend session ID from the stored URL
  276. parts := strings.Split(backendInfo, "?sessionId=")
  277. if len(parts) > 1 {
  278. req.Header.Set("Mcp-Session-Id", parts[1])
  279. }
  280. // Add any additional headers
  281. for name, value := range p.headers {
  282. req.Header.Set(name, value)
  283. }
  284. // Make the request to the backend
  285. client := &http.Client{
  286. Timeout: time.Second * 10,
  287. }
  288. resp, err := client.Do(req)
  289. if err != nil {
  290. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  291. return
  292. }
  293. defer resp.Body.Close()
  294. // Remove the session from our store
  295. p.store.Delete(proxySessionID)
  296. // Copy response headers, but not the backend session ID
  297. for name, values := range resp.Header {
  298. if name == "Mcp-Session-Id" {
  299. continue
  300. }
  301. for _, value := range values {
  302. w.Header().Add(name, value)
  303. }
  304. }
  305. // Set response status code
  306. w.WriteHeader(resp.StatusCode)
  307. // Copy response body
  308. _, _ = io.Copy(w, resp.Body)
  309. }
  310. // proxyInitialOrNoSessionRequest handles the initial request that doesn't have a session ID yet
  311. func (p *StreamableProxy) proxyInitialOrNoSessionRequest(w http.ResponseWriter, r *http.Request) {
  312. // Create a request to the backend
  313. req, err := http.NewRequestWithContext(r.Context(), r.Method, p.backend, r.Body)
  314. if err != nil {
  315. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  316. return
  317. }
  318. // Copy headers from original request
  319. for name, values := range r.Header {
  320. for _, value := range values {
  321. req.Header.Add(name, value)
  322. }
  323. }
  324. // Add any additional headers
  325. for name, value := range p.headers {
  326. req.Header.Set(name, value)
  327. }
  328. //nolint:bodyclose
  329. resp, err := http.DefaultClient.Do(req)
  330. if err != nil {
  331. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  332. return
  333. }
  334. defer resp.Body.Close()
  335. // Check if we received a session ID from the backend
  336. backendSessionID := resp.Header.Get("Mcp-Session-Id")
  337. if backendSessionID != "" {
  338. // Generate a new proxy session ID
  339. proxySessionID := p.store.New()
  340. // Store the mapping between our proxy session ID and the backend endpoint with its session
  341. // ID
  342. backendURL := p.backend
  343. if strings.Contains(backendURL, "?") {
  344. backendURL += "&sessionId=" + backendSessionID
  345. } else {
  346. backendURL += "?sessionId=" + backendSessionID
  347. }
  348. p.store.Set(proxySessionID, backendURL)
  349. // Replace the backend session ID with our proxy session ID in the response
  350. w.Header().Set("Mcp-Session-Id", proxySessionID)
  351. }
  352. // Copy other response headers
  353. for name, values := range resp.Header {
  354. if name != "Mcp-Session-Id" { // Skip the original session ID
  355. for _, value := range values {
  356. w.Header().Add(name, value)
  357. }
  358. }
  359. }
  360. // Set response status code
  361. w.WriteHeader(resp.StatusCode)
  362. // Check if the response is an SSE stream
  363. if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
  364. // Handle SSE response
  365. reader := bufio.NewReader(resp.Body)
  366. flusher, ok := w.(http.Flusher)
  367. if !ok {
  368. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  369. return
  370. }
  371. // Create a context that cancels when the client disconnects
  372. ctx, cancel := context.WithCancel(r.Context())
  373. defer cancel()
  374. // Monitor client disconnection
  375. go func() {
  376. <-ctx.Done()
  377. resp.Body.Close()
  378. }()
  379. for {
  380. line, err := reader.ReadString('\n')
  381. if err != nil {
  382. if err == io.EOF {
  383. break
  384. }
  385. return
  386. }
  387. // Write the line to the client
  388. fmt.Fprint(w, line)
  389. flusher.Flush()
  390. }
  391. } else {
  392. // Copy regular response body
  393. _, _ = io.Copy(w, resp.Body)
  394. }
  395. }