sse.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. package mcpproxy
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/url"
  9. "strings"
  10. "time"
  11. )
  12. type EndpointProvider interface {
  13. NewEndpoint() (newSession string, newEndpoint string)
  14. LoadEndpoint(endpoint string) (session string)
  15. }
  16. // Proxy represents the proxy object that handles SSE and HTTP requests
  17. type Proxy struct {
  18. store SessionManager
  19. endpoint EndpointProvider
  20. backend string
  21. headers map[string]string
  22. }
  23. // NewProxy creates a new proxy with the given backend and endpoint handler
  24. func NewProxy(backend string, headers map[string]string, store SessionManager, endpoint EndpointProvider) *Proxy {
  25. return &Proxy{
  26. store: store,
  27. endpoint: endpoint,
  28. backend: backend,
  29. headers: headers,
  30. }
  31. }
  32. func (p *Proxy) SSEHandler(w http.ResponseWriter, r *http.Request) {
  33. SSEHandler(w, r, p.store, p.endpoint, p.backend, p.headers)
  34. }
  35. func SSEHandler(
  36. w http.ResponseWriter,
  37. r *http.Request,
  38. store SessionManager,
  39. endpoint EndpointProvider,
  40. backend string,
  41. headers map[string]string,
  42. ) {
  43. // Create a request to the backend SSE endpoint
  44. req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, backend, nil)
  45. if err != nil {
  46. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  47. return
  48. }
  49. // Copy headers from original request
  50. for name, value := range headers {
  51. req.Header.Set(name, value)
  52. }
  53. // Set necessary headers for SSE
  54. req.Header.Set("Accept", "text/event-stream")
  55. req.Header.Set("Cache-Control", "no-cache")
  56. req.Header.Set("Connection", "keep-alive")
  57. // Make the request to the backend
  58. //nolint:bodyclose
  59. resp, err := http.DefaultClient.Do(req)
  60. if err != nil {
  61. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  62. return
  63. }
  64. defer resp.Body.Close()
  65. // Set SSE headers for the client response
  66. w.Header().Set("Content-Type", "text/event-stream")
  67. w.Header().Set("Cache-Control", "no-cache")
  68. w.Header().Set("Connection", "keep-alive")
  69. w.Header().Set("Access-Control-Allow-Origin", "*")
  70. // Create a context that cancels when the client disconnects
  71. ctx, cancel := context.WithCancel(r.Context())
  72. defer cancel()
  73. // Monitor client disconnection
  74. go func() {
  75. <-ctx.Done()
  76. resp.Body.Close()
  77. }()
  78. // Parse the SSE stream and extract sessionID
  79. reader := bufio.NewReader(resp.Body)
  80. flusher, ok := w.(http.Flusher)
  81. if !ok {
  82. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  83. return
  84. }
  85. for {
  86. line, err := reader.ReadString('\n')
  87. if err != nil {
  88. if err == io.EOF {
  89. break
  90. }
  91. return
  92. }
  93. // Write the line to the client
  94. fmt.Fprint(w, line)
  95. flusher.Flush()
  96. // Check if this is an endpoint event with sessionID
  97. if strings.HasPrefix(line, "event: endpoint") {
  98. // Next line should contain the data
  99. dataLine, err := reader.ReadString('\n')
  100. if err != nil {
  101. return
  102. }
  103. newSession, newEndpoint := endpoint.NewEndpoint()
  104. defer func() {
  105. store.Delete(newSession)
  106. }()
  107. // Extract sessionID from data line
  108. // Example: data: /message?sessionId=3088a771-7961-44e8-9bdf-21953889f694
  109. if strings.HasPrefix(dataLine, "data: ") {
  110. endpoint := strings.TrimSpace(strings.TrimPrefix(dataLine, "data: "))
  111. copyURL := *req.URL
  112. backendHostURL := &copyURL
  113. backendHostURL.Path = ""
  114. backendHostURL.RawQuery = ""
  115. store.Set(newSession, backendHostURL.String()+endpoint)
  116. } else {
  117. break
  118. }
  119. // Write the data line to the client
  120. fmt.Fprintf(w, "data: %s\n", newEndpoint)
  121. flusher.Flush()
  122. }
  123. }
  124. }
  125. func (p *Proxy) ProxyHandler(w http.ResponseWriter, r *http.Request) {
  126. ProxyHandler(w, r, p.store, p.endpoint)
  127. }
  128. func ProxyHandler(
  129. w http.ResponseWriter,
  130. r *http.Request,
  131. store SessionManager,
  132. endpoint EndpointProvider,
  133. ) {
  134. // Extract sessionID from the request
  135. sessionID := endpoint.LoadEndpoint(r.URL.String())
  136. if sessionID == "" {
  137. http.Error(w, "Missing sessionId", http.StatusBadRequest)
  138. return
  139. }
  140. // Look up the backend endpoint
  141. backendEndpoint, ok := store.Get(sessionID)
  142. if !ok {
  143. http.Error(w, "Invalid or expired sessionId", http.StatusNotFound)
  144. return
  145. }
  146. u, err := url.Parse(backendEndpoint)
  147. if err != nil || (u.Scheme != "http" && u.Scheme != "https") {
  148. http.Error(w, "Invalid backend", http.StatusBadRequest)
  149. return
  150. }
  151. // Create a request to the backend
  152. req, err := http.NewRequestWithContext(r.Context(), r.Method, backendEndpoint, r.Body)
  153. if err != nil {
  154. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  155. return
  156. }
  157. // Copy headers from original request
  158. for name, values := range r.Header {
  159. for _, value := range values {
  160. req.Header.Add(name, value)
  161. }
  162. }
  163. // Make the request to the backend
  164. client := &http.Client{
  165. Timeout: time.Second * 30,
  166. }
  167. resp, err := client.Do(req)
  168. if err != nil {
  169. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  170. return
  171. }
  172. defer resp.Body.Close()
  173. // Copy response headers
  174. for name, values := range resp.Header {
  175. for _, value := range values {
  176. w.Header().Add(name, value)
  177. }
  178. }
  179. // Set response status code
  180. w.WriteHeader(resp.StatusCode)
  181. // Copy response body
  182. _, _ = io.Copy(w, resp.Body)
  183. }