ratelimiter_test.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package common
  2. import (
  3. "testing"
  4. "time"
  5. "github.com/stretchr/testify/assert"
  6. "github.com/stretchr/testify/require"
  7. )
  8. func TestRateLimiterConfig(t *testing.T) {
  9. config := RateLimiterConfig{}
  10. err := config.validate()
  11. require.Error(t, err)
  12. config.Burst = 1
  13. config.Period = 10
  14. err = config.validate()
  15. require.Error(t, err)
  16. config.Period = 1000
  17. config.Type = 100
  18. err = config.validate()
  19. require.Error(t, err)
  20. config.Type = int(rateLimiterTypeSource)
  21. config.EntriesSoftLimit = 0
  22. err = config.validate()
  23. require.Error(t, err)
  24. config.EntriesSoftLimit = 150
  25. config.EntriesHardLimit = 0
  26. err = config.validate()
  27. require.Error(t, err)
  28. config.EntriesHardLimit = 200
  29. config.Protocols = []string{"unsupported protocol"}
  30. err = config.validate()
  31. require.Error(t, err)
  32. config.Protocols = rateLimiterProtocolValues
  33. err = config.validate()
  34. require.NoError(t, err)
  35. limiter := config.getLimiter()
  36. require.Equal(t, 500*time.Millisecond, limiter.maxDelay)
  37. require.Nil(t, limiter.globalBucket)
  38. config.Type = int(rateLimiterTypeGlobal)
  39. config.Average = 1
  40. config.Period = 10000
  41. limiter = config.getLimiter()
  42. require.Equal(t, 5*time.Second, limiter.maxDelay)
  43. require.NotNil(t, limiter.globalBucket)
  44. config.Period = 100000
  45. limiter = config.getLimiter()
  46. require.Equal(t, 10*time.Second, limiter.maxDelay)
  47. config.Period = 500
  48. config.Average = 1
  49. limiter = config.getLimiter()
  50. require.Equal(t, 250*time.Millisecond, limiter.maxDelay)
  51. }
  52. func TestRateLimiter(t *testing.T) {
  53. config := RateLimiterConfig{
  54. Average: 1,
  55. Period: 1000,
  56. Burst: 1,
  57. Type: int(rateLimiterTypeGlobal),
  58. Protocols: rateLimiterProtocolValues,
  59. }
  60. limiter := config.getLimiter()
  61. _, err := limiter.Wait("")
  62. require.NoError(t, err)
  63. _, err = limiter.Wait("")
  64. require.Error(t, err)
  65. config.Type = int(rateLimiterTypeSource)
  66. config.GenerateDefenderEvents = true
  67. config.EntriesSoftLimit = 5
  68. config.EntriesHardLimit = 10
  69. limiter = config.getLimiter()
  70. source := "192.168.1.2"
  71. _, err = limiter.Wait(source)
  72. require.NoError(t, err)
  73. _, err = limiter.Wait(source)
  74. require.Error(t, err)
  75. // a different source should work
  76. _, err = limiter.Wait(source + "1")
  77. require.NoError(t, err)
  78. config.Burst = 0
  79. limiter = config.getLimiter()
  80. _, err = limiter.Wait(source)
  81. require.ErrorIs(t, err, errReserve)
  82. }
  83. func TestLimiterCleanup(t *testing.T) {
  84. config := RateLimiterConfig{
  85. Average: 100,
  86. Period: 1000,
  87. Burst: 1,
  88. Type: int(rateLimiterTypeSource),
  89. Protocols: rateLimiterProtocolValues,
  90. EntriesSoftLimit: 1,
  91. EntriesHardLimit: 3,
  92. }
  93. limiter := config.getLimiter()
  94. source1 := "10.8.0.1"
  95. source2 := "10.8.0.2"
  96. source3 := "10.8.0.3"
  97. source4 := "10.8.0.4"
  98. _, err := limiter.Wait(source1)
  99. assert.NoError(t, err)
  100. time.Sleep(20 * time.Millisecond)
  101. _, err = limiter.Wait(source2)
  102. assert.NoError(t, err)
  103. time.Sleep(20 * time.Millisecond)
  104. assert.Len(t, limiter.buckets.buckets, 2)
  105. _, ok := limiter.buckets.buckets[source1]
  106. assert.True(t, ok)
  107. _, ok = limiter.buckets.buckets[source2]
  108. assert.True(t, ok)
  109. _, err = limiter.Wait(source3)
  110. assert.NoError(t, err)
  111. assert.Len(t, limiter.buckets.buckets, 3)
  112. _, ok = limiter.buckets.buckets[source1]
  113. assert.True(t, ok)
  114. _, ok = limiter.buckets.buckets[source2]
  115. assert.True(t, ok)
  116. _, ok = limiter.buckets.buckets[source3]
  117. assert.True(t, ok)
  118. time.Sleep(20 * time.Millisecond)
  119. _, err = limiter.Wait(source4)
  120. assert.NoError(t, err)
  121. assert.Len(t, limiter.buckets.buckets, 2)
  122. _, ok = limiter.buckets.buckets[source3]
  123. assert.True(t, ok)
  124. _, ok = limiter.buckets.buckets[source4]
  125. assert.True(t, ok)
  126. }