瀏覽代碼

add support for partial authentication

Multi-step authentication is activated disabling all single-step
auth methods for a given user
Nicola Murino 5 年之前
父節點
當前提交
b1c7317cf6

+ 1 - 0
README.md

@@ -11,6 +11,7 @@ Fully featured and highly configurable SFTP server, written in Go
 - SQLite, MySQL, PostgreSQL, bbolt (key/value store in pure Go) and in-memory data providers are supported.
 - SQLite, MySQL, PostgreSQL, bbolt (key/value store in pure Go) and in-memory data providers are supported.
 - Public key and password authentication. Multiple public keys per user are supported.
 - Public key and password authentication. Multiple public keys per user are supported.
 - Keyboard interactive authentication. You can easily setup a customizable multi-factor authentication.
 - Keyboard interactive authentication. You can easily setup a customizable multi-factor authentication.
+- Partial authentication. You can configure multi-step authentication requiring, for example, the user password after successful public key authentication.
 - Per user authentication methods. You can, for example, deny one or more authentication methods to one or more users.
 - Per user authentication methods. You can, for example, deny one or more authentication methods to one or more users.
 - Custom authentication via external programs is supported.
 - Custom authentication via external programs is supported.
 - Dynamic user modification before login via external programs is supported.
 - Dynamic user modification before login via external programs is supported.

+ 1 - 1
dataprovider/bolt.go

@@ -116,7 +116,7 @@ func (p BoltProvider) validateUserAndPass(username string, password string) (Use
 	return checkUserAndPass(user, password)
 	return checkUserAndPass(user, password)
 }
 }
 
 
