streamable.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. package mcpproxy
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. )
  11. // StreamableProxy represents a proxy for the MCP Streamable HTTP transport
  12. type StreamableProxy struct {
  13. store SessionManager
  14. backend string
  15. headers map[string]string
  16. }
  17. // NewStreamableProxy creates a new proxy for the Streamable HTTP transport
  18. func NewStreamableProxy(backend string, headers map[string]string, store SessionManager) *StreamableProxy {
  19. return &StreamableProxy{
  20. store: store,
  21. backend: backend,
  22. headers: headers,
  23. }
  24. }
  25. // ServeHTTP handles both GET and POST requests for the Streamable HTTP transport
  26. func (p *StreamableProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  27. // Add CORS headers
  28. w.Header().Set("Access-Control-Allow-Origin", "*")
  29. w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
  30. w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
  31. w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
  32. // Handle preflight requests
  33. if r.Method == http.MethodOptions {
  34. w.WriteHeader(http.StatusOK)
  35. return
  36. }
  37. switch r.Method {
  38. case http.MethodGet:
  39. p.handleGetRequest(w, r)
  40. case http.MethodPost:
  41. p.handlePostRequest(w, r)
  42. case http.MethodDelete:
  43. p.handleDeleteRequest(w, r)
  44. default:
  45. http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
  46. }
  47. }
  48. // handleGetRequest handles GET requests for SSE streaming
  49. func (p *StreamableProxy) handleGetRequest(w http.ResponseWriter, r *http.Request) {
  50. // Check if Accept header includes text/event-stream
  51. acceptHeader := r.Header.Get("Accept")
  52. if !strings.Contains(acceptHeader, "text/event-stream") {
  53. http.Error(w, "Accept header must include text/event-stream", http.StatusBadRequest)
  54. return
  55. }
  56. // Get proxy session ID from header
  57. proxySessionID := r.Header.Get("Mcp-Session-Id")
  58. if proxySessionID == "" {
  59. // This might be an initialization request
  60. p.proxyInitialOrNoSessionRequest(w, r)
  61. return
  62. }
  63. // Look up the backend endpoint and session ID
  64. backendInfo, ok := p.store.Get(proxySessionID)
  65. if !ok {
  66. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  67. return
  68. }
  69. // Create a request to the backend
  70. req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, backendInfo, nil)
  71. if err != nil {
  72. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  73. return
  74. }
  75. // Copy headers from original request, but replace the session ID
  76. for name, values := range r.Header {
  77. if name == "Mcp-Session-Id" {
  78. continue // Skip the proxy session ID
  79. }
  80. for _, value := range values {
  81. req.Header.Add(name, value)
  82. }
  83. }
  84. // Extract the real backend session ID from the stored URL
  85. parts := strings.Split(backendInfo, "?sessionId=")
  86. if len(parts) > 1 {
  87. req.Header.Set("Mcp-Session-Id", parts[1])
  88. }
  89. // Add any additional headers
  90. for name, value := range p.headers {
  91. req.Header.Set(name, value)
  92. }
  93. //nolint:bodyclose
  94. resp, err := http.DefaultClient.Do(req)
  95. if err != nil {
  96. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  97. return
  98. }
  99. defer resp.Body.Close()
  100. // Check if we got an SSE response
  101. if resp.StatusCode != http.StatusOK || !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
  102. // Copy response headers, but not the backend session ID
  103. for name, values := range resp.Header {
  104. if name == "Mcp-Session-Id" {
  105. continue
  106. }
  107. for _, value := range values {
  108. w.Header().Add(name, value)
  109. }
  110. }
  111. // Add our proxy session ID
  112. w.Header().Set("Mcp-Session-Id", proxySessionID)
  113. w.WriteHeader(resp.StatusCode)
  114. _, _ = io.Copy(w, resp.Body)
  115. return
  116. }
  117. // Set SSE headers for the client response
  118. w.Header().Set("Content-Type", "text/event-stream")
  119. w.Header().Set("Cache-Control", "no-cache")
  120. w.Header().Set("Connection", "keep-alive")
  121. // Create a context that cancels when the client disconnects
  122. ctx, cancel := context.WithCancel(r.Context())
  123. defer cancel()
  124. // Monitor client disconnection
  125. go func() {
  126. <-ctx.Done()
  127. resp.Body.Close()
  128. }()
  129. // Stream the SSE events to the client
  130. reader := bufio.NewReader(resp.Body)
  131. flusher, ok := w.(http.Flusher)
  132. if !ok {
  133. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  134. return
  135. }
  136. for {
  137. line, err := reader.ReadString('\n')
  138. if err != nil {
  139. if err == io.EOF {
  140. break
  141. }
  142. return
  143. }
  144. // Write the line to the client
  145. fmt.Fprint(w, line)
  146. flusher.Flush()
  147. }
  148. }
  149. // handlePostRequest handles POST requests for JSON-RPC messages
  150. func (p *StreamableProxy) handlePostRequest(w http.ResponseWriter, r *http.Request) {
  151. // Check if this is an initialization request
  152. proxySessionID := r.Header.Get("Mcp-Session-Id")
  153. if proxySessionID == "" {
  154. p.proxyInitialOrNoSessionRequest(w, r)
  155. return
  156. }
  157. // Look up the backend endpoint and session ID
  158. backendInfo, ok := p.store.Get(proxySessionID)
  159. if !ok {
  160. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  161. return
  162. }
  163. // Create a request to the backend
  164. req, err := http.NewRequestWithContext(r.Context(), http.MethodPost, backendInfo, r.Body)
  165. if err != nil {
  166. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  167. return
  168. }
  169. // Copy headers from original request, but replace the session ID
  170. for name, values := range r.Header {
  171. if name == "Mcp-Session-Id" {
  172. continue // Skip the proxy session ID
  173. }
  174. for _, value := range values {
  175. req.Header.Add(name, value)
  176. }
  177. }
  178. // Extract the real backend session ID from the stored URL
  179. parts := strings.Split(backendInfo, "?sessionId=")
  180. if len(parts) > 1 {
  181. req.Header.Set("Mcp-Session-Id", parts[1])
  182. }
  183. // Add any additional headers
  184. for name, value := range p.headers {
  185. req.Header.Set(name, value)
  186. }
  187. //nolint:bodyclose
  188. resp, err := http.DefaultClient.Do(req)
  189. if err != nil {
  190. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  191. return
  192. }
  193. defer resp.Body.Close()
  194. // Copy response headers, but not the backend session ID
  195. for name, values := range resp.Header {
  196. if name == "Mcp-Session-Id" {
  197. continue
  198. }
  199. for _, value := range values {
  200. w.Header().Add(name, value)
  201. }
  202. }
  203. // Add our proxy session ID
  204. w.Header().Set("Mcp-Session-Id", proxySessionID)
  205. // Set response status code
  206. w.WriteHeader(resp.StatusCode)
  207. // Check if the response is an SSE stream
  208. if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
  209. // Handle SSE response
  210. reader := bufio.NewReader(resp.Body)
  211. flusher, ok := w.(http.Flusher)
  212. if !ok {
  213. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  214. return
  215. }
  216. // Create a context that cancels when the client disconnects
  217. ctx, cancel := context.WithCancel(r.Context())
  218. defer cancel()
  219. // Monitor client disconnection
  220. go func() {
  221. <-ctx.Done()
  222. resp.Body.Close()
  223. }()
  224. for {
  225. line, err := reader.ReadString('\n')
  226. if err != nil {
  227. if err == io.EOF {
  228. break
  229. }
  230. return
  231. }
  232. // Write the line to the client
  233. _, _ = fmt.Fprint(w, line)
  234. flusher.Flush()
  235. }
  236. } else {
  237. // Copy regular response body
  238. _, _ = io.Copy(w, resp.Body)
  239. }
  240. }
  241. // handleDeleteRequest handles DELETE requests for session termination
  242. func (p *StreamableProxy) handleDeleteRequest(w http.ResponseWriter, r *http.Request) {
  243. // Get proxy session ID from header
  244. proxySessionID := r.Header.Get("Mcp-Session-Id")
  245. if proxySessionID == "" {
  246. http.Error(w, "Missing session ID", http.StatusBadRequest)
  247. return
  248. }
  249. // Look up the backend endpoint and session ID
  250. backendInfo, ok := p.store.Get(proxySessionID)
  251. if !ok {
  252. http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
  253. return
  254. }
  255. // Create a request to the backend
  256. req, err := http.NewRequestWithContext(r.Context(), http.MethodDelete, backendInfo, nil)
  257. if err != nil {
  258. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  259. return
  260. }
  261. // Copy headers from original request, but replace the session ID
  262. for name, values := range r.Header {
  263. if name == "Mcp-Session-Id" {
  264. continue // Skip the proxy session ID
  265. }
  266. for _, value := range values {
  267. req.Header.Add(name, value)
  268. }
  269. }
  270. // Extract the real backend session ID from the stored URL
  271. parts := strings.Split(backendInfo, "?sessionId=")
  272. if len(parts) > 1 {
  273. req.Header.Set("Mcp-Session-Id", parts[1])
  274. }
  275. // Add any additional headers
  276. for name, value := range p.headers {
  277. req.Header.Set(name, value)
  278. }
  279. // Make the request to the backend
  280. client := &http.Client{
  281. Timeout: time.Second * 10,
  282. }
  283. resp, err := client.Do(req)
  284. if err != nil {
  285. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  286. return
  287. }
  288. defer resp.Body.Close()
  289. // Remove the session from our store
  290. p.store.Delete(proxySessionID)
  291. // Copy response headers, but not the backend session ID
  292. for name, values := range resp.Header {
  293. if name == "Mcp-Session-Id" {
  294. continue
  295. }
  296. for _, value := range values {
  297. w.Header().Add(name, value)
  298. }
  299. }
  300. // Set response status code
  301. w.WriteHeader(resp.StatusCode)
  302. // Copy response body
  303. _, _ = io.Copy(w, resp.Body)
  304. }
  305. // proxyInitialOrNoSessionRequest handles the initial request that doesn't have a session ID yet
  306. func (p *StreamableProxy) proxyInitialOrNoSessionRequest(w http.ResponseWriter, r *http.Request) {
  307. // Create a request to the backend
  308. req, err := http.NewRequestWithContext(r.Context(), r.Method, p.backend, r.Body)
  309. if err != nil {
  310. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  311. return
  312. }
  313. // Copy headers from original request
  314. for name, values := range r.Header {
  315. for _, value := range values {
  316. req.Header.Add(name, value)
  317. }
  318. }
  319. // Add any additional headers
  320. for name, value := range p.headers {
  321. req.Header.Set(name, value)
  322. }
  323. //nolint:bodyclose
  324. resp, err := http.DefaultClient.Do(req)
  325. if err != nil {
  326. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  327. return
  328. }
  329. defer resp.Body.Close()
  330. // Check if we received a session ID from the backend
  331. backendSessionID := resp.Header.Get("Mcp-Session-Id")
  332. if backendSessionID != "" {
  333. // Generate a new proxy session ID
  334. proxySessionID := p.store.New()
  335. // Store the mapping between our proxy session ID and the backend endpoint with its session ID
  336. backendURL := p.backend
  337. if strings.Contains(backendURL, "?") {
  338. backendURL += "&sessionId=" + backendSessionID
  339. } else {
  340. backendURL += "?sessionId=" + backendSessionID
  341. }
  342. p.store.Set(proxySessionID, backendURL)
  343. // Replace the backend session ID with our proxy session ID in the response
  344. w.Header().Set("Mcp-Session-Id", proxySessionID)
  345. }
  346. // Copy other response headers
  347. for name, values := range resp.Header {
  348. if name != "Mcp-Session-Id" { // Skip the original session ID
  349. for _, value := range values {
  350. w.Header().Add(name, value)
  351. }
  352. }
  353. }
  354. // Set response status code
  355. w.WriteHeader(resp.StatusCode)
  356. // Check if the response is an SSE stream
  357. if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
  358. // Handle SSE response
  359. reader := bufio.NewReader(resp.Body)
  360. flusher, ok := w.(http.Flusher)
  361. if !ok {
  362. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  363. return
  364. }
  365. // Create a context that cancels when the client disconnects
  366. ctx, cancel := context.WithCancel(r.Context())
  367. defer cancel()
  368. // Monitor client disconnection
  369. go func() {
  370. <-ctx.Done()
  371. resp.Body.Close()
  372. }()
  373. for {
  374. line, err := reader.ReadString('\n')
  375. if err != nil {
  376. if err == io.EOF {
  377. break
  378. }
  379. return
  380. }
  381. // Write the line to the client
  382. fmt.Fprint(w, line)
  383. flusher.Flush()
  384. }
  385. } else {
  386. // Copy regular response body
  387. _, _ = io.Copy(w, resp.Body)
  388. }
  389. }