فهرست منبع

fix: ignore header passthrough during channel tests

Seefs 15 ساعت پیش
والد
کامیت
c78b37662b
2فایلهای تغییر یافته به همراه112 افزوده شده و 26 حذف شده
  1. 31 26
      relay/channel/api_request.go
  2. 81 0
      relay/channel/api_request_test.go

+ 31 - 26
relay/channel/api_request.go

@@ -171,35 +171,37 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
 
 	passAll := false
 	var passthroughRegex []*regexp.Regexp
-	for k := range info.HeadersOverride {
-		key := strings.TrimSpace(k)
-		if key == "" {
-			continue
-		}
-		if key == headerPassthroughAllKey {
-			passAll = true
-			continue
-		}
+	if !info.IsChannelTest {
+		for k := range info.HeadersOverride {
+			key := strings.TrimSpace(k)
+			if key == "" {
+				continue
+			}
+			if key == headerPassthroughAllKey {
+				passAll = true
+				continue
+			}
 
-		lower := strings.ToLower(key)
-		var pattern string
-		switch {
-		case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
-			pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
-		case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
-			pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
-		default:
-			continue
-		}
+			lower := strings.ToLower(key)
+			var pattern string
+			switch {
+			case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
+				pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
+			case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
+				pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
+			default:
+				continue
+			}
 
-		if pattern == "" {
-			return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
-		}
-		compiled, err := getHeaderPassthroughRegex(pattern)
-		if err != nil {
-			return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
+			if pattern == "" {
+				return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
+			}
+			compiled, err := getHeaderPassthroughRegex(pattern)
+			if err != nil {
+				return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
+			}
+			passthroughRegex = append(passthroughRegex, compiled)
 		}
-		passthroughRegex = append(passthroughRegex, compiled)
 	}
 
 	if passAll || len(passthroughRegex) > 0 {
@@ -243,6 +245,9 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
 		if !ok {
 			return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
 		}
+		if info.IsChannelTest && strings.HasPrefix(strings.TrimSpace(str), clientHeaderPlaceholderPrefix) {
+			continue
+		}
 
 		value, include, err := applyHeaderOverridePlaceholders(str, c, info.ApiKey)
 		if err != nil {

+ 81 - 0
relay/channel/api_request_test.go

@@ -0,0 +1,81 @@
+package channel
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/gin-gonic/gin"
+	"github.com/stretchr/testify/require"
+)
+
+func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+	ctx.Request.Header.Set("X-Trace-Id", "trace-123")
+
+	info := &relaycommon.RelayInfo{
+		IsChannelTest: true,
+		ChannelMeta: &relaycommon.ChannelMeta{
+			HeadersOverride: map[string]any{
+				"*": "",
+			},
+		},
+	}
+
+	headers, err := processHeaderOverride(info, ctx)
+	require.NoError(t, err)
+	require.Empty(t, headers)
+}
+
+func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+	ctx.Request.Header.Set("X-Trace-Id", "trace-123")
+
+	info := &relaycommon.RelayInfo{
+		IsChannelTest: true,
+		ChannelMeta: &relaycommon.ChannelMeta{
+			HeadersOverride: map[string]any{
+				"X-Upstream-Trace": "{client_header:X-Trace-Id}",
+			},
+		},
+	}
+
+	headers, err := processHeaderOverride(info, ctx)
+	require.NoError(t, err)
+	_, ok := headers["X-Upstream-Trace"]
+	require.False(t, ok)
+}
+
+func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+	ctx.Request.Header.Set("X-Trace-Id", "trace-123")
+
+	info := &relaycommon.RelayInfo{
+		IsChannelTest: false,
+		ChannelMeta: &relaycommon.ChannelMeta{
+			HeadersOverride: map[string]any{
+				"X-Upstream-Trace": "{client_header:X-Trace-Id}",
+			},
+		},
+	}
+
+	headers, err := processHeaderOverride(info, ctx)
+	require.NoError(t, err)
+	require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
+}