Browse Source

sftpd: fix duplicate defender error introduced in the previous commit

improve the defender test cases by verifying the expected score

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 2 years ago
parent
commit
9c9c9fa3a5

+ 27 - 3
internal/ftpd/ftpd_test.go

@@ -1724,8 +1724,10 @@ func TestDefender(t *testing.T) {
 
 	cfg := config.GetCommonConfig()
 	cfg.DefenderConfig.Enabled = true
-	cfg.DefenderConfig.Threshold = 3
+	cfg.DefenderConfig.Threshold = 4
 	cfg.DefenderConfig.ScoreLimitExceeded = 2
+	cfg.DefenderConfig.ScoreNoAuth = 1
+	cfg.DefenderConfig.ScoreValid = 1
 
 	err := common.Initialize(cfg, 0)
 	assert.NoError(t, err)
@@ -1739,9 +1741,31 @@ func TestDefender(t *testing.T) {
 		err = client.Quit()
 		assert.NoError(t, err)
 	}
+	// just dial without login
+	ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)}
+	client, err = ftp.Dial(ftpServerAddr, ftpOptions...)
+	assert.NoError(t, err)
+	err = client.Quit()
+	assert.NoError(t, err)
+	hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
+	assert.NoError(t, err)
+	if assert.Len(t, hosts, 1) {
+		host := hosts[0]
+		assert.Empty(t, host.GetBanTime())
+		assert.Equal(t, 1, host.Score)
+	}
+	user.Password = "wrong_pwd"
+	_, err = getFTPClient(user, false, nil)
+	assert.Error(t, err)
+	hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK)
+	assert.NoError(t, err)
+	if assert.Len(t, hosts, 1) {
+		host := hosts[0]
+		assert.Empty(t, host.GetBanTime())
+		assert.Equal(t, 2, host.Score)
+	}
 
-	for i := 0; i < 3; i++ {
-		user.Password = "wrong_pwd"
+	for i := 0; i < 2; i++ {
 		_, err = getFTPClient(user, false, nil)
 		assert.Error(t, err)
 	}

+ 10 - 5
internal/sftpd/internal_test.go

@@ -2301,19 +2301,24 @@ func TestCanReadSymlink(t *testing.T) {
 }
 
 func TestAuthenticationErrors(t *testing.T) {
-	err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")))
+	loginMethod := dataprovider.SSHLoginMethodPassword
+	err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")), loginMethod)
 	assert.ErrorIs(t, err, sftpAuthError)
 	assert.ErrorIs(t, err, util.ErrNotFound)
-	err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission))
+	var sftpAuthErr *authenticationError
+	if assert.ErrorAs(t, err, &sftpAuthErr) {
+		assert.Equal(t, loginMethod, sftpAuthErr.getLoginMethod())
+	}
+	err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission), loginMethod)
 	assert.ErrorIs(t, err, sftpAuthError)
 	assert.NotErrorIs(t, err, util.ErrNotFound)
-	err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert))
+	err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert), loginMethod)
 	assert.ErrorIs(t, err, sftpAuthError)
 	assert.NotErrorIs(t, err, util.ErrNotFound)
-	err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"))
+	err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"), loginMethod)
 	assert.ErrorIs(t, err, sftpAuthError)
 	assert.NotErrorIs(t, err, util.ErrNotFound)
-	err = newAuthenticationError(nil)
+	err = newAuthenticationError(nil, loginMethod)
 	assert.ErrorIs(t, err, sftpAuthError)
 	assert.NotErrorIs(t, err, util.ErrNotFound)
 }

+ 27 - 15
internal/sftpd/server.go

@@ -93,7 +93,7 @@ var (
 		certs: map[string]bool{},
 	}
 
-	sftpAuthError = newAuthenticationError(nil)
+	sftpAuthError = newAuthenticationError(nil, "")
 )
 
 // Binding defines the configuration for a network listener
@@ -210,7 +210,8 @@ type Configuration struct {
 }
 
 type authenticationError struct {
-	err error
+	err         error
+	loginMethod string
 }
 
 func (e *authenticationError) Error() string {
@@ -228,8 +229,12 @@ func (e *authenticationError) Unwrap() error {
 	return e.err
 }
 
-func newAuthenticationError(err error) *authenticationError {
-	return &authenticationError{err: err}
+func (e *authenticationError) getLoginMethod() string {
+	return e.loginMethod
+}
+
+func newAuthenticationError(err error, loginMethod string) *authenticationError {
+	return &authenticationError{err: err, loginMethod: loginMethod}
 }
 
 // ShouldBind returns true if there is at least a valid binding
@@ -253,7 +258,8 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
 				return sp, err
 			}
 			if err != nil {
-				return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err))
+				return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err),
+					dataprovider.SSHLoginMethodPublicKey)
 			}
 
 			return sp, nil
@@ -273,7 +279,8 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
 		serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
 			sp, err := c.validatePasswordCredentials(conn, pass)
 			if err != nil {
-				return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err))
+				return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err),
+					dataprovider.SSHLoginMethodPassword)
 			}
 
 			return sp, nil
