|
|
@@ -12,6 +12,29 @@ import (
|
|
|
"github.com/labring/aiproxy/core/common/mcpproxy"
|
|
|
)
|
|
|
|
|
|
+type TestSessionManager struct {
|
|
|
+ m map[string]string
|
|
|
+}
|
|
|
+
|
|
|
+func (t *TestSessionManager) New() string {
|
|
|
+ return "test-session-id"
|
|
|
+}
|
|
|
+
|
|
|
+// Set stores a sessionID and its corresponding backend endpoint
|
|
|
+func (t *TestSessionManager) Set(sessionID string, endpoint string) {
|
|
|
+ t.m[sessionID] = endpoint
|
|
|
+}
|
|
|
+
|
|
|
+// Get retrieves the backend endpoint for a sessionID
|
|
|
+func (t *TestSessionManager) Get(sessionID string) (string, bool) {
|
|
|
+ v, ok := t.m[sessionID]
|
|
|
+ return v, ok
|
|
|
+}
|
|
|
+
|
|
|
+// Delete removes a sessionID from the store
|
|
|
+func (t *TestSessionManager) Delete(string) {
|
|
|
+}
|
|
|
+
|
|
|
type TestEndpointHandler struct{}
|
|
|
|
|
|
func (h *TestEndpointHandler) NewEndpoint(_ string) string {
|
|
|
@@ -26,6 +49,7 @@ func (h *TestEndpointHandler) LoadEndpoint(endpoint string) string {
|
|
|
}
|
|
|
|
|
|
func TestProxySSEEndpoint(t *testing.T) {
|
|
|
+ reqDone := make(chan struct{})
|
|
|
// Setup a mock backend server
|
|
|
backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
|
@@ -42,13 +66,14 @@ func TestProxySSEEndpoint(t *testing.T) {
|
|
|
fmt.Fprintf(w, "data: /message?sessionId=original-session-id\n\n")
|
|
|
flusher.Flush()
|
|
|
|
|
|
- // Keep the connection open for a bit
|
|
|
- time.Sleep(100 * time.Millisecond)
|
|
|
+ close(reqDone)
|
|
|
}))
|
|
|
defer backendServer.Close()
|
|
|
|
|
|
// Create the proxy
|
|
|
- store := mcpproxy.NewMemStore()
|
|
|
+ store := &TestSessionManager{
|
|
|
+ m: map[string]string{},
|
|
|
+ }
|
|
|
handler := &TestEndpointHandler{}
|
|
|
proxy := mcpproxy.NewSSEProxy(backendServer.URL+"/sse", nil, store, handler)
|
|
|
|
|
|
@@ -70,6 +95,13 @@ func TestProxySSEEndpoint(t *testing.T) {
|
|
|
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
|
|
}
|
|
|
|
|
|
+ select {
|
|
|
+ case <-time.NewTimer(time.Second).C:
|
|
|
+ t.Error("timeout")
|
|
|
+ return
|
|
|
+ case <-reqDone:
|
|
|
+ }
|
|
|
+
|
|
|
// Verify the session was stored
|
|
|
endpoint, ok := store.Get("test-session-id")
|
|
|
if !ok {
|