Browse Source

lib: Ensure timely service termination (fixes #5860) (#5863)

Simon Frei 6 years ago
parent
commit
4d3432af3e
8 changed files with 174 additions and 150 deletions
  1. 0 19
      lib/beacon/beacon.go
  2. 51 58
      lib/beacon/broadcast.go
  3. 41 32
      lib/beacon/multicast.go
  4. 18 6
      lib/nat/registry.go
  5. 34 23
      lib/nat/service.go
  6. 17 10
      lib/relay/client/dynamic.go
  7. 6 2
      lib/stun/stun.go
  8. 7 0
      lib/util/utils.go

+ 0 - 19
lib/beacon/beacon.go

@@ -8,7 +8,6 @@ package beacon
 
 import (
 	"net"
-	stdsync "sync"
 
 	"github.com/thejerf/suture"
 )
@@ -24,21 +23,3 @@ type Interface interface {
 	Recv() ([]byte, net.Addr)
 	Error() error
 }
-
-type errorHolder struct {
-	err error
-	mut stdsync.Mutex // uses stdlib sync as I want this to be trivially embeddable, and there is no risk of blocking
-}
-
-func (e *errorHolder) setError(err error) {
-	e.mut.Lock()
-	e.err = err
-	e.mut.Unlock()
-}
-
-func (e *errorHolder) Error() error {
-	e.mut.Lock()
-	err := e.err
-	e.mut.Unlock()
-	return err
-}

+ 51 - 58
lib/beacon/broadcast.go

@@ -11,8 +11,9 @@ import (
 	"net"
 	"time"
 
-	"github.com/syncthing/syncthing/lib/sync"
 	"github.com/thejerf/suture"
+
+	"github.com/syncthing/syncthing/lib/util"
 )
 
 type Broadcast struct {
@@ -44,16 +45,16 @@ func NewBroadcast(port int) *Broadcast {
 	}
 
 	b.br = &broadcastReader{
-		port:    port,
-		outbox:  b.outbox,
-		connMut: sync.NewMutex(),
+		port:   port,
+		outbox: b.outbox,
 	}
+	b.br.ServiceWithError = util.AsServiceWithError(b.br.serve)
 	b.Add(b.br)
 	b.bw = &broadcastWriter{
-		port:    port,
-		inbox:   b.inbox,
-		connMut: sync.NewMutex(),
+		port:  port,
+		inbox: b.inbox,
 	}
+	b.bw.ServiceWithError = util.AsServiceWithError(b.bw.serve)
 	b.Add(b.bw)
 
 	return b
@@ -76,34 +77,42 @@ func (b *Broadcast) Error() error {
 }
 
 type broadcastWriter struct {
-	port    int
-	inbox   chan []byte
-	conn    *net.UDPConn
-	connMut sync.Mutex
-	errorHolder
+	util.ServiceWithError
+	port  int
+	inbox chan []byte
 }
 
-func (w *broadcastWriter) Serve() {
+func (w *broadcastWriter) serve(stop chan struct{}) error {
 	l.Debugln(w, "starting")
 	defer l.Debugln(w, "stopping")
 
 	conn, err := net.ListenUDP("udp4", nil)
 	if err != nil {
 		l.Debugln(err)
-		w.setError(err)
-		return
+		return err
 	}
-	defer conn.Close()
+	done := make(chan struct{})
+	defer close(done)
+	go func() {
+		select {
+		case <-stop:
+		case <-done:
+		}
+		conn.Close()
+	}()
 
-	w.connMut.Lock()
-	w.conn = conn
-	w.connMut.Unlock()
+	for {
+		var bs []byte
+		select {
+		case bs = <-w.inbox:
+		case <-stop:
+			return nil
+		}
 
-	for bs := range w.inbox {
 		addrs, err := net.InterfaceAddrs()
 		if err != nil {
 			l.Debugln(err)
-			w.setError(err)
+			w.SetError(err)
 			continue
 		}
 
@@ -134,14 +143,13 @@ func (w *broadcastWriter) Serve() {
 				// Write timeouts should not happen. We treat it as a fatal
 				// error on the socket.
 				l.Debugln(err)
-				w.setError(err)
-				return
+				return err
 			}
 
 			if err != nil {
 				// Some other error that we don't expect. Debug and continue.
 				l.Debugln(err)
-				w.setError(err)
+				w.SetError(err)
 				continue
 			}
 
@@ -150,57 +158,49 @@ func (w *broadcastWriter) Serve() {
 		}
 
 		if success > 0 {
-			w.setError(nil)
+			w.SetError(nil)
 		}
 	}
 }
 
-func (w *broadcastWriter) Stop() {
-	w.connMut.Lock()
-	if w.conn != nil {
-		w.conn.Close()
-	}
-	w.connMut.Unlock()
-}
-
 func (w *broadcastWriter) String() string {
 	return fmt.Sprintf("broadcastWriter@%p", w)
 }
 
 type broadcastReader struct {
-	port    int
-	outbox  chan recv
-	conn    *net.UDPConn
-	connMut sync.Mutex
-	errorHolder
+	util.ServiceWithError
+	port   int
+	outbox chan recv
 }
 
-func (r *broadcastReader) Serve() {
+func (r *broadcastReader) serve(stop chan struct{}) error {
 	l.Debugln(r, "starting")
 	defer l.Debugln(r, "stopping")
 
 	conn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: r.port})
 	if err != nil {
 		l.Debugln(err)
-		r.setError(err)
-		return
+		return err
 	}
-	defer conn.Close()
-
-	r.connMut.Lock()
-	r.conn = conn
-	r.connMut.Unlock()
+	done := make(chan struct{})
+	defer close(done)
+	go func() {
+		select {
+		case <-stop:
+		case <-done:
+		}
+		conn.Close()
+	}()
 
 	bs := make([]byte, 65536)
 	for {
 		n, addr, err := conn.ReadFrom(bs)
 		if err != nil {
 			l.Debugln(err)
-			r.setError(err)
-			return
+			return err
 		}
 
-		r.setError(nil)
+		r.SetError(nil)
 
 		l.Debugf("recv %d bytes from %s", n, addr)
 
@@ -208,19 +208,12 @@ func (r *broadcastReader) Serve() {
 		copy(c, bs)
 		select {
 		case r.outbox <- recv{c, addr}:
+		case <-stop:
+			return nil
 		default:
 			l.Debugln("dropping message")
 		}
 	}
-
-}
-
-func (r *broadcastReader) Stop() {
-	r.connMut.Lock()
-	if r.conn != nil {
-		r.conn.Close()
-	}
-	r.connMut.Unlock()
 }
 
 func (r *broadcastReader) String() string {

+ 41 - 32
lib/beacon/multicast.go

@@ -48,14 +48,14 @@ func NewMulticast(addr string) *Multicast {
 		addr:   addr,
 		outbox: m.outbox,
 	}
-	m.mr.Service = util.AsService(m.mr.serve)
+	m.mr.ServiceWithError = util.AsServiceWithError(m.mr.serve)
 	m.Add(m.mr)
 
 	m.mw = &multicastWriter{
 		addr:  addr,
 		inbox: m.inbox,
 	}
-	m.mw.Service = util.AsService(m.mw.serve)
+	m.mw.ServiceWithError = util.AsServiceWithError(m.mw.serve)
 	m.Add(m.mw)
 
 	return m
@@ -78,29 +78,35 @@ func (m *Multicast) Error() error {
 }
 
 type multicastWriter struct {
-	suture.Service
+	util.ServiceWithError
 	addr  string
 	inbox <-chan []byte
-	errorHolder
 }
 
-func (w *multicastWriter) serve(stop chan struct{}) {
+func (w *multicastWriter) serve(stop chan struct{}) error {
 	l.Debugln(w, "starting")
 	defer l.Debugln(w, "stopping")
 
 	gaddr, err := net.ResolveUDPAddr("udp6", w.addr)
 	if err != nil {
 		l.Debugln(err)
-		w.setError(err)
-		return
+		return err
 	}
 
 	conn, err := net.ListenPacket("udp6", ":0")
 	if err != nil {
 		l.Debugln(err)
-		w.setError(err)
-		return
+		return err
 	}
+	done := make(chan struct{})
+	defer close(done)
+	go func() {
+		select {
+		case <-stop:
+		case <-done:
+		}
+		conn.Close()
+	}()
 
 	pconn := ipv6.NewPacketConn(conn)
 
@@ -113,14 +119,13 @@ func (w *multicastWriter) serve(stop chan struct{}) {
 		select {
 		case bs = <-w.inbox:
 		case <-stop:
-			return
+			return nil
 		}
 
 		intfs, err := net.Interfaces()
 		if err != nil {
 			l.Debugln(err)
-			w.setError(err)
-			return
+			return err
 		}
 
 		success := 0
@@ -132,7 +137,7 @@ func (w *multicastWriter) serve(stop chan struct{}) {
 
 			if err != nil {
 				l.Debugln(err, "on write to", gaddr, intf.Name)
-				w.setError(err)
+				w.SetError(err)
 				continue
 			}
 
@@ -142,16 +147,13 @@ func (w *multicastWriter) serve(stop chan struct{}) {
 
 			select {
 			case <-stop:
-				return
+				return nil
 			default:
 			}
 		}
 
 		if success > 0 {
-			w.setError(nil)
-		} else {
-			l.Debugln(err)
-			w.setError(err)
+			w.SetError(nil)
 		}
 	}
 }
@@ -161,35 +163,40 @@ func (w *multicastWriter) String() string {
 }
 
 type multicastReader struct {
-	suture.Service
+	util.ServiceWithError
 	addr   string
 	outbox chan<- recv
-	errorHolder
 }
 
-func (r *multicastReader) serve(stop chan struct{}) {
+func (r *multicastReader) serve(stop chan struct{}) error {
 	l.Debugln(r, "starting")
 	defer l.Debugln(r, "stopping")
 
 	gaddr, err := net.ResolveUDPAddr("udp6", r.addr)
 	if err != nil {
 		l.Debugln(err)
-		r.setError(err)
-		return
+		return err
 	}
 
 	conn, err := net.ListenPacket("udp6", r.addr)
 	if err != nil {
 		l.Debugln(err)
-		r.setError(err)
-		return
+		return err
 	}
+	done := make(chan struct{})
+	defer close(done)
+	go func() {
+		select {
+		case <-stop:
+		case <-done:
+		}
+		conn.Close()
+	}()
 
 	intfs, err := net.Interfaces()
 	if err != nil {
 		l.Debugln(err)
-		r.setError(err)
-		return
+		return err
 	}
 
 	pconn := ipv6.NewPacketConn(conn)
@@ -206,16 +213,20 @@ func (r *multicastReader) serve(stop chan struct{}) {
 
 	if joined == 0 {
 		l.Debugln("no multicast interfaces available")
-		r.setError(errors.New("no multicast interfaces available"))
-		return
+		return errors.New("no multicast interfaces available")
 	}
 
 	bs := make([]byte, 65536)
 	for {
+		select {
+		case <-stop:
+			return nil
+		default:
+		}
 		n, _, addr, err := pconn.ReadFrom(bs)
 		if err != nil {
 			l.Debugln(err)
-			r.setError(err)
+			r.SetError(err)
 			continue
 		}
 		l.Debugf("recv %d bytes from %s", n, addr)
@@ -224,8 +235,6 @@ func (r *multicastReader) serve(stop chan struct{}) {
 		copy(c, bs)
 		select {
 		case r.outbox <- recv{c, addr}:
-		case <-stop:
-			return
 		default:
 			l.Debugln("dropping message")
 		}

+ 18 - 6
lib/nat/registry.go

@@ -19,7 +19,7 @@ func Register(provider DiscoverFunc) {
 	providers = append(providers, provider)
 }
 
-func discoverAll(renewal, timeout time.Duration) map[string]Device {
+func discoverAll(renewal, timeout time.Duration, stop chan struct{}) map[string]Device {
 	wg := &sync.WaitGroup{}
 	wg.Add(len(providers))
 
@@ -28,20 +28,32 @@ func discoverAll(renewal, timeout time.Duration) map[string]Device {
 
 	for _, discoverFunc := range providers {
 		go func(f DiscoverFunc) {
+			defer wg.Done()
 			for _, dev := range f(renewal, timeout) {
-				c <- dev
+				select {
+				case c <- dev:
+				case <-stop:
+					return
+				}
 			}
-			wg.Done()
 		}(discoverFunc)
 	}
 
 	nats := make(map[string]Device)
 
 	go func() {
-		for dev := range c {
-			nats[dev.ID()] = dev
+		defer close(done)
+		for {
+			select {
+			case dev, ok := <-c:
+				if !ok {
+					return
+				}
+				nats[dev.ID()] = dev
+			case <-stop:
+				return
+			}
 		}
-		close(done)
 	}()
 
 	wg.Wait()

+ 34 - 23
lib/nat/service.go

@@ -14,17 +14,21 @@ import (
 	stdsync "sync"
 	"time"
 
+	"github.com/thejerf/suture"
+
 	"github.com/syncthing/syncthing/lib/config"
 	"github.com/syncthing/syncthing/lib/protocol"
 	"github.com/syncthing/syncthing/lib/sync"
+	"github.com/syncthing/syncthing/lib/util"
 )
 
 // Service runs a loop for discovery of IGDs (Internet Gateway Devices) and
 // setup/renewal of a port mapping.
 type Service struct {
-	id   protocol.DeviceID
-	cfg  config.Wrapper
-	stop chan struct{}
+	suture.Service
+
+	id  protocol.DeviceID
+	cfg config.Wrapper
 
 	mappings []*Mapping
 	timer    *time.Timer
@@ -32,27 +36,28 @@ type Service struct {
 }
 
 func NewService(id protocol.DeviceID, cfg config.Wrapper) *Service {
-	return &Service{
+	s := &Service{
 		id:  id,
 		cfg: cfg,
 
 		timer: time.NewTimer(0),
 		mut:   sync.NewRWMutex(),
 	}
+	s.Service = util.AsService(s.serve)
+	return s
 }
 
-func (s *Service) Serve() {
+func (s *Service) serve(stop chan struct{}) {
 	announce := stdsync.Once{}
 
 	s.mut.Lock()
 	s.timer.Reset(0)
-	s.stop = make(chan struct{})
 	s.mut.Unlock()
 
 	for {
 		select {
 		case <-s.timer.C:
-			if found := s.process(); found != -1 {
+			if found := s.process(stop); found != -1 {
 				announce.Do(func() {
 					suffix := "s"
 					if found == 1 {
@@ -61,7 +66,7 @@ func (s *Service) Serve() {
 					l.Infoln("Detected", found, "NAT service"+suffix)
 				})
 			}
-		case <-s.stop:
+		case <-stop:
 			s.timer.Stop()
 			s.mut.RLock()
 			for _, mapping := range s.mappings {
@@ -73,7 +78,7 @@ func (s *Service) Serve() {
 	}
 }
 
-func (s *Service) process() int {
+func (s *Service) process(stop chan struct{}) int {
 	// toRenew are mappings which are due for renewal
 	// toUpdate are the remaining mappings, which will only be updated if one of
 	// the old IGDs has gone away, or a new IGD has appeared, but only if we
@@ -115,25 +120,19 @@ func (s *Service) process() int {
 		return -1
 	}
 
-	nats := discoverAll(time.Duration(s.cfg.Options().NATRenewalM)*time.Minute, time.Duration(s.cfg.Options().NATTimeoutS)*time.Second)
+	nats := discoverAll(time.Duration(s.cfg.Options().NATRenewalM)*time.Minute, time.Duration(s.cfg.Options().NATTimeoutS)*time.Second, stop)
 
 	for _, mapping := range toRenew {
-		s.updateMapping(mapping, nats, true)
+		s.updateMapping(mapping, nats, true, stop)
 	}
 
 	for _, mapping := range toUpdate {
-		s.updateMapping(mapping, nats, false)
+		s.updateMapping(mapping, nats, false, stop)
 	}
 
 	return len(nats)
 }
 
-func (s *Service) Stop() {
-	s.mut.RLock()
-	close(s.stop)
-	s.mut.RUnlock()
-}
-
 func (s *Service) NewMapping(protocol Protocol, ip net.IP, port int) *Mapping {
 	mapping := &Mapping{
 		protocol: protocol,
@@ -178,17 +177,17 @@ func (s *Service) RemoveMapping(mapping *Mapping) {
 // acquire mappings for natds which the mapping was unaware of before.
 // Optionally takes renew flag which indicates whether or not we should renew
 // mappings with existing natds
-func (s *Service) updateMapping(mapping *Mapping, nats map[string]Device, renew bool) {
+func (s *Service) updateMapping(mapping *Mapping, nats map[string]Device, renew bool, stop chan struct{}) {
 	var added, removed []Address
 
 	renewalTime := time.Duration(s.cfg.Options().NATRenewalM) * time.Minute
 	mapping.expires = time.Now().Add(renewalTime)
 
-	newAdded, newRemoved := s.verifyExistingMappings(mapping, nats, renew)
+	newAdded, newRemoved := s.verifyExistingMappings(mapping, nats, renew, stop)
 	added = append(added, newAdded...)
 	removed = append(removed, newRemoved...)
 
-	newAdded, newRemoved = s.acquireNewMappings(mapping, nats)
+	newAdded, newRemoved = s.acquireNewMappings(mapping, nats, stop)
 	added = append(added, newAdded...)
 	removed = append(removed, newRemoved...)
 
@@ -197,12 +196,18 @@ func (s *Service) updateMapping(mapping *Mapping, nats map[string]Device, renew
 	}
 }
 
-func (s *Service) verifyExistingMappings(mapping *Mapping, nats map[string]Device, renew bool) ([]Address, []Address) {
+func (s *Service) verifyExistingMappings(mapping *Mapping, nats map[string]Device, renew bool, stop chan struct{}) ([]Address, []Address) {
 	var added, removed []Address
 
 	leaseTime := time.Duration(s.cfg.Options().NATLeaseM) * time.Minute
 
 	for id, address := range mapping.addressMap() {
+		select {
+		case <-stop:
+			return nil, nil
+		default:
+		}
+
 		// Delete addresses for NATDevice's that do not exist anymore
 		nat, ok := nats[id]
 		if !ok {
@@ -242,13 +247,19 @@ func (s *Service) verifyExistingMappings(mapping *Mapping, nats map[string]Devic
 	return added, removed
 }
 
-func (s *Service) acquireNewMappings(mapping *Mapping, nats map[string]Device) ([]Address, []Address) {
+func (s *Service) acquireNewMappings(mapping *Mapping, nats map[string]Device, stop chan struct{}) ([]Address, []Address) {
 	var added, removed []Address
 
 	leaseTime := time.Duration(s.cfg.Options().NATLeaseM) * time.Minute
 	addrMap := mapping.addressMap()
 
 	for id, nat := range nats {
+		select {
+		case <-stop:
+			return nil, nil
+		default:
+		}
+
 		if _, ok := addrMap[id]; ok {
 			continue
 		}

+ 17 - 10
lib/relay/client/dynamic.go

@@ -69,15 +69,7 @@ func (c *dynamicClient) serve(stop chan struct{}) error {
 		addrs = append(addrs, ruri.String())
 	}
 
-	defer func() {
-		c.mut.RLock()
-		if c.client != nil {
-			c.client.Stop()
-		}
-		c.mut.RUnlock()
-	}()
-
-	for _, addr := range relayAddressesOrder(addrs) {
+	for _, addr := range relayAddressesOrder(addrs, stop) {
 		select {
 		case <-stop:
 			l.Debugln(c, "stopping")
@@ -104,6 +96,15 @@ func (c *dynamicClient) serve(stop chan struct{}) error {
 	return fmt.Errorf("could not find a connectable relay")
 }
 
+func (c *dynamicClient) Stop() {
+	c.mut.RLock()
+	if c.client != nil {
+		c.client.Stop()
+	}
+	c.mut.RUnlock()
+	c.commonClient.Stop()
+}
+
 func (c *dynamicClient) Error() error {
 	c.mut.RLock()
 	defer c.mut.RUnlock()
@@ -147,7 +148,7 @@ type dynamicAnnouncement struct {
 // the closest 50ms, and puts them in buckets of 50ms latency ranges. Then
 // shuffles each bucket, and returns all addresses starting with the ones from
 // the lowest latency bucket, ending with the highest latency buceket.
-func relayAddressesOrder(input []string) []string {
+func relayAddressesOrder(input []string, stop chan struct{}) []string {
 	buckets := make(map[int][]string)
 
 	for _, relay := range input {
@@ -159,6 +160,12 @@ func relayAddressesOrder(input []string) []string {
 		id := int(latency/time.Millisecond) / 50
 
 		buckets[id] = append(buckets[id], relay)
+
+		select {
+		case <-stop:
+			return nil
+		default:
+		}
 	}
 
 	var ids []int

+ 6 - 2
lib/stun/stun.go

@@ -109,8 +109,8 @@ func New(cfg config.Wrapper, subscriber Subscriber, conn net.PacketConn) (*Servi
 }
 
 func (s *Service) Stop() {
-	s.Service.Stop()
 	_ = s.stunConn.Close()
+	s.Service.Stop()
 }
 
 func (s *Service) serve(stop chan struct{}) {
@@ -163,7 +163,11 @@ func (s *Service) serve(stop chan struct{}) {
 
 		// We failed to contact all provided stun servers or the nat is not punchable.
 		// Chillout for a while.
-		time.Sleep(stunRetryInterval)
+		select {
+		case <-time.After(stunRetryInterval):
+		case <-stop:
+			return
+		}
 	}
 }
 

+ 7 - 0
lib/util/utils.go

@@ -187,6 +187,7 @@ func AsService(fn func(stop chan struct{})) suture.Service {
 type ServiceWithError interface {
 	suture.Service
 	Error() error
+	SetError(error)
 }
 
 // AsServiceWithError does the same as AsService, except that it keeps track
@@ -244,3 +245,9 @@ func (s *service) Error() error {
 	defer s.mut.Unlock()
 	return s.err
 }
+
+func (s *service) SetError(err error) {
+	s.mut.Lock()
+	s.err = err
+	s.mut.Unlock()
+}