Browse Source

feat: move monitor to plugin (#236)

* feat: move monitor to plugin

* fix: move group monitor to plugin
zijiren 6 months ago
parent
commit
dabb5e7e94
43 changed files with 941 additions and 845 deletions
  1. 42 0
      core/common/gin.go
  2. 2 1
      core/controller/mcp/host.go
  3. 422 0
      core/controller/relay-channel.go
  4. 12 655
      core/controller/relay-controller.go
  5. 3 2
      core/middleware/auth.go
  6. 14 18
      core/middleware/ctxkey.go
  7. 16 63
      core/middleware/distributor.go
  8. 3 37
      core/middleware/log.go
  9. 2 1
      core/middleware/mcp.go
  10. 2 1
      core/middleware/reqid.go
  11. 3 2
      core/middleware/utils.go
  12. 1 2
      core/relay/adaptor/ali/embeddings.go
  13. 2 2
      core/relay/adaptor/ali/image.go
  14. 1 2
      core/relay/adaptor/ali/rerank.go
  15. 2 2
      core/relay/adaptor/ali/tts.go
  16. 1 2
      core/relay/adaptor/anthropic/main.go
  17. 1 2
      core/relay/adaptor/anthropic/openai.go
  18. 2 2
      core/relay/adaptor/aws/claude/main.go
  19. 2 2
      core/relay/adaptor/aws/llama3/main.go
  20. 2 2
      core/relay/adaptor/baidu/embeddings.go
  21. 2 2
      core/relay/adaptor/baidu/image.go
  22. 2 2
      core/relay/adaptor/baidu/main.go
  23. 2 2
      core/relay/adaptor/baidu/rerank.go
  24. 2 2
      core/relay/adaptor/cohere/main.go
  25. 3 3
      core/relay/adaptor/coze/main.go
  26. 2 2
      core/relay/adaptor/doubaoaudio/tts.go
  27. 1 2
      core/relay/adaptor/gemini/main.go
  28. 2 2
      core/relay/adaptor/jina/rerank.go
  29. 3 3
      core/relay/adaptor/minimax/tts.go
  30. 1 2
      core/relay/adaptor/ollama/main.go
  31. 2 3
      core/relay/adaptor/openai/chat.go
  32. 1 2
      core/relay/adaptor/openai/image.go
  33. 1 2
      core/relay/adaptor/openai/moderations.go
  34. 1 2
      core/relay/adaptor/openai/rerank.go
  35. 2 2
      core/relay/adaptor/openai/stt.go
  36. 1 2
      core/relay/adaptor/openai/tts.go
  37. 2 3
      core/relay/adaptor/openai/video.go
  38. 1 2
      core/relay/adaptor/text-embeddings-inference/rerank.go
  39. 2 3
      core/relay/controller/dohelper.go
  40. 2 2
      core/relay/controller/handle.go
  41. 2 2
      core/relay/controller/stt.go
  42. 117 0
      core/relay/plugin/monitor/group.go
  43. 252 0
      core/relay/plugin/monitor/monitor.go

+ 42 - 0
core/common/gin.go

@@ -8,9 +8,12 @@ import (
 	"io"
 	"net/http"
 	"strings"
+	"sync"
 
 	"github.com/bytedance/sonic"
 	"github.com/bytedance/sonic/ast"
+	"github.com/gin-gonic/gin"
+	"github.com/sirupsen/logrus"
 )
 
 type RequestBodyKey struct{}
@@ -104,3 +107,42 @@ func UnmarshalBody2Node(req *http.Request) (ast.Node, error) {
 	}
 	return sonic.Get(requestBody)
 }
+
+var fieldsPool = sync.Pool{
+	New: func() any {
+		return make(logrus.Fields, 6)
+	},
+}
+
+func GetLogFields() logrus.Fields {
+	fields, ok := fieldsPool.Get().(logrus.Fields)
+	if !ok {
+		panic(fmt.Sprintf("fields pool type error: %T, %v", fields, fields))
+	}
+	return fields
+}
+
+func PutLogFields(fields logrus.Fields) {
+	clear(fields)
+	fieldsPool.Put(fields)
+}
+
+func GetLogger(c *gin.Context) *logrus.Entry {
+	if log, ok := c.Get("log"); ok {
+		v, ok := log.(*logrus.Entry)
+		if !ok {
+			panic(fmt.Sprintf("log type error: %T, %v", v, v))
+		}
+		return v
+	}
+	entry := NewLogger()
+	c.Set("log", entry)
+	return entry
+}
+
+func NewLogger() *logrus.Entry {
+	return &logrus.Entry{
+		Logger: logrus.StandardLogger(),
+		Data:   GetLogFields(),
+	}
+}

+ 2 - 1
core/controller/mcp/host.go

