Browse Source

lib/connections, lib/model: Improve new conn handling (#8253)

Simon Frei 3 years ago
parent
commit
072fa46bfd
2 changed files with 105 additions and 53 deletions
  1. 97 23
      lib/connections/service.go
  2. 8 30
      lib/model/model.go

+ 97 - 23
lib/connections/service.go

@@ -12,6 +12,7 @@ package connections
 import (
 	"context"
 	"crypto/tls"
+	"crypto/x509"
 	"fmt"
 	"math"
 	"net"
@@ -55,6 +56,13 @@ var (
 	errDisabled   = fmt.Errorf("%w: disabled by configuration", errUnsupported)
 	errDeprecated = fmt.Errorf("%w: deprecated", errUnsupported)
 	errNotInBuild = fmt.Errorf("%w: disabled at build time", errUnsupported)
+
+	// Various reasons to reject a connection
+	errNetworkNotAllowed      = errors.New("network not allowed")
+	errDeviceAlreadyConnected = errors.New("already connected to this device")
+	errDeviceIgnored          = errors.New("device is ignored")
+	errConnLimitReached       = errors.New("connection limit reached")
+	errDevicePaused           = errors.New("device is paused")
 )
 
 const (
@@ -128,6 +136,14 @@ type ConnectionStatusEntry struct {
 	Error *string   `json:"error"`
 }
 
+type connWithHello struct {
+	c          internalConn
+	hello      protocol.Hello
+	err        error
+	remoteID   protocol.DeviceID
+	remoteCert *x509.Certificate
+}
+
 type service struct {
 	*suture.Supervisor
 	connectionStatusHandler
@@ -138,6 +154,7 @@ type service struct {
 	tlsCfg               *tls.Config
 	discoverer           discover.Finder
 	conns                chan internalConn
+	hellos               chan *connWithHello
 	bepProtocolName      string
 	tlsDefaultCommonName string
 	limiter              *limiter
@@ -194,7 +211,8 @@ func NewService(cfg config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *t
 	// incoming or outgoing.
 
 	service.Add(svcutil.AsService(service.connect, fmt.Sprintf("%s/connect", service)))
-	service.Add(svcutil.AsService(service.handle, fmt.Sprintf("%s/handle", service)))
+	service.Add(svcutil.AsService(service.handleConns, fmt.Sprintf("%s/handleConns", service)))
+	service.Add(svcutil.AsService(service.handleHellos, fmt.Sprintf("%s/handleHellos", service)))
 	service.Add(service.natService)
 
 	svcutil.OnSupervisorDone(service.Supervisor, func() {
@@ -205,7 +223,7 @@ func NewService(cfg config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *t
 	return service
 }
 
-func (s *service) handle(ctx context.Context) error {
+func (s *service) handleConns(ctx context.Context) error {
 	var c internalConn
 	for {
 		select {
@@ -245,8 +263,84 @@ func (s *service) handle(ctx context.Context) error {
 			continue
 		}
 
+		if err := s.connectionCheckEarly(remoteID, c); err != nil {
+			l.Infof("Connection from %s at %s (%s) rejected: %v", remoteID, c.RemoteAddr(), c.Type(), err)
+			c.Close()
+			continue
+		}
+
 		_ = c.SetDeadline(time.Now().Add(20 * time.Second))
-		hello, err := protocol.ExchangeHello(c, s.model.GetHello(remoteID))
+		go func() {
+			hello, err := protocol.ExchangeHello(c, s.model.GetHello(remoteID))
+			select {
+			case s.hellos <- &connWithHello{c, hello, err, remoteID, remoteCert}:
+			case <-ctx.Done():
+			}
+		}()
+	}
+}
+
+func (s *service) connectionCheckEarly(remoteID protocol.DeviceID, c internalConn) error {
+	if s.cfg.IgnoredDevice(remoteID) {
+		return errDeviceIgnored
+	}
+
+	if max := s.cfg.Options().ConnectionLimitMax; max > 0 && s.model.NumConnections() >= max {
+		// We're not allowed to accept any more connections.
+		return errConnLimitReached
+	}
+
+	cfg, ok := s.cfg.Device(remoteID)
+	if !ok {
+		// We do go ahead exchanging hello messages to get information about the device.
+		return nil
+	}
+
+	if cfg.Paused {
+		return errDevicePaused
+	}
+
+	if len(cfg.AllowedNetworks) > 0 && !IsAllowedNetwork(c.RemoteAddr().String(), cfg.AllowedNetworks) {
+		// The connection is not from an allowed network.
+		return errNetworkNotAllowed
+	}
+
+	// Lower priority is better, just like nice etc.
+	if ct, ok := s.model.Connection(remoteID); ok {
+		if ct.Priority() > c.priority || time.Since(ct.Statistics().StartedAt) > minConnectionReplaceAge {
+			l.Debugf("Switching connections %s (existing: %s new: %s)", remoteID, ct, c)
+		} else {
+			// We should not already be connected to the other party. TODO: This
+			// could use some better handling. If the old connection is dead but
+			// hasn't timed out yet we may want to drop *that* connection and keep
+			// this one. But in case we are two devices connecting to each other
+			// in parallel we don't want to do that or we end up with no
+			// connections still established...
+			return errDeviceAlreadyConnected
+		}
+	}
+
+	return nil
+}
+
+func (s *service) handleHellos(ctx context.Context) error {
+	var c internalConn
+	var hello protocol.Hello
+	var err error
+	var remoteID protocol.DeviceID
+	var remoteCert *x509.Certificate
+	for {
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		case withHello := <-s.hellos:
+			c = withHello.c
+			hello = withHello.hello
+			err = withHello.err
+			remoteID = withHello.remoteID
+			remoteCert = withHello.remoteCert
+		}
+
 		if err != nil {
 			if protocol.IsVersionMismatch(err) {
 				// The error will be a relatively user friendly description
@@ -279,25 +373,6 @@ func (s *service) handle(ctx context.Context) error {
 			continue
 		}
 
-		// If we have a relay connection, and the new incoming connection is
-		// not a relay connection, we should drop that, and prefer this one.
-		ct, connected := s.model.Connection(remoteID)
-
-		// Lower priority is better, just like nice etc.
-		if connected && (ct.Priority() > c.priority || time.Since(ct.Statistics().StartedAt) > minConnectionReplaceAge) {
-			l.Debugf("Switching connections %s (existing: %s new: %s)", remoteID, ct, c)
-		} else if connected {
-			// We should not already be connected to the other party. TODO: This
-			// could use some better handling. If the old connection is dead but
-			// hasn't timed out yet we may want to drop *that* connection and keep
-			// this one. But in case we are two devices connecting to each other
-			// in parallel we don't want to do that or we end up with no
-			// connections still established...
-			l.Infof("Connected to already connected device %s (existing: %s new: %s)", remoteID, ct, c)
-			c.Close()
-			continue
-		}
-
 		deviceCfg, ok := s.cfg.Device(remoteID)
 		if !ok {
 			l.Infof("Device %s removed from config during connection attempt at %s", remoteID, c)
@@ -346,7 +421,6 @@ func (s *service) handle(ctx context.Context) error {
 		continue
 	}
 }
-
 func (s *service) connect(ctx context.Context) error {
 	// Map of when to earliest dial each given device + address again
 	nextDialAt := make(nextDialRegistry)

+ 8 - 30
lib/model/model.go

@@ -177,15 +177,13 @@ var (
 )
 
 var (
-	errDeviceUnknown     = errors.New("unknown device")
-	errDevicePaused      = errors.New("device is paused")
-	errDeviceIgnored     = errors.New("device is ignored")
-	errDeviceRemoved     = errors.New("device has been removed")
-	ErrFolderPaused      = errors.New("folder is paused")
-	ErrFolderNotRunning  = errors.New("folder is not running")
-	ErrFolderMissing     = errors.New("no such folder")
-	errNetworkNotAllowed = errors.New("network not allowed")
-	errNoVersioner       = errors.New("folder has no versioner")
+	errDeviceUnknown    = errors.New("unknown device")
+	errDevicePaused     = errors.New("device is paused")
+	errDeviceRemoved    = errors.New("device has been removed")
+	ErrFolderPaused     = errors.New("folder is paused")
+	ErrFolderNotRunning = errors.New("folder is not running")
+	ErrFolderMissing    = errors.New("no such folder")
+	errNoVersioner      = errors.New("folder has no versioner")
 	// errors about why a connection is closed
 	errReplacingConnection                = errors.New("replacing connection")
 	errStopped                            = errors.New("Syncthing is being stopped")
@@ -2114,12 +2112,7 @@ func (m *model) setIgnores(cfg config.FolderConfiguration, content []string) err
 // This allows us to extract some information from the Hello message
 // and add it to a list of known devices ahead of any checks.
 func (m *model) OnHello(remoteID protocol.DeviceID, addr net.Addr, hello protocol.Hello) error {
-	if m.cfg.IgnoredDevice(remoteID) {
-		return errDeviceIgnored
-	}
-
-	cfg, ok := m.cfg.Device(remoteID)
-	if !ok {
+	if _, ok := m.cfg.Device(remoteID); !ok {
 		if err := m.db.AddOrUpdatePendingDevice(remoteID, hello.DeviceName, addr.String()); err != nil {
 			l.Warnf("Failed to persist pending device entry to database: %v", err)
 		}
@@ -2138,21 +2131,6 @@ func (m *model) OnHello(remoteID protocol.DeviceID, addr net.Addr, hello protoco
 		})
 		return errDeviceUnknown
 	}
-
-	if cfg.Paused {
-		return errDevicePaused
-	}
-
-	if len(cfg.AllowedNetworks) > 0 && !connections.IsAllowedNetwork(addr.String(), cfg.AllowedNetworks) {
-		// The connection is not from an allowed network.
-		return errNetworkNotAllowed
-	}
-
-	if max := m.cfg.Options().ConnectionLimitMax; max > 0 && m.NumConnections() >= max {
-		// We're not allowed to accept any more connections.
-		return errConnLimitReached
-	}
-
 	return nil
 }