package channel import ( "net/http" "net/http/httptest" "testing" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) ctx.Request.Header.Set("X-Trace-Id", "trace-123") info := &relaycommon.RelayInfo{ IsChannelTest: true, ChannelMeta: &relaycommon.ChannelMeta{ HeadersOverride: map[string]any{ "*": "", }, }, } headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) require.Empty(t, headers) } func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) ctx.Request.Header.Set("X-Trace-Id", "trace-123") info := &relaycommon.RelayInfo{ IsChannelTest: true, ChannelMeta: &relaycommon.ChannelMeta{ HeadersOverride: map[string]any{ "X-Upstream-Trace": "{client_header:X-Trace-Id}", }, }, } headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) _, ok := headers["X-Upstream-Trace"] require.False(t, ok) } func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(recorder) ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) ctx.Request.Header.Set("X-Trace-Id", "trace-123") info := &relaycommon.RelayInfo{ IsChannelTest: false, ChannelMeta: &relaycommon.ChannelMeta{ HeadersOverride: map[string]any{ "X-Upstream-Trace": "{client_header:X-Trace-Id}", }, }, } headers, err := processHeaderOverride(info, ctx) require.NoError(t, err) require.Equal(t, "trace-123", headers["X-Upstream-Trace"]) }