memmodel.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. package monitor
  2. import (
  3. "context"
  4. "sync"
  5. "time"
  6. )
  7. var memModelMonitor *MemModelMonitor
  8. func init() {
  9. memModelMonitor = NewMemModelMonitor()
  10. }
  11. const (
  12. timeWindow = 10 * time.Second
  13. maxSliceCount = 12
  14. banDuration = 5 * time.Minute
  15. minRequestCount = 20
  16. cleanupInterval = time.Minute
  17. )
  18. type MemModelMonitor struct {
  19. mu sync.RWMutex
  20. models map[string]*ModelData
  21. }
  22. type ModelData struct {
  23. channels map[int64]*ChannelStats
  24. totalStats *TimeWindowStats
  25. }
  26. type ChannelStats struct {
  27. timeWindows *TimeWindowStats
  28. bannedUntil time.Time
  29. }
  30. type TimeWindowStats struct {
  31. slices []*timeSlice
  32. mu sync.Mutex
  33. }
  34. type timeSlice struct {
  35. windowStart time.Time
  36. requests int
  37. errors int
  38. }
  39. func NewTimeWindowStats() *TimeWindowStats {
  40. return &TimeWindowStats{
  41. slices: make([]*timeSlice, 0, maxSliceCount),
  42. }
  43. }
  44. func NewMemModelMonitor() *MemModelMonitor {
  45. mm := &MemModelMonitor{
  46. models: make(map[string]*ModelData),
  47. }
  48. go mm.periodicCleanup()
  49. return mm
  50. }
  51. func (m *MemModelMonitor) periodicCleanup() {
  52. ticker := time.NewTicker(cleanupInterval)
  53. defer ticker.Stop()
  54. for range ticker.C {
  55. m.cleanupExpiredData()
  56. }
  57. }
  58. func (m *MemModelMonitor) cleanupExpiredData() {
  59. m.mu.Lock()
  60. defer m.mu.Unlock()
  61. now := time.Now()
  62. for modelName, modelData := range m.models {
  63. for channelID, channelStats := range modelData.channels {
  64. hasValidSlices := channelStats.timeWindows.HasValidSlices()
  65. if !hasValidSlices && !channelStats.bannedUntil.After(now) {
  66. delete(modelData.channels, channelID)
  67. }
  68. }
  69. hasValidSlices := modelData.totalStats.HasValidSlices()
  70. if !hasValidSlices && len(modelData.channels) == 0 {
  71. delete(m.models, modelName)
  72. }
  73. }
  74. }
  75. func (m *MemModelMonitor) AddRequest(
  76. model string,
  77. channelID int64,
  78. isError, tryBan bool,
  79. maxErrorRate float64,
  80. ) (beyondThreshold, banExecution bool) {
  81. m.mu.Lock()
  82. defer m.mu.Unlock()
  83. now := time.Now()
  84. var (
  85. modelData *ModelData
  86. exists bool
  87. )
  88. if modelData, exists = m.models[model]; !exists {
  89. modelData = &ModelData{
  90. channels: make(map[int64]*ChannelStats),
  91. totalStats: NewTimeWindowStats(),
  92. }
  93. m.models[model] = modelData
  94. }
  95. var channel *ChannelStats
  96. if channel, exists = modelData.channels[channelID]; !exists {
  97. channel = &ChannelStats{
  98. timeWindows: NewTimeWindowStats(),
  99. }
  100. modelData.channels[channelID] = channel
  101. }
  102. modelData.totalStats.AddRequest(now, isError)
  103. channel.timeWindows.AddRequest(now, isError)
  104. return m.checkAndBan(now, channel, tryBan, maxErrorRate)
  105. }
  106. func (m *MemModelMonitor) checkAndBan(
  107. now time.Time,
  108. channel *ChannelStats,
  109. tryBan bool,
  110. maxErrorRate float64,
  111. ) (beyondThreshold, banExecution bool) {
  112. canBan := maxErrorRate > 0
  113. if tryBan && canBan {
  114. if channel.bannedUntil.After(now) {
  115. return false, false
  116. }
  117. channel.bannedUntil = now.Add(banDuration)
  118. return false, true
  119. }
  120. req, err := channel.timeWindows.GetStats()
  121. if req < minRequestCount {
  122. return false, false
  123. }
  124. if float64(err)/float64(req) >= maxErrorRate {
  125. if !canBan || channel.bannedUntil.After(now) {
  126. return true, false
  127. }
  128. channel.bannedUntil = now.Add(banDuration)
  129. return false, true
  130. }
  131. return false, false
  132. }
  133. func getErrorRateFromStats(stats *TimeWindowStats) float64 {
  134. req, err := stats.GetStats()
  135. if req < minRequestCount {
  136. return 0
  137. }
  138. return float64(err) / float64(req)
  139. }
  140. func (m *MemModelMonitor) GetModelsErrorRate(_ context.Context) (map[string]float64, error) {
  141. m.mu.RLock()
  142. defer m.mu.RUnlock()
  143. result := make(map[string]float64)
  144. for model, data := range m.models {
  145. result[model] = getErrorRateFromStats(data.totalStats)
  146. }
  147. return result, nil
  148. }
  149. func (m *MemModelMonitor) GetModelChannelErrorRate(
  150. _ context.Context,
  151. model string,
  152. ) (map[int64]float64, error) {
  153. m.mu.RLock()
  154. defer m.mu.RUnlock()
  155. result := make(map[int64]float64)
  156. if data, exists := m.models[model]; exists {
  157. for channelID, channel := range data.channels {
  158. result[channelID] = getErrorRateFromStats(channel.timeWindows)
  159. }
  160. }
  161. return result, nil
  162. }
  163. func (m *MemModelMonitor) GetChannelModelErrorRates(
  164. _ context.Context,
  165. channelID int64,
  166. ) (map[string]float64, error) {
  167. m.mu.RLock()
  168. defer m.mu.RUnlock()
  169. result := make(map[string]float64)
  170. for model, data := range m.models {
  171. if channel, exists := data.channels[channelID]; exists {
  172. result[model] = getErrorRateFromStats(channel.timeWindows)
  173. }
  174. }
  175. return result, nil
  176. }
  177. func (m *MemModelMonitor) GetAllChannelModelErrorRates(
  178. _ context.Context,
  179. ) (map[int64]map[string]float64, error) {
  180. m.mu.RLock()
  181. defer m.mu.RUnlock()
  182. result := make(map[int64]map[string]float64)
  183. for model, data := range m.models {
  184. for channelID, channel := range data.channels {
  185. if _, exists := result[channelID]; !exists {
  186. result[channelID] = make(map[string]float64)
  187. }
  188. result[channelID][model] = getErrorRateFromStats(channel.timeWindows)
  189. }
  190. }
  191. return result, nil
  192. }
  193. func (m *MemModelMonitor) GetBannedChannelsWithModel(
  194. _ context.Context,
  195. model string,
  196. ) ([]int64, error) {
  197. m.mu.RLock()
  198. defer m.mu.RUnlock()
  199. var banned []int64
  200. if data, exists := m.models[model]; exists {
  201. now := time.Now()
  202. for channelID, channel := range data.channels {
  203. if channel.bannedUntil.After(now) {
  204. banned = append(banned, channelID)
  205. }
  206. }
  207. }
  208. return banned, nil
  209. }
  210. func (m *MemModelMonitor) GetAllBannedModelChannels(_ context.Context) (map[string][]int64, error) {
  211. m.mu.RLock()
  212. defer m.mu.RUnlock()
  213. result := make(map[string][]int64)
  214. now := time.Now()
  215. for model, data := range m.models {
  216. for channelID, channel := range data.channels {
  217. if channel.bannedUntil.After(now) {
  218. if _, exists := result[model]; !exists {
  219. result[model] = []int64{}
  220. }
  221. result[model] = append(result[model], channelID)
  222. }
  223. }
  224. }
  225. return result, nil
  226. }
  227. func (m *MemModelMonitor) ClearChannelModelErrors(
  228. _ context.Context,
  229. model string,
  230. channelID int,
  231. ) error {
  232. m.mu.Lock()
  233. defer m.mu.Unlock()
  234. if data, exists := m.models[model]; exists {
  235. delete(data.channels, int64(channelID))
  236. }
  237. return nil
  238. }
  239. func (m *MemModelMonitor) ClearChannelAllModelErrors(_ context.Context, channelID int) error {
  240. m.mu.Lock()
  241. defer m.mu.Unlock()
  242. for _, data := range m.models {
  243. delete(data.channels, int64(channelID))
  244. }
  245. return nil
  246. }
  247. func (m *MemModelMonitor) ClearAllModelErrors(_ context.Context) error {
  248. m.mu.Lock()
  249. defer m.mu.Unlock()
  250. m.models = make(map[string]*ModelData)
  251. return nil
  252. }
  253. func (t *TimeWindowStats) cleanupLocked(callback func(slice *timeSlice)) {
  254. cutoff := time.Now().Add(-timeWindow * time.Duration(maxSliceCount))
  255. validSlices := t.slices[:0]
  256. for _, s := range t.slices {
  257. if s.windowStart.After(cutoff) || s.windowStart.Equal(cutoff) {
  258. validSlices = append(validSlices, s)
  259. if callback != nil {
  260. callback(s)
  261. }
  262. }
  263. }
  264. t.slices = validSlices
  265. }
  266. func (t *TimeWindowStats) AddRequest(now time.Time, isError bool) {
  267. t.mu.Lock()
  268. defer t.mu.Unlock()
  269. t.cleanupLocked(nil)
  270. currentWindow := now.Truncate(timeWindow)
  271. var slice *timeSlice
  272. for i := range t.slices {
  273. if t.slices[i].windowStart.Equal(currentWindow) {
  274. slice = t.slices[i]
  275. break
  276. }
  277. }
  278. if slice == nil {
  279. slice = &timeSlice{windowStart: currentWindow}
  280. t.slices = append(t.slices, slice)
  281. }
  282. slice.requests++
  283. if isError {
  284. slice.errors++
  285. }
  286. }
  287. func (t *TimeWindowStats) GetStats() (totalReq, totalErr int) {
  288. t.mu.Lock()
  289. defer t.mu.Unlock()
  290. t.cleanupLocked(func(slice *timeSlice) {
  291. totalReq += slice.requests
  292. totalErr += slice.errors
  293. })
  294. return
  295. }
  296. func (t *TimeWindowStats) HasValidSlices() bool {
  297. t.mu.Lock()
  298. defer t.mu.Unlock()
  299. t.cleanupLocked(nil)
  300. return len(t.slices) > 0
  301. }