Browse Source

nettest, *: add option to run HTTP tests with in-memory network

To avoid ephemeral port / TIME_WAIT exhaustion with high --count
values, and to eventually detect leaked connections in tests. (Later
the memory network will register a Cleanup on the TB to verify that
everything's been shut down)

Updates tailscale/corp#27636

Change-Id: Id06f1ae750d8719c5a75d871654574a8226d2733
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 11 months ago
parent
commit
c76d075472

+ 4 - 4
client/local/local_test.go

@@ -9,10 +9,10 @@ import (
 	"context"
 	"net"
 	"net/http"
-	"net/http/httptest"
 	"testing"
 
 	"tailscale.com/tstest/deptest"
+	"tailscale.com/tstest/nettest"
 	"tailscale.com/types/key"
 )
 
@@ -36,15 +36,15 @@ func TestGetServeConfigFromJSON(t *testing.T) {
 }
 
 func TestWhoIsPeerNotFound(t *testing.T) {
-	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+	nw := nettest.GetNetwork(t)
+	ts := nettest.NewHTTPServer(nw, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		w.WriteHeader(404)
 	}))
 	defer ts.Close()
 
 	lc := &Client{
 		Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
-			var std net.Dialer
-			return std.DialContext(ctx, network, ts.Listener.Addr().(*net.TCPAddr).String())
+			return nw.Dial(ctx, network, ts.Listener.Addr().String())
 		},
 	}
 	var k key.NodePublic

+ 2 - 1
client/web/web_test.go

@@ -28,6 +28,7 @@ import (
 	"tailscale.com/ipn/ipnstate"
 	"tailscale.com/net/memnet"
 	"tailscale.com/tailcfg"
+	"tailscale.com/tstest/nettest"
 	"tailscale.com/types/views"
 	"tailscale.com/util/httpm"
 )
@@ -1508,7 +1509,7 @@ func TestCSRFProtect(t *testing.T) {
 		}
 	})
 	h := s.withCSRF(mux)
-	ser := httptest.NewServer(h)
+	ser := nettest.NewHTTPServer(nettest.GetNetwork(t), h)
 	defer ser.Close()
 
 	jar, err := cookiejar.New(nil)

+ 7 - 2
control/controlclient/noise_test.go

