ratelimiter_test.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. // Copyright (C) 2019 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package common
  15. import (
  16. "testing"
  17. "time"
  18. "github.com/stretchr/testify/assert"
  19. "github.com/stretchr/testify/require"
  20. )
  21. func TestRateLimiterConfig(t *testing.T) {
  22. config := RateLimiterConfig{}
  23. err := config.validate()
  24. require.Error(t, err)
  25. config.Burst = 1
  26. config.Period = 10
  27. err = config.validate()
  28. require.Error(t, err)
  29. config.Period = 1000
  30. config.Type = 100
  31. err = config.validate()
  32. require.Error(t, err)
  33. config.Type = int(rateLimiterTypeSource)
  34. config.EntriesSoftLimit = 0
  35. err = config.validate()
  36. require.Error(t, err)
  37. config.EntriesSoftLimit = 150
  38. config.EntriesHardLimit = 0
  39. err = config.validate()
  40. require.Error(t, err)
  41. config.EntriesHardLimit = 200
  42. config.Protocols = []string{"unsupported protocol"}
  43. err = config.validate()
  44. require.Error(t, err)
  45. config.Protocols = rateLimiterProtocolValues
  46. err = config.validate()
  47. require.NoError(t, err)
  48. limiter := config.getLimiter()
  49. require.Equal(t, 500*time.Millisecond, limiter.maxDelay)
  50. require.Nil(t, limiter.globalBucket)
  51. config.Type = int(rateLimiterTypeGlobal)
  52. config.Average = 1
  53. config.Period = 10000
  54. limiter = config.getLimiter()
  55. require.Equal(t, 5*time.Second, limiter.maxDelay)
  56. require.NotNil(t, limiter.globalBucket)
  57. config.Period = 100000
  58. limiter = config.getLimiter()
  59. require.Equal(t, 10*time.Second, limiter.maxDelay)
  60. config.Period = 500
  61. config.Average = 1
  62. limiter = config.getLimiter()
  63. require.Equal(t, 250*time.Millisecond, limiter.maxDelay)
  64. }
  65. func TestRateLimiter(t *testing.T) {
  66. config := RateLimiterConfig{
  67. Average: 1,
  68. Period: 1000,
  69. Burst: 1,
  70. Type: int(rateLimiterTypeGlobal),
  71. Protocols: rateLimiterProtocolValues,
  72. }
  73. limiter := config.getLimiter()
  74. _, err := limiter.Wait("", ProtocolFTP)
  75. require.NoError(t, err)
  76. _, err = limiter.Wait("", ProtocolSSH)
  77. require.Error(t, err)
  78. config.Type = int(rateLimiterTypeSource)
  79. config.GenerateDefenderEvents = true
  80. config.EntriesSoftLimit = 5
  81. config.EntriesHardLimit = 10
  82. limiter = config.getLimiter()
  83. source := "192.168.1.2"
  84. _, err = limiter.Wait(source, ProtocolSSH)
  85. require.NoError(t, err)
  86. _, err = limiter.Wait(source, ProtocolSSH)
  87. require.Error(t, err)
  88. // a different source should work
  89. _, err = limiter.Wait(source+"1", ProtocolSSH)
  90. require.NoError(t, err)
  91. config.Burst = 0
  92. limiter = config.getLimiter()
  93. _, err = limiter.Wait(source, ProtocolSSH)
  94. require.ErrorIs(t, err, errReserve)
  95. }
  96. func TestLimiterCleanup(t *testing.T) {
  97. config := RateLimiterConfig{
  98. Average: 100,
  99. Period: 1000,
  100. Burst: 1,
  101. Type: int(rateLimiterTypeSource),
  102. Protocols: rateLimiterProtocolValues,
  103. EntriesSoftLimit: 1,
  104. EntriesHardLimit: 3,
  105. }
  106. limiter := config.getLimiter()
  107. source1 := "10.8.0.1"
  108. source2 := "10.8.0.2"
  109. source3 := "10.8.0.3"
  110. source4 := "10.8.0.4"
  111. _, err := limiter.Wait(source1, ProtocolSSH)
  112. assert.NoError(t, err)
  113. time.Sleep(20 * time.Millisecond)
  114. _, err = limiter.Wait(source2, ProtocolSSH)
  115. assert.NoError(t, err)
  116. time.Sleep(20 * time.Millisecond)
  117. assert.Len(t, limiter.buckets.buckets, 2)
  118. _, ok := limiter.buckets.buckets[source1]
  119. assert.True(t, ok)
  120. _, ok = limiter.buckets.buckets[source2]
  121. assert.True(t, ok)
  122. _, err = limiter.Wait(source3, ProtocolSSH)
  123. assert.NoError(t, err)
  124. assert.Len(t, limiter.buckets.buckets, 3)
  125. _, ok = limiter.buckets.buckets[source1]
  126. assert.True(t, ok)
  127. _, ok = limiter.buckets.buckets[source2]
  128. assert.True(t, ok)
  129. _, ok = limiter.buckets.buckets[source3]
  130. assert.True(t, ok)
  131. time.Sleep(20 * time.Millisecond)
  132. _, err = limiter.Wait(source4, ProtocolSSH)
  133. assert.NoError(t, err)
  134. assert.Len(t, limiter.buckets.buckets, 2)
  135. _, ok = limiter.buckets.buckets[source3]
  136. assert.True(t, ok)
  137. _, ok = limiter.buckets.buckets[source4]
  138. assert.True(t, ok)
  139. }