sse_test.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. package mcpproxy_test
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/http/httptest"
  6. "strings"
  7. "testing"
  8. "time"
  9. "github.com/labring/aiproxy/core/common/mcpproxy"
  10. )
  11. type TestSessionManager struct {
  12. m map[string]string
  13. }
  14. func (t *TestSessionManager) New() string {
  15. return "test-session-id"
  16. }
  17. // Set stores a sessionID and its corresponding backend endpoint
  18. func (t *TestSessionManager) Set(sessionID string, endpoint string) {
  19. t.m[sessionID] = endpoint
  20. }
  21. // Get retrieves the backend endpoint for a sessionID
  22. func (t *TestSessionManager) Get(sessionID string) (string, bool) {
  23. v, ok := t.m[sessionID]
  24. return v, ok
  25. }
  26. // Delete removes a sessionID from the store
  27. func (t *TestSessionManager) Delete(string) {
  28. }
  29. type TestEndpointHandler struct{}
  30. func (h *TestEndpointHandler) NewEndpoint(_ string) string {
  31. return "/message?sessionId=test-session-id"
  32. }
  33. func (h *TestEndpointHandler) LoadEndpoint(endpoint string) string {
  34. if strings.Contains(endpoint, "test-session-id") {
  35. return "test-session-id"
  36. }
  37. return ""
  38. }
  39. func TestProxySSEEndpoint(t *testing.T) {
  40. reqDone := make(chan struct{})
  41. // Setup a mock backend server
  42. backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
  43. w.Header().Set("Content-Type", "text/event-stream")
  44. w.Header().Set("Cache-Control", "no-cache")
  45. w.Header().Set("Connection", "keep-alive")
  46. flusher, ok := w.(http.Flusher)
  47. if !ok {
  48. t.Fatal("Expected ResponseWriter to be a Flusher")
  49. }
  50. // Send an endpoint event
  51. fmt.Fprintf(w, "event: endpoint\n")
  52. fmt.Fprintf(w, "data: /message?sessionId=original-session-id\n\n")
  53. flusher.Flush()
  54. close(reqDone)
  55. }))
  56. defer backendServer.Close()
  57. // Create the proxy
  58. store := &TestSessionManager{
  59. m: map[string]string{},
  60. }
  61. handler := &TestEndpointHandler{}
  62. proxy := mcpproxy.NewSSEProxy(backendServer.URL+"/sse", nil, store, handler)
  63. // Setup the proxy server
  64. proxyServer := httptest.NewServer(http.HandlerFunc(proxy.SSEHandler))
  65. defer proxyServer.Close()
  66. // Make a request to the proxy
  67. req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, proxyServer.URL, nil)
  68. if err != nil {
  69. t.Fatalf("Error making request to proxy: %v", err)
  70. }
  71. resp, err := http.DefaultClient.Do(req)
  72. if err != nil {
  73. t.Fatalf("Error making request to proxy: %v", err)
  74. }
  75. defer resp.Body.Close()
  76. if resp.StatusCode != http.StatusOK {
  77. t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
  78. }
  79. select {
  80. case <-time.NewTimer(time.Second).C:
  81. t.Error("timeout")
  82. return
  83. case <-reqDone:
  84. }
  85. // Verify the session was stored
  86. endpoint, ok := store.Get("test-session-id")
  87. if !ok {
  88. t.Error("Session was not stored")
  89. }
  90. if !strings.Contains(endpoint, "/message?sessionId=original-session-id") {
  91. t.Errorf("Endpoint does not contain expected path, got: %s", endpoint)
  92. }
  93. }