Procházet zdrojové kódy

add support for limiting max concurrent client connections

Nicola Murino před 4 roky
rodič
revize
f34462e3c3

+ 15 - 1
common/common.go

@@ -247,7 +247,9 @@ type Configuration struct {
 	// Absolute path to an external program or an HTTP URL to invoke after a user connects
 	// and before he tries to login. It allows you to reject the connection based on the source
 	// ip address. Leave empty do disable.
-	PostConnectHook       string `json:"post_connect_hook" mapstructure:"post_connect_hook"`
+	PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"`
+	// Maximum number of concurrent client connections. 0 means unlimited
+	MaxTotalConnections   int `json:"max_total_connections" mapstructure:"max_total_connections"`
 	idleTimeoutAsDuration time.Duration
 	idleLoginTimeout      time.Duration
 }
@@ -544,6 +546,18 @@ func (conns *ActiveConnections) checkIdles() {
 	conns.RUnlock()
 }
 
+// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
+func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
+	if Config.MaxTotalConnections == 0 {
+		return true
+	}
+
+	conns.RLock()
+	defer conns.RUnlock()
+
+	return len(conns.connections) < Config.MaxTotalConnections
+}
+
 // GetStats returns stats for active connections
 func (conns *ActiveConnections) GetStats() []ConnectionStatus {
 	conns.RLock()

+ 21 - 0
common/common_test.go

@@ -225,6 +225,26 @@ func TestSSHConnections(t *testing.T) {
 	assert.NoError(t, sshConn3.Close())
 }
 
+func TestMaxConnections(t *testing.T) {
+	oldValue := Config.MaxTotalConnections
+	Config.MaxTotalConnections = 1
+
+	assert.True(t, Connections.IsNewConnectionAllowed())
+	c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil)
+	fakeConn := &fakeConnection{
+		BaseConnection: c,
+	}
+	Connections.Add(fakeConn)
+	assert.Len(t, Connections.GetStats(), 1)
+	assert.False(t, Connections.IsNewConnectionAllowed())
+
+	res := Connections.Close(fakeConn.GetID())
+	assert.True(t, res)
+	assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
+
+	Config.MaxTotalConnections = oldValue
+}
+
 func TestIdleConnections(t *testing.T) {
 	configCopy := Config
 
@@ -310,6 +330,7 @@ func TestCloseConnection(t *testing.T) {
 	fakeConn := &fakeConnection{
 		BaseConnection: c,
 	}
+	assert.True(t, Connections.IsNewConnectionAllowed())
 	Connections.Add(fakeConn)
 	assert.Len(t, Connections.GetStats(), 1)
 	res := Connections.Close(fakeConn.GetID())

+ 6 - 3
config/config.go

@@ -65,9 +65,11 @@ func Init() {
 				ExecuteOn: []string{},
 				Hook:      "",
 			},
-			SetstatMode:   0,
-			ProxyProtocol: 0,
-			ProxyAllowed:  []string{},
+			SetstatMode:         0,
+			ProxyProtocol:       0,
+			ProxyAllowed:        []string{},
+			PostConnectHook:     "",
+			MaxTotalConnections: 0,
 		},
 		SFTPD: sftpd.Configuration{
 			Banner:                  defaultSFTPDBanner,
@@ -413,6 +415,7 @@ func setViperDefaults() {
 	viper.SetDefault("common.proxy_protocol", globalConf.Common.ProxyProtocol)
 	viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed)
 	viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook)
+	viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
 	viper.SetDefault("sftpd.bind_port", globalConf.SFTPD.BindPort)
 	viper.SetDefault("sftpd.bind_address", globalConf.SFTPD.BindAddress)
 	viper.SetDefault("sftpd.max_auth_tries", globalConf.SFTPD.MaxAuthTries)

+ 1 - 0
docs/full-configuration.md

@@ -63,6 +63,7 @@ The configuration file contains the following sections:
     - If `proxy_protocol` is set to 1 and we receive a proxy header from an IP that is not in the list then the connection will be accepted and the header will be ignored
     - If `proxy_protocol` is set to 2 and we receive a proxy header from an IP that is not in the list then the connection will be rejected
   - `post_connect_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Post connect hook](./post-connect-hook.md) for more details. Leave empty to disable
+  - `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited
 - **"sftpd"**, the configuration for the SFTP server
   - `bind_port`, integer. The port used for serving SFTP requests. 0 means disabled. Default: 2022
   - `bind_address`, string. Leave blank to listen on all available network interfaces. Default: ""

+ 23 - 0
ftpd/ftpd_test.go

@@ -502,6 +502,29 @@ func TestPostConnectHook(t *testing.T) {
 	common.Config.PostConnectHook = ""
 }
 
+func TestMaxConnections(t *testing.T) {
+	oldValue := common.Config.MaxTotalConnections
+	common.Config.MaxTotalConnections = 1
+
+	user, _, err := httpd.AddUser(getTestUser(), http.StatusOK)
+	assert.NoError(t, err)
+	client, err := getFTPClient(user, true)
+	if assert.NoError(t, err) {
+		err = checkBasicFTP(client)
+		assert.NoError(t, err)
+		_, err = getFTPClient(user, false)
+		assert.Error(t, err)
+		err = client.Quit()
+		assert.NoError(t, err)
+	}
+	_, err = httpd.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+
+	common.Config.MaxTotalConnections = oldValue
+}
+
 func TestMaxSessions(t *testing.T) {
 	u := getTestUser()
 	u.MaxSessions = 1

+ 5 - 1
ftpd/server.go

@@ -98,8 +98,12 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
 
 // ClientConnected is called to send the very first welcome message
 func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
+	if !common.Connections.IsNewConnectionAllowed() {
+		logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
+		return "", common.ErrConnectionDenied
+	}
 	if err := common.Config.ExecutePostConnectHook(cc.RemoteAddr().String(), common.ProtocolFTP); err != nil {
-		return common.ErrConnectionDenied.Error(), err
+		return "", err
 	}
 	connID := fmt.Sprintf("%v", cc.ID())
 	user := dataprovider.User{}

+ 21 - 13
sftpd/server.go

@@ -277,23 +277,22 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 			logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
 		}
 	}()
