ratelimiter.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. package common
  2. import (
  3. "errors"
  4. "fmt"
  5. "net"
  6. "sort"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "golang.org/x/time/rate"
  11. "github.com/drakkan/sftpgo/v2/util"
  12. )
  13. var (
  14. errNoBucket = errors.New("no bucket found")
  15. errReserve = errors.New("unable to reserve token")
  16. rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP}
  17. )
  18. // RateLimiterType defines the supported rate limiters types
  19. type RateLimiterType int
  20. // Supported rate limiter types
  21. const (
  22. rateLimiterTypeGlobal RateLimiterType = iota + 1
  23. rateLimiterTypeSource
  24. )
  25. // RateLimiterConfig defines the configuration for a rate limiter
  26. type RateLimiterConfig struct {
  27. // Average defines the maximum rate allowed. 0 means disabled
  28. Average int64 `json:"average" mapstructure:"average"`
  29. // Period defines the period as milliseconds. Default: 1000 (1 second).
  30. // The rate is actually defined by dividing average by period.
  31. // So for a rate below 1 req/s, one needs to define a period larger than a second.
  32. Period int64 `json:"period" mapstructure:"period"`
  33. // Burst is the maximum number of requests allowed to go through in the
  34. // same arbitrarily small period of time. Default: 1.
  35. Burst int `json:"burst" mapstructure:"burst"`
  36. // Type defines the rate limiter type:
  37. // - rateLimiterTypeGlobal is a global rate limiter independent from the source
  38. // - rateLimiterTypeSource is a per-source rate limiter
  39. Type int `json:"type" mapstructure:"type"`
  40. // Protocols defines the protocols for this rate limiter.
  41. // Available protocols are: "SFTP", "FTP", "DAV".
  42. // A rate limiter with no protocols defined is disabled
  43. Protocols []string `json:"protocols" mapstructure:"protocols"`
  44. // AllowList defines a list of IP addresses and IP ranges excluded from rate limiting
  45. AllowList []string `json:"allow_list" mapstructure:"mapstructure"`
  46. // If the rate limit is exceeded, the defender is enabled, and this is a per-source limiter,
  47. // a new defender event will be generated
  48. GenerateDefenderEvents bool `json:"generate_defender_events" mapstructure:"generate_defender_events"`
  49. // The number of per-ip rate limiters kept in memory will vary between the
  50. // soft and hard limit
  51. EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
  52. EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
  53. }
  54. func (r *RateLimiterConfig) isEnabled() bool {
  55. return r.Average > 0 && len(r.Protocols) > 0
  56. }
  57. func (r *RateLimiterConfig) validate() error {
  58. if r.Burst < 1 {
  59. return fmt.Errorf("invalid burst %v. It must be >= 1", r.Burst)
  60. }
  61. if r.Period < 100 {
  62. return fmt.Errorf("invalid period %v. It must be >= 100", r.Period)
  63. }
  64. if r.Type != int(rateLimiterTypeGlobal) && r.Type != int(rateLimiterTypeSource) {
  65. return fmt.Errorf("invalid type %v", r.Type)
  66. }
  67. if r.Type != int(rateLimiterTypeGlobal) {
  68. if r.EntriesSoftLimit <= 0 {
  69. return fmt.Errorf("invalid entries_soft_limit %v", r.EntriesSoftLimit)
  70. }
  71. if r.EntriesHardLimit <= r.EntriesSoftLimit {
  72. return fmt.Errorf("invalid entries_hard_limit %v must be > %v", r.EntriesHardLimit, r.EntriesSoftLimit)
  73. }
  74. }
  75. r.Protocols = util.RemoveDuplicates(r.Protocols)
  76. for _, protocol := range r.Protocols {
  77. if !util.IsStringInSlice(protocol, rateLimiterProtocolValues) {
  78. return fmt.Errorf("invalid protocol %#v", protocol)
  79. }
  80. }
  81. return nil
  82. }
  83. func (r *RateLimiterConfig) getLimiter() *rateLimiter {
  84. limiter := &rateLimiter{
  85. burst: r.Burst,
  86. globalBucket: nil,
  87. generateDefenderEvents: r.GenerateDefenderEvents,
  88. }
  89. var maxDelay time.Duration
  90. period := time.Duration(r.Period) * time.Millisecond
  91. rtl := float64(r.Average*int64(time.Second)) / float64(period)
  92. limiter.rate = rate.Limit(rtl)
  93. if rtl < 1 {
  94. maxDelay = period / 2
  95. } else {
  96. maxDelay = time.Second / (time.Duration(rtl) * 2)
  97. }
  98. if maxDelay > 10*time.Second {
  99. maxDelay = 10 * time.Second
  100. }
  101. limiter.maxDelay = maxDelay
  102. limiter.buckets = sourceBuckets{
  103. buckets: make(map[string]sourceRateLimiter),
  104. hardLimit: r.EntriesHardLimit,
  105. softLimit: r.EntriesSoftLimit,
  106. }
  107. if r.Type != int(rateLimiterTypeSource) {
  108. limiter.globalBucket = rate.NewLimiter(limiter.rate, limiter.burst)
  109. }
  110. return limiter
  111. }
  112. // RateLimiter defines a rate limiter
  113. type rateLimiter struct {
  114. rate rate.Limit
  115. burst int
  116. maxDelay time.Duration
  117. globalBucket *rate.Limiter
  118. buckets sourceBuckets
  119. generateDefenderEvents bool
  120. allowList []func(net.IP) bool
  121. }
  122. // Wait blocks until the limit allows one event to happen
  123. // or returns an error if the time to wait exceeds the max
  124. // allowed delay
  125. func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
  126. if len(rl.allowList) > 0 {
  127. ip := net.ParseIP(source)
  128. if ip != nil {
  129. for idx := range rl.allowList {
  130. if rl.allowList[idx](ip) {
  131. return 0, nil
  132. }
  133. }
  134. }
  135. }
  136. var res *rate.Reservation
  137. if rl.globalBucket != nil {
  138. res = rl.globalBucket.Reserve()
  139. } else {
  140. var err error
  141. res, err = rl.buckets.reserve(source)
  142. if err != nil {
  143. rateLimiter := rate.NewLimiter(rl.rate, rl.burst)
  144. res = rl.buckets.addAndReserve(rateLimiter, source)
  145. }
  146. }
  147. if !res.OK() {
  148. return 0, errReserve
  149. }
  150. delay := res.Delay()
  151. if delay > rl.maxDelay {
  152. res.Cancel()
  153. if rl.generateDefenderEvents && rl.globalBucket == nil {
  154. AddDefenderEvent(source, HostEventLimitExceeded)
  155. }
  156. return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
  157. }
  158. time.Sleep(delay)
  159. return 0, nil
  160. }
  161. type sourceRateLimiter struct {
  162. lastActivity int64
  163. bucket *rate.Limiter
  164. }
  165. func (s *sourceRateLimiter) updateLastActivity() {
  166. atomic.StoreInt64(&s.lastActivity, time.Now().UnixNano())
  167. }
  168. func (s *sourceRateLimiter) getLastActivity() int64 {
  169. return atomic.LoadInt64(&s.lastActivity)
  170. }
  171. type sourceBuckets struct {
  172. sync.RWMutex
  173. buckets map[string]sourceRateLimiter
  174. hardLimit int
  175. softLimit int
  176. }
  177. func (b *sourceBuckets) reserve(source string) (*rate.Reservation, error) {
  178. b.RLock()
  179. defer b.RUnlock()
  180. if src, ok := b.buckets[source]; ok {
  181. src.updateLastActivity()
  182. return src.bucket.Reserve(), nil
  183. }
  184. return nil, errNoBucket
  185. }
  186. func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Reservation {
  187. b.Lock()
  188. defer b.Unlock()
  189. b.cleanup()
  190. src := sourceRateLimiter{
  191. bucket: r,
  192. }
  193. src.updateLastActivity()
  194. b.buckets[source] = src
  195. return src.bucket.Reserve()
  196. }
  197. func (b *sourceBuckets) cleanup() {
  198. if len(b.buckets) >= b.hardLimit {
  199. numToRemove := len(b.buckets) - b.softLimit
  200. kvList := make(kvList, 0, len(b.buckets))
  201. for k, v := range b.buckets {
  202. kvList = append(kvList, kv{
  203. Key: k,
  204. Value: v.getLastActivity(),
  205. })
  206. }
  207. sort.Sort(kvList)
  208. for idx, kv := range kvList {
  209. if idx >= numToRemove {
  210. break
  211. }
  212. delete(b.buckets, kv.Key)
  213. }
  214. }
  215. }