Explorar el Código

lib/connections: Trigger dialer when connection gets closed (#7753)

Simon Frei hace 4 años
padre
commit
857caf3637

+ 65 - 0
lib/connections/connections_test.go

@@ -231,6 +231,71 @@ func TestConnectionStatus(t *testing.T) {
 	check(nil, nil)
 }
 
+func TestNextDialRegistryCleanup(t *testing.T) {
+	now := time.Now()
+	firsts := []time.Time{
+		now.Add(-dialCoolDownInterval + time.Second),
+		now.Add(-dialCoolDownDelay + time.Second),
+		now.Add(-2 * dialCoolDownDelay),
+	}
+
+	r := make(nextDialRegistry)
+
+	// Cases where the device should be cleaned up
+
+	r[protocol.LocalDeviceID] = nextDialDevice{}
+	r.sleepDurationAndCleanup(now)
+	if l := len(r); l > 0 {
+		t.Errorf("Expected empty to be cleaned up, got length %v", l)
+	}
+	for _, dev := range []nextDialDevice{
+		// attempts below threshold, outside of interval
+		{
+			attempts:              1,
+			coolDownIntervalStart: firsts[1],
+		},
+		{
+			attempts:              1,
+			coolDownIntervalStart: firsts[2],
+		},
+		// Threshold reached, but outside of cooldown delay
+		{
+			attempts:              dialCoolDownMaxAttemps,
+			coolDownIntervalStart: firsts[2],
+		},
+	} {
+		r[protocol.LocalDeviceID] = dev
+		r.sleepDurationAndCleanup(now)
+		if l := len(r); l > 0 {
+			t.Errorf("attempts: %v, start: %v: Expected all cleaned up, got length %v", dev.attempts, dev.coolDownIntervalStart, l)
+		}
+	}
+
+	// Cases where the device should stay monitored
+	for _, dev := range []nextDialDevice{
+		// attempts below threshold, inside of interval
+		{
+			attempts:              1,
+			coolDownIntervalStart: firsts[0],
+		},
+		// attempts at threshold, inside delay
+		{
+			attempts:              dialCoolDownMaxAttemps,
+			coolDownIntervalStart: firsts[0],
+		},
+		{
+			attempts:              dialCoolDownMaxAttemps,
+			coolDownIntervalStart: firsts[1],
+		},
+	} {
+		r[protocol.LocalDeviceID] = dev
+		r.sleepDurationAndCleanup(now)
+		if l := len(r); l != 1 {
+			t.Errorf("attempts: %v, start: %v: Expected device still tracked, got length %v", dev.attempts, dev.coolDownIntervalStart, l)
+		}
+	}
+}
+
 func BenchmarkConnections(pb *testing.B) {
 	addrs := []string{
 		"tcp://127.0.0.1:0",

+ 139 - 46
lib/connections/service.go

@@ -142,10 +142,13 @@ type service struct {
 	natService           *nat.Service
 	evLogger             events.Logger
 
-	deviceAddressesChanged chan struct{}
-	listenersMut           sync.RWMutex
-	listeners              map[string]genericListener
-	listenerTokens         map[string]suture.ServiceToken
+	dialNow           chan struct{}
+	dialNowDevices    map[protocol.DeviceID]struct{}
+	dialNowDevicesMut sync.Mutex
+
+	listenersMut   sync.RWMutex
+	listeners      map[string]genericListener
+	listenerTokens map[string]suture.ServiceToken
 }
 
 func NewService(cfg config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder, bepProtocolName string, tlsDefaultCommonName string, evLogger events.Logger) Service {
@@ -166,10 +169,13 @@ func NewService(cfg config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *t
 		natService:           nat.NewService(myID, cfg),
 		evLogger:             evLogger,
 
-		deviceAddressesChanged: make(chan struct{}, 1),
-		listenersMut:           sync.NewRWMutex(),
-		listeners:              make(map[string]genericListener),
-		listenerTokens:         make(map[string]suture.ServiceToken),
+		dialNowDevicesMut: sync.NewMutex(),
+		dialNow:           make(chan struct{}, 1),
+		dialNowDevices:    make(map[protocol.DeviceID]struct{}),
+
+		listenersMut:   sync.NewRWMutex(),
+		listeners:      make(map[string]genericListener),
+		listenerTokens: make(map[string]suture.ServiceToken),
 	}
 	cfg.Subscribe(service)
 
@@ -324,6 +330,13 @@ func (s *service) handle(ctx context.Context) error {
 		rd, wr := s.limiter.getLimiters(remoteID, c, isLAN)
 
 		protoConn := protocol.NewConnection(remoteID, rd, wr, c, s.model, c, deviceCfg.Compression, s.cfg.FolderPasswords(remoteID))
+		go func() {
+			<-protoConn.Closed()
+			s.dialNowDevicesMut.Lock()
+			s.dialNowDevices[remoteID] = struct{}{}
+			s.scheduleDialNow()
+			s.dialNowDevicesMut.Unlock()
+		}()
 
 		l.Infof("Established secure connection to %s at %s", remoteID, c)
 
@@ -334,7 +347,7 @@ func (s *service) handle(ctx context.Context) error {
 
 func (s *service) connect(ctx context.Context) error {
 	// Map of when to earliest dial each given device + address again
-	nextDialAt := make(map[string]time.Time)
+	nextDialAt := make(nextDialRegistry)
 
 	// Used as delay for the first few connection attempts (adjusted up to
 	// minConnectionLoopSleep), increased exponentially until it reaches
@@ -369,7 +382,7 @@ func (s *service) connect(ctx context.Context) error {
 			// The sleep time is until the next dial scheduled in nextDialAt,
 			// clamped by stdConnectionLoopSleep as we don't want to sleep too
 			// long (config changes might happen).
-			sleep = filterAndFindSleepDuration(nextDialAt, now)
+			sleep = nextDialAt.sleepDurationAndCleanup(now)
 		}
 
 		// ... while making sure not to loop too quickly either.
@@ -379,9 +392,20 @@ func (s *service) connect(ctx context.Context) error {
 
 		l.Debugln("Next connection loop in", sleep)
 
+		timeout := time.NewTimer(sleep)
 		select {
-		case <-s.deviceAddressesChanged:
-		case <-time.After(sleep):
+		case <-s.dialNow:
+			// Remove affected devices from nextDialAt to dial immediately,
+			// regardless of when we last dialed it (there's cool down in the
+			// registry for too many repeat dials).
+			s.dialNowDevicesMut.Lock()
+			for device := range s.dialNowDevices {
+				nextDialAt.redialDevice(device, now)
+			}
+			s.dialNowDevices = make(map[protocol.DeviceID]struct{})
+			s.dialNowDevicesMut.Unlock()
+			timeout.Stop()
+		case <-timeout.C:
 		case <-ctx.Done():
 			return ctx.Err()
 		}
@@ -401,7 +425,7 @@ func (s *service) bestDialerPriority(cfg config.Configuration) int {
 	return bestDialerPriority
 }
 
-func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Configuration, bestDialerPriority int, nextDialAt map[string]time.Time, initial bool) {
+func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Configuration, bestDialerPriority int, nextDialAt nextDialRegistry, initial bool) {
 	// Figure out current connection limits up front to see if there's any
 	// point in resolving devices and such at all.
 	allowAdditional := 0 // no limit
@@ -477,7 +501,7 @@ func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Con
 	}
 }
 
-func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg config.Configuration, deviceCfg config.DeviceConfiguration, nextDialAt map[string]time.Time, initial bool, priorityCutoff int) []dialTarget {
+func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg config.Configuration, deviceCfg config.DeviceConfiguration, nextDialAt nextDialRegistry, initial bool, priorityCutoff int) []dialTarget {
 	deviceID := deviceCfg.DeviceID
 
 	addrs := s.resolveDeviceAddrs(ctx, deviceCfg)
@@ -485,18 +509,16 @@ func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg con
 
 	dialTargets := make([]dialTarget, 0, len(addrs))
 	for _, addr := range addrs {
-		// Use a special key that is more than just the address, as you
-		// might have two devices connected to the same relay
-		nextDialKey := deviceID.String() + "/" + addr
-		when, ok := nextDialAt[nextDialKey]
-		if ok && !initial && when.After(now) {
+		// Use both device and address, as you might have two devices connected
+		// to the same relay
+		if !initial && nextDialAt.get(deviceID, addr).After(now) {
 			l.Debugf("Not dialing %s via %v as it's not time yet", deviceID, addr)
 			continue
 		}
 
 		// If we fail at any step before actually getting the dialer
 		// retry in a minute
-		nextDialAt[nextDialKey] = now.Add(time.Minute)
+		nextDialAt.set(deviceID, addr, now.Add(time.Minute))
 
 		uri, err := url.Parse(addr)
 		if err != nil {
@@ -532,7 +554,7 @@ func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg con
 		}
 
 		dialer := dialerFactory.New(s.cfg.Options(), s.tlsCfg)
-		nextDialAt[nextDialKey] = now.Add(dialer.RedialFrequency())
+		nextDialAt.set(deviceID, addr, now.Add(dialer.RedialFrequency()))
 
 		// For LAN addresses, increase the priority so that we
 		// try these first.
@@ -735,24 +757,24 @@ func (s *service) CommitConfiguration(from, to config.Configuration) bool {
 }
 
 func (s *service) checkAndSignalConnectLoopOnUpdatedDevices(from, to config.Configuration) {
-	oldDevices := make(map[protocol.DeviceID]config.DeviceConfiguration, len(from.Devices))
-	for _, dev := range from.Devices {
-		oldDevices[dev.DeviceID] = dev
-	}
-
+	oldDevices := from.DeviceMap()
 	for _, dev := range to.Devices {
 		oldDev, ok := oldDevices[dev.DeviceID]
 		if !ok || !util.EqualStrings(oldDev.Addresses, dev.Addresses) {
-			select {
-			case s.deviceAddressesChanged <- struct{}{}:
-			default:
-				// channel is blocked - a config update is already pending for the connection loop.
-			}
+			s.scheduleDialNow()
 			break
 		}
 	}
 }
 
+func (s *service) scheduleDialNow() {
+	select {
+	case s.dialNow <- struct{}{}:
+	default:
+		// channel is blocked - a config update is already pending for the connection loop.
+	}
+}
+
 func (s *service) AllAddresses() []string {
 	s.listenersMut.RLock()
 	var addrs []string
@@ -877,21 +899,6 @@ func getListenerFactory(cfg config.Configuration, uri *url.URL) (listenerFactory
 	return listenerFactory, nil
 }
 
-func filterAndFindSleepDuration(nextDialAt map[string]time.Time, now time.Time) time.Duration {
-	sleep := stdConnectionLoopSleep
-	for key, next := range nextDialAt {
-		if next.Before(now) {
-			// Expired entry, address was not seen in last pass(es)
-			delete(nextDialAt, key)
-			continue
-		}
-		if cur := next.Sub(now); cur < sleep {
-			sleep = cur
-		}
-	}
-	return sleep
-}
-
 func urlsToStrings(urls []*url.URL) []string {
 	strings := make([]string, len(urls))
 	for i, url := range urls {
@@ -1050,3 +1057,89 @@ func (s *service) validateIdentity(c internalConn, expectedID protocol.DeviceID)
 
 	return nil
 }
+
+type nextDialRegistry map[protocol.DeviceID]nextDialDevice
+
+type nextDialDevice struct {
+	nextDial              map[string]time.Time
+	coolDownIntervalStart time.Time
+	attempts              int
+}
+
+func (r nextDialRegistry) get(device protocol.DeviceID, addr string) time.Time {
+	return r[device].nextDial[addr]
+}
+
+const (
+	dialCoolDownInterval   = 2 * time.Minute
+	dialCoolDownDelay      = 5 * time.Minute
+	dialCoolDownMaxAttemps = 3
+)
+
+// redialDevice marks the device for immediate redial, unless the remote keeps
+// dropping established connections. Thus we keep track of when the first forced
+// re-dial happened, and how many attempts happen in the dialCoolDownInterval
+// after that. If it's more than dialCoolDownMaxAttempts, don't force-redial
+// that device for dialCoolDownDelay (regular dials still happen).
+func (r nextDialRegistry) redialDevice(device protocol.DeviceID, now time.Time) {
+	dev, ok := r[device]
+	if !ok {
+		r[device] = nextDialDevice{
+			coolDownIntervalStart: now,
+			attempts:              1,
+		}
+		return
+	}
+	if dev.attempts == 0 || now.Before(dev.coolDownIntervalStart.Add(dialCoolDownInterval)) {
+		if dev.attempts >= dialCoolDownMaxAttemps {
+			// Device has been force redialed too often - let it cool down.
+			return
+		}
+		if dev.attempts == 0 {
+			dev.coolDownIntervalStart = now
+		}
+		dev.attempts++
+		dev.nextDial = make(map[string]time.Time)
+		return
+	}
+	if dev.attempts >= dialCoolDownMaxAttemps && now.Before(dev.coolDownIntervalStart.Add(dialCoolDownDelay)) {
+		return // Still cooling down
+	}
+	delete(r, device)
+}
+
+func (r nextDialRegistry) set(device protocol.DeviceID, addr string, next time.Time) {
+	if _, ok := r[device]; !ok {
+		r[device] = nextDialDevice{nextDial: make(map[string]time.Time)}
+	}
+	r[device].nextDial[addr] = next
+}
+
+func (r nextDialRegistry) sleepDurationAndCleanup(now time.Time) time.Duration {
+	sleep := stdConnectionLoopSleep
+	for id, dev := range r {
+		for address, next := range dev.nextDial {
+			if next.Before(now) {
+				// Expired entry, address was not seen in last pass(es)
+				delete(dev.nextDial, address)
+				continue
+			}
+			if cur := next.Sub(now); cur < sleep {
+				sleep = cur
+			}
+		}
+		if dev.attempts > 0 {
+			interval := dialCoolDownInterval
+			if dev.attempts >= dialCoolDownMaxAttemps {
+				interval = dialCoolDownDelay
+			}
+			if now.After(dev.coolDownIntervalStart.Add(interval)) {
+				dev.attempts = 0
+			}
+		}
+		if len(dev.nextDial) == 0 && dev.attempts == 0 {
+			delete(r, id)
+		}
+	}
+	return sleep
+}

+ 7 - 1
lib/model/fakeconns_test.go

@@ -27,14 +27,18 @@ func newFakeConnection(id protocol.DeviceID, model Model) *fakeConnection {
 		Connection: new(protocolmocks.Connection),
 		id:         id,
 		model:      model,
+		closed:     make(chan struct{}),
 	}
 	f.RequestCalls(func(ctx context.Context, folder, name string, blockNo int, offset int64, size int, hash []byte, weakHash uint32, fromTemporary bool) ([]byte, error) {
 		return f.fileData[name], nil
 	})
 	f.IDReturns(id)
 	f.CloseCalls(func(err error) {
+		f.closeOnce.Do(func() {
+			close(f.closed)
+		})
 		model.Closed(id, err)
-		f.ClosedReturns(true)
+		f.ClosedReturns(f.closed)
 	})
 	return f
 }
@@ -47,6 +51,8 @@ type fakeConnection struct {
 	fileData                 map[string][]byte
 	folder                   string
 	model                    Model
+	closed                   chan struct{}
+	closeOnce                sync.Once
 	mut                      sync.Mutex
 }
 

+ 6 - 2
lib/model/model_test.go

@@ -2245,8 +2245,10 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
 		t.Error("not shared with device2")
 	}
 
-	if conn2.Closed() {
+	select {
+	case <-conn2.Closed():
 		t.Error("conn already closed")
+	default:
 	}
 
 	if _, err := wcfg.RemoveDevice(device2); err != nil {
@@ -2271,7 +2273,9 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
 		}
 	}
 
-	if !conn2.Closed() {
+	select {
+	case <-conn2.Closed():
+	default:
 		t.Error("connection not closed")
 	}
 

+ 1 - 1
lib/protocol/encryption.go

@@ -224,7 +224,7 @@ func (e encryptedConnection) Close(err error) {
 	e.conn.Close(err)
 }
 
-func (e encryptedConnection) Closed() bool {
+func (e encryptedConnection) Closed() <-chan struct{} {
 	return e.conn.Closed()
 }
 

+ 10 - 10
lib/protocol/mocks/connection.go

@@ -16,15 +16,15 @@ type Connection struct {
 	closeArgsForCall []struct {
 		arg1 error
 	}
-	ClosedStub        func() bool
+	ClosedStub        func() <-chan struct{}
 	closedMutex       sync.RWMutex
 	closedArgsForCall []struct {
 	}
 	closedReturns struct {
-		result1 bool
+		result1 <-chan struct{}
 	}
 	closedReturnsOnCall map[int]struct {
-		result1 bool
+		result1 <-chan struct{}
 	}
 	ClusterConfigStub        func(protocol.ClusterConfig)
 	clusterConfigMutex       sync.RWMutex
@@ -220,7 +220,7 @@ func (fake *Connection) CloseArgsForCall(i int) error {
 	return argsForCall.arg1
 }
 
-func (fake *Connection) Closed() bool {
+func (fake *Connection) Closed() <-chan struct{} {
 	fake.closedMutex.Lock()
 	ret, specificReturn := fake.closedReturnsOnCall[len(fake.closedArgsForCall)]
 	fake.closedArgsForCall = append(fake.closedArgsForCall, struct {
@@ -244,32 +244,32 @@ func (fake *Connection) ClosedCallCount() int {
 	return len(fake.closedArgsForCall)
 }
 
-func (fake *Connection) ClosedCalls(stub func() bool) {
+func (fake *Connection) ClosedCalls(stub func() <-chan struct{}) {
 	fake.closedMutex.Lock()
 	defer fake.closedMutex.Unlock()
 	fake.ClosedStub = stub
 }
 
-func (fake *Connection) ClosedReturns(result1 bool) {
+func (fake *Connection) ClosedReturns(result1 <-chan struct{}) {
 	fake.closedMutex.Lock()
 	defer fake.closedMutex.Unlock()
 	fake.ClosedStub = nil
 	fake.closedReturns = struct {
-		result1 bool
+		result1 <-chan struct{}
 	}{result1}
 }
 
-func (fake *Connection) ClosedReturnsOnCall(i int, result1 bool) {
+func (fake *Connection) ClosedReturnsOnCall(i int, result1 <-chan struct{}) {
 	fake.closedMutex.Lock()
 	defer fake.closedMutex.Unlock()
 	fake.ClosedStub = nil
 	if fake.closedReturnsOnCall == nil {
 		fake.closedReturnsOnCall = make(map[int]struct {
-			result1 bool
+			result1 <-chan struct{}
 		})
 	}
 	fake.closedReturnsOnCall[i] = struct {
-		result1 bool
+		result1 <-chan struct{}
 	}{result1}
 }
 

+ 3 - 8
lib/protocol/protocol.go

@@ -151,7 +151,7 @@ type Connection interface {
 	ClusterConfig(config ClusterConfig)
 	DownloadProgress(ctx context.Context, folder string, updates []FileDownloadProgressUpdate)
 	Statistics() Statistics
-	Closed() bool
+	Closed() <-chan struct{}
 	ConnectionInfo
 }
 
@@ -380,13 +380,8 @@ func (c *rawConnection) ClusterConfig(config ClusterConfig) {
 	}
 }
 
-func (c *rawConnection) Closed() bool {
-	select {
-	case <-c.closed:
-		return true
-	default:
-		return false
-	}
+func (c *rawConnection) Closed() <-chan struct{} {
+	return c.closed
 }
 
 // DownloadProgress sends the progress updates for the files that are currently being downloaded.