Browse Source

try to fix a randomly failing test case

Signed-off-by: Nicola Murino <[email protected]>
Nicola Murino 2 years ago
parent
commit
324d695d93
1 changed files with 28 additions and 21 deletions
  1. 28 21
      internal/sftpd/server.go

+ 28 - 21
internal/sftpd/server.go

@@ -571,25 +571,6 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
 	serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive)
 }
 
-func canAcceptConnection(ip string) bool {
-	if common.IsBanned(ip, common.ProtocolSSH) {
-		logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %q is banned", ip)
-		return false
-	}
-	if err := common.Connections.IsNewConnectionAllowed(ip, common.ProtocolSSH); err != nil {
-		logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection not allowed from ip %q: %v", ip, err)
-		return false
-	}
-	_, err := common.LimitRate(common.ProtocolSSH, ip)
-	if err != nil {
-		return false
-	}
-	if err := common.Config.ExecutePostConnectHook(ip, common.ProtocolSSH); err != nil {
-		return false
-	}
-	return true
-}
-
 // AcceptInboundConnection handles an inbound connection to the server instance and determines if the request should be served or not.
 func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
 	defer func() {
@@ -618,6 +599,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 	}
 	// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
 	conn.SetDeadline(time.Time{}) //nolint:errcheck
+	go ssh.DiscardRequests(reqs)
 
 	defer conn.Close()
 
@@ -632,6 +614,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 	defer user.CloseFs() //nolint:errcheck
 	if err = user.CheckFsRoot(connectionID); err != nil {
 		logger.Warn(logSender, connectionID, "unable to check fs root for user %q: %v", user.Username, err)
+		go discardAllChannels(chans, "invalid root fs", connectionID)
 		return
 	}
 
@@ -645,8 +628,6 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
 
 	defer common.Connections.RemoveSSHConnection(connectionID)
 
-	go ssh.DiscardRequests(reqs)
-
 	channelCounter := int64(0)
 	for newChannel := range chans {
 		// If its not a session channel we just move on because its not something we
@@ -756,6 +737,32 @@ func (c *Configuration) createHandlers(connection *Connection) sftp.Handlers {
 	}
 }
 
+func canAcceptConnection(ip string) bool {
+	if common.IsBanned(ip, common.ProtocolSSH) {
+		logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %q is banned", ip)
+		return false
+	}
+	if err := common.Connections.IsNewConnectionAllowed(ip, common.ProtocolSSH); err != nil {
+		logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection not allowed from ip %q: %v", ip, err)
+		return false
+	}
+	_, err := common.LimitRate(common.ProtocolSSH, ip)
+	if err != nil {
+		return false
+	}
+	if err := common.Config.ExecutePostConnectHook(ip, common.ProtocolSSH); err != nil {
+		return false
+	}
+	return true
+}
+
+func discardAllChannels(in <-chan ssh.NewChannel, message, connectionID string) {
+	for req := range in {
+		err := req.Reject(ssh.ConnectionFailed, message)
+		logger.Debug(logSender, connectionID, "discarded channel request, message %q err: %v", message, err)
+	}
+}
+
 func checkAuthError(ip string, err error) {
 	if authErrors, ok := err.(*ssh.ServerAuthError); ok {
 		// check public key auth errors here