+	if !common.Connections.IsNewConnectionAllowed() {
+		logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
+		conn.Close()
+		return
+	}
 	// Before beginning a handshake must be performed on the incoming net.Conn
 	// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
 	conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck
-	remoteAddr := conn.RemoteAddr()
-	if err := common.Config.ExecutePostConnectHook(remoteAddr.String(), common.ProtocolSSH); err != nil {
+	if err := common.Config.ExecutePostConnectHook(conn.RemoteAddr().String(), common.ProtocolSSH); err != nil {
 		conn.Close()
 		return
 	}
 	sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
 	if err != nil {
 		logger.Debug(logSender, "", "failed to accept an incoming connection: %v", err)
-		if _, ok := err.(*ssh.ServerAuthError); !ok {
-			ip := utils.GetIPFromRemoteAddress(remoteAddr.String())
-			logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
-			metrics.AddNoAuthTryed()
-			dataprovider.ExecutePostLoginHook("", dataprovider.LoginMethodNoAuthTryed, ip, common.ProtocolSSH, err)
-		}
+		checkAuthError(conn, err)
 		return
 	}
 	// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
@@ -315,7 +314,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 
 	logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
 		"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
-		user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())
+		user.ID, loginType, user.Username, user.HomeDir, conn.RemoteAddr().String())
 	dataprovider.UpdateLastLogin(user) //nolint:errcheck
 
 	sshConnection := common.NewSSHConnection(connectionID, conn)
@@ -354,13 +353,13 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 				switch req.Type {
 				case "subsystem":
 					if string(req.Payload[4:]) == "sftp" {
-						fs, err := user.GetFilesystem(connectionID)
+						fs, err := user.GetFilesystem(connID)
 						if err == nil {
 							ok = true
 							connection := Connection{
 								BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs),
 								ClientVersion:  string(sconn.ClientVersion()),
-								RemoteAddr:     remoteAddr,
+								RemoteAddr:     conn.RemoteAddr(),
 								channel:        channel,
 							}
 							go c.handleSftpConnection(channel, &connection)
@@ -368,12 +367,12 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 					}
 				case "exec":
 					// protocol will be set later inside processSSHCommand it could be SSH or SCP
-					fs, err := user.GetFilesystem(connectionID)
+					fs, err := user.GetFilesystem(connID)
 					if err == nil {
 						connection := Connection{
 							BaseConnection: common.NewBaseConnection(connID, "sshd_exec", user, fs),
 							ClientVersion:  string(sconn.ClientVersion()),
-							RemoteAddr:     remoteAddr,
+							RemoteAddr:     conn.RemoteAddr(),
 							channel:        channel,
 						}
 						ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands)
@@ -420,6 +419,15 @@ func (c *Configuration) createHandler(connection *Connection) sftp.Handlers {
 	}
 }
 
