Kaynağa Gözat

feat: ignore all failed channel when last retry (#343)

* feat: ignore all failed channel when last retry

* fix: ci lint
zijiren 4 ay önce
ebeveyn
işleme
2c6ed9993f

+ 52 - 20
core/controller/relay-channel.go

@@ -5,7 +5,6 @@ import (
 	"errors"
 	"fmt"
 	"math/rand/v2"
-	"slices"
 	"strconv"
 
 	"github.com/gin-gonic/gin"
@@ -143,7 +142,7 @@ func GetRandomChannel(
 	modelName string,
 	mode mode.Mode,
 	errorRates map[int64]float64,
-	ignoreChannel ...int64,
+	ignoreChannelMap map[int64]struct{},
 ) (*model.Channel, []*model.Channel, error) {
 	channelMap := make(map[int]*model.Channel)
 	if len(availableSet) != 0 {
@@ -183,7 +182,7 @@ func GetRandomChannel(
 		migratedChannels = append(migratedChannels, channel)
 	}
 
-	channel, err := getRandomChannel(migratedChannels, mode, errorRates, ignoreChannel...)
+	channel, err := ignoreChannel(migratedChannels, mode, errorRates, ignoreChannelMap, nil)
 
 	return channel, migratedChannels, err
 }
@@ -200,17 +199,17 @@ func getPriority(channel *model.Channel, errorRate float64) int32 {
 	return int32(float64(priority) / errorRate)
 }
 
-func getRandomChannel(
+func ignoreChannel(
 	channels []*model.Channel,
 	mode mode.Mode,
 	errorRates map[int64]float64,
-	ignoreChannel ...int64,
+	ignoreChannelIDs ...map[int64]struct{},
 ) (*model.Channel, error) {
 	if len(channels) == 0 {
 		return nil, ErrChannelsNotFound
 	}
 
-	channels = filterChannels(channels, mode, ignoreChannel...)
+	channels = filterChannels(channels, mode, ignoreChannelIDs...)
 	if len(channels) == 0 {
 		return nil, ErrChannelsExhausted
 	}
@@ -249,7 +248,7 @@ func getChannelWithFallback(
 	modelName string,
 	mode mode.Mode,
 	errorRates map[int64]float64,
-	ignoreChannelIDs ...int64,
+	ignoreChannelIDs map[int64]struct{},
 ) (*model.Channel, []*model.Channel, error) {
 	channel, migratedChannels, err := GetRandomChannel(
 		cache,
@@ -257,7 +256,7 @@ func getChannelWithFallback(
 		modelName,
 		mode,
 		errorRates,
-		ignoreChannelIDs...)
+		ignoreChannelIDs)
 	if err == nil {
 		return channel, migratedChannels, nil
 	}
@@ -272,6 +271,7 @@ func getChannelWithFallback(
 		modelName,
 		mode,
 		errorRates,
+		nil,
 	)
 
 	return channel, migratedChannels, err
@@ -280,7 +280,7 @@ func getChannelWithFallback(
 type initialChannel struct {
 	channel           *model.Channel
 	designatedChannel bool
-	ignoreChannelIDs  []int64
+	ignoreChannelIDs  map[int64]struct{}
 	errorRates        map[int64]float64
 	migratedChannels  []*model.Channel
 }
@@ -329,12 +329,12 @@ func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialC
 
 	mc := middleware.GetModelCaches(c)
 
-	ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
+	ignoreChannelIDs, err := monitor.GetBannedChannelsMapWithModel(c.Request.Context(), modelName)
 	if err != nil {
 		log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
 	}
 
-	log.Debugf("%s model banned channels: %+v", modelName, ids)
+	log.Debugf("%s model banned channels: %+v", modelName, ignoreChannelIDs)
 
 	errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
 	if err != nil {
@@ -347,14 +347,15 @@ func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialC
 		modelName,
 		m,
 		errorRates,
-		ids...)
+		ignoreChannelIDs,
+	)
 	if err != nil {
 		return nil, err
 	}
 
 	return &initialChannel{
 		channel:          channel,
-		ignoreChannelIDs: ids,
+		ignoreChannelIDs: ignoreChannelIDs,
 		errorRates:       errorRates,
 		migratedChannels: migratedChannels,
 	}, nil
@@ -365,7 +366,7 @@ func getWebSearchChannel(
 	mc *model.ModelCaches,
 	modelName string,
 ) (*model.Channel, error) {
-	ids, _ := monitor.GetBannedChannelsWithModel(ctx, modelName)
+	ignoreChannelIDs, _ := monitor.GetBannedChannelsMapWithModel(ctx, modelName)
 	errorRates, _ := monitor.GetModelChannelErrorRate(ctx, modelName)
 
 	channel, _, err := getChannelWithFallback(
@@ -374,7 +375,7 @@ func getWebSearchChannel(
 		modelName,
 		mode.ChatCompletions,
 		errorRates,
-		ids...)
+		ignoreChannelIDs)
 	if err != nil {
 		return nil, err
 	}
@@ -382,7 +383,7 @@ func getWebSearchChannel(
 	return channel, nil
 }
 
-func getRetryChannel(state *retryState) (*model.Channel, error) {
+func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.Channel, error) {
 	if state.exhausted {
 		if state.lastHasPermissionChannel == nil {
 			return nil, ErrChannelsExhausted
@@ -390,11 +391,28 @@ func getRetryChannel(state *retryState) (*model.Channel, error) {
 		return state.lastHasPermissionChannel, nil
 	}
 
-	newChannel, err := getRandomChannel(
+	// For the last retry, filter out all previously failed channels if there are other options
+	if currentRetry == totalRetries-1 && len(state.failedChannelIDs) > 0 {
+		// Check if there are channels available after filtering out failed channels
+		newChannel, err := ignoreChannel(
+			state.migratedChannels,
+			state.meta.Mode,
+			state.errorRates,
+			state.ignoreChannelIDs,
+			state.failedChannelIDs,
+		)
+		if err == nil {
+			return newChannel, nil
+		}
+		// If no channels available after filtering, fall back to not using failed channels filter
+	}
+
+	newChannel, err := ignoreChannel(
 		state.migratedChannels,
 		state.meta.Mode,
 		state.errorRates,
-		state.ignoreChannelIDs...)
+		state.ignoreChannelIDs,
+	)
 	if err != nil {
 		if !errors.Is(err, ErrChannelsExhausted) || state.lastHasPermissionChannel == nil {
 			return nil, err
@@ -411,7 +429,7 @@ func getRetryChannel(state *retryState) (*model.Channel, error) {
 func filterChannels(
 	channels []*model.Channel,
 	mode mode.Mode,
-	ignoreChannel ...int64,
+	ignoreChannel ...map[int64]struct{},
 ) []*model.Channel {
 	filtered := make([]*model.Channel, 0)
 	for _, channel := range channels {
@@ -428,7 +446,21 @@ func filterChannels(
 			continue
 		}
 
-		if slices.Contains(ignoreChannel, int64(channel.ID)) {
+		chid := int64(channel.ID)
+		needIgnore := false
+
+		for _, ignores := range ignoreChannel {
+			if ignores == nil {
+				continue
+			}
+
+			_, needIgnore = ignores[chid]
+			if needIgnore {
+				break
+			}
+		}
+
+		if needIgnore {
 			continue
 		}
 

+ 23 - 4
core/controller/relay-controller.go

@@ -342,9 +342,10 @@ func recordResult(
 type retryState struct {
 	retryTimes               int
 	lastHasPermissionChannel *model.Channel
-	ignoreChannelIDs         []int64
+	ignoreChannelIDs         map[int64]struct{}
 	errorRates               map[int64]float64
 	exhausted                bool
+	failedChannelIDs         map[int64]struct{} // Track all failed channels in this request
 
 	meta             *meta.Meta
 	price            model.Price
@@ -389,14 +390,22 @@ func initRetryState(
 		price:            price,
 		requestUsage:     meta.RequestUsage,
 		migratedChannels: channel.migratedChannels,
+		failedChannelIDs: make(map[int64]struct{}),
 	}
 
+	// Record initial failed channel
+	state.failedChannelIDs[int64(meta.Channel.ID)] = struct{}{}
+
 	if channel.designatedChannel {
 		state.exhausted = true
 	}
 
 	if !monitorplugin.ChannelHasPermission(result.Error) {
-		state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(channel.channel.ID))
+		if state.ignoreChannelIDs == nil {
+			state.ignoreChannelIDs = make(map[int64]struct{})
+		}
+
+		state.ignoreChannelIDs[int64(channel.channel.ID)] = struct{}{}
 	} else {
 		state.lastHasPermissionChannel = channel.channel
 	}
@@ -414,7 +423,7 @@ func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayControlle
 		lastStatusCode := state.result.Error.StatusCode()
 		lastChannelID := state.meta.Channel.ID
 
-		newChannel, err := getRetryChannel(state)
+		newChannel, err := getRetryChannel(state, i, state.retryTimes)
 		if err == nil {
 			err = prepareRetry(c)
 		}
@@ -482,6 +491,12 @@ func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayControlle
 		state.result, retry = RelayHelper(c, state.meta, relayController)
 
 		done := handleRetryResult(c, retry, newChannel, state)
+
+		// Record failed channel if retry is needed
+		if !done && state.result.Error != nil {
+			state.failedChannelIDs[int64(newChannel.ID)] = struct{}{}
+		}
+
 		if done || i == state.retryTimes-1 {
 			recordResult(
 				c,
@@ -538,7 +553,11 @@ func handleRetryResult(
 		}
 	} else {
 		if !hasPermission {
-			state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(newChannel.ID))
+			if state.ignoreChannelIDs == nil {
+				state.ignoreChannelIDs = make(map[int64]struct{})
+			}
+
+			state.ignoreChannelIDs[int64(newChannel.ID)] = struct{}{}
 			state.retryTimes++
 		} else {
 			state.lastHasPermissionChannel = newChannel

+ 20 - 0
core/monitor/memmodel.go

@@ -276,6 +276,26 @@ func (m *MemModelMonitor) GetBannedChannelsWithModel(
 	return banned, nil
 }
 
+func (m *MemModelMonitor) GetBannedChannelsMapWithModel(
+	_ context.Context,
+	model string,
+) (map[int64]struct{}, error) {
+	m.mu.RLock()
+	defer m.mu.RUnlock()
+
+	banned := make(map[int64]struct{})
+	if data, exists := m.models[model]; exists {
+		now := time.Now()
+		for channelID, channel := range data.channels {
+			if channel.bannedUntil.After(now) {
+				banned[channelID] = struct{}{}
+			}
+		}
+	}
+
+	return banned, nil
+}
+
 func (m *MemModelMonitor) GetAllBannedModelChannels(_ context.Context) (map[string][]int64, error) {
 	m.mu.RLock()
 	defer m.mu.RUnlock()

+ 30 - 0
core/monitor/model.go

@@ -263,6 +263,36 @@ func GetBannedChannelsWithModel(ctx context.Context, model string) ([]int64, err
 	return result, nil
 }
 
+// GetBannedChannelsMapWithModel gets banned channels for a specific model as a map for efficient lookups
+func GetBannedChannelsMapWithModel(ctx context.Context, model string) (map[int64]struct{}, error) {
+	if !common.RedisEnabled {
+		return memModelMonitor.GetBannedChannelsMapWithModel(ctx, model)
+	}
+
+	result := make(map[int64]struct{})
+	prefix := modelKeyPrefix() + model + channelKeyPart
+	pattern := prefix + "*" + bannedKeySuffix
+	iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
+
+	for iter.Next(ctx) {
+		key := iter.Val()
+		channelIDStr := strings.TrimSuffix(strings.TrimPrefix(key, prefix), bannedKeySuffix)
+
+		channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
+		if err != nil {
+			continue
+		}
+
+		result[channelID] = struct{}{}
+	}
+
+	if err := iter.Err(); err != nil {
+		return nil, err
+	}
+
+	return result, nil
+}
+
 // ClearChannelModelErrors clears errors for a specific channel and model
 func ClearChannelModelErrors(ctx context.Context, model string, channelID int) error {
 	if !common.RedisEnabled {