|
|
@@ -0,0 +1,81 @@
|
|
|
+package channel
|
|
|
+
|
|
|
+import (
|
|
|
+ "net/http"
|
|
|
+ "net/http/httptest"
|
|
|
+ "testing"
|
|
|
+
|
|
|
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
+ "github.com/stretchr/testify/require"
|
|
|
+)
|
|
|
+
|
|
|
+func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
+ recorder := httptest.NewRecorder()
|
|
|
+ ctx, _ := gin.CreateTestContext(recorder)
|
|
|
+ ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
|
+ ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
|
|
+
|
|
|
+ info := &relaycommon.RelayInfo{
|
|
|
+ IsChannelTest: true,
|
|
|
+ ChannelMeta: &relaycommon.ChannelMeta{
|
|
|
+ HeadersOverride: map[string]any{
|
|
|
+ "*": "",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ headers, err := processHeaderOverride(info, ctx)
|
|
|
+ require.NoError(t, err)
|
|
|
+ require.Empty(t, headers)
|
|
|
+}
|
|
|
+
|
|
|
+func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
+ recorder := httptest.NewRecorder()
|
|
|
+ ctx, _ := gin.CreateTestContext(recorder)
|
|
|
+ ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
|
+ ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
|
|
+
|
|
|
+ info := &relaycommon.RelayInfo{
|
|
|
+ IsChannelTest: true,
|
|
|
+ ChannelMeta: &relaycommon.ChannelMeta{
|
|
|
+ HeadersOverride: map[string]any{
|
|
|
+ "X-Upstream-Trace": "{client_header:X-Trace-Id}",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ headers, err := processHeaderOverride(info, ctx)
|
|
|
+ require.NoError(t, err)
|
|
|
+ _, ok := headers["X-Upstream-Trace"]
|
|
|
+ require.False(t, ok)
|
|
|
+}
|
|
|
+
|
|
|
+func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
+ recorder := httptest.NewRecorder()
|
|
|
+ ctx, _ := gin.CreateTestContext(recorder)
|
|
|
+ ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
|
+ ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
|
|
+
|
|
|
+ info := &relaycommon.RelayInfo{
|
|
|
+ IsChannelTest: false,
|
|
|
+ ChannelMeta: &relaycommon.ChannelMeta{
|
|
|
+ HeadersOverride: map[string]any{
|
|
|
+ "X-Upstream-Trace": "{client_header:X-Trace-Id}",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ headers, err := processHeaderOverride(info, ctx)
|
|
|
+ require.NoError(t, err)
|
|
|
+ require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
|
|
|
+}
|