|
|
@@ -0,0 +1,451 @@
|
|
|
+package mcpproxy
|
|
|
+
|
|
|
+import (
|
|
|
+ "bufio"
|
|
|
+ "context"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "net/http"
|
|
|
+ "strings"
|
|
|
+ "time"
|
|
|
+)
|
|
|
+
|
|
|
+// StreamableProxy represents a proxy for the MCP Streamable HTTP transport
|
|
|
+type StreamableProxy struct {
|
|
|
+ store SessionManager
|
|
|
+ backend string
|
|
|
+ headers map[string]string
|
|
|
+}
|
|
|
+
|
|
|
+// NewStreamableProxy creates a new proxy for the Streamable HTTP transport
|
|
|
+func NewStreamableProxy(backend string, headers map[string]string, store SessionManager) *StreamableProxy {
|
|
|
+ return &StreamableProxy{
|
|
|
+ store: store,
|
|
|
+ backend: backend,
|
|
|
+ headers: headers,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// ServeHTTP handles both GET and POST requests for the Streamable HTTP transport
|
|
|
+func (p *StreamableProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
|
+ // Add CORS headers
|
|
|
+ w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
|
+ w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
|
|
|
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
|
|
|
+ w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
|
|
|
+
|
|
|
+ // Handle preflight requests
|
|
|
+ if r.Method == http.MethodOptions {
|
|
|
+ w.WriteHeader(http.StatusOK)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ switch r.Method {
|
|
|
+ case http.MethodGet:
|
|
|
+ p.handleGetRequest(w, r)
|
|
|
+ case http.MethodPost:
|
|
|
+ p.handlePostRequest(w, r)
|
|
|
+ case http.MethodDelete:
|
|
|
+ p.handleDeleteRequest(w, r)
|
|
|
+ default:
|
|
|
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// handleGetRequest handles GET requests for SSE streaming
|
|
|
+func (p *StreamableProxy) handleGetRequest(w http.ResponseWriter, r *http.Request) {
|
|
|
+ // Check if Accept header includes text/event-stream
|
|
|
+ acceptHeader := r.Header.Get("Accept")
|
|
|
+ if !strings.Contains(acceptHeader, "text/event-stream") {
|
|
|
+ http.Error(w, "Accept header must include text/event-stream", http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Get proxy session ID from header
|
|
|
+ proxySessionID := r.Header.Get("Mcp-Session-Id")
|
|
|
+ if proxySessionID == "" {
|
|
|
+ // This might be an initialization request
|
|
|
+ p.proxyInitialOrNoSessionRequest(w, r)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Look up the backend endpoint and session ID
|
|
|
+ backendInfo, ok := p.store.Get(proxySessionID)
|
|
|
+ if !ok {
|
|
|
+ http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Create a request to the backend
|
|
|
+ req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, backendInfo, nil)
|
|
|
+ if err != nil {
|
|
|
+ http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Copy headers from original request, but replace the session ID
|
|
|
+ for name, values := range r.Header {
|
|
|
+ if name == "Mcp-Session-Id" {
|
|
|
+ continue // Skip the proxy session ID
|
|
|
+ }
|
|
|
+ for _, value := range values {
|
|
|
+ req.Header.Add(name, value)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Extract the real backend session ID from the stored URL
|
|
|
+ parts := strings.Split(backendInfo, "?sessionId=")
|
|
|
+ if len(parts) > 1 {
|
|
|
+ req.Header.Set("Mcp-Session-Id", parts[1])
|
|
|
+ }
|
|
|
+
|
|
|
+ // Add any additional headers
|
|
|
+ for name, value := range p.headers {
|
|
|
+ req.Header.Set(name, value)
|
|
|
+ }
|
|
|
+
|
|
|
+ //nolint:bodyclose
|
|
|
+ resp, err := http.DefaultClient.Do(req)
|
|
|
+ if err != nil {
|
|
|
+ http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer resp.Body.Close()
|
|
|
+
|
|
|
+ // Check if we got an SSE response
|
|
|
+ if resp.StatusCode != http.StatusOK || !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
|
|
|
+ // Copy response headers, but not the backend session ID
|
|
|
+ for name, values := range resp.Header {
|
|
|
+ if name == "Mcp-Session-Id" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ for _, value := range values {
|
|
|
+ w.Header().Add(name, value)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // Add our proxy session ID
|
|
|
+ w.Header().Set("Mcp-Session-Id", proxySessionID)
|
|
|
+
|
|
|
+ w.WriteHeader(resp.StatusCode)
|
|
|
+ _, _ = io.Copy(w, resp.Body)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Set SSE headers for the client response
|
|
|
+ w.Header().Set("Content-Type", "text/event-stream")
|
|
|
+ w.Header().Set("Cache-Control", "no-cache")
|
|
|
+ w.Header().Set("Connection", "keep-alive")
|
|
|
+
|
|
|
+ // Create a context that cancels when the client disconnects
|
|
|
+ ctx, cancel := context.WithCancel(r.Context())
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ // Monitor client disconnection
|
|
|
+ go func() {
|
|
|
+ <-ctx.Done()
|
|
|
+ resp.Body.Close()
|
|
|
+ }()
|
|
|
+
|
|
|
+ // Stream the SSE events to the client
|
|
|
+ reader := bufio.NewReader(resp.Body)
|
|
|
+ flusher, ok := w.(http.Flusher)
|
|
|
+ if !ok {
|
|
|
+ http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ for {
|
|
|
+ line, err := reader.ReadString('\n')
|
|
|
+ if err != nil {
|
|
|
+ if err == io.EOF {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Write the line to the client
|
|
|
+ fmt.Fprint(w, line)
|
|
|
+ flusher.Flush()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// handlePostRequest handles POST requests for JSON-RPC messages
|
|
|
+func (p *StreamableProxy) handlePostRequest(w http.ResponseWriter, r *http.Request) {
|
|
|
+ // Check if this is an initialization request
|
|
|
+ proxySessionID := r.Header.Get("Mcp-Session-Id")
|
|
|
+ if proxySessionID == "" {
|
|
|
+ p.proxyInitialOrNoSessionRequest(w, r)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Look up the backend endpoint and session ID
|
|
|
+ backendInfo, ok := p.store.Get(proxySessionID)
|
|
|
+ if !ok {
|
|
|
+ http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Create a request to the backend
|
|
|
+ req, err := http.NewRequestWithContext(r.Context(), http.MethodPost, backendInfo, r.Body)
|
|
|
+ if err != nil {
|
|
|
+ http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Copy headers from original request, but replace the session ID
|
|
|
+ for name, values := range r.Header {
|
|
|
+ if name == "Mcp-Session-Id" {
|
|
|
+ continue // Skip the proxy session ID
|
|
|
+ }
|
|
|
+ for _, value := range values {
|
|
|
+ req.Header.Add(name, value)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Extract the real backend session ID from the stored URL
|
|
|
+ parts := strings.Split(backendInfo, "?sessionId=")
|
|
|
+ if len(parts) > 1 {
|
|
|
+ req.Header.Set("Mcp-Session-Id", parts[1])
|
|
|
+ }
|
|
|
+
|
|
|
+ // Add any additional headers
|
|
|
+ for name, value := range p.headers {
|
|
|
+ req.Header.Set(name, value)
|
|
|
+ }
|
|
|
+
|
|
|
+ //nolint:bodyclose
|
|
|
+ resp, err := http.DefaultClient.Do(req)
|
|
|
+ if err != nil {
|
|
|
+ http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer resp.Body.Close()
|
|
|
+
|
|
|
+ // Copy response headers, but not the backend session ID
|
|
|
+ for name, values := range resp.Header {
|
|
|
+ if name == "Mcp-Session-Id" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ for _, value := range values {
|
|
|
+ w.Header().Add(name, value)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // Add our proxy session ID
|
|
|
+ w.Header().Set("Mcp-Session-Id", proxySessionID)
|
|
|
+
|
|
|
+ // Set response status code
|
|
|
+ w.WriteHeader(resp.StatusCode)
|
|
|
+
|
|
|
+ // Check if the response is an SSE stream
|
|
|
+ if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
|
|
|
+ // Handle SSE response
|
|
|
+ reader := bufio.NewReader(resp.Body)
|
|
|
+ flusher, ok := w.(http.Flusher)
|
|
|
+ if !ok {
|
|
|
+ http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Create a context that cancels when the client disconnects
|
|
|
+ ctx, cancel := context.WithCancel(r.Context())
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ // Monitor client disconnection
|
|
|
+ go func() {
|
|
|
+ <-ctx.Done()
|
|
|
+ resp.Body.Close()
|
|
|
+ }()
|
|
|
+
|
|
|
+ for {
|
|
|
+ line, err := reader.ReadString('\n')
|
|
|
+ if err != nil {
|
|
|
+ if err == io.EOF {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Write the line to the client
|
|
|
+ _, _ = fmt.Fprint(w, line)
|
|
|
+ flusher.Flush()
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // Copy regular response body
|
|
|
+ _, _ = io.Copy(w, resp.Body)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// handleDeleteRequest handles DELETE requests for session termination
|
|
|
+func (p *StreamableProxy) handleDeleteRequest(w http.ResponseWriter, r *http.Request) {
|
|
|
+ // Get proxy session ID from header
|
|
|
+ proxySessionID := r.Header.Get("Mcp-Session-Id")
|
|
|
+ if proxySessionID == "" {
|
|
|
+ http.Error(w, "Missing session ID", http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Look up the backend endpoint and session ID
|
|
|
+ backendInfo, ok := p.store.Get(proxySessionID)
|
|
|
+ if !ok {
|
|
|
+ http.Error(w, "Invalid or expired session ID", http.StatusNotFound)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Create a request to the backend
|
|
|
+ req, err := http.NewRequestWithContext(r.Context(), http.MethodDelete, backendInfo, nil)
|
|
|
+ if err != nil {
|
|
|
+ http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Copy headers from original request, but replace the session ID
|
|
|
+ for name, values := range r.Header {
|
|
|
+ if name == "Mcp-Session-Id" {
|
|
|
+ continue // Skip the proxy session ID
|
|
|
+ }
|
|
|
+ for _, value := range values {
|
|
|
+ req.Header.Add(name, value)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Extract the real backend session ID from the stored URL
|
|
|
+ parts := strings.Split(backendInfo, "?sessionId=")
|
|
|
+ if len(parts) > 1 {
|
|
|
+ req.Header.Set("Mcp-Session-Id", parts[1])
|
|
|
+ }
|
|
|
+
|
|
|
+ // Add any additional headers
|
|
|
+ for name, value := range p.headers {
|
|
|
+ req.Header.Set(name, value)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Make the request to the backend
|
|
|
+ client := &http.Client{
|
|
|
+ Timeout: time.Second * 10,
|
|
|
+ }
|
|
|
+
|
|
|
+ resp, err := client.Do(req)
|
|
|
+ if err != nil {
|
|
|
+ http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer resp.Body.Close()
|
|
|
+
|
|
|
+ // Remove the session from our store
|
|
|
+ p.store.Delete(proxySessionID)
|
|
|
+
|
|
|
+ // Copy response headers, but not the backend session ID
|
|
|
+ for name, values := range resp.Header {
|
|
|
+ if name == "Mcp-Session-Id" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ for _, value := range values {
|
|
|
+ w.Header().Add(name, value)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Set response status code
|
|
|
+ w.WriteHeader(resp.StatusCode)
|
|
|
+
|
|
|
+ // Copy response body
|
|
|
+ _, _ = io.Copy(w, resp.Body)
|
|
|
+}
|
|
|
+
|
|
|
+// proxyInitialOrNoSessionRequest handles the initial request that doesn't have a session ID yet
|
|
|
+func (p *StreamableProxy) proxyInitialOrNoSessionRequest(w http.ResponseWriter, r *http.Request) {
|
|
|
+ // Create a request to the backend
|
|
|
+ req, err := http.NewRequestWithContext(r.Context(), r.Method, p.backend, r.Body)
|
|
|
+ if err != nil {
|
|
|
+ http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Copy headers from original request
|
|
|
+ for name, values := range r.Header {
|
|
|
+ for _, value := range values {
|
|
|
+ req.Header.Add(name, value)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Add any additional headers
|
|
|
+ for name, value := range p.headers {
|
|
|
+ req.Header.Set(name, value)
|
|
|
+ }
|
|
|
+
|
|
|
+ //nolint:bodyclose
|
|
|
+ resp, err := http.DefaultClient.Do(req)
|
|
|
+ if err != nil {
|
|
|
+ http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer resp.Body.Close()
|
|
|
+
|
|
|
+ // Check if we received a session ID from the backend
|
|
|
+ backendSessionID := resp.Header.Get("Mcp-Session-Id")
|
|
|
+ if backendSessionID != "" {
|
|
|
+ // Generate a new proxy session ID
|
|
|
+ proxySessionID := p.store.New()
|
|
|
+
|
|
|
+ // Store the mapping between our proxy session ID and the backend endpoint with its session ID
|
|
|
+ backendURL := p.backend
|
|
|
+ if strings.Contains(backendURL, "?") {
|
|
|
+ backendURL += "&sessionId=" + backendSessionID
|
|
|
+ } else {
|
|
|
+ backendURL += "?sessionId=" + backendSessionID
|
|
|
+ }
|
|
|
+ p.store.Set(proxySessionID, backendURL)
|
|
|
+
|
|
|
+ // Replace the backend session ID with our proxy session ID in the response
|
|
|
+ w.Header().Set("Mcp-Session-Id", proxySessionID)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Copy other response headers
|
|
|
+ for name, values := range resp.Header {
|
|
|
+ if name != "Mcp-Session-Id" { // Skip the original session ID
|
|
|
+ for _, value := range values {
|
|
|
+ w.Header().Add(name, value)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Set response status code
|
|
|
+ w.WriteHeader(resp.StatusCode)
|
|
|
+
|
|
|
+ // Check if the response is an SSE stream
|
|
|
+ if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
|
|
|
+ // Handle SSE response
|
|
|
+ reader := bufio.NewReader(resp.Body)
|
|
|
+ flusher, ok := w.(http.Flusher)
|
|
|
+ if !ok {
|
|
|
+ http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Create a context that cancels when the client disconnects
|
|
|
+ ctx, cancel := context.WithCancel(r.Context())
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ // Monitor client disconnection
|
|
|
+ go func() {
|
|
|
+ <-ctx.Done()
|
|
|
+ resp.Body.Close()
|
|
|
+ }()
|
|
|
+
|
|
|
+ for {
|
|
|
+ line, err := reader.ReadString('\n')
|
|
|
+ if err != nil {
|
|
|
+ if err == io.EOF {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Write the line to the client
|
|
|
+ fmt.Fprint(w, line)
|
|
|
+ flusher.Flush()
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // Copy regular response body
|
|
|
+ _, _ = io.Copy(w, resp.Body)
|
|
|
+ }
|
|
|
+}
|