sse.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. package mcpproxy
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/url"
  9. "strings"
  10. "time"
  11. )
  12. type EndpointProvider interface {
  13. NewEndpoint(newSession string) (newEndpoint string)
  14. LoadEndpoint(endpoint string) (session string)
  15. }
  16. // SSEAProxy represents the proxy object that handles SSE and HTTP requests
  17. type SSEAProxy struct {
  18. store SessionManager
  19. endpoint EndpointProvider
  20. backend string
  21. headers map[string]string
  22. }
  23. // NewSSEProxy creates a new proxy with the given backend and endpoint handler
  24. func NewSSEProxy(backend string, headers map[string]string, store SessionManager, endpoint EndpointProvider) *SSEAProxy {
  25. return &SSEAProxy{
  26. store: store,
  27. endpoint: endpoint,
  28. backend: backend,
  29. headers: headers,
  30. }
  31. }
  32. func (p *SSEAProxy) SSEHandler(w http.ResponseWriter, r *http.Request) {
  33. SSEHandler(w, r, p.store, p.endpoint, p.backend, p.headers)
  34. }
  35. func SSEHandler(
  36. w http.ResponseWriter,
  37. r *http.Request,
  38. store SessionManager,
  39. endpoint EndpointProvider,
  40. backend string,
  41. headers map[string]string,
  42. ) {
  43. // Create a request to the backend SSE endpoint
  44. req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, backend, nil)
  45. if err != nil {
  46. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  47. return
  48. }
  49. // Copy headers from original request
  50. for name, value := range headers {
  51. req.Header.Set(name, value)
  52. }
  53. // Set necessary headers for SSE
  54. req.Header.Set("Accept", "text/event-stream")
  55. req.Header.Set("Cache-Control", "no-cache")
  56. req.Header.Set("Connection", "keep-alive")
  57. // Make the request to the backend
  58. //nolint:bodyclose
  59. resp, err := http.DefaultClient.Do(req)
  60. if err != nil {
  61. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  62. return
  63. }
  64. defer resp.Body.Close()
  65. // Set SSE headers for the client response
  66. w.Header().Set("Content-Type", "text/event-stream")
  67. w.Header().Set("Cache-Control", "no-cache")
  68. w.Header().Set("Connection", "keep-alive")
  69. w.Header().Set("Access-Control-Allow-Origin", "*")
  70. // Create a context that cancels when the client disconnects
  71. ctx, cancel := context.WithCancel(r.Context())
  72. defer cancel()
  73. // Monitor client disconnection
  74. go func() {
  75. <-ctx.Done()
  76. resp.Body.Close()
  77. }()
  78. // Parse the SSE stream and extract sessionID
  79. reader := bufio.NewReader(resp.Body)
  80. flusher, ok := w.(http.Flusher)
  81. if !ok {
  82. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  83. return
  84. }
  85. for {
  86. line, err := reader.ReadString('\n')
  87. if err != nil {
  88. if err == io.EOF {
  89. break
  90. }
  91. return
  92. }
  93. // Write the line to the client
  94. fmt.Fprint(w, line)
  95. flusher.Flush()
  96. // Check if this is an endpoint event with sessionID
  97. if strings.HasPrefix(line, "event: endpoint") {
  98. // Next line should contain the data
  99. dataLine, err := reader.ReadString('\n')
  100. if err != nil {
  101. return
  102. }
  103. newSession := store.New()
  104. newEndpoint := endpoint.NewEndpoint(newSession)
  105. defer func() {
  106. store.Delete(newSession)
  107. }()
  108. // Extract sessionID from data line
  109. // Example: data: /message?sessionId=3088a771-7961-44e8-9bdf-21953889f694
  110. if strings.HasPrefix(dataLine, "data: ") {
  111. endpoint := strings.TrimSpace(strings.TrimPrefix(dataLine, "data: "))
  112. copyURL := *req.URL
  113. backendHostURL := &copyURL
  114. backendHostURL.Path = ""
  115. backendHostURL.RawQuery = ""
  116. store.Set(newSession, backendHostURL.String()+endpoint)
  117. } else {
  118. break
  119. }
  120. // Write the data line to the client
  121. _, _ = fmt.Fprintf(w, "data: %s\n", newEndpoint)
  122. flusher.Flush()
  123. }
  124. }
  125. }
  126. func (p *SSEAProxy) ProxyHandler(w http.ResponseWriter, r *http.Request) {
  127. SSEProxyHandler(w, r, p.store, p.endpoint)
  128. }
  129. func SSEProxyHandler(
  130. w http.ResponseWriter,
  131. r *http.Request,
  132. store SessionManager,
  133. endpoint EndpointProvider,
  134. ) {
  135. // Extract sessionID from the request
  136. sessionID := endpoint.LoadEndpoint(r.URL.String())
  137. if sessionID == "" {
  138. http.Error(w, "Missing sessionId", http.StatusBadRequest)
  139. return
  140. }
  141. // Look up the backend endpoint
  142. backendEndpoint, ok := store.Get(sessionID)
  143. if !ok {
  144. http.Error(w, "Invalid or expired sessionId", http.StatusNotFound)
  145. return
  146. }
  147. u, err := url.Parse(backendEndpoint)
  148. if err != nil || (u.Scheme != "http" && u.Scheme != "https") {
  149. http.Error(w, "Invalid backend", http.StatusBadRequest)
  150. return
  151. }
  152. // Create a request to the backend
  153. req, err := http.NewRequestWithContext(r.Context(), r.Method, backendEndpoint, r.Body)
  154. if err != nil {
  155. http.Error(w, "Failed to create backend request", http.StatusInternalServerError)
  156. return
  157. }
  158. // Copy headers from original request
  159. for name, values := range r.Header {
  160. for _, value := range values {
  161. req.Header.Add(name, value)
  162. }
  163. }
  164. // Make the request to the backend
  165. client := &http.Client{
  166. Timeout: time.Second * 30,
  167. }
  168. resp, err := client.Do(req)
  169. if err != nil {
  170. http.Error(w, "Failed to connect to backend", http.StatusInternalServerError)
  171. return
  172. }
  173. defer resp.Body.Close()
  174. // Copy response headers
  175. for name, values := range resp.Header {
  176. for _, value := range values {
  177. w.Header().Add(name, value)
  178. }
  179. }
  180. // Set response status code
  181. w.WriteHeader(resp.StatusCode)
  182. // Copy response body
  183. _, _ = io.Copy(w, resp.Body)
  184. }