-func (p BoltProvider) validateUserAndPubKey(username string, pubKey string) (User, string, error) {
+func (p BoltProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) {
 	var user User
 	var user User
 	if len(pubKey) == 0 {
 	if len(pubKey) == 0 {
 		return user, "", errors.New("Credentials cannot be null or empty")
 		return user, "", errors.New("Credentials cannot be null or empty")

+ 18 - 15
dataprovider/dataprovider.go

@@ -73,18 +73,21 @@ const (
 )
 )
 
 
 var (
 var (
-	// SupportedProviders data provider configured in the sftpgo.conf file must match of these strings
+	// SupportedProviders defines the supported data providers
 	SupportedProviders = []string{SQLiteDataProviderName, PGSQLDataProviderName, MySQLDataProviderName,
 	SupportedProviders = []string{SQLiteDataProviderName, PGSQLDataProviderName, MySQLDataProviderName,
 		BoltDataProviderName, MemoryDataProviderName}
 		BoltDataProviderName, MemoryDataProviderName}
-	// ValidPerms list that contains all the valid permissions for a user
+	// ValidPerms defines all the valid permissions for a user
 	ValidPerms = []string{PermAny, PermListItems, PermDownload, PermUpload, PermOverwrite, PermRename, PermDelete,
 	ValidPerms = []string{PermAny, PermListItems, PermDownload, PermUpload, PermOverwrite, PermRename, PermDelete,
 		PermCreateDirs, PermCreateSymlinks, PermChmod, PermChown, PermChtimes}
 		PermCreateDirs, PermCreateSymlinks, PermChmod, PermChown, PermChtimes}
-	// ValidSSHLoginMethods list that contains all the valid SSH login methods
-	ValidSSHLoginMethods = []string{SSHLoginMethodPublicKey, SSHLoginMethodPassword, SSHLoginMethodKeyboardInteractive}
-	config               Config
-	provider             Provider
-	sqlPlaceholders      []string
-	hashPwdPrefixes      = []string{argonPwdPrefix, bcryptPwdPrefix, pbkdf2SHA1Prefix, pbkdf2SHA256Prefix,
+	// ValidSSHLoginMethods defines all the valid SSH login methods
+	ValidSSHLoginMethods = []string{SSHLoginMethodPublicKey, SSHLoginMethodPassword, SSHLoginMethodKeyboardInteractive,
+		SSHLoginMethodKeyAndPassword, SSHLoginMethodKeyAndKeyboardInt}
+	// SSHMultiStepsLoginMethods defines the supported Multi-Step Authentications
+	SSHMultiStepsLoginMethods = []string{SSHLoginMethodKeyAndPassword, SSHLoginMethodKeyAndKeyboardInt}
+	config                    Config
+	provider                  Provider
+	sqlPlaceholders           []string
+	hashPwdPrefixes           = []string{argonPwdPrefix, bcryptPwdPrefix, pbkdf2SHA1Prefix, pbkdf2SHA256Prefix,
 		pbkdf2SHA512Prefix, md5cryptPwdPrefix, md5cryptApr1PwdPrefix, sha512cryptPwdPrefix}
 		pbkdf2SHA512Prefix, md5cryptPwdPrefix, md5cryptApr1PwdPrefix, sha512cryptPwdPrefix}
 	pbkdfPwdPrefixes       = []string{pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, pbkdf2SHA512Prefix}
 	pbkdfPwdPrefixes       = []string{pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, pbkdf2SHA512Prefix}
 	unixPwdPrefixes        = []string{md5cryptPwdPrefix, md5cryptApr1PwdPrefix, sha512cryptPwdPrefix}
 	unixPwdPrefixes        = []string{md5cryptPwdPrefix, md5cryptApr1PwdPrefix, sha512cryptPwdPrefix}
@@ -311,7 +314,7 @@ func GetQuotaTracking() int {
 // Provider interface that data providers must implement.
 // Provider interface that data providers must implement.
 type Provider interface {
 type Provider interface {
 	validateUserAndPass(username string, password string) (User, error)
 	validateUserAndPass(username string, password string) (User, error)
-	validateUserAndPubKey(username string, pubKey string) (User, string, error)
+	validateUserAndPubKey(username string, pubKey []byte) (User, string, error)
 	updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error
 	updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error
 	getUsedQuota(username string) (int, int64, error)
 	getUsedQuota(username string) (int, int64, error)
 	userExists(username string) (User, error)
 	userExists(username string) (User, error)
@@ -401,7 +404,7 @@ func InitializeDatabase(cnf Config, basePath string) error {
 // CheckUserAndPass retrieves the SFTP user with the given username and password if a match is found or an error
 // CheckUserAndPass retrieves the SFTP user with the given username and password if a match is found or an error
 func CheckUserAndPass(p Provider, username string, password string) (User, error) {
 func CheckUserAndPass(p Provider, username string, password string) (User, error) {
 	if len(config.ExternalAuthHook) > 0 && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&1 != 0) {
 	if len(config.ExternalAuthHook) > 0 && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&1 != 0) {
-		user, err := doExternalAuth(username, password, "", "")
+		user, err := doExternalAuth(username, password, nil, "")
 		if err != nil {
 		if err != nil {
 			return user, err
 			return user, err
 		}
 		}
@@ -418,7 +421,7 @@ func CheckUserAndPass(p Provider, username string, password string) (User, error
 }
 }
 
 
 // CheckUserAndPubKey retrieves the SFTP user with the given username and public key if a match is found or an error
 // CheckUserAndPubKey retrieves the SFTP user with the given username and public key if a match is found or an error
-func CheckUserAndPubKey(p Provider, username string, pubKey string) (User, string, error) {
+func CheckUserAndPubKey(p Provider, username string, pubKey []byte) (User, string, error) {
 	if len(config.ExternalAuthHook) > 0 && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&2 != 0) {
 	if len(config.ExternalAuthHook) > 0 && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&2 != 0) {
 		user, err := doExternalAuth(username, "", pubKey, "")
 		user, err := doExternalAuth(username, "", pubKey, "")
 		if err != nil {
 		if err != nil {
@@ -442,7 +445,7 @@ func CheckKeyboardInteractiveAuth(p Provider, username, authHook string, client
 	var user User
 	var user User
 	var err error
 	var err error
 	if len(config.ExternalAuthHook) > 0 && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&4 != 0) {
 	if len(config.ExternalAuthHook) > 0 && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&4 != 0) {
-		user, err = doExternalAuth(username, "", "", "1")
+		user, err = doExternalAuth(username, "", nil, "1")
 	} else if len(config.PreLoginHook) > 0 {
 	} else if len(config.PreLoginHook) > 0 {
 		user, err = executePreLoginHook(username, SSHLoginMethodKeyboardInteractive)
 		user, err = executePreLoginHook(username, SSHLoginMethodKeyboardInteractive)
 	} else {
 	} else {
@@ -934,7 +937,7 @@ func checkUserAndPass(user User, password string) (User, error) {
 	return user, err
 	return user, err
 }
 }
 
 
-func checkUserAndPubKey(user User, pubKey string) (User, string, error) {
+func checkUserAndPubKey(user User, pubKey []byte) (User, string, error) {
 	err := checkLoginConditions(user)
 	err := checkLoginConditions(user)
 	if err != nil {
 	if err != nil {
 		return user, "", err
 		return user, "", err
@@ -948,7 +951,7 @@ func checkUserAndPubKey(user User, pubKey string) (User, string, error) {
 			providerLog(logger.LevelWarn, "error parsing stored public key %d for user %v: %v", i, user.Username, err)
 			providerLog(logger.LevelWarn, "error parsing stored public key %d for user %v: %v", i, user.Username, err)
 			return user, "", err
 			return user, "", err
 		}
 		}
-		if string(storedPubKey.Marshal()) == pubKey {
+		if bytes.Equal(storedPubKey.Marshal(), pubKey) {
 			fp := ssh.FingerprintSHA256(storedPubKey)
 			fp := ssh.FingerprintSHA256(storedPubKey)
 			return user, fp + ":" + comment, nil
 			return user, fp + ":" + comment, nil
 		}
 		}
@@ -1451,7 +1454,7 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive strin
 	return cmd.Output()
 	return cmd.Output()
 }
 }
 
 
-func doExternalAuth(username, password, pubKey, keyboardInteractive string) (User, error) {
+func doExternalAuth(username, password string, pubKey []byte, keyboardInteractive string) (User, error) {
 	var user User
 	var user User
 	pkey := ""
 	pkey := ""
 	if len(pubKey) > 0 {
 	if len(pubKey) > 0 {

+ 1 - 1
dataprovider/memory.go

@@ -91,7 +91,7 @@ func (p MemoryProvider) validateUserAndPass(username string, password string) (U
 	return checkUserAndPass(user, password)
 	return checkUserAndPass(user, password)
 }
 }
 
 
-func (p MemoryProvider) validateUserAndPubKey(username string, pubKey string) (User, string, error) {
+func (p MemoryProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) {
 	var user User
 	var user User
 	if len(pubKey) == 0 {
 	if len(pubKey) == 0 {
 		return user, "", errors.New("Credentials cannot be null or empty")
 		return user, "", errors.New("Credentials cannot be null or empty")

+ 1 - 1
dataprovider/mysql.go

@@ -66,7 +66,7 @@ func (p MySQLProvider) validateUserAndPass(username string, password string) (Us
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 }
 }
 
 
-func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, string, error) {
+func (p MySQLProvider) validateUserAndPubKey(username string, publicKey []byte) (User, string, error) {
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 }
 }
 
 

+ 1 - 1
dataprovider/pgsql.go

@@ -64,7 +64,7 @@ func (p PGSQLProvider) validateUserAndPass(username string, password string) (Us
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 }
 }
 
 
-func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey string) (User, string, error) {
+func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey []byte) (User, string, error) {
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 }
 }
 
 

+ 1 - 1
dataprovider/sqlcommon.go

@@ -44,7 +44,7 @@ func sqlCommonValidateUserAndPass(username string, password string, dbHandle *sq
 	return checkUserAndPass(user, password)
 	return checkUserAndPass(user, password)
 }
 }
 
 
-func sqlCommonValidateUserAndPubKey(username string, pubKey string, dbHandle *sql.DB) (User, string, error) {
+func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) {
 	var user User
 	var user User
 	if len(pubKey) == 0 {
 	if len(pubKey) == 0 {
 		return user, "", errors.New("Credentials cannot be null or empty")
 		return user, "", errors.New("Credentials cannot be null or empty")

+ 1 - 1
dataprovider/sqlite.go

@@ -63,7 +63,7 @@ func (p SQLiteProvider) validateUserAndPass(username string, password string) (U
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 }
 }
 
 
-func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey string) (User, string, error) {
+func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey []byte) (User, string, error) {
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 }
 }
 
 

+ 61 - 3
dataprovider/user.go

@@ -50,6 +50,8 @@ const (
 	SSHLoginMethodPublicKey           = "publickey"
 	SSHLoginMethodPublicKey           = "publickey"
 	SSHLoginMethodPassword            = "password"
 	SSHLoginMethodPassword            = "password"
 	SSHLoginMethodKeyboardInteractive = "keyboard-interactive"
 	SSHLoginMethodKeyboardInteractive = "keyboard-interactive"
+	SSHLoginMethodKeyAndPassword      = "publickey+password"
+	SSHLoginMethodKeyAndKeyboardInt   = "publickey+keyboard-interactive"
 )
 )
 
 
 // ExtensionsFilter defines filters based on file extensions.
 // ExtensionsFilter defines filters based on file extensions.
@@ -246,17 +248,73 @@ func (u *User) HasPerms(permissions []string, path string) bool {
 	return true
 	return true
 }
 }
 
 
-// IsLoginMethodAllowed returns true if the specified login method is allowed for the user
-func (u *User) IsLoginMethodAllowed(loginMetod string) bool {
+// IsLoginMethodAllowed returns true if the specified login method is allowed
+func (u *User) IsLoginMethodAllowed(loginMethod string, partialSuccessMethods []string) bool {
 	if len(u.Filters.DeniedLoginMethods) == 0 {
 	if len(u.Filters.DeniedLoginMethods) == 0 {
 		return true
 		return true
 	}
 	}
-	if utils.IsStringInSlice(loginMetod, u.Filters.DeniedLoginMethods) {
+	if len(partialSuccessMethods) == 1 {
+		for _, method := range u.GetNextAuthMethods(partialSuccessMethods) {
+			if method == loginMethod {
+				return true
+			}
+		}
+	}
+	if utils.IsStringInSlice(loginMethod, u.Filters.DeniedLoginMethods) {
 		return false
 		return false
 	}
 	}
 	return true
 	return true
 }
 }
 
 
+// GetNextAuthMethods returns the list of authentications methods that
+// can continue for multi-step authentication
+func (u *User) GetNextAuthMethods(partialSuccessMethods []string) []string {
+	var methods []string
+	if len(partialSuccessMethods) != 1 {
+		return methods
+	}
+	if partialSuccessMethods[0] != SSHLoginMethodPublicKey {
+		return methods
+	}
+	for _, method := range u.GetAllowedLoginMethods() {
+		if method == SSHLoginMethodKeyAndPassword {
+			methods = append(methods, SSHLoginMethodPassword)
+		}
+		if method == SSHLoginMethodKeyAndKeyboardInt {
+			methods = append(methods, SSHLoginMethodKeyboardInteractive)
+		}
+	}
+	return methods
+}
+
+// IsPartialAuth returns true if the specified login method is a step for
+// a multi-step Authentication.
+// We support publickey+password and publickey+keyboard-interactive, so
+// only publickey can returns partial success.
+// We can have partial success if only multi-step Auth methods are enabled
+func (u *User) IsPartialAuth(loginMethod string) bool {
+	if loginMethod != SSHLoginMethodPublicKey {
+		return false
+	}
+	for _, method := range u.GetAllowedLoginMethods() {
+		if !utils.IsStringInSlice(method, SSHMultiStepsLoginMethods) {
+			return false
+		}
+	}
+	return true
+}
+
+// GetAllowedLoginMethods returns the allowed login methods
+func (u *User) GetAllowedLoginMethods() []string {
+	var allowedMethods []string
+	for _, method := range ValidSSHLoginMethods {
+		if !utils.IsStringInSlice(method, u.Filters.DeniedLoginMethods) {
+			allowedMethods = append(allowedMethods, method)
+		}
+	}
+	return allowedMethods
+}
+
 // IsFileAllowed returns true if the specified file is allowed by the file restrictions filters
 // IsFileAllowed returns true if the specified file is allowed by the file restrictions filters
 func (u *User) IsFileAllowed(sftpPath string) bool {
 func (u *User) IsFileAllowed(sftpPath string) bool {
 	if len(u.Filters.FileExtensions) == 0 {
 	if len(u.Filters.FileExtensions) == 0 {

+ 3 - 1
docs/account.md

@@ -30,10 +30,12 @@ For each account, the following properties can be configured:
 - `download_bandwidth` maximum download bandwidth as KB/s, 0 means unlimited.
 - `download_bandwidth` maximum download bandwidth as KB/s, 0 means unlimited.
 - `allowed_ip`, List of IP/Mask allowed to login. Any IP address not contained in this list cannot login. IP/Mask must be in CIDR notation as defined in RFC 4632 and RFC 4291, for example "192.0.2.0/24" or "2001:db8::/32"
 - `allowed_ip`, List of IP/Mask allowed to login. Any IP address not contained in this list cannot login. IP/Mask must be in CIDR notation as defined in RFC 4632 and RFC 4291, for example "192.0.2.0/24" or "2001:db8::/32"
 - `denied_ip`, List of IP/Mask not allowed to login. If an IP address is both allowed and denied then login will be denied
 - `denied_ip`, List of IP/Mask not allowed to login. If an IP address is both allowed and denied then login will be denied
-- `denied_login_methods`, List of login methods not allowed. The following login methods are supported:
+- `denied_login_methods`, List of login methods not allowed. To enable multi-step authentication you have to allow only multi-step login methods. The following login methods are supported:
   - `publickey`
   - `publickey`
   - `password`
   - `password`
   - `keyboard-interactive`
   - `keyboard-interactive`
+  - `publickey+password`
+  - `publickey+keyboard-interactive`
 - `file_extensions`, list of struct. These restrictions do not apply to files listing for performance reasons, so a denied file cannot be downloaded/overwritten/renamed but it will still be listed in the list of files. Please note that these restrictions can be easily bypassed. Each struct contains the following fields:
 - `file_extensions`, list of struct. These restrictions do not apply to files listing for performance reasons, so a denied file cannot be downloaded/overwritten/renamed but it will still be listed in the list of files. Please note that these restrictions can be easily bypassed. Each struct contains the following fields:
   - `allowed_extensions`, list of, case insensitive, allowed files extension. Shell like expansion is not supported so you have to specify `.jpg` and not `*.jpg`. Any file that does not end with this suffix will be denied
   - `allowed_extensions`, list of, case insensitive, allowed files extension. Shell like expansion is not supported so you have to specify `.jpg` and not `*.jpg`. Any file that does not end with this suffix will be denied
   - `denied_extensions`, list of, case insensitive, denied files extension. Denied file extensions are evaluated before the allowed ones
   - `denied_extensions`, list of, case insensitive, denied files extension. Denied file extensions are evaluated before the allowed ones

+ 1 - 1
docs/metrics.md

@@ -10,7 +10,7 @@ Several counters and gauges are available, for example:
 - Total SSH command errors
 - Total SSH command errors
 - Number of active connections
 - Number of active connections
 - Data provider availability
 - Data provider availability
-- Total successful and failed logins using password, public key or keyboard interactive authentication
+- Total successful and failed logins using password, public key, keyboard interactive authentication or supported multi-step authentications
 - Total HTTP requests served and totals for response code
 - Total HTTP requests served and totals for response code
 - Go's runtime details about GC, number of gouroutines and OS threads
 - Go's runtime details about GC, number of gouroutines and OS threads
 - Process information like CPU, memory, file descriptor usage and start time
 - Process information like CPU, memory, file descriptor usage and start time

+ 2 - 2
go.mod

@@ -43,6 +43,6 @@ require (
 )
 )
 
 
 replace (
 replace (
-	github.com/eikenb/pipeat => github.com/drakkan/pipeat v0.0.0-20200315002837-010186aaa07d
-	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20200403175740-a4c92a934d79
+	github.com/eikenb/pipeat => github.com/drakkan/pipeat v0.0.0-20200327213700-f3a27a751cdc
+	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20200409210311-95730af1ff98
 )
 )

+ 4 - 4
go.sum

@@ -61,10 +61,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
 github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
 github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
 github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
-github.com/drakkan/crypto v0.0.0-20200403175740-a4c92a934d79 h1:dARy4jBz2xeVoGq3CTZ8dsQTXB4G833M5cIq7rMtEuw=
-github.com/drakkan/crypto v0.0.0-20200403175740-a4c92a934d79/go.mod h1:v3bhWOXGYda7H5d2s5t9XA6th3fxW3s0MQxU1R96G/w=
-github.com/drakkan/pipeat v0.0.0-20200315002837-010186aaa07d h1:qD1b7ZnrTUscSof+W+Pa3D9hN4jmQ/UcoZ05q7W96rA=
-github.com/drakkan/pipeat v0.0.0-20200315002837-010186aaa07d/go.mod h1:wNYvIpR5rIhoezOYcpxcXz4HbIEOu7A45EqlQCA+h+w=
+github.com/drakkan/crypto v0.0.0-20200409210311-95730af1ff98 h1:5yTRWoqE1GhPfctw9E0wh1d3+xHajMyhQAQuxuNqwQ8=
+github.com/drakkan/crypto v0.0.0-20200409210311-95730af1ff98/go.mod h1:v3bhWOXGYda7H5d2s5t9XA6th3fxW3s0MQxU1R96G/w=
+github.com/drakkan/pipeat v0.0.0-20200327213700-f3a27a751cdc h1:aDQn2DUfujideR+efH3697MZk+9YY4T95P7BujKkWiY=
+github.com/drakkan/pipeat v0.0.0-20200327213700-f3a27a751cdc/go.mod h1:+JPhBw5JdJrSF80r6xsSg1TYHjyAGxYs4X24VyUdMZU=
 github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
 github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=

+ 1 - 2
httpd/httpd_test.go

@@ -327,8 +327,7 @@ func TestAddUserInvalidFilters(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Errorf("unexpected error adding user with invalid filters: %v", err)
 		t.Errorf("unexpected error adding user with invalid filters: %v", err)
 	}
 	}
-	u.Filters.DeniedLoginMethods = []string{dataprovider.SSHLoginMethodKeyboardInteractive,
-		dataprovider.SSHLoginMethodPassword, dataprovider.SSHLoginMethodPublicKey}
+	u.Filters.DeniedLoginMethods = dataprovider.ValidSSHLoginMethods
 	_, _, err = httpd.AddUser(u, http.StatusBadRequest)
 	_, _, err = httpd.AddUser(u, http.StatusBadRequest)
 	if err != nil {
 	if err != nil {
 		t.Errorf("unexpected error adding user with invalid filters: %v", err)
 		t.Errorf("unexpected error adding user with invalid filters: %v", err)

+ 2 - 0
httpd/schema/openapi.yaml

@@ -933,6 +933,8 @@ components:
         - 'publickey'
         - 'publickey'
         - 'password'
         - 'password'
         - 'keyboard-interactive'
         - 'keyboard-interactive'
+        - 'publickey+password'
+        - 'publickey+keyboard-interactive'
     ExtensionsFilter:
     ExtensionsFilter:
       type: object
       type: object
       properties:
       properties:

+ 54 - 0
metrics/metrics.go

@@ -148,6 +148,48 @@ var (
 		Help: "The total number of failed logins using keyboard interactive authentication",
 		Help: "The total number of failed logins using keyboard interactive authentication",
 	})
 	})
 
 
+	// totalKeyAndPasswordLoginAttempts is the metric that reports the total number of
+	// login attempts using public key + password multi steps auth
+	totalKeyAndPasswordLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_key_and_password_login_attempts_total",
+		Help: "The total number of login attempts using public key + password",
+	})
+
+	// totalKeyAndPasswordLoginOK is the metric that reports the total number of
+	// successful logins using public key + password multi steps auth
+	totalKeyAndPasswordLoginOK = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_key_and_password_login_ok_total",
+		Help: "The total number of successful logins using public key + password",
+	})
+
+	// totalKeyAndPasswordLoginFailed is the metric that reports the total number of
+	// failed logins using public key + password multi steps auth
+	totalKeyAndPasswordLoginFailed = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_key_and_password_login_ko_total",
+		Help: "The total number of failed logins using  public key + password",
+	})
+
+	// totalKeyAndKeyIntLoginAttempts is the metric that reports the total number of
+	// login attempts using public key + keyboard interactive multi steps auth
+	totalKeyAndKeyIntLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_key_and_keyboard_int_login_attempts_total",
+		Help: "The total number of login attempts using public key + keyboard interactive",
+	})
+
+	// totalKeyAndKeyIntLoginOK is the metric that reports the total number of
+	// successful logins using public key + keyboard interactive multi steps auth
+	totalKeyAndKeyIntLoginOK = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_key_and_keyboard_int_login_ok_total",
+		Help: "The total number of successful logins using public key + keyboard interactive",
+	})
+
+	// totalKeyAndKeyIntLoginFailed is the metric that reports the total number of
+	// failed logins using public key + keyboard interactive multi steps auth
+	totalKeyAndKeyIntLoginFailed = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_key_and_keyboard_int_login_ko_total",
+		Help: "The total number of failed logins using  public key + keyboard interactive",
+	})
+
 	totalHTTPRequests = promauto.NewCounter(prometheus.CounterOpts{
 	totalHTTPRequests = promauto.NewCounter(prometheus.CounterOpts{
 		Name: "sftpgo_http_req_total",
 		Name: "sftpgo_http_req_total",
 		Help: "The total number of HTTP requests served",
 		Help: "The total number of HTTP requests served",
@@ -498,6 +540,10 @@ func AddLoginAttempt(authMethod string) {
 		totalKeyLoginAttempts.Inc()
 		totalKeyLoginAttempts.Inc()
 	case "keyboard-interactive":
 	case "keyboard-interactive":
 		totalInteractiveLoginAttempts.Inc()
 		totalInteractiveLoginAttempts.Inc()
+	case "publickey+password":
+		totalKeyAndPasswordLoginAttempts.Inc()
+	case "publickey+keyboard-interactive":
+		totalKeyAndKeyIntLoginAttempts.Inc()
 	default:
 	default:
 		totalPasswordLoginAttempts.Inc()
 		totalPasswordLoginAttempts.Inc()
 	}
 	}
@@ -512,6 +558,10 @@ func AddLoginResult(authMethod string, err error) {
 			totalKeyLoginOK.Inc()
 			totalKeyLoginOK.Inc()
 		case "keyboard-interactive":
 		case "keyboard-interactive":
 			totalInteractiveLoginOK.Inc()
 			totalInteractiveLoginOK.Inc()
+		case "publickey+password":
+			totalKeyAndPasswordLoginOK.Inc()
+		case "publickey+keyboard-interactive":
+			totalKeyAndKeyIntLoginOK.Inc()
 		default:
 		default:
 			totalPasswordLoginOK.Inc()
 			totalPasswordLoginOK.Inc()
 		}
 		}
@@ -522,6 +572,10 @@ func AddLoginResult(authMethod string, err error) {
 			totalKeyLoginFailed.Inc()
 			totalKeyLoginFailed.Inc()
 		case "keyboard-interactive":
 		case "keyboard-interactive":
 			totalInteractiveLoginFailed.Inc()
 			totalInteractiveLoginFailed.Inc()
+		case "publickey+password":
+			totalKeyAndPasswordLoginFailed.Inc()
+		case "publickey+keyboard-interactive":
+			totalKeyAndKeyIntLoginFailed.Inc()
 		default:
 		default:
 			totalPasswordLoginFailed.Inc()
 			totalPasswordLoginFailed.Inc()
 		}
 		}

+ 2 - 1
scripts/sftpgo_api_cli.py

@@ -503,7 +503,8 @@ def addCommonUserArguments(parser):
 							'create_symlinks', 'chmod', 'chown', 'chtimes'], help='Permissions for the root directory '
 							'create_symlinks', 'chmod', 'chown', 'chtimes'], help='Permissions for the root directory '
 							+'(/). Default: %(default)s')
 							+'(/). Default: %(default)s')
 	parser.add_argument('-L', '--denied-login-methods', type=str, nargs='+', default=[],
 	parser.add_argument('-L', '--denied-login-methods', type=str, nargs='+', default=[],
-					choices=['', 'publickey', 'password', 'keyboard-interactive'], help='Default: %(default)s')
+					choices=['', 'publickey', 'password', 'keyboard-interactive', 'publickey+password',
+							'publickey+keyboard-interactive'], help='Default: %(default)s')
 	parser.add_argument('--subdirs-permissions', type=str, nargs='*', default=[], help='Permissions for subdirs. '
 	parser.add_argument('--subdirs-permissions', type=str, nargs='*', default=[], help='Permissions for subdirs. '
 					+'For example: "/somedir::list,download" "/otherdir/subdir::*" Default: %(default)s')
 					+'For example: "/somedir::list,download" "/otherdir/subdir::*" Default: %(default)s')
 	parser.add_argument('--virtual-folders', type=str, nargs='*', default=[], help='Virtual folder mapping. For example: '
 	parser.add_argument('--virtual-folders', type=str, nargs='*', default=[], help='Virtual folder mapping. For example: '

+ 1 - 1
sftpd/internal_test.go

@@ -484,7 +484,7 @@ func TestUploadFiles(t *testing.T) {
 func TestWithInvalidHome(t *testing.T) {
 func TestWithInvalidHome(t *testing.T) {
 	u := dataprovider.User{}
 	u := dataprovider.User{}
 	u.HomeDir = "home_rel_path"
 	u.HomeDir = "home_rel_path"
-	_, err := loginUser(u, dataprovider.SSHLoginMethodPassword, "", "")
+	_, err := loginUser(u, dataprovider.SSHLoginMethodPassword, "", nil)
 	if err == nil {
 	if err == nil {
 		t.Errorf("login a user with an invalid home_dir must fail")
 		t.Errorf("login a user with an invalid home_dir must fail")
 	}
 	}

+ 66 - 39
sftpd/server.go

@@ -158,42 +158,32 @@ func (c Configuration) Initialize(configDir string) error {
 			return sp, nil
 			return sp, nil
 		},
 		},
 		PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
 		PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
-			sp, err := c.validatePublicKeyCredentials(conn, string(pubKey.Marshal()))
+			sp, err := c.validatePublicKeyCredentials(conn, pubKey.Marshal())
+			if err == ssh.ErrPartialSuccess {
+				return nil, err
+			}
 			if err != nil {
 			if err != nil {
 				return nil, &authenticationError{err: fmt.Sprintf("could not validate public key credentials: %v", err)}
 				return nil, &authenticationError{err: fmt.Sprintf("could not validate public key credentials: %v", err)}
 			}
 			}
 
 
 			return sp, nil
 			return sp, nil
 		},
 		},
-		ServerVersion: "SSH-2.0-" + c.Banner,
+		NextAuthMethodsCallback: func(conn ssh.ConnMetadata) []string {
+			var nextMethods []string
+			user, err := dataprovider.UserExists(dataProvider, conn.User())
+			if err == nil {
+				nextMethods = user.GetNextAuthMethods(conn.PartialSuccessMethods())
+			}
+			return nextMethods
+		},
+		ServerVersion: fmt.Sprintf("SSH-2.0-%v", c.Banner),
 	}
 	}
 
 
