2
0

sse.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package mcpproxy
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "time"
  8. "github.com/bytedance/sonic"
  9. mcpservers "github.com/labring/aiproxy/mcp-servers"
  10. "github.com/mark3labs/mcp-go/mcp"
  11. )
  12. // SSEServer implements a Server-Sent Events (SSE) based MCP server.
  13. // It provides real-time communication capabilities over HTTP using the SSE protocol.
  14. type SSEServer struct {
  15. server mcpservers.Server
  16. messageEndpoint string
  17. eventQueue chan string
  18. keepAlive bool
  19. keepAliveInterval time.Duration
  20. }
  21. // SSEOption defines a function type for configuring SSEServer
  22. type SSEOption func(*SSEServer)
  23. // WithMessageEndpoint sets the message endpoint path
  24. func WithMessageEndpoint(endpoint string) SSEOption {
  25. return func(s *SSEServer) {
  26. s.messageEndpoint = endpoint
  27. }
  28. }
  29. func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption {
  30. return func(s *SSEServer) {
  31. s.keepAlive = true
  32. s.keepAliveInterval = keepAliveInterval
  33. }
  34. }
  35. func WithKeepAlive(keepAlive bool) SSEOption {
  36. return func(s *SSEServer) {
  37. s.keepAlive = keepAlive
  38. }
  39. }
  40. // NewSSEServer creates a new SSE server instance with the given MCP server and options.
  41. func NewSSEServer(server mcpservers.Server, opts ...SSEOption) *SSEServer {
  42. s := &SSEServer{
  43. server: server,
  44. messageEndpoint: "/message",
  45. keepAlive: false,
  46. keepAliveInterval: 30 * time.Second,
  47. eventQueue: make(chan string, 100),
  48. }
  49. // Apply all options
  50. for _, opt := range opts {
  51. opt(s)
  52. }
  53. return s
  54. }
  55. // handleSSE handles incoming SSE connection requests.
  56. // It sets up appropriate headers and creates a new session for the client.
  57. func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  58. if r.Method != http.MethodGet {
  59. http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
  60. return
  61. }
  62. w.Header().Set("Content-Type", "text/event-stream")
  63. w.Header().Set("Cache-Control", "no-cache")
  64. w.Header().Set("Connection", "keep-alive")
  65. w.Header().Set("Access-Control-Allow-Origin", "*")
  66. flusher, ok := w.(http.Flusher)
  67. if !ok {
  68. http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
  69. return
  70. }
  71. // Start keep alive : ping
  72. if s.keepAlive {
  73. go func() {
  74. ticker := time.NewTicker(s.keepAliveInterval)
  75. defer ticker.Stop()
  76. id := 0
  77. for {
  78. id++
  79. select {
  80. case <-ticker.C:
  81. message := mcp.JSONRPCRequest{
  82. JSONRPC: "2.0",
  83. ID: mcp.NewRequestId(id),
  84. Request: mcp.Request{
  85. Method: "ping",
  86. },
  87. }
  88. messageBytes, _ := sonic.Marshal(message)
  89. pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes)
  90. select {
  91. case s.eventQueue <- pingMsg:
  92. case <-r.Context().Done():
  93. return
  94. }
  95. case <-r.Context().Done():
  96. return
  97. }
  98. }
  99. }()
  100. }
  101. // Send the initial endpoint event
  102. fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.messageEndpoint)
  103. flusher.Flush()
  104. // Main event loop - this runs in the HTTP handler goroutine
  105. for {
  106. select {
  107. case event := <-s.eventQueue:
  108. // Write the event to the response
  109. fmt.Fprint(w, event)
  110. flusher.Flush()
  111. case <-r.Context().Done():
  112. return
  113. }
  114. }
  115. }
  116. // handleMessage processes incoming JSON-RPC messages from clients and sends responses
  117. // back through both the SSE connection and HTTP response.
  118. func (s *SSEServer) HandleMessage(ctx context.Context, req []byte) error {
  119. // Process message through MCPServer
  120. response := s.server.HandleMessage(ctx, req)
  121. // Only send response if there is one (not for notifications)
  122. if response != nil {
  123. var message string
  124. eventData, err := sonic.Marshal(response)
  125. if err != nil {
  126. message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n"
  127. } else {
  128. message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData)
  129. }
  130. // Queue the event for sending via SSE
  131. select {
  132. case s.eventQueue <- message:
  133. // Event queued successfully
  134. default:
  135. // Queue is full
  136. return errors.New("event queue is full")
  137. }
  138. }
  139. return nil
  140. }