Browse Source

ssh: add username to sftp auth errors

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 1 năm trước cách đây
mục cha
commit
be2ed1089c
4 tập tin đã thay đổi với 25 bổ sung16 xóa
  1. 1 1
      go.mod
  2. 2 2
      go.sum
  3. 8 5
      internal/sftpd/internal_test.go
  4. 14 8
      internal/sftpd/server.go

+ 1 - 1
go.mod

@@ -187,5 +187,5 @@ replace (
 	github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20240210102745-f1ffc43f78d2
 	github.com/pkg/sftp => github.com/drakkan/sftp v0.0.0-20240214104840-fbb0b8bdb30c
 	github.com/robfig/cron/v3 => github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0
-	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20240224191538-9f4629f0732c
+	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20240225143451-17f1441a9706
 )

+ 2 - 2
go.sum

@@ -109,8 +109,8 @@ github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
 github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
 github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 h1:EW9gIJRmt9lzk66Fhh4S8VEtURA6QHZqGeSRE9Nb2/U=
 github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
-github.com/drakkan/crypto v0.0.0-20240224191538-9f4629f0732c h1:Jwxc0vhD7t484x2GDv8B5WbvrSwxML9V+A8esl7vkh8=
-github.com/drakkan/crypto v0.0.0-20240224191538-9f4629f0732c/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
+github.com/drakkan/crypto v0.0.0-20240225143451-17f1441a9706 h1:udyaktWAkqPhMN9iVkpZ4kxdW7clkq7Ia4etPilRMRk=
+github.com/drakkan/crypto v0.0.0-20240225143451-17f1441a9706/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
 github.com/drakkan/ftp v0.0.0-20240210102745-f1ffc43f78d2 h1:ufiGMPFBjndWSQOst9FNP11IuMqPblI2NXbpRMUWNhk=
 github.com/drakkan/ftp v0.0.0-20240210102745-f1ffc43f78d2/go.mod h1:4p8lUl4vQ80L598CygL+3IFtm+3nggvvW/palOlViwE=
 github.com/drakkan/ftpserverlib v0.0.0-20240212100826-a241365cb085 h1:LAKYR9z9USKeyEQK91sRWldmMOjEHLOt2NuLDx+x1UQ=

+ 8 - 5
internal/sftpd/internal_test.go

@@ -2227,23 +2227,26 @@ func TestCanReadSymlink(t *testing.T) {
 
 func TestAuthenticationErrors(t *testing.T) {
 	loginMethod := dataprovider.SSHLoginMethodPassword
-	err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")), loginMethod)
+	username := "test user"
+	err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")),
+		loginMethod, username)
 	assert.ErrorIs(t, err, sftpAuthError)
 	assert.ErrorIs(t, err, util.ErrNotFound)
 	var sftpAuthErr *authenticationError
 	if assert.ErrorAs(t, err, &sftpAuthErr) {
 		assert.Equal(t, loginMethod, sftpAuthErr.getLoginMethod())
+		assert.Equal(t, username, sftpAuthErr.getUsername())
 	}
-	err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission), loginMethod)
+	err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission), loginMethod, username)
 	assert.ErrorIs(t, err, sftpAuthError)
 	assert.NotErrorIs(t, err, util.ErrNotFound)
-	err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert), loginMethod)
+	err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert), loginMethod, username)
 	assert.ErrorIs(t, err, sftpAuthError)
 	assert.NotErrorIs(t, err, util.ErrNotFound)
-	err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"), loginMethod)
+	err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"), loginMethod, username)
 	assert.ErrorIs(t, err, sftpAuthError)
 	assert.NotErrorIs(t, err, util.ErrNotFound)
-	err = newAuthenticationError(nil, loginMethod)
+	err = newAuthenticationError(nil, loginMethod, username)
 	assert.ErrorIs(t, err, sftpAuthError)
 	assert.NotErrorIs(t, err, util.ErrNotFound)
 }

+ 14 - 8
internal/sftpd/server.go

@@ -77,7 +77,7 @@ var (
 		certs: map[string]bool{},
 	}
 
-	sftpAuthError = newAuthenticationError(nil, "")
+	sftpAuthError = newAuthenticationError(nil, "", "")
 )
 
 // Binding defines the configuration for a network listener
@@ -192,6 +192,7 @@ type Configuration struct {
 type authenticationError struct {
 	err         error
 	loginMethod string
+	username    string
 }
 
 func (e *authenticationError) Error() string {
@@ -213,8 +214,12 @@ func (e *authenticationError) getLoginMethod() string {
 	return e.loginMethod
 }
 
-func newAuthenticationError(err error, loginMethod string) *authenticationError {
-	return &authenticationError{err: err, loginMethod: loginMethod}
+func (e *authenticationError) getUsername() string {
+	return e.username
+}
+
+func newAuthenticationError(err error, loginMethod, username string) *authenticationError {
+	return &authenticationError{err: err, loginMethod: loginMethod, username: username}
 }
 
 // ShouldBind returns true if there is at least a valid binding
@@ -240,7 +245,7 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
 			}
 			if err != nil {
 				return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err),
-					dataprovider.SSHLoginMethodPublicKey)
+					dataprovider.SSHLoginMethodPublicKey, conn.User())
 			}
 
 			return sp, nil
@@ -751,7 +756,8 @@ func discardAllChannels(in <-chan ssh.NewChannel, message, connectionID string)
 }
 
 func checkAuthError(ip string, err error) {
-	if authErrors, ok := err.(*ssh.ServerAuthError); ok {
+	var authErrors *ssh.ServerAuthError
+	if errors.As(err, &authErrors) {
 		// check public key auth errors here
 		for _, err := range authErrors.Errors {
 			var sftpAuthErr *authenticationError
@@ -764,7 +770,7 @@ func checkAuthError(ip string, err error) {
 						logEv = notifier.LogEventTypeLoginNoUser
 					}
 					common.AddDefenderEvent(ip, common.ProtocolSSH, event)
-					plugin.Handler.NotifyLogEvent(logEv, common.ProtocolSSH, "", ip, "", err)
+					plugin.Handler.NotifyLogEvent(logEv, common.ProtocolSSH, sftpAuthErr.getUsername(), ip, "", err)
 					return
 				}
 			}
@@ -1215,7 +1221,7 @@ func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass
 	user.Username = conn.User()
 	updateLoginMetrics(&user, ipAddr, method, err)
 	if err != nil {
-		return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err), method)
+		return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err), method, conn.User())
 	}
 	return sshPerm, nil
 }
@@ -1235,7 +1241,7 @@ func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMeta
 	user.Username = conn.User()
 	updateLoginMetrics(&user, ipAddr, method, err)
 	if err != nil {
-		return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err), method)
+		return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err), method, conn.User())
 	}
 	return sshPerm, nil
 }