streamable.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. package mcpproxy
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. )
  11. const (
  12. headerKeySessionID = "Mcp-Session-Id"
  13. )
  14. // StreamableProxy represents a proxy for the MCP Streamable HTTP transport
  15. type StreamableProxy struct {
  16. store SessionManager
  17. backend string
  18. headers map[string]string
  19. }
  20. // NewStreamableProxy creates a new proxy for the Streamable HTTP transport
  21. func NewStreamableProxy(
  22. backend string,
  23. headers map[string]string,
  24. store SessionManager,
  25. ) *StreamableProxy {
  26. return &StreamableProxy{
  27. store: store,
  28. backend: backend,
  29. headers: headers,
  30. }
  31. }
  32. // ServeHTTP handles both GET and POST requests for the Streamable HTTP transport
  33. func (p *StreamableProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  34. // Add CORS headers
  35. w.Header().Set("Access-Control-Allow-Origin", "*")
  36. w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
  37. w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
  38. w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
  39. // Handle preflight requests
  40. if r.Method == http.MethodOptions {
  41. w.WriteHeader(http.StatusOK)
  42. return
  43. }
  44. switch r.Method {
  45. case http.MethodGet:
  46. p.handleGetRequest(w, r)
  47. case http.MethodPost:
  48. p.handlePostRequest(w, r)
  49. case http.MethodDelete:
  50. p.handleDeleteRequest(w, r)
  51. default:
  52. http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
  53. }
  54. }
  55. // handleGetRequest handles GET requests for SSE streaming
  56. func (p *StreamableProxy) handleGetRequest(w http.ResponseWriter, r *http.Request) {
  57. // Check if Accept header includes text/event-stream
  58. acceptHeader := r.Header.Get("Accept")
  59. if !strings.Contains(acceptHeader, "text/event-stream") {
  60. http.Error(w, "Accept header must include text/event-stream", http.StatusBadRequest)
  61. return
  62. }
  63. // Get proxy session ID from header
  64. proxySessionID := r.Header.Get(headerKeySessionID)
  65. if proxySessionID == "" {
  66. // This might be an initialization request
  67. p.proxyInitialOrNoSessionRequest(w, r)
  68. return
  69. }
  70. // Look up the backend endpoint and session ID
  71. backendInfo, ok := p.store.Get(proxySessionID)
  72. if !ok {
  73. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  74. return
  75. }
  76. // Create a request to the backend
  77. req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, backendInfo, nil)
  78. if err != nil {
  79. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  80. return
  81. }
  82. // Extract the real backend session ID from the stored URL
  83. parts := strings.Split(backendInfo, "?sessionId=")
  84. if len(parts) > 1 {
  85. req.Header.Set(headerKeySessionID, parts[1])
  86. }
  87. // Add any additional headers
  88. for name, value := range p.headers {
  89. req.Header.Set(name, value)
  90. }
  91. //nolint:bodyclose
  92. resp, err := http.DefaultClient.Do(req)
  93. if err != nil {
  94. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  95. return
  96. }
  97. defer resp.Body.Close()
  98. // Check if we got an SSE response
  99. if resp.StatusCode != http.StatusOK ||
  100. !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
  101. // Add our proxy session ID
  102. w.Header().Set(headerKeySessionID, proxySessionID)
  103. w.WriteHeader(resp.StatusCode)
  104. _, _ = io.Copy(w, resp.Body)
  105. return
  106. }
  107. // Set SSE headers for the client response
  108. w.Header().Set("Content-Type", "text/event-stream")
  109. w.Header().Set("Cache-Control", "no-cache")
  110. w.Header().Set("Connection", "keep-alive")
  111. // Create a context that cancels when the client disconnects
  112. ctx, cancel := context.WithCancel(r.Context())
  113. defer cancel()
  114. // Monitor client disconnection
  115. go func() {
  116. <-ctx.Done()
  117. resp.Body.Close()
  118. }()
  119. // Stream the SSE events to the client
  120. reader := bufio.NewReader(resp.Body)
  121. flusher, ok := w.(http.Flusher)
  122. if !ok {
  123. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  124. return
  125. }
  126. for {
  127. line, err := reader.ReadString('\n')
  128. if err != nil {
  129. if err == io.EOF {
  130. break
  131. }
  132. return
  133. }
  134. // Write the line to the client
  135. fmt.Fprint(w, line)
  136. flusher.Flush()
  137. }
  138. }
  139. // handlePostRequest handles POST requests for JSON-RPC messages
  140. func (p *StreamableProxy) handlePostRequest(w http.ResponseWriter, r *http.Request) {
  141. // Check if this is an initialization request
  142. proxySessionID := r.Header.Get(headerKeySessionID)
  143. if proxySessionID == "" {
  144. p.proxyInitialOrNoSessionRequest(w, r)
  145. return
  146. }
  147. // Look up the backend endpoint and session ID
  148. backendInfo, ok := p.store.Get(proxySessionID)
  149. if !ok {
  150. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  151. return
  152. }
  153. // Create a request to the backend
  154. req, err := http.NewRequestWithContext(r.Context(), http.MethodPost, backendInfo, r.Body)
  155. if err != nil {
  156. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  157. return
  158. }
  159. // Extract the real backend session ID from the stored URL
  160. parts := strings.Split(backendInfo, "?sessionId=")
  161. if len(parts) > 1 {
  162. req.Header.Set(headerKeySessionID, parts[1])
  163. }
  164. // Add any additional headers
  165. for name, value := range p.headers {
  166. req.Header.Set(name, value)
  167. }
  168. //nolint:bodyclose
  169. resp, err := http.DefaultClient.Do(req)
  170. if err != nil {
  171. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  172. return
  173. }
  174. defer resp.Body.Close()
  175. // Add our proxy session ID
  176. w.Header().Set(headerKeySessionID, proxySessionID)
  177. contentType := resp.Header.Get("Content-Type")
  178. w.Header().Set("Content-Type", contentType)
  179. // Set response status code
  180. w.WriteHeader(resp.StatusCode)
  181. // Check if the response is an SSE stream
  182. if strings.Contains(contentType, "text/event-stream") {
  183. // Handle SSE response
  184. reader := bufio.NewReader(resp.Body)
  185. flusher, ok := w.(http.Flusher)
  186. if !ok {
  187. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  188. return
  189. }
  190. // Create a context that cancels when the client disconnects
  191. ctx, cancel := context.WithCancel(r.Context())
  192. defer cancel()
  193. // Monitor client disconnection
  194. go func() {
  195. <-ctx.Done()
  196. resp.Body.Close()
  197. }()
  198. for {
  199. line, err := reader.ReadString('\n')
  200. if err != nil {
  201. if err == io.EOF {
  202. break
  203. }
  204. return
  205. }
  206. // Write the line to the client
  207. _, _ = fmt.Fprint(w, line)
  208. flusher.Flush()
  209. }
  210. } else {
  211. // Copy regular response body
  212. _, _ = io.Copy(w, resp.Body)
  213. }
  214. }
  215. // handleDeleteRequest handles DELETE requests for session termination
  216. func (p *StreamableProxy) handleDeleteRequest(w http.ResponseWriter, r *http.Request) {
  217. // Get proxy session ID from header
  218. proxySessionID := r.Header.Get(headerKeySessionID)
  219. if proxySessionID == "" {
  220. http.Error(w, "Missing session ID", http.StatusBadRequest)
  221. return
  222. }
  223. // Look up the backend endpoint and session ID
  224. backendInfo, ok := p.store.Get(proxySessionID)
  225. if !ok {
  226. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  227. return
  228. }
  229. // Create a request to the backend
  230. req, err := http.NewRequestWithContext(r.Context(), http.MethodDelete, backendInfo, nil)
  231. if err != nil {
  232. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  233. return
  234. }
  235. // Extract the real backend session ID from the stored URL
  236. parts := strings.Split(backendInfo, "?sessionId=")
  237. if len(parts) > 1 {
  238. req.Header.Set(headerKeySessionID, parts[1])
  239. }
  240. // Add any additional headers
  241. for name, value := range p.headers {
  242. req.Header.Set(name, value)
  243. }
  244. // Make the request to the backend
  245. client := &http.Client{
  246. Timeout: time.Second * 10,
  247. }
  248. resp, err := client.Do(req)
  249. if err != nil {
  250. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  251. return
  252. }
  253. defer resp.Body.Close()
  254. // Remove the session from our store
  255. p.store.Delete(proxySessionID)
  256. contentType := resp.Header.Get("Content-Type")
  257. w.Header().Set("Content-Type", contentType)
  258. // Set response status code
  259. w.WriteHeader(resp.StatusCode)
  260. // Copy response body
  261. _, _ = io.Copy(w, resp.Body)
  262. }
  263. // proxyInitialOrNoSessionRequest handles the initial request that doesn't have a session ID yet
  264. func (p *StreamableProxy) proxyInitialOrNoSessionRequest(w http.ResponseWriter, r *http.Request) {
  265. // Create a request to the backend
  266. req, err := http.NewRequestWithContext(r.Context(), r.Method, p.backend, r.Body)
  267. if err != nil {
  268. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  269. return
  270. }
  271. // Add any additional headers
  272. for name, value := range p.headers {
  273. req.Header.Set(name, value)
  274. }
  275. //nolint:bodyclose
  276. resp, err := http.DefaultClient.Do(req)
  277. if err != nil {
  278. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  279. return
  280. }
  281. defer resp.Body.Close()
  282. // Check if we received a session ID from the backend
  283. backendSessionID := resp.Header.Get(headerKeySessionID)
  284. if backendSessionID != "" {
  285. // Generate a new proxy session ID
  286. proxySessionID := p.store.New()
  287. // Store the mapping between our proxy session ID and the backend endpoint with its session
  288. // ID
  289. backendURL := p.backend
  290. if strings.Contains(backendURL, "?") {
  291. backendURL += "&sessionId=" + backendSessionID
  292. } else {
  293. backendURL += "?sessionId=" + backendSessionID
  294. }
  295. p.store.Set(proxySessionID, backendURL)
  296. // Replace the backend session ID with our proxy session ID in the response
  297. w.Header().Set(headerKeySessionID, proxySessionID)
  298. }
  299. contentType := resp.Header.Get("Content-Type")
  300. w.Header().Set("Content-Type", contentType)
  301. // Set response status code
  302. w.WriteHeader(resp.StatusCode)
  303. // Check if the response is an SSE stream
  304. if strings.Contains(contentType, "text/event-stream") {
  305. // Handle SSE response
  306. reader := bufio.NewReader(resp.Body)
  307. flusher, ok := w.(http.Flusher)
  308. if !ok {
  309. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  310. return
  311. }
  312. // Create a context that cancels when the client disconnects
  313. ctx, cancel := context.WithCancel(r.Context())
  314. defer cancel()
  315. // Monitor client disconnection
  316. go func() {
  317. <-ctx.Done()
  318. resp.Body.Close()
  319. }()
  320. for {
  321. line, err := reader.ReadString('\n')
  322. if err != nil {
  323. if err == io.EOF {
  324. break
  325. }
  326. return
  327. }
  328. // Write the line to the client
  329. fmt.Fprint(w, line)
  330. flusher.Flush()
  331. }
  332. } else {
  333. // Copy regular response body
  334. _, _ = io.Copy(w, resp.Body)
  335. }
  336. }