ratelimiter_test.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package common
  2. import (
  3. "testing"
  4. "time"
  5. "github.com/stretchr/testify/assert"
  6. "github.com/stretchr/testify/require"
  7. "github.com/drakkan/sftpgo/v2/util"
  8. )
  9. func TestRateLimiterConfig(t *testing.T) {
  10. config := RateLimiterConfig{}
  11. err := config.validate()
  12. require.Error(t, err)
  13. config.Burst = 1
  14. config.Period = 10
  15. err = config.validate()
  16. require.Error(t, err)
  17. config.Period = 1000
  18. config.Type = 100
  19. err = config.validate()
  20. require.Error(t, err)
  21. config.Type = int(rateLimiterTypeSource)
  22. config.EntriesSoftLimit = 0
  23. err = config.validate()
  24. require.Error(t, err)
  25. config.EntriesSoftLimit = 150
  26. config.EntriesHardLimit = 0
  27. err = config.validate()
  28. require.Error(t, err)
  29. config.EntriesHardLimit = 200
  30. config.Protocols = []string{"unsupported protocol"}
  31. err = config.validate()
  32. require.Error(t, err)
  33. config.Protocols = rateLimiterProtocolValues
  34. err = config.validate()
  35. require.NoError(t, err)
  36. limiter := config.getLimiter()
  37. require.Equal(t, 500*time.Millisecond, limiter.maxDelay)
  38. require.Nil(t, limiter.globalBucket)
  39. config.Type = int(rateLimiterTypeGlobal)
  40. config.Average = 1
  41. config.Period = 10000
  42. limiter = config.getLimiter()
  43. require.Equal(t, 5*time.Second, limiter.maxDelay)
  44. require.NotNil(t, limiter.globalBucket)
  45. config.Period = 100000
  46. limiter = config.getLimiter()
  47. require.Equal(t, 10*time.Second, limiter.maxDelay)
  48. config.Period = 500
  49. config.Average = 1
  50. limiter = config.getLimiter()
  51. require.Equal(t, 250*time.Millisecond, limiter.maxDelay)
  52. }
  53. func TestRateLimiter(t *testing.T) {
  54. config := RateLimiterConfig{
  55. Average: 1,
  56. Period: 1000,
  57. Burst: 1,
  58. Type: int(rateLimiterTypeGlobal),
  59. Protocols: rateLimiterProtocolValues,
  60. }
  61. limiter := config.getLimiter()
  62. _, err := limiter.Wait("")
  63. require.NoError(t, err)
  64. _, err = limiter.Wait("")
  65. require.Error(t, err)
  66. config.Type = int(rateLimiterTypeSource)
  67. config.GenerateDefenderEvents = true
  68. config.EntriesSoftLimit = 5
  69. config.EntriesHardLimit = 10
  70. limiter = config.getLimiter()
  71. source := "192.168.1.2"
  72. _, err = limiter.Wait(source)
  73. require.NoError(t, err)
  74. _, err = limiter.Wait(source)
  75. require.Error(t, err)
  76. // a different source should work
  77. _, err = limiter.Wait(source + "1")
  78. require.NoError(t, err)
  79. allowList := []string{"192.168.1.0/24"}
  80. allowFuncs, err := util.ParseAllowedIPAndRanges(allowList)
  81. assert.NoError(t, err)
  82. limiter.allowList = allowFuncs
  83. for i := 0; i < 5; i++ {
  84. _, err = limiter.Wait(source)
  85. require.NoError(t, err)
  86. }
  87. _, err = limiter.Wait("not an ip")
  88. require.NoError(t, err)
  89. config.Burst = 0
  90. limiter = config.getLimiter()
  91. _, err = limiter.Wait(source)
  92. require.ErrorIs(t, err, errReserve)
  93. }
  94. func TestLimiterCleanup(t *testing.T) {
  95. config := RateLimiterConfig{
  96. Average: 100,
  97. Period: 1000,
  98. Burst: 1,
  99. Type: int(rateLimiterTypeSource),
  100. Protocols: rateLimiterProtocolValues,
  101. EntriesSoftLimit: 1,
  102. EntriesHardLimit: 3,
  103. }
  104. limiter := config.getLimiter()
  105. source1 := "10.8.0.1"
  106. source2 := "10.8.0.2"
  107. source3 := "10.8.0.3"
  108. source4 := "10.8.0.4"
  109. _, err := limiter.Wait(source1)
  110. assert.NoError(t, err)
  111. time.Sleep(20 * time.Millisecond)
  112. _, err = limiter.Wait(source2)
  113. assert.NoError(t, err)
  114. time.Sleep(20 * time.Millisecond)
  115. assert.Len(t, limiter.buckets.buckets, 2)
  116. _, ok := limiter.buckets.buckets[source1]
  117. assert.True(t, ok)
  118. _, ok = limiter.buckets.buckets[source2]
  119. assert.True(t, ok)
  120. _, err = limiter.Wait(source3)
  121. assert.NoError(t, err)
  122. assert.Len(t, limiter.buckets.buckets, 3)
  123. _, ok = limiter.buckets.buckets[source1]
  124. assert.True(t, ok)
  125. _, ok = limiter.buckets.buckets[source2]
  126. assert.True(t, ok)
  127. _, ok = limiter.buckets.buckets[source3]
  128. assert.True(t, ok)
  129. time.Sleep(20 * time.Millisecond)
  130. _, err = limiter.Wait(source4)
  131. assert.NoError(t, err)
  132. assert.Len(t, limiter.buckets.buckets, 2)
  133. _, ok = limiter.buckets.buckets[source3]
  134. assert.True(t, ok)
  135. _, ok = limiter.buckets.buckets[source4]
  136. assert.True(t, ok)
  137. }