Răsfoiți Sursa

Merge pull request #3333 from seefs001/fix/channel-affinity-disable

fix: honor channel affinity skip-retry when channel is disabled
Calcium-Ion 2 săptămâni în urmă
părinte
comite
c667e4706a

+ 7 - 2
middleware/distributor.go

@@ -101,8 +101,13 @@ func Distribute() func(c *gin.Context) {
 
 				if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found {
 					preferred, err := model.CacheGetChannel(preferredChannelID)
-					if err == nil && preferred != nil && preferred.Status == common.ChannelStatusEnabled {
-						if usingGroup == "auto" {
+					if err == nil && preferred != nil {
+						if preferred.Status != common.ChannelStatusEnabled {
+							if service.ShouldSkipRetryAfterChannelAffinityFailure(c) {
+								abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorChannelDisabled))
+								return
+							}
+						} else if usingGroup == "auto" {
 							userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
 							autoGroups := service.GetUserAutoGroup(userGroup)
 							for _, g := range autoGroups {

+ 7 - 4
service/channel_affinity.go

@@ -610,14 +610,17 @@ func ShouldSkipRetryAfterChannelAffinityFailure(c *gin.Context) bool {
 		return false
 	}
 	v, ok := c.Get(ginKeyChannelAffinitySkipRetry)
-	if !ok {
-		return false
+	if ok {
+		b, ok := v.(bool)
+		if ok {
+			return b
+		}
 	}
-	b, ok := v.(bool)
+	meta, ok := getChannelAffinityMeta(c)
 	if !ok {
 		return false
 	}
-	return b
+	return meta.SkipRetry
 }
 
 func MarkChannelAffinityUsed(c *gin.Context, selectedGroup string, channelID int) {

+ 60 - 0
service/channel_affinity_template_test.go

@@ -116,6 +116,66 @@ func TestApplyChannelAffinityOverrideTemplate_MergeOperations(t *testing.T) {
 	require.Equal(t, "trim_prefix", secondOp["mode"])
 }
 
+func TestShouldSkipRetryAfterChannelAffinityFailure(t *testing.T) {
+	tests := []struct {
+		name string
+		ctx  func() *gin.Context
+		want bool
+	}{
+		{
+			name: "nil context",
+			ctx: func() *gin.Context {
+				return nil
+			},
+			want: false,
+		},
+		{
+			name: "explicit skip retry flag in context",
+			ctx: func() *gin.Context {
+				ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
+					RuleName:   "rule-explicit-flag",
+					SkipRetry:  false,
+					UsingGroup: "default",
+					ModelName:  "gpt-5",
+				})
+				ctx.Set(ginKeyChannelAffinitySkipRetry, true)
+				return ctx
+			},
+			want: true,
+		},
+		{
+			name: "fallback to matched rule meta",
+			ctx: func() *gin.Context {
+				return buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
+					RuleName:   "rule-skip-retry",
+					SkipRetry:  true,
+					UsingGroup: "default",
+					ModelName:  "gpt-5",
+				})
+			},
+			want: true,
+		},
+		{
+			name: "no flag and no skip retry meta",
+			ctx: func() *gin.Context {
+				return buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
+					RuleName:   "rule-no-skip-retry",
+					SkipRetry:  false,
+					UsingGroup: "default",
+					ModelName:  "gpt-5",
+				})
+			},
+			want: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			require.Equal(t, tt.want, ShouldSkipRetryAfterChannelAffinityFailure(tt.ctx()))
+		})
+	}
+}
+
 func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) {
 	gin.SetMode(gin.TestMode)