Browse Source

tsnet: avoid deadlock on close

tsnet.Server.Close was calling listener.Close with the server mutex
held, but the listener close method tries to grab that mutex, resulting
in a deadlock.

Co-authored-by: David Crawshaw <[email protected]>
Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 3 years ago
parent
commit
b4d3e2928b
2 changed files with 48 additions and 8 deletions
  1. 24 8
      tsnet/tsnet.go
  2. 24 0
      tsnet/tsnet_test.go

+ 24 - 8
tsnet/tsnet.go

@@ -118,6 +118,7 @@ type Server struct {
 	mu        sync.Mutex
 	listeners map[listenKey]*listener
 	dialer    *tsdial.Dialer
+	closed    bool
 }
 
 // Dial connects to the address on the tailnet.
@@ -303,6 +304,11 @@ func (s *Server) Up(ctx context.Context) (*ipnstate.Status, error) {
 //
 // It must not be called before or concurrently with Start.
 func (s *Server) Close() error {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	if s.closed {
+		return fmt.Errorf("tsnet: %w", net.ErrClosed)
+	}
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 	defer cancel()
 	var wg sync.WaitGroup
@@ -350,14 +356,12 @@ func (s *Server) Close() error {
 		s.loopbackListener.Close()
 	}
 
-	s.mu.Lock()
-	defer s.mu.Unlock()
 	for _, ln := range s.listeners {
-		ln.Close()
+		ln.closeLocked()
 	}
-	s.listeners = nil
 
 	wg.Wait()
+	s.closed = true
 	return nil
 }
 
@@ -1017,10 +1021,11 @@ type listenKey struct {
 }
 
 type listener struct {
-	s    *Server
-	keys []listenKey
-	addr string
-	conn chan net.Conn
+	s      *Server
+	keys   []listenKey
+	addr   string
+	conn   chan net.Conn
+	closed bool // guarded by s.mu
 }
 
 func (ln *listener) Accept() (net.Conn, error) {
@@ -1032,15 +1037,26 @@ func (ln *listener) Accept() (net.Conn, error) {
 }
 
 func (ln *listener) Addr() net.Addr { return addr{ln} }
+
 func (ln *listener) Close() error {
 	ln.s.mu.Lock()
 	defer ln.s.mu.Unlock()
+	return ln.closeLocked()
+}
+
+// closeLocked closes the listener.
+// It must be called with ln.s.mu held.
+func (ln *listener) closeLocked() error {
+	if ln.closed {
+		return fmt.Errorf("tsnet: %w", net.ErrClosed)
+	}
 	for _, key := range ln.keys {
 		if v, ok := ln.s.listeners[key]; ok && v == ln {
 			delete(ln.s.listeners, key)
 		}
 	}
 	close(ln.conn)
+	ln.closed = true
 	return nil
 }
 

+ 24 - 0
tsnet/tsnet_test.go

@@ -9,6 +9,7 @@ import (
 	"flag"
 	"fmt"
 	"io"
+	"net"
 	"net/http"
 	"net/http/httptest"
 	"net/netip"
@@ -344,3 +345,26 @@ func TestTailscaleIPs(t *testing.T) {
 			sIp4, upIp4, sIp6, upIp6)
 	}
 }
+
+// TestListenerCleanup is a regression test to verify that s.Close doesn't
+// deadlock if a listener is still open.
+func TestListenerCleanup(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+	defer cancel()
+
+	controlURL := startControl(t)
+	s1, _ := startServer(t, ctx, controlURL, "s1")
+
+	ln, err := s1.Listen("tcp", ":8081")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if err := s1.Close(); err != nil {
+		t.Fatal(err)
+	}
+
+	if err := ln.Close(); !errors.Is(err, net.ErrClosed) {
+		t.Fatalf("second ln.Close error: %v, want net.ErrClosed", err)
+	}
+}