瀏覽代碼

add rate limiting support for REST API/web admin too

Nicola Murino 4 年之前
父節點
當前提交
f45c89fc46

+ 5 - 4
common/common.go

@@ -70,6 +70,7 @@ const (
 	ProtocolSSH    = "SSH"
 	ProtocolFTP    = "FTP"
 	ProtocolWebDAV = "DAV"
+	ProtocolHTTP   = "HTTP"
 )
 
 // Upload modes
@@ -144,14 +145,14 @@ func Initialize(c Configuration) error {
 // allow one event to happen.
 // It returns an error if the time to wait exceeds the max
 // allowed delay
-func LimitRate(protocol, ip string) error {
+func LimitRate(protocol, ip string) (time.Duration, error) {
 	for _, limiter := range rateLimiters[protocol] {
-		if err := limiter.Wait(ip); err != nil {
+		if delay, err := limiter.Wait(ip); err != nil {
 			logger.Debug(logSender, "", "protocol %v ip %v: %v", protocol, ip, err)
-			return err
+			return delay, err
 		}
 	}
-	return nil
+	return 0, nil
 }
 
 // ReloadDefender reloads the defender's block and safe lists

+ 9 - 8
common/common_test.go

@@ -194,30 +194,31 @@ func TestRateLimitersIntegration(t *testing.T) {
 	err = Initialize(Config)
 	assert.NoError(t, err)
 
-	assert.Len(t, rateLimiters, 3)
+	assert.Len(t, rateLimiters, 4)
 	assert.Len(t, rateLimiters[ProtocolSSH], 1)
 	assert.Len(t, rateLimiters[ProtocolFTP], 2)
 	assert.Len(t, rateLimiters[ProtocolWebDAV], 2)
+	assert.Len(t, rateLimiters[ProtocolHTTP], 1)
 
 	source1 := "127.1.1.1"
 	source2 := "127.1.1.2"
 
-	err = LimitRate(ProtocolSSH, source1)
+	_, err = LimitRate(ProtocolSSH, source1)
 	assert.NoError(t, err)
-	err = LimitRate(ProtocolFTP, source1)
+	_, err = LimitRate(ProtocolFTP, source1)
 	assert.NoError(t, err)
 	// sleep to allow the add configured burst to the token.
 	// This sleep is not enough to add the per-source burst
 	time.Sleep(20 * time.Millisecond)
-	err = LimitRate(ProtocolWebDAV, source2)
+	_, err = LimitRate(ProtocolWebDAV, source2)
 	assert.NoError(t, err)
-	err = LimitRate(ProtocolFTP, source1)
+	_, err = LimitRate(ProtocolFTP, source1)
 	assert.Error(t, err)
-	err = LimitRate(ProtocolWebDAV, source2)
+	_, err = LimitRate(ProtocolWebDAV, source2)
 	assert.Error(t, err)
-	err = LimitRate(ProtocolSSH, source1)
+	_, err = LimitRate(ProtocolSSH, source1)
 	assert.NoError(t, err)
-	err = LimitRate(ProtocolSSH, source2)
+	_, err = LimitRate(ProtocolSSH, source2)
 	assert.NoError(t, err)
 
 	Config = configCopy

+ 5 - 5
common/ratelimiter.go

@@ -16,7 +16,7 @@ import (
 var (
 	errNoBucket               = errors.New("no bucket found")
 	errReserve                = errors.New("unable to reserve token")
-	rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV}
+	rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP}
 )
 
 // RateLimiterType defines the supported rate limiters types
@@ -130,7 +130,7 @@ type rateLimiter struct {
 // Wait blocks until the limit allows one event to happen
 // or returns an error if the time to wait exceeds the max
 // allowed delay
-func (rl *rateLimiter) Wait(source string) error {
+func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
 	var res *rate.Reservation
 	if rl.globalBucket != nil {
 		res = rl.globalBucket.Reserve()
@@ -143,7 +143,7 @@ func (rl *rateLimiter) Wait(source string) error {
 		}
 	}
 	if !res.OK() {
-		return errReserve
+		return 0, errReserve
 	}
 	delay := res.Delay()
 	if delay > rl.maxDelay {
@@ -151,10 +151,10 @@ func (rl *rateLimiter) Wait(source string) error {
 		if rl.generateDefenderEvents && rl.globalBucket == nil {
 			AddDefenderEvent(source, HostEventRateExceeded)
 		}
-		return fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
+		return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
 	}
 	time.Sleep(delay)
-	return nil
+	return 0, nil
 }
 
 type sourceRateLimiter struct {

+ 10 - 10
common/ratelimiter_test.go

@@ -63,9 +63,9 @@ func TestRateLimiter(t *testing.T) {
 		Protocols: rateLimiterProtocolValues,
 	}
 	limiter := config.getLimiter()
-	err := limiter.Wait("")
+	_, err := limiter.Wait("")
 	require.NoError(t, err)
-	err = limiter.Wait("")
+	_, err = limiter.Wait("")
 	require.Error(t, err)
 
 	config.Type = int(rateLimiterTypeSource)
@@ -75,17 +75,17 @@ func TestRateLimiter(t *testing.T) {
 	limiter = config.getLimiter()
 
 	source := "192.168.1.2"
-	err = limiter.Wait(source)
+	_, err = limiter.Wait(source)
 	require.NoError(t, err)
-	err = limiter.Wait(source)
+	_, err = limiter.Wait(source)
 	require.Error(t, err)
 	// a different source should work
-	err = limiter.Wait(source + "1")
+	_, err = limiter.Wait(source + "1")
 	require.NoError(t, err)
 
 	config.Burst = 0
 	limiter = config.getLimiter()
-	err = limiter.Wait(source)
+	_, err = limiter.Wait(source)
 	require.ErrorIs(t, err, errReserve)
 }
 
@@ -104,10 +104,10 @@ func TestLimiterCleanup(t *testing.T) {
 	source2 := "10.8.0.2"
 	source3 := "10.8.0.3"
 	source4 := "10.8.0.4"
-	err := limiter.Wait(source1)
+	_, err := limiter.Wait(source1)
 	assert.NoError(t, err)
 	time.Sleep(20 * time.Millisecond)
-	err = limiter.Wait(source2)
+	_, err = limiter.Wait(source2)
 	assert.NoError(t, err)
 	time.Sleep(20 * time.Millisecond)
 	assert.Len(t, limiter.buckets.buckets, 2)
@@ -115,7 +115,7 @@ func TestLimiterCleanup(t *testing.T) {
 	assert.True(t, ok)
 	_, ok = limiter.buckets.buckets[source2]
 	assert.True(t, ok)
-	err = limiter.Wait(source3)
+	_, err = limiter.Wait(source3)
 	assert.NoError(t, err)
 	assert.Len(t, limiter.buckets.buckets, 3)
 	_, ok = limiter.buckets.buckets[source1]
@@ -125,7 +125,7 @@ func TestLimiterCleanup(t *testing.T) {
 	_, ok = limiter.buckets.buckets[source3]
 	assert.True(t, ok)
 	time.Sleep(20 * time.Millisecond)
-	err = limiter.Wait(source4)
+	_, err = limiter.Wait(source4)
 	assert.NoError(t, err)
 	assert.Len(t, limiter.buckets.buckets, 2)
 	_, ok = limiter.buckets.buckets[source3]

+ 1 - 1
config/config.go

@@ -74,7 +74,7 @@ var (
 		Period:                 1000,
 		Burst:                  1,
 		Type:                   2,
-		Protocols:              []string{common.ProtocolSSH, common.ProtocolFTP, common.ProtocolWebDAV},
+		Protocols:              []string{common.ProtocolSSH, common.ProtocolFTP, common.ProtocolWebDAV, common.ProtocolHTTP},
 		GenerateDefenderEvents: false,
 		EntriesSoftLimit:       100,
 		EntriesHardLimit:       150,

+ 2 - 1
config/config_test.go

@@ -474,10 +474,11 @@ func TestRateLimitersFromEnv(t *testing.T) {
 	require.Equal(t, 1, limiters[1].Burst)
 	require.Equal(t, 2, limiters[1].Type)
 	protocols = limiters[1].Protocols
-	require.Len(t, protocols, 3)
+	require.Len(t, protocols, 4)
 	require.True(t, utils.IsStringInSlice(common.ProtocolFTP, protocols))
 	require.True(t, utils.IsStringInSlice(common.ProtocolSSH, protocols))
 	require.True(t, utils.IsStringInSlice(common.ProtocolWebDAV, protocols))
+	require.True(t, utils.IsStringInSlice(common.ProtocolHTTP, protocols))
 	require.False(t, limiters[1].GenerateDefenderEvents)
 	require.Equal(t, 100, limiters[1].EntriesSoftLimit)
 	require.Equal(t, 150, limiters[1].EntriesHardLimit)

+ 1 - 1
docs/full-configuration.md

@@ -83,7 +83,7 @@ The configuration file contains the following sections:
     - `period`, integer. Period defines the period as milliseconds. The rate is actually defined by dividing average by period Default: 1000 (1 second).
     - `burst`, integer. Burst defines the maximum number of requests allowed to go through in the same arbitrarily small period of time. Default: 1
     - `type`, integer. 1 means a global rate limiter, independent from the source host. 2 means a per-ip rate limiter. Default: 2
-    - `protocols`, list of strings. Available protocols are `SSH`, `FTP`, `DAV`. By default all supported protocols are enabled
+    - `protocols`, list of strings. Available protocols are `SSH`, `FTP`, `DAV`, `HTTP`. By default all supported protocols are enabled
     - `generate_defender_events`, boolean. If `true`, the defender is enabled, and this is not a global rate limiter, a new defender event will be generated each time the configured limit is exceeded. Default `false`
     - `entries_soft_limit`, integer.
     - `entries_hard_limit`, integer. The number of per-ip rate limiters kept in memory will vary between the soft and hard limit

+ 12 - 4
docs/rate-limiting.md

@@ -1,6 +1,6 @@
 # Rate limiting
 
-Rate limiting allows to control the number of requests going to the configured services.
+Rate limiting allows to control the number of requests going to the SFTPGo services.
 
 SFTPGo implements a [token bucket](https://en.wikipedia.org/wiki/Token_bucket) initially full and refilled at the configured rate. The `burst` configuration parameter defines the size of the bucket. The rate is defined by dividing `average` by `period`, so for a rate below 1 req/s, one needs to define a period larger than a second.
 
@@ -8,9 +8,16 @@ Requests that exceed the configured limit will be delayed or denied if they exce
 
 SFTPGo allows to define per-protocol rate limiters so you can have different configurations for different protocols.
 
+The supported protocols are:
+
+- `SSH`, includes SFTP and SSH commands
+- `FTP`, includes FTP, FTPES, FTPS
+- `DAV`, WebDAV
+- `HTTP`, REST API and web admin
+
 You can also define two types of rate limiters:
 
-- global, it is independent from the source host and therefore define a limit for the configured protocol/s
+- global, it is independent from the source host and therefore define an aggregate limit for the configured protocol/s
 - per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_rate_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked
 
 If you configure a per-host rate limiter, SFTPGo will keep a rate limiter in memory for each host that connects to the service, you can limit the memory usage using the `entries_soft_limit` and `entries_hard_limit` configuration keys.
@@ -27,7 +34,8 @@ You can defines how many rate limiters as you want, but keep in mind that if you
       "protocols": [
         "SSH",
         "FTP",
-        "DAV"
+        "DAV",
+        "HTTP"
       ],
       "generate_defender_events": false,
       "entries_soft_limit": 100,
@@ -48,6 +56,6 @@ You can defines how many rate limiters as you want, but keep in mind that if you
 ]
 ```
 
-we have a global rate limiter that limit the rate for the whole service to 100 req/s and an additional rate limiter that limits the `FTP` protocol to 10 req/s per host.
+we have a global rate limiter that limit the aggregate rate for the all the services to 100 req/s and an additional rate limiter that limits the `FTP` protocol to 10 req/s per host.
 With this configuration, when a client connects via FTP it will be limited first by the global rate limiter and then by the per host rate limiter.
 Clients connecting via SFTP/WebDAV will be checked only against the global rate limiter.

+ 2 - 1
ftpd/server.go

@@ -144,7 +144,8 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
 		logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
 		return "Access denied: max allowed connection exceeded", common.ErrConnectionDenied
 	}
-	if err := common.LimitRate(common.ProtocolFTP, ipAddr); err != nil {
+	_, err := common.LimitRate(common.ProtocolFTP, ipAddr)
+	if err != nil {
 		return fmt.Sprintf("Access denied: %v", err.Error()), err
 	}
 	if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolFTP); err != nil {

+ 38 - 0
httpd/httpd_test.go

@@ -3118,6 +3118,44 @@ func TestLoaddataMode(t *testing.T) {
 	assert.NoError(t, err)
 }
 
+func TestRateLimiter(t *testing.T) {
+	oldConfig := config.GetCommonConfig()
+
+	cfg := config.GetCommonConfig()
+	cfg.RateLimitersConfig = []common.RateLimiterConfig{
+		{
+			Average:   1,
+			Period:    1000,
+			Burst:     1,
+			Type:      1,
+			Protocols: []string{common.ProtocolHTTP},
+		},
+	}
+
+	err := common.Initialize(cfg)
+	assert.NoError(t, err)
+
+	client := &http.Client{
+		Timeout: 5 * time.Second,
+	}
+	resp, err := client.Get(httpBaseURL + healthzPath)
+	assert.NoError(t, err)
+	assert.Equal(t, http.StatusOK, resp.StatusCode)
+	err = resp.Body.Close()
+	assert.NoError(t, err)
+
+	resp, err = client.Get(httpBaseURL + healthzPath)
+	assert.NoError(t, err)
+	assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+	assert.NotEmpty(t, resp.Header.Get("Retry-After"))
+	assert.NotEmpty(t, resp.Header.Get("X-Retry-In"))
+	err = resp.Body.Close()
+	assert.NoError(t, err)
+
+	err = common.Initialize(oldConfig)
+	assert.NoError(t, err)
+}
+
 func TestHTTPSConnection(t *testing.T) {
 	client := &http.Client{
 		Timeout: 5 * time.Second,

+ 14 - 0
httpd/middleware.go

@@ -3,11 +3,13 @@ package httpd
 import (
 	"context"
 	"errors"
+	"fmt"
 	"net/http"
 
 	"github.com/go-chi/jwtauth/v5"
 	"github.com/lestrrat-go/jwx/jwt"
 
+	"github.com/drakkan/sftpgo/common"
 	"github.com/drakkan/sftpgo/logger"
 	"github.com/drakkan/sftpgo/utils"
 )
@@ -141,3 +143,15 @@ func verifyCSRFHeader(next http.Handler) http.Handler {
 		next.ServeHTTP(w, r)
 	})
 }
+
+func rateLimiter(next http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		if delay, err := common.LimitRate(common.ProtocolHTTP, utils.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
+			w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds()))
+			w.Header().Set("X-Retry-In", delay.String())
+			sendAPIResponse(w, r, err, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
+			return
+		}
+		next.ServeHTTP(w, r)
+	})
+}

+ 2 - 1
httpd/server.go

@@ -259,6 +259,8 @@ func (s *httpdServer) initializeRouter() {
 	s.router.Use(saveConnectionAddress)
 	s.router.Use(middleware.GetHead)
 	s.router.Use(middleware.StripSlashes)
+	s.router.Use(middleware.RealIP)
+	s.router.Use(rateLimiter)
 
 	s.router.Group(func(r chi.Router) {
 		r.Get(healthzPath, func(w http.ResponseWriter, r *http.Request) {
@@ -268,7 +270,6 @@ func (s *httpdServer) initializeRouter() {
 
 	s.router.Group(func(router chi.Router) {
 		router.Use(middleware.RequestID)
-		router.Use(middleware.RealIP)
 		router.Use(logger.NewStructuredLogger(logger.GetLogger()))
 		router.Use(middleware.Recoverer)
 

+ 2 - 1
sftpd/server.go

@@ -360,7 +360,8 @@ func canAcceptConnection(ip string) bool {
 		logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
 		return false
 	}
-	if err := common.LimitRate(common.ProtocolSSH, ip); err != nil {
+	_, err := common.LimitRate(common.ProtocolSSH, ip)
+	if err != nil {
 		return false
 	}
 	if err := common.Config.ExecutePostConnectHook(ip, common.ProtocolSSH); err != nil {

+ 2 - 1
sftpgo.json

@@ -35,7 +35,8 @@
         "protocols": [
           "SSH",
           "FTP",
-          "DAV"
+          "DAV",
+          "HTTP"
         ],
         "generate_defender_events": false,
         "entries_soft_limit": 100,

+ 4 - 1
webdavd/server.go

@@ -158,7 +158,10 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
 		return
 	}
-	if err := common.LimitRate(common.ProtocolWebDAV, ipAddr); err != nil {
+	delay, err := common.LimitRate(common.ProtocolWebDAV, ipAddr)
+	if err != nil {
+		w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds()))
+		w.Header().Set("X-Retry-In", delay.String())
 		http.Error(w, err.Error(), http.StatusTooManyRequests)
 		return
 	}