Просмотр исходного кода

sftpd: use VerifiedPublicKeyCallback

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 3 недель назад
Родитель
Сommit
f4092b9f9e
3 измененных файлов с 66 добавлено и 54 удалено
  1. 1 1
      go.mod
  2. 2 2
      go.sum
  3. 63 51
      internal/sftpd/server.go

+ 1 - 1
go.mod

@@ -64,7 +64,7 @@ require (
 	github.com/yl2chen/cidranger v1.0.3-0.20210928021809-d1cb2c52f37a
 	go.etcd.io/bbolt v1.4.3
 	gocloud.dev v0.43.0
-	golang.org/x/crypto v0.42.0
+	golang.org/x/crypto v0.42.1-0.20250927194341-2beaa59a3c99
 	golang.org/x/net v0.44.0
 	golang.org/x/oauth2 v0.31.0
 	golang.org/x/sys v0.36.0

+ 2 - 2
go.sum

@@ -420,8 +420,8 @@ golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf
 golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
 golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
 golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
-golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
-golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
+golang.org/x/crypto v0.42.1-0.20250927194341-2beaa59a3c99 h1:zYtc2MFSothOSrO593KfVa10ypmGMc4q31CirElMBFQ=
+golang.org/x/crypto v0.42.1-0.20250927194341-2beaa59a3c99/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
 golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
 golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=

+ 63 - 51
internal/sftpd/server.go

@@ -22,6 +22,7 @@ import (
 	"fmt"
 	"io"
 	"io/fs"
+	"maps"
 	"net"
 	"os"
 	"path/filepath"
@@ -32,7 +33,6 @@ import (
 	"time"
 
 	"github.com/pkg/sftp"
-	"github.com/rs/xid"
 	"github.com/sftpgo/sdk/plugin/notifier"
 	"golang.org/x/crypto/ssh"
 
@@ -52,6 +52,10 @@ const (
 	defaultPrivateEd25519KeyName      = "id_ed25519"
 	sourceAddressCriticalOption       = "source-address"
 	keyExchangeCurve25519SHA256LibSSH = "[email protected]"
+	extraDataPartialSuccessErrKey     = "partialSuccessErr"
+	extraDataUserKey                  = "user"
+	extraDataKeyIDKey                 = "keyID"
+	extraDataLoginMethodKey           = "login_method"
 )
 
 var (
@@ -233,10 +237,6 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
 		MaxAuthTries: c.MaxAuthTries,
 		PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
 			sp, err := c.validatePublicKeyCredentials(conn, pubKey)
-			var partialSuccess *ssh.PartialSuccessError
-			if errors.As(err, &partialSuccess) {
-				return sp, err
-			}
 			if err != nil {
 				return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err),
 					dataprovider.SSHLoginMethodPublicKey, conn.User())
@@ -244,6 +244,35 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
 
 			return sp, nil
 		},
+		VerifiedPublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey, permissions *ssh.Permissions, signatureAlgorithm string) (*ssh.Permissions, error) {
+			if partialErr, ok := permissions.ExtraData[extraDataPartialSuccessErrKey]; ok {
+				logger.Info(logSender, hex.EncodeToString(conn.SessionID()), "user %q authenticated with partial success, signature algorithm %q",
+					conn.User(), signatureAlgorithm)
+				return nil, partialErr.(error)
+			}
+			method := dataprovider.SSHLoginMethodPublicKey
+			user := permissions.ExtraData[extraDataUserKey].(dataprovider.User)
+			keyID := permissions.ExtraData[extraDataKeyIDKey].(string)
+			sshPerm, err := loginUser(&user, method, fmt.Sprintf("%s (%s)", keyID, signatureAlgorithm), conn)
+			if err == nil {
+				// if we have a SSH user cert we need to merge certificate permissions with our ones
+				// we only set Extensions, so CriticalOptions are always the ones from the certificate
+				sshPerm.CriticalOptions = permissions.CriticalOptions
+				if permissions.Extensions != nil {
+					if sshPerm.Extensions == nil {
+						sshPerm.Extensions = make(map[string]string)
+					}
+					maps.Copy(sshPerm.Extensions, permissions.Extensions)
+				}
+				if sshPerm.ExtraData == nil {
+					sshPerm.ExtraData = make(map[any]any)
+				}
+			}
+			user.Username = conn.User()
+			ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String())
+			updateLoginMetrics(&user, ipAddr, method, err)
+			return sshPerm, err
+		},
 		ServerVersion: fmt.Sprintf("SSH-2.0-%s", version.GetServerVersion("_", false)),
 	}
 
@@ -591,13 +620,9 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 
 	defer sconn.Close()
 
-	var user dataprovider.User
-
-	// Unmarshal cannot fails here and even if it fails we'll have a user with no permissions
-	json.Unmarshal(util.StringToBytes(sconn.Permissions.Extensions["sftpgo_user"]), &user) //nolint:errcheck
-
-	loginType := sconn.Permissions.Extensions["sftpgo_login_method"]
-	connectionID := xid.New().String()
+	user := sconn.Permissions.ExtraData[extraDataUserKey].(dataprovider.User)
+	loginType := sconn.Permissions.ExtraData[extraDataLoginMethodKey].(string)
+	connectionID := hex.EncodeToString(sconn.SessionID())
 
 	defer user.CloseFs() //nolint:errcheck
 	if err = user.CheckFsRoot(connectionID); err != nil {
@@ -815,18 +840,13 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.
 		return nil, fmt.Errorf("login for user %q is not allowed from this address: %v", user.Username, remoteAddr)
 	}
 