+func checkAuthError(conn net.Conn, err error) {
+	if _, ok := err.(*ssh.ServerAuthError); !ok {
+		ip := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
+		logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
+		metrics.AddNoAuthTryed()
+		dataprovider.ExecutePostLoginHook("", dataprovider.LoginMethodNoAuthTryed, ip, common.ProtocolSSH, err)
+	}
+}
+
 func checkRootPath(user *dataprovider.User, connectionID string) error {
 	if user.FsConfig.Provider != dataprovider.SFTPFilesystemProvider {
 		// for sftp fs check root path does nothing so don't open a useless SFTP connection

+ 25 - 0
sftpd/sftpd_test.go

@@ -2441,6 +2441,31 @@ func TestQuotaDisabledError(t *testing.T) {
 	assert.NoError(t, err)
 }
 
+func TestMaxConnections(t *testing.T) {
+	oldValue := common.Config.MaxTotalConnections
+	common.Config.MaxTotalConnections = 1
+
+	usePubKey := true
+	u := getTestUser(usePubKey)
+	user, _, err := httpd.AddUser(u, http.StatusOK)
+	assert.NoError(t, err)
+	client, err := getSftpClient(user, usePubKey)
+	if assert.NoError(t, err) {
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+		c, err := getSftpClient(user, usePubKey)
+		if !assert.Error(t, err, "max sessions exceeded, new login should not succeed") {
+			c.Close()
+		}
+	}
+	_, err = httpd.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+
+	common.Config.MaxTotalConnections = oldValue
+}
+
 func TestMaxSessions(t *testing.T) {
 	usePubKey := false
 	u := getTestUser(usePubKey)

+ 2 - 1
sftpgo.json

@@ -9,7 +9,8 @@
     "setstat_mode": 0,
     "proxy_protocol": 0,
     "proxy_allowed": [],
-    "post_connect_hook": ""
+    "post_connect_hook": "",
+    "max_total_connections": 0
   },
   "sftpd": {
     "bind_port": 2022,

+ 5 - 0
webdavd/server.go

@@ -112,6 +112,11 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 			http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
 		}
 	}()
+	if !common.Connections.IsNewConnectionAllowed() {
+		logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
+		http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable)
+		return
+	}
 	checkRemoteAddress(r)
 	if err := common.Config.ExecutePostConnectHook(r.RemoteAddr, common.ProtocolWebDAV); err != nil {
 		http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)

+ 25 - 0
webdavd/webdavd_test.go

@@ -650,6 +650,31 @@ func TestPostConnectHook(t *testing.T) {
 	common.Config.PostConnectHook = ""
 }
 
+func TestMaxConnections(t *testing.T) {
+	oldValue := common.Config.MaxTotalConnections
+	common.Config.MaxTotalConnections = 1
+
+	user, _, err := httpd.AddUser(getTestUser(), http.StatusOK)
+	assert.NoError(t, err)
+	client := getWebDavClient(user)
+	assert.NoError(t, checkBasicFunc(client))
+	// now add a fake connection
+	fs := vfs.NewOsFs("id", os.TempDir(), nil)
+	connection := &webdavd.Connection{
+		BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
+	}
+	common.Connections.Add(connection)
+	assert.Error(t, checkBasicFunc(client))
+	common.Connections.Remove(connection.GetID())
+	_, err = httpd.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+	assert.Len(t, common.Connections.GetStats(), 0)
+
+	common.Config.MaxTotalConnections = oldValue
+}
+
 func TestMaxSessions(t *testing.T) {
 	u := getTestUser()
 	u.MaxSessions = 1