-	err = c.checkHostKeys(configDir)
+	err = c.checkAndLoadHostKeys(configDir, serverConfig)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	for _, k := range c.Keys {
-		privateFile := k.PrivateKey
-		if !filepath.IsAbs(privateFile) {
-			privateFile = filepath.Join(configDir, privateFile)
-		}
-		logger.Info(logSender, "", "Loading private key: %s", privateFile)
-
-		privateBytes, err := ioutil.ReadFile(privateFile)
-		if err != nil {
-			return err
-		}
-
-		private, err := ssh.ParsePrivateKey(privateBytes)
-		if err != nil {
-			return err
-		}
-
-		// Add private key to the server configuration.
-		serverConfig.AddHostKey(private)
-	}
-
 	c.configureSecurityOptions(serverConfig)
 	c.configureSecurityOptions(serverConfig)
 	c.configureKeyboardInteractiveAuth(serverConfig)
 	c.configureKeyboardInteractiveAuth(serverConfig)
 	c.configureLoginBanner(serverConfig, configDir)
 	c.configureLoginBanner(serverConfig, configDir)
@@ -285,9 +275,10 @@ func (c Configuration) configureLoginBanner(serverConfig *ssh.ServerConfig, conf
 		if !filepath.IsAbs(bannerFilePath) {
 		if !filepath.IsAbs(bannerFilePath) {
 			bannerFilePath = filepath.Join(configDir, bannerFilePath)
 			bannerFilePath = filepath.Join(configDir, bannerFilePath)
 		}
 		}
-		var banner []byte
-		banner, err = ioutil.ReadFile(bannerFilePath)
+		var bannerContent []byte
+		bannerContent, err = ioutil.ReadFile(bannerFilePath)
 		if err == nil {
 		if err == nil {
+			banner := string(bannerContent)
 			serverConfig.BannerCallback = func(conn ssh.ConnMetadata) string {
 			serverConfig.BannerCallback = func(conn ssh.ConnMetadata) string {
 				return string(banner)
 				return string(banner)
 			}
 			}
@@ -459,9 +450,13 @@ func (c Configuration) createHandler(connection Connection) sftp.Handlers {
 	}
 	}
 }
 }
 
 
