2
0

memmodel.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. package monitor
  2. import (
  3. "context"
  4. "sync"
  5. "time"
  6. "github.com/labring/aiproxy/core/common/config"
  7. )
  8. var memModelMonitor *MemModelMonitor
  9. func init() {
  10. memModelMonitor = NewMemModelMonitor()
  11. }
  12. const (
  13. timeWindow = 10 * time.Second
  14. maxSliceCount = 12
  15. banDuration = 5 * time.Minute
  16. minRequestCount = 20
  17. cleanupInterval = time.Minute
  18. )
  19. type MemModelMonitor struct {
  20. mu sync.RWMutex
  21. models map[string]*ModelData
  22. }
  23. type ModelData struct {
  24. channels map[int64]*ChannelStats
  25. totalStats *TimeWindowStats
  26. }
  27. type ChannelStats struct {
  28. timeWindows *TimeWindowStats
  29. bannedUntil time.Time
  30. }
  31. type TimeWindowStats struct {
  32. slices []*timeSlice
  33. mu sync.Mutex
  34. }
  35. type timeSlice struct {
  36. windowStart time.Time
  37. requests int
  38. errors int
  39. }
  40. func NewTimeWindowStats() *TimeWindowStats {
  41. return &TimeWindowStats{
  42. slices: make([]*timeSlice, 0, maxSliceCount),
  43. }
  44. }
  45. func NewMemModelMonitor() *MemModelMonitor {
  46. mm := &MemModelMonitor{
  47. models: make(map[string]*ModelData),
  48. }
  49. go mm.periodicCleanup()
  50. return mm
  51. }
  52. func (m *MemModelMonitor) periodicCleanup() {
  53. ticker := time.NewTicker(cleanupInterval)
  54. defer ticker.Stop()
  55. for range ticker.C {
  56. m.cleanupExpiredData()
  57. }
  58. }
  59. func (m *MemModelMonitor) cleanupExpiredData() {
  60. m.mu.Lock()
  61. defer m.mu.Unlock()
  62. now := time.Now()
  63. for modelName, modelData := range m.models {
  64. for channelID, channelStats := range modelData.channels {
  65. hasValidSlices := channelStats.timeWindows.HasValidSlices()
  66. if !hasValidSlices && !channelStats.bannedUntil.After(now) {
  67. delete(modelData.channels, channelID)
  68. }
  69. }
  70. hasValidSlices := modelData.totalStats.HasValidSlices()
  71. if !hasValidSlices && len(modelData.channels) == 0 {
  72. delete(m.models, modelName)
  73. }
  74. }
  75. }
  76. func (m *MemModelMonitor) AddRequest(
  77. model string,
  78. channelID int64,
  79. isError, tryBan bool,
  80. warnErrorRate,
  81. maxErrorRate float64,
  82. ) (beyondThreshold, banExecution bool) {
  83. // Set default warning threshold if not specified
  84. if warnErrorRate <= 0 {
  85. warnErrorRate = config.GetDefaultWarnNotifyErrorRate()
  86. }
  87. m.mu.Lock()
  88. defer m.mu.Unlock()
  89. now := time.Now()
  90. var (
  91. modelData *ModelData
  92. exists bool
  93. )
  94. if modelData, exists = m.models[model]; !exists {
  95. modelData = &ModelData{
  96. channels: make(map[int64]*ChannelStats),
  97. totalStats: NewTimeWindowStats(),
  98. }
  99. m.models[model] = modelData
  100. }
  101. var channel *ChannelStats
  102. if channel, exists = modelData.channels[channelID]; !exists {
  103. channel = &ChannelStats{
  104. timeWindows: NewTimeWindowStats(),
  105. }
  106. modelData.channels[channelID] = channel
  107. }
  108. modelData.totalStats.AddRequest(now, isError)
  109. channel.timeWindows.AddRequest(now, isError)
  110. return m.checkAndBan(now, channel, tryBan, warnErrorRate, maxErrorRate)
  111. }
  112. func (m *MemModelMonitor) checkAndBan(
  113. now time.Time,
  114. channel *ChannelStats,
  115. tryBan bool,
  116. warnErrorRate,
  117. maxErrorRate float64,
  118. ) (beyondThreshold, banExecution bool) {
  119. canBan := maxErrorRate > 0
  120. if tryBan && canBan {
  121. if channel.bannedUntil.After(now) {
  122. return false, false
  123. }
  124. channel.bannedUntil = now.Add(banDuration)
  125. return false, true
  126. }
  127. req, err := channel.timeWindows.GetStats()
  128. if req < minRequestCount {
  129. return false, false
  130. }
  131. errorRate := float64(err) / float64(req)
  132. // Check if error rate exceeds warning threshold
  133. exceedsWarning := errorRate >= warnErrorRate
  134. // Check if we should ban (only if maxErrorRate is set and exceeded)
  135. if canBan && errorRate >= maxErrorRate {
  136. if channel.bannedUntil.After(now) {
  137. return true, false // Already banned
  138. }
  139. channel.bannedUntil = now.Add(banDuration)
  140. return false, true // Ban executed
  141. } else if exceedsWarning {
  142. return true, false // Beyond warning threshold but not banning
  143. }
  144. return false, false
  145. }
  146. func getErrorRateFromStats(stats *TimeWindowStats) float64 {
  147. req, err := stats.GetStats()
  148. if req < minRequestCount {
  149. return 0
  150. }
  151. return float64(err) / float64(req)
  152. }
  153. func (m *MemModelMonitor) GetModelsErrorRate(_ context.Context) (map[string]float64, error) {
  154. m.mu.RLock()
  155. defer m.mu.RUnlock()
  156. result := make(map[string]float64)
  157. for model, data := range m.models {
  158. result[model] = getErrorRateFromStats(data.totalStats)
  159. }
  160. return result, nil
  161. }
  162. func (m *MemModelMonitor) GetModelChannelErrorRate(
  163. _ context.Context,
  164. model string,
  165. ) (map[int64]float64, error) {
  166. m.mu.RLock()
  167. defer m.mu.RUnlock()
  168. result := make(map[int64]float64)
  169. if data, exists := m.models[model]; exists {
  170. for channelID, channel := range data.channels {
  171. result[channelID] = getErrorRateFromStats(channel.timeWindows)
  172. }
  173. }
  174. return result, nil
  175. }
  176. func (m *MemModelMonitor) GetChannelModelErrorRates(
  177. _ context.Context,
  178. channelID int64,
  179. ) (map[string]float64, error) {
  180. m.mu.RLock()
  181. defer m.mu.RUnlock()
  182. result := make(map[string]float64)
  183. for model, data := range m.models {
  184. if channel, exists := data.channels[channelID]; exists {
  185. result[model] = getErrorRateFromStats(channel.timeWindows)
  186. }
  187. }
  188. return result, nil
  189. }
  190. func (m *MemModelMonitor) GetAllChannelModelErrorRates(
  191. _ context.Context,
  192. ) (map[int64]map[string]float64, error) {
  193. m.mu.RLock()
  194. defer m.mu.RUnlock()
  195. result := make(map[int64]map[string]float64)
  196. for model, data := range m.models {
  197. for channelID, channel := range data.channels {
  198. if _, exists := result[channelID]; !exists {
  199. result[channelID] = make(map[string]float64)
  200. }
  201. result[channelID][model] = getErrorRateFromStats(channel.timeWindows)
  202. }
  203. }
  204. return result, nil
  205. }
  206. func (m *MemModelMonitor) GetBannedChannelsWithModel(
  207. _ context.Context,
  208. model string,
  209. ) ([]int64, error) {
  210. m.mu.RLock()
  211. defer m.mu.RUnlock()
  212. var banned []int64
  213. if data, exists := m.models[model]; exists {
  214. now := time.Now()
  215. for channelID, channel := range data.channels {
  216. if channel.bannedUntil.After(now) {
  217. banned = append(banned, channelID)
  218. }
  219. }
  220. }
  221. return banned, nil
  222. }
  223. func (m *MemModelMonitor) GetBannedChannelsMapWithModel(
  224. _ context.Context,
  225. model string,
  226. ) (map[int64]struct{}, error) {
  227. m.mu.RLock()
  228. defer m.mu.RUnlock()
  229. banned := make(map[int64]struct{})
  230. if data, exists := m.models[model]; exists {
  231. now := time.Now()
  232. for channelID, channel := range data.channels {
  233. if channel.bannedUntil.After(now) {
  234. banned[channelID] = struct{}{}
  235. }
  236. }
  237. }
  238. return banned, nil
  239. }
  240. func (m *MemModelMonitor) GetAllBannedModelChannels(_ context.Context) (map[string][]int64, error) {
  241. m.mu.RLock()
  242. defer m.mu.RUnlock()
  243. result := make(map[string][]int64)
  244. now := time.Now()
  245. for model, data := range m.models {
  246. for channelID, channel := range data.channels {
  247. if channel.bannedUntil.After(now) {
  248. if _, exists := result[model]; !exists {
  249. result[model] = []int64{}
  250. }
  251. result[model] = append(result[model], channelID)
  252. }
  253. }
  254. }
  255. return result, nil
  256. }
  257. func (m *MemModelMonitor) ClearChannelModelErrors(
  258. _ context.Context,
  259. model string,
  260. channelID int,
  261. ) error {
  262. m.mu.Lock()
  263. defer m.mu.Unlock()
  264. if data, exists := m.models[model]; exists {
  265. delete(data.channels, int64(channelID))
  266. }
  267. return nil
  268. }
  269. func (m *MemModelMonitor) ClearChannelAllModelErrors(_ context.Context, channelID int) error {
  270. m.mu.Lock()
  271. defer m.mu.Unlock()
  272. for _, data := range m.models {
  273. delete(data.channels, int64(channelID))
  274. }
  275. return nil
  276. }
  277. func (m *MemModelMonitor) ClearAllModelErrors(_ context.Context) error {
  278. m.mu.Lock()
  279. defer m.mu.Unlock()
  280. m.models = make(map[string]*ModelData)
  281. return nil
  282. }
  283. func (t *TimeWindowStats) cleanupLocked(callback func(slice *timeSlice)) {
  284. cutoff := time.Now().Add(-timeWindow * time.Duration(maxSliceCount))
  285. validSlices := t.slices[:0]
  286. for _, s := range t.slices {
  287. if s.windowStart.After(cutoff) || s.windowStart.Equal(cutoff) {
  288. validSlices = append(validSlices, s)
  289. if callback != nil {
  290. callback(s)
  291. }
  292. }
  293. }
  294. t.slices = validSlices
  295. }
  296. func (t *TimeWindowStats) AddRequest(now time.Time, isError bool) {
  297. t.mu.Lock()
  298. defer t.mu.Unlock()
  299. t.cleanupLocked(nil)
  300. currentWindow := now.Truncate(timeWindow)
  301. var slice *timeSlice
  302. for i := range t.slices {
  303. if t.slices[i].windowStart.Equal(currentWindow) {
  304. slice = t.slices[i]
  305. break
  306. }
  307. }
  308. if slice == nil {
  309. slice = &timeSlice{windowStart: currentWindow}
  310. t.slices = append(t.slices, slice)
  311. }
  312. slice.requests++
  313. if isError {
  314. slice.errors++
  315. }
  316. }
  317. func (t *TimeWindowStats) GetStats() (totalReq, totalErr int) {
  318. t.mu.Lock()
  319. defer t.mu.Unlock()
  320. t.cleanupLocked(func(slice *timeSlice) {
  321. totalReq += slice.requests
  322. totalErr += slice.errors
  323. })
  324. return totalReq, totalErr
  325. }
  326. func (t *TimeWindowStats) HasValidSlices() bool {
  327. t.mu.Lock()
  328. defer t.mu.Unlock()
  329. t.cleanupLocked(nil)
  330. return len(t.slices) > 0
  331. }