Sfoglia il codice sorgente

lib/connections: Refactor connection loop (#7177)

This breaks out some methods from the connection loop to make it simpler
to manage and understand.

Some slight simplifications to remove the `seen` variable (we can filter
`nextDial` based on times are in the future or not, so we don't need to
track `seen`) and adding a minimum loop interval (5s) in case some
dialer goes haywire and requests a 0s redial interval or such.

Otherwise no significant behavioral changes.
Jakob Borg 5 anni fa
parent
commit
05f25e600e
1 ha cambiato i file con 171 aggiunte e 140 eliminazioni
  1. 171 140
      lib/connections/service.go

+ 171 - 140
lib/connections/service.go

@@ -10,6 +10,7 @@ import (
 	"context"
 	"crypto/tls"
 	"fmt"
+	"math"
 	"net"
 	"net/url"
 	"sort"
@@ -56,6 +57,9 @@ const (
 	perDeviceWarningIntv    = 15 * time.Minute
 	tlsHandshakeTimeout     = 10 * time.Second
 	minConnectionReplaceAge = 10 * time.Second
+	minConnectionLoopSleep  = 5 * time.Second
+	stdConnectionLoopSleep  = time.Minute
+	worstDialerPriority     = math.MaxInt32
 )
 
 // From go/src/crypto/tls/cipher_suites.go
@@ -342,165 +346,197 @@ func (s *service) handle(ctx context.Context) error {
 }
 
 func (s *service) connect(ctx context.Context) error {
-	nextDial := make(map[string]time.Time)
+	// Map of when to earliest dial each given device + address again
+	nextDialAt := make(map[string]time.Time)
 
-	// Used as delay for the first few connection attempts, increases
-	// exponentially
+	// Used as delay for the first few connection attempts (adjusted up to
+	// minConnectionLoopSleep), increased exponentially until it reaches
+	// stdConnectionLoopSleep, at which time the normal sleep mechanism
+	// kicks in.
 	initialRampup := time.Second
 
-	// Calculated from actual dialers reconnectInterval
-	var sleep time.Duration
-
 	for {
 		cfg := s.cfg.RawCopy()
+		bestDialerPriority := s.bestDialerPriority(cfg)
+		isInitialRampup := initialRampup < stdConnectionLoopSleep
 
-		bestDialerPrio := 1<<31 - 1 // worse prio won't build on 32 bit
-		for _, df := range dialers {
-			if df.Valid(cfg) != nil {
-				continue
-			}
-			if prio := df.Priority(); prio < bestDialerPrio {
-				bestDialerPrio = prio
-			}
+		l.Debugln("Connection loop")
+		if isInitialRampup {
+			l.Debugln("Connection loop in initial rampup")
 		}
 
-		l.Debugln("Reconnect loop")
-
+		// Used for consistency throughout this loop run, as time passes
+		// while we try connections etc.
 		now := time.Now()
-		var seen []string
-
-		for _, deviceCfg := range cfg.Devices {
-			select {
-			case <-ctx.Done():
-				return ctx.Err()
-			default:
-			}
-
-			deviceID := deviceCfg.DeviceID
-			if deviceID == s.myID {
-				continue
-			}
 
-			if deviceCfg.Paused {
-				continue
-			}
+		// Attempt to dial all devices that are unconnected or can be connection-upgraded
+		s.dialDevices(ctx, now, cfg, bestDialerPriority, nextDialAt, isInitialRampup)
 
-			ct, connected := s.model.Connection(deviceID)
+		var sleep time.Duration
+		if isInitialRampup {
+			// We are in the initial rampup time, so we slowly, statically
+			// increase the sleep time.
+			sleep = initialRampup
+			initialRampup *= 2
+		} else {
+			// The sleep time is until the next dial scheduled in nextDialAt,
+			// clamped by stdConnectionLoopSleep as we don't want to sleep too
+			// long (config changes might happen).
+			sleep = filterAndFindSleepDuration(nextDialAt, now)
+		}
 
-			if connected && ct.Priority() == bestDialerPrio {
-				// Things are already as good as they can get.
-				continue
-			}
+		// ... while making sure not to loop too quickly either.
+		if sleep < minConnectionLoopSleep {
+			sleep = minConnectionLoopSleep
+		}
 
-			var addrs []string
-			for _, addr := range deviceCfg.Addresses {
-				if addr == "dynamic" {
-					if s.discoverer != nil {
-						if t, err := s.discoverer.Lookup(ctx, deviceID); err == nil {
-							addrs = append(addrs, t...)
-						}
-					}
-				} else {
-					addrs = append(addrs, addr)
-				}
-			}
+		l.Debugln("Next connection loop in", sleep)
 
-			addrs = util.UniqueTrimmedStrings(addrs)
+		select {
+		case <-time.After(sleep):
+		case <-ctx.Done():
+			return ctx.Err()
+		}
+	}
+}
 
-			l.Debugln("Reconnect loop for", deviceID, addrs)
+func (s *service) bestDialerPriority(cfg config.Configuration) int {
+	bestDialerPriority := worstDialerPriority
+	for _, df := range dialers {
+		if df.Valid(cfg) != nil {
+			continue
+		}
+		if prio := df.Priority(); prio < bestDialerPriority {
+			bestDialerPriority = prio
+		}
+	}
+	return bestDialerPriority
+}
 
-			dialTargets := make([]dialTarget, 0)
+func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Configuration, bestDialerPriority int, nextDialAt map[string]time.Time, initial bool) {
+	for _, deviceCfg := range cfg.Devices {
+		// Don't attempt to connect to ourselves...
+		if deviceCfg.DeviceID == s.myID {
+			continue
+		}
 
-			for _, addr := range addrs {
-				// Use a special key that is more than just the address, as you might have two devices connected to the same relay
-				nextDialKey := deviceID.String() + "/" + addr
-				seen = append(seen, nextDialKey)
-				nextDialAt, ok := nextDial[nextDialKey]
-				if ok && initialRampup >= sleep && nextDialAt.After(now) {
-					l.Debugf("Not dialing %s via %v as sleep is %v, next dial is at %s and current time is %s", deviceID, addr, sleep, nextDialAt, now)
-					continue
-				}
-				// If we fail at any step before actually getting the dialer
-				// retry in a minute
-				nextDial[nextDialKey] = now.Add(time.Minute)
+		// Don't attempt to connect to paused devices...
+		if deviceCfg.Paused {
+			continue
+		}
 
-				uri, err := url.Parse(addr)
-				if err != nil {
-					s.setConnectionStatus(addr, err)
-					l.Infof("Parsing dialer address %s: %v", addr, err)
-					continue
-				}
+		// See if we are already connected and, if so, what our cutoff is
+		// for dialer priority.
+		priorityCutoff := worstDialerPriority
+		connection, connected := s.model.Connection(deviceCfg.DeviceID)
+		if connected {
+			priorityCutoff = connection.Priority()
+			if bestDialerPriority >= priorityCutoff {
+				// Our best dialer is not any better than what we already
+				// have, so nothing to do here.
+				continue
+			}
+		}
 
-				if len(deviceCfg.AllowedNetworks) > 0 {
-					if !IsAllowedNetwork(uri.Host, deviceCfg.AllowedNetworks) {
-						s.setConnectionStatus(addr, errors.New("network disallowed"))
-						l.Debugln("Network for", uri, "is disallowed")
-						continue
-					}
-				}
+		dialTargets := s.resolveDialTargets(ctx, now, cfg, deviceCfg, nextDialAt, initial, priorityCutoff)
+		if conn, ok := s.dialParallel(ctx, deviceCfg.DeviceID, dialTargets); ok {
+			s.conns <- conn
+		}
+	}
+}
 
-				dialerFactory, err := getDialerFactory(cfg, uri)
-				if err != nil {
-					s.setConnectionStatus(addr, err)
-				}
-				if errors.Is(err, errUnsupported) {
-					l.Debugf("Dialer for %v: %v", uri, err)
-					continue
-				} else if err != nil {
-					l.Infof("Dialer for %v: %v", uri, err)
-					continue
-				}
+func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg config.Configuration, deviceCfg config.DeviceConfiguration, nextDialAt map[string]time.Time, initial bool, priorityCutoff int) []dialTarget {
+	deviceID := deviceCfg.DeviceID
 
-				priority := dialerFactory.Priority()
+	addrs := s.resolveDeviceAddrs(ctx, deviceCfg)
+	l.Debugln("Resolved device", deviceID, "addresses:", addrs)
 
-				if connected && priority >= ct.Priority() {
-					l.Debugf("Not dialing using %s as priority is less than current connection (%d >= %d)", dialerFactory, dialerFactory.Priority(), ct.Priority())
-					continue
-				}
+	dialTargets := make([]dialTarget, 0, len(addrs))
+	for _, addr := range addrs {
+		// Use a special key that is more than just the address, as you
+		// might have two devices connected to the same relay
+		nextDialKey := deviceID.String() + "/" + addr
+		when, ok := nextDialAt[nextDialKey]
+		if ok && !initial && when.After(now) {
+			l.Debugf("Not dialing %s via %v as it's not time yet", deviceID, addr)
+			continue
+		}
 
-				dialer := dialerFactory.New(s.cfg.Options(), s.tlsCfg)
-				nextDial[nextDialKey] = now.Add(dialer.RedialFrequency())
+		// If we fail at any step before actually getting the dialer
+		// retry in a minute
+		nextDialAt[nextDialKey] = now.Add(time.Minute)
 
-				// For LAN addresses, increase the priority so that we
-				// try these first.
-				switch {
-				case dialerFactory.AlwaysWAN():
-					// Do nothing.
-				case s.isLANHost(uri.Host):
-					priority -= 1
-				}
+		uri, err := url.Parse(addr)
+		if err != nil {
+			s.setConnectionStatus(addr, err)
+			l.Infof("Parsing dialer address %s: %v", addr, err)
+			continue
+		}
 
-				dialTargets = append(dialTargets, dialTarget{
-					addr:     addr,
-					dialer:   dialer,
-					priority: priority,
-					deviceID: deviceID,
-					uri:      uri,
-				})
+		if len(deviceCfg.AllowedNetworks) > 0 {
+			if !IsAllowedNetwork(uri.Host, deviceCfg.AllowedNetworks) {
+				s.setConnectionStatus(addr, errors.New("network disallowed"))
+				l.Debugln("Network for", uri, "is disallowed")
+				continue
 			}
+		}
 
-			conn, ok := s.dialParallel(ctx, deviceCfg.DeviceID, dialTargets)
-			if ok {
-				s.conns <- conn
-			}
+		dialerFactory, err := getDialerFactory(cfg, uri)
+		if err != nil {
+			s.setConnectionStatus(addr, err)
+		}
+		if errors.Is(err, errUnsupported) {
+			l.Debugf("Dialer for %v: %v", uri, err)
+			continue
+		} else if err != nil {
+			l.Infof("Dialer for %v: %v", uri, err)
+			continue
 		}
 
-		nextDial, sleep = filterAndFindSleepDuration(nextDial, seen, now)
+		priority := dialerFactory.Priority()
+		if priority >= priorityCutoff {
+			l.Debugf("Not dialing using %s as priority is not better than current connection (%d >= %d)", dialerFactory, dialerFactory.Priority(), priorityCutoff)
+			continue
+		}
 
-		if initialRampup < sleep {
-			l.Debugln("initial rampup; sleep", initialRampup, "and update to", initialRampup*2)
-			sleep = initialRampup
-			initialRampup *= 2
-		} else {
-			l.Debugln("sleep until next dial", sleep)
+		dialer := dialerFactory.New(s.cfg.Options(), s.tlsCfg)
+		nextDialAt[nextDialKey] = now.Add(dialer.RedialFrequency())
+
+		// For LAN addresses, increase the priority so that we
+		// try these first.
+		switch {
+		case dialerFactory.AlwaysWAN():
+			// Do nothing.
+		case s.isLANHost(uri.Host):
+			priority--
 		}
 
-		select {
-		case <-time.After(sleep):
-		case <-ctx.Done():
-			return ctx.Err()
+		dialTargets = append(dialTargets, dialTarget{
+			addr:     addr,
+			dialer:   dialer,
+			priority: priority,
+			deviceID: deviceID,
+			uri:      uri,
+		})
+	}
+
+	return dialTargets
+}
+
+func (s *service) resolveDeviceAddrs(ctx context.Context, cfg config.DeviceConfiguration) []string {
+	var addrs []string
+	for _, addr := range cfg.Addresses {
+		if addr == "dynamic" {
+			if s.discoverer != nil {
+				if t, err := s.discoverer.Lookup(ctx, cfg.DeviceID); err == nil {
+					addrs = append(addrs, t...)
+				}
+			}
+		} else {
+			addrs = append(addrs, addr)
 		}
 	}
+	return util.UniqueTrimmedStrings(addrs)
 }
 
 func (s *service) isLANHost(host string) bool {
@@ -778,24 +814,19 @@ func getListenerFactory(cfg config.Configuration, uri *url.URL) (listenerFactory
 	return listenerFactory, nil
 }
 
-func filterAndFindSleepDuration(nextDial map[string]time.Time, seen []string, now time.Time) (map[string]time.Time, time.Duration) {
-	newNextDial := make(map[string]time.Time)
-
-	for _, addr := range seen {
-		nextDialAt, ok := nextDial[addr]
-		if ok {
-			newNextDial[addr] = nextDialAt
+func filterAndFindSleepDuration(nextDialAt map[string]time.Time, now time.Time) time.Duration {
+	sleep := stdConnectionLoopSleep
+	for key, next := range nextDialAt {
+		if next.Before(now) {
+			// Expired entry, address was not seen in last pass(es)
+			delete(nextDialAt, key)
+			continue
 		}
-	}
-
-	min := time.Minute
-	for _, next := range newNextDial {
-		cur := next.Sub(now)
-		if cur < min {
-			min = cur
+		if cur := next.Sub(now); cur < sleep {
+			sleep = cur
 		}
 	}
-	return newNextDial, min
+	return sleep
 }
 
 func urlsToStrings(urls []*url.URL) []string {