-func loginUser(user dataprovider.User, loginMethod, remoteAddr, publicKey string) (*ssh.Permissions, error) {
+func loginUser(user dataprovider.User, loginMethod, publicKey string, conn ssh.ConnMetadata) (*ssh.Permissions, error) {
+	connectionID := ""
+	if conn != nil {
+		connectionID = hex.EncodeToString(conn.SessionID())
+	}
 	if !filepath.IsAbs(user.HomeDir) {
 	if !filepath.IsAbs(user.HomeDir) {
-		logger.Warn(logSender, "", "user %#v has an invalid home dir: %#v. Home dir must be an absolute path, login not allowed",
+		logger.Warn(logSender, connectionID, "user %#v has an invalid home dir: %#v. Home dir must be an absolute path, login not allowed",
 			user.Username, user.HomeDir)
 			user.Username, user.HomeDir)
 		return nil, fmt.Errorf("cannot login user with invalid home dir: %#v", user.HomeDir)
 		return nil, fmt.Errorf("cannot login user with invalid home dir: %#v", user.HomeDir)
 	}
 	}
@@ -473,18 +468,19 @@ func loginUser(user dataprovider.User, loginMethod, remoteAddr, publicKey string
 			return nil, fmt.Errorf("too many open sessions: %v", activeSessions)
 			return nil, fmt.Errorf("too many open sessions: %v", activeSessions)
 		}
 		}
 	}
 	}
