| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510 |
- package controller
- import (
- "context"
- "errors"
- "fmt"
- "math/rand/v2"
- "strconv"
- "github.com/gin-gonic/gin"
- "github.com/labring/aiproxy/core/common"
- "github.com/labring/aiproxy/core/middleware"
- "github.com/labring/aiproxy/core/model"
- "github.com/labring/aiproxy/core/monitor"
- "github.com/labring/aiproxy/core/relay/adaptors"
- "github.com/labring/aiproxy/core/relay/mode"
- )
- const (
- AIProxyChannelHeader = "Aiproxy-Channel"
- // maxRetryErrorRate is the maximum error rate threshold for channel retry selection
- // Channels with error rate higher than this will be filtered out during retry
- maxRetryErrorRate = 0.75
- )
- func GetChannelFromHeader(
- header string,
- mc *model.ModelCaches,
- availableSet []string,
- model string,
- m mode.Mode,
- ) (*model.Channel, error) {
- channelIDInt, err := strconv.ParseInt(header, 10, 64)
- if err != nil {
- return nil, err
- }
- for _, set := range availableSet {
- enabledChannels := mc.EnabledModel2ChannelsBySet[set][model]
- if len(enabledChannels) > 0 {
- for _, channel := range enabledChannels {
- if int64(channel.ID) == channelIDInt {
- a, ok := adaptors.GetAdaptor(channel.Type)
- if !ok {
- return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
- }
- if !a.SupportMode(m) {
- return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
- }
- return channel, nil
- }
- }
- }
- disabledChannels := mc.DisabledModel2ChannelsBySet[set][model]
- if len(disabledChannels) > 0 {
- for _, channel := range disabledChannels {
- if int64(channel.ID) == channelIDInt {
- a, ok := adaptors.GetAdaptor(channel.Type)
- if !ok {
- return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
- }
- if !a.SupportMode(m) {
- return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
- }
- return channel, nil
- }
- }
- }
- }
- return nil, fmt.Errorf("channel %d not found for model `%s`", channelIDInt, model)
- }
- func needPinChannel(m mode.Mode) bool {
- switch m {
- case mode.VideoGenerationsGetJobs,
- mode.VideoGenerationsContent,
- mode.ResponsesGet,
- mode.ResponsesDelete,
- mode.ResponsesCancel,
- mode.ResponsesInputItems:
- return true
- default:
- return false
- }
- }
- func GetChannelFromRequest(
- c *gin.Context,
- mc *model.ModelCaches,
- availableSet []string,
- modelName string,
- m mode.Mode,
- ) (*model.Channel, error) {
- channelID := middleware.GetChannelID(c)
- if channelID == 0 {
- if needPinChannel(m) {
- return nil, fmt.Errorf("%s need pinned channel", m)
- }
- return nil, nil
- }
- for _, set := range availableSet {
- enabledChannels := mc.EnabledModel2ChannelsBySet[set][modelName]
- if len(enabledChannels) > 0 {
- for _, channel := range enabledChannels {
- if channel.ID == channelID {
- a, ok := adaptors.GetAdaptor(channel.Type)
- if !ok {
- return nil, fmt.Errorf(
- "adaptor not found for pinned channel %d",
- channel.ID,
- )
- }
- if !a.SupportMode(m) {
- return nil, fmt.Errorf(
- "pinned channel %d not supported by adaptor",
- channel.ID,
- )
- }
- return channel, nil
- }
- }
- }
- }
- return nil, fmt.Errorf("pinned channel %d not found for model `%s`", channelID, modelName)
- }
- var (
- ErrChannelsNotFound = errors.New("channels not found")
- ErrChannelsExhausted = errors.New("channels exhausted")
- )
- func getRandomChannel(
- mc *model.ModelCaches,
- availableSet []string,
- modelName string,
- mode mode.Mode,
- errorRates map[int64]float64,
- maxErrorRate float64,
- ignoreChannelMap ...map[int64]struct{},
- ) (*model.Channel, []*model.Channel, error) {
- channelMap := make(map[int]*model.Channel)
- if len(availableSet) != 0 {
- for _, set := range availableSet {
- channels := mc.EnabledModel2ChannelsBySet[set][modelName]
- for _, channel := range channels {
- a, ok := adaptors.GetAdaptor(channel.Type)
- if !ok {
- continue
- }
- if !a.SupportMode(mode) {
- continue
- }
- channelMap[channel.ID] = channel
- }
- }
- } else {
- for _, sets := range mc.EnabledModel2ChannelsBySet {
- for _, channel := range sets[modelName] {
- a, ok := adaptors.GetAdaptor(channel.Type)
- if !ok {
- continue
- }
- if !a.SupportMode(mode) {
- continue
- }
- channelMap[channel.ID] = channel
- }
- }
- }
- migratedChannels := make([]*model.Channel, 0, len(channelMap))
- for _, channel := range channelMap {
- migratedChannels = append(migratedChannels, channel)
- }
- channel, err := ignoreChannel(
- migratedChannels,
- mode,
- errorRates,
- maxErrorRate,
- ignoreChannelMap...,
- )
- return channel, migratedChannels, err
- }
- func getPriority(channel *model.Channel, errorRate float64) int32 {
- priority := channel.GetPriority()
- if errorRate > 1 {
- errorRate = 1
- } else if errorRate < 0.1 {
- errorRate = 0.1
- }
- return int32(float64(priority) / errorRate)
- }
- func ignoreChannel(
- channels []*model.Channel,
- mode mode.Mode,
- errorRates map[int64]float64,
- maxErrorRate float64,
- ignoreChannelIDs ...map[int64]struct{},
- ) (*model.Channel, error) {
- if len(channels) == 0 {
- return nil, ErrChannelsNotFound
- }
- channels = filterChannels(channels, mode, errorRates, maxErrorRate, ignoreChannelIDs...)
- if len(channels) == 0 {
- return nil, ErrChannelsExhausted
- }
- if len(channels) == 1 {
- return channels[0], nil
- }
- var totalWeight int32
- cachedPrioritys := make([]int32, len(channels))
- for i, ch := range channels {
- priority := getPriority(ch, errorRates[int64(ch.ID)])
- totalWeight += priority
- cachedPrioritys[i] = priority
- }
- if totalWeight == 0 {
- return channels[rand.IntN(len(channels))], nil
- }
- r := rand.Int32N(totalWeight)
- for i, ch := range channels {
- r -= cachedPrioritys[i]
- if r < 0 {
- return ch, nil
- }
- }
- return channels[rand.IntN(len(channels))], nil
- }
- func getChannelWithFallback(
- cache *model.ModelCaches,
- availableSet []string,
- modelName string,
- mode mode.Mode,
- errorRates map[int64]float64,
- ignoreChannelIDs map[int64]struct{},
- ) (*model.Channel, []*model.Channel, error) {
- channel, migratedChannels, err := getRandomChannel(
- cache,
- availableSet,
- modelName,
- mode,
- errorRates,
- maxRetryErrorRate,
- ignoreChannelIDs,
- )
- if err == nil {
- return channel, migratedChannels, nil
- }
- if !errors.Is(err, ErrChannelsExhausted) {
- return nil, migratedChannels, err
- }
- return getRandomChannel(
- cache,
- availableSet,
- modelName,
- mode,
- errorRates,
- 0,
- )
- }
- type initialChannel struct {
- channel *model.Channel
- designatedChannel bool
- ignoreChannelIDs map[int64]struct{}
- errorRates map[int64]float64
- migratedChannels []*model.Channel
- }
- func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialChannel, error) {
- log := common.GetLogger(c)
- group := middleware.GetGroup(c)
- availableSet := group.GetAvailableSets()
- if channelHeader := c.Request.Header.Get(AIProxyChannelHeader); channelHeader != "" {
- if group.Status != model.GroupStatusInternal {
- return nil, errors.New("channel header is not allowed in non-internal group")
- }
- channel, err := GetChannelFromHeader(
- channelHeader,
- middleware.GetModelCaches(c),
- availableSet,
- modelName,
- m,
- )
- if err != nil {
- return nil, err
- }
- log.Data["designated_channel"] = "true"
- return &initialChannel{channel: channel, designatedChannel: true}, nil
- }
- channel, err := GetChannelFromRequest(
- c,
- middleware.GetModelCaches(c),
- availableSet,
- modelName,
- m,
- )
- if err != nil {
- return nil, err
- }
- if channel != nil {
- return &initialChannel{channel: channel, designatedChannel: true}, nil
- }
- mc := middleware.GetModelCaches(c)
- ignoreChannelIDs, err := monitor.GetBannedChannelsMapWithModel(c.Request.Context(), modelName)
- if err != nil {
- log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
- }
- log.Debugf("%s model banned channels: %+v", modelName, ignoreChannelIDs)
- errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
- if err != nil {
- log.Errorf("get channel model error rates failed: %+v", err)
- }
- channel, migratedChannels, err := getChannelWithFallback(
- mc,
- availableSet,
- modelName,
- m,
- errorRates,
- ignoreChannelIDs,
- )
- if err != nil {
- return nil, err
- }
- return &initialChannel{
- channel: channel,
- ignoreChannelIDs: ignoreChannelIDs,
- errorRates: errorRates,
- migratedChannels: migratedChannels,
- }, nil
- }
- func getWebSearchChannel(
- ctx context.Context,
- mc *model.ModelCaches,
- modelName string,
- ) (*model.Channel, error) {
- ignoreChannelIDs, _ := monitor.GetBannedChannelsMapWithModel(ctx, modelName)
- errorRates, _ := monitor.GetModelChannelErrorRate(ctx, modelName)
- channel, _, err := getChannelWithFallback(
- mc,
- nil,
- modelName,
- mode.ChatCompletions,
- errorRates,
- ignoreChannelIDs)
- if err != nil {
- return nil, err
- }
- return channel, nil
- }
- func getRetryChannel(state *retryState, currentRetry, totalRetries int) (*model.Channel, error) {
- if state.exhausted {
- if state.lastHasPermissionChannel == nil {
- return nil, ErrChannelsExhausted
- }
- // Check if lastHasPermissionChannel has high error rate
- // If so, return exhausted to prevent retrying with a bad channel
- channelID := int64(state.lastHasPermissionChannel.ID)
- if errorRate, ok := state.errorRates[channelID]; ok && errorRate > maxRetryErrorRate {
- return nil, ErrChannelsExhausted
- }
- return state.lastHasPermissionChannel, nil
- }
- // For the last retry, filter out all previously failed channels if there are other options
- if currentRetry == totalRetries-1 && len(state.failedChannelIDs) > 0 {
- // Check if there are channels available after filtering out failed channels
- newChannel, err := ignoreChannel(
- state.migratedChannels,
- state.meta.Mode,
- state.errorRates,
- maxRetryErrorRate,
- state.ignoreChannelIDs,
- state.failedChannelIDs,
- )
- if err == nil {
- return newChannel, nil
- }
- // If no channels available after filtering, fall back to not using failed channels filter
- }
- newChannel, err := ignoreChannel(
- state.migratedChannels,
- state.meta.Mode,
- state.errorRates,
- maxRetryErrorRate,
- state.ignoreChannelIDs,
- )
- if err != nil {
- if !errors.Is(err, ErrChannelsExhausted) || state.lastHasPermissionChannel == nil {
- return nil, err
- }
- // Check if lastHasPermissionChannel has high error rate before using it
- channelID := int64(state.lastHasPermissionChannel.ID)
- if errorRate, ok := state.errorRates[channelID]; ok && errorRate > maxRetryErrorRate {
- return nil, ErrChannelsExhausted
- }
- state.exhausted = true
- return state.lastHasPermissionChannel, nil
- }
- return newChannel, nil
- }
- func filterChannels(
- channels []*model.Channel,
- mode mode.Mode,
- errorRates map[int64]float64,
- maxErrorRate float64,
- ignoreChannel ...map[int64]struct{},
- ) []*model.Channel {
- filtered := make([]*model.Channel, 0)
- for _, channel := range channels {
- if channel.Status != model.ChannelStatusEnabled {
- continue
- }
- a, ok := adaptors.GetAdaptor(channel.Type)
- if !ok {
- continue
- }
- if !a.SupportMode(mode) {
- continue
- }
- chid := int64(channel.ID)
- if maxErrorRate != 0 {
- // Filter out channels with error rate higher than threshold
- // This avoids amplifying attacks and retrying with bad channels
- if errorRate, ok := errorRates[chid]; ok && errorRate > maxErrorRate {
- continue
- }
- }
- needIgnore := false
- for _, ignores := range ignoreChannel {
- if ignores == nil {
- continue
- }
- _, needIgnore = ignores[chid]
- if needIgnore {
- break
- }
- }
- if needIgnore {
- continue
- }
- filtered = append(filtered, channel)
- }
- return filtered
- }
|