Selaa lähdekoodia

lib/connections: Parallel dials in the same priority (fixes #4456)

Well Tested(TM)

Introduces a potential issue where we always pick some connectable but dodgy connection that breaks
soon after the TLS handshake.

GitHub-Pull-Request: https://github.com/syncthing/syncthing/pull/4489
AudriusButkevicius 8 vuotta sitten
vanhempi
sitoutus
aecd7c64ce
2 muutettua tiedostoa jossa 91 lisäystä ja 9 poistoa
  1. 77 9
      lib/connections/service.go
  2. 14 0
      lib/connections/structs.go

+ 77 - 9
lib/connections/service.go

@@ -12,6 +12,7 @@ import (
 	"fmt"
 	"net"
 	"net/url"
+	"sort"
 	"strings"
 	"time"
 
@@ -318,7 +319,6 @@ func (s *Service) connect() {
 		now := time.Now()
 		var seen []string
 
-	nextDevice:
 		for _, deviceCfg := range cfg.Devices {
 			deviceID := deviceCfg.DeviceID
 			if deviceID == s.myID {
@@ -357,6 +357,8 @@ func (s *Service) connect() {
 
 			seen = append(seen, addrs...)
 
+			dialTargets := make([]dialTarget, 0)
+
 			for _, addr := range addrs {
 				nextDialAt, ok := nextDial[addr]
 				if ok && initialRampup >= sleep && nextDialAt.After(now) {
@@ -390,23 +392,27 @@ func (s *Service) connect() {
 					continue
 				}
 
-				if priorityKnown && dialerFactory.Priority() >= ct.internalConn.priority {
+				prio := dialerFactory.Priority()
+
+				if priorityKnown && prio >= ct.internalConn.priority {
 					l.Debugf("Not dialing using %s as priority is less than current connection (%d >= %d)", dialerFactory, dialerFactory.Priority(), ct.internalConn.priority)
 					continue
 				}
 
 				dialer := dialerFactory.New(s.cfg, s.tlsCfg)
-				l.Debugln("dial", deviceCfg.DeviceID, uri)
 				nextDial[addr] = now.Add(dialer.RedialFrequency())
 
-				conn, err := dialer.Dial(deviceID, uri)
-				if err != nil {
-					l.Debugf("%v for %v at %v: %v", dialerFactory, deviceCfg.DeviceID, uri, err)
-					continue
-				}
+				dialTargets = append(dialTargets, dialTarget{
+					dialer:   dialer,
+					priority: prio,
+					deviceID: deviceID,
+					uri:      uri,
+				})
+			}
 
+			conn, ok := dialParallel(deviceCfg.DeviceID, dialTargets)
+			if ok {
 				s.conns <- conn
-				continue nextDevice
 			}
 		}
 
@@ -710,3 +716,65 @@ func IsAllowedNetwork(host string, allowed []string) bool {
 
 	return false
 }
+
+func dialParallel(deviceID protocol.DeviceID, dialTargets []dialTarget) (internalConn, bool) {
+	// Group targets into buckets by priority
+	dialTargetBuckets := make(map[int][]dialTarget, len(dialTargets))
+	for _, tgt := range dialTargets {
+		dialTargetBuckets[tgt.priority] = append(dialTargetBuckets[tgt.priority], tgt)
+	}
+
+	// Get all available priorities
+	priorities := make([]int, 0, len(dialTargetBuckets))
+	for prio := range dialTargetBuckets {
+		priorities = append(priorities, prio)
+	}
+
+	// Sort the priorities so that we dial lowest first (which means highest...)
+	sort.Ints(priorities)
+
+	for _, prio := range priorities {
+		tgts := dialTargetBuckets[prio]
+		res := make(chan internalConn, len(tgts))
+		wg := sync.NewWaitGroup()
+		for _, tgt := range tgts {
+			wg.Add(1)
+			go func() {
+				conn, err := tgt.Dial()
+				if err == nil {
+					res <- conn
+				}
+				wg.Done()
+			}()
+		}
+
+		// Spawn a routine which will unblock main routine in case we fail
+		// to connect to anyone.
+		go func() {
+			wg.Wait()
+			close(res)
+		}()
+
+		// Wait for the first connection, or for channel closure.
+		conn, ok := <-res
+
+		// Got a connection, means more might come back, hence spawn a
+		// routine that will do the discarding.
+		if ok {
+			l.Debugln("connected to", deviceID, prio, "using", conn, conn.priority)
+			go func(deviceID protocol.DeviceID, prio int) {
+				wg.Wait()
+				l.Debugln("discarding", len(res), "connections while connecting to", deviceID, prio)
+				for conn := range res {
+					conn.Close()
+				}
+			}(deviceID, prio)
+		} else {
+			// Failed to connect, report that fact.
+			l.Debugln("failed to connect to", deviceID, prio)
+		}
+
+		return conn, ok
+	}
+	return internalConn{}, false
+}

+ 14 - 0
lib/connections/structs.go

@@ -177,3 +177,17 @@ func (o *onAddressesChangedNotifier) notifyAddressesChanged(l genericListener) {
 		callback(l)
 	}
 }
+
+type dialTarget struct {
+	dialer   genericDialer
+	priority int
+	uri      *url.URL
+	deviceID protocol.DeviceID
+}
+
+func (t dialTarget) Dial() (internalConn, error) {
+	l.Debugln("dialing", t.deviceID, t.uri, "prio", t.priority)
+	conn, err := t.dialer.Dial(t.deviceID, t.uri)
+	l.Debugln("dialing", t.deviceID, t.uri, "outcome", conn, err)
+	return conn, err
+}