1
0

ratelimiter.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. // Copyright (C) 2019 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package common
  15. import (
  16. "errors"
  17. "fmt"
  18. "sort"
  19. "sync"
  20. "sync/atomic"
  21. "time"
  22. "golang.org/x/time/rate"
  23. "github.com/drakkan/sftpgo/v2/internal/util"
  24. )
  25. var (
  26. errNoBucket = errors.New("no bucket found")
  27. errReserve = errors.New("unable to reserve token")
  28. rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP}
  29. )
  30. // RateLimiterType defines the supported rate limiters types
  31. type RateLimiterType int
  32. // Supported rate limiter types
  33. const (
  34. rateLimiterTypeGlobal RateLimiterType = iota + 1
  35. rateLimiterTypeSource
  36. )
  37. // RateLimiterConfig defines the configuration for a rate limiter
  38. type RateLimiterConfig struct {
  39. // Average defines the maximum rate allowed. 0 means disabled
  40. Average int64 `json:"average" mapstructure:"average"`
  41. // Period defines the period as milliseconds. Default: 1000 (1 second).
  42. // The rate is actually defined by dividing average by period.
  43. // So for a rate below 1 req/s, one needs to define a period larger than a second.
  44. Period int64 `json:"period" mapstructure:"period"`
  45. // Burst is the maximum number of requests allowed to go through in the
  46. // same arbitrarily small period of time. Default: 1.
  47. Burst int `json:"burst" mapstructure:"burst"`
  48. // Type defines the rate limiter type:
  49. // - rateLimiterTypeGlobal is a global rate limiter independent from the source
  50. // - rateLimiterTypeSource is a per-source rate limiter
  51. Type int `json:"type" mapstructure:"type"`
  52. // Protocols defines the protocols for this rate limiter.
  53. // Available protocols are: "SFTP", "FTP", "DAV".
  54. // A rate limiter with no protocols defined is disabled
  55. Protocols []string `json:"protocols" mapstructure:"protocols"`
  56. // If the rate limit is exceeded, the defender is enabled, and this is a per-source limiter,
  57. // a new defender event will be generated
  58. GenerateDefenderEvents bool `json:"generate_defender_events" mapstructure:"generate_defender_events"`
  59. // The number of per-ip rate limiters kept in memory will vary between the
  60. // soft and hard limit
  61. EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
  62. EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
  63. }
  64. func (r *RateLimiterConfig) isEnabled() bool {
  65. return r.Average > 0 && len(r.Protocols) > 0
  66. }
  67. func (r *RateLimiterConfig) validate() error {
  68. if r.Burst < 1 {
  69. return fmt.Errorf("invalid burst %v. It must be >= 1", r.Burst)
  70. }
  71. if r.Period < 100 {
  72. return fmt.Errorf("invalid period %v. It must be >= 100", r.Period)
  73. }
  74. if r.Type != int(rateLimiterTypeGlobal) && r.Type != int(rateLimiterTypeSource) {
  75. return fmt.Errorf("invalid type %v", r.Type)
  76. }
  77. if r.Type != int(rateLimiterTypeGlobal) {
  78. if r.EntriesSoftLimit <= 0 {
  79. return fmt.Errorf("invalid entries_soft_limit %v", r.EntriesSoftLimit)
  80. }
  81. if r.EntriesHardLimit <= r.EntriesSoftLimit {
  82. return fmt.Errorf("invalid entries_hard_limit %v must be > %v", r.EntriesHardLimit, r.EntriesSoftLimit)
  83. }
  84. }
  85. r.Protocols = util.RemoveDuplicates(r.Protocols, true)
  86. for _, protocol := range r.Protocols {
  87. if !util.Contains(rateLimiterProtocolValues, protocol) {
  88. return fmt.Errorf("invalid protocol %q", protocol)
  89. }
  90. }
  91. return nil
  92. }
  93. func (r *RateLimiterConfig) getLimiter() *rateLimiter {
  94. limiter := &rateLimiter{
  95. burst: r.Burst,
  96. globalBucket: nil,
  97. generateDefenderEvents: r.GenerateDefenderEvents,
  98. }
  99. var maxDelay time.Duration
  100. period := time.Duration(r.Period) * time.Millisecond
  101. rtl := float64(r.Average*int64(time.Second)) / float64(period)
  102. limiter.rate = rate.Limit(rtl)
  103. if rtl < 1 {
  104. maxDelay = period / 2
  105. } else {
  106. maxDelay = time.Second / (time.Duration(rtl) * 2)
  107. }
  108. if maxDelay > 10*time.Second {
  109. maxDelay = 10 * time.Second
  110. }
  111. limiter.maxDelay = maxDelay
  112. limiter.buckets = sourceBuckets{
  113. buckets: make(map[string]sourceRateLimiter),
  114. hardLimit: r.EntriesHardLimit,
  115. softLimit: r.EntriesSoftLimit,
  116. }
  117. if r.Type != int(rateLimiterTypeSource) {
  118. limiter.globalBucket = rate.NewLimiter(limiter.rate, limiter.burst)
  119. }
  120. return limiter
  121. }
  122. // RateLimiter defines a rate limiter
  123. type rateLimiter struct {
  124. rate rate.Limit
  125. burst int
  126. maxDelay time.Duration
  127. globalBucket *rate.Limiter
  128. buckets sourceBuckets
  129. generateDefenderEvents bool
  130. }
  131. // Wait blocks until the limit allows one event to happen
  132. // or returns an error if the time to wait exceeds the max
  133. // allowed delay
  134. func (rl *rateLimiter) Wait(source, protocol string) (time.Duration, error) {
  135. var res *rate.Reservation
  136. if rl.globalBucket != nil {
  137. res = rl.globalBucket.Reserve()
  138. } else {
  139. var err error
  140. res, err = rl.buckets.reserve(source)
  141. if err != nil {
  142. rateLimiter := rate.NewLimiter(rl.rate, rl.burst)
  143. res = rl.buckets.addAndReserve(rateLimiter, source)
  144. }
  145. }
  146. if !res.OK() {
  147. return 0, errReserve
  148. }
  149. delay := res.Delay()
  150. if delay > rl.maxDelay {
  151. res.Cancel()
  152. if rl.generateDefenderEvents && rl.globalBucket == nil {
  153. AddDefenderEvent(source, protocol, HostEventLimitExceeded)
  154. }
  155. return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
  156. }
  157. time.Sleep(delay)
  158. return 0, nil
  159. }
  160. type sourceRateLimiter struct {
  161. lastActivity *atomic.Int64
  162. bucket *rate.Limiter
  163. }
  164. func (s *sourceRateLimiter) updateLastActivity() {
  165. s.lastActivity.Store(time.Now().UnixNano())
  166. }
  167. func (s *sourceRateLimiter) getLastActivity() int64 {
  168. return s.lastActivity.Load()
  169. }
  170. type sourceBuckets struct {
  171. sync.RWMutex
  172. buckets map[string]sourceRateLimiter
  173. hardLimit int
  174. softLimit int
  175. }
  176. func (b *sourceBuckets) reserve(source string) (*rate.Reservation, error) {
  177. b.RLock()
  178. defer b.RUnlock()
  179. if src, ok := b.buckets[source]; ok {
  180. src.updateLastActivity()
  181. return src.bucket.Reserve(), nil
  182. }
  183. return nil, errNoBucket
  184. }
  185. func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Reservation {
  186. b.Lock()
  187. defer b.Unlock()
  188. b.cleanup()
  189. src := sourceRateLimiter{
  190. lastActivity: new(atomic.Int64),
  191. bucket: r,
  192. }
  193. src.updateLastActivity()
  194. b.buckets[source] = src
  195. return src.bucket.Reserve()
  196. }
  197. func (b *sourceBuckets) cleanup() {
  198. if len(b.buckets) >= b.hardLimit {
  199. numToRemove := len(b.buckets) - b.softLimit
  200. kvList := make(kvList, 0, len(b.buckets))
  201. for k, v := range b.buckets {
  202. kvList = append(kvList, kv{
  203. Key: k,
  204. Value: v.getLastActivity(),
  205. })
  206. }
  207. sort.Sort(kvList)
  208. for idx, kv := range kvList {
  209. if idx >= numToRemove {
  210. break
  211. }
  212. delete(b.buckets, kv.Key)
  213. }
  214. }
  215. }