ratelimiter.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. // Copyright (C) 2019 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package common
  15. import (
  16. "errors"
  17. "fmt"
  18. "slices"
  19. "sort"
  20. "sync"
  21. "sync/atomic"
  22. "time"
  23. "golang.org/x/time/rate"
  24. "github.com/drakkan/sftpgo/v2/internal/util"
  25. )
  26. var (
  27. errNoBucket = errors.New("no bucket found")
  28. errReserve = errors.New("unable to reserve token")
  29. rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP}
  30. )
  31. // RateLimiterType defines the supported rate limiters types
  32. type RateLimiterType int
  33. // Supported rate limiter types
  34. const (
  35. rateLimiterTypeGlobal RateLimiterType = iota + 1
  36. rateLimiterTypeSource
  37. )
  38. // RateLimiterConfig defines the configuration for a rate limiter
  39. type RateLimiterConfig struct {
  40. // Average defines the maximum rate allowed. 0 means disabled
  41. Average int64 `json:"average" mapstructure:"average"`
  42. // Period defines the period as milliseconds. Default: 1000 (1 second).
  43. // The rate is actually defined by dividing average by period.
  44. // So for a rate below 1 req/s, one needs to define a period larger than a second.
  45. Period int64 `json:"period" mapstructure:"period"`
  46. // Burst is the maximum number of requests allowed to go through in the
  47. // same arbitrarily small period of time. Default: 1.
  48. Burst int `json:"burst" mapstructure:"burst"`
  49. // Type defines the rate limiter type:
  50. // - rateLimiterTypeGlobal is a global rate limiter independent from the source
  51. // - rateLimiterTypeSource is a per-source rate limiter
  52. Type int `json:"type" mapstructure:"type"`
  53. // Protocols defines the protocols for this rate limiter.
  54. // Available protocols are: "SFTP", "FTP", "DAV".
  55. // A rate limiter with no protocols defined is disabled
  56. Protocols []string `json:"protocols" mapstructure:"protocols"`
  57. // If the rate limit is exceeded, the defender is enabled, and this is a per-source limiter,
  58. // a new defender event will be generated
  59. GenerateDefenderEvents bool `json:"generate_defender_events" mapstructure:"generate_defender_events"`
  60. // The number of per-ip rate limiters kept in memory will vary between the
  61. // soft and hard limit
  62. EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
  63. EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
  64. }
  65. func (r *RateLimiterConfig) isEnabled() bool {
  66. return r.Average > 0 && len(r.Protocols) > 0
  67. }
  68. func (r *RateLimiterConfig) validate() error {
  69. if r.Burst < 1 {
  70. return fmt.Errorf("invalid burst %v. It must be >= 1", r.Burst)
  71. }
  72. if r.Period < 100 {
  73. return fmt.Errorf("invalid period %v. It must be >= 100", r.Period)
  74. }
  75. if r.Type != int(rateLimiterTypeGlobal) && r.Type != int(rateLimiterTypeSource) {
  76. return fmt.Errorf("invalid type %v", r.Type)
  77. }
  78. if r.Type != int(rateLimiterTypeGlobal) {
  79. if r.EntriesSoftLimit <= 0 {
  80. return fmt.Errorf("invalid entries_soft_limit %v", r.EntriesSoftLimit)
  81. }
  82. if r.EntriesHardLimit <= r.EntriesSoftLimit {
  83. return fmt.Errorf("invalid entries_hard_limit %v must be > %v", r.EntriesHardLimit, r.EntriesSoftLimit)
  84. }
  85. }
  86. r.Protocols = util.RemoveDuplicates(r.Protocols, true)
  87. for _, protocol := range r.Protocols {
  88. if !slices.Contains(rateLimiterProtocolValues, protocol) {
  89. return fmt.Errorf("invalid protocol %q", protocol)
  90. }
  91. }
  92. return nil
  93. }
  94. func (r *RateLimiterConfig) getLimiter() *rateLimiter {
  95. limiter := &rateLimiter{
  96. burst: r.Burst,
  97. globalBucket: nil,
  98. generateDefenderEvents: r.GenerateDefenderEvents,
  99. }
  100. var maxDelay time.Duration
  101. period := time.Duration(r.Period) * time.Millisecond
  102. rtl := float64(r.Average*int64(time.Second)) / float64(period)
  103. limiter.rate = rate.Limit(rtl)
  104. if rtl < 1 {
  105. maxDelay = period / 2
  106. } else {
  107. maxDelay = time.Second / (time.Duration(rtl) * 2)
  108. }
  109. if maxDelay > 10*time.Second {
  110. maxDelay = 10 * time.Second
  111. }
  112. limiter.maxDelay = maxDelay
  113. limiter.buckets = sourceBuckets{
  114. buckets: make(map[string]sourceRateLimiter),
  115. hardLimit: r.EntriesHardLimit,
  116. softLimit: r.EntriesSoftLimit,
  117. }
  118. if r.Type != int(rateLimiterTypeSource) {
  119. limiter.globalBucket = rate.NewLimiter(limiter.rate, limiter.burst)
  120. }
  121. return limiter
  122. }
  123. // RateLimiter defines a rate limiter
  124. type rateLimiter struct {
  125. rate rate.Limit
  126. burst int
  127. maxDelay time.Duration
  128. globalBucket *rate.Limiter
  129. buckets sourceBuckets
  130. generateDefenderEvents bool
  131. }
  132. // Wait blocks until the limit allows one event to happen
  133. // or returns an error if the time to wait exceeds the max
  134. // allowed delay
  135. func (rl *rateLimiter) Wait(source, protocol string) (time.Duration, error) {
  136. var res *rate.Reservation
  137. if rl.globalBucket != nil {
  138. res = rl.globalBucket.Reserve()
  139. } else {
  140. var err error
  141. res, err = rl.buckets.reserve(source)
  142. if err != nil {
  143. rateLimiter := rate.NewLimiter(rl.rate, rl.burst)
  144. res = rl.buckets.addAndReserve(rateLimiter, source)
  145. }
  146. }
  147. if !res.OK() {
  148. return 0, errReserve
  149. }
  150. delay := res.Delay()
  151. if delay > rl.maxDelay {
  152. res.Cancel()
  153. if rl.generateDefenderEvents && rl.globalBucket == nil {
  154. AddDefenderEvent(source, protocol, HostEventLimitExceeded)
  155. }
  156. return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
  157. }
  158. time.Sleep(delay)
  159. return 0, nil
  160. }
  161. type sourceRateLimiter struct {
  162. lastActivity *atomic.Int64
  163. bucket *rate.Limiter
  164. }
  165. func (s *sourceRateLimiter) updateLastActivity() {
  166. s.lastActivity.Store(time.Now().UnixNano())
  167. }
  168. func (s *sourceRateLimiter) getLastActivity() int64 {
  169. return s.lastActivity.Load()
  170. }
  171. type sourceBuckets struct {
  172. sync.RWMutex
  173. buckets map[string]sourceRateLimiter
  174. hardLimit int
  175. softLimit int
  176. }
  177. func (b *sourceBuckets) reserve(source string) (*rate.Reservation, error) {
  178. b.RLock()
  179. defer b.RUnlock()
  180. if src, ok := b.buckets[source]; ok {
  181. src.updateLastActivity()
  182. return src.bucket.Reserve(), nil
  183. }
  184. return nil, errNoBucket
  185. }
  186. func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Reservation {
  187. b.Lock()
  188. defer b.Unlock()
  189. b.cleanup()
  190. src := sourceRateLimiter{
  191. lastActivity: new(atomic.Int64),
  192. bucket: r,
  193. }
  194. src.updateLastActivity()
  195. b.buckets[source] = src
  196. return src.bucket.Reserve()
  197. }
  198. func (b *sourceBuckets) cleanup() {
  199. if len(b.buckets) >= b.hardLimit {
  200. numToRemove := len(b.buckets) - b.softLimit
  201. kvList := make(kvList, 0, len(b.buckets))
  202. for k, v := range b.buckets {
  203. kvList = append(kvList, kv{
  204. Key: k,
  205. Value: v.getLastActivity(),
  206. })
  207. }
  208. sort.Sort(kvList)
  209. for idx, kv := range kvList {
  210. if idx >= numToRemove {
  211. break
  212. }
  213. delete(b.buckets, kv.Key)
  214. }
  215. }
  216. }