|
|
@@ -2,13 +2,11 @@ package controller
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
- "context"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"math/rand/v2"
|
|
|
"net/http"
|
|
|
- "slices"
|
|
|
"strconv"
|
|
|
"time"
|
|
|
|
|
|
@@ -19,12 +17,8 @@ import (
|
|
|
"github.com/labring/aiproxy/core/common/config"
|
|
|
"github.com/labring/aiproxy/core/common/consume"
|
|
|
"github.com/labring/aiproxy/core/common/conv"
|
|
|
- "github.com/labring/aiproxy/core/common/notify"
|
|
|
- "github.com/labring/aiproxy/core/common/reqlimit"
|
|
|
- "github.com/labring/aiproxy/core/common/trylock"
|
|
|
"github.com/labring/aiproxy/core/middleware"
|
|
|
"github.com/labring/aiproxy/core/model"
|
|
|
- "github.com/labring/aiproxy/core/monitor"
|
|
|
"github.com/labring/aiproxy/core/relay/adaptor"
|
|
|
"github.com/labring/aiproxy/core/relay/adaptors"
|
|
|
"github.com/labring/aiproxy/core/relay/controller"
|
|
|
@@ -33,9 +27,9 @@ import (
|
|
|
relaymodel "github.com/labring/aiproxy/core/relay/model"
|
|
|
"github.com/labring/aiproxy/core/relay/plugin"
|
|
|
"github.com/labring/aiproxy/core/relay/plugin/cache"
|
|
|
+ monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
|
|
|
"github.com/labring/aiproxy/core/relay/plugin/thinksplit"
|
|
|
websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
|
|
|
- log "github.com/sirupsen/logrus"
|
|
|
)
|
|
|
|
|
|
// https://platform.openai.com/docs/api-reference/chat
|
|
|
@@ -52,114 +46,6 @@ type RelayController struct {
|
|
|
Handler RelayHandler
|
|
|
}
|
|
|
|
|
|
-// TODO: convert to plugin
|
|
|
-type wrapAdaptor struct {
|
|
|
- adaptor.Adaptor
|
|
|
-}
|
|
|
-
|
|
|
-const (
|
|
|
- MetaChannelModelKeyRPM = "channel_model_rpm"
|
|
|
- MetaChannelModelKeyRPS = "channel_model_rps"
|
|
|
- MetaChannelModelKeyTPM = "channel_model_tpm"
|
|
|
- MetaChannelModelKeyTPS = "channel_model_tps"
|
|
|
-)
|
|
|
-
|
|
|
-func getChannelModelRequestRate(c *gin.Context, meta *meta.Meta) model.RequestRate {
|
|
|
- rate := model.RequestRate{}
|
|
|
-
|
|
|
- if rpm, ok := meta.Get(MetaChannelModelKeyRPM); ok {
|
|
|
- rate.RPM, _ = rpm.(int64)
|
|
|
- rate.RPS = meta.GetInt64(MetaChannelModelKeyRPS)
|
|
|
- } else {
|
|
|
- rpm, rps := reqlimit.GetChannelModelRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
|
|
|
- rate.RPM = rpm
|
|
|
- rate.RPS = rps
|
|
|
- updateChannelModelRequestRate(c, meta, rpm, rps)
|
|
|
- }
|
|
|
-
|
|
|
- if tpm, ok := meta.Get(MetaChannelModelKeyTPM); ok {
|
|
|
- rate.TPM, _ = tpm.(int64)
|
|
|
- rate.TPS = meta.GetInt64(MetaChannelModelKeyTPS)
|
|
|
- } else {
|
|
|
- tpm, tps := reqlimit.GetChannelModelTokensRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
|
|
|
- rate.TPM = tpm
|
|
|
- rate.TPS = tps
|
|
|
- updateChannelModelTokensRequestRate(c, meta, tpm, tps)
|
|
|
- }
|
|
|
-
|
|
|
- return rate
|
|
|
-}
|
|
|
-
|
|
|
-func updateChannelModelRequestRate(c *gin.Context, meta *meta.Meta, rpm, rps int64) {
|
|
|
- meta.Set(MetaChannelModelKeyRPM, rpm)
|
|
|
- meta.Set(MetaChannelModelKeyRPS, rps)
|
|
|
- log := middleware.GetLogger(c)
|
|
|
- log.Data["ch_rpm"] = rpm
|
|
|
- log.Data["ch_rps"] = rps
|
|
|
-}
|
|
|
-
|
|
|
-func updateChannelModelTokensRequestRate(c *gin.Context, meta *meta.Meta, tpm, tps int64) {
|
|
|
- meta.Set(MetaChannelModelKeyTPM, tpm)
|
|
|
- meta.Set(MetaChannelModelKeyTPS, tps)
|
|
|
- log := middleware.GetLogger(c)
|
|
|
- log.Data["ch_tpm"] = tpm
|
|
|
- log.Data["ch_tps"] = tps
|
|
|
-}
|
|
|
-
|
|
|
-func (w *wrapAdaptor) DoRequest(
|
|
|
- meta *meta.Meta,
|
|
|
- store adaptor.Store,
|
|
|
- c *gin.Context,
|
|
|
- req *http.Request,
|
|
|
-) (*http.Response, error) {
|
|
|
- count, overLimitCount, secondCount := reqlimit.PushChannelModelRequest(
|
|
|
- context.Background(),
|
|
|
- strconv.Itoa(meta.Channel.ID),
|
|
|
- meta.OriginModel,
|
|
|
- )
|
|
|
- updateChannelModelRequestRate(c, meta, count+overLimitCount, secondCount)
|
|
|
- return w.Adaptor.DoRequest(meta, store, c, req)
|
|
|
-}
|
|
|
-
|
|
|
-func (w *wrapAdaptor) DoResponse(
|
|
|
- meta *meta.Meta,
|
|
|
- store adaptor.Store,
|
|
|
- c *gin.Context,
|
|
|
- resp *http.Response,
|
|
|
-) (model.Usage, adaptor.Error) {
|
|
|
- usage, relayErr := w.Adaptor.DoResponse(meta, store, c, resp)
|
|
|
-
|
|
|
- if usage.TotalTokens > 0 {
|
|
|
- count, overLimitCount, secondCount := reqlimit.PushChannelModelTokensRequest(
|
|
|
- context.Background(),
|
|
|
- strconv.Itoa(meta.Channel.ID),
|
|
|
- meta.OriginModel,
|
|
|
- int64(usage.TotalTokens),
|
|
|
- )
|
|
|
- updateChannelModelTokensRequestRate(c, meta, count+overLimitCount, secondCount)
|
|
|
-
|
|
|
- count, overLimitCount, secondCount = reqlimit.PushGroupModelTokensRequest(
|
|
|
- context.Background(),
|
|
|
- meta.Group.ID,
|
|
|
- meta.OriginModel,
|
|
|
- meta.ModelConfig.TPM,
|
|
|
- int64(usage.TotalTokens),
|
|
|
- )
|
|
|
- middleware.UpdateGroupModelTokensRequest(c, meta.Group, count+overLimitCount, secondCount)
|
|
|
-
|
|
|
- count, overLimitCount, secondCount = reqlimit.PushGroupModelTokennameTokensRequest(
|
|
|
- context.Background(),
|
|
|
- meta.Group.ID,
|
|
|
- meta.OriginModel,
|
|
|
- meta.Token.Name,
|
|
|
- int64(usage.TotalTokens),
|
|
|
- )
|
|
|
- middleware.UpdateGroupModelTokennameTokensRequest(c, count+overLimitCount, secondCount)
|
|
|
- }
|
|
|
-
|
|
|
- return usage, relayErr
|
|
|
-}
|
|
|
-
|
|
|
var adaptorStore adaptor.Store = &storeImpl{}
|
|
|
|
|
|
type storeImpl struct{}
|
|
|
@@ -192,7 +78,7 @@ func (s *storeImpl) SaveStore(store adaptor.StoreCache) error {
|
|
|
}
|
|
|
|
|
|
func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
|
|
|
- log := middleware.GetLogger(c)
|
|
|
+ log := common.GetLogger(c)
|
|
|
middleware.SetLogFieldsFromMeta(meta, log.Data)
|
|
|
|
|
|
adaptor, ok := adaptors.GetAdaptor(meta.Channel.Type)
|
|
|
@@ -206,12 +92,14 @@ func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- a := plugin.WrapperAdaptor(&wrapAdaptor{adaptor},
|
|
|
+ a := plugin.WrapperAdaptor(adaptor,
|
|
|
+ monitorplugin.NewGroupMonitorPlugin(),
|
|
|
cache.NewCachePlugin(common.RDB),
|
|
|
websearch.NewWebSearchPlugin(func(modelName string) (*model.Channel, error) {
|
|
|
return getWebSearchChannel(c, modelName)
|
|
|
}),
|
|
|
thinksplit.NewThinkPlugin(),
|
|
|
+ monitorplugin.NewChannelMonitorPlugin(),
|
|
|
)
|
|
|
|
|
|
return controller.Handle(a, c, meta, adaptorStore)
|
|
|
@@ -259,123 +147,6 @@ func relayController(m mode.Mode) RelayController {
|
|
|
return c
|
|
|
}
|
|
|
|
|
|
-const (
|
|
|
- AIProxyChannelHeader = "Aiproxy-Channel"
|
|
|
-)
|
|
|
-
|
|
|
-func GetChannelFromHeader(
|
|
|
- header string,
|
|
|
- mc *model.ModelCaches,
|
|
|
- availableSet []string,
|
|
|
- model string,
|
|
|
- m mode.Mode,
|
|
|
-) (*model.Channel, error) {
|
|
|
- channelIDInt, err := strconv.ParseInt(header, 10, 64)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
- for _, set := range availableSet {
|
|
|
- enabledChannels := mc.EnabledModel2ChannelsBySet[set][model]
|
|
|
- if len(enabledChannels) > 0 {
|
|
|
- for _, channel := range enabledChannels {
|
|
|
- if int64(channel.ID) == channelIDInt {
|
|
|
- a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
- if !ok {
|
|
|
- return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
|
|
|
- }
|
|
|
- if !a.SupportMode(m) {
|
|
|
- return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
|
|
|
- }
|
|
|
- return channel, nil
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- disabledChannels := mc.DisabledModel2ChannelsBySet[set][model]
|
|
|
- if len(disabledChannels) > 0 {
|
|
|
- for _, channel := range disabledChannels {
|
|
|
- if int64(channel.ID) == channelIDInt {
|
|
|
- a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
- if !ok {
|
|
|
- return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
|
|
|
- }
|
|
|
- if !a.SupportMode(m) {
|
|
|
- return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
|
|
|
- }
|
|
|
- return channel, nil
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return nil, fmt.Errorf("channel %d not found for model `%s`", channelIDInt, model)
|
|
|
-}
|
|
|
-
|
|
|
-func GetChannelFromRequest(
|
|
|
- c *gin.Context,
|
|
|
- mc *model.ModelCaches,
|
|
|
- availableSet []string,
|
|
|
- modelName string,
|
|
|
- m mode.Mode,
|
|
|
-) (*model.Channel, error) {
|
|
|
- switch m {
|
|
|
- case mode.VideoGenerationsGetJobs,
|
|
|
- mode.VideoGenerationsContent:
|
|
|
- channelID := middleware.GetChannelID(c)
|
|
|
- if channelID == 0 {
|
|
|
- return nil, errors.New("channel id is required")
|
|
|
- }
|
|
|
- for _, set := range availableSet {
|
|
|
- enabledChannels := mc.EnabledModel2ChannelsBySet[set][modelName]
|
|
|
- if len(enabledChannels) > 0 {
|
|
|
- for _, channel := range enabledChannels {
|
|
|
- if channel.ID == channelID {
|
|
|
- a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
- if !ok {
|
|
|
- return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
|
|
|
- }
|
|
|
- if !a.SupportMode(m) {
|
|
|
- return nil, fmt.Errorf(
|
|
|
- "channel %d not supported by adaptor",
|
|
|
- channel.ID,
|
|
|
- )
|
|
|
- }
|
|
|
- return channel, nil
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- return nil, fmt.Errorf("channel %d not found for model `%s`", channelID, modelName)
|
|
|
- default:
|
|
|
- channelID := middleware.GetChannelID(c)
|
|
|
- if channelID == 0 {
|
|
|
- return nil, nil
|
|
|
- }
|
|
|
- for _, set := range availableSet {
|
|
|
- enabledChannels := mc.EnabledModel2ChannelsBySet[set][modelName]
|
|
|
- if len(enabledChannels) > 0 {
|
|
|
- for _, channel := range enabledChannels {
|
|
|
- if channel.ID == channelID {
|
|
|
- a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
- if !ok {
|
|
|
- return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
|
|
|
- }
|
|
|
- if !a.SupportMode(m) {
|
|
|
- return nil, fmt.Errorf(
|
|
|
- "channel %d not supported by adaptor",
|
|
|
- channel.ID,
|
|
|
- )
|
|
|
- }
|
|
|
- return channel, nil
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- return nil, nil
|
|
|
-}
|
|
|
-
|
|
|
func RelayHelper(
|
|
|
c *gin.Context,
|
|
|
meta *meta.Meta,
|
|
|
@@ -383,267 +154,9 @@ func RelayHelper(
|
|
|
) (*controller.HandleResult, bool) {
|
|
|
result := handel(c, meta)
|
|
|
if result.Error == nil {
|
|
|
- if _, _, err := monitor.AddRequest(
|
|
|
- context.Background(),
|
|
|
- meta.OriginModel,
|
|
|
- int64(meta.Channel.ID),
|
|
|
- false,
|
|
|
- false,
|
|
|
- meta.ModelConfig.MaxErrorRate,
|
|
|
- ); err != nil {
|
|
|
- log.Errorf("add request failed: %+v", err)
|
|
|
- }
|
|
|
return result, false
|
|
|
}
|
|
|
- shouldRetry := shouldRetry(c, result.Error)
|
|
|
- if shouldRetry {
|
|
|
- hasPermission := channelHasPermission(result.Error)
|
|
|
- beyondThreshold, banExecution, err := monitor.AddRequest(
|
|
|
- context.Background(),
|
|
|
- meta.OriginModel,
|
|
|
- int64(meta.Channel.ID),
|
|
|
- true,
|
|
|
- !hasPermission,
|
|
|
- meta.ModelConfig.MaxErrorRate,
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- log.Errorf("add request failed: %+v", err)
|
|
|
- }
|
|
|
- switch {
|
|
|
- case banExecution:
|
|
|
- notifyChannelIssue(c, meta, "autoBanned", "Auto Banned", result.Error)
|
|
|
- case beyondThreshold:
|
|
|
- notifyChannelIssue(
|
|
|
- c,
|
|
|
- meta,
|
|
|
- "beyondThreshold",
|
|
|
- "Error Rate Beyond Threshold",
|
|
|
- result.Error,
|
|
|
- )
|
|
|
- case !hasPermission:
|
|
|
- notifyChannelIssue(c, meta, "channelHasPermission", "No Permission", result.Error)
|
|
|
- }
|
|
|
- }
|
|
|
- return result, shouldRetry
|
|
|
-}
|
|
|
-
|
|
|
-func notifyChannelIssue(
|
|
|
- c *gin.Context,
|
|
|
- meta *meta.Meta,
|
|
|
- issueType, titleSuffix string,
|
|
|
- err adaptor.Error,
|
|
|
-) {
|
|
|
- var notifyFunc func(title, message string)
|
|
|
-
|
|
|
- lockKey := fmt.Sprintf("%s:%d:%s", issueType, meta.Channel.ID, meta.OriginModel)
|
|
|
- switch issueType {
|
|
|
- case "beyondThreshold":
|
|
|
- notifyFunc = func(title, message string) {
|
|
|
- notify.WarnThrottle(lockKey, time.Minute, title, message)
|
|
|
- }
|
|
|
- default:
|
|
|
- notifyFunc = func(title, message string) {
|
|
|
- notify.ErrorThrottle(lockKey, time.Minute, title, message)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- respBody, _ := err.MarshalJSON()
|
|
|
-
|
|
|
- message := fmt.Sprintf(
|
|
|
- "channel: %s (type: %d, type name: %s, id: %d)\nmodel: %s\nmode: %s\nstatus code: %d\ndetail: %s\nrequest id: %s",
|
|
|
- meta.Channel.Name,
|
|
|
- meta.Channel.Type,
|
|
|
- meta.Channel.Type.String(),
|
|
|
- meta.Channel.ID,
|
|
|
- meta.OriginModel,
|
|
|
- meta.Mode,
|
|
|
- err.StatusCode(),
|
|
|
- conv.BytesToString(respBody),
|
|
|
- meta.RequestID,
|
|
|
- )
|
|
|
-
|
|
|
- if err.StatusCode() == http.StatusTooManyRequests {
|
|
|
- if !trylock.Lock(lockKey, time.Minute) {
|
|
|
- return
|
|
|
- }
|
|
|
- switch issueType {
|
|
|
- case "beyondThreshold":
|
|
|
- notifyFunc = notify.Warn
|
|
|
- default:
|
|
|
- notifyFunc = notify.Error
|
|
|
- }
|
|
|
-
|
|
|
- rate := getChannelModelRequestRate(c, meta)
|
|
|
- message += fmt.Sprintf(
|
|
|
- "\nrpm: %d\nrps: %d\ntpm: %d\ntps: %d",
|
|
|
- rate.RPM,
|
|
|
- rate.RPS,
|
|
|
- rate.TPM,
|
|
|
- rate.TPS,
|
|
|
- )
|
|
|
- }
|
|
|
-
|
|
|
- notifyFunc(
|
|
|
- fmt.Sprintf("%s `%s` %s", meta.Channel.Name, meta.OriginModel, titleSuffix),
|
|
|
- message,
|
|
|
- )
|
|
|
-}
|
|
|
-
|
|
|
-func filterChannels(
|
|
|
- channels []*model.Channel,
|
|
|
- mode mode.Mode,
|
|
|
- ignoreChannel ...int64,
|
|
|
-) []*model.Channel {
|
|
|
- filtered := make([]*model.Channel, 0)
|
|
|
- for _, channel := range channels {
|
|
|
- if channel.Status != model.ChannelStatusEnabled {
|
|
|
- continue
|
|
|
- }
|
|
|
- a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
- if !ok {
|
|
|
- continue
|
|
|
- }
|
|
|
- if !a.SupportMode(mode) {
|
|
|
- continue
|
|
|
- }
|
|
|
- if slices.Contains(ignoreChannel, int64(channel.ID)) {
|
|
|
- continue
|
|
|
- }
|
|
|
- filtered = append(filtered, channel)
|
|
|
- }
|
|
|
- return filtered
|
|
|
-}
|
|
|
-
|
|
|
-var (
|
|
|
- ErrChannelsNotFound = errors.New("channels not found")
|
|
|
- ErrChannelsExhausted = errors.New("channels exhausted")
|
|
|
-)
|
|
|
-
|
|
|
-func GetRandomChannel(
|
|
|
- mc *model.ModelCaches,
|
|
|
- availableSet []string,
|
|
|
- modelName string,
|
|
|
- mode mode.Mode,
|
|
|
- errorRates map[int64]float64,
|
|
|
- ignoreChannel ...int64,
|
|
|
-) (*model.Channel, []*model.Channel, error) {
|
|
|
- channelMap := make(map[int]*model.Channel)
|
|
|
- if len(availableSet) != 0 {
|
|
|
- for _, set := range availableSet {
|
|
|
- for _, channel := range mc.EnabledModel2ChannelsBySet[set][modelName] {
|
|
|
- a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
- if !ok {
|
|
|
- continue
|
|
|
- }
|
|
|
- if !a.SupportMode(mode) {
|
|
|
- continue
|
|
|
- }
|
|
|
- channelMap[channel.ID] = channel
|
|
|
- }
|
|
|
- }
|
|
|
- } else {
|
|
|
- for _, sets := range mc.EnabledModel2ChannelsBySet {
|
|
|
- for _, channel := range sets[modelName] {
|
|
|
- a, ok := adaptors.GetAdaptor(channel.Type)
|
|
|
- if !ok {
|
|
|
- continue
|
|
|
- }
|
|
|
- if !a.SupportMode(mode) {
|
|
|
- continue
|
|
|
- }
|
|
|
- channelMap[channel.ID] = channel
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- migratedChannels := make([]*model.Channel, 0, len(channelMap))
|
|
|
- for _, channel := range channelMap {
|
|
|
- migratedChannels = append(migratedChannels, channel)
|
|
|
- }
|
|
|
- channel, err := getRandomChannel(migratedChannels, mode, errorRates, ignoreChannel...)
|
|
|
- return channel, migratedChannels, err
|
|
|
-}
|
|
|
-
|
|
|
-func getPriority(channel *model.Channel, errorRate float64) int32 {
|
|
|
- priority := channel.GetPriority()
|
|
|
- if errorRate > 1 {
|
|
|
- errorRate = 1
|
|
|
- } else if errorRate < 0.1 {
|
|
|
- errorRate = 0.1
|
|
|
- }
|
|
|
- return int32(float64(priority) / errorRate)
|
|
|
-}
|
|
|
-
|
|
|
-func getRandomChannel(
|
|
|
- channels []*model.Channel,
|
|
|
- mode mode.Mode,
|
|
|
- errorRates map[int64]float64,
|
|
|
- ignoreChannel ...int64,
|
|
|
-) (*model.Channel, error) {
|
|
|
- if len(channels) == 0 {
|
|
|
- return nil, ErrChannelsNotFound
|
|
|
- }
|
|
|
-
|
|
|
- channels = filterChannels(channels, mode, ignoreChannel...)
|
|
|
- if len(channels) == 0 {
|
|
|
- return nil, ErrChannelsExhausted
|
|
|
- }
|
|
|
-
|
|
|
- if len(channels) == 1 {
|
|
|
- return channels[0], nil
|
|
|
- }
|
|
|
-
|
|
|
- var totalWeight int32
|
|
|
- cachedPrioritys := make([]int32, len(channels))
|
|
|
- for i, ch := range channels {
|
|
|
- priority := getPriority(ch, errorRates[int64(ch.ID)])
|
|
|
- totalWeight += priority
|
|
|
- cachedPrioritys[i] = priority
|
|
|
- }
|
|
|
-
|
|
|
- if totalWeight == 0 {
|
|
|
- return channels[rand.IntN(len(channels))], nil
|
|
|
- }
|
|
|
-
|
|
|
- r := rand.Int32N(totalWeight)
|
|
|
- for i, ch := range channels {
|
|
|
- r -= cachedPrioritys[i]
|
|
|
- if r < 0 {
|
|
|
- return ch, nil
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return channels[rand.IntN(len(channels))], nil
|
|
|
-}
|
|
|
-
|
|
|
-func getChannelWithFallback(
|
|
|
- cache *model.ModelCaches,
|
|
|
- availableSet []string,
|
|
|
- modelName string,
|
|
|
- mode mode.Mode,
|
|
|
- errorRates map[int64]float64,
|
|
|
- ignoreChannelIDs ...int64,
|
|
|
-) (*model.Channel, []*model.Channel, error) {
|
|
|
- channel, migratedChannels, err := GetRandomChannel(
|
|
|
- cache,
|
|
|
- availableSet,
|
|
|
- modelName,
|
|
|
- mode,
|
|
|
- errorRates,
|
|
|
- ignoreChannelIDs...)
|
|
|
- if err == nil {
|
|
|
- return channel, migratedChannels, nil
|
|
|
- }
|
|
|
- if !errors.Is(err, ErrChannelsExhausted) {
|
|
|
- return nil, migratedChannels, err
|
|
|
- }
|
|
|
- channel, migratedChannels, err = GetRandomChannel(
|
|
|
- cache,
|
|
|
- availableSet,
|
|
|
- modelName,
|
|
|
- mode,
|
|
|
- errorRates,
|
|
|
- )
|
|
|
- return channel, migratedChannels, err
|
|
|
+ return result, monitorplugin.ShouldRetry(result.Error)
|
|
|
}
|
|
|
|
|
|
func NewRelay(mode mode.Mode) func(c *gin.Context) {
|
|
|
@@ -784,7 +297,7 @@ func recordResult(
|
|
|
price,
|
|
|
)
|
|
|
if amount > 0 {
|
|
|
- log := middleware.GetLogger(c)
|
|
|
+ log := common.GetLogger(c)
|
|
|
log.Data["amount"] = strconv.FormatFloat(amount, 'f', -1, 64)
|
|
|
}
|
|
|
|
|
|
@@ -802,7 +315,7 @@ func recordResult(
|
|
|
downstreamResult,
|
|
|
user,
|
|
|
metadata,
|
|
|
- getChannelModelRequestRate(c, meta),
|
|
|
+ monitorplugin.GetChannelModelRequestRate(c, meta),
|
|
|
middleware.GetGroupModelTokenRequestRate(c),
|
|
|
)
|
|
|
}
|
|
|
@@ -821,113 +334,6 @@ type retryState struct {
|
|
|
migratedChannels []*model.Channel
|
|
|
}
|
|
|
|
|
|
-type initialChannel struct {
|
|
|
- channel *model.Channel
|
|
|
- designatedChannel bool
|
|
|
- ignoreChannelIDs []int64
|
|
|
- errorRates map[int64]float64
|
|
|
- migratedChannels []*model.Channel
|
|
|
-}
|
|
|
-
|
|
|
-func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialChannel, error) {
|
|
|
- log := middleware.GetLogger(c)
|
|
|
-
|
|
|
- group := middleware.GetGroup(c)
|
|
|
- availableSet := group.GetAvailableSets()
|
|
|
-
|
|
|
- if channelHeader := c.Request.Header.Get(AIProxyChannelHeader); channelHeader != "" {
|
|
|
- if group.Status != model.GroupStatusInternal {
|
|
|
- return nil, errors.New("channel header is not allowed in non-internal group")
|
|
|
- }
|
|
|
- channel, err := GetChannelFromHeader(
|
|
|
- channelHeader,
|
|
|
- middleware.GetModelCaches(c),
|
|
|
- availableSet,
|
|
|
- modelName,
|
|
|
- m,
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- log.Data["designated_channel"] = "true"
|
|
|
- return &initialChannel{channel: channel, designatedChannel: true}, nil
|
|
|
- }
|
|
|
-
|
|
|
- channel, err := GetChannelFromRequest(
|
|
|
- c,
|
|
|
- middleware.GetModelCaches(c),
|
|
|
- availableSet,
|
|
|
- modelName,
|
|
|
- m,
|
|
|
- )
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- if channel != nil {
|
|
|
- return &initialChannel{channel: channel, designatedChannel: true}, nil
|
|
|
- }
|
|
|
-
|
|
|
- mc := middleware.GetModelCaches(c)
|
|
|
-
|
|
|
- ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
|
|
|
- if err != nil {
|
|
|
- log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
|
|
|
- }
|
|
|
- log.Debugf("%s model banned channels: %+v", modelName, ids)
|
|
|
-
|
|
|
- errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
|
|
|
- if err != nil {
|
|
|
- log.Errorf("get channel model error rates failed: %+v", err)
|
|
|
- }
|
|
|
-
|
|
|
- channel, migratedChannels, err := getChannelWithFallback(
|
|
|
- mc,
|
|
|
- availableSet,
|
|
|
- modelName,
|
|
|
- m,
|
|
|
- errorRates,
|
|
|
- ids...)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
- return &initialChannel{
|
|
|
- channel: channel,
|
|
|
- ignoreChannelIDs: ids,
|
|
|
- errorRates: errorRates,
|
|
|
- migratedChannels: migratedChannels,
|
|
|
- }, nil
|
|
|
-}
|
|
|
-
|
|
|
-func getWebSearchChannel(c *gin.Context, modelName string) (*model.Channel, error) {
|
|
|
- log := middleware.GetLogger(c)
|
|
|
- mc := middleware.GetModelCaches(c)
|
|
|
-
|
|
|
- ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
|
|
|
- if err != nil {
|
|
|
- log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
|
|
|
- }
|
|
|
- log.Debugf("%s model banned channels: %+v", modelName, ids)
|
|
|
-
|
|
|
- errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
|
|
|
- if err != nil {
|
|
|
- log.Errorf("get channel model error rates failed: %+v", err)
|
|
|
- }
|
|
|
-
|
|
|
- channel, _, err := getChannelWithFallback(
|
|
|
- mc,
|
|
|
- nil,
|
|
|
- modelName,
|
|
|
- mode.ChatCompletions,
|
|
|
- errorRates,
|
|
|
- ids...)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
- return channel, nil
|
|
|
-}
|
|
|
-
|
|
|
func handleRelayResult(
|
|
|
c *gin.Context,
|
|
|
bizErr adaptor.Error,
|
|
|
@@ -968,7 +374,7 @@ func initRetryState(
|
|
|
state.exhausted = true
|
|
|
}
|
|
|
|
|
|
- if !channelHasPermission(result.Error) {
|
|
|
+ if !monitorplugin.ChannelHasPermission(result.Error) {
|
|
|
state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(channel.channel.ID))
|
|
|
} else {
|
|
|
state.lastHasPermissionChannel = channel.channel
|
|
|
@@ -978,7 +384,7 @@ func initRetryState(
|
|
|
}
|
|
|
|
|
|
func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler) {
|
|
|
- log := middleware.GetLogger(c)
|
|
|
+ log := common.GetLogger(c)
|
|
|
|
|
|
// do not use for i := range state.retryTimes, because the retryTimes is constant
|
|
|
i := 0
|
|
|
@@ -1072,30 +478,6 @@ func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayControlle
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func getRetryChannel(state *retryState) (*model.Channel, error) {
|
|
|
- if state.exhausted {
|
|
|
- if state.lastHasPermissionChannel == nil {
|
|
|
- return nil, ErrChannelsExhausted
|
|
|
- }
|
|
|
- return state.lastHasPermissionChannel, nil
|
|
|
- }
|
|
|
-
|
|
|
- newChannel, err := getRandomChannel(
|
|
|
- state.migratedChannels,
|
|
|
- state.meta.Mode,
|
|
|
- state.errorRates,
|
|
|
- state.ignoreChannelIDs...)
|
|
|
- if err != nil {
|
|
|
- if !errors.Is(err, ErrChannelsExhausted) || state.lastHasPermissionChannel == nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- state.exhausted = true
|
|
|
- return state.lastHasPermissionChannel, nil
|
|
|
- }
|
|
|
-
|
|
|
- return newChannel, nil
|
|
|
-}
|
|
|
-
|
|
|
func prepareRetry(c *gin.Context) error {
|
|
|
requestBody, err := common.GetRequestBody(c.Request)
|
|
|
if err != nil {
|
|
|
@@ -1118,7 +500,7 @@ func handleRetryResult(
|
|
|
return true
|
|
|
}
|
|
|
|
|
|
- hasPermission := channelHasPermission(state.result.Error)
|
|
|
+ hasPermission := monitorplugin.ChannelHasPermission(state.result.Error)
|
|
|
|
|
|
if state.exhausted {
|
|
|
if !hasPermission {
|
|
|
@@ -1136,31 +518,6 @@ func handleRetryResult(
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
-var channelNoRetryStatusCodesMap = map[int]struct{}{
|
|
|
- http.StatusBadRequest: {},
|
|
|
- http.StatusRequestEntityTooLarge: {},
|
|
|
- http.StatusUnprocessableEntity: {},
|
|
|
- http.StatusUnavailableForLegalReasons: {},
|
|
|
-}
|
|
|
-
|
|
|
-// 仅当是channel错误时,才需要记录,用户请求参数错误时,不需要记录
|
|
|
-func shouldRetry(_ *gin.Context, relayErr adaptor.Error) bool {
|
|
|
- _, ok := channelNoRetryStatusCodesMap[relayErr.StatusCode()]
|
|
|
- return !ok
|
|
|
-}
|
|
|
-
|
|
|
-var channelNoPermissionStatusCodesMap = map[int]struct{}{
|
|
|
- http.StatusUnauthorized: {},
|
|
|
- http.StatusPaymentRequired: {},
|
|
|
- http.StatusForbidden: {},
|
|
|
- http.StatusNotFound: {},
|
|
|
-}
|
|
|
-
|
|
|
-func channelHasPermission(relayErr adaptor.Error) bool {
|
|
|
- _, ok := channelNoPermissionStatusCodesMap[relayErr.StatusCode()]
|
|
|
- return !ok
|
|
|
-}
|
|
|
-
|
|
|
// shouldDelay checks if we need to add a delay before retrying
|
|
|
// Only adds delay when retrying with the same channel for rate limiting issues
|
|
|
func shouldDelay(statusCode, lastChannelID, newChannelID int) bool {
|
|
|
@@ -1193,7 +550,7 @@ func ErrorWithRequestID(c *gin.Context, relayErr adaptor.Error) {
|
|
|
c.JSON(relayErr.StatusCode(), relayErr)
|
|
|
return
|
|
|
}
|
|
|
- log := middleware.GetLogger(c)
|
|
|
+ log := common.GetLogger(c)
|
|
|
data, err := relayErr.MarshalJSON()
|
|
|
if err != nil {
|
|
|
log.Errorf("marshal error failed: %+v", err)
|