-	if !user.IsLoginMethodAllowed(loginMethod) {
-		logger.Debug(logSender, "", "cannot login user %#v, login method %#v is not allowed", user.Username, loginMethod)
+	if !user.IsLoginMethodAllowed(loginMethod, conn.PartialSuccessMethods()) {
+		logger.Debug(logSender, connectionID, "cannot login user %#v, login method %#v is not allowed", user.Username, loginMethod)
 		return nil, fmt.Errorf("Login method %#v is not allowed for user %#v", loginMethod, user.Username)
 		return nil, fmt.Errorf("Login method %#v is not allowed for user %#v", loginMethod, user.Username)
 	}
 	}
+	remoteAddr := conn.RemoteAddr().String()
 	if !user.IsLoginFromAddrAllowed(remoteAddr) {
 	if !user.IsLoginFromAddrAllowed(remoteAddr) {
-		logger.Debug(logSender, "", "cannot login user %#v, remote address is not allowed: %v", user.Username, remoteAddr)
+		logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, remoteAddr)
 		return nil, fmt.Errorf("Login for user %#v is not allowed from this address: %v", user.Username, remoteAddr)
 		return nil, fmt.Errorf("Login for user %#v is not allowed from this address: %v", user.Username, remoteAddr)
 	}
 	}
 
 
 	json, err := json.Marshal(user)
 	json, err := json.Marshal(user)
 	if err != nil {
 	if err != nil {
-		logger.Warn(logSender, "", "error serializing user info: %v, authentication rejected", err)
+		logger.Warn(logSender, connectionID, "error serializing user info: %v, authentication rejected", err)
 		return nil, err
 		return nil, err
 	}
 	}
 	if len(publicKey) > 0 {
 	if len(publicKey) > 0 {
@@ -514,8 +510,8 @@ func (c *Configuration) checkSSHCommands() {
 	c.EnabledSSHCommands = sshCommands
 	c.EnabledSSHCommands = sshCommands
 }
 }
 
 
-// If no host keys are defined we try to use or generate the default one.
-func (c *Configuration) checkHostKeys(configDir string) error {
+// If no host keys are defined we try to use or generate the default ones.
+func (c *Configuration) checkAndLoadHostKeys(configDir string, serverConfig *ssh.ServerConfig) error {
 	if len(c.Keys) == 0 {
 	if len(c.Keys) == 0 {
 		defaultKeys := []string{defaultPrivateRSAKeyName, defaultPrivateECDSAKeyName}
 		defaultKeys := []string{defaultPrivateRSAKeyName, defaultPrivateECDSAKeyName}
 		for _, k := range defaultKeys {
 		for _, k := range defaultKeys {
@@ -535,20 +531,45 @@ func (c *Configuration) checkHostKeys(configDir string) error {
 			c.Keys = append(c.Keys, Key{PrivateKey: k})
 			c.Keys = append(c.Keys, Key{PrivateKey: k})
 		}
 		}
 	}
 	}
+	for _, k := range c.Keys {
+		privateFile := k.PrivateKey
+		if !filepath.IsAbs(privateFile) {
+			privateFile = filepath.Join(configDir, privateFile)
+		}
+		logger.Info(logSender, "", "Loading private key: %s", privateFile)
+
+		privateBytes, err := ioutil.ReadFile(privateFile)
+		if err != nil {
+			return err
+		}
+
+		private, err := ssh.ParsePrivateKey(privateBytes)
+		if err != nil {
+			return err
+		}
+
+		// Add private key to the server configuration.
+		serverConfig.AddHostKey(private)
+	}
 	return nil
 	return nil
 }
 }
 
 
-func (c Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey string) (*ssh.Permissions, error) {
+func (c Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey []byte) (*ssh.Permissions, error) {
 	var err error
 	var err error
 	var user dataprovider.User
 	var user dataprovider.User
 	var keyID string
 	var keyID string
 	var sshPerm *ssh.Permissions
 	var sshPerm *ssh.Permissions
 
 
+	connectionID := hex.EncodeToString(conn.SessionID())
 	method := dataprovider.SSHLoginMethodPublicKey
 	method := dataprovider.SSHLoginMethodPublicKey
-	metrics.AddLoginAttempt(method)
 	if user, keyID, err = dataprovider.CheckUserAndPubKey(dataProvider, conn.User(), pubKey); err == nil {
 	if user, keyID, err = dataprovider.CheckUserAndPubKey(dataProvider, conn.User(), pubKey); err == nil {
-		sshPerm, err = loginUser(user, method, conn.RemoteAddr().String(), keyID)
+		if user.IsPartialAuth(method) {
+			logger.Debug(logSender, connectionID, "user %#v authenticated with partial success", conn.User())
+			return nil, ssh.ErrPartialSuccess
+		}
+		sshPerm, err = loginUser(user, method, keyID, conn)
 	}
 	}
+	metrics.AddLoginAttempt(method)
 	if err != nil {
 	if err != nil {
 		logger.ConnectionFailedLog(conn.User(), utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()), method, err.Error())
 		logger.ConnectionFailedLog(conn.User(), utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()), method, err.Error())
 	}
 	}
@@ -562,9 +583,12 @@ func (c Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass [
 	var sshPerm *ssh.Permissions
 	var sshPerm *ssh.Permissions
 
 
 	method := dataprovider.SSHLoginMethodPassword
 	method := dataprovider.SSHLoginMethodPassword
+	if len(conn.PartialSuccessMethods()) == 1 {
+		method = dataprovider.SSHLoginMethodKeyAndPassword
+	}
 	metrics.AddLoginAttempt(method)
 	metrics.AddLoginAttempt(method)
 	if user, err = dataprovider.CheckUserAndPass(dataProvider, conn.User(), string(pass)); err == nil {
 	if user, err = dataprovider.CheckUserAndPass(dataProvider, conn.User(), string(pass)); err == nil {
-		sshPerm, err = loginUser(user, method, conn.RemoteAddr().String(), "")
+		sshPerm, err = loginUser(user, method, "", conn)
 	}
 	}
 	if err != nil {
 	if err != nil {
 		logger.ConnectionFailedLog(conn.User(), utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()), method, err.Error())
 		logger.ConnectionFailedLog(conn.User(), utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()), method, err.Error())
@@ -579,9 +603,12 @@ func (c Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMetad
 	var sshPerm *ssh.Permissions
 	var sshPerm *ssh.Permissions
 
 
 	method := dataprovider.SSHLoginMethodKeyboardInteractive
 	method := dataprovider.SSHLoginMethodKeyboardInteractive
+	if len(conn.PartialSuccessMethods()) == 1 {
+		method = dataprovider.SSHLoginMethodKeyAndKeyboardInt
+	}
 	metrics.AddLoginAttempt(method)
 	metrics.AddLoginAttempt(method)
 	if user, err = dataprovider.CheckKeyboardInteractiveAuth(dataProvider, conn.User(), c.KeyboardInteractiveHook, client); err == nil {
 	if user, err = dataprovider.CheckKeyboardInteractiveAuth(dataProvider, conn.User(), c.KeyboardInteractiveHook, client); err == nil {
-		sshPerm, err = loginUser(user, method, conn.RemoteAddr().String(), "")
+		sshPerm, err = loginUser(user, method, "", conn)
 	}
 	}
 	if err != nil {
 	if err != nil {
 		logger.ConnectionFailedLog(conn.User(), utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()), method, err.Error())
 		logger.ConnectionFailedLog(conn.User(), utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()), method, err.Error())

+ 276 - 0
sftpd/sftpd_test.go

@@ -999,6 +999,120 @@ func TestLogin(t *testing.T) {
 	os.RemoveAll(user.GetHomeDir())
 	os.RemoveAll(user.GetHomeDir())
 }
 }
 
 
+func TestMultiStepLoginKeyAndPwd(t *testing.T) {
+	u := getTestUser(true)
+	u.Password = defaultPassword
+	u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, []string{
+		dataprovider.SSHLoginMethodKeyAndKeyboardInt,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}...)
+	user, _, err := httpd.AddUser(u, http.StatusOK)
+	if err != nil {
+		t.Errorf("unable to add user: %v", err)
+	}
+	_, err = getSftpClient(user, true)
+	if err == nil {
+		t.Error("login with public key is disallowed and must fail")
+	}
+	_, err = getSftpClient(user, true)
+	if err == nil {
+		t.Error("login with password is disallowed and must fail")
+	}
+	key, _ := ssh.ParsePrivateKey([]byte(testPrivateKey))
+	authMethods := []ssh.AuthMethod{
+		ssh.PublicKeys(key),
+		ssh.Password(defaultPassword),
+	}
+	client, err := getCustomAuthSftpClient(user, authMethods)
+	if err != nil {
+		t.Errorf("unable to create sftp client: %v", err)
+	} else {
+		defer client.Close()
+		_, err := client.Getwd()
+		if err != nil {
+			t.Errorf("unexpected error: %v", err)
+		}
+	}
+	authMethods = []ssh.AuthMethod{
+		ssh.Password(defaultPassword),
+		ssh.PublicKeys(key),
+	}
+	_, err = getCustomAuthSftpClient(user, authMethods)
+	if err == nil {
+		t.Error("multi step auth login with wrong order must fail")
+	}
+	_, err = httpd.RemoveUser(user, http.StatusOK)
+	if err != nil {
+		t.Errorf("unable to remove user: %v", err)
+	}
+	os.RemoveAll(user.GetHomeDir())
+}
+
+func TestMultiStepLoginKeyAndKeyInt(t *testing.T) {
+	if runtime.GOOS == "windows" {
+		t.Skip("this test is not available on Windows")
+	}
+	u := getTestUser(true)
+	u.Password = defaultPassword
+	u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, []string{
+		dataprovider.SSHLoginMethodKeyAndPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}...)
+	user, _, err := httpd.AddUser(u, http.StatusOK)
+	if err != nil {
+		t.Errorf("unable to add user: %v", err)
+	}
+	ioutil.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, false, 1), 0755)
+	_, err = getSftpClient(user, true)
+	if err == nil {
+		t.Error("login with public key is disallowed and must fail")
+	}
+	key, _ := ssh.ParsePrivateKey([]byte(testPrivateKey))
+	authMethods := []ssh.AuthMethod{
+		ssh.PublicKeys(key),
+		ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
+			return []string{"1", "2"}, nil
+		}),
+	}
+	client, err := getCustomAuthSftpClient(user, authMethods)
+	if err != nil {
+		t.Errorf("unable to create sftp client: %v", err)
+	} else {
+		defer client.Close()
+		_, err := client.Getwd()
+		if err != nil {
+			t.Errorf("unexpected error: %v", err)
+		}
+	}
+	authMethods = []ssh.AuthMethod{
+		ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
+			return []string{"1", "2"}, nil
+		}),
+		ssh.PublicKeys(key),
+	}
+	_, err = getCustomAuthSftpClient(user, authMethods)
+	if err == nil {
+		t.Error("multi step auth login with wrong order must fail")
+	}
+	authMethods = []ssh.AuthMethod{
+		ssh.PublicKeys(key),
+		ssh.Password(defaultPassword),
+	}
+	_, err = getCustomAuthSftpClient(user, authMethods)
+	if err == nil {
+		t.Error("multi step auth login with wrong method must fail")
+	}
+	_, err = httpd.RemoveUser(user, http.StatusOK)
+	if err != nil {
+		t.Errorf("unable to remove user: %v", err)
+	}
+	os.RemoveAll(user.GetHomeDir())
+}
+
 func TestLoginUserStatus(t *testing.T) {
 func TestLoginUserStatus(t *testing.T) {
 	usePubKey := true
 	usePubKey := true
 	user, _, err := httpd.AddUser(getTestUser(usePubKey), http.StatusOK)
 	user, _, err := httpd.AddUser(getTestUser(usePubKey), http.StatusOK)
@@ -3823,6 +3937,151 @@ func TestFilterFileExtensions(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestUserAllowedLoginMethods(t *testing.T) {
+	user := getTestUser(true)
+	user.Filters.DeniedLoginMethods = dataprovider.ValidSSHLoginMethods
+	allowedMethods := user.GetAllowedLoginMethods()
+	if len(allowedMethods) != 0 {
+		t.Errorf("unexpected allowed methods: %+v", allowedMethods)
+	}
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}
+	allowedMethods = user.GetAllowedLoginMethods()
+	if len(allowedMethods) != 2 {
+		t.Errorf("unexpected allowed methods: %+v", allowedMethods)
+	}
+	if !utils.IsStringInSlice(dataprovider.SSHLoginMethodKeyAndKeyboardInt, allowedMethods) {
+		t.Errorf("unexpected allowed methods: %+v", allowedMethods)
+	}
+	if !utils.IsStringInSlice(dataprovider.SSHLoginMethodKeyAndPassword, allowedMethods) {
+		t.Errorf("unexpected allowed methods: %+v", allowedMethods)
+	}
+}
+
+func TestUserPartialAuth(t *testing.T) {
+	user := getTestUser(true)
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}
+	if user.IsPartialAuth(dataprovider.SSHLoginMethodPassword) {
+		t.Error("unexpected partial auth method")
+	}
+	if user.IsPartialAuth(dataprovider.SSHLoginMethodKeyboardInteractive) {
+		t.Error("unexpected partial auth method")
+	}
+	if !user.IsPartialAuth(dataprovider.SSHLoginMethodPublicKey) {
+		t.Error("public key must be a partial auth method with this configuration")
+	}
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}
+	if user.IsPartialAuth(dataprovider.SSHLoginMethodPublicKey) {
+		t.Error("public key must not be a partial auth method with this configuration")
+	}
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+	}
+	if user.IsPartialAuth(dataprovider.SSHLoginMethodPublicKey) {
+		t.Error("public key must not be a partial auth method with this configuration")
+	}
+}
+
+func TestUserGetNextAuthMethods(t *testing.T) {
+	user := getTestUser(true)
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}
+	methods := user.GetNextAuthMethods(nil)
+	if len(methods) != 0 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodPassword})
+	if len(methods) != 0 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodKeyboardInteractive})
+	if len(methods) != 0 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	methods = user.GetNextAuthMethods([]string{
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	})
+	if len(methods) != 0 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodPublicKey})
+	if len(methods) != 2 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	if !utils.IsStringInSlice(dataprovider.SSHLoginMethodPassword, methods) {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	if !utils.IsStringInSlice(dataprovider.SSHLoginMethodKeyboardInteractive, methods) {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+		dataprovider.SSHLoginMethodKeyAndKeyboardInt,
+	}
+	methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodPublicKey})
+	if len(methods) != 1 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	if !utils.IsStringInSlice(dataprovider.SSHLoginMethodPassword, methods) {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+		dataprovider.SSHLoginMethodKeyAndPassword,
+	}
+	methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodPublicKey})
+	if len(methods) != 1 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	if !utils.IsStringInSlice(dataprovider.SSHLoginMethodKeyboardInteractive, methods) {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+}
+
+func TestUserIsLoginMethodAllowed(t *testing.T) {
+	user := getTestUser(true)
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}
+	if user.IsLoginMethodAllowed(dataprovider.SSHLoginMethodPassword, nil) {
+		t.Error("unexpected login method allowed")
+	}
+	if !user.IsLoginMethodAllowed(dataprovider.SSHLoginMethodPassword, []string{dataprovider.SSHLoginMethodPublicKey}) {
+		t.Error("unexpected login method denied")
+	}
+	if !user.IsLoginMethodAllowed(dataprovider.SSHLoginMethodKeyboardInteractive, []string{dataprovider.SSHLoginMethodPublicKey}) {
+		t.Error("unexpected login method denied")
+	}
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}
+	if !user.IsLoginMethodAllowed(dataprovider.SSHLoginMethodPassword, nil) {
+		t.Error("unexpected login method denied")
+	}
+}
+
 func TestUserEmptySubDirPerms(t *testing.T) {
 func TestUserEmptySubDirPerms(t *testing.T) {
 	user := getTestUser(true)
 	user := getTestUser(true)
 	user.Permissions = make(map[string][]string)
 	user.Permissions = make(map[string][]string)
@@ -5041,6 +5300,23 @@ func getKeyboardInteractiveSftpClient(user dataprovider.User, answers []string)
 	return sftpClient, err
 	return sftpClient, err
 }
 }
 
 
+func getCustomAuthSftpClient(user dataprovider.User, authMethods []ssh.AuthMethod) (*sftp.Client, error) {
+	var sftpClient *sftp.Client
+	config := &ssh.ClientConfig{
+		User: user.Username,
+		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
+			return nil
+		},
+		Auth: authMethods,
+	}
+	conn, err := ssh.Dial("tcp", sftpServerAddr, config)
+	if err != nil {
+		return sftpClient, err
+	}
+	sftpClient, err = sftp.NewClient(conn)
+	return sftpClient, err
+}
+
 func createTestFile(path string, size int64) error {
 func createTestFile(path string, size int64) error {
 	baseDir := filepath.Dir(path)
 	baseDir := filepath.Dir(path)
 	if _, err := os.Stat(baseDir); os.IsNotExist(err) {
 	if _, err := os.Stat(baseDir); os.IsNotExist(err) {