ratelimiter.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. // Copyright (C) 2019-2022 Nicola Murino
  2. //
  3. // This program is free software: you can redistribute it and/or modify
  4. // it under the terms of the GNU Affero General Public License as published
  5. // by the Free Software Foundation, version 3.
  6. //
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. // GNU Affero General Public License for more details.
  11. //
  12. // You should have received a copy of the GNU Affero General Public License
  13. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. package common
  15. import (
  16. "errors"
  17. "fmt"
  18. "net"
  19. "sort"
  20. "sync"
  21. "sync/atomic"
  22. "time"
  23. "golang.org/x/time/rate"
  24. "github.com/drakkan/sftpgo/v2/internal/util"
  25. )
  26. var (
  27. errNoBucket = errors.New("no bucket found")
  28. errReserve = errors.New("unable to reserve token")
  29. rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP}
  30. )
  31. // RateLimiterType defines the supported rate limiters types
  32. type RateLimiterType int
  33. // Supported rate limiter types
  34. const (
  35. rateLimiterTypeGlobal RateLimiterType = iota + 1
  36. rateLimiterTypeSource
  37. )
  38. // RateLimiterConfig defines the configuration for a rate limiter
  39. type RateLimiterConfig struct {
  40. // Average defines the maximum rate allowed. 0 means disabled
  41. Average int64 `json:"average" mapstructure:"average"`
  42. // Period defines the period as milliseconds. Default: 1000 (1 second).
  43. // The rate is actually defined by dividing average by period.
  44. // So for a rate below 1 req/s, one needs to define a period larger than a second.
  45. Period int64 `json:"period" mapstructure:"period"`
  46. // Burst is the maximum number of requests allowed to go through in the
  47. // same arbitrarily small period of time. Default: 1.
  48. Burst int `json:"burst" mapstructure:"burst"`
  49. // Type defines the rate limiter type:
  50. // - rateLimiterTypeGlobal is a global rate limiter independent from the source
  51. // - rateLimiterTypeSource is a per-source rate limiter
  52. Type int `json:"type" mapstructure:"type"`
  53. // Protocols defines the protocols for this rate limiter.
  54. // Available protocols are: "SFTP", "FTP", "DAV".
  55. // A rate limiter with no protocols defined is disabled
  56. Protocols []string `json:"protocols" mapstructure:"protocols"`
  57. // AllowList defines a list of IP addresses and IP ranges excluded from rate limiting
  58. AllowList []string `json:"allow_list" mapstructure:"mapstructure"`
  59. // If the rate limit is exceeded, the defender is enabled, and this is a per-source limiter,
  60. // a new defender event will be generated
  61. GenerateDefenderEvents bool `json:"generate_defender_events" mapstructure:"generate_defender_events"`
  62. // The number of per-ip rate limiters kept in memory will vary between the
  63. // soft and hard limit
  64. EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
  65. EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
  66. }
  67. func (r *RateLimiterConfig) isEnabled() bool {
  68. return r.Average > 0 && len(r.Protocols) > 0
  69. }
  70. func (r *RateLimiterConfig) validate() error {
  71. if r.Burst < 1 {
  72. return fmt.Errorf("invalid burst %v. It must be >= 1", r.Burst)
  73. }
  74. if r.Period < 100 {
  75. return fmt.Errorf("invalid period %v. It must be >= 100", r.Period)
  76. }
  77. if r.Type != int(rateLimiterTypeGlobal) && r.Type != int(rateLimiterTypeSource) {
  78. return fmt.Errorf("invalid type %v", r.Type)
  79. }
  80. if r.Type != int(rateLimiterTypeGlobal) {
  81. if r.EntriesSoftLimit <= 0 {
  82. return fmt.Errorf("invalid entries_soft_limit %v", r.EntriesSoftLimit)
  83. }
  84. if r.EntriesHardLimit <= r.EntriesSoftLimit {
  85. return fmt.Errorf("invalid entries_hard_limit %v must be > %v", r.EntriesHardLimit, r.EntriesSoftLimit)
  86. }
  87. }
  88. r.Protocols = util.RemoveDuplicates(r.Protocols, true)
  89. for _, protocol := range r.Protocols {
  90. if !util.Contains(rateLimiterProtocolValues, protocol) {
  91. return fmt.Errorf("invalid protocol %#v", protocol)
  92. }
  93. }
  94. return nil
  95. }
  96. func (r *RateLimiterConfig) getLimiter() *rateLimiter {
  97. limiter := &rateLimiter{
  98. burst: r.Burst,
  99. globalBucket: nil,
  100. generateDefenderEvents: r.GenerateDefenderEvents,
  101. }
  102. var maxDelay time.Duration
  103. period := time.Duration(r.Period) * time.Millisecond
  104. rtl := float64(r.Average*int64(time.Second)) / float64(period)
  105. limiter.rate = rate.Limit(rtl)
  106. if rtl < 1 {
  107. maxDelay = period / 2
  108. } else {
  109. maxDelay = time.Second / (time.Duration(rtl) * 2)
  110. }
  111. if maxDelay > 10*time.Second {
  112. maxDelay = 10 * time.Second
  113. }
  114. limiter.maxDelay = maxDelay
  115. limiter.buckets = sourceBuckets{
  116. buckets: make(map[string]sourceRateLimiter),
  117. hardLimit: r.EntriesHardLimit,
  118. softLimit: r.EntriesSoftLimit,
  119. }
  120. if r.Type != int(rateLimiterTypeSource) {
  121. limiter.globalBucket = rate.NewLimiter(limiter.rate, limiter.burst)
  122. }
  123. return limiter
  124. }
  125. // RateLimiter defines a rate limiter
  126. type rateLimiter struct {
  127. rate rate.Limit
  128. burst int
  129. maxDelay time.Duration
  130. globalBucket *rate.Limiter
  131. buckets sourceBuckets
  132. generateDefenderEvents bool
  133. allowList []func(net.IP) bool
  134. }
  135. // Wait blocks until the limit allows one event to happen
  136. // or returns an error if the time to wait exceeds the max
  137. // allowed delay
  138. func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
  139. if len(rl.allowList) > 0 {
  140. ip := net.ParseIP(source)
  141. if ip != nil {
  142. for idx := range rl.allowList {
  143. if rl.allowList[idx](ip) {
  144. return 0, nil
  145. }
  146. }
  147. }
  148. }
  149. var res *rate.Reservation
  150. if rl.globalBucket != nil {
  151. res = rl.globalBucket.Reserve()
  152. } else {
  153. var err error
  154. res, err = rl.buckets.reserve(source)
  155. if err != nil {
  156. rateLimiter := rate.NewLimiter(rl.rate, rl.burst)
  157. res = rl.buckets.addAndReserve(rateLimiter, source)
  158. }
  159. }
  160. if !res.OK() {
  161. return 0, errReserve
  162. }
  163. delay := res.Delay()
  164. if delay > rl.maxDelay {
  165. res.Cancel()
  166. if rl.generateDefenderEvents && rl.globalBucket == nil {
  167. AddDefenderEvent(source, HostEventLimitExceeded)
  168. }
  169. return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
  170. }
  171. time.Sleep(delay)
  172. return 0, nil
  173. }
  174. type sourceRateLimiter struct {
  175. lastActivity *atomic.Int64
  176. bucket *rate.Limiter
  177. }
  178. func (s *sourceRateLimiter) updateLastActivity() {
  179. s.lastActivity.Store(time.Now().UnixNano())
  180. }
  181. func (s *sourceRateLimiter) getLastActivity() int64 {
  182. return s.lastActivity.Load()
  183. }
  184. type sourceBuckets struct {
  185. sync.RWMutex
  186. buckets map[string]sourceRateLimiter
  187. hardLimit int
  188. softLimit int
  189. }
  190. func (b *sourceBuckets) reserve(source string) (*rate.Reservation, error) {
  191. b.RLock()
  192. defer b.RUnlock()
  193. if src, ok := b.buckets[source]; ok {
  194. src.updateLastActivity()
  195. return src.bucket.Reserve(), nil
  196. }
  197. return nil, errNoBucket
  198. }
  199. func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Reservation {
  200. b.Lock()
  201. defer b.Unlock()
  202. b.cleanup()
  203. src := sourceRateLimiter{
  204. lastActivity: new(atomic.Int64),
  205. bucket: r,
  206. }
  207. src.updateLastActivity()
  208. b.buckets[source] = src
  209. return src.bucket.Reserve()
  210. }
  211. func (b *sourceBuckets) cleanup() {
  212. if len(b.buckets) >= b.hardLimit {
  213. numToRemove := len(b.buckets) - b.softLimit
  214. kvList := make(kvList, 0, len(b.buckets))
  215. for k, v := range b.buckets {
  216. kvList = append(kvList, kv{
  217. Key: k,
  218. Value: v.getLastActivity(),
  219. })
  220. }
  221. sort.Sort(kvList)
  222. for idx, kv := range kvList {
  223. if idx >= numToRemove {
  224. break
  225. }
  226. delete(b.buckets, kv.Key)
  227. }
  228. }
  229. }