@@ -7,6 +7,7 @@ import (
 	"strings"
 
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/config"
 	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
@@ -44,7 +45,7 @@ func routeHostMCP(
 	c *gin.Context,
 	publicHandler, groupHandler func(c *gin.Context, mcpID string),
 ) {
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 	host := c.Request.Host
 
 	log.Debugf("route host mcp: %s", host)

+ 422 - 0
core/controller/relay-channel.go

@@ -0,0 +1,422 @@
+package controller
+
+import (
+	"errors"
+	"fmt"
+	"math/rand/v2"
+	"slices"
+	"strconv"
+
+	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
+	"github.com/labring/aiproxy/core/middleware"
+	"github.com/labring/aiproxy/core/model"
+	"github.com/labring/aiproxy/core/monitor"
+	"github.com/labring/aiproxy/core/relay/adaptors"
+	"github.com/labring/aiproxy/core/relay/mode"
+)
+
+const (
+	AIProxyChannelHeader = "Aiproxy-Channel"
+)
+
+func GetChannelFromHeader(
+	header string,
+	mc *model.ModelCaches,
+	availableSet []string,
+	model string,
+	m mode.Mode,
+) (*model.Channel, error) {
+	channelIDInt, err := strconv.ParseInt(header, 10, 64)
+	if err != nil {
+		return nil, err
+	}
+
+	for _, set := range availableSet {
+		enabledChannels := mc.EnabledModel2ChannelsBySet[set][model]
+		if len(enabledChannels) > 0 {
+			for _, channel := range enabledChannels {
+				if int64(channel.ID) == channelIDInt {
+					a, ok := adaptors.GetAdaptor(channel.Type)
+					if !ok {
+						return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
+					}
+					if !a.SupportMode(m) {
+						return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
+					}
+					return channel, nil
+				}
+			}
+		}
+
+		disabledChannels := mc.DisabledModel2ChannelsBySet[set][model]
+		if len(disabledChannels) > 0 {
+			for _, channel := range disabledChannels {
+				if int64(channel.ID) == channelIDInt {
+					a, ok := adaptors.GetAdaptor(channel.Type)
+					if !ok {
+						return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
+					}
+					if !a.SupportMode(m) {
+						return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
+					}
+					return channel, nil
+				}
+			}
+		}
+	}
+
+	return nil, fmt.Errorf("channel %d not found for model `%s`", channelIDInt, model)
+}
+
+func GetChannelFromRequest(
+	c *gin.Context,
+	mc *model.ModelCaches,
+	availableSet []string,
+	modelName string,
+	m mode.Mode,
+) (*model.Channel, error) {
+	switch m {
+	case mode.VideoGenerationsGetJobs,
+		mode.VideoGenerationsContent:
+		channelID := middleware.GetChannelID(c)
+		if channelID == 0 {
+			return nil, errors.New("channel id is required")
+		}
+		for _, set := range availableSet {
+			enabledChannels := mc.EnabledModel2ChannelsBySet[set][modelName]
+			if len(enabledChannels) > 0 {
+				for _, channel := range enabledChannels {
+					if channel.ID == channelID {
+						a, ok := adaptors.GetAdaptor(channel.Type)
+						if !ok {
+							return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
+						}
+						if !a.SupportMode(m) {
+							return nil, fmt.Errorf(
+								"channel %d not supported by adaptor",
+								channel.ID,
+							)
+						}
+						return channel, nil
+					}
+				}
+			}
+		}
+		return nil, fmt.Errorf("channel %d not found for model `%s`", channelID, modelName)
+	default:
+		channelID := middleware.GetChannelID(c)
+		if channelID == 0 {
+			return nil, nil
+		}
+		for _, set := range availableSet {
+			enabledChannels := mc.EnabledModel2ChannelsBySet[set][modelName]
+			if len(enabledChannels) > 0 {
+				for _, channel := range enabledChannels {
+					if channel.ID == channelID {
+						a, ok := adaptors.GetAdaptor(channel.Type)
+						if !ok {
+							return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
+						}
+						if !a.SupportMode(m) {
+							return nil, fmt.Errorf(
+								"channel %d not supported by adaptor",
+								channel.ID,
+							)
+						}
+						return channel, nil
+					}
+				}
+			}
+		}
+	}
+	return nil, nil
+}
+
+var (
+	ErrChannelsNotFound  = errors.New("channels not found")
+	ErrChannelsExhausted = errors.New("channels exhausted")
+)
+
+func GetRandomChannel(
+	mc *model.ModelCaches,
+	availableSet []string,
+	modelName string,
+	mode mode.Mode,
+	errorRates map[int64]float64,
+	ignoreChannel ...int64,
+) (*model.Channel, []*model.Channel, error) {
+	channelMap := make(map[int]*model.Channel)
+	if len(availableSet) != 0 {
+		for _, set := range availableSet {
+			for _, channel := range mc.EnabledModel2ChannelsBySet[set][modelName] {
+				a, ok := adaptors.GetAdaptor(channel.Type)
+				if !ok {
+					continue
+				}
+				if !a.SupportMode(mode) {
+					continue
+				}
+				channelMap[channel.ID] = channel
+			}
+		}
+	} else {
+		for _, sets := range mc.EnabledModel2ChannelsBySet {
+			for _, channel := range sets[modelName] {
+				a, ok := adaptors.GetAdaptor(channel.Type)
+				if !ok {
+					continue
+				}
+				if !a.SupportMode(mode) {
+					continue
+				}
+				channelMap[channel.ID] = channel
+			}
+		}
+	}
+	migratedChannels := make([]*model.Channel, 0, len(channelMap))
+	for _, channel := range channelMap {
+		migratedChannels = append(migratedChannels, channel)
+	}
+	channel, err := getRandomChannel(migratedChannels, mode, errorRates, ignoreChannel...)
+	return channel, migratedChannels, err
+}
+
+func getPriority(channel *model.Channel, errorRate float64) int32 {
+	priority := channel.GetPriority()
+	if errorRate > 1 {
+		errorRate = 1
+	} else if errorRate < 0.1 {
+		errorRate = 0.1
+	}
+	return int32(float64(priority) / errorRate)
+}
+
+func getRandomChannel(
+	channels []*model.Channel,
+	mode mode.Mode,
+	errorRates map[int64]float64,
+	ignoreChannel ...int64,
+) (*model.Channel, error) {
+	if len(channels) == 0 {
+		return nil, ErrChannelsNotFound
+	}
+
+	channels = filterChannels(channels, mode, ignoreChannel...)
+	if len(channels) == 0 {
+		return nil, ErrChannelsExhausted
+	}
+
+	if len(channels) == 1 {
+		return channels[0], nil
+	}
+
+	var totalWeight int32
+	cachedPrioritys := make([]int32, len(channels))
+	for i, ch := range channels {
+		priority := getPriority(ch, errorRates[int64(ch.ID)])
+		totalWeight += priority
+		cachedPrioritys[i] = priority
+	}
+
+	if totalWeight == 0 {
+		return channels[rand.IntN(len(channels))], nil
+	}
+
+	r := rand.Int32N(totalWeight)
+	for i, ch := range channels {
+		r -= cachedPrioritys[i]
+		if r < 0 {
+			return ch, nil
+		}
+	}
+
+	return channels[rand.IntN(len(channels))], nil
+}
+
+func getChannelWithFallback(
+	cache *model.ModelCaches,
+	availableSet []string,
+	modelName string,
+	mode mode.Mode,
+	errorRates map[int64]float64,
+	ignoreChannelIDs ...int64,
+) (*model.Channel, []*model.Channel, error) {
+	channel, migratedChannels, err := GetRandomChannel(
+		cache,
+		availableSet,
+		modelName,
+		mode,
+		errorRates,
+		ignoreChannelIDs...)
+	if err == nil {
+		return channel, migratedChannels, nil
+	}
+	if !errors.Is(err, ErrChannelsExhausted) {
+		return nil, migratedChannels, err
+	}
+	channel, migratedChannels, err = GetRandomChannel(
+		cache,
+		availableSet,
+		modelName,
+		mode,
+		errorRates,
+	)
+	return channel, migratedChannels, err
+}
+
+type initialChannel struct {
+	channel           *model.Channel
+	designatedChannel bool
+	ignoreChannelIDs  []int64
+	errorRates        map[int64]float64
+	migratedChannels  []*model.Channel
+}
+
+func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialChannel, error) {
+	log := common.GetLogger(c)
+
+	group := middleware.GetGroup(c)
+	availableSet := group.GetAvailableSets()
+
+	if channelHeader := c.Request.Header.Get(AIProxyChannelHeader); channelHeader != "" {
+		if group.Status != model.GroupStatusInternal {
+			return nil, errors.New("channel header is not allowed in non-internal group")
+		}
+		channel, err := GetChannelFromHeader(
+			channelHeader,
+			middleware.GetModelCaches(c),
+			availableSet,
+			modelName,
+			m,
+		)
+		if err != nil {
+			return nil, err
+		}
+		log.Data["designated_channel"] = "true"
+		return &initialChannel{channel: channel, designatedChannel: true}, nil
+	}
+
+	channel, err := GetChannelFromRequest(
+		c,
+		middleware.GetModelCaches(c),
+		availableSet,
+		modelName,
+		m,
+	)
+	if err != nil {
+		return nil, err
+	}
+	if channel != nil {
+		return &initialChannel{channel: channel, designatedChannel: true}, nil
+	}
+
+	mc := middleware.GetModelCaches(c)
+
+	ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
+	if err != nil {
+		log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
+	}
+	log.Debugf("%s model banned channels: %+v", modelName, ids)
+
+	errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
+	if err != nil {
+		log.Errorf("get channel model error rates failed: %+v", err)
+	}
+
+	channel, migratedChannels, err := getChannelWithFallback(
+		mc,
+		availableSet,
+		modelName,
+		m,
+		errorRates,
+		ids...)
+	if err != nil {
+		return nil, err
+	}
+
+	return &initialChannel{
+		channel:          channel,
+		ignoreChannelIDs: ids,
+		errorRates:       errorRates,
+		migratedChannels: migratedChannels,
+	}, nil
+}
+
+func getWebSearchChannel(c *gin.Context, modelName string) (*model.Channel, error) {
+	log := common.GetLogger(c)
+	mc := middleware.GetModelCaches(c)
+
+	ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
+	if err != nil {
+		log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
+	}
+	log.Debugf("%s model banned channels: %+v", modelName, ids)
+
+	errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
+	if err != nil {
+		log.Errorf("get channel model error rates failed: %+v", err)
+	}
+
+	channel, _, err := getChannelWithFallback(
+		mc,
+		nil,
+		modelName,
+		mode.ChatCompletions,
+		errorRates,
+		ids...)
+	if err != nil {
+		return nil, err
+	}
+
+	return channel, nil
+}
+
+func getRetryChannel(state *retryState) (*model.Channel, error) {
+	if state.exhausted {
+		if state.lastHasPermissionChannel == nil {
+			return nil, ErrChannelsExhausted
+		}
+		return state.lastHasPermissionChannel, nil
+	}
+
+	newChannel, err := getRandomChannel(
+		state.migratedChannels,
+		state.meta.Mode,
+		state.errorRates,
+		state.ignoreChannelIDs...)
+	if err != nil {
+		if !errors.Is(err, ErrChannelsExhausted) || state.lastHasPermissionChannel == nil {
+			return nil, err
+		}
+		state.exhausted = true
+		return state.lastHasPermissionChannel, nil
+	}
+
+	return newChannel, nil
+}
+
+func filterChannels(
+	channels []*model.Channel,
+	mode mode.Mode,
+	ignoreChannel ...int64,
+) []*model.Channel {
+	filtered := make([]*model.Channel, 0)
+	for _, channel := range channels {
+		if channel.Status != model.ChannelStatusEnabled {
+			continue
+		}
+		a, ok := adaptors.GetAdaptor(channel.Type)
+		if !ok {
+			continue
+		}
+		if !a.SupportMode(mode) {
+			continue
+		}
+		if slices.Contains(ignoreChannel, int64(channel.ID)) {
+			continue
+		}
+		filtered = append(filtered, channel)
+	}
+	return filtered
+}

+ 12 - 655
core/controller/relay-controller.go

@@ -2,13 +2,11 @@ package controller
 
 import (
 	"bytes"
-	"context"
 	"errors"
 	"fmt"
 	"io"
 	"math/rand/v2"
 	"net/http"
-	"slices"
 	"strconv"
 	"time"
 
@@ -19,12 +17,8 @@ import (
 	"github.com/labring/aiproxy/core/common/config"
 	"github.com/labring/aiproxy/core/common/consume"
 	"github.com/labring/aiproxy/core/common/conv"
-	"github.com/labring/aiproxy/core/common/notify"
-	"github.com/labring/aiproxy/core/common/reqlimit"
-	"github.com/labring/aiproxy/core/common/trylock"
 	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
-	"github.com/labring/aiproxy/core/monitor"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptors"
 	"github.com/labring/aiproxy/core/relay/controller"
@@ -33,9 +27,9 @@ import (
 	relaymodel "github.com/labring/aiproxy/core/relay/model"
 	"github.com/labring/aiproxy/core/relay/plugin"
 	"github.com/labring/aiproxy/core/relay/plugin/cache"
+	monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
 	"github.com/labring/aiproxy/core/relay/plugin/thinksplit"
 	websearch "github.com/labring/aiproxy/core/relay/plugin/web-search"
-	log "github.com/sirupsen/logrus"
 )
 
 // https://platform.openai.com/docs/api-reference/chat
@@ -52,114 +46,6 @@ type RelayController struct {
 	Handler         RelayHandler
 }
 
-// TODO: convert to plugin
-type wrapAdaptor struct {
-	adaptor.Adaptor
-}
-
-const (
-	MetaChannelModelKeyRPM = "channel_model_rpm"
-	MetaChannelModelKeyRPS = "channel_model_rps"
-	MetaChannelModelKeyTPM = "channel_model_tpm"
-	MetaChannelModelKeyTPS = "channel_model_tps"
-)
-
-func getChannelModelRequestRate(c *gin.Context, meta *meta.Meta) model.RequestRate {
-	rate := model.RequestRate{}
-
-	if rpm, ok := meta.Get(MetaChannelModelKeyRPM); ok {
-		rate.RPM, _ = rpm.(int64)
-		rate.RPS = meta.GetInt64(MetaChannelModelKeyRPS)
-	} else {
-		rpm, rps := reqlimit.GetChannelModelRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
-		rate.RPM = rpm
-		rate.RPS = rps
-		updateChannelModelRequestRate(c, meta, rpm, rps)
-	}
-
-	if tpm, ok := meta.Get(MetaChannelModelKeyTPM); ok {
-		rate.TPM, _ = tpm.(int64)
-		rate.TPS = meta.GetInt64(MetaChannelModelKeyTPS)
-	} else {
-		tpm, tps := reqlimit.GetChannelModelTokensRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
-		rate.TPM = tpm
-		rate.TPS = tps
-		updateChannelModelTokensRequestRate(c, meta, tpm, tps)
-	}
-
-	return rate
-}
-
-func updateChannelModelRequestRate(c *gin.Context, meta *meta.Meta, rpm, rps int64) {
-	meta.Set(MetaChannelModelKeyRPM, rpm)
-	meta.Set(MetaChannelModelKeyRPS, rps)
-	log := middleware.GetLogger(c)
-	log.Data["ch_rpm"] = rpm
-	log.Data["ch_rps"] = rps
-}
-
-func updateChannelModelTokensRequestRate(c *gin.Context, meta *meta.Meta, tpm, tps int64) {
-	meta.Set(MetaChannelModelKeyTPM, tpm)
-	meta.Set(MetaChannelModelKeyTPS, tps)
-	log := middleware.GetLogger(c)
-	log.Data["ch_tpm"] = tpm
-	log.Data["ch_tps"] = tps
-}
-
-func (w *wrapAdaptor) DoRequest(
-	meta *meta.Meta,
-	store adaptor.Store,
-	c *gin.Context,
-	req *http.Request,
-) (*http.Response, error) {
-	count, overLimitCount, secondCount := reqlimit.PushChannelModelRequest(
-		context.Background(),
-		strconv.Itoa(meta.Channel.ID),
-		meta.OriginModel,
-	)
-	updateChannelModelRequestRate(c, meta, count+overLimitCount, secondCount)
-	return w.Adaptor.DoRequest(meta, store, c, req)
-}
-
-func (w *wrapAdaptor) DoResponse(
-	meta *meta.Meta,
-	store adaptor.Store,
-	c *gin.Context,
-	resp *http.Response,
-) (model.Usage, adaptor.Error) {
-	usage, relayErr := w.Adaptor.DoResponse(meta, store, c, resp)
-
-	if usage.TotalTokens > 0 {
-		count, overLimitCount, secondCount := reqlimit.PushChannelModelTokensRequest(
-			context.Background(),
-			strconv.Itoa(meta.Channel.ID),
-			meta.OriginModel,
-			int64(usage.TotalTokens),
-		)
-		updateChannelModelTokensRequestRate(c, meta, count+overLimitCount, secondCount)
-
-		count, overLimitCount, secondCount = reqlimit.PushGroupModelTokensRequest(
-			context.Background(),
-			meta.Group.ID,
-			meta.OriginModel,
-			meta.ModelConfig.TPM,
-			int64(usage.TotalTokens),
-		)
-		middleware.UpdateGroupModelTokensRequest(c, meta.Group, count+overLimitCount, secondCount)
-
-		count, overLimitCount, secondCount = reqlimit.PushGroupModelTokennameTokensRequest(
-			context.Background(),
-			meta.Group.ID,
-			meta.OriginModel,
-			meta.Token.Name,
-			int64(usage.TotalTokens),
-		)
-		middleware.UpdateGroupModelTokennameTokensRequest(c, count+overLimitCount, secondCount)
-	}
-
-	return usage, relayErr
-}
-
 var adaptorStore adaptor.Store = &storeImpl{}
 
 type storeImpl struct{}
@@ -192,7 +78,7 @@ func (s *storeImpl) SaveStore(store adaptor.StoreCache) error {
 }
 
 func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 	middleware.SetLogFieldsFromMeta(meta, log.Data)
 
 	adaptor, ok := adaptors.GetAdaptor(meta.Channel.Type)
@@ -206,12 +92,14 @@ func relayHandler(c *gin.Context, meta *meta.Meta) *controller.HandleResult {
 		}
 	}
 
-	a := plugin.WrapperAdaptor(&wrapAdaptor{adaptor},
+	a := plugin.WrapperAdaptor(adaptor,
+		monitorplugin.NewGroupMonitorPlugin(),
 		cache.NewCachePlugin(common.RDB),
 		websearch.NewWebSearchPlugin(func(modelName string) (*model.Channel, error) {
 			return getWebSearchChannel(c, modelName)
 		}),
 		thinksplit.NewThinkPlugin(),
+		monitorplugin.NewChannelMonitorPlugin(),
 	)
 
 	return controller.Handle(a, c, meta, adaptorStore)
@@ -259,123 +147,6 @@ func relayController(m mode.Mode) RelayController {
 	return c
 }
 
-const (
-	AIProxyChannelHeader = "Aiproxy-Channel"
-)
-
-func GetChannelFromHeader(
-	header string,
-	mc *model.ModelCaches,
-	availableSet []string,
-	model string,
-	m mode.Mode,
-) (*model.Channel, error) {
-	channelIDInt, err := strconv.ParseInt(header, 10, 64)
-	if err != nil {
-		return nil, err
-	}
-
-	for _, set := range availableSet {
-		enabledChannels := mc.EnabledModel2ChannelsBySet[set][model]
-		if len(enabledChannels) > 0 {
-			for _, channel := range enabledChannels {
-				if int64(channel.ID) == channelIDInt {
-					a, ok := adaptors.GetAdaptor(channel.Type)
-					if !ok {
-						return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
-					}
-					if !a.SupportMode(m) {
-						return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
-					}
-					return channel, nil
-				}
-			}
-		}
-
-		disabledChannels := mc.DisabledModel2ChannelsBySet[set][model]
-		if len(disabledChannels) > 0 {
-			for _, channel := range disabledChannels {
-				if int64(channel.ID) == channelIDInt {
-					a, ok := adaptors.GetAdaptor(channel.Type)
-					if !ok {
-						return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
-					}
-					if !a.SupportMode(m) {
-						return nil, fmt.Errorf("channel %d not supported by adaptor", channel.ID)
-					}
-					return channel, nil
-				}
-			}
-		}
-	}
-
-	return nil, fmt.Errorf("channel %d not found for model `%s`", channelIDInt, model)
-}
-
-func GetChannelFromRequest(
-	c *gin.Context,
-	mc *model.ModelCaches,
-	availableSet []string,
-	modelName string,
-	m mode.Mode,
-) (*model.Channel, error) {
-	switch m {
-	case mode.VideoGenerationsGetJobs,
-		mode.VideoGenerationsContent:
-		channelID := middleware.GetChannelID(c)
-		if channelID == 0 {
-			return nil, errors.New("channel id is required")
-		}
-		for _, set := range availableSet {
-			enabledChannels := mc.EnabledModel2ChannelsBySet[set][modelName]
-			if len(enabledChannels) > 0 {
-				for _, channel := range enabledChannels {
-					if channel.ID == channelID {
-						a, ok := adaptors.GetAdaptor(channel.Type)
-						if !ok {
-							return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
-						}
-						if !a.SupportMode(m) {
-							return nil, fmt.Errorf(
-								"channel %d not supported by adaptor",
-								channel.ID,
-							)
-						}
-						return channel, nil
-					}
-				}
-			}
-		}
-		return nil, fmt.Errorf("channel %d not found for model `%s`", channelID, modelName)
-	default:
-		channelID := middleware.GetChannelID(c)
-		if channelID == 0 {
-			return nil, nil
-		}
-		for _, set := range availableSet {
-			enabledChannels := mc.EnabledModel2ChannelsBySet[set][modelName]
-			if len(enabledChannels) > 0 {
-				for _, channel := range enabledChannels {
-					if channel.ID == channelID {
-						a, ok := adaptors.GetAdaptor(channel.Type)
-						if !ok {
-							return nil, fmt.Errorf("adaptor not found for channel %d", channel.ID)
-						}
-						if !a.SupportMode(m) {
-							return nil, fmt.Errorf(
-								"channel %d not supported by adaptor",
-								channel.ID,
-							)
-						}
-						return channel, nil
-					}
-				}
-			}
-		}
-	}
-	return nil, nil
-}
-
 func RelayHelper(
 	c *gin.Context,
 	meta *meta.Meta,
@@ -383,267 +154,9 @@ func RelayHelper(
 ) (*controller.HandleResult, bool) {
 	result := handel(c, meta)
 	if result.Error == nil {
-		if _, _, err := monitor.AddRequest(
-			context.Background(),
-			meta.OriginModel,
-			int64(meta.Channel.ID),
-			false,
-			false,
-			meta.ModelConfig.MaxErrorRate,
-		); err != nil {
-			log.Errorf("add request failed: %+v", err)
-		}
 		return result, false
 	}
-	shouldRetry := shouldRetry(c, result.Error)
-	if shouldRetry {
-		hasPermission := channelHasPermission(result.Error)
-		beyondThreshold, banExecution, err := monitor.AddRequest(
-			context.Background(),
-			meta.OriginModel,
-			int64(meta.Channel.ID),
-			true,
-			!hasPermission,
-			meta.ModelConfig.MaxErrorRate,
-		)
-		if err != nil {
-			log.Errorf("add request failed: %+v", err)
-		}
-		switch {
-		case banExecution:
-			notifyChannelIssue(c, meta, "autoBanned", "Auto Banned", result.Error)
-		case beyondThreshold:
-			notifyChannelIssue(
-				c,
-				meta,
-				"beyondThreshold",
-				"Error Rate Beyond Threshold",
-				result.Error,
-			)
-		case !hasPermission:
-			notifyChannelIssue(c, meta, "channelHasPermission", "No Permission", result.Error)
-		}
-	}
-	return result, shouldRetry
-}
-
-func notifyChannelIssue(
-	c *gin.Context,
-	meta *meta.Meta,
-	issueType, titleSuffix string,
-	err adaptor.Error,
-) {
-	var notifyFunc func(title, message string)
-
-	lockKey := fmt.Sprintf("%s:%d:%s", issueType, meta.Channel.ID, meta.OriginModel)
-	switch issueType {
-	case "beyondThreshold":
-		notifyFunc = func(title, message string) {
-			notify.WarnThrottle(lockKey, time.Minute, title, message)
-		}
-	default:
-		notifyFunc = func(title, message string) {
-			notify.ErrorThrottle(lockKey, time.Minute, title, message)
-		}
-	}
-
-	respBody, _ := err.MarshalJSON()
-
-	message := fmt.Sprintf(
-		"channel: %s (type: %d, type name: %s, id: %d)\nmodel: %s\nmode: %s\nstatus code: %d\ndetail: %s\nrequest id: %s",
-		meta.Channel.Name,
-		meta.Channel.Type,
-		meta.Channel.Type.String(),
-		meta.Channel.ID,
-		meta.OriginModel,
-		meta.Mode,
-		err.StatusCode(),
-		conv.BytesToString(respBody),
-		meta.RequestID,
-	)
-
-	if err.StatusCode() == http.StatusTooManyRequests {
-		if !trylock.Lock(lockKey, time.Minute) {
-			return
-		}
-		switch issueType {
-		case "beyondThreshold":
-			notifyFunc = notify.Warn
-		default:
-			notifyFunc = notify.Error
-		}
-
-		rate := getChannelModelRequestRate(c, meta)
-		message += fmt.Sprintf(
-			"\nrpm: %d\nrps: %d\ntpm: %d\ntps: %d",
-			rate.RPM,
-			rate.RPS,
-			rate.TPM,
-			rate.TPS,
-		)
-	}
-
-	notifyFunc(
-		fmt.Sprintf("%s `%s` %s", meta.Channel.Name, meta.OriginModel, titleSuffix),
-		message,
-	)
-}
-
-func filterChannels(
-	channels []*model.Channel,
-	mode mode.Mode,
-	ignoreChannel ...int64,
-) []*model.Channel {
-	filtered := make([]*model.Channel, 0)
-	for _, channel := range channels {
-		if channel.Status != model.ChannelStatusEnabled {
-			continue
-		}
-		a, ok := adaptors.GetAdaptor(channel.Type)
-		if !ok {
-			continue
-		}
-		if !a.SupportMode(mode) {
-			continue
-		}
-		if slices.Contains(ignoreChannel, int64(channel.ID)) {
-			continue
-		}
-		filtered = append(filtered, channel)
-	}
-	return filtered
-}
-
-var (
-	ErrChannelsNotFound  = errors.New("channels not found")
-	ErrChannelsExhausted = errors.New("channels exhausted")
-)
-
-func GetRandomChannel(
-	mc *model.ModelCaches,
-	availableSet []string,
-	modelName string,
-	mode mode.Mode,
-	errorRates map[int64]float64,
-	ignoreChannel ...int64,
-) (*model.Channel, []*model.Channel, error) {
-	channelMap := make(map[int]*model.Channel)
-	if len(availableSet) != 0 {
-		for _, set := range availableSet {
-			for _, channel := range mc.EnabledModel2ChannelsBySet[set][modelName] {
-				a, ok := adaptors.GetAdaptor(channel.Type)
-				if !ok {
-					continue
-				}
-				if !a.SupportMode(mode) {
-					continue
-				}
-				channelMap[channel.ID] = channel
-			}
-		}
-	} else {
-		for _, sets := range mc.EnabledModel2ChannelsBySet {
-			for _, channel := range sets[modelName] {
-				a, ok := adaptors.GetAdaptor(channel.Type)
-				if !ok {
-					continue
-				}
-				if !a.SupportMode(mode) {
-					continue
-				}
-				channelMap[channel.ID] = channel
-			}
-		}
-	}
-	migratedChannels := make([]*model.Channel, 0, len(channelMap))
-	for _, channel := range channelMap {
-		migratedChannels = append(migratedChannels, channel)
-	}
-	channel, err := getRandomChannel(migratedChannels, mode, errorRates, ignoreChannel...)
-	return channel, migratedChannels, err
-}
-
-func getPriority(channel *model.Channel, errorRate float64) int32 {
-	priority := channel.GetPriority()
-	if errorRate > 1 {
-		errorRate = 1
-	} else if errorRate < 0.1 {
-		errorRate = 0.1
-	}
-	return int32(float64(priority) / errorRate)
-}
-
-func getRandomChannel(
-	channels []*model.Channel,
-	mode mode.Mode,
-	errorRates map[int64]float64,
-	ignoreChannel ...int64,
-) (*model.Channel, error) {
-	if len(channels) == 0 {
-		return nil, ErrChannelsNotFound
-	}
-
-	channels = filterChannels(channels, mode, ignoreChannel...)
-	if len(channels) == 0 {
-		return nil, ErrChannelsExhausted
-	}
-
-	if len(channels) == 1 {
-		return channels[0], nil
-	}
-
-	var totalWeight int32
-	cachedPrioritys := make([]int32, len(channels))
-	for i, ch := range channels {
-		priority := getPriority(ch, errorRates[int64(ch.ID)])
-		totalWeight += priority
-		cachedPrioritys[i] = priority
-	}
-
-	if totalWeight == 0 {
-		return channels[rand.IntN(len(channels))], nil
-	}
-
-	r := rand.Int32N(totalWeight)
-	for i, ch := range channels {
-		r -= cachedPrioritys[i]
-		if r < 0 {
-			return ch, nil
-		}
-	}
-
-	return channels[rand.IntN(len(channels))], nil
-}
-
-func getChannelWithFallback(
-	cache *model.ModelCaches,
-	availableSet []string,
-	modelName string,
-	mode mode.Mode,
-	errorRates map[int64]float64,
-	ignoreChannelIDs ...int64,
-) (*model.Channel, []*model.Channel, error) {
-	channel, migratedChannels, err := GetRandomChannel(
-		cache,
-		availableSet,
-		modelName,
-		mode,
-		errorRates,
-		ignoreChannelIDs...)
-	if err == nil {
-		return channel, migratedChannels, nil
-	}
-	if !errors.Is(err, ErrChannelsExhausted) {
-		return nil, migratedChannels, err
-	}
-	channel, migratedChannels, err = GetRandomChannel(
-		cache,
-		availableSet,
-		modelName,
-		mode,
-		errorRates,
-	)
-	return channel, migratedChannels, err
+	return result, monitorplugin.ShouldRetry(result.Error)
 }
 
 func NewRelay(mode mode.Mode) func(c *gin.Context) {
@@ -784,7 +297,7 @@ func recordResult(
 		price,
 	)
 	if amount > 0 {
-		log := middleware.GetLogger(c)
+		log := common.GetLogger(c)
 		log.Data["amount"] = strconv.FormatFloat(amount, 'f', -1, 64)
 	}
 
@@ -802,7 +315,7 @@ func recordResult(
 		downstreamResult,
 		user,
 		metadata,
-		getChannelModelRequestRate(c, meta),
+		monitorplugin.GetChannelModelRequestRate(c, meta),
 		middleware.GetGroupModelTokenRequestRate(c),
 	)
 }
@@ -821,113 +334,6 @@ type retryState struct {
 	migratedChannels []*model.Channel
 }
 
-type initialChannel struct {
-	channel           *model.Channel
-	designatedChannel bool
-	ignoreChannelIDs  []int64
-	errorRates        map[int64]float64
-	migratedChannels  []*model.Channel
-}
-
-func getInitialChannel(c *gin.Context, modelName string, m mode.Mode) (*initialChannel, error) {
-	log := middleware.GetLogger(c)
-
-	group := middleware.GetGroup(c)
-	availableSet := group.GetAvailableSets()
-
-	if channelHeader := c.Request.Header.Get(AIProxyChannelHeader); channelHeader != "" {
-		if group.Status != model.GroupStatusInternal {
-			return nil, errors.New("channel header is not allowed in non-internal group")
-		}
-		channel, err := GetChannelFromHeader(
-			channelHeader,
-			middleware.GetModelCaches(c),
-			availableSet,
-			modelName,
-			m,
-		)
-		if err != nil {
-			return nil, err
-		}
-		log.Data["designated_channel"] = "true"
-		return &initialChannel{channel: channel, designatedChannel: true}, nil
-	}
-
-	channel, err := GetChannelFromRequest(
-		c,
-		middleware.GetModelCaches(c),
-		availableSet,
-		modelName,
-		m,
-	)
-	if err != nil {
-		return nil, err
-	}
-	if channel != nil {
-		return &initialChannel{channel: channel, designatedChannel: true}, nil
-	}
-
-	mc := middleware.GetModelCaches(c)
-
-	ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
-	if err != nil {
-		log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
-	}
-	log.Debugf("%s model banned channels: %+v", modelName, ids)
-
-	errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
-	if err != nil {
-		log.Errorf("get channel model error rates failed: %+v", err)
-	}
-
-	channel, migratedChannels, err := getChannelWithFallback(
-		mc,
-		availableSet,
-		modelName,
-		m,
-		errorRates,
-		ids...)
-	if err != nil {
-		return nil, err
-	}
-
-	return &initialChannel{
-		channel:          channel,
-		ignoreChannelIDs: ids,
-		errorRates:       errorRates,
-		migratedChannels: migratedChannels,
-	}, nil
-}
-
-func getWebSearchChannel(c *gin.Context, modelName string) (*model.Channel, error) {
-	log := middleware.GetLogger(c)
-	mc := middleware.GetModelCaches(c)
-
-	ids, err := monitor.GetBannedChannelsWithModel(c.Request.Context(), modelName)
-	if err != nil {
-		log.Errorf("get %s auto banned channels failed: %+v", modelName, err)
-	}
-	log.Debugf("%s model banned channels: %+v", modelName, ids)
-
-	errorRates, err := monitor.GetModelChannelErrorRate(c.Request.Context(), modelName)
-	if err != nil {
-		log.Errorf("get channel model error rates failed: %+v", err)
-	}
-
-	channel, _, err := getChannelWithFallback(
-		mc,
-		nil,
-		modelName,
-		mode.ChatCompletions,
-		errorRates,
-		ids...)
-	if err != nil {
-		return nil, err
-	}
-
-	return channel, nil
-}
-
 func handleRelayResult(
 	c *gin.Context,
 	bizErr adaptor.Error,
@@ -968,7 +374,7 @@ func initRetryState(
 		state.exhausted = true
 	}
 
-	if !channelHasPermission(result.Error) {
+	if !monitorplugin.ChannelHasPermission(result.Error) {
 		state.ignoreChannelIDs = append(state.ignoreChannelIDs, int64(channel.channel.ID))
 	} else {
 		state.lastHasPermissionChannel = channel.channel
@@ -978,7 +384,7 @@ func initRetryState(
 }
 
 func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayController RelayHandler) {
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	// do not use for i := range state.retryTimes, because the retryTimes is constant
 	i := 0
@@ -1072,30 +478,6 @@ func retryLoop(c *gin.Context, mode mode.Mode, state *retryState, relayControlle
 	}
 }
 
-func getRetryChannel(state *retryState) (*model.Channel, error) {
-	if state.exhausted {
-		if state.lastHasPermissionChannel == nil {
-			return nil, ErrChannelsExhausted
-		}
-		return state.lastHasPermissionChannel, nil
-	}
-
-	newChannel, err := getRandomChannel(
-		state.migratedChannels,
-		state.meta.Mode,
-		state.errorRates,
-		state.ignoreChannelIDs...)
-	if err != nil {
-		if !errors.Is(err, ErrChannelsExhausted) || state.lastHasPermissionChannel == nil {
-			return nil, err
-		}
-		state.exhausted = true
-		return state.lastHasPermissionChannel, nil
-	}
-
-	return newChannel, nil
-}
-
 func prepareRetry(c *gin.Context) error {
 	requestBody, err := common.GetRequestBody(c.Request)
 	if err != nil {
@@ -1118,7 +500,7 @@ func handleRetryResult(
 		return true
 	}
 
-	hasPermission := channelHasPermission(state.result.Error)
+	hasPermission := monitorplugin.ChannelHasPermission(state.result.Error)
 
 	if state.exhausted {
 		if !hasPermission {
@@ -1136,31 +518,6 @@ func handleRetryResult(
 	return false
 }
 
-var channelNoRetryStatusCodesMap = map[int]struct{}{
-	http.StatusBadRequest:                 {},
-	http.StatusRequestEntityTooLarge:      {},
-	http.StatusUnprocessableEntity:        {},
-	http.StatusUnavailableForLegalReasons: {},
-}
-
-// 仅当是channel错误时,才需要记录,用户请求参数错误时,不需要记录
-func shouldRetry(_ *gin.Context, relayErr adaptor.Error) bool {
-	_, ok := channelNoRetryStatusCodesMap[relayErr.StatusCode()]
-	return !ok
-}
-
-var channelNoPermissionStatusCodesMap = map[int]struct{}{
-	http.StatusUnauthorized:    {},
-	http.StatusPaymentRequired: {},
-	http.StatusForbidden:       {},
-	http.StatusNotFound:        {},
-}
-
-func channelHasPermission(relayErr adaptor.Error) bool {
-	_, ok := channelNoPermissionStatusCodesMap[relayErr.StatusCode()]
-	return !ok
-}
-
 // shouldDelay checks if we need to add a delay before retrying
 // Only adds delay when retrying with the same channel for rate limiting issues
 func shouldDelay(statusCode, lastChannelID, newChannelID int) bool {
@@ -1193,7 +550,7 @@ func ErrorWithRequestID(c *gin.Context, relayErr adaptor.Error) {
 		c.JSON(relayErr.StatusCode(), relayErr)
 		return
 	}
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 	data, err := relayErr.MarshalJSON()
 	if err != nil {
 		log.Errorf("marshal error failed: %+v", err)

+ 3 - 2
core/middleware/auth.go

@@ -8,6 +8,7 @@ import (
 	"strings"
 
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/config"
 	"github.com/labring/aiproxy/core/common/network"
 	"github.com/labring/aiproxy/core/model"
@@ -62,7 +63,7 @@ func AdminAuth(c *gin.Context) {
 
 	group := c.Param("group")
 	if group != "" {
-		log := GetLogger(c)
+		log := common.GetLogger(c)
 		log.Data["gid"] = group
 	}
 
@@ -70,7 +71,7 @@ func AdminAuth(c *gin.Context) {
 }
 
 func TokenAuth(c *gin.Context) {
-	log := GetLogger(c)
+	log := common.GetLogger(c)
 	key := c.Request.Header.Get("Authorization")
 	if key == "" {
 		key = c.Request.Header.Get("X-Api-Key")

+ 14 - 18
core/middleware/ctxkey.go

@@ -1,22 +1,18 @@
 package middleware
 
 const (
-	ChannelID          = "channel_id"
-	GroupModelTokenRPM = "group_model_token_rpm"
-	GroupModelTokenRPS = "group_model_token_rps"
-	GroupModelTokenTPM = "group_model_token_tpm"
-	GroupModelTokenTPS = "group_model_token_tps"
-	Group              = "group"
-	Token              = "token"
-	GroupBalance       = "group_balance"
-	RequestModel       = "request_model"
-	RequestUser        = "request_user"
-	RequestMetadata    = "request_metadata"
-	RequestAt          = "request_at"
-	RequestID          = "request_id"
-	ModelCaches        = "model_caches"
-	ModelConfig        = "model_config"
-	Mode               = "mode"
-	JobID              = "job_id"
-	GenerationID       = "generation_id"
+	ChannelID       = "channel_id"
+	Group           = "group"
+	Token           = "token"
+	GroupBalance    = "group_balance"
+	RequestModel    = "request_model"
+	RequestUser     = "request_user"
+	RequestMetadata = "request_metadata"
+	RequestAt       = "request_at"
+	RequestID       = "request_id"
+	ModelCaches     = "model_caches"
+	ModelConfig     = "model_config"
+	Mode            = "mode"
+	JobID           = "job_id"
+	GenerationID    = "generation_id"
 )

+ 16 - 63
core/middleware/distributor.go

@@ -20,6 +20,7 @@ import (
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/meta"
 	"github.com/labring/aiproxy/core/relay/mode"
+	monitorplugin "github.com/labring/aiproxy/core/relay/plugin/monitor"
 )
 
 func calculateGroupConsumeLevelRatio(usedAmount float64) float64 {
@@ -96,49 +97,13 @@ func setTpmHeaders(c *gin.Context, tpm, remainingRequests int64) {
 	c.Header(XRateLimitResetTokens, "1m0s")
 }
 
-func UpdateGroupModelRequest(c *gin.Context, group model.GroupCache, rpm, rps int64) {
-	if group.Status == model.GroupStatusInternal {
-		return
-	}
-
-	log := GetLogger(c)
-	log.Data["group_rpm"] = strconv.FormatInt(rpm, 10)
-	log.Data["group_rps"] = strconv.FormatInt(rps, 10)
-}
-
-func UpdateGroupModelTokensRequest(c *gin.Context, group model.GroupCache, tpm, tps int64) {
-	if group.Status == model.GroupStatusInternal {
-		return
-	}
-
-	log := GetLogger(c)
-	log.Data["group_tpm"] = strconv.FormatInt(tpm, 10)
-	log.Data["group_tps"] = strconv.FormatInt(tps, 10)
-}
-
-func UpdateGroupModelTokennameRequest(c *gin.Context, rpm, rps int64) {
-	c.Set(GroupModelTokenRPM, rpm)
-	c.Set(GroupModelTokenRPS, rps)
-	// log := GetLogger(c)
-	// log.Data["rpm"] = strconv.FormatInt(rpm, 10)
-	// log.Data["rps"] = strconv.FormatInt(rps, 10)
-}
-
-func UpdateGroupModelTokennameTokensRequest(c *gin.Context, tpm, tps int64) {
-	c.Set(GroupModelTokenTPM, tpm)
-	c.Set(GroupModelTokenTPS, tps)
-	// log := GetLogger(c)
-	// log.Data["tpm"] = strconv.FormatInt(tpm, 10)
-	// log.Data["tps"] = strconv.FormatInt(tps, 10)
-}
-
 func checkGroupModelRPMAndTPM(
 	c *gin.Context,
 	group model.GroupCache,
 	mc model.ModelConfig,
 	tokenName string,
 ) error {
-	log := GetLogger(c)
+	log := common.GetLogger(c)
 
 	adjustedModelConfig := GetGroupAdjustedModelConfig(group, mc)
 
@@ -148,7 +113,7 @@ func checkGroupModelRPMAndTPM(
 		mc.Model,
 		adjustedModelConfig.RPM,
 	)
-	UpdateGroupModelRequest(
+	monitorplugin.UpdateGroupModelRequest(
 		c,
 		group,
 		groupModelCount+groupModelOverLimitCount,
@@ -161,7 +126,7 @@ func checkGroupModelRPMAndTPM(
 		mc.Model,
 		tokenName,
 	)
-	UpdateGroupModelTokennameRequest(
+	monitorplugin.UpdateGroupModelTokennameRequest(
 		c,
 		groupModelTokenCount+groupModelTokenOverLimitCount,
 		groupModelTokenSecondCount,
@@ -182,7 +147,7 @@ func checkGroupModelRPMAndTPM(
 		group.ID,
 		mc.Model,
 	)
-	UpdateGroupModelTokensRequest(c, group, groupModelCountTPM, groupModelCountTPS)
+	monitorplugin.UpdateGroupModelTokensRequest(c, group, groupModelCountTPM, groupModelCountTPS)
 
 	groupModelTokenCountTPM, groupModelTokenCountTPS := reqlimit.GetGroupModelTokennameTokensRequest(
 		c.Request.Context(),
@@ -190,7 +155,11 @@ func checkGroupModelRPMAndTPM(
 		mc.Model,
 		tokenName,
 	)
-	UpdateGroupModelTokennameTokensRequest(c, groupModelTokenCountTPM, groupModelTokenCountTPS)
+	monitorplugin.UpdateGroupModelTokennameTokensRequest(
+		c,
+		groupModelTokenCountTPM,
+		groupModelTokenCountTPS,
+	)
 
 	if group.Status != model.GroupStatusInternal &&
 		adjustedModelConfig.TPM > 0 {
@@ -242,7 +211,7 @@ func GetGroupBalanceConsumer(
 			Consumer: nil,
 		}
 	} else {
-		log := GetLogger(c)
+		log := common.GetLogger(c)
 		groupBalance, consumer, err := balance.GetGroupRemainBalance(c.Request.Context(), group)
 		if err != nil {
 			return nil, err
@@ -356,7 +325,7 @@ func distribute(c *gin.Context, mode mode.Mode) {
 		return
 	}
 
-	log := GetLogger(c)
+	log := common.GetLogger(c)
 
 	group := GetGroup(c)
 	token := GetToken(c)
@@ -464,29 +433,13 @@ func distribute(c *gin.Context, mode mode.Mode) {
 
 func GetGroupModelTokenRequestRate(c *gin.Context) model.RequestRate {
 	return model.RequestRate{
-		RPM: GetGroupModelTokenRPM(c),
-		RPS: GetGroupModelTokenRPS(c),
-		TPM: GetGroupModelTokenTPM(c),
-		TPS: GetGroupModelTokenTPS(c),
+		RPM: monitorplugin.GetGroupModelTokenRPM(c),
+		RPS: monitorplugin.GetGroupModelTokenRPS(c),
+		TPM: monitorplugin.GetGroupModelTokenTPM(c),
+		TPS: monitorplugin.GetGroupModelTokenTPS(c),
 	}
 }
 
-func GetGroupModelTokenRPM(c *gin.Context) int64 {
-	return c.GetInt64(GroupModelTokenRPM)
-}
-
-func GetGroupModelTokenRPS(c *gin.Context) int64 {
-	return c.GetInt64(GroupModelTokenRPS)
-}
-
-func GetGroupModelTokenTPM(c *gin.Context) int64 {
-	return c.GetInt64(GroupModelTokenTPM)
-}
-
-func GetGroupModelTokenTPS(c *gin.Context) int64 {
-	return c.GetInt64(GroupModelTokenTPS)
-}
-
 func GetRequestModel(c *gin.Context) string {
 	return c.GetString(RequestModel)
 }

+ 3 - 37
core/middleware/log.go

@@ -3,28 +3,18 @@ package middleware
 import (
 	"fmt"
 	"net/http"
-	"sync"
 	"time"
 
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/sirupsen/logrus"
 )
 
-var fieldsPool = sync.Pool{
-	New: func() any {
-		return make(logrus.Fields, 6)
-	},
-}
-
 func NewLog(l *logrus.Logger) gin.HandlerFunc {
 	return func(c *gin.Context) {
-		fields, ok := fieldsPool.Get().(logrus.Fields)
-		if !ok {
-			panic(fmt.Sprintf("fields pool type error: %T, %v", fields, fields))
-		}
+		fields := common.GetLogFields()
 		defer func() {
-			clear(fields)
-			fieldsPool.Put(fields)
+			common.PutLogFields(fields)
 		}()
 
 		entry := &logrus.Entry{
@@ -95,27 +85,3 @@ func formatter(param gin.LogFormatterParams) string {
 		param.ErrorMessage,
 	)
 }
-
-func GetLogger(c *gin.Context) *logrus.Entry {
-	if log, ok := c.Get("log"); ok {
-		v, ok := log.(*logrus.Entry)
-		if !ok {
-			panic(fmt.Sprintf("log type error: %T, %v", v, v))
-		}
-		return v
-	}
-	entry := NewLogger()
-	c.Set("log", entry)
-	return entry
-}
-
-func NewLogger() *logrus.Entry {
-	fields, ok := fieldsPool.Get().(logrus.Fields)
-	if !ok {
-		panic(fmt.Sprintf("fields pool type error: %T, %v", fields, fields))
-	}
-	return &logrus.Entry{
-		Logger: logrus.StandardLogger(),
-		Data:   fields,
-	}
-}

+ 2 - 1
core/middleware/mcp.go

@@ -6,13 +6,14 @@ import (
 	"strings"
 
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/config"
 	"github.com/labring/aiproxy/core/common/network"
 	"github.com/labring/aiproxy/core/model"
 )
 
 func MCPAuth(c *gin.Context) {
-	log := GetLogger(c)
+	log := common.GetLogger(c)
 	key := c.Request.Header.Get("Authorization")
 	if key == "" {
 		key, _ = c.GetQuery("key")

+ 2 - 1
core/middleware/reqid.go

@@ -5,6 +5,7 @@ import (
 	"time"
 
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 )
 
 func GenRequestID(t time.Time) string {
@@ -18,7 +19,7 @@ const (
 func SetRequestID(c *gin.Context, id string) {
 	c.Set(RequestID, id)
 	c.Header(RequestIDHeader, id)
-	log := GetLogger(c)
+	log := common.GetLogger(c)
 	SetLogRequestIDField(log.Data, id)
 }
 

+ 3 - 2
core/middleware/utils.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/relay/mode"
 	relaymodel "github.com/labring/aiproxy/core/relay/model"
 )
@@ -15,7 +16,7 @@ func AbortLogWithMessageWithMode(
 	message string,
 	typ ...string,
 ) {
-	GetLogger(c).Error(message)
+	common.GetLogger(c).Error(message)
 	AbortWithMessageWithMode(m, c, statusCode, message, typ...)
 }
 
@@ -33,7 +34,7 @@ func AbortWithMessageWithMode(
 }
 
 func AbortLogWithMessage(c *gin.Context, statusCode int, message string, typ ...string) {
-	GetLogger(c).Error(message)
+	common.GetLogger(c).Error(message)
 	AbortWithMessage(c, statusCode, message, typ...)
 }
 

+ 1 - 2
core/relay/adaptor/ali/embeddings.go

@@ -10,7 +10,6 @@ import (
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
 	"github.com/labring/aiproxy/core/common"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -91,7 +90,7 @@ func EmbeddingsHandler(
 ) (model.Usage, adaptor.Error) {
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {

+ 2 - 2
core/relay/adaptor/ali/image.go

@@ -12,8 +12,8 @@ import (
 
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/image"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/openai"
@@ -69,7 +69,7 @@ func ImageHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	responseFormat, _ := meta.MustGet(MetaResponseFormat).(string)
 

+ 1 - 2
core/relay/adaptor/ali/rerank.go

@@ -9,7 +9,6 @@ import (
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
 	"github.com/labring/aiproxy/core/common"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/openai"
@@ -78,7 +77,7 @@ func RerankHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {

+ 2 - 2
core/relay/adaptor/ali/tts.go

@@ -11,7 +11,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/google/uuid"
 	"github.com/gorilla/websocket"
-	"github.com/labring/aiproxy/core/middleware"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -203,7 +203,7 @@ func TTSDoResponse(
 	c *gin.Context,
 	_ *http.Response,
 ) (usage model.Usage, err adaptor.Error) {
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	conn, ok := meta.MustGet("ws_conn").(*websocket.Conn)
 	if !ok {

+ 1 - 2
core/relay/adaptor/anthropic/main.go

@@ -16,7 +16,6 @@ import (
 	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/conv"
 	"github.com/labring/aiproxy/core/common/image"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/openai"
@@ -169,7 +168,7 @@ func StreamHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	scanner := bufio.NewScanner(resp.Body)
 	buf := openai.GetScannerBuffer()

+ 1 - 2
core/relay/adaptor/anthropic/openai.go

@@ -17,7 +17,6 @@ import (
 	"github.com/labring/aiproxy/core/common/conv"
 	"github.com/labring/aiproxy/core/common/image"
 	"github.com/labring/aiproxy/core/common/render"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/openai"
@@ -459,7 +458,7 @@ func OpenAIStreamHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	scanner := bufio.NewScanner(resp.Body)
 	buf := openai.GetScannerBuffer()

+ 2 - 2
core/relay/adaptor/aws/claude/main.go

@@ -14,8 +14,8 @@ import (
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
 	"github.com/jinzhu/copier"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/render"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/anthropic"
@@ -181,7 +181,7 @@ func Handler(meta *meta.Meta, c *gin.Context) (model.Usage, adaptor.Error) {
 }
 
 func StreamHandler(m *meta.Meta, c *gin.Context) (model.Usage, adaptor.Error) {
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 	awsModelID, err := awsModelID(m.ActualModel)
 	if err != nil {
 		return model.Usage{}, relaymodel.WrapperOpenAIErrorWithMessage(

+ 2 - 2
core/relay/adaptor/aws/llama3/main.go

@@ -14,8 +14,8 @@ import (
 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/render"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/aws/utils"
@@ -200,7 +200,7 @@ func ResponseLlama2OpenAI(meta *meta.Meta, llamaResponse Response) relaymodel.Te
 }
 
 func StreamHandler(meta *meta.Meta, c *gin.Context) (model.Usage, adaptor.Error) {
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	createdTime := time.Now().Unix()
 	awsModelID, err := awsModelID(meta.ActualModel)

+ 2 - 2
core/relay/adaptor/baidu/embeddings.go

@@ -7,7 +7,7 @@ import (
 
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
-	"github.com/labring/aiproxy/core/middleware"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -26,7 +26,7 @@ func EmbeddingsHandler(
 ) (model.Usage, adaptor.Error) {
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	body, err := io.ReadAll(resp.Body)
 	if err != nil {

+ 2 - 2
core/relay/adaptor/baidu/image.go

@@ -7,7 +7,7 @@ import (
 
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
-	"github.com/labring/aiproxy/core/middleware"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -28,7 +28,7 @@ type ImageResponse struct {
 func ImageHandler(_ *meta.Meta, c *gin.Context, resp *http.Response) (model.Usage, adaptor.Error) {
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	body, err := io.ReadAll(resp.Body)
 	if err != nil {

+ 2 - 2
core/relay/adaptor/baidu/main.go

@@ -8,9 +8,9 @@ import (
 
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/conv"
 	"github.com/labring/aiproxy/core/common/render"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/openai"
@@ -139,7 +139,7 @@ func StreamHandler(
 ) (model.Usage, adaptor.Error) {
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	var usage relaymodel.Usage
 	scanner := bufio.NewScanner(resp.Body)

+ 2 - 2
core/relay/adaptor/baidu/rerank.go

@@ -7,7 +7,7 @@ import (
 
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
-	"github.com/labring/aiproxy/core/middleware"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -26,7 +26,7 @@ func RerankHandler(
 ) (model.Usage, adaptor.Error) {
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	respBody, err := io.ReadAll(resp.Body)
 	if err != nil {

+ 2 - 2
core/relay/adaptor/cohere/main.go

@@ -9,9 +9,9 @@ import (
 
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/conv"
 	"github.com/labring/aiproxy/core/common/render"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/openai"
@@ -163,7 +163,7 @@ func StreamHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	scanner := bufio.NewScanner(resp.Body)
 	buf := openai.GetScannerBuffer()

+ 3 - 3
core/relay/adaptor/coze/main.go

@@ -9,9 +9,9 @@ import (
 
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/conv"
 	"github.com/labring/aiproxy/core/common/render"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/coze/constant/messagetype"
@@ -104,7 +104,7 @@ func StreamHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	responseText := strings.Builder{}
 	createdTime := time.Now().Unix()
@@ -166,7 +166,7 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (model.Usage,
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	var cozeResponse Response
 	err := sonic.ConfigDefault.NewDecoder(resp.Body).Decode(&cozeResponse)

+ 2 - 2
core/relay/adaptor/doubaoaudio/tts.go

@@ -14,8 +14,8 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/google/uuid"
 	"github.com/gorilla/websocket"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/conv"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -181,7 +181,7 @@ func TTSDoResponse(
 	c *gin.Context,
 	_ *http.Response,
 ) (model.Usage, adaptor.Error) {
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	conn, ok := meta.MustGet("ws_conn").(*websocket.Conn)
 	if !ok {

+ 1 - 2
core/relay/adaptor/gemini/main.go

@@ -18,7 +18,6 @@ import (
 	"github.com/labring/aiproxy/core/common/conv"
 	"github.com/labring/aiproxy/core/common/image"
 	"github.com/labring/aiproxy/core/common/render"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/openai"
@@ -645,7 +644,7 @@ func StreamHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	responseText := strings.Builder{}
 

+ 2 - 2
core/relay/adaptor/jina/rerank.go

@@ -8,7 +8,7 @@ import (
 	"github.com/bytedance/sonic"
 	"github.com/bytedance/sonic/ast"
 	"github.com/gin-gonic/gin"
-	"github.com/labring/aiproxy/core/middleware"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -26,7 +26,7 @@ func RerankHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {

+ 3 - 3
core/relay/adaptor/minimax/tts.go

@@ -11,7 +11,7 @@ import (
 
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
-	"github.com/labring/aiproxy/core/middleware"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/openai"
@@ -129,7 +129,7 @@ func TTSHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	body, err := io.ReadAll(resp.Body)
 	if err != nil {
@@ -196,7 +196,7 @@ func ttsStreamHandler(
 
 	resp.Header.Set("Content-Type", "application/octet-stream")
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	scanner := bufio.NewScanner(resp.Body)
 	buf := openai.GetScannerBuffer()

+ 1 - 2
core/relay/adaptor/ollama/main.go

@@ -12,7 +12,6 @@ import (
 	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/image"
 	"github.com/labring/aiproxy/core/common/render"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/adaptor/openai"
@@ -216,7 +215,7 @@ func StreamHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	var usage *relaymodel.Usage
 	scanner := bufio.NewScanner(resp.Body)

+ 2 - 3
core/relay/adaptor/openai/chat.go

@@ -19,7 +19,6 @@ import (
 	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/conv"
 	"github.com/labring/aiproxy/core/common/render"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -212,7 +211,7 @@ func StreamHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	responseText := strings.Builder{}
 
@@ -347,7 +346,7 @@ func Handler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {

+ 1 - 2
core/relay/adaptor/openai/image.go

@@ -13,7 +13,6 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/image"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -129,7 +128,7 @@ func ImagesHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {

+ 1 - 2
core/relay/adaptor/openai/moderations.go

@@ -10,7 +10,6 @@ import (
 	"github.com/bytedance/sonic/ast"
 	"github.com/gin-gonic/gin"
 	"github.com/labring/aiproxy/core/common"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -55,7 +54,7 @@ func ModerationsHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	body, err := io.ReadAll(resp.Body)
 	if err != nil {

+ 1 - 2
core/relay/adaptor/openai/rerank.go

@@ -10,7 +10,6 @@ import (
 	"github.com/bytedance/sonic/ast"
 	"github.com/gin-gonic/gin"
 	"github.com/labring/aiproxy/core/common"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -55,7 +54,7 @@ func RerankHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {

+ 2 - 2
core/relay/adaptor/openai/stt.go

@@ -12,8 +12,8 @@ import (
 
 	"github.com/bytedance/sonic"
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/conv"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -96,7 +96,7 @@ func STTHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	responseFormat := meta.GetString(MetaResponseFormat)
 

+ 1 - 2
core/relay/adaptor/openai/tts.go

@@ -10,7 +10,6 @@ import (
 	"github.com/bytedance/sonic/ast"
 	"github.com/gin-gonic/gin"
 	"github.com/labring/aiproxy/core/common"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -66,7 +65,7 @@ func TTSHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	for k, v := range resp.Header {
 		c.Writer.Header().Set(k, v[0])

+ 2 - 3
core/relay/adaptor/openai/video.go

@@ -11,7 +11,6 @@ import (
 	"github.com/bytedance/sonic/ast"
 	"github.com/gin-gonic/gin"
 	"github.com/labring/aiproxy/core/common"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -103,7 +102,7 @@ func VideoHandler(
 		ExpiresAt: time.Now().Add(time.Hour * 24),
 	})
 	if err != nil {
-		log := middleware.GetLogger(c)
+		log := common.GetLogger(c)
 		log.Errorf("save store failed: %v", err)
 	}
 
@@ -165,7 +164,7 @@ func VideoGetJobsHandler(
 			ExpiresAt: time.Unix(expiresAt, 0),
 		})
 		if err != nil {
-			log := middleware.GetLogger(c)
+			log := common.GetLogger(c)
 			log.Errorf("save store failed: %v", err)
 		}
 		return true

+ 1 - 2
core/relay/adaptor/text-embeddings-inference/rerank.go

@@ -11,7 +11,6 @@ import (
 	"github.com/bytedance/sonic/ast"
 	"github.com/gin-gonic/gin"
 	"github.com/labring/aiproxy/core/common"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -127,7 +126,7 @@ func RerankHandler(
 
 	defer resp.Body.Close()
 
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	respSlice := RerankResponse{}
 	err := sonic.ConfigDefault.NewDecoder(resp.Body).Decode(&respSlice)

+ 2 - 3
core/relay/controller/dohelper.go

@@ -14,7 +14,6 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/conv"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -137,7 +136,7 @@ func DoHelper(
 	}
 
 	// 5. Update usage metrics
-	updateUsageMetrics(usage, middleware.GetLogger(c))
+	updateUsageMetrics(usage, common.GetLogger(c))
 
 	return usage, &detail, nil
 }
@@ -174,7 +173,7 @@ func prepareAndDoRequest(
 	meta *meta.Meta,
 	store adaptor.Store,
 ) (*http.Response, adaptor.Error) {
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	convertResult, err := a.ConvertRequest(meta, store, c.Request)
 	if err != nil {

+ 2 - 2
core/relay/controller/handle.go

@@ -2,8 +2,8 @@ package controller
 
 import (
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/config"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 	"github.com/labring/aiproxy/core/relay/adaptor"
 	"github.com/labring/aiproxy/core/relay/meta"
@@ -22,7 +22,7 @@ func Handle(
 	meta *meta.Meta,
 	store adaptor.Store,
 ) *HandleResult {
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 
 	usage, detail, respErr := DoHelper(adaptor, c, meta, store)
 	if respErr != nil {

+ 2 - 2
core/relay/controller/stt.go

@@ -8,8 +8,8 @@ import (
 	"os"
 
 	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
 	"github.com/labring/aiproxy/core/common/audio"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
 )
 
@@ -29,7 +29,7 @@ func GetSTTRequestUsage(c *gin.Context, _ model.ModelConfig) (model.Usage, error
 	}
 
 	durationInt := int64(math.Ceil(duration))
-	log := middleware.GetLogger(c)
+	log := common.GetLogger(c)
 	log.Data["duration"] = durationInt
 
 	return model.Usage{

+ 117 - 0
core/relay/plugin/monitor/group.go

@@ -0,0 +1,117 @@
+package monitor
+
+import (
+	"context"
+	"net/http"
+	"strconv"
+
+	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
+	"github.com/labring/aiproxy/core/common/reqlimit"
+	"github.com/labring/aiproxy/core/model"
+	"github.com/labring/aiproxy/core/relay/adaptor"
+	"github.com/labring/aiproxy/core/relay/meta"
+	"github.com/labring/aiproxy/core/relay/plugin"
+	"github.com/labring/aiproxy/core/relay/plugin/noop"
+)
+
+const (
+	GroupModelTokenRPM = "group_model_token_rpm"
+	GroupModelTokenRPS = "group_model_token_rps"
+	GroupModelTokenTPM = "group_model_token_tpm"
+	GroupModelTokenTPS = "group_model_token_tps"
+)
+
+var _ plugin.Plugin = (*GroupMonitor)(nil)
+
+type GroupMonitor struct {
+	noop.Noop
+}
+
+func NewGroupMonitorPlugin() plugin.Plugin {
+	return &GroupMonitor{}
+}
+
+func (m *GroupMonitor) DoResponse(
+	meta *meta.Meta,
+	store adaptor.Store,
+	c *gin.Context,
+	resp *http.Response,
+	do adaptor.DoResponse,
+) (model.Usage, adaptor.Error) {
+	usage, relayErr := do.DoResponse(meta, store, c, resp)
+
+	if usage.TotalTokens > 0 {
+		count, overLimitCount, secondCount := reqlimit.PushGroupModelTokensRequest(
+			context.Background(),
+			meta.Group.ID,
+			meta.OriginModel,
+			meta.ModelConfig.TPM,
+			int64(usage.TotalTokens),
+		)
+		UpdateGroupModelTokensRequest(c, meta.Group, count+overLimitCount, secondCount)
+
+		count, overLimitCount, secondCount = reqlimit.PushGroupModelTokennameTokensRequest(
+			context.Background(),
+			meta.Group.ID,
+			meta.OriginModel,
+			meta.Token.Name,
+			int64(usage.TotalTokens),
+		)
+		UpdateGroupModelTokennameTokensRequest(c, count+overLimitCount, secondCount)
+	}
+
+	return usage, relayErr
+}
+
+func UpdateGroupModelRequest(c *gin.Context, group model.GroupCache, rpm, rps int64) {
+	if group.Status == model.GroupStatusInternal {
+		return
+	}
+
+	log := common.GetLogger(c)
+	log.Data["group_rpm"] = strconv.FormatInt(rpm, 10)
+	log.Data["group_rps"] = strconv.FormatInt(rps, 10)
+}
+
+func UpdateGroupModelTokensRequest(c *gin.Context, group model.GroupCache, tpm, tps int64) {
+	if group.Status == model.GroupStatusInternal {
+		return
+	}
+
+	log := common.GetLogger(c)
+	log.Data["group_tpm"] = strconv.FormatInt(tpm, 10)
+	log.Data["group_tps"] = strconv.FormatInt(tps, 10)
+}
+
+func UpdateGroupModelTokennameRequest(c *gin.Context, rpm, rps int64) {
+	c.Set(GroupModelTokenRPM, rpm)
+	c.Set(GroupModelTokenRPS, rps)
+	// log := common.GetLogger(c)
+	// log.Data["rpm"] = strconv.FormatInt(rpm, 10)
+	// log.Data["rps"] = strconv.FormatInt(rps, 10)
+}
+
+func UpdateGroupModelTokennameTokensRequest(c *gin.Context, tpm, tps int64) {
+	c.Set(GroupModelTokenTPM, tpm)
+	c.Set(GroupModelTokenTPS, tps)
+	// log := common.GetLogger(c)
+	// log.Data["tpm"] = strconv.FormatInt(tpm, 10)
+	// log.Data["tps"] = strconv.FormatInt(tps, 10)
+}
+
+func GetGroupModelTokenRPM(c *gin.Context) int64 {
+	return c.GetInt64(GroupModelTokenRPM)
+}
+
+func GetGroupModelTokenRPS(c *gin.Context) int64 {
+	return c.GetInt64(GroupModelTokenRPS)
+}
+
+func GetGroupModelTokenTPM(c *gin.Context) int64 {
+	return c.GetInt64(GroupModelTokenTPM)
+}
+
+func GetGroupModelTokenTPS(c *gin.Context) int64 {
+	return c.GetInt64(GroupModelTokenTPS)
+}

+ 252 - 0
core/relay/plugin/monitor/monitor.go

@@ -0,0 +1,252 @@
+package monitor
+
+import (
+	"context"
+	"fmt"
+	"net/http"
+	"strconv"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
+	"github.com/labring/aiproxy/core/common/conv"
+	"github.com/labring/aiproxy/core/common/notify"
+	"github.com/labring/aiproxy/core/common/reqlimit"
+	"github.com/labring/aiproxy/core/common/trylock"
+	"github.com/labring/aiproxy/core/model"
+	"github.com/labring/aiproxy/core/monitor"
+	"github.com/labring/aiproxy/core/relay/adaptor"
+	"github.com/labring/aiproxy/core/relay/meta"
+	"github.com/labring/aiproxy/core/relay/plugin"
+	"github.com/labring/aiproxy/core/relay/plugin/noop"
+)
+
+var _ plugin.Plugin = (*ChannelMonitor)(nil)
+
+type ChannelMonitor struct {
+	noop.Noop
+}
+
+func NewChannelMonitorPlugin() plugin.Plugin {
+	return &ChannelMonitor{}
+}
+
+var channelNoRetryStatusCodesMap = map[int]struct{}{
+	http.StatusBadRequest:                 {},
+	http.StatusRequestEntityTooLarge:      {},
+	http.StatusUnprocessableEntity:        {},
+	http.StatusUnavailableForLegalReasons: {},
+}
+
+func ShouldRetry(relayErr adaptor.Error) bool {
+	_, ok := channelNoRetryStatusCodesMap[relayErr.StatusCode()]
+	return !ok
+}
+
+var channelNoPermissionStatusCodesMap = map[int]struct{}{
+	http.StatusUnauthorized:    {},
+	http.StatusPaymentRequired: {},
+	http.StatusForbidden:       {},
+	http.StatusNotFound:        {},
+}
+
+func ChannelHasPermission(relayErr adaptor.Error) bool {
+	_, ok := channelNoPermissionStatusCodesMap[relayErr.StatusCode()]
+	return !ok
+}
+
+func (m *ChannelMonitor) DoRequest(
+	meta *meta.Meta,
+	store adaptor.Store,
+	c *gin.Context,
+	req *http.Request,
+	do adaptor.DoRequest,
+) (*http.Response, error) {
+	count, overLimitCount, secondCount := reqlimit.PushChannelModelRequest(
+		context.Background(),
+		strconv.Itoa(meta.Channel.ID),
+		meta.OriginModel,
+	)
+	updateChannelModelRequestRate(c, meta, count+overLimitCount, secondCount)
+	return do.DoRequest(meta, store, c, req)
+}
+
+func (m *ChannelMonitor) DoResponse(
+	meta *meta.Meta,
+	store adaptor.Store,
+	c *gin.Context,
+	resp *http.Response,
+	do adaptor.DoResponse,
+) (model.Usage, adaptor.Error) {
+	log := common.GetLogger(c)
+
+	usage, relayErr := do.DoResponse(meta, store, c, resp)
+
+	if usage.TotalTokens > 0 {
+		count, overLimitCount, secondCount := reqlimit.PushChannelModelTokensRequest(
+			context.Background(),
+			strconv.Itoa(meta.Channel.ID),
+			meta.OriginModel,
+			int64(usage.TotalTokens),
+		)
+		updateChannelModelTokensRequestRate(c, meta, count+overLimitCount, secondCount)
+	}
+
+	if relayErr == nil {
+		if _, _, err := monitor.AddRequest(
+			context.Background(),
+			meta.OriginModel,
+			int64(meta.Channel.ID),
+			false,
+			false,
+			meta.ModelConfig.MaxErrorRate,
+		); err != nil {
+			log.Errorf("add request failed: %+v", err)
+		}
+		return usage, nil
+	}
+
+	if !ShouldRetry(relayErr) {
+		return usage, relayErr
+	}
+
+	hasPermission := ChannelHasPermission(relayErr)
+	beyondThreshold, banExecution, err := monitor.AddRequest(
+		context.Background(),
+		meta.OriginModel,
+		int64(meta.Channel.ID),
+		true,
+		!hasPermission,
+		meta.ModelConfig.MaxErrorRate,
+	)
+	if err != nil {
+		log.Errorf("add request failed: %+v", err)
+	}
+	switch {
+	case banExecution:
+		notifyChannelIssue(c, meta, "autoBanned", "Auto Banned", relayErr)
+	case beyondThreshold:
+		notifyChannelIssue(
+			c,
+			meta,
+			"beyondThreshold",
+			"Error Rate Beyond Threshold",
+			relayErr,
+		)
+	case !hasPermission:
+		notifyChannelIssue(c, meta, "channelHasPermission", "No Permission", relayErr)
+	}
+
+	return usage, relayErr
+}
+
+func notifyChannelIssue(
+	c *gin.Context,
+	meta *meta.Meta,
+	issueType, titleSuffix string,
+	err adaptor.Error,
+) {
+	var notifyFunc func(title, message string)
+
+	lockKey := fmt.Sprintf("%s:%d:%s", issueType, meta.Channel.ID, meta.OriginModel)
+	switch issueType {
+	case "beyondThreshold":
+		notifyFunc = func(title, message string) {
+			notify.WarnThrottle(lockKey, time.Minute, title, message)
+		}
+	default:
+		notifyFunc = func(title, message string) {
+			notify.ErrorThrottle(lockKey, time.Minute, title, message)
+		}
+	}
+
+	respBody, _ := err.MarshalJSON()
+
+	message := fmt.Sprintf(
+		"channel: %s (type: %d, type name: %s, id: %d)\nmodel: %s\nmode: %s\nstatus code: %d\ndetail: %s\nrequest id: %s",
+		meta.Channel.Name,
+		meta.Channel.Type,
+		meta.Channel.Type.String(),
+		meta.Channel.ID,
+		meta.OriginModel,
+		meta.Mode,
+		err.StatusCode(),
+		conv.BytesToString(respBody),
+		meta.RequestID,
+	)
+
+	if err.StatusCode() == http.StatusTooManyRequests {
+		if !trylock.Lock(lockKey, time.Minute) {
+			return
+		}
+		switch issueType {
+		case "beyondThreshold":
+			notifyFunc = notify.Warn
+		default:
+			notifyFunc = notify.Error
+		}
+
+		rate := GetChannelModelRequestRate(c, meta)
+		message += fmt.Sprintf(
+			"\nrpm: %d\nrps: %d\ntpm: %d\ntps: %d",
+			rate.RPM,
+			rate.RPS,
+			rate.TPM,
+			rate.TPS,
+		)
+	}
+
+	notifyFunc(
+		fmt.Sprintf("%s `%s` %s", meta.Channel.Name, meta.OriginModel, titleSuffix),
+		message,
+	)
+}
+
+const (
+	MetaChannelModelKeyRPM = "channel_model_rpm"
+	MetaChannelModelKeyRPS = "channel_model_rps"
+	MetaChannelModelKeyTPM = "channel_model_tpm"
+	MetaChannelModelKeyTPS = "channel_model_tps"
+)
+
+func GetChannelModelRequestRate(c *gin.Context, meta *meta.Meta) model.RequestRate {
+	rate := model.RequestRate{}
+
+	if rpm, ok := meta.Get(MetaChannelModelKeyRPM); ok {
+		rate.RPM, _ = rpm.(int64)
+		rate.RPS = meta.GetInt64(MetaChannelModelKeyRPS)
+	} else {
+		rpm, rps := reqlimit.GetChannelModelRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
+		rate.RPM = rpm
+		rate.RPS = rps
+		updateChannelModelRequestRate(c, meta, rpm, rps)
+	}
+
+	if tpm, ok := meta.Get(MetaChannelModelKeyTPM); ok {
+		rate.TPM, _ = tpm.(int64)
+		rate.TPS = meta.GetInt64(MetaChannelModelKeyTPS)
+	} else {
+		tpm, tps := reqlimit.GetChannelModelTokensRequest(context.Background(), strconv.Itoa(meta.Channel.ID), meta.OriginModel)
+		rate.TPM = tpm
+		rate.TPS = tps
+		updateChannelModelTokensRequestRate(c, meta, tpm, tps)
+	}
+
+	return rate
+}
+
+func updateChannelModelRequestRate(c *gin.Context, meta *meta.Meta, rpm, rps int64) {
+	meta.Set(MetaChannelModelKeyRPM, rpm)
+	meta.Set(MetaChannelModelKeyRPS, rps)
+	log := common.GetLogger(c)
+	log.Data["ch_rpm"] = rpm
+	log.Data["ch_rps"] = rps
+}
+
+func updateChannelModelTokensRequestRate(c *gin.Context, meta *meta.Meta, tpm, tps int64) {
+	meta.Set(MetaChannelModelKeyTPM, tpm)
+	meta.Set(MetaChannelModelKeyTPS, tps)
+	log := common.GetLogger(c)
+	log.Data["ch_tpm"] = tpm
+	log.Data["ch_tps"] = tps
+}