Przeglądaj źródła

feat: add max retry error rate on choise channel (#415)

* feat: add max retry error rate on choise channel

* fix: not use filed

* chore: function args

* feat: split log record and summary update

* fix: add default fuzzy token threshold to protect cpu usage
zijiren 1 miesiąc temu
rodzic
commit
dae5d6fe89

+ 21 - 0
core/common/consume/consume.go

@@ -117,6 +117,27 @@ func Consume(
 	}
 }
 
+func Summary(
+	code int,
+	firstByteAt time.Time,
+	meta *meta.Meta,
+	usage model.Usage,
+	modelPrice model.Price,
+	downstreamResult bool,
+) {
+	amount := CalculateAmount(code, usage, modelPrice)
+
+	recordSummary(
+		time.Now(),
+		meta,
+		code,
+		firstByteAt,
+		usage,
+		amount,
+		downstreamResult,
+	)
+}
+
 func checkNeedRecordConsume(code int, meta *meta.Meta) bool {
 	switch meta.Mode {
 	case mode.VideoGenerationsGetJobs,

+ 25 - 0
core/common/consume/record.go

@@ -49,3 +49,28 @@ func recordConsume(
 		metadata,
 	)
 }
+
+func recordSummary(
+	now time.Time,
+	meta *meta.Meta,
+	code int,
+	firstByteAt time.Time,
+	usage model.Usage,
+	amount float64,
+	downstreamResult bool,
+) {
+	model.BatchUpdateSummary(
+		now,
+		meta.RequestAt,
+		firstByteAt,
+		meta.Group.ID,
+		code,
+		meta.Channel.ID,
+		meta.OriginModel,
+		meta.Token.ID,
+		meta.Token.Name,
+		downstreamResult,
+		usage,
+		amount,
+	)
+}

+ 46 - 8
core/controller/relay-channel.go

@@ -18,6 +18,9 @@ import (
 
 const (
 	AIProxyChannelHeader = "Aiproxy-Channel"
+	// maxRetryErrorRate is the maximum error rate threshold for channel retry selection
+	// Channels with error rate higher than this will be filtered out during retry
+	maxRetryErrorRate = 0.75
 )
 
 func GetChannelFromHeader(
@@ -142,7 +145,8 @@ func GetRandomChannel(
 	modelName string,
 	mode mode.Mode,
 	errorRates map[int64]float64,
-	ignoreChannelMap map[int64]struct{},
+	maxErrorRate float64,
+	ignoreChannelMap ...map[int64]struct{},
 ) (*model.Channel, []*model.Channel, error) {
 	channelMap := make(map[int]*model.Channel)
 	if len(availableSet) != 0 {
@@ -182,7 +186,13 @@ func GetRandomChannel(
 		migratedChannels = append(migratedChannels, channel)
 	}
 
-	channel, err := ignoreChannel(migratedChannels, mode, errorRates, ignoreChannelMap, nil)
+	channel, err := ignoreChannel(
+		migratedChannels,
+		mode,
+		errorRates,
+		maxErrorRate,
+		ignoreChannelMap...,
+	)
 
 	return channel, migratedChannels, err
 }
@@ -203,13 +213,14 @@ func ignoreChannel(
 	channels []*model.Channel,
 	mode mode.Mode,
 	errorRates map[int64]float64,
+	maxErrorRate float64,
 	ignoreChannelIDs ...map[int64]struct{},
 ) (*model.Channel, error) {
 	if len(channels) == 0 {
 		return nil, ErrChannelsNotFound
 	}
 
-	channels = filterChannels(channels, mode, ignoreChannelIDs...)
+	channels = filterChannels(channels, mode, errorRates, maxErrorRate, ignoreChannelIDs...)
 	if len(channels) == 0 {
 		return nil, ErrChannelsExhausted
 	}
@@ -256,7 +267,9 @@ func getChannelWithFallback(
 		modelName,
 		mode,
 		errorRates,
-		ignoreChannelIDs)
+		maxRetryErrorRate,
+		ignoreChannelIDs,
+	)
 	if err == nil {
 		return channel, migratedChannels, nil
 	}
@@ -265,16 +278,14 @@ func getChannelWithFallback(
 		return nil, migratedChannels, err
 	}
 
-	channel, migratedChannels, err = GetRandomChannel(
+	return GetRandomChannel(
 		cache,
 		availableSet,
 		modelName,
 		mode,
 		errorRates,
-		nil,
+		0,
 	)
-
-	return channel, migratedChannels, err
 }
 
 type initialChannel struct {
@@ -388,6 +399,14 @@ func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.
 		if state.lastHasPermissionChannel == nil {
 			return nil, ErrChannelsExhausted
 		}
+
+		// Check if lastHasPermissionChannel has high error rate
+		// If so, return exhausted to prevent retrying with a bad channel
+		channelID := int64(state.lastHasPermissionChannel.ID)
+		if errorRate, ok := state.errorRates[channelID]; ok && errorRate > maxRetryErrorRate {
+			return nil, ErrChannelsExhausted
+		}
+
 		return state.lastHasPermissionChannel, nil
 	}
 
@@ -398,6 +417,7 @@ func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.
 			state.migratedChannels,
 			state.meta.Mode,
 			state.errorRates,
+			maxRetryErrorRate,
 			state.ignoreChannelIDs,
 			state.failedChannelIDs,
 		)
@@ -411,6 +431,7 @@ func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.
 		state.migratedChannels,
 		state.meta.Mode,
 		state.errorRates,
+		maxRetryErrorRate,
 		state.ignoreChannelIDs,
 	)
 	if err != nil {
@@ -418,6 +439,12 @@ func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.
 			return nil, err
 		}
 
+		// Check if lastHasPermissionChannel has high error rate before using it
+		channelID := int64(state.lastHasPermissionChannel.ID)
+		if errorRate, ok := state.errorRates[channelID]; ok && errorRate > maxRetryErrorRate {
+			return nil, ErrChannelsExhausted
+		}
+
 		state.exhausted = true
 
 		return state.lastHasPermissionChannel, nil
@@ -429,6 +456,8 @@ func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.
 func filterChannels(
 	channels []*model.Channel,
 	mode mode.Mode,
+	errorRates map[int64]float64,
+	maxErrorRate float64,
 	ignoreChannel ...map[int64]struct{},
 ) []*model.Channel {
 	filtered := make([]*model.Channel, 0)
@@ -447,6 +476,15 @@ func filterChannels(
 		}
 
 		chid := int64(channel.ID)
+
+		if maxErrorRate != 0 {
+			// Filter out channels with error rate higher than threshold
+			// This avoids amplifying attacks and retrying with bad channels
+			if errorRate, ok := errorRates[chid]; ok && errorRate > maxErrorRate {
+				continue
+			}
+		}
+
 		needIgnore := false
 
 		for _, ignores := range ignoreChannel {

+ 2 - 8
core/middleware/distributor.go

@@ -454,20 +454,14 @@ func distribute(c *gin.Context, mode mode.Mode) {
 
 	if err := checkGroupModelRPMAndTPM(c, group, mc, token.Name); err != nil {
 		errMsg := err.Error()
-		consume.AsyncConsume(
-			nil,
+
+		consume.Summary(
 			http.StatusTooManyRequests,
 			time.Time{},
 			NewMetaByContext(c, nil, mode),
 			model.Usage{},
 			model.Price{},
-			errMsg,
-			c.ClientIP(),
-			0,
-			nil,
 			true,
-			user,
-			metadata,
 		)
 		AbortLogWithMessage(c, http.StatusTooManyRequests, errMsg)
 

+ 37 - 3
core/model/batch.go

@@ -377,6 +377,42 @@ func BatchRecordLogs(
 		}
 	}
 
+	BatchUpdateSummary(
+		now,
+		requestAt,
+		firstByteAt,
+		group,
+		code,
+		channelID,
+		modelName,
+		tokenID,
+		tokenName,
+		downstreamResult,
+		usage,
+		amount,
+	)
+
+	return err
+}
+
+func BatchUpdateSummary(
+	now time.Time,
+	requestAt time.Time,
+	firstByteAt time.Time,
+	group string,
+	code int,
+	channelID int,
+	modelName string,
+	tokenID int,
+	tokenName string,
+	downstreamResult bool,
+	usage Usage,
+	amount float64,
+) {
+	if now.IsZero() {
+		now = time.Now()
+	}
+
 	amountDecimal := decimal.NewFromFloat(amount)
 
 	batchData.Lock()
@@ -412,7 +448,7 @@ func BatchRecordLogs(
 
 	// group related data only records downstream result
 	if !downstreamResult {
-		return err
+		return
 	}
 
 	updateGroupData(group, amount, amountDecimal)
@@ -444,8 +480,6 @@ func BatchRecordLogs(
 			usage,
 		)
 	}
-
-	return err
 }
 
 func updateChannelData(

+ 6 - 0
core/relay/adaptor/openai/token.go

@@ -13,9 +13,15 @@ import (
 	"github.com/tiktoken-go/tokenizer"
 )
 
+const defaultMaxFuzzyTokenThreshold = 2048000
+
 func getTokenNum(tokenEncoder tokenizer.Codec, text string) int64 {
 	// Check fuzzy token threshold
 	threshold := config.GetFuzzyTokenThreshold()
+	if threshold == 0 || threshold > defaultMaxFuzzyTokenThreshold {
+		threshold = defaultMaxFuzzyTokenThreshold
+	}
+
 	textLen := len(text)
 
 	// If threshold is set and text length exceeds it, use fuzzy calculation