123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- // Copyright (C) 2019 Nicola Murino
- //
- // This program is free software: you can redistribute it and/or modify
- // it under the terms of the GNU Affero General Public License as published
- // by the Free Software Foundation, version 3.
- //
- // This program is distributed in the hope that it will be useful,
- // but WITHOUT ANY WARRANTY; without even the implied warranty of
- // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- // GNU Affero General Public License for more details.
- //
- // You should have received a copy of the GNU Affero General Public License
- // along with this program. If not, see <https://www.gnu.org/licenses/>.
- package common
- import (
- "errors"
- "fmt"
- "sort"
- "sync"
- "sync/atomic"
- "time"
- "golang.org/x/time/rate"
- "github.com/drakkan/sftpgo/v2/internal/util"
- )
- var (
- errNoBucket = errors.New("no bucket found")
- errReserve = errors.New("unable to reserve token")
- rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP}
- )
- // RateLimiterType defines the supported rate limiters types
- type RateLimiterType int
- // Supported rate limiter types
- const (
- rateLimiterTypeGlobal RateLimiterType = iota + 1
- rateLimiterTypeSource
- )
- // RateLimiterConfig defines the configuration for a rate limiter
- type RateLimiterConfig struct {
- // Average defines the maximum rate allowed. 0 means disabled
- Average int64 `json:"average" mapstructure:"average"`
- // Period defines the period as milliseconds. Default: 1000 (1 second).
- // The rate is actually defined by dividing average by period.
- // So for a rate below 1 req/s, one needs to define a period larger than a second.
- Period int64 `json:"period" mapstructure:"period"`
- // Burst is the maximum number of requests allowed to go through in the
- // same arbitrarily small period of time. Default: 1.
- Burst int `json:"burst" mapstructure:"burst"`
- // Type defines the rate limiter type:
- // - rateLimiterTypeGlobal is a global rate limiter independent from the source
- // - rateLimiterTypeSource is a per-source rate limiter
- Type int `json:"type" mapstructure:"type"`
- // Protocols defines the protocols for this rate limiter.
- // Available protocols are: "SFTP", "FTP", "DAV".
- // A rate limiter with no protocols defined is disabled
- Protocols []string `json:"protocols" mapstructure:"protocols"`
- // If the rate limit is exceeded, the defender is enabled, and this is a per-source limiter,
- // a new defender event will be generated
- GenerateDefenderEvents bool `json:"generate_defender_events" mapstructure:"generate_defender_events"`
- // The number of per-ip rate limiters kept in memory will vary between the
- // soft and hard limit
- EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
- EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
- }
- func (r *RateLimiterConfig) isEnabled() bool {
- return r.Average > 0 && len(r.Protocols) > 0
- }
- func (r *RateLimiterConfig) validate() error {
- if r.Burst < 1 {
- return fmt.Errorf("invalid burst %v. It must be >= 1", r.Burst)
- }
- if r.Period < 100 {
- return fmt.Errorf("invalid period %v. It must be >= 100", r.Period)
- }
- if r.Type != int(rateLimiterTypeGlobal) && r.Type != int(rateLimiterTypeSource) {
- return fmt.Errorf("invalid type %v", r.Type)
- }
- if r.Type != int(rateLimiterTypeGlobal) {
- if r.EntriesSoftLimit <= 0 {
- return fmt.Errorf("invalid entries_soft_limit %v", r.EntriesSoftLimit)
- }
- if r.EntriesHardLimit <= r.EntriesSoftLimit {
- return fmt.Errorf("invalid entries_hard_limit %v must be > %v", r.EntriesHardLimit, r.EntriesSoftLimit)
- }
- }
- r.Protocols = util.RemoveDuplicates(r.Protocols, true)
- for _, protocol := range r.Protocols {
- if !util.Contains(rateLimiterProtocolValues, protocol) {
- return fmt.Errorf("invalid protocol %q", protocol)
- }
- }
- return nil
- }
- func (r *RateLimiterConfig) getLimiter() *rateLimiter {
- limiter := &rateLimiter{
- burst: r.Burst,
- globalBucket: nil,
- generateDefenderEvents: r.GenerateDefenderEvents,
- }
- var maxDelay time.Duration
- period := time.Duration(r.Period) * time.Millisecond
- rtl := float64(r.Average*int64(time.Second)) / float64(period)
- limiter.rate = rate.Limit(rtl)
- if rtl < 1 {
- maxDelay = period / 2
- } else {
- maxDelay = time.Second / (time.Duration(rtl) * 2)
- }
- if maxDelay > 10*time.Second {
- maxDelay = 10 * time.Second
- }
- limiter.maxDelay = maxDelay
- limiter.buckets = sourceBuckets{
- buckets: make(map[string]sourceRateLimiter),
- hardLimit: r.EntriesHardLimit,
- softLimit: r.EntriesSoftLimit,
- }
- if r.Type != int(rateLimiterTypeSource) {
- limiter.globalBucket = rate.NewLimiter(limiter.rate, limiter.burst)
- }
- return limiter
- }
- // RateLimiter defines a rate limiter
- type rateLimiter struct {
- rate rate.Limit
- burst int
- maxDelay time.Duration
- globalBucket *rate.Limiter
- buckets sourceBuckets
- generateDefenderEvents bool
- }
- // Wait blocks until the limit allows one event to happen
- // or returns an error if the time to wait exceeds the max
- // allowed delay
- func (rl *rateLimiter) Wait(source, protocol string) (time.Duration, error) {
- var res *rate.Reservation
- if rl.globalBucket != nil {
- res = rl.globalBucket.Reserve()
- } else {
- var err error
- res, err = rl.buckets.reserve(source)
- if err != nil {
- rateLimiter := rate.NewLimiter(rl.rate, rl.burst)
- res = rl.buckets.addAndReserve(rateLimiter, source)
- }
- }
- if !res.OK() {
- return 0, errReserve
- }
- delay := res.Delay()
- if delay > rl.maxDelay {
- res.Cancel()
- if rl.generateDefenderEvents && rl.globalBucket == nil {
- AddDefenderEvent(source, protocol, HostEventLimitExceeded)
- }
- return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
- }
- time.Sleep(delay)
- return 0, nil
- }
- type sourceRateLimiter struct {
- lastActivity *atomic.Int64
- bucket *rate.Limiter
- }
- func (s *sourceRateLimiter) updateLastActivity() {
- s.lastActivity.Store(time.Now().UnixNano())
- }
- func (s *sourceRateLimiter) getLastActivity() int64 {
- return s.lastActivity.Load()
- }
- type sourceBuckets struct {
- sync.RWMutex
- buckets map[string]sourceRateLimiter
- hardLimit int
- softLimit int
- }
- func (b *sourceBuckets) reserve(source string) (*rate.Reservation, error) {
- b.RLock()
- defer b.RUnlock()
- if src, ok := b.buckets[source]; ok {
- src.updateLastActivity()
- return src.bucket.Reserve(), nil
- }
- return nil, errNoBucket
- }
- func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Reservation {
- b.Lock()
- defer b.Unlock()
- b.cleanup()
- src := sourceRateLimiter{
- lastActivity: new(atomic.Int64),
- bucket: r,
- }
- src.updateLastActivity()
- b.buckets[source] = src
- return src.bucket.Reserve()
- }
- func (b *sourceBuckets) cleanup() {
- if len(b.buckets) >= b.hardLimit {
- numToRemove := len(b.buckets) - b.softLimit
- kvList := make(kvList, 0, len(b.buckets))
- for k, v := range b.buckets {
- kvList = append(kvList, kv{
- Key: k,
- Value: v.getLastActivity(),
- })
- }
- sort.Sort(kvList)
- for idx, kv := range kvList {
- if idx >= numToRemove {
- break
- }
- delete(b.buckets, kv.Key)
- }
- }
- }
|