|
|
@@ -5,7 +5,6 @@ import (
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"math/rand/v2"
|
|
|
- "slices"
|
|
|
"strconv"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
@@ -143,7 +142,7 @@ func GetRandomChannel(
|
|
|
modelName string,
|
|
|
mode mode.Mode,
|
|
|
errorRates map[int64]float64,
|
|
|
- ignoreChannel ...int64,
|
|
|
+ ignoreChannelMap map[int64]struct{},
|
|
|
) (*model.Channel, []*model.Channel, error) {
|
|
|
channelMap := make(map[int]*model.Channel)
|
|
|
if len(availableSet) != 0 {
|
|
|
@@ -183,7 +182,7 @@ func GetRandomChannel(
|
|
|
migratedChannels = append(migratedChannels, channel)
|
|
|
}
|
|
|
|
|
|
- channel, err := getRandomChannel(migratedChannels, mode, errorRates, ignoreChannel...)
|
|
|
+ channel, err := ignoreChannel(migratedChannels, mode, errorRates, ignoreChannelMap, nil)
|
|
|
|
|
|
return channel, migratedChannels, err
|
|
|
}
|
|
|
@@ -200,17 +199,17 @@ func getPriority(channel *model.Channel, errorRate float64) int32 {
|
|
|
return int32(float64(priority) / errorRate)
|
|
|
}
|
|
|
|
|
|
-func getRandomChannel(
|
|
|
+func ignoreChannel(
|
|
|
channels []*model.Channel,
|
|
|
mode mode.Mode,
|
|
|
errorRates map[int64]float64,
|
|
|
- ignoreChannel ...int64,
|
|
|
+ ignoreChannelIDs ...map[int64]struct{},
|
|
|
) (*model.Channel, error) {
|
|
|
if len(channels) == 0 {
|
|
|
return nil, ErrChannelsNotFound
|
|
|
}
|
|
|
|
|
|
- channels = filterChannels(channels, mode, ignoreChannel...)
|
|
|
+ channels = filterChannels(channels, mode, ignoreChannelIDs...)
|
|
|
if len(channels) == 0 {
|
|
|
return nil, ErrChannelsExhausted
|
|
|
}
|
|
|
@@ -249,7 +248,7 @@ func getChannelWithFallback(
|
|
|
modelName string,
|
|
|
mode mode.Mode,
|
|
|
errorRates map[int64]float64,
|
|
|
- ignoreChannelIDs ...int64,
|
|
|
+ ignoreChannelIDs map[int64]struct{},
|
|
|
) (*model.Channel, []*model.Channel, error) {
|
|
|
channel, migratedChannels, err := GetRandomChannel(
|
|
|
cache,
|
|
|
@@ -257,7 +256,7 @@ func getChannelWithFallback(
|
|
|
modelName,
|
|
|
mode,
|
|
|
errorRates,
|
|
|
- ignoreChannelIDs...)
|
|
|
+ ignoreChannelIDs)
|
|
|
if err == nil {
|
|
|
return channel, migratedChannels, nil
|
|
|
}
|
|
|
@@ -272,6 +271,7 @@ func getChannelWithFallback(
|
|
|
modelName,
|
|
|
mode,
|
|
|
errorRates,
|
|
|
+ nil,
|
|
|
)
|
|
|
|
|
|
return channel, migratedChannels, err
|
|
|
@@ -280,7 +280,7 @@ func getChannelWithFallback(
|
|
|
type initialChannel struct {
|
|
|
channel *model.Channel
|
|
|
designatedChannel bool
|
|
|
- ignoreChannelIDs []int64
|
|
|
+ ignoreChannelIDs map[int64]struct{}
|
|
|
errorRates map[int64]float64
|
|
|
migratedChannels []*model.Channel
|
|
|
}
|
|
|
@@ -329,12 +329,12 @@ func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialC
|
|
|
|
|
|
mc := middleware.GetModelCaches(c)
|
|
|
|
|
|
- ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
|
|
|
+ ignoreChannelIDs, err := monitor.GetBannedChannelsMapWithModel(c.Request.Context(), modelName)
|
|
|
if err != nil {
|
|
|
log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
|
|
|
}
|
|
|
|
|
|
- log.Debugf("%s model banned channels: %+v", modelName, ids)
|
|
|
+ log.Debugf("%s model banned channels: %+v", modelName, ignoreChannelIDs)
|
|
|
|
|
|
errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
|
|
|
if err != nil {
|
|
|
@@ -347,14 +347,15 @@ func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialC
|
|
|
modelName,
|
|
|
m,
|
|
|
errorRates,
|
|
|
- ids...)
|
|
|
+ ignoreChannelIDs,
|
|
|
+ )
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
return &initialChannel{
|
|
|
channel: channel,
|
|
|
- ignoreChannelIDs: ids,
|
|
|
+ ignoreChannelIDs: ignoreChannelIDs,
|
|
|
errorRates: errorRates,
|
|
|
migratedChannels: migratedChannels,
|
|
|
}, nil
|
|
|
@@ -365,7 +366,7 @@ func getWebSearchChannel(
|
|
|
mc *model.ModelCaches,
|
|
|
modelName string,
|
|
|
) (*model.Channel, error) {
|
|
|
- ids, _ := monitor.GetBannedChannelsWithModel(ctx, modelName)
|
|
|
+ ignoreChannelIDs, _ := monitor.GetBannedChannelsMapWithModel(ctx, modelName)
|
|
|
errorRates, _ := monitor.GetModelChannelErrorRate(ctx, modelName)
|
|
|
|
|
|
channel, _, err := getChannelWithFallback(
|
|
|
@@ -374,7 +375,7 @@ func getWebSearchChannel(
|
|
|
modelName,
|
|
|
mode.ChatCompletions,
|
|
|
errorRates,
|
|
|
- ids...)
|
|
|
+ ignoreChannelIDs)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -382,7 +383,7 @@ func getWebSearchChannel(
|
|
|
return channel, nil
|
|
|
}
|
|
|
|
|
|
-func getRetryChannel(state *retryState) (*model.Channel, error) {
|
|
|
+func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.Channel, error) {
|
|
|
if state.exhausted {
|
|
|
if state.lastHasPermissionChannel == nil {
|
|
|
return nil, ErrChannelsExhausted
|
|
|
@@ -390,11 +391,28 @@ func getRetryChannel(state *retryState) (*model.Channel, error) {
|
|
|
return state.lastHasPermissionChannel, nil
|
|
|
}
|
|
|
|
|
|
- newChannel, err := getRandomChannel(
|
|
|
+ // For the last retry, filter out all previously failed channels if there are other options
|
|
|
+ if currentRetry == totalRetries-1 && len(state.failedChannelIDs) > 0 {
|
|
|
+ // Check if there are channels available after filtering out failed channels
|
|
|
+ newChannel, err := ignoreChannel(
|
|
|
+ state.migratedChannels,
|
|
|
+ state.meta.Mode,
|
|
|
+ state.errorRates,
|
|
|
+ state.ignoreChannelIDs,
|
|
|
+ state.failedChannelIDs,
|
|
|
+ )
|
|
|
+ if err == nil {
|
|
|
+ return newChannel, nil
|
|
|
+ }
|
|
|
+ // If no channels available after filtering, fall back to not using failed channels filter
|
|
|
+ }
|
|
|
+
|
|
|
+ newChannel, err := ignoreChannel(
|
|
|
state.migratedChannels,
|
|
|
state.meta.Mode,
|
|
|
state.errorRates,
|
|
|
- state.ignoreChannelIDs...)
|
|
|
+ state.ignoreChannelIDs,
|
|
|
+ )
|
|
|
if err != nil {
|
|
|
if !errors.Is(err, ErrChannelsExhausted) || state.lastHasPermissionChannel == nil {
|
|
|
return nil, err
|
|
|
@@ -411,7 +429,7 @@ func getRetryChannel(state *retryState) (*model.Channel, error) {
|
|
|
func filterChannels(
|
|
|
channels []*model.Channel,
|
|
|
mode mode.Mode,
|
|
|
- ignoreChannel ...int64,
|
|
|
+ ignoreChannel ...map[int64]struct{},
|
|
|
) []*model.Channel {
|
|
|
filtered := make([]*model.Channel, 0)
|
|
|
for _, channel := range channels {
|
|
|
@@ -428,7 +446,21 @@ func filterChannels(
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- if slices.Contains(ignoreChannel, int64(channel.ID)) {
|
|
|
+ chid := int64(channel.ID)
|
|
|
+ needIgnore := false
|
|
|
+
|
|
|
+ for _, ignores := range ignoreChannel {
|
|
|
+ if ignores == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ _, needIgnore = ignores[chid]
|
|
|
+ if needIgnore {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if needIgnore {
|
|
|
continue
|
|
|
}
|
|
|
|