@@ -487,7 +494,8 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
 	serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
 		sp, err := c.validateKeyboardInteractiveCredentials(conn, client)
 		if err != nil {
-			return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err))
+			return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err),
+				dataprovider.SSHLoginMethodKeyboardInteractive)
 		}
 
 		return sp, nil
@@ -561,7 +569,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 	}
 
 	logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
-		"User %#v logged in with %#v, from ip %#v, client version %#v", user.Username, loginType,
+		"User %q logged in with %q, from ip %q, client version %q", user.Username, loginType,
 		ipAddr, string(sconn.ClientVersion()))
 	dataprovider.UpdateLastLogin(&user)
 
@@ -683,16 +691,20 @@ func (c *Configuration) createHandlers(connection *Connection) sftp.Handlers {
 
 func checkAuthError(ip string, err error) {
 	if authErrors, ok := err.(*ssh.ServerAuthError); ok {
-		event := common.HostEventLoginFailed
+		// check public key auth errors here
 		for _, err := range authErrors.Errors {
-			if errors.Is(err, sftpAuthError) {
-				if errors.Is(err, util.ErrNotFound) {
-					event = common.HostEventUserNotFound
+			var sftpAuthErr *authenticationError
+			if errors.As(err, &sftpAuthErr) {
+				if sftpAuthErr.getLoginMethod() == dataprovider.SSHLoginMethodPublicKey {
+					event := common.HostEventLoginFailed
+					if errors.Is(err, util.ErrNotFound) {
+						event = common.HostEventUserNotFound
+					}
+					common.AddDefenderEvent(ip, event)
+					return
 				}
-				break
 			}
 		}
-		common.AddDefenderEvent(ip, event)
 	} else {
 		logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
 		metric.AddNoAuthTryed()
@@ -1078,7 +1090,7 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
 				cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey))
 		}
 		if user.IsPartialAuth(method) {
-			logger.Debug(logSender, connectionID, "user %#v authenticated with partial success", conn.User())
+			logger.Debug(logSender, connectionID, "user %q authenticated with partial success", conn.User())
 			return certPerm, ssh.ErrPartialSuccess
 		}
 		sshPerm, err = loginUser(&user, method, keyID, conn)

+ 14 - 2
internal/sftpd/sftpd_test.go

@@ -966,6 +966,7 @@ func TestDefender(t *testing.T) {
 	cfg.DefenderConfig.Enabled = true
 	cfg.DefenderConfig.Threshold = 3
 	cfg.DefenderConfig.ScoreLimitExceeded = 2
+	cfg.DefenderConfig.ScoreValid = 1
 
 	err := common.Initialize(cfg, 0)
 	assert.NoError(t, err)
@@ -977,12 +978,23 @@ func TestDefender(t *testing.T) {
 	if assert.NoError(t, err) {
 		defer conn.Close()
 		defer client.Close()
+
 		err = checkBasicSFTP(client)
 		assert.NoError(t, err)
 	}
 
-	for i := 0; i < 3; i++ {
-		user.Password = "wrong_pwd"
+	user.Password = "wrong_pwd"
+	_, _, err = getSftpClient(user, usePubKey)
+	assert.Error(t, err)
+	hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
+	assert.NoError(t, err)
+	if assert.Len(t, hosts, 1) {
+		host := hosts[0]
+		assert.Empty(t, host.GetBanTime())
+		assert.Equal(t, 1, host.Score)
+	}
+
+	for i := 0; i < 2; i++ {
 		_, _, err = getSftpClient(user, usePubKey)
 		assert.Error(t, err)
 	}

+ 13 - 2
internal/webdavd/webdavd_test.go

@@ -995,6 +995,7 @@ func TestDefender(t *testing.T) {
 	cfg.DefenderConfig.Enabled = true
 	cfg.DefenderConfig.Threshold = 3
 	cfg.DefenderConfig.ScoreLimitExceeded = 2
+	cfg.DefenderConfig.ScoreValid = 1
 
 	err := common.Initialize(cfg, 0)
 	assert.NoError(t, err)
@@ -1004,8 +1005,18 @@ func TestDefender(t *testing.T) {
 	client := getWebDavClient(user, true, nil)
 	assert.NoError(t, checkBasicFunc(client))
 
-	for i := 0; i < 3; i++ {
-		user.Password = "wrong_pwd"
+	user.Password = "wrong_pwd"
+	client = getWebDavClient(user, false, nil)
+	assert.Error(t, checkBasicFunc(client))
+	hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK)
+	assert.NoError(t, err)
+	if assert.Len(t, hosts, 1) {
+		host := hosts[0]
+		assert.Empty(t, host.GetBanTime())
+		assert.Equal(t, 1, host.Score)
+	}
+
+	for i := 0; i < 2; i++ {
 		client = getWebDavClient(user, false, nil)
 		assert.Error(t, checkBasicFunc(client))
 	}