streamable.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. package mcpproxy
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. )
  11. const (
  12. headerKeySessionID = "Mcp-Session-Id"
  13. )
  14. // StreamableProxy represents a proxy for the MCP Streamable HTTP transport
  15. type StreamableProxy struct {
  16. store SessionManager
  17. backend string
  18. headers map[string]string
  19. }
  20. // NewStreamableProxy creates a new proxy for the Streamable HTTP transport
  21. func NewStreamableProxy(
  22. backend string,
  23. headers map[string]string,
  24. store SessionManager,
  25. ) *StreamableProxy {
  26. return &StreamableProxy{
  27. store: store,
  28. backend: backend,
  29. headers: headers,
  30. }
  31. }
  32. // ServeHTTP handles both GET and POST requests for the Streamable HTTP transport
  33. func (p *StreamableProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  34. // Add CORS headers
  35. w.Header().Set("Access-Control-Allow-Origin", "*")
  36. w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
  37. w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
  38. w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
  39. // Handle preflight requests
  40. if r.Method == http.MethodOptions {
  41. w.WriteHeader(http.StatusOK)
  42. return
  43. }
  44. switch r.Method {
  45. case http.MethodGet:
  46. p.handleGetRequest(w, r)
  47. case http.MethodPost:
  48. p.handlePostRequest(w, r)
  49. case http.MethodDelete:
  50. p.handleDeleteRequest(w, r)
  51. default:
  52. http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
  53. }
  54. }
  55. // handleGetRequest handles GET requests for SSE streaming
  56. func (p *StreamableProxy) handleGetRequest(w http.ResponseWriter, r *http.Request) {
  57. // Check if Accept header includes text/event-stream
  58. acceptHeader := r.Header.Get("Accept")
  59. if !strings.Contains(acceptHeader, "text/event-stream") {
  60. http.Error(w, "Accept header must include text/event-stream", http.StatusBadRequest)
  61. return
  62. }
  63. // Get proxy session ID from header
  64. proxySessionID := r.Header.Get(headerKeySessionID)
  65. if proxySessionID == "" {
  66. // This might be an initialization request
  67. p.proxyInitialOrNoSessionRequest(w, r)
  68. return
  69. }
  70. // Look up the backend endpoint and session ID
  71. backendInfo, ok := p.store.Get(proxySessionID)
  72. if !ok {
  73. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  74. return
  75. }
  76. // Create a request to the backend
  77. req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, backendInfo, nil)
  78. if err != nil {
  79. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  80. return
  81. }
  82. // Extract the real backend session ID from the stored URL
  83. parts := strings.Split(backendInfo, "|sessionId=")
  84. if len(parts) > 1 {
  85. req.Header.Set(headerKeySessionID, parts[1])
  86. }
  87. // Add any additional headers
  88. for name, value := range p.headers {
  89. req.Header.Set(name, value)
  90. }
  91. req.Header.Set("Content-Type", r.Header.Get("Content-Type"))
  92. //nolint:bodyclose
  93. resp, err := http.DefaultClient.Do(req)
  94. if err != nil {
  95. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  96. return
  97. }
  98. defer resp.Body.Close()
  99. // Check if we got an SSE response
  100. if resp.StatusCode != http.StatusOK ||
  101. !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
  102. // Add our proxy session ID
  103. w.Header().Set(headerKeySessionID, proxySessionID)
  104. w.WriteHeader(resp.StatusCode)
  105. _, _ = io.Copy(w, resp.Body)
  106. return
  107. }
  108. // Set SSE headers for the client response
  109. w.Header().Set("Content-Type", "text/event-stream")
  110. w.Header().Set("Cache-Control", "no-cache")
  111. w.Header().Set("Connection", "keep-alive")
  112. // Create a context that cancels when the client disconnects
  113. ctx, cancel := context.WithCancel(r.Context())
  114. defer cancel()
  115. // Monitor client disconnection
  116. go func() {
  117. <-ctx.Done()
  118. resp.Body.Close()
  119. }()
  120. // Stream the SSE events to the client
  121. reader := bufio.NewReader(resp.Body)
  122. flusher, ok := w.(http.Flusher)
  123. if !ok {
  124. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  125. return
  126. }
  127. for {
  128. line, err := reader.ReadString('\n')
  129. if err != nil {
  130. if err == io.EOF {
  131. break
  132. }
  133. return
  134. }
  135. // Write the line to the client
  136. fmt.Fprint(w, line)
  137. flusher.Flush()
  138. }
  139. }
  140. // handlePostRequest handles POST requests for JSON-RPC messages
  141. func (p *StreamableProxy) handlePostRequest(w http.ResponseWriter, r *http.Request) {
  142. // Check if this is an initialization request
  143. proxySessionID := r.Header.Get(headerKeySessionID)
  144. if proxySessionID == "" {
  145. p.proxyInitialOrNoSessionRequest(w, r)
  146. return
  147. }
  148. // Look up the backend endpoint and session ID
  149. backendInfo, ok := p.store.Get(proxySessionID)
  150. if !ok {
  151. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  152. return
  153. }
  154. // Extract the real backend session ID from the stored URL
  155. parts := strings.Split(backendInfo, "|sessionId=")
  156. if len(parts) != 2 {
  157. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  158. return
  159. }
  160. backend := parts[0]
  161. sessionID := parts[1]
  162. // Create a request to the backend
  163. req, err := http.NewRequestWithContext(r.Context(), http.MethodPost, backend, r.Body)
  164. if err != nil {
  165. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  166. return
  167. }
  168. // Add any additional headers
  169. for name, value := range p.headers {
  170. req.Header.Set(name, value)
  171. }
  172. req.Header.Set(headerKeySessionID, sessionID)
  173. req.Header.Set("Accept", "application/json, text/event-stream")
  174. req.Header.Set("Content-Type", r.Header.Get("Content-Type"))
  175. //nolint:bodyclose
  176. resp, err := http.DefaultClient.Do(req)
  177. if err != nil {
  178. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  179. return
  180. }
  181. defer resp.Body.Close()
  182. // Add our proxy session ID
  183. w.Header().Set(headerKeySessionID, proxySessionID)
  184. contentType := resp.Header.Get("Content-Type")
  185. w.Header().Set("Content-Type", contentType)
  186. // Set response status code
  187. w.WriteHeader(resp.StatusCode)
  188. // Check if the response is an SSE stream
  189. if strings.Contains(contentType, "text/event-stream") {
  190. // Handle SSE response
  191. reader := bufio.NewReader(resp.Body)
  192. flusher, ok := w.(http.Flusher)
  193. if !ok {
  194. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  195. return
  196. }
  197. // Create a context that cancels when the client disconnects
  198. ctx, cancel := context.WithCancel(r.Context())
  199. defer cancel()
  200. // Monitor client disconnection
  201. go func() {
  202. <-ctx.Done()
  203. resp.Body.Close()
  204. }()
  205. for {
  206. line, err := reader.ReadString('\n')
  207. if err != nil {
  208. if err == io.EOF {
  209. break
  210. }
  211. return
  212. }
  213. // Write the line to the client
  214. _, _ = fmt.Fprint(w, line)
  215. flusher.Flush()
  216. }
  217. } else {
  218. // Copy regular response body
  219. _, _ = io.Copy(w, resp.Body)
  220. }
  221. }
  222. // handleDeleteRequest handles DELETE requests for session termination
  223. func (p *StreamableProxy) handleDeleteRequest(w http.ResponseWriter, r *http.Request) {
  224. // Get proxy session ID from header
  225. proxySessionID := r.Header.Get(headerKeySessionID)
  226. if proxySessionID == "" {
  227. http.Error(w, "Missing session ID", http.StatusBadRequest)
  228. return
  229. }
  230. // Look up the backend endpoint and session ID
  231. backendInfo, ok := p.store.Get(proxySessionID)
  232. if !ok {
  233. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  234. return
  235. }
  236. // Create a request to the backend
  237. req, err := http.NewRequestWithContext(r.Context(), http.MethodDelete, backendInfo, nil)
  238. if err != nil {
  239. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  240. return
  241. }
  242. // Extract the real backend session ID from the stored URL
  243. parts := strings.Split(backendInfo, "|sessionId=")
  244. if len(parts) > 1 {
  245. req.Header.Set(headerKeySessionID, parts[1])
  246. }
  247. // Add any additional headers
  248. for name, value := range p.headers {
  249. req.Header.Set(name, value)
  250. }
  251. // Make the request to the backend
  252. client := &http.Client{
  253. Timeout: time.Second * 10,
  254. }
  255. resp, err := client.Do(req)
  256. if err != nil {
  257. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  258. return
  259. }
  260. defer resp.Body.Close()
  261. // Remove the session from our store
  262. p.store.Delete(proxySessionID)
  263. contentType := resp.Header.Get("Content-Type")
  264. w.Header().Set("Content-Type", contentType)
  265. // Set response status code
  266. w.WriteHeader(resp.StatusCode)
  267. // Copy response body
  268. _, _ = io.Copy(w, resp.Body)
  269. }
  270. // proxyInitialOrNoSessionRequest handles the initial request that doesn't have a session ID yet
  271. func (p *StreamableProxy) proxyInitialOrNoSessionRequest(w http.ResponseWriter, r *http.Request) {
  272. // Create a request to the backend
  273. req, err := http.NewRequestWithContext(r.Context(), r.Method, p.backend, r.Body)
  274. if err != nil {
  275. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  276. return
  277. }
  278. // Add any additional headers
  279. for name, value := range p.headers {
  280. req.Header.Set(name, value)
  281. }
  282. req.Header.Set("Accept", "application/json, text/event-stream")
  283. req.Header.Set("Content-Type", r.Header.Get("Content-Type"))
  284. //nolint:bodyclose
  285. resp, err := http.DefaultClient.Do(req)
  286. if err != nil {
  287. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  288. return
  289. }
  290. defer resp.Body.Close()
  291. // Check if we received a session ID from the backend
  292. backendSessionID := resp.Header.Get(headerKeySessionID)
  293. if backendSessionID != "" {
  294. // Generate a new proxy session ID
  295. proxySessionID := p.store.New()
  296. // Store the mapping between our proxy session ID and the backend endpoint with its session
  297. // ID
  298. backendURL := p.backend
  299. backendURL += "|sessionId=" + backendSessionID
  300. p.store.Set(proxySessionID, backendURL)
  301. // Replace the backend session ID with our proxy session ID in the response
  302. w.Header().Set(headerKeySessionID, proxySessionID)
  303. }
  304. contentType := resp.Header.Get("Content-Type")
  305. w.Header().Set("Content-Type", contentType)
  306. // Set response status code
  307. w.WriteHeader(resp.StatusCode)
  308. // Check if the response is an SSE stream
  309. if strings.Contains(contentType, "text/event-stream") {
  310. // Handle SSE response
  311. reader := bufio.NewReader(resp.Body)
  312. flusher, ok := w.(http.Flusher)
  313. if !ok {
  314. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  315. return
  316. }
  317. // Create a context that cancels when the client disconnects
  318. ctx, cancel := context.WithCancel(r.Context())
  319. defer cancel()
  320. // Monitor client disconnection
  321. go func() {
  322. <-ctx.Done()
  323. resp.Body.Close()
  324. }()
  325. for {
  326. line, err := reader.ReadString('\n')
  327. if err != nil {
  328. if err == io.EOF {
  329. break
  330. }
  331. return
  332. }
  333. // Write the line to the client
  334. fmt.Fprint(w, line)
  335. flusher.Flush()
  336. }
  337. } else {
  338. // Copy regular response body
  339. _, _ = io.Copy(w, resp.Body)
  340. }
  341. }