| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420 |
- package monitor
- import (
- "context"
- "sync"
- "time"
- "github.com/labring/aiproxy/core/common/config"
- )
- var memModelMonitor *MemModelMonitor
- func init() {
- memModelMonitor = NewMemModelMonitor()
- }
- const (
- timeWindow = 10 * time.Second
- maxSliceCount = 12
- banDuration = 5 * time.Minute
- minRequestCount = 20
- cleanupInterval = time.Minute
- )
- type MemModelMonitor struct {
- mu sync.RWMutex
- models map[string]*ModelData
- }
- type ModelData struct {
- channels map[int64]*ChannelStats
- totalStats *TimeWindowStats
- }
- type ChannelStats struct {
- timeWindows *TimeWindowStats
- bannedUntil time.Time
- }
- type TimeWindowStats struct {
- slices []*timeSlice
- mu sync.Mutex
- }
- type timeSlice struct {
- windowStart time.Time
- requests int
- errors int
- }
- func NewTimeWindowStats() *TimeWindowStats {
- return &TimeWindowStats{
- slices: make([]*timeSlice, 0, maxSliceCount),
- }
- }
- func NewMemModelMonitor() *MemModelMonitor {
- mm := &MemModelMonitor{
- models: make(map[string]*ModelData),
- }
- go mm.periodicCleanup()
- return mm
- }
- func (m *MemModelMonitor) periodicCleanup() {
- ticker := time.NewTicker(cleanupInterval)
- defer ticker.Stop()
- for range ticker.C {
- m.cleanupExpiredData()
- }
- }
- func (m *MemModelMonitor) cleanupExpiredData() {
- m.mu.Lock()
- defer m.mu.Unlock()
- now := time.Now()
- for modelName, modelData := range m.models {
- for channelID, channelStats := range modelData.channels {
- hasValidSlices := channelStats.timeWindows.HasValidSlices()
- if !hasValidSlices && !channelStats.bannedUntil.After(now) {
- delete(modelData.channels, channelID)
- }
- }
- hasValidSlices := modelData.totalStats.HasValidSlices()
- if !hasValidSlices && len(modelData.channels) == 0 {
- delete(m.models, modelName)
- }
- }
- }
- func (m *MemModelMonitor) AddRequest(
- model string,
- channelID int64,
- isError, tryBan bool,
- warnErrorRate,
- maxErrorRate float64,
- ) (beyondThreshold, banExecution bool) {
- // Set default warning threshold if not specified
- if warnErrorRate <= 0 {
- warnErrorRate = config.GetDefaultWarnNotifyErrorRate()
- }
- m.mu.Lock()
- defer m.mu.Unlock()
- now := time.Now()
- var (
- modelData *ModelData
- exists bool
- )
- if modelData, exists = m.models[model]; !exists {
- modelData = &ModelData{
- channels: make(map[int64]*ChannelStats),
- totalStats: NewTimeWindowStats(),
- }
- m.models[model] = modelData
- }
- var channel *ChannelStats
- if channel, exists = modelData.channels[channelID]; !exists {
- channel = &ChannelStats{
- timeWindows: NewTimeWindowStats(),
- }
- modelData.channels[channelID] = channel
- }
- modelData.totalStats.AddRequest(now, isError)
- channel.timeWindows.AddRequest(now, isError)
- return m.checkAndBan(now, channel, tryBan, warnErrorRate, maxErrorRate)
- }
- func (m *MemModelMonitor) checkAndBan(
- now time.Time,
- channel *ChannelStats,
- tryBan bool,
- warnErrorRate,
- maxErrorRate float64,
- ) (beyondThreshold, banExecution bool) {
- canBan := maxErrorRate > 0
- if tryBan && canBan {
- if channel.bannedUntil.After(now) {
- return false, false
- }
- channel.bannedUntil = now.Add(banDuration)
- return false, true
- }
- req, err := channel.timeWindows.GetStats()
- if req < minRequestCount {
- return false, false
- }
- errorRate := float64(err) / float64(req)
- // Check if error rate exceeds warning threshold
- exceedsWarning := errorRate >= warnErrorRate
- // Check if we should ban (only if maxErrorRate is set and exceeded)
- if canBan && errorRate >= maxErrorRate {
- if channel.bannedUntil.After(now) {
- return true, false // Already banned
- }
- channel.bannedUntil = now.Add(banDuration)
- return false, true // Ban executed
- } else if exceedsWarning {
- return true, false // Beyond warning threshold but not banning
- }
- return false, false
- }
- func getErrorRateFromStats(stats *TimeWindowStats) float64 {
- req, err := stats.GetStats()
- if req < minRequestCount {
- return 0
- }
- return float64(err) / float64(req)
- }
- func (m *MemModelMonitor) GetModelsErrorRate(_ context.Context) (map[string]float64, error) {
- m.mu.RLock()
- defer m.mu.RUnlock()
- result := make(map[string]float64)
- for model, data := range m.models {
- result[model] = getErrorRateFromStats(data.totalStats)
- }
- return result, nil
- }
- func (m *MemModelMonitor) GetModelChannelErrorRate(
- _ context.Context,
- model string,
- ) (map[int64]float64, error) {
- m.mu.RLock()
- defer m.mu.RUnlock()
- result := make(map[int64]float64)
- if data, exists := m.models[model]; exists {
- for channelID, channel := range data.channels {
- result[channelID] = getErrorRateFromStats(channel.timeWindows)
- }
- }
- return result, nil
- }
- func (m *MemModelMonitor) GetChannelModelErrorRates(
- _ context.Context,
- channelID int64,
- ) (map[string]float64, error) {
- m.mu.RLock()
- defer m.mu.RUnlock()
- result := make(map[string]float64)
- for model, data := range m.models {
- if channel, exists := data.channels[channelID]; exists {
- result[model] = getErrorRateFromStats(channel.timeWindows)
- }
- }
- return result, nil
- }
- func (m *MemModelMonitor) GetAllChannelModelErrorRates(
- _ context.Context,
- ) (map[int64]map[string]float64, error) {
- m.mu.RLock()
- defer m.mu.RUnlock()
- result := make(map[int64]map[string]float64)
- for model, data := range m.models {
- for channelID, channel := range data.channels {
- if _, exists := result[channelID]; !exists {
- result[channelID] = make(map[string]float64)
- }
- result[channelID][model] = getErrorRateFromStats(channel.timeWindows)
- }
- }
- return result, nil
- }
- func (m *MemModelMonitor) GetBannedChannelsWithModel(
- _ context.Context,
- model string,
- ) ([]int64, error) {
- m.mu.RLock()
- defer m.mu.RUnlock()
- var banned []int64
- if data, exists := m.models[model]; exists {
- now := time.Now()
- for channelID, channel := range data.channels {
- if channel.bannedUntil.After(now) {
- banned = append(banned, channelID)
- }
- }
- }
- return banned, nil
- }
- func (m *MemModelMonitor) GetBannedChannelsMapWithModel(
- _ context.Context,
- model string,
- ) (map[int64]struct{}, error) {
- m.mu.RLock()
- defer m.mu.RUnlock()
- banned := make(map[int64]struct{})
- if data, exists := m.models[model]; exists {
- now := time.Now()
- for channelID, channel := range data.channels {
- if channel.bannedUntil.After(now) {
- banned[channelID] = struct{}{}
- }
- }
- }
- return banned, nil
- }
- func (m *MemModelMonitor) GetAllBannedModelChannels(_ context.Context) (map[string][]int64, error) {
- m.mu.RLock()
- defer m.mu.RUnlock()
- result := make(map[string][]int64)
- now := time.Now()
- for model, data := range m.models {
- for channelID, channel := range data.channels {
- if channel.bannedUntil.After(now) {
- if _, exists := result[model]; !exists {
- result[model] = []int64{}
- }
- result[model] = append(result[model], channelID)
- }
- }
- }
- return result, nil
- }
- func (m *MemModelMonitor) ClearChannelModelErrors(
- _ context.Context,
- model string,
- channelID int,
- ) error {
- m.mu.Lock()
- defer m.mu.Unlock()
- if data, exists := m.models[model]; exists {
- delete(data.channels, int64(channelID))
- }
- return nil
- }
- func (m *MemModelMonitor) ClearChannelAllModelErrors(_ context.Context, channelID int) error {
- m.mu.Lock()
- defer m.mu.Unlock()
- for _, data := range m.models {
- delete(data.channels, int64(channelID))
- }
- return nil
- }
- func (m *MemModelMonitor) ClearAllModelErrors(_ context.Context) error {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.models = make(map[string]*ModelData)
- return nil
- }
- func (t *TimeWindowStats) cleanupLocked(callback func(slice *timeSlice)) {
- cutoff := time.Now().Add(-timeWindow * time.Duration(maxSliceCount))
- validSlices := t.slices[:0]
- for _, s := range t.slices {
- if s.windowStart.After(cutoff) || s.windowStart.Equal(cutoff) {
- validSlices = append(validSlices, s)
- if callback != nil {
- callback(s)
- }
- }
- }
- t.slices = validSlices
- }
- func (t *TimeWindowStats) AddRequest(now time.Time, isError bool) {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.cleanupLocked(nil)
- currentWindow := now.Truncate(timeWindow)
- var slice *timeSlice
- for i := range t.slices {
- if t.slices[i].windowStart.Equal(currentWindow) {
- slice = t.slices[i]
- break
- }
- }
- if slice == nil {
- slice = &timeSlice{windowStart: currentWindow}
- t.slices = append(t.slices, slice)
- }
- slice.requests++
- if isError {
- slice.errors++
- }
- }
- func (t *TimeWindowStats) GetStats() (totalReq, totalErr int) {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.cleanupLocked(func(slice *timeSlice) {
- totalReq += slice.requests
- totalErr += slice.errors
- })
- return totalReq, totalErr
- }
- func (t *TimeWindowStats) HasValidSlices() bool {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.cleanupLocked(nil)
- return len(t.slices) > 0
- }
|