1
0
Эх сурвалжийг харах

add support for partial authentication

Multi-step authentication is activated disabling all single-step
auth methods for a given user
Nicola Murino 5 жил өмнө
parent
commit
b1c7317cf6

+ 1 - 0
README.md

@@ -11,6 +11,7 @@ Fully featured and highly configurable SFTP server, written in Go
 - SQLite, MySQL, PostgreSQL, bbolt (key/value store in pure Go) and in-memory data providers are supported.
 - Public key and password authentication. Multiple public keys per user are supported.
 - Keyboard interactive authentication. You can easily setup a customizable multi-factor authentication.
+- Partial authentication. You can configure multi-step authentication requiring, for example, the user password after successful public key authentication.
 - Per user authentication methods. You can, for example, deny one or more authentication methods to one or more users.
 - Custom authentication via external programs is supported.
 - Dynamic user modification before login via external programs is supported.

+ 1 - 1
dataprovider/bolt.go

@@ -116,7 +116,7 @@ func (p BoltProvider) validateUserAndPass(username string, password string) (Use
 	return checkUserAndPass(user, password)
 }
 
-func (p BoltProvider) validateUserAndPubKey(username string, pubKey string) (User, string, error) {
+func (p BoltProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) {
 	var user User
 	if len(pubKey) == 0 {
 		return user, "", errors.New("Credentials cannot be null or empty")

+ 18 - 15
dataprovider/dataprovider.go

@@ -73,18 +73,21 @@ const (
 )
 
 var (
-	// SupportedProviders data provider configured in the sftpgo.conf file must match of these strings
+	// SupportedProviders defines the supported data providers
 	SupportedProviders = []string{SQLiteDataProviderName, PGSQLDataProviderName, MySQLDataProviderName,
 		BoltDataProviderName, MemoryDataProviderName}
-	// ValidPerms list that contains all the valid permissions for a user
+	// ValidPerms defines all the valid permissions for a user
 	ValidPerms = []string{PermAny, PermListItems, PermDownload, PermUpload, PermOverwrite, PermRename, PermDelete,
 		PermCreateDirs, PermCreateSymlinks, PermChmod, PermChown, PermChtimes}
-	// ValidSSHLoginMethods list that contains all the valid SSH login methods
-	ValidSSHLoginMethods = []string{SSHLoginMethodPublicKey, SSHLoginMethodPassword, SSHLoginMethodKeyboardInteractive}
-	config               Config
-	provider             Provider
-	sqlPlaceholders      []string
-	hashPwdPrefixes      = []string{argonPwdPrefix, bcryptPwdPrefix, pbkdf2SHA1Prefix, pbkdf2SHA256Prefix,
+	// ValidSSHLoginMethods defines all the valid SSH login methods
+	ValidSSHLoginMethods = []string{SSHLoginMethodPublicKey, SSHLoginMethodPassword, SSHLoginMethodKeyboardInteractive,
+		SSHLoginMethodKeyAndPassword, SSHLoginMethodKeyAndKeyboardInt}
+	// SSHMultiStepsLoginMethods defines the supported Multi-Step Authentications
+	SSHMultiStepsLoginMethods = []string{SSHLoginMethodKeyAndPassword, SSHLoginMethodKeyAndKeyboardInt}
+	config                    Config
+	provider                  Provider
+	sqlPlaceholders           []string
+	hashPwdPrefixes           = []string{argonPwdPrefix, bcryptPwdPrefix, pbkdf2SHA1Prefix, pbkdf2SHA256Prefix,
 		pbkdf2SHA512Prefix, md5cryptPwdPrefix, md5cryptApr1PwdPrefix, sha512cryptPwdPrefix}
 	pbkdfPwdPrefixes       = []string{pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, pbkdf2SHA512Prefix}
 	unixPwdPrefixes        = []string{md5cryptPwdPrefix, md5cryptApr1PwdPrefix, sha512cryptPwdPrefix}
@@ -311,7 +314,7 @@ func GetQuotaTracking() int {
 // Provider interface that data providers must implement.
 type Provider interface {
 	validateUserAndPass(username string, password string) (User, error)
-	validateUserAndPubKey(username string, pubKey string) (User, string, error)
+	validateUserAndPubKey(username string, pubKey []byte) (User, string, error)
 	updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error
 	getUsedQuota(username string) (int, int64, error)
 	userExists(username string) (User, error)
@@ -401,7 +404,7 @@ func InitializeDatabase(cnf Config, basePath string) error {
 // CheckUserAndPass retrieves the SFTP user with the given username and password if a match is found or an error
 func CheckUserAndPass(p Provider, username string, password string) (User, error) {
 	if len(config.ExternalAuthHook) > 0 && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&1 != 0) {
-		user, err := doExternalAuth(username, password, "", "")
+		user, err := doExternalAuth(username, password, nil, "")
 		if err != nil {
 			return user, err
 		}
@@ -418,7 +421,7 @@ func CheckUserAndPass(p Provider, username string, password string) (User, error
 }
 
 // CheckUserAndPubKey retrieves the SFTP user with the given username and public key if a match is found or an error
-func CheckUserAndPubKey(p Provider, username string, pubKey string) (User, string, error) {
+func CheckUserAndPubKey(p Provider, username string, pubKey []byte) (User, string, error) {
 	if len(config.ExternalAuthHook) > 0 && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&2 != 0) {
 		user, err := doExternalAuth(username, "", pubKey, "")
 		if err != nil {
@@ -442,7 +445,7 @@ func CheckKeyboardInteractiveAuth(p Provider, username, authHook string, client
 	var user User
 	var err error
 	if len(config.ExternalAuthHook) > 0 && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&4 != 0) {
-		user, err = doExternalAuth(username, "", "", "1")
+		user, err = doExternalAuth(username, "", nil, "1")
 	} else if len(config.PreLoginHook) > 0 {
 		user, err = executePreLoginHook(username, SSHLoginMethodKeyboardInteractive)
 	} else {
@@ -934,7 +937,7 @@ func checkUserAndPass(user User, password string) (User, error) {
 	return user, err
 }
 
-func checkUserAndPubKey(user User, pubKey string) (User, string, error) {
+func checkUserAndPubKey(user User, pubKey []byte) (User, string, error) {
 	err := checkLoginConditions(user)
 	if err != nil {
 		return user, "", err
@@ -948,7 +951,7 @@ func checkUserAndPubKey(user User, pubKey string) (User, string, error) {
 			providerLog(logger.LevelWarn, "error parsing stored public key %d for user %v: %v", i, user.Username, err)
 			return user, "", err
 		}
-		if string(storedPubKey.Marshal()) == pubKey {
+		if bytes.Equal(storedPubKey.Marshal(), pubKey) {
 			fp := ssh.FingerprintSHA256(storedPubKey)
 			return user, fp + ":" + comment, nil
 		}
@@ -1451,7 +1454,7 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive strin
 	return cmd.Output()
 }
 
-func doExternalAuth(username, password, pubKey, keyboardInteractive string) (User, error) {
+func doExternalAuth(username, password string, pubKey []byte, keyboardInteractive string) (User, error) {
 	var user User
 	pkey := ""
 	if len(pubKey) > 0 {

+ 1 - 1
dataprovider/memory.go

@@ -91,7 +91,7 @@ func (p MemoryProvider) validateUserAndPass(username string, password string) (U
 	return checkUserAndPass(user, password)
 }
 
-func (p MemoryProvider) validateUserAndPubKey(username string, pubKey string) (User, string, error) {
+func (p MemoryProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) {
 	var user User
 	if len(pubKey) == 0 {
 		return user, "", errors.New("Credentials cannot be null or empty")

+ 1 - 1
dataprovider/mysql.go

@@ -66,7 +66,7 @@ func (p MySQLProvider) validateUserAndPass(username string, password string) (Us
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 }
 
-func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, string, error) {
+func (p MySQLProvider) validateUserAndPubKey(username string, publicKey []byte) (User, string, error) {
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 }
 

+ 1 - 1
dataprovider/pgsql.go

@@ -64,7 +64,7 @@ func (p PGSQLProvider) validateUserAndPass(username string, password string) (Us
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 }
 
-func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey string) (User, string, error) {
+func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey []byte) (User, string, error) {
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 }
 

+ 1 - 1
dataprovider/sqlcommon.go

@@ -44,7 +44,7 @@ func sqlCommonValidateUserAndPass(username string, password string, dbHandle *sq
 	return checkUserAndPass(user, password)
 }
 
-func sqlCommonValidateUserAndPubKey(username string, pubKey string, dbHandle *sql.DB) (User, string, error) {
+func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) {
 	var user User
 	if len(pubKey) == 0 {
 		return user, "", errors.New("Credentials cannot be null or empty")

+ 1 - 1
dataprovider/sqlite.go

@@ -63,7 +63,7 @@ func (p SQLiteProvider) validateUserAndPass(username string, password string) (U
 	return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
 }
 
-func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey string) (User, string, error) {
+func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey []byte) (User, string, error) {
 	return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
 }
 

+ 61 - 3
dataprovider/user.go

@@ -50,6 +50,8 @@ const (
 	SSHLoginMethodPublicKey           = "publickey"
 	SSHLoginMethodPassword            = "password"
 	SSHLoginMethodKeyboardInteractive = "keyboard-interactive"
+	SSHLoginMethodKeyAndPassword      = "publickey+password"
+	SSHLoginMethodKeyAndKeyboardInt   = "publickey+keyboard-interactive"
 )
 
 // ExtensionsFilter defines filters based on file extensions.
@@ -246,17 +248,73 @@ func (u *User) HasPerms(permissions []string, path string) bool {
 	return true
 }
 
-// IsLoginMethodAllowed returns true if the specified login method is allowed for the user
-func (u *User) IsLoginMethodAllowed(loginMetod string) bool {
+// IsLoginMethodAllowed returns true if the specified login method is allowed
+func (u *User) IsLoginMethodAllowed(loginMethod string, partialSuccessMethods []string) bool {
 	if len(u.Filters.DeniedLoginMethods) == 0 {
 		return true
 	}
-	if utils.IsStringInSlice(loginMetod, u.Filters.DeniedLoginMethods) {
+	if len(partialSuccessMethods) == 1 {
+		for _, method := range u.GetNextAuthMethods(partialSuccessMethods) {
+			if method == loginMethod {
+				return true
+			}
+		}
+	}
+	if utils.IsStringInSlice(loginMethod, u.Filters.DeniedLoginMethods) {
 		return false
 	}
 	return true
 }
 
+// GetNextAuthMethods returns the list of authentications methods that
+// can continue for multi-step authentication
+func (u *User) GetNextAuthMethods(partialSuccessMethods []string) []string {
+	var methods []string
+	if len(partialSuccessMethods) != 1 {
+		return methods
+	}
+	if partialSuccessMethods[0] != SSHLoginMethodPublicKey {
+		return methods
+	}
+	for _, method := range u.GetAllowedLoginMethods() {
+		if method == SSHLoginMethodKeyAndPassword {
+			methods = append(methods, SSHLoginMethodPassword)
+		}
+		if method == SSHLoginMethodKeyAndKeyboardInt {
+			methods = append(methods, SSHLoginMethodKeyboardInteractive)
+		}
+	}
+	return methods
+}
+
+// IsPartialAuth returns true if the specified login method is a step for
+// a multi-step Authentication.
+// We support publickey+password and publickey+keyboard-interactive, so
+// only publickey can returns partial success.
+// We can have partial success if only multi-step Auth methods are enabled
+func (u *User) IsPartialAuth(loginMethod string) bool {
+	if loginMethod != SSHLoginMethodPublicKey {
+		return false
+	}
+	for _, method := range u.GetAllowedLoginMethods() {
+		if !utils.IsStringInSlice(method, SSHMultiStepsLoginMethods) {
+			return false
+		}
+	}
+	return true
+}
+
+// GetAllowedLoginMethods returns the allowed login methods
+func (u *User) GetAllowedLoginMethods() []string {
+	var allowedMethods []string
+	for _, method := range ValidSSHLoginMethods {
+		if !utils.IsStringInSlice(method, u.Filters.DeniedLoginMethods) {
+			allowedMethods = append(allowedMethods, method)
+		}
+	}
+	return allowedMethods
+}
+
 // IsFileAllowed returns true if the specified file is allowed by the file restrictions filters
 func (u *User) IsFileAllowed(sftpPath string) bool {
 	if len(u.Filters.FileExtensions) == 0 {

+ 3 - 1
docs/account.md

@@ -30,10 +30,12 @@ For each account, the following properties can be configured:
 - `download_bandwidth` maximum download bandwidth as KB/s, 0 means unlimited.
 - `allowed_ip`, List of IP/Mask allowed to login. Any IP address not contained in this list cannot login. IP/Mask must be in CIDR notation as defined in RFC 4632 and RFC 4291, for example "192.0.2.0/24" or "2001:db8::/32"
 - `denied_ip`, List of IP/Mask not allowed to login. If an IP address is both allowed and denied then login will be denied
-- `denied_login_methods`, List of login methods not allowed. The following login methods are supported:
+- `denied_login_methods`, List of login methods not allowed. To enable multi-step authentication you have to allow only multi-step login methods. The following login methods are supported:
   - `publickey`
   - `password`
   - `keyboard-interactive`
+  - `publickey+password`
+  - `publickey+keyboard-interactive`
 - `file_extensions`, list of struct. These restrictions do not apply to files listing for performance reasons, so a denied file cannot be downloaded/overwritten/renamed but it will still be listed in the list of files. Please note that these restrictions can be easily bypassed. Each struct contains the following fields:
   - `allowed_extensions`, list of, case insensitive, allowed files extension. Shell like expansion is not supported so you have to specify `.jpg` and not `*.jpg`. Any file that does not end with this suffix will be denied
   - `denied_extensions`, list of, case insensitive, denied files extension. Denied file extensions are evaluated before the allowed ones

+ 1 - 1
docs/metrics.md

@@ -10,7 +10,7 @@ Several counters and gauges are available, for example:
 - Total SSH command errors
 - Number of active connections
 - Data provider availability
-- Total successful and failed logins using password, public key or keyboard interactive authentication
+- Total successful and failed logins using password, public key, keyboard interactive authentication or supported multi-step authentications
 - Total HTTP requests served and totals for response code
 - Go's runtime details about GC, number of gouroutines and OS threads
 - Process information like CPU, memory, file descriptor usage and start time

+ 2 - 2
go.mod

@@ -43,6 +43,6 @@ require (
 )
 
 replace (
-	github.com/eikenb/pipeat => github.com/drakkan/pipeat v0.0.0-20200315002837-010186aaa07d
-	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20200403175740-a4c92a934d79
+	github.com/eikenb/pipeat => github.com/drakkan/pipeat v0.0.0-20200327213700-f3a27a751cdc
+	golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20200409210311-95730af1ff98
 )

+ 4 - 4
go.sum

@@ -61,10 +61,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
 github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
-github.com/drakkan/crypto v0.0.0-20200403175740-a4c92a934d79 h1:dARy4jBz2xeVoGq3CTZ8dsQTXB4G833M5cIq7rMtEuw=
-github.com/drakkan/crypto v0.0.0-20200403175740-a4c92a934d79/go.mod h1:v3bhWOXGYda7H5d2s5t9XA6th3fxW3s0MQxU1R96G/w=
-github.com/drakkan/pipeat v0.0.0-20200315002837-010186aaa07d h1:qD1b7ZnrTUscSof+W+Pa3D9hN4jmQ/UcoZ05q7W96rA=
-github.com/drakkan/pipeat v0.0.0-20200315002837-010186aaa07d/go.mod h1:wNYvIpR5rIhoezOYcpxcXz4HbIEOu7A45EqlQCA+h+w=
+github.com/drakkan/crypto v0.0.0-20200409210311-95730af1ff98 h1:5yTRWoqE1GhPfctw9E0wh1d3+xHajMyhQAQuxuNqwQ8=
+github.com/drakkan/crypto v0.0.0-20200409210311-95730af1ff98/go.mod h1:v3bhWOXGYda7H5d2s5t9XA6th3fxW3s0MQxU1R96G/w=
+github.com/drakkan/pipeat v0.0.0-20200327213700-f3a27a751cdc h1:aDQn2DUfujideR+efH3697MZk+9YY4T95P7BujKkWiY=
+github.com/drakkan/pipeat v0.0.0-20200327213700-f3a27a751cdc/go.mod h1:+JPhBw5JdJrSF80r6xsSg1TYHjyAGxYs4X24VyUdMZU=
 github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=

+ 1 - 2
httpd/httpd_test.go

@@ -327,8 +327,7 @@ func TestAddUserInvalidFilters(t *testing.T) {
 	if err != nil {
 		t.Errorf("unexpected error adding user with invalid filters: %v", err)
 	}
-	u.Filters.DeniedLoginMethods = []string{dataprovider.SSHLoginMethodKeyboardInteractive,
-		dataprovider.SSHLoginMethodPassword, dataprovider.SSHLoginMethodPublicKey}
+	u.Filters.DeniedLoginMethods = dataprovider.ValidSSHLoginMethods
 	_, _, err = httpd.AddUser(u, http.StatusBadRequest)
 	if err != nil {
 		t.Errorf("unexpected error adding user with invalid filters: %v", err)

+ 2 - 0
httpd/schema/openapi.yaml

@@ -933,6 +933,8 @@ components:
         - 'publickey'
         - 'password'
         - 'keyboard-interactive'
+        - 'publickey+password'
+        - 'publickey+keyboard-interactive'
     ExtensionsFilter:
       type: object
       properties:

+ 54 - 0
metrics/metrics.go

@@ -148,6 +148,48 @@ var (
 		Help: "The total number of failed logins using keyboard interactive authentication",
 	})
 
+	// totalKeyAndPasswordLoginAttempts is the metric that reports the total number of
+	// login attempts using public key + password multi steps auth
+	totalKeyAndPasswordLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_key_and_password_login_attempts_total",
+		Help: "The total number of login attempts using public key + password",
+	})
+
+	// totalKeyAndPasswordLoginOK is the metric that reports the total number of
+	// successful logins using public key + password multi steps auth
+	totalKeyAndPasswordLoginOK = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_key_and_password_login_ok_total",
+		Help: "The total number of successful logins using public key + password",
+	})
+
+	// totalKeyAndPasswordLoginFailed is the metric that reports the total number of
+	// failed logins using public key + password multi steps auth
+	totalKeyAndPasswordLoginFailed = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_key_and_password_login_ko_total",
+		Help: "The total number of failed logins using  public key + password",
+	})
+
+	// totalKeyAndKeyIntLoginAttempts is the metric that reports the total number of
+	// login attempts using public key + keyboard interactive multi steps auth
+	totalKeyAndKeyIntLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_key_and_keyboard_int_login_attempts_total",
+		Help: "The total number of login attempts using public key + keyboard interactive",
+	})
+
+	// totalKeyAndKeyIntLoginOK is the metric that reports the total number of
+	// successful logins using public key + keyboard interactive multi steps auth
+	totalKeyAndKeyIntLoginOK = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_key_and_keyboard_int_login_ok_total",
+		Help: "The total number of successful logins using public key + keyboard interactive",
+	})
+
+	// totalKeyAndKeyIntLoginFailed is the metric that reports the total number of
+	// failed logins using public key + keyboard interactive multi steps auth
+	totalKeyAndKeyIntLoginFailed = promauto.NewCounter(prometheus.CounterOpts{
+		Name: "sftpgo_key_and_keyboard_int_login_ko_total",
+		Help: "The total number of failed logins using  public key + keyboard interactive",
+	})
+
 	totalHTTPRequests = promauto.NewCounter(prometheus.CounterOpts{
 		Name: "sftpgo_http_req_total",
 		Help: "The total number of HTTP requests served",
@@ -498,6 +540,10 @@ func AddLoginAttempt(authMethod string) {
 		totalKeyLoginAttempts.Inc()
 	case "keyboard-interactive":
 		totalInteractiveLoginAttempts.Inc()
+	case "publickey+password":
+		totalKeyAndPasswordLoginAttempts.Inc()
+	case "publickey+keyboard-interactive":
+		totalKeyAndKeyIntLoginAttempts.Inc()
 	default:
 		totalPasswordLoginAttempts.Inc()
 	}
@@ -512,6 +558,10 @@ func AddLoginResult(authMethod string, err error) {
 			totalKeyLoginOK.Inc()
 		case "keyboard-interactive":
 			totalInteractiveLoginOK.Inc()
+		case "publickey+password":
+			totalKeyAndPasswordLoginOK.Inc()
+		case "publickey+keyboard-interactive":
+			totalKeyAndKeyIntLoginOK.Inc()
 		default:
 			totalPasswordLoginOK.Inc()
 		}
@@ -522,6 +572,10 @@ func AddLoginResult(authMethod string, err error) {
 			totalKeyLoginFailed.Inc()
 		case "keyboard-interactive":
 			totalInteractiveLoginFailed.Inc()
+		case "publickey+password":
+			totalKeyAndPasswordLoginFailed.Inc()
+		case "publickey+keyboard-interactive":
+			totalKeyAndKeyIntLoginFailed.Inc()
 		default:
 			totalPasswordLoginFailed.Inc()
 		}

+ 2 - 1
scripts/sftpgo_api_cli.py

@@ -503,7 +503,8 @@ def addCommonUserArguments(parser):
 							'create_symlinks', 'chmod', 'chown', 'chtimes'], help='Permissions for the root directory '
 							+'(/). Default: %(default)s')
 	parser.add_argument('-L', '--denied-login-methods', type=str, nargs='+', default=[],
-					choices=['', 'publickey', 'password', 'keyboard-interactive'], help='Default: %(default)s')
+					choices=['', 'publickey', 'password', 'keyboard-interactive', 'publickey+password',
+							'publickey+keyboard-interactive'], help='Default: %(default)s')
 	parser.add_argument('--subdirs-permissions', type=str, nargs='*', default=[], help='Permissions for subdirs. '
 					+'For example: "/somedir::list,download" "/otherdir/subdir::*" Default: %(default)s')
 	parser.add_argument('--virtual-folders', type=str, nargs='*', default=[], help='Virtual folder mapping. For example: '

+ 1 - 1
sftpd/internal_test.go

@@ -484,7 +484,7 @@ func TestUploadFiles(t *testing.T) {
 func TestWithInvalidHome(t *testing.T) {
 	u := dataprovider.User{}
 	u.HomeDir = "home_rel_path"
-	_, err := loginUser(u, dataprovider.SSHLoginMethodPassword, "", "")
+	_, err := loginUser(u, dataprovider.SSHLoginMethodPassword, "", nil)
 	if err == nil {
 		t.Errorf("login a user with an invalid home_dir must fail")
 	}

+ 66 - 39
sftpd/server.go

@@ -158,42 +158,32 @@ func (c Configuration) Initialize(configDir string) error {
 			return sp, nil
 		},
 		PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
-			sp, err := c.validatePublicKeyCredentials(conn, string(pubKey.Marshal()))
+			sp, err := c.validatePublicKeyCredentials(conn, pubKey.Marshal())
+			if err == ssh.ErrPartialSuccess {
+				return nil, err
+			}
 			if err != nil {
 				return nil, &authenticationError{err: fmt.Sprintf("could not validate public key credentials: %v", err)}
 			}
 
 			return sp, nil
 		},
-		ServerVersion: "SSH-2.0-" + c.Banner,
+		NextAuthMethodsCallback: func(conn ssh.ConnMetadata) []string {
+			var nextMethods []string
+			user, err := dataprovider.UserExists(dataProvider, conn.User())
+			if err == nil {
+				nextMethods = user.GetNextAuthMethods(conn.PartialSuccessMethods())
+			}
+			return nextMethods
+		},
+		ServerVersion: fmt.Sprintf("SSH-2.0-%v", c.Banner),
 	}
 
-	err = c.checkHostKeys(configDir)
+	err = c.checkAndLoadHostKeys(configDir, serverConfig)
 	if err != nil {
 		return err
 	}
 
-	for _, k := range c.Keys {
-		privateFile := k.PrivateKey
-		if !filepath.IsAbs(privateFile) {
-			privateFile = filepath.Join(configDir, privateFile)
-		}
-		logger.Info(logSender, "", "Loading private key: %s", privateFile)
-
-		privateBytes, err := ioutil.ReadFile(privateFile)
-		if err != nil {
-			return err
-		}
-
-		private, err := ssh.ParsePrivateKey(privateBytes)
-		if err != nil {
-			return err
-		}
-
-		// Add private key to the server configuration.
-		serverConfig.AddHostKey(private)
-	}
-
 	c.configureSecurityOptions(serverConfig)
 	c.configureKeyboardInteractiveAuth(serverConfig)
 	c.configureLoginBanner(serverConfig, configDir)
@@ -285,9 +275,10 @@ func (c Configuration) configureLoginBanner(serverConfig *ssh.ServerConfig, conf
 		if !filepath.IsAbs(bannerFilePath) {
 			bannerFilePath = filepath.Join(configDir, bannerFilePath)
 		}
-		var banner []byte
-		banner, err = ioutil.ReadFile(bannerFilePath)
+		var bannerContent []byte
+		bannerContent, err = ioutil.ReadFile(bannerFilePath)
 		if err == nil {
+			banner := string(bannerContent)
 			serverConfig.BannerCallback = func(conn ssh.ConnMetadata) string {
 				return string(banner)
 			}
@@ -459,9 +450,13 @@ func (c Configuration) createHandler(connection Connection) sftp.Handlers {
 	}
 }
 
-func loginUser(user dataprovider.User, loginMethod, remoteAddr, publicKey string) (*ssh.Permissions, error) {
+func loginUser(user dataprovider.User, loginMethod, publicKey string, conn ssh.ConnMetadata) (*ssh.Permissions, error) {
+	connectionID := ""
+	if conn != nil {
+		connectionID = hex.EncodeToString(conn.SessionID())
+	}
 	if !filepath.IsAbs(user.HomeDir) {
-		logger.Warn(logSender, "", "user %#v has an invalid home dir: %#v. Home dir must be an absolute path, login not allowed",
+		logger.Warn(logSender, connectionID, "user %#v has an invalid home dir: %#v. Home dir must be an absolute path, login not allowed",
 			user.Username, user.HomeDir)
 		return nil, fmt.Errorf("cannot login user with invalid home dir: %#v", user.HomeDir)
 	}
@@ -473,18 +468,19 @@ func loginUser(user dataprovider.User, loginMethod, remoteAddr, publicKey string
 			return nil, fmt.Errorf("too many open sessions: %v", activeSessions)
 		}
 	}
-	if !user.IsLoginMethodAllowed(loginMethod) {
-		logger.Debug(logSender, "", "cannot login user %#v, login method %#v is not allowed", user.Username, loginMethod)
+	if !user.IsLoginMethodAllowed(loginMethod, conn.PartialSuccessMethods()) {
+		logger.Debug(logSender, connectionID, "cannot login user %#v, login method %#v is not allowed", user.Username, loginMethod)
 		return nil, fmt.Errorf("Login method %#v is not allowed for user %#v", loginMethod, user.Username)
 	}
+	remoteAddr := conn.RemoteAddr().String()
 	if !user.IsLoginFromAddrAllowed(remoteAddr) {
-		logger.Debug(logSender, "", "cannot login user %#v, remote address is not allowed: %v", user.Username, remoteAddr)
+		logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, remoteAddr)
 		return nil, fmt.Errorf("Login for user %#v is not allowed from this address: %v", user.Username, remoteAddr)
 	}
 
 	json, err := json.Marshal(user)
 	if err != nil {
-		logger.Warn(logSender, "", "error serializing user info: %v, authentication rejected", err)
+		logger.Warn(logSender, connectionID, "error serializing user info: %v, authentication rejected", err)
 		return nil, err
 	}
 	if len(publicKey) > 0 {
@@ -514,8 +510,8 @@ func (c *Configuration) checkSSHCommands() {
 	c.EnabledSSHCommands = sshCommands
 }
 
-// If no host keys are defined we try to use or generate the default one.
-func (c *Configuration) checkHostKeys(configDir string) error {
+// If no host keys are defined we try to use or generate the default ones.
+func (c *Configuration) checkAndLoadHostKeys(configDir string, serverConfig *ssh.ServerConfig) error {
 	if len(c.Keys) == 0 {
 		defaultKeys := []string{defaultPrivateRSAKeyName, defaultPrivateECDSAKeyName}
 		for _, k := range defaultKeys {
@@ -535,20 +531,45 @@ func (c *Configuration) checkHostKeys(configDir string) error {
 			c.Keys = append(c.Keys, Key{PrivateKey: k})
 		}
 	}
+	for _, k := range c.Keys {
+		privateFile := k.PrivateKey
+		if !filepath.IsAbs(privateFile) {
+			privateFile = filepath.Join(configDir, privateFile)
+		}
+		logger.Info(logSender, "", "Loading private key: %s", privateFile)
+
+		privateBytes, err := ioutil.ReadFile(privateFile)
+		if err != nil {
+			return err
+		}
+
+		private, err := ssh.ParsePrivateKey(privateBytes)
+		if err != nil {
+			return err
+		}
+
+		// Add private key to the server configuration.
+		serverConfig.AddHostKey(private)
+	}
 	return nil
 }
 
-func (c Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey string) (*ssh.Permissions, error) {
+func (c Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey []byte) (*ssh.Permissions, error) {
 	var err error
 	var user dataprovider.User
 	var keyID string
 	var sshPerm *ssh.Permissions
 
+	connectionID := hex.EncodeToString(conn.SessionID())
 	method := dataprovider.SSHLoginMethodPublicKey
-	metrics.AddLoginAttempt(method)
 	if user, keyID, err = dataprovider.CheckUserAndPubKey(dataProvider, conn.User(), pubKey); err == nil {
-		sshPerm, err = loginUser(user, method, conn.RemoteAddr().String(), keyID)
+		if user.IsPartialAuth(method) {
+			logger.Debug(logSender, connectionID, "user %#v authenticated with partial success", conn.User())
+			return nil, ssh.ErrPartialSuccess
+		}
+		sshPerm, err = loginUser(user, method, keyID, conn)
 	}
+	metrics.AddLoginAttempt(method)
 	if err != nil {
 		logger.ConnectionFailedLog(conn.User(), utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()), method, err.Error())
 	}
@@ -562,9 +583,12 @@ func (c Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass [
 	var sshPerm *ssh.Permissions
 
 	method := dataprovider.SSHLoginMethodPassword
+	if len(conn.PartialSuccessMethods()) == 1 {
+		method = dataprovider.SSHLoginMethodKeyAndPassword
+	}
 	metrics.AddLoginAttempt(method)
 	if user, err = dataprovider.CheckUserAndPass(dataProvider, conn.User(), string(pass)); err == nil {
-		sshPerm, err = loginUser(user, method, conn.RemoteAddr().String(), "")
+		sshPerm, err = loginUser(user, method, "", conn)
 	}
 	if err != nil {
 		logger.ConnectionFailedLog(conn.User(), utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()), method, err.Error())
@@ -579,9 +603,12 @@ func (c Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMetad
 	var sshPerm *ssh.Permissions
 
 	method := dataprovider.SSHLoginMethodKeyboardInteractive
+	if len(conn.PartialSuccessMethods()) == 1 {
+		method = dataprovider.SSHLoginMethodKeyAndKeyboardInt
+	}
 	metrics.AddLoginAttempt(method)
 	if user, err = dataprovider.CheckKeyboardInteractiveAuth(dataProvider, conn.User(), c.KeyboardInteractiveHook, client); err == nil {
-		sshPerm, err = loginUser(user, method, conn.RemoteAddr().String(), "")
+		sshPerm, err = loginUser(user, method, "", conn)
 	}
 	if err != nil {
 		logger.ConnectionFailedLog(conn.User(), utils.GetIPFromRemoteAddress(conn.RemoteAddr().String()), method, err.Error())

+ 276 - 0
sftpd/sftpd_test.go

@@ -999,6 +999,120 @@ func TestLogin(t *testing.T) {
 	os.RemoveAll(user.GetHomeDir())
 }
 
+func TestMultiStepLoginKeyAndPwd(t *testing.T) {
+	u := getTestUser(true)
+	u.Password = defaultPassword
+	u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, []string{
+		dataprovider.SSHLoginMethodKeyAndKeyboardInt,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}...)
+	user, _, err := httpd.AddUser(u, http.StatusOK)
+	if err != nil {
+		t.Errorf("unable to add user: %v", err)
+	}
+	_, err = getSftpClient(user, true)
+	if err == nil {
+		t.Error("login with public key is disallowed and must fail")
+	}
+	_, err = getSftpClient(user, true)
+	if err == nil {
+		t.Error("login with password is disallowed and must fail")
+	}
+	key, _ := ssh.ParsePrivateKey([]byte(testPrivateKey))
+	authMethods := []ssh.AuthMethod{
+		ssh.PublicKeys(key),
+		ssh.Password(defaultPassword),
+	}
+	client, err := getCustomAuthSftpClient(user, authMethods)
+	if err != nil {
+		t.Errorf("unable to create sftp client: %v", err)
+	} else {
+		defer client.Close()
+		_, err := client.Getwd()
+		if err != nil {
+			t.Errorf("unexpected error: %v", err)
+		}
+	}
+	authMethods = []ssh.AuthMethod{
+		ssh.Password(defaultPassword),
+		ssh.PublicKeys(key),
+	}
+	_, err = getCustomAuthSftpClient(user, authMethods)
+	if err == nil {
+		t.Error("multi step auth login with wrong order must fail")
+	}
+	_, err = httpd.RemoveUser(user, http.StatusOK)
+	if err != nil {
+		t.Errorf("unable to remove user: %v", err)
+	}
+	os.RemoveAll(user.GetHomeDir())
+}
+
+func TestMultiStepLoginKeyAndKeyInt(t *testing.T) {
+	if runtime.GOOS == "windows" {
+		t.Skip("this test is not available on Windows")
+	}
+	u := getTestUser(true)
+	u.Password = defaultPassword
+	u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, []string{
+		dataprovider.SSHLoginMethodKeyAndPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}...)
+	user, _, err := httpd.AddUser(u, http.StatusOK)
+	if err != nil {
+		t.Errorf("unable to add user: %v", err)
+	}
+	ioutil.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, false, 1), 0755)
+	_, err = getSftpClient(user, true)
+	if err == nil {
+		t.Error("login with public key is disallowed and must fail")
+	}
+	key, _ := ssh.ParsePrivateKey([]byte(testPrivateKey))
+	authMethods := []ssh.AuthMethod{
+		ssh.PublicKeys(key),
+		ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
+			return []string{"1", "2"}, nil
+		}),
+	}
+	client, err := getCustomAuthSftpClient(user, authMethods)
+	if err != nil {
+		t.Errorf("unable to create sftp client: %v", err)
+	} else {
+		defer client.Close()
+		_, err := client.Getwd()
+		if err != nil {
+			t.Errorf("unexpected error: %v", err)
+		}
+	}
+	authMethods = []ssh.AuthMethod{
+		ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
+			return []string{"1", "2"}, nil
+		}),
+		ssh.PublicKeys(key),
+	}
+	_, err = getCustomAuthSftpClient(user, authMethods)
+	if err == nil {
+		t.Error("multi step auth login with wrong order must fail")
+	}
+	authMethods = []ssh.AuthMethod{
+		ssh.PublicKeys(key),
+		ssh.Password(defaultPassword),
+	}
+	_, err = getCustomAuthSftpClient(user, authMethods)
+	if err == nil {
+		t.Error("multi step auth login with wrong method must fail")
+	}
+	_, err = httpd.RemoveUser(user, http.StatusOK)
+	if err != nil {
+		t.Errorf("unable to remove user: %v", err)
+	}
+	os.RemoveAll(user.GetHomeDir())
+}
+
 func TestLoginUserStatus(t *testing.T) {
 	usePubKey := true
 	user, _, err := httpd.AddUser(getTestUser(usePubKey), http.StatusOK)
@@ -3823,6 +3937,151 @@ func TestFilterFileExtensions(t *testing.T) {
 	}
 }
 
+func TestUserAllowedLoginMethods(t *testing.T) {
+	user := getTestUser(true)
+	user.Filters.DeniedLoginMethods = dataprovider.ValidSSHLoginMethods
+	allowedMethods := user.GetAllowedLoginMethods()
+	if len(allowedMethods) != 0 {
+		t.Errorf("unexpected allowed methods: %+v", allowedMethods)
+	}
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}
+	allowedMethods = user.GetAllowedLoginMethods()
+	if len(allowedMethods) != 2 {
+		t.Errorf("unexpected allowed methods: %+v", allowedMethods)
+	}
+	if !utils.IsStringInSlice(dataprovider.SSHLoginMethodKeyAndKeyboardInt, allowedMethods) {
+		t.Errorf("unexpected allowed methods: %+v", allowedMethods)
+	}
+	if !utils.IsStringInSlice(dataprovider.SSHLoginMethodKeyAndPassword, allowedMethods) {
+		t.Errorf("unexpected allowed methods: %+v", allowedMethods)
+	}
+}
+
+func TestUserPartialAuth(t *testing.T) {
+	user := getTestUser(true)
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}
+	if user.IsPartialAuth(dataprovider.SSHLoginMethodPassword) {
+		t.Error("unexpected partial auth method")
+	}
+	if user.IsPartialAuth(dataprovider.SSHLoginMethodKeyboardInteractive) {
+		t.Error("unexpected partial auth method")
+	}
+	if !user.IsPartialAuth(dataprovider.SSHLoginMethodPublicKey) {
+		t.Error("public key must be a partial auth method with this configuration")
+	}
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}
+	if user.IsPartialAuth(dataprovider.SSHLoginMethodPublicKey) {
+		t.Error("public key must not be a partial auth method with this configuration")
+	}
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+	}
+	if user.IsPartialAuth(dataprovider.SSHLoginMethodPublicKey) {
+		t.Error("public key must not be a partial auth method with this configuration")
+	}
+}
+
+func TestUserGetNextAuthMethods(t *testing.T) {
+	user := getTestUser(true)
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}
+	methods := user.GetNextAuthMethods(nil)
+	if len(methods) != 0 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodPassword})
+	if len(methods) != 0 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodKeyboardInteractive})
+	if len(methods) != 0 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	methods = user.GetNextAuthMethods([]string{
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	})
+	if len(methods) != 0 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodPublicKey})
+	if len(methods) != 2 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	if !utils.IsStringInSlice(dataprovider.SSHLoginMethodPassword, methods) {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	if !utils.IsStringInSlice(dataprovider.SSHLoginMethodKeyboardInteractive, methods) {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+		dataprovider.SSHLoginMethodKeyAndKeyboardInt,
+	}
+	methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodPublicKey})
+	if len(methods) != 1 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	if !utils.IsStringInSlice(dataprovider.SSHLoginMethodPassword, methods) {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+		dataprovider.SSHLoginMethodKeyAndPassword,
+	}
+	methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodPublicKey})
+	if len(methods) != 1 {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+	if !utils.IsStringInSlice(dataprovider.SSHLoginMethodKeyboardInteractive, methods) {
+		t.Errorf("unexpected next auth methods: %+v", methods)
+	}
+}
+
+func TestUserIsLoginMethodAllowed(t *testing.T) {
+	user := getTestUser(true)
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPassword,
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}
+	if user.IsLoginMethodAllowed(dataprovider.SSHLoginMethodPassword, nil) {
+		t.Error("unexpected login method allowed")
+	}
+	if !user.IsLoginMethodAllowed(dataprovider.SSHLoginMethodPassword, []string{dataprovider.SSHLoginMethodPublicKey}) {
+		t.Error("unexpected login method denied")
+	}
+	if !user.IsLoginMethodAllowed(dataprovider.SSHLoginMethodKeyboardInteractive, []string{dataprovider.SSHLoginMethodPublicKey}) {
+		t.Error("unexpected login method denied")
+	}
+	user.Filters.DeniedLoginMethods = []string{
+		dataprovider.SSHLoginMethodPublicKey,
+		dataprovider.SSHLoginMethodKeyboardInteractive,
+	}
+	if !user.IsLoginMethodAllowed(dataprovider.SSHLoginMethodPassword, nil) {
+		t.Error("unexpected login method denied")
+	}
+}
+
 func TestUserEmptySubDirPerms(t *testing.T) {
 	user := getTestUser(true)
 	user.Permissions = make(map[string][]string)
@@ -5041,6 +5300,23 @@ func getKeyboardInteractiveSftpClient(user dataprovider.User, answers []string)
 	return sftpClient, err
 }
 
+func getCustomAuthSftpClient(user dataprovider.User, authMethods []ssh.AuthMethod) (*sftp.Client, error) {
+	var sftpClient *sftp.Client
+	config := &ssh.ClientConfig{
+		User: user.Username,
+		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
+			return nil
+		},
+		Auth: authMethods,
+	}
+	conn, err := ssh.Dial("tcp", sftpServerAddr, config)
+	if err != nil {
+		return sftpClient, err
+	}
+	sftpClient, err = sftp.NewClient(conn)
+	return sftpClient, err
+}
+
 func createTestFile(path string, size int64) error {
 	baseDir := filepath.Dir(path)
 	if _, err := os.Stat(baseDir); os.IsNotExist(err) {