Ver código fonte

lib/nat: Fix race condition in Mapping (#8042)

The locking protocol in nat.Mapping was racy:

* Mapping.addressMap RLock'd, but then returned a map shared between
  caller and Mapping, so the lock didn't do anything.

* Operations inside Service.{verifyExistingMappings,acquireNewMappings}
  would lock the map for every update, but that means callers to
  Mapping.ExternalAddresses can be looping over the map while the
  Service methods are concurrently modifying it. When the Go runtime
  detects that happening, it panics.

* Mapping.expires was read and updated without locking.

The Service methods now lock the map once and release the lock only when
done.

Also, subscribers no longer get the added and removed addresses, because
none of them were using the information. This was changed for a previous
attempt to retain the fine-grained locking and not reverted because it
simplifies the code.
greatroar 4 anos atrás
pai
commit
8265dac127
4 arquivos alterados com 43 adições e 61 exclusões
  1. 1 1
      cmd/strelaysrv/main.go
  2. 1 1
      lib/connections/tcp_listen.go
  3. 31 35
      lib/nat/service.go
  4. 10 24
      lib/nat/structs.go

+ 1 - 1
cmd/strelaysrv/main.go

@@ -200,7 +200,7 @@ func main() {
 		go natSvc.Serve(ctx)
 		defer cancel()
 		found := make(chan struct{})
-		mapping.OnChanged(func(_ *nat.Mapping, _, _ []nat.Address) {
+		mapping.OnChanged(func() {
 			select {
 			case found <- struct{}{}:
 			default:

+ 1 - 1
lib/connections/tcp_listen.go

@@ -76,7 +76,7 @@ func (t *tcpListener) serve(ctx context.Context) error {
 	defer l.Infof("TCP listener (%v) shutting down", tcaddr)
 
 	mapping := t.natService.NewMapping(nat.TCP, tcaddr.IP, tcaddr.Port)
-	mapping.OnChanged(func(_ *nat.Mapping, _, _ []nat.Address) {
+	mapping.OnChanged(func() {
 		t.notifyAddressesChanged(t)
 	})
 	// Should be called after t.mapping is nil'ed out.

+ 31 - 35
lib/nat/service.go

@@ -125,11 +125,15 @@ func (s *Service) process(ctx context.Context) (int, time.Duration) {
 
 	s.mut.RLock()
 	for _, mapping := range s.mappings {
-		if mapping.expires.Before(time.Now()) {
+		mapping.mut.RLock()
+		expires := mapping.expires
+		mapping.mut.RUnlock()
+
+		if expires.Before(time.Now()) {
 			toRenew = append(toRenew, mapping)
 		} else {
 			toUpdate = append(toUpdate, mapping)
-			mappingRenewIn := time.Until(mapping.expires)
+			mappingRenewIn := time.Until(expires)
 			if mappingRenewIn < renewIn {
 				renewIn = mappingRenewIn
 			}
@@ -206,41 +210,36 @@ func (s *Service) RemoveMapping(mapping *Mapping) {
 // Optionally takes renew flag which indicates whether or not we should renew
 // mappings with existing natds
 func (s *Service) updateMapping(ctx context.Context, mapping *Mapping, nats map[string]Device, renew bool) {
-	var added, removed []Address
-
 	renewalTime := time.Duration(s.cfg.Options().NATRenewalM) * time.Minute
-	mapping.expires = time.Now().Add(renewalTime)
 
-	newAdded, newRemoved := s.verifyExistingMappings(ctx, mapping, nats, renew)
-	added = append(added, newAdded...)
-	removed = append(removed, newRemoved...)
+	mapping.mut.Lock()
 
-	newAdded, newRemoved = s.acquireNewMappings(ctx, mapping, nats)
-	added = append(added, newAdded...)
-	removed = append(removed, newRemoved...)
+	mapping.expires = time.Now().Add(renewalTime)
+	change := s.verifyExistingLocked(ctx, mapping, nats, renew)
+	add := s.acquireNewLocked(ctx, mapping, nats)
 
-	if len(added) > 0 || len(removed) > 0 {
-		mapping.notify(added, removed)
+	mapping.mut.Unlock()
+
+	if change || add {
+		mapping.notify()
 	}
 }
 
-func (s *Service) verifyExistingMappings(ctx context.Context, mapping *Mapping, nats map[string]Device, renew bool) ([]Address, []Address) {
-	var added, removed []Address
-
+func (s *Service) verifyExistingLocked(ctx context.Context, mapping *Mapping, nats map[string]Device, renew bool) (change bool) {
 	leaseTime := time.Duration(s.cfg.Options().NATLeaseM) * time.Minute
 
-	for id, address := range mapping.addressMap() {
+	for id, address := range mapping.extAddresses {
 		select {
 		case <-ctx.Done():
-			return nil, nil
+			return false
 		default:
 		}
 
 		// Delete addresses for NATDevice's that do not exist anymore
 		nat, ok := nats[id]
 		if !ok {
-			mapping.removeAddress(id)
-			removed = append(removed, address)
+			mapping.removeAddressLocked(id)
+			change = true
 			continue
 		} else if renew {
 			// Only perform renewals on the nat's that have the right local IP
@@ -256,35 +255,32 @@ func (s *Service) verifyExistingMappings(ctx context.Context, mapping *Mapping,
 			addr, err := s.tryNATDevice(ctx, nat, mapping.address.Port, address.Port, leaseTime)
 			if err != nil {
 				l.Debugf("Failed to renew %s -> mapping on %s", mapping, address, id)
-				mapping.removeAddress(id)
-				removed = append(removed, address)
+				mapping.removeAddressLocked(id)
+				change = true
 				continue
 			}
 
 			l.Debugf("Renewed %s -> %s mapping on %s", mapping, address, id)
 
 			if !addr.Equal(address) {
-				mapping.removeAddress(id)
-				mapping.setAddress(id, addr)
-				removed = append(removed, address)
-				added = append(added, address)
+				mapping.removeAddressLocked(id)
+				mapping.setAddressLocked(id, addr)
+				change = true
 			}
 		}
 	}
 
-	return added, removed
+	return change
 }
 
-func (s *Service) acquireNewMappings(ctx context.Context, mapping *Mapping, nats map[string]Device) ([]Address, []Address) {
-	var added, removed []Address
-
+func (s *Service) acquireNewLocked(ctx context.Context, mapping *Mapping, nats map[string]Device) (change bool) {
 	leaseTime := time.Duration(s.cfg.Options().NATLeaseM) * time.Minute
-	addrMap := mapping.addressMap()
+	addrMap := mapping.extAddresses
 
 	for id, nat := range nats {
 		select {
 		case <-ctx.Done():
-			return nil, nil
+			return false
 		default:
 		}
 
@@ -310,11 +306,11 @@ func (s *Service) acquireNewMappings(ctx context.Context, mapping *Mapping, nats
 
 		l.Debugf("Acquired %s -> %s mapping on %s", mapping, addr, id)
 
-		mapping.setAddress(id, addr)
-		added = append(added, addr)
+		mapping.setAddressLocked(id, addr)
+		change = true
 	}
 
-	return added, removed
+	return change
 }
 
 // tryNATDevice tries to acquire a port mapping for the given internal address to

+ 10 - 24
lib/nat/structs.go

@@ -14,7 +14,7 @@ import (
 	"github.com/syncthing/syncthing/lib/sync"
 )
 
-type MappingChangeSubscriber func(*Mapping, []Address, []Address)
+type MappingChangeSubscriber func()
 
 type Mapping struct {
 	protocol Protocol
@@ -26,55 +26,41 @@ type Mapping struct {
 	mut          sync.RWMutex
 }
 
-func (m *Mapping) setAddress(id string, address Address) {
-	m.mut.Lock()
-	if existing, ok := m.extAddresses[id]; !ok || !existing.Equal(address) {
-		l.Infof("New NAT port mapping: external %s address %s to local address %s.", m.protocol, address, m.address)
-		m.extAddresses[id] = address
-	}
-	m.mut.Unlock()
+func (m *Mapping) setAddressLocked(id string, address Address) {
+	l.Infof("New NAT port mapping: external %s address %s to local address %s.", m.protocol, address, m.address)
+	m.extAddresses[id] = address
 }
 
-func (m *Mapping) removeAddress(id string) {
-	m.mut.Lock()
+func (m *Mapping) removeAddressLocked(id string) {
 	addr, ok := m.extAddresses[id]
 	if ok {
 		l.Infof("Removing NAT port mapping: external %s address %s, NAT %s is no longer available.", m.protocol, addr, id)
 		delete(m.extAddresses, id)
 	}
-	m.mut.Unlock()
 }
 
 func (m *Mapping) clearAddresses() {
 	m.mut.Lock()
-	var removed []Address
+	change := len(m.extAddresses) > 0
 	for id, addr := range m.extAddresses {
 		l.Debugf("Clearing mapping %s: ID: %s Address: %s", m, id, addr)
-		removed = append(removed, addr)
 		delete(m.extAddresses, id)
 	}
 	m.expires = time.Time{}
 	m.mut.Unlock()
-	if len(removed) > 0 {
-		m.notify(nil, removed)
+	if change {
+		m.notify()
 	}
 }
 
-func (m *Mapping) notify(added, removed []Address) {
+func (m *Mapping) notify() {
 	m.mut.RLock()
 	for _, subscriber := range m.subscribers {
-		subscriber(m, added, removed)
+		subscriber()
 	}
 	m.mut.RUnlock()
 }
 
-func (m *Mapping) addressMap() map[string]Address {
-	m.mut.RLock()
-	addrMap := m.extAddresses
-	m.mut.RUnlock()
-	return addrMap
-}
-
 func (m *Mapping) Protocol() Protocol {
 	return m.protocol
 }