|
|
@@ -18,6 +18,9 @@ import (
|
|
|
|
|
|
const (
|
|
|
AIProxyChannelHeader = "Aiproxy-Channel"
|
|
|
+ // maxRetryErrorRate is the maximum error rate threshold for channel retry selection
|
|
|
+ // Channels with error rate higher than this will be filtered out during retry
|
|
|
+ maxRetryErrorRate = 0.75
|
|
|
)
|
|
|
|
|
|
func GetChannelFromHeader(
|
|
|
@@ -142,7 +145,8 @@ func GetRandomChannel(
|
|
|
modelName string,
|
|
|
mode mode.Mode,
|
|
|
errorRates map[int64]float64,
|
|
|
- ignoreChannelMap map[int64]struct{},
|
|
|
+ maxErrorRate float64,
|
|
|
+ ignoreChannelMap ...map[int64]struct{},
|
|
|
) (*model.Channel, []*model.Channel, error) {
|
|
|
channelMap := make(map[int]*model.Channel)
|
|
|
if len(availableSet) != 0 {
|
|
|
@@ -182,7 +186,13 @@ func GetRandomChannel(
|
|
|
migratedChannels = append(migratedChannels, channel)
|
|
|
}
|
|
|
|
|
|
- channel, err := ignoreChannel(migratedChannels, mode, errorRates, ignoreChannelMap, nil)
|
|
|
+ channel, err := ignoreChannel(
|
|
|
+ migratedChannels,
|
|
|
+ mode,
|
|
|
+ errorRates,
|
|
|
+ maxErrorRate,
|
|
|
+ ignoreChannelMap...,
|
|
|
+ )
|
|
|
|
|
|
return channel, migratedChannels, err
|
|
|
}
|
|
|
@@ -203,13 +213,14 @@ func ignoreChannel(
|
|
|
channels []*model.Channel,
|
|
|
mode mode.Mode,
|
|
|
errorRates map[int64]float64,
|
|
|
+ maxErrorRate float64,
|
|
|
ignoreChannelIDs ...map[int64]struct{},
|
|
|
) (*model.Channel, error) {
|
|
|
if len(channels) == 0 {
|
|
|
return nil, ErrChannelsNotFound
|
|
|
}
|
|
|
|
|
|
- channels = filterChannels(channels, mode, ignoreChannelIDs...)
|
|
|
+ channels = filterChannels(channels, mode, errorRates, maxErrorRate, ignoreChannelIDs...)
|
|
|
if len(channels) == 0 {
|
|
|
return nil, ErrChannelsExhausted
|
|
|
}
|
|
|
@@ -256,7 +267,9 @@ func getChannelWithFallback(
|
|
|
modelName,
|
|
|
mode,
|
|
|
errorRates,
|
|
|
- ignoreChannelIDs)
|
|
|
+ maxRetryErrorRate,
|
|
|
+ ignoreChannelIDs,
|
|
|
+ )
|
|
|
if err == nil {
|
|
|
return channel, migratedChannels, nil
|
|
|
}
|
|
|
@@ -265,16 +278,14 @@ func getChannelWithFallback(
|
|
|
return nil, migratedChannels, err
|
|
|
}
|
|
|
|
|
|
- channel, migratedChannels, err = GetRandomChannel(
|
|
|
+ return GetRandomChannel(
|
|
|
cache,
|
|
|
availableSet,
|
|
|
modelName,
|
|
|
mode,
|
|
|
errorRates,
|
|
|
- nil,
|
|
|
+ 0,
|
|
|
)
|
|
|
-
|
|
|
- return channel, migratedChannels, err
|
|
|
}
|
|
|
|
|
|
type initialChannel struct {
|
|
|
@@ -388,6 +399,14 @@ func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.
|
|
|
if state.lastHasPermissionChannel == nil {
|
|
|
return nil, ErrChannelsExhausted
|
|
|
}
|
|
|
+
|
|
|
+ // Check if lastHasPermissionChannel has high error rate
|
|
|
+ // If so, return exhausted to prevent retrying with a bad channel
|
|
|
+ channelID := int64(state.lastHasPermissionChannel.ID)
|
|
|
+ if errorRate, ok := state.errorRates[channelID]; ok && errorRate > maxRetryErrorRate {
|
|
|
+ return nil, ErrChannelsExhausted
|
|
|
+ }
|
|
|
+
|
|
|
return state.lastHasPermissionChannel, nil
|
|
|
}
|
|
|
|
|
|
@@ -398,6 +417,7 @@ func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.
|
|
|
state.migratedChannels,
|
|
|
state.meta.Mode,
|
|
|
state.errorRates,
|
|
|
+ maxRetryErrorRate,
|
|
|
state.ignoreChannelIDs,
|
|
|
state.failedChannelIDs,
|
|
|
)
|
|
|
@@ -411,6 +431,7 @@ func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.
|
|
|
state.migratedChannels,
|
|
|
state.meta.Mode,
|
|
|
state.errorRates,
|
|
|
+ maxRetryErrorRate,
|
|
|
state.ignoreChannelIDs,
|
|
|
)
|
|
|
if err != nil {
|
|
|
@@ -418,6 +439,12 @@ func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
+ // Check if lastHasPermissionChannel has high error rate before using it
|
|
|
+ channelID := int64(state.lastHasPermissionChannel.ID)
|
|
|
+ if errorRate, ok := state.errorRates[channelID]; ok && errorRate > maxRetryErrorRate {
|
|
|
+ return nil, ErrChannelsExhausted
|
|
|
+ }
|
|
|
+
|
|
|
state.exhausted = true
|
|
|
|
|
|
return state.lastHasPermissionChannel, nil
|
|
|
@@ -429,6 +456,8 @@ func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.
|
|
|
func filterChannels(
|
|
|
channels []*model.Channel,
|
|
|
mode mode.Mode,
|
|
|
+ errorRates map[int64]float64,
|
|
|
+ maxErrorRate float64,
|
|
|
ignoreChannel ...map[int64]struct{},
|
|
|
) []*model.Channel {
|
|
|
filtered := make([]*model.Channel, 0)
|
|
|
@@ -447,6 +476,15 @@ func filterChannels(
|
|
|
}
|
|
|
|
|
|
chid := int64(channel.ID)
|
|
|
+
|
|
|
+ if maxErrorRate != 0 {
|
|
|
+ // Filter out channels with error rate higher than threshold
|
|
|
+ // This avoids amplifying attacks and retrying with bad channels
|
|
|
+ if errorRate, ok := errorRates[chid]; ok && errorRate > maxErrorRate {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
needIgnore := false
|
|
|
|
|
|
for _, ignores := range ignoreChannel {
|