Browse Source

tsnet: return from Accept when the listener gets closed

Fixes #14808

Signed-off-by: Anton Tolchanov <[email protected]>
Anton Tolchanov 1 year ago
parent
commit
3abfbf50ae
2 changed files with 35 additions and 3 deletions
  1. 4 3
      tsnet/tsnet.go
  2. 31 0
      tsnet/tsnet_test.go

+ 4 - 3
tsnet/tsnet.go

@@ -1286,11 +1286,12 @@ type listener struct {
 }
 
 func (ln *listener) Accept() (net.Conn, error) {
-	c, ok := <-ln.conn
-	if !ok {
+	select {
+	case c := <-ln.conn:
+		return c, nil
+	case <-ln.closedc:
 		return nil, fmt.Errorf("tsnet: %w", net.ErrClosed)
 	}
-	return c, nil
 }
 
 func (ln *listener) Addr() net.Addr { return addr{ln} }

+ 31 - 0
tsnet/tsnet_test.go

@@ -667,6 +667,37 @@ func TestFunnel(t *testing.T) {
 	}
 }
 
+func TestListenerClose(t *testing.T) {
+	ctx := context.Background()
+	controlURL, _ := startControl(t)
+
+	s1, _, _ := startServer(t, ctx, controlURL, "s1")
+
+	ln, err := s1.Listen("tcp", ":8080")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	errc := make(chan error, 1)
+	go func() {
+		c, err := ln.Accept()
+		if c != nil {
+			c.Close()
+		}
+		errc <- err
+	}()
+
+	ln.Close()
+	select {
+	case err := <-errc:
+		if !errors.Is(err, net.ErrClosed) {
+			t.Errorf("unexpected error: %v", err)
+		}
+	case <-time.After(10 * time.Second):
+		t.Fatal("timeout waiting for Accept to return")
+	}
+}
+
 func dialIngressConn(from, to *Server, target string) (net.Conn, error) {
 	toLC := must.Get(to.LocalClient())
 	toStatus := must.Get(toLC.StatusWithoutPeers(context.Background()))