@@ -10,7 +10,6 @@ import (
 	"io"
 	"math"
 	"net/http"
-	"net/http/httptest"
 	"testing"
 	"time"
 
@@ -20,6 +19,7 @@ import (
 	"tailscale.com/net/netmon"
 	"tailscale.com/net/tsdial"
 	"tailscale.com/tailcfg"
+	"tailscale.com/tstest/nettest"
 	"tailscale.com/types/key"
 	"tailscale.com/types/logger"
 )
@@ -178,7 +178,8 @@ func (tt noiseClientTest) run(t *testing.T) {
 
 	const msg = "Hello, client"
 	h2 := &http2.Server{}
-	hs := httptest.NewServer(&Upgrader{
+	nw := nettest.GetNetwork(t)
+	hs := nettest.NewHTTPServer(nw, &Upgrader{
 		h2srv:            h2,
 		noiseKeyPriv:     serverPrivate,
 		sendEarlyPayload: tt.sendEarlyPayload,
@@ -193,6 +194,10 @@ func (tt noiseClientTest) run(t *testing.T) {
 	defer hs.Close()
 
 	dialer := tsdial.NewDialer(netmon.NewStatic())
+	if nettest.PreferMemNetwork() {
+		dialer.SetSystemDialerForTest(nw.Dial)
+	}
+
 	nc, err := NewNoiseClient(NoiseOpts{
 		PrivKey:      clientPrivate,
 		ServerPubKey: serverPrivate.Public(),

+ 18 - 4
net/tsdial/tsdial.go

@@ -71,6 +71,7 @@ type Dialer struct {
 
 	netnsDialerOnce sync.Once
 	netnsDialer     netns.Dialer
+	sysDialForTest  func(_ context.Context, network, addr string) (net.Conn, error) // or nil
 
 	routes atomic.Pointer[bart.Table[bool]] // or nil if UserDial should not use routes. `true` indicates routes that point into the Tailscale interface
 
@@ -361,6 +362,13 @@ func (d *Dialer) logf(format string, args ...any) {
 	}
 }
 
+// SetSystemDialerForTest sets an alternate function to use for SystemDial
+// instead of netns.Dialer. This is intended for use with nettest.MemoryNetwork.
+func (d *Dialer) SetSystemDialerForTest(fn func(ctx context.Context, network, addr string) (net.Conn, error)) {
+	testenv.AssertInTest()
+	d.sysDialForTest = fn
+}
+
 // SystemDial connects to the provided network address without going over
 // Tailscale. It prefers going over the default interface and closes existing
 // connections if the default interface changes. It is used to connect to
@@ -380,10 +388,16 @@ func (d *Dialer) SystemDial(ctx context.Context, network, addr string) (net.Conn
 		return nil, net.ErrClosed
 	}
 
-	d.netnsDialerOnce.Do(func() {
-		d.netnsDialer = netns.NewDialer(d.logf, d.netMon)
-	})
-	c, err := d.netnsDialer.DialContext(ctx, network, addr)
+	var c net.Conn
+	var err error
+	if d.sysDialForTest != nil {
+		c, err = d.sysDialForTest(ctx, network, addr)
+	} else {
+		d.netnsDialerOnce.Do(func() {
+			d.netnsDialer = netns.NewDialer(d.logf, d.netMon)
+		})
+		c, err = d.netnsDialer.DialContext(ctx, network, addr)
+	}
 	if err != nil {
 		return nil, err
 	}

+ 199 - 0
tstest/nettest/nettest.go

@@ -6,11 +6,23 @@
 package nettest
 
 import (
+	"context"
+	"flag"
+	"fmt"
+	"net"
+	"net/http"
+	"net/http/httptest"
+	"net/netip"
+	"sync"
 	"testing"
 
+	"tailscale.com/net/memnet"
 	"tailscale.com/net/netmon"
+	"tailscale.com/util/testenv"
 )
 
+var useMemNet = flag.Bool("use-test-memnet", false, "prefer using in-memory network for tests")
+
 // SkipIfNoNetwork skips the test if it looks like there's no network
 // access.
 func SkipIfNoNetwork(t testing.TB) {
@@ -19,3 +31,190 @@ func SkipIfNoNetwork(t testing.TB) {
 		t.Skip("skipping; test requires network but no interface is up")
 	}
 }
+
+// Network is an interface for use in tests that describes either [RealNetwork]
+// or [MemNetwork].
+type Network interface {
+	NewLocalTCPListener() net.Listener
+	Listen(network, address string) (net.Listener, error)
+	Dial(ctx context.Context, network, address string) (net.Conn, error)
+}
+
+// PreferMemNetwork reports whether the --use-test-memnet flag is set.
+func PreferMemNetwork() bool {
+	return *useMemNet
+}
+
+// GetNetwork returns the appropriate Network implementation based on
+// whether the --use-test-memnet flag is set.
+//
+// Each call generates a new network.
+func GetNetwork(tb testing.TB) Network {
+	var n Network
+	if PreferMemNetwork() {
+		n = MemNetwork()
+	} else {
+		n = RealNetwork()
+	}
+
+	detectLeaks := PreferMemNetwork() || !testenv.InParallelTest(tb)
+	if detectLeaks {
+		tb.Cleanup(func() {
+			// TODO: leak detection, making sure no connections
+			// remain at the end of the test. For real network,
+			// snapshot conns in pid table before & after.
+		})
+	}
+	return n
+}
+
+// RealNetwork returns a Network implementation that uses the real
+// net package.
+func RealNetwork() Network { return realNetwork{} }
+
+// realNetwork implements [Network] using the real net package.
+type realNetwork struct{}
+
+func (realNetwork) Listen(network, address string) (net.Listener, error) {
+	return net.Listen(network, address)
+}
+
+func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
+	var d net.Dialer
+	return d.DialContext(ctx, network, address)
+}
+
+func (realNetwork) NewLocalTCPListener() net.Listener {
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil {
+			panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
+		}
+	}
+	return ln
+}
+
+// MemNetwork returns a Network implementation that uses an in-memory
+// network for testing. It is only suitable for tests that do not
+// require real network access.
+func MemNetwork() Network { return &memNetwork{} }
+
+// memNetwork implements [Network] using an in-memory network.
+type memNetwork struct {
+	mu  sync.Mutex
+	lns map[string]*memnet.Listener // address -> listener
+}
+
+func (m *memNetwork) Listen(network, address string) (net.Listener, error) {
+	if network != "tcp" && network != "tcp4" && network != "tcp6" {
+		return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
+	}
+	ap, err := netip.ParseAddrPort(address)
+	if err != nil {
+		return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
+	}
+
+	m.mu.Lock()
+	defer m.mu.Unlock()
+
+	if m.lns == nil {
+		m.lns = make(map[string]*memnet.Listener)
+	}
+	port := ap.Port()
+	for {
+		if port == 0 {
+			port = 33000
+		}
+		key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
+		_, ok := m.lns[key]
+		if ok {
+			if ap.Port() != 0 {
+				return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
+			}
+			port++
+			continue
+		}
+		ln := memnet.Listen(key)
+		m.lns[key] = ln
+		return ln, nil
+	}
+}
+
+func (m *memNetwork) NewLocalTCPListener() net.Listener {
+	ln, err := m.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
+	}
+	return ln
+}
+
+func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
+	if network != "tcp" && network != "tcp4" && network != "tcp6" {
+		return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
+	}
+	m.mu.Lock()
+	ln, ok := m.lns[address]
+	m.mu.Unlock()
+	if !ok {
+		return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
+	}
+	return ln.Dial(ctx, network, address)
+}
+
+// NewHTTPServer starts and returns a new [httptest.Server].
+// The caller should call Close when finished, to shut it down.
+func NewHTTPServer(net Network, handler http.Handler) *httptest.Server {
+	ts := NewUnstartedHTTPServer(net, handler)
+	ts.Start()
+	return ts
+}
+
+// NewUnstartedHTTPServer returns a new [httptest.Server] but doesn't start it.
+//
+// After changing its configuration, the caller should call Start or
+// StartTLS.
+//
+// The caller should call Close when finished, to shut it down.
+func NewUnstartedHTTPServer(nw Network, handler http.Handler) *httptest.Server {
+	s := &httptest.Server{
+		Config: &http.Server{Handler: handler},
+	}
+	ln := nw.NewLocalTCPListener()
+	s.Listener = &listenerOnAddrOnce{
+		Listener: ln,
+		fn: func() {
+			c := s.Client()
+			if c == nil {
+				// This httptest.Server.Start initialization order has been true
+				// for over 10 years. Let's keep counting on it.
+				panic("httptest.Server: Client not initialized before Addr called")
+			}
+			if c.Transport == nil {
+				c.Transport = &http.Transport{}
+			}
+			tr := c.Transport.(*http.Transport)
+			if tr.Dial != nil || tr.DialContext != nil {
+				panic("unexpected non-nil Dial or DialContext in httptest.Server.Client.Transport")
+			}
+			tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+				return nw.Dial(ctx, network, addr)
+			}
+		},
+	}
+	return s
+}
+
+// listenerOnAddrOnce is a net.Listener that wraps another net.Listener
+// and calls a function the first time its Addr is called.
+type listenerOnAddrOnce struct {
+	net.Listener
+	once sync.Once
+	fn   func()
+}
+
+func (ln *listenerOnAddrOnce) Addr() net.Addr {
+	ln.once.Do(func() {
+		ln.fn()
+	})
+	return ln.Listener.Addr()
+}

+ 7 - 0
util/testenv/testenv.go

@@ -58,3 +58,10 @@ func InParallelTest(t TB) (isParallel bool) {
 	t.Chdir(".") // panics in a t.Parallel test
 	return false
 }
+
+// AssertInTest panics if called outside of a test binary.
+func AssertInTest() {
+	if !InTest() {
+		panic("func called outside of test binary")
+	}
+}