sse_test.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. package mcpproxy_test
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "net/http/httptest"
  7. "strings"
  8. "testing"
  9. "time"
  10. "github.com/labring/aiproxy/core/common/mcpproxy"
  11. )
  12. type TestEndpointHandler struct{}
  13. func (h *TestEndpointHandler) NewEndpoint() (string, string) {
  14. return "test-session-id", "/message?sessionId=test-session-id"
  15. }
  16. func (h *TestEndpointHandler) LoadEndpoint(endpoint string) string {
  17. if strings.Contains(endpoint, "test-session-id") {
  18. return "test-session-id"
  19. }
  20. return ""
  21. }
  22. func TestProxySSEEndpoint(t *testing.T) {
  23. // Setup a mock backend server
  24. backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
  25. w.Header().Set("Content-Type", "text/event-stream")
  26. w.Header().Set("Cache-Control", "no-cache")
  27. w.Header().Set("Connection", "keep-alive")
  28. flusher, ok := w.(http.Flusher)
  29. if !ok {
  30. t.Fatal("Expected ResponseWriter to be a Flusher")
  31. }
  32. // Send an endpoint event
  33. fmt.Fprintf(w, "event: endpoint\n")
  34. fmt.Fprintf(w, "data: /message?sessionId=original-session-id\n\n")
  35. flusher.Flush()
  36. // Keep the connection open for a bit
  37. time.Sleep(100 * time.Millisecond)
  38. }))
  39. defer backendServer.Close()
  40. // Create the proxy
  41. store := mcpproxy.NewMemStore()
  42. handler := &TestEndpointHandler{}
  43. proxy := mcpproxy.NewProxy(backendServer.URL+"/sse", nil, store, handler)
  44. // Setup the proxy server
  45. proxyServer := httptest.NewServer(http.HandlerFunc(proxy.SSEHandler))
  46. defer proxyServer.Close()
  47. // Make a request to the proxy
  48. req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, proxyServer.URL, nil)
  49. if err != nil {
  50. t.Fatalf("Error making request to proxy: %v", err)
  51. }
  52. resp, err := http.DefaultClient.Do(req)
  53. if err != nil {
  54. t.Fatalf("Error making request to proxy: %v", err)
  55. }
  56. defer resp.Body.Close()
  57. if resp.StatusCode != http.StatusOK {
  58. t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
  59. }
  60. // Verify the session was stored
  61. endpoint, ok := store.Get("test-session-id")
  62. if !ok {
  63. t.Error("Session was not stored")
  64. }
  65. if !strings.Contains(endpoint, "/message?sessionId=original-session-id") {
  66. t.Errorf("Endpoint does not contain expected path, got: %s", endpoint)
  67. }
  68. }