Browse Source

lib/connections: Correct service termination order (#7657)

Audrius Butkevicius 4 years ago
parent
commit
411796606c
2 changed files with 16 additions and 16 deletions
  1. 8 7
      lib/connections/quic_listen.go
  2. 8 9
      lib/connections/tcp_listen.go

+ 8 - 7
lib/connections/quic_listen.go

@@ -88,8 +88,10 @@ func (t *quicListener) serve(ctx context.Context) error {
 		l.Infoln("Listen (BEP/quic):", err)
 		l.Infoln("Listen (BEP/quic):", err)
 		return err
 		return err
 	}
 	}
+	defer packetConn.Close()
 
 
 	svc, conn := stun.New(t.cfg, t, packetConn)
 	svc, conn := stun.New(t.cfg, t, packetConn)
+	defer conn.Close()
 	wrapped := &stunConnQUICWrapper{
 	wrapped := &stunConnQUICWrapper{
 		PacketConn: conn,
 		PacketConn: conn,
 		underlying: packetConn.(*net.UDPConn),
 		underlying: packetConn.(*net.UDPConn),
@@ -98,29 +100,28 @@ func (t *quicListener) serve(ctx context.Context) error {
 	go svc.Serve(ctx)
 	go svc.Serve(ctx)
 
 
 	registry.Register(t.uri.Scheme, wrapped)
 	registry.Register(t.uri.Scheme, wrapped)
+	defer registry.Unregister(t.uri.Scheme, wrapped)
 
 
 	listener, err := quic.Listen(wrapped, t.tlsCfg, quicConfig)
 	listener, err := quic.Listen(wrapped, t.tlsCfg, quicConfig)
 	if err != nil {
 	if err != nil {
 		l.Infoln("Listen (BEP/quic):", err)
 		l.Infoln("Listen (BEP/quic):", err)
 		return err
 		return err
 	}
 	}
+	defer listener.Close()
+
 	t.notifyAddressesChanged(t)
 	t.notifyAddressesChanged(t)
+	defer t.clearAddresses(t)
 
 
 	l.Infof("QUIC listener (%v) starting", packetConn.LocalAddr())
 	l.Infof("QUIC listener (%v) starting", packetConn.LocalAddr())
+	defer l.Infof("QUIC listener (%v) shutting down", packetConn.LocalAddr())
+
 	t.mut.Lock()
 	t.mut.Lock()
 	t.laddr = packetConn.LocalAddr()
 	t.laddr = packetConn.LocalAddr()
 	t.mut.Unlock()
 	t.mut.Unlock()
-
 	defer func() {
 	defer func() {
-		l.Infof("QUIC listener (%v) shutting down", packetConn.LocalAddr())
 		t.mut.Lock()
 		t.mut.Lock()
 		t.laddr = nil
 		t.laddr = nil
 		t.mut.Unlock()
 		t.mut.Unlock()
-		registry.Unregister(t.uri.Scheme, wrapped)
-		t.clearAddresses(t)
-		_ = listener.Close()
-		_ = conn.Close()
-		_ = packetConn.Close()
 	}()
 	}()
 
 
 	acceptFailures := 0
 	acceptFailures := 0

+ 8 - 9
lib/connections/tcp_listen.go

@@ -61,34 +61,36 @@ func (t *tcpListener) serve(ctx context.Context) error {
 		l.Infoln("Listen (BEP/tcp):", err)
 		l.Infoln("Listen (BEP/tcp):", err)
 		return err
 		return err
 	}
 	}
+	defer listener.Close()
 
 
 	// We might bind to :0, so use the port we've been given.
 	// We might bind to :0, so use the port we've been given.
 	tcaddr = listener.Addr().(*net.TCPAddr)
 	tcaddr = listener.Addr().(*net.TCPAddr)
 
 
 	t.notifyAddressesChanged(t)
 	t.notifyAddressesChanged(t)
+	defer t.clearAddresses(t)
+
 	registry.Register(t.uri.Scheme, tcaddr)
 	registry.Register(t.uri.Scheme, tcaddr)
+	defer registry.Unregister(t.uri.Scheme, tcaddr)
 
 
 	l.Infof("TCP listener (%v) starting", tcaddr)
 	l.Infof("TCP listener (%v) starting", tcaddr)
+	defer l.Infof("TCP listener (%v) shutting down", tcaddr)
 
 
 	mapping := t.natService.NewMapping(nat.TCP, tcaddr.IP, tcaddr.Port)
 	mapping := t.natService.NewMapping(nat.TCP, tcaddr.IP, tcaddr.Port)
 	mapping.OnChanged(func(_ *nat.Mapping, _, _ []nat.Address) {
 	mapping.OnChanged(func(_ *nat.Mapping, _, _ []nat.Address) {
 		t.notifyAddressesChanged(t)
 		t.notifyAddressesChanged(t)
 	})
 	})
+	// Should be called after t.mapping is nil'ed out.
+	defer t.natService.RemoveMapping(mapping)
 
 
 	t.mut.Lock()
 	t.mut.Lock()
 	t.mapping = mapping
 	t.mapping = mapping
 	t.laddr = tcaddr
 	t.laddr = tcaddr
 	t.mut.Unlock()
 	t.mut.Unlock()
-
 	defer func() {
 	defer func() {
-		l.Infof("TCP listener (%v) shutting down", tcaddr)
-		t.natService.RemoveMapping(mapping)
 		t.mut.Lock()
 		t.mut.Lock()
+		t.mapping = nil
 		t.laddr = nil
 		t.laddr = nil
 		t.mut.Unlock()
 		t.mut.Unlock()
-		registry.Unregister(t.uri.Scheme, tcaddr)
-		t.clearAddresses(t)
-		_ = listener.Close()
 	}()
 	}()
 
 
 	acceptFailures := 0
 	acceptFailures := 0
@@ -105,9 +107,6 @@ func (t *tcpListener) serve(ctx context.Context) error {
 			if err == nil {
 			if err == nil {
 				conn.Close()
 				conn.Close()
 			}
 			}
-			t.mut.Lock()
-			t.mapping = nil
-			t.mut.Unlock()
 			return nil
 			return nil
 		default:
 		default:
 		}
 		}