sse_test.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. package mcpproxy_test
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "net/http/httptest"
  7. "strings"
  8. "testing"
  9. "time"
  10. "github.com/labring/aiproxy/core/common/mcpproxy"
  11. )
  12. type TestSessionManager struct {
  13. m map[string]string
  14. }
  15. func (t *TestSessionManager) New() string {
  16. return "test-session-id"
  17. }
  18. // Set stores a sessionID and its corresponding backend endpoint
  19. func (t *TestSessionManager) Set(sessionID string, endpoint string) {
  20. t.m[sessionID] = endpoint
  21. }
  22. // Get retrieves the backend endpoint for a sessionID
  23. func (t *TestSessionManager) Get(sessionID string) (string, bool) {
  24. v, ok := t.m[sessionID]
  25. return v, ok
  26. }
  27. // Delete removes a sessionID from the store
  28. func (t *TestSessionManager) Delete(string) {
  29. }
  30. type TestEndpointHandler struct{}
  31. func (h *TestEndpointHandler) NewEndpoint(_ string) string {
  32. return "/message?sessionId=test-session-id"
  33. }
  34. func (h *TestEndpointHandler) LoadEndpoint(endpoint string) string {
  35. if strings.Contains(endpoint, "test-session-id") {
  36. return "test-session-id"
  37. }
  38. return ""
  39. }
  40. func TestProxySSEEndpoint(t *testing.T) {
  41. reqDone := make(chan struct{})
  42. // Setup a mock backend server
  43. backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
  44. w.Header().Set("Content-Type", "text/event-stream")
  45. w.Header().Set("Cache-Control", "no-cache")
  46. w.Header().Set("Connection", "keep-alive")
  47. flusher, ok := w.(http.Flusher)
  48. if !ok {
  49. t.Fatal("Expected ResponseWriter to be a Flusher")
  50. }
  51. // Send an endpoint event
  52. fmt.Fprintf(w, "event: endpoint\n")
  53. fmt.Fprintf(w, "data: /message?sessionId=original-session-id\n\n")
  54. flusher.Flush()
  55. close(reqDone)
  56. }))
  57. defer backendServer.Close()
  58. // Create the proxy
  59. store := &TestSessionManager{
  60. m: map[string]string{},
  61. }
  62. handler := &TestEndpointHandler{}
  63. proxy := mcpproxy.NewSSEProxy(backendServer.URL+"/sse", nil, store, handler)
  64. // Setup the proxy server
  65. proxyServer := httptest.NewServer(http.HandlerFunc(proxy.SSEHandler))
  66. defer proxyServer.Close()
  67. // Make a request to the proxy
  68. req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, proxyServer.URL, nil)
  69. if err != nil {
  70. t.Fatalf("Error making request to proxy: %v", err)
  71. }
  72. resp, err := http.DefaultClient.Do(req)
  73. if err != nil {
  74. t.Fatalf("Error making request to proxy: %v", err)
  75. }
  76. defer resp.Body.Close()
  77. if resp.StatusCode != http.StatusOK {
  78. t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
  79. }
  80. select {
  81. case <-time.NewTimer(time.Second).C:
  82. t.Error("timeout")
  83. return
  84. case <-reqDone:
  85. }
  86. // Verify the session was stored
  87. endpoint, ok := store.Get("test-session-id")
  88. if !ok {
  89. t.Error("Session was not stored")
  90. }
  91. if !strings.Contains(endpoint, "/message?sessionId=original-session-id") {
  92. t.Errorf("Endpoint does not contain expected path, got: %s", endpoint)
  93. }
  94. }