|  | @@ -22,6 +22,7 @@ import (
 | 
	
		
			
				|  |  |  	"fmt"
 | 
	
		
			
				|  |  |  	"io"
 | 
	
		
			
				|  |  |  	"io/fs"
 | 
	
		
			
				|  |  | +	"maps"
 | 
	
		
			
				|  |  |  	"net"
 | 
	
		
			
				|  |  |  	"os"
 | 
	
		
			
				|  |  |  	"path/filepath"
 | 
	
	
		
			
				|  | @@ -32,7 +33,6 @@ import (
 | 
	
		
			
				|  |  |  	"time"
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	"github.com/pkg/sftp"
 | 
	
		
			
				|  |  | -	"github.com/rs/xid"
 | 
	
		
			
				|  |  |  	"github.com/sftpgo/sdk/plugin/notifier"
 | 
	
		
			
				|  |  |  	"golang.org/x/crypto/ssh"
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -52,6 +52,10 @@ const (
 | 
	
		
			
				|  |  |  	defaultPrivateEd25519KeyName      = "id_ed25519"
 | 
	
		
			
				|  |  |  	sourceAddressCriticalOption       = "source-address"
 | 
	
		
			
				|  |  |  	keyExchangeCurve25519SHA256LibSSH = "[email protected]"
 | 
	
		
			
				|  |  | +	extraDataPartialSuccessErrKey     = "partialSuccessErr"
 | 
	
		
			
				|  |  | +	extraDataUserKey                  = "user"
 | 
	
		
			
				|  |  | +	extraDataKeyIDKey                 = "keyID"
 | 
	
		
			
				|  |  | +	extraDataLoginMethodKey           = "login_method"
 | 
	
		
			
				|  |  |  )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  var (
 | 
	
	
		
			
				|  | @@ -233,10 +237,6 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
 | 
	
		
			
				|  |  |  		MaxAuthTries: c.MaxAuthTries,
 | 
	
		
			
				|  |  |  		PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
 | 
	
		
			
				|  |  |  			sp, err := c.validatePublicKeyCredentials(conn, pubKey)
 | 
	
		
			
				|  |  | -			var partialSuccess *ssh.PartialSuccessError
 | 
	
		
			
				|  |  | -			if errors.As(err, &partialSuccess) {
 | 
	
		
			
				|  |  | -				return sp, err
 | 
	
		
			
				|  |  | -			}
 | 
	
		
			
				|  |  |  			if err != nil {
 | 
	
		
			
				|  |  |  				return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err),
 | 
	
		
			
				|  |  |  					dataprovider.SSHLoginMethodPublicKey, conn.User())
 | 
	
	
		
			
				|  | @@ -244,6 +244,35 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  			return sp, nil
 | 
	
		
			
				|  |  |  		},
 | 
	
		
			
				|  |  | +		VerifiedPublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey, permissions *ssh.Permissions, signatureAlgorithm string) (*ssh.Permissions, error) {
 | 
	
		
			
				|  |  | +			if partialErr, ok := permissions.ExtraData[extraDataPartialSuccessErrKey]; ok {
 | 
	
		
			
				|  |  | +				logger.Info(logSender, hex.EncodeToString(conn.SessionID()), "user %q authenticated with partial success, signature algorithm %q",
 | 
	
		
			
				|  |  | +					conn.User(), signatureAlgorithm)
 | 
	
		
			
				|  |  | +				return nil, partialErr.(error)
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +			method := dataprovider.SSHLoginMethodPublicKey
 | 
	
		
			
				|  |  | +			user := permissions.ExtraData[extraDataUserKey].(dataprovider.User)
 | 
	
		
			
				|  |  | +			keyID := permissions.ExtraData[extraDataKeyIDKey].(string)
 | 
	
		
			
				|  |  | +			sshPerm, err := loginUser(&user, method, fmt.Sprintf("%s (%s)", keyID, signatureAlgorithm), conn)
 | 
	
		
			
				|  |  | +			if err == nil {
 | 
	
		
			
				|  |  | +				// if we have a SSH user cert we need to merge certificate permissions with our ones
 | 
	
		
			
				|  |  | +				// we only set Extensions, so CriticalOptions are always the ones from the certificate
 | 
	
		
			
				|  |  | +				sshPerm.CriticalOptions = permissions.CriticalOptions
 | 
	
		
			
				|  |  | +				if permissions.Extensions != nil {
 | 
	
		
			
				|  |  | +					if sshPerm.Extensions == nil {
 | 
	
		
			
				|  |  | +						sshPerm.Extensions = make(map[string]string)
 | 
	
		
			
				|  |  | +					}
 | 
	
		
			
				|  |  | +					maps.Copy(sshPerm.Extensions, permissions.Extensions)
 | 
	
		
			
				|  |  | +				}
 | 
	
		
			
				|  |  | +				if sshPerm.ExtraData == nil {
 | 
	
		
			
				|  |  | +					sshPerm.ExtraData = make(map[any]any)
 | 
	
		
			
				|  |  | +				}
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +			user.Username = conn.User()
 | 
	
		
			
				|  |  | +			ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String())
 | 
	
		
			
				|  |  | +			updateLoginMetrics(&user, ipAddr, method, err)
 | 
	
		
			
				|  |  | +			return sshPerm, err
 | 
	
		
			
				|  |  | +		},
 | 
	
		
			
				|  |  |  		ServerVersion: fmt.Sprintf("SSH-2.0-%s", version.GetServerVersion("_", false)),
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -591,13 +620,9 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	defer sconn.Close()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	var user dataprovider.User
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -	// Unmarshal cannot fails here and even if it fails we'll have a user with no permissions
 | 
	
		
			
				|  |  | -	json.Unmarshal(util.StringToBytes(sconn.Permissions.Extensions["sftpgo_user"]), &user) //nolint:errcheck
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -	loginType := sconn.Permissions.Extensions["sftpgo_login_method"]
 | 
	
		
			
				|  |  | -	connectionID := xid.New().String()
 | 
	
		
			
				|  |  | +	user := sconn.Permissions.ExtraData[extraDataUserKey].(dataprovider.User)
 | 
	
		
			
				|  |  | +	loginType := sconn.Permissions.ExtraData[extraDataLoginMethodKey].(string)
 | 
	
		
			
				|  |  | +	connectionID := hex.EncodeToString(sconn.SessionID())
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	defer user.CloseFs() //nolint:errcheck
 | 
	
		
			
				|  |  |  	if err = user.CheckFsRoot(connectionID); err != nil {
 | 
	
	
		
			
				|  | @@ -815,18 +840,13 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.
 | 
	
		
			
				|  |  |  		return nil, fmt.Errorf("login for user %q is not allowed from this address: %v", user.Username, remoteAddr)
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	json, err := json.Marshal(user)
 | 
	
		
			
				|  |  | -	if err != nil {
 | 
	
		
			
				|  |  | -		logger.Warn(logSender, connectionID, "error serializing user info: %v, authentication rejected", err)
 | 
	
		
			
				|  |  | -		return nil, err
 | 
	
		
			
				|  |  | -	}
 | 
	
		
			
				|  |  |  	if publicKey != "" {
 | 
	
		
			
				|  |  |  		loginMethod = fmt.Sprintf("%v: %v", loginMethod, publicKey)
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  	p := &ssh.Permissions{}
 | 
	
		
			
				|  |  | -	p.Extensions = make(map[string]string)
 | 
	
		
			
				|  |  | -	p.Extensions["sftpgo_user"] = util.BytesToString(json)
 | 
	
		
			
				|  |  | -	p.Extensions["sftpgo_login_method"] = loginMethod
 | 
	
		
			
				|  |  | +	p.ExtraData = make(map[any]any)
 | 
	
		
			
				|  |  | +	p.ExtraData[extraDataUserKey] = *user
 | 
	
		
			
				|  |  | +	p.ExtraData[extraDataLoginMethodKey] = loginMethod
 | 
	
		
			
				|  |  |  	return p, nil
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -1114,13 +1134,9 @@ func (c *Configuration) getPartialSuccessError(nextAuthMethods []string) error {
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
 | 
	
		
			
				|  |  | -	var err error
 | 
	
		
			
				|  |  |  	var user dataprovider.User
 | 
	
		
			
				|  |  | -	var keyID string
 | 
	
		
			
				|  |  | -	var sshPerm *ssh.Permissions
 | 
	
		
			
				|  |  |  	var certPerm *ssh.Permissions
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	connectionID := hex.EncodeToString(conn.SessionID())
 | 
	
		
			
				|  |  |  	method := dataprovider.SSHLoginMethodPublicKey
 | 
	
		
			
				|  |  |  	ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String())
 | 
	
		
			
				|  |  |  	cert, ok := pubKey.(*ssh.Certificate)
 | 
	
	
		
			
				|  | @@ -1128,25 +1144,25 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
 | 
	
		
			
				|  |  |  	if ok {
 | 
	
		
			
				|  |  |  		certFingerprint = ssh.FingerprintSHA256(cert.Key)
 | 
	
		
			
				|  |  |  		if cert.CertType != ssh.UserCert {
 | 
	
		
			
				|  |  | -			err = fmt.Errorf("ssh: cert has type %d", cert.CertType)
 | 
	
		
			
				|  |  | +			err := fmt.Errorf("ssh: cert has type %d", cert.CertType)
 | 
	
		
			
				|  |  |  			user.Username = conn.User()
 | 
	
		
			
				|  |  |  			updateLoginMetrics(&user, ipAddr, method, err)
 | 
	
		
			
				|  |  |  			return nil, err
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  		if !c.certChecker.IsUserAuthority(cert.SignatureKey) {
 | 
	
		
			
				|  |  | -			err = errors.New("ssh: certificate signed by unrecognized authority")
 | 
	
		
			
				|  |  | +			err := errors.New("ssh: certificate signed by unrecognized authority")
 | 
	
		
			
				|  |  |  			user.Username = conn.User()
 | 
	
		
			
				|  |  |  			updateLoginMetrics(&user, ipAddr, method, err)
 | 
	
		
			
				|  |  |  			return nil, err
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  		if len(cert.ValidPrincipals) == 0 {
 | 
	
		
			
				|  |  | -			err = fmt.Errorf("ssh: certificate %s has no valid principals, user: \"%s\"", certFingerprint, conn.User())
 | 
	
		
			
				|  |  | +			err := fmt.Errorf("ssh: certificate %s has no valid principals, user: \"%s\"", certFingerprint, conn.User())
 | 
	
		
			
				|  |  |  			user.Username = conn.User()
 | 
	
		
			
				|  |  |  			updateLoginMetrics(&user, ipAddr, method, err)
 | 
	
		
			
				|  |  |  			return nil, err
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  		if revokedCertManager.isRevoked(certFingerprint) {
 | 
	
		
			
				|  |  | -			err = fmt.Errorf("ssh: certificate %s is revoked", certFingerprint)
 | 
	
		
			
				|  |  | +			err := fmt.Errorf("ssh: certificate %s is revoked", certFingerprint)
 | 
	
		
			
				|  |  |  			user.Username = conn.User()
 | 
	
		
			
				|  |  |  			updateLoginMetrics(&user, ipAddr, method, err)
 | 
	
		
			
				|  |  |  			return nil, err
 | 
	
	
		
			
				|  | @@ -1158,30 +1174,26 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  		certPerm = &cert.Permissions
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	if user, keyID, err = dataprovider.CheckUserAndPubKey(conn.User(), pubKey.Marshal(), ipAddr, common.ProtocolSSH, ok); err == nil {
 | 
	
		
			
				|  |  | -		if ok {
 | 
	
		
			
				|  |  | -			keyID = fmt.Sprintf("%s: ID: %s, serial: %v, CA %s %s", certFingerprint,
 | 
	
		
			
				|  |  | -				cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey))
 | 
	
		
			
				|  |  | -		}
 | 
	
		
			
				|  |  | -		if user.IsPartialAuth() {
 | 
	
		
			
				|  |  | -			logger.Debug(logSender, connectionID, "user %q authenticated with partial success", conn.User())
 | 
	
		
			
				|  |  | -			return certPerm, c.getPartialSuccessError(user.GetNextAuthMethods())
 | 
	
		
			
				|  |  | -		}
 | 
	
		
			
				|  |  | -		sshPerm, err = loginUser(&user, method, keyID, conn)
 | 
	
		
			
				|  |  | -		if err == nil && certPerm != nil {
 | 
	
		
			
				|  |  | -			// if we have a SSH user cert we need to merge certificate permissions with our ones
 | 
	
		
			
				|  |  | -			// we only set Extensions, so CriticalOptions are always the ones from the certificate
 | 
	
		
			
				|  |  | -			sshPerm.CriticalOptions = certPerm.CriticalOptions
 | 
	
		
			
				|  |  | -			if certPerm.Extensions != nil {
 | 
	
		
			
				|  |  | -				for k, v := range certPerm.Extensions {
 | 
	
		
			
				|  |  | -					sshPerm.Extensions[k] = v
 | 
	
		
			
				|  |  | -				}
 | 
	
		
			
				|  |  | -			}
 | 
	
		
			
				|  |  | -		}
 | 
	
		
			
				|  |  | +	user, keyID, err := dataprovider.CheckUserAndPubKey(conn.User(), pubKey.Marshal(), ipAddr, common.ProtocolSSH, ok)
 | 
	
		
			
				|  |  | +	if err != nil {
 | 
	
		
			
				|  |  | +		user.Username = conn.User()
 | 
	
		
			
				|  |  | +		updateLoginMetrics(&user, ipAddr, method, err)
 | 
	
		
			
				|  |  | +		return nil, err
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	user.Username = conn.User()
 | 
	
		
			
				|  |  | -	updateLoginMetrics(&user, ipAddr, method, err)
 | 
	
		
			
				|  |  | -	return sshPerm, err
 | 
	
		
			
				|  |  | +	if ok {
 | 
	
		
			
				|  |  | +		keyID = fmt.Sprintf("%s: ID: %s, serial: %v, CA %s %s", certFingerprint,
 | 
	
		
			
				|  |  | +			cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey))
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	if certPerm == nil {
 | 
	
		
			
				|  |  | +		certPerm = &ssh.Permissions{}
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	certPerm.ExtraData = make(map[any]any)
 | 
	
		
			
				|  |  | +	certPerm.ExtraData[extraDataKeyIDKey] = keyID
 | 
	
		
			
				|  |  | +	certPerm.ExtraData[extraDataUserKey] = user
 | 
	
		
			
				|  |  | +	if user.IsPartialAuth() {
 | 
	
		
			
				|  |  | +		certPerm.ExtraData[extraDataPartialSuccessErrKey] = c.getPartialSuccessError(user.GetNextAuthMethods())
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	return certPerm, nil
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass []byte, method string) (*ssh.Permissions, error) {
 |