| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- package reqlimit
- import (
- "strings"
- "sync"
- "sync/atomic"
- "time"
- )
- type windowCounts struct {
- normal int64
- over int64
- }
- type entry struct {
- sync.Mutex
- windows map[int64]*windowCounts
- lastAccess atomic.Value
- }
- type InMemoryRecord struct {
- entries sync.Map
- }
- func NewInMemoryRecord() *InMemoryRecord {
- rl := &InMemoryRecord{
- entries: sync.Map{},
- }
- go rl.cleanupInactiveEntries(2*time.Minute, 1*time.Minute)
- return rl
- }
- func (m *InMemoryRecord) getEntry(keys []string) *entry {
- key := strings.Join(keys, ":")
- actual, _ := m.entries.LoadOrStore(key, &entry{
- windows: make(map[int64]*windowCounts),
- })
- e, _ := actual.(*entry)
- if e.lastAccess.Load() == nil {
- e.lastAccess.CompareAndSwap(nil, time.Now())
- }
- return e
- }
- func (m *InMemoryRecord) cleanupAndCount(e *entry, cutoff int64) (int64, int64) {
- normalCount := int64(0)
- overCount := int64(0)
- for ts, wc := range e.windows {
- if ts < cutoff {
- delete(e.windows, ts)
- } else {
- normalCount += wc.normal
- overCount += wc.over
- }
- }
- return normalCount, overCount
- }
- func (m *InMemoryRecord) PushRequest(
- overed int64,
- duration time.Duration,
- n int64,
- keys ...string,
- ) (normalCount, overCount, secondCount int64) {
- e := m.getEntry(keys)
- e.Lock()
- defer e.Unlock()
- now := time.Now()
- e.lastAccess.Store(now)
- windowStart := now.Unix()
- cutoff := windowStart - int64(duration.Seconds())
- normalCount, overCount = m.cleanupAndCount(e, cutoff)
- wc, exists := e.windows[windowStart]
- if !exists {
- wc = &windowCounts{}
- e.windows[windowStart] = wc
- }
- if overed == 0 || normalCount <= overed {
- wc.normal += n
- normalCount += n
- } else {
- wc.over += n
- overCount += n
- }
- return normalCount, overCount, wc.normal + wc.over
- }
- func (m *InMemoryRecord) GetRequest(
- duration time.Duration,
- keys ...string,
- ) (totalCount, secondCount int64) {
- nowSecond := time.Now().Unix()
- cutoff := nowSecond - int64(duration.Seconds())
- m.entries.Range(func(key, value any) bool {
- k, _ := key.(string)
- currentKeys := parseKeys(k)
- if matchKeys(keys, currentKeys) {
- e, _ := value.(*entry)
- e.Lock()
- normalCount, overCount := m.cleanupAndCount(e, cutoff)
- nowWindow := e.windows[nowSecond]
- e.Unlock()
- totalCount += normalCount + overCount
- if nowWindow != nil {
- secondCount += nowWindow.normal + nowWindow.over
- }
- }
- return true
- })
- return totalCount, secondCount
- }
- func (m *InMemoryRecord) cleanupInactiveEntries(interval, maxInactivity time.Duration) {
- ticker := time.NewTicker(interval)
- defer ticker.Stop()
- for range ticker.C {
- m.entries.Range(func(key, value any) bool {
- e, _ := value.(*entry)
- la := e.lastAccess.Load()
- if la == nil {
- return true
- }
- lastAccess, _ := la.(time.Time)
- if time.Since(lastAccess) > maxInactivity {
- m.entries.CompareAndDelete(key, e)
- }
- return true
- })
- }
- }
- func parseKeys(key string) []string {
- return strings.Split(key, ":")
- }
- func matchKeys(pattern, keys []string) bool {
- if len(pattern) != len(keys) {
- return false
- }
- for i, p := range pattern {
- if p != "*" && p != keys[i] {
- return false
- }
- }
- return true
- }
|