ratelimiter.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. package common
  2. import (
  3. "errors"
  4. "fmt"
  5. "sort"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. "golang.org/x/time/rate"
  10. "github.com/drakkan/sftpgo/utils"
  11. )
  12. var (
  13. errNoBucket = errors.New("no bucket found")
  14. errReserve = errors.New("unable to reserve token")
  15. rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP}
  16. )
  17. // RateLimiterType defines the supported rate limiters types
  18. type RateLimiterType int
  19. // Supported rate limiter types
  20. const (
  21. rateLimiterTypeGlobal RateLimiterType = iota + 1
  22. rateLimiterTypeSource
  23. )
  24. // RateLimiterConfig defines the configuration for a rate limiter
  25. type RateLimiterConfig struct {
  26. // Average defines the maximum rate allowed. 0 means disabled
  27. Average int64 `json:"average" mapstructure:"average"`
  28. // Period defines the period as milliseconds. Default: 1000 (1 second).
  29. // The rate is actually defined by dividing average by period.
  30. // So for a rate below 1 req/s, one needs to define a period larger than a second.
  31. Period int64 `json:"period" mapstructure:"period"`
  32. // Burst is the maximum number of requests allowed to go through in the
  33. // same arbitrarily small period of time. Default: 1.
  34. Burst int `json:"burst" mapstructure:"burst"`
  35. // Type defines the rate limiter type:
  36. // - rateLimiterTypeGlobal is a global rate limiter independent from the source
  37. // - rateLimiterTypeSource is a per-source rate limiter
  38. Type int `json:"type" mapstructure:"type"`
  39. // Protocols defines the protocols for this rate limiter.
  40. // Available protocols are: "SFTP", "FTP", "DAV".
  41. // A rate limiter with no protocols defined is disabled
  42. Protocols []string `json:"protocols" mapstructure:"protocols"`
  43. // If the rate limit is exceeded, the defender is enabled, and this is a per-source limiter,
  44. // a new defender event will be generated
  45. GenerateDefenderEvents bool `json:"generate_defender_events" mapstructure:"generate_defender_events"`
  46. // The number of per-ip rate limiters kept in memory will vary between the
  47. // soft and hard limit
  48. EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
  49. EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
  50. }
  51. func (r *RateLimiterConfig) isEnabled() bool {
  52. return r.Average > 0 && len(r.Protocols) > 0
  53. }
  54. func (r *RateLimiterConfig) validate() error {
  55. if r.Burst < 1 {
  56. return fmt.Errorf("invalid burst %v. It must be >= 1", r.Burst)
  57. }
  58. if r.Period < 100 {
  59. return fmt.Errorf("invalid period %v. It must be >= 100", r.Period)
  60. }
  61. if r.Type != int(rateLimiterTypeGlobal) && r.Type != int(rateLimiterTypeSource) {
  62. return fmt.Errorf("invalid type %v", r.Type)
  63. }
  64. if r.Type != int(rateLimiterTypeGlobal) {
  65. if r.EntriesSoftLimit <= 0 {
  66. return fmt.Errorf("invalid entries_soft_limit %v", r.EntriesSoftLimit)
  67. }
  68. if r.EntriesHardLimit <= r.EntriesSoftLimit {
  69. return fmt.Errorf("invalid entries_hard_limit %v must be > %v", r.EntriesHardLimit, r.EntriesSoftLimit)
  70. }
  71. }
  72. r.Protocols = utils.RemoveDuplicates(r.Protocols)
  73. for _, protocol := range r.Protocols {
  74. if !utils.IsStringInSlice(protocol, rateLimiterProtocolValues) {
  75. return fmt.Errorf("invalid protocol %#v", protocol)
  76. }
  77. }
  78. return nil
  79. }
  80. func (r *RateLimiterConfig) getLimiter() *rateLimiter {
  81. limiter := &rateLimiter{
  82. burst: r.Burst,
  83. globalBucket: nil,
  84. generateDefenderEvents: r.GenerateDefenderEvents,
  85. }
  86. var maxDelay time.Duration
  87. period := time.Duration(r.Period) * time.Millisecond
  88. rtl := float64(r.Average*int64(time.Second)) / float64(period)
  89. limiter.rate = rate.Limit(rtl)
  90. if rtl < 1 {
  91. maxDelay = period / 2
  92. } else {
  93. maxDelay = time.Second / (time.Duration(rtl) * 2)
  94. }
  95. if maxDelay > 10*time.Second {
  96. maxDelay = 10 * time.Second
  97. }
  98. limiter.maxDelay = maxDelay
  99. limiter.buckets = sourceBuckets{
  100. buckets: make(map[string]sourceRateLimiter),
  101. hardLimit: r.EntriesHardLimit,
  102. softLimit: r.EntriesSoftLimit,
  103. }
  104. if r.Type != int(rateLimiterTypeSource) {
  105. limiter.globalBucket = rate.NewLimiter(limiter.rate, limiter.burst)
  106. }
  107. return limiter
  108. }
  109. // RateLimiter defines a rate limiter
  110. type rateLimiter struct {
  111. rate rate.Limit
  112. burst int
  113. maxDelay time.Duration
  114. globalBucket *rate.Limiter
  115. buckets sourceBuckets
  116. generateDefenderEvents bool
  117. }
  118. // Wait blocks until the limit allows one event to happen
  119. // or returns an error if the time to wait exceeds the max
  120. // allowed delay
  121. func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
  122. var res *rate.Reservation
  123. if rl.globalBucket != nil {
  124. res = rl.globalBucket.Reserve()
  125. } else {
  126. var err error
  127. res, err = rl.buckets.reserve(source)
  128. if err != nil {
  129. rateLimiter := rate.NewLimiter(rl.rate, rl.burst)
  130. res = rl.buckets.addAndReserve(rateLimiter, source)
  131. }
  132. }
  133. if !res.OK() {
  134. return 0, errReserve
  135. }
  136. delay := res.Delay()
  137. if delay > rl.maxDelay {
  138. res.Cancel()
  139. if rl.generateDefenderEvents && rl.globalBucket == nil {
  140. AddDefenderEvent(source, HostEventRateExceeded)
  141. }
  142. return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
  143. }
  144. time.Sleep(delay)
  145. return 0, nil
  146. }
  147. type sourceRateLimiter struct {
  148. lastActivity int64
  149. bucket *rate.Limiter
  150. }
  151. func (s *sourceRateLimiter) updateLastActivity() {
  152. atomic.StoreInt64(&s.lastActivity, time.Now().UnixNano())
  153. }
  154. func (s *sourceRateLimiter) getLastActivity() int64 {
  155. return atomic.LoadInt64(&s.lastActivity)
  156. }
  157. type sourceBuckets struct {
  158. sync.RWMutex
  159. buckets map[string]sourceRateLimiter
  160. hardLimit int
  161. softLimit int
  162. }
  163. func (b *sourceBuckets) reserve(source string) (*rate.Reservation, error) {
  164. b.RLock()
  165. defer b.RUnlock()
  166. if src, ok := b.buckets[source]; ok {
  167. src.updateLastActivity()
  168. return src.bucket.Reserve(), nil
  169. }
  170. return nil, errNoBucket
  171. }
  172. func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Reservation {
  173. b.Lock()
  174. defer b.Unlock()
  175. b.cleanup()
  176. src := sourceRateLimiter{
  177. bucket: r,
  178. }
  179. src.updateLastActivity()
  180. b.buckets[source] = src
  181. return src.bucket.Reserve()
  182. }
  183. func (b *sourceBuckets) cleanup() {
  184. if len(b.buckets) >= b.hardLimit {
  185. numToRemove := len(b.buckets) - b.softLimit
  186. kvList := make(kvList, 0, len(b.buckets))
  187. for k, v := range b.buckets {
  188. kvList = append(kvList, kv{
  189. Key: k,
  190. Value: v.getLastActivity(),
  191. })
  192. }
  193. sort.Sort(kvList)
  194. for idx, kv := range kvList {
  195. if idx >= numToRemove {
  196. break
  197. }
  198. delete(b.buckets, kv.Key)
  199. }
  200. }
  201. }