-	json, err := json.Marshal(user)
-	if err != nil {
-		logger.Warn(logSender, connectionID, "error serializing user info: %v, authentication rejected", err)
-		return nil, err
-	}
 	if publicKey != "" {
 		loginMethod = fmt.Sprintf("%v: %v", loginMethod, publicKey)
 	}
 	p := &ssh.Permissions{}
-	p.Extensions = make(map[string]string)
-	p.Extensions["sftpgo_user"] = util.BytesToString(json)
-	p.Extensions["sftpgo_login_method"] = loginMethod
+	p.ExtraData = make(map[any]any)
+	p.ExtraData[extraDataUserKey] = *user
+	p.ExtraData[extraDataLoginMethodKey] = loginMethod
 	return p, nil
 }
 
@@ -1114,13 +1134,9 @@ func (c *Configuration) getPartialSuccessError(nextAuthMethods []string) error {
 }
 
 func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
-	var err error
 	var user dataprovider.User
-	var keyID string
-	var sshPerm *ssh.Permissions
 	var certPerm *ssh.Permissions
 
-	connectionID := hex.EncodeToString(conn.SessionID())
 	method := dataprovider.SSHLoginMethodPublicKey
 	ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String())
 	cert, ok := pubKey.(*ssh.Certificate)
@@ -1128,25 +1144,25 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
 	if ok {
 		certFingerprint = ssh.FingerprintSHA256(cert.Key)
 		if cert.CertType != ssh.UserCert {
-			err = fmt.Errorf("ssh: cert has type %d", cert.CertType)
+			err := fmt.Errorf("ssh: cert has type %d", cert.CertType)
 			user.Username = conn.User()
 			updateLoginMetrics(&user, ipAddr, method, err)
 			return nil, err
 		}
 		if !c.certChecker.IsUserAuthority(cert.SignatureKey) {
-			err = errors.New("ssh: certificate signed by unrecognized authority")
+			err := errors.New("ssh: certificate signed by unrecognized authority")
 			user.Username = conn.User()
 			updateLoginMetrics(&user, ipAddr, method, err)
 			return nil, err
 		}
 		if len(cert.ValidPrincipals) == 0 {
-			err = fmt.Errorf("ssh: certificate %s has no valid principals, user: \"%s\"", certFingerprint, conn.User())
+			err := fmt.Errorf("ssh: certificate %s has no valid principals, user: \"%s\"", certFingerprint, conn.User())
 			user.Username = conn.User()
 			updateLoginMetrics(&user, ipAddr, method, err)
 			return nil, err
 		}
 		if revokedCertManager.isRevoked(certFingerprint) {
-			err = fmt.Errorf("ssh: certificate %s is revoked", certFingerprint)
+			err := fmt.Errorf("ssh: certificate %s is revoked", certFingerprint)
 			user.Username = conn.User()
 			updateLoginMetrics(&user, ipAddr, method, err)
 			return nil, err
@@ -1158,30 +1174,26 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
 		}
 		certPerm = &cert.Permissions
 	}
-	if user, keyID, err = dataprovider.CheckUserAndPubKey(conn.User(), pubKey.Marshal(), ipAddr, common.ProtocolSSH, ok); err == nil {
-		if ok {
-			keyID = fmt.Sprintf("%s: ID: %s, serial: %v, CA %s %s", certFingerprint,
-				cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey))
-		}
-		if user.IsPartialAuth() {
-			logger.Debug(logSender, connectionID, "user %q authenticated with partial success", conn.User())
-			return certPerm, c.getPartialSuccessError(user.GetNextAuthMethods())
-		}
-		sshPerm, err = loginUser(&user, method, keyID, conn)
-		if err == nil && certPerm != nil {
-			// if we have a SSH user cert we need to merge certificate permissions with our ones
-			// we only set Extensions, so CriticalOptions are always the ones from the certificate
-			sshPerm.CriticalOptions = certPerm.CriticalOptions
-			if certPerm.Extensions != nil {
-				for k, v := range certPerm.Extensions {
-					sshPerm.Extensions[k] = v
-				}
-			}
-		}
+	user, keyID, err := dataprovider.CheckUserAndPubKey(conn.User(), pubKey.Marshal(), ipAddr, common.ProtocolSSH, ok)
+	if err != nil {
+		user.Username = conn.User()
+		updateLoginMetrics(&user, ipAddr, method, err)
+		return nil, err
 	}
-	user.Username = conn.User()
-	updateLoginMetrics(&user, ipAddr, method, err)
-	return sshPerm, err
+	if ok {
+		keyID = fmt.Sprintf("%s: ID: %s, serial: %v, CA %s %s", certFingerprint,
+			cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey))
+	}
+	if certPerm == nil {
+		certPerm = &ssh.Permissions{}
+	}
+	certPerm.ExtraData = make(map[any]any)
+	certPerm.ExtraData[extraDataKeyIDKey] = keyID
+	certPerm.ExtraData[extraDataUserKey] = user
+	if user.IsPartialAuth() {
+		certPerm.ExtraData[extraDataPartialSuccessErrKey] = c.getPartialSuccessError(user.GetNextAuthMethods())
+	}
+	return certPerm, nil
 }
 
 func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass []byte, method string) (*ssh.Permissions, error) {