| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- package mcpproxy
- import (
- "context"
- "errors"
- "fmt"
- "net/http"
- "time"
- "github.com/bytedance/sonic"
- mcpservers "github.com/labring/aiproxy/mcp-servers"
- "github.com/mark3labs/mcp-go/mcp"
- )
- // SSEServer implements a Server-Sent Events (SSE) based MCP server.
- // It provides real-time communication capabilities over HTTP using the SSE protocol.
- type SSEServer struct {
- server mcpservers.Server
- messageEndpoint string
- eventQueue chan string
- keepAlive bool
- keepAliveInterval time.Duration
- }
- // SSEOption defines a function type for configuring SSEServer
- type SSEOption func(*SSEServer)
- // WithMessageEndpoint sets the message endpoint path
- func WithMessageEndpoint(endpoint string) SSEOption {
- return func(s *SSEServer) {
- s.messageEndpoint = endpoint
- }
- }
- func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption {
- return func(s *SSEServer) {
- s.keepAlive = true
- s.keepAliveInterval = keepAliveInterval
- }
- }
- func WithKeepAlive(keepAlive bool) SSEOption {
- return func(s *SSEServer) {
- s.keepAlive = keepAlive
- }
- }
- // NewSSEServer creates a new SSE server instance with the given MCP server and options.
- func NewSSEServer(server mcpservers.Server, opts ...SSEOption) *SSEServer {
- s := &SSEServer{
- server: server,
- messageEndpoint: "/message",
- keepAlive: false,
- keepAliveInterval: 30 * time.Second,
- eventQueue: make(chan string, 100),
- }
- // Apply all options
- for _, opt := range opts {
- opt(s)
- }
- return s
- }
- // handleSSE handles incoming SSE connection requests.
- // It sets up appropriate headers and creates a new session for the client.
- func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodGet {
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
- return
- }
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache")
- w.Header().Set("Connection", "keep-alive")
- w.Header().Set("Access-Control-Allow-Origin", "*")
- flusher, ok := w.(http.Flusher)
- if !ok {
- http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
- return
- }
- // Start keep alive : ping
- if s.keepAlive {
- go func() {
- ticker := time.NewTicker(s.keepAliveInterval)
- defer ticker.Stop()
- id := 0
- for {
- id++
- select {
- case <-ticker.C:
- message := mcp.JSONRPCRequest{
- JSONRPC: "2.0",
- ID: mcp.NewRequestId(id),
- Request: mcp.Request{
- Method: "ping",
- },
- }
- messageBytes, _ := sonic.Marshal(message)
- pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes)
- select {
- case s.eventQueue <- pingMsg:
- case <-r.Context().Done():
- return
- }
- case <-r.Context().Done():
- return
- }
- }
- }()
- }
- // Send the initial endpoint event
- fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.messageEndpoint)
- flusher.Flush()
- // Main event loop - this runs in the HTTP handler goroutine
- for {
- select {
- case event := <-s.eventQueue:
- // Write the event to the response
- fmt.Fprint(w, event)
- flusher.Flush()
- case <-r.Context().Done():
- return
- }
- }
- }
- // handleMessage processes incoming JSON-RPC messages from clients and sends responses
- // back through both the SSE connection and HTTP response.
- func (s *SSEServer) HandleMessage(ctx context.Context, req []byte) error {
- // Process message through MCPServer
- response := s.server.HandleMessage(ctx, req)
- // Only send response if there is one (not for notifications)
- if response != nil {
- var message string
- eventData, err := sonic.Marshal(response)
- if err != nil {
- message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n"
- } else {
- message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData)
- }
- // Queue the event for sending via SSE
- select {
- case s.eventQueue <- message:
- // Event queued successfully
- default:
- // Queue is full
- return errors.New("event queue is full")
- }
- }
- return nil
- }
|