1
0

ratelimiter_test.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. // Copyright (C) 2019-2022 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package common
  15. import (
  16. "testing"
  17. "time"
  18. "github.com/stretchr/testify/assert"
  19. "github.com/stretchr/testify/require"
  20. "github.com/drakkan/sftpgo/v2/internal/util"
  21. )
  22. func TestRateLimiterConfig(t *testing.T) {
  23. config := RateLimiterConfig{}
  24. err := config.validate()
  25. require.Error(t, err)
  26. config.Burst = 1
  27. config.Period = 10
  28. err = config.validate()
  29. require.Error(t, err)
  30. config.Period = 1000
  31. config.Type = 100
  32. err = config.validate()
  33. require.Error(t, err)
  34. config.Type = int(rateLimiterTypeSource)
  35. config.EntriesSoftLimit = 0
  36. err = config.validate()
  37. require.Error(t, err)
  38. config.EntriesSoftLimit = 150
  39. config.EntriesHardLimit = 0
  40. err = config.validate()
  41. require.Error(t, err)
  42. config.EntriesHardLimit = 200
  43. config.Protocols = []string{"unsupported protocol"}
  44. err = config.validate()
  45. require.Error(t, err)
  46. config.Protocols = rateLimiterProtocolValues
  47. err = config.validate()
  48. require.NoError(t, err)
  49. limiter := config.getLimiter()
  50. require.Equal(t, 500*time.Millisecond, limiter.maxDelay)
  51. require.Nil(t, limiter.globalBucket)
  52. config.Type = int(rateLimiterTypeGlobal)
  53. config.Average = 1
  54. config.Period = 10000
  55. limiter = config.getLimiter()
  56. require.Equal(t, 5*time.Second, limiter.maxDelay)
  57. require.NotNil(t, limiter.globalBucket)
  58. config.Period = 100000
  59. limiter = config.getLimiter()
  60. require.Equal(t, 10*time.Second, limiter.maxDelay)
  61. config.Period = 500
  62. config.Average = 1
  63. limiter = config.getLimiter()
  64. require.Equal(t, 250*time.Millisecond, limiter.maxDelay)
  65. }
  66. func TestRateLimiter(t *testing.T) {
  67. config := RateLimiterConfig{
  68. Average: 1,
  69. Period: 1000,
  70. Burst: 1,
  71. Type: int(rateLimiterTypeGlobal),
  72. Protocols: rateLimiterProtocolValues,
  73. }
  74. limiter := config.getLimiter()
  75. _, err := limiter.Wait("")
  76. require.NoError(t, err)
  77. _, err = limiter.Wait("")
  78. require.Error(t, err)
  79. config.Type = int(rateLimiterTypeSource)
  80. config.GenerateDefenderEvents = true
  81. config.EntriesSoftLimit = 5
  82. config.EntriesHardLimit = 10
  83. limiter = config.getLimiter()
  84. source := "192.168.1.2"
  85. _, err = limiter.Wait(source)
  86. require.NoError(t, err)
  87. _, err = limiter.Wait(source)
  88. require.Error(t, err)
  89. // a different source should work
  90. _, err = limiter.Wait(source + "1")
  91. require.NoError(t, err)
  92. allowList := []string{"192.168.1.0/24"}
  93. allowFuncs, err := util.ParseAllowedIPAndRanges(allowList)
  94. assert.NoError(t, err)
  95. limiter.allowList = allowFuncs
  96. for i := 0; i < 5; i++ {
  97. _, err = limiter.Wait(source)
  98. require.NoError(t, err)
  99. }
  100. _, err = limiter.Wait("not an ip")
  101. require.NoError(t, err)
  102. config.Burst = 0
  103. limiter = config.getLimiter()
  104. _, err = limiter.Wait(source)
  105. require.ErrorIs(t, err, errReserve)
  106. }
  107. func TestLimiterCleanup(t *testing.T) {
  108. config := RateLimiterConfig{
  109. Average: 100,
  110. Period: 1000,
  111. Burst: 1,
  112. Type: int(rateLimiterTypeSource),
  113. Protocols: rateLimiterProtocolValues,
  114. EntriesSoftLimit: 1,
  115. EntriesHardLimit: 3,
  116. }
  117. limiter := config.getLimiter()
  118. source1 := "10.8.0.1"
  119. source2 := "10.8.0.2"
  120. source3 := "10.8.0.3"
  121. source4 := "10.8.0.4"
  122. _, err := limiter.Wait(source1)
  123. assert.NoError(t, err)
  124. time.Sleep(20 * time.Millisecond)
  125. _, err = limiter.Wait(source2)
  126. assert.NoError(t, err)
  127. time.Sleep(20 * time.Millisecond)
  128. assert.Len(t, limiter.buckets.buckets, 2)
  129. _, ok := limiter.buckets.buckets[source1]
  130. assert.True(t, ok)
  131. _, ok = limiter.buckets.buckets[source2]
  132. assert.True(t, ok)
  133. _, err = limiter.Wait(source3)
  134. assert.NoError(t, err)
  135. assert.Len(t, limiter.buckets.buckets, 3)
  136. _, ok = limiter.buckets.buckets[source1]
  137. assert.True(t, ok)
  138. _, ok = limiter.buckets.buckets[source2]
  139. assert.True(t, ok)
  140. _, ok = limiter.buckets.buckets[source3]
  141. assert.True(t, ok)
  142. time.Sleep(20 * time.Millisecond)
  143. _, err = limiter.Wait(source4)
  144. assert.NoError(t, err)
  145. assert.Len(t, limiter.buckets.buckets, 2)
  146. _, ok = limiter.buckets.buckets[source3]
  147. assert.True(t, ok)
  148. _, ok = limiter.buckets.buckets[source4]
  149. assert.True(t, ok)
  150. }