|
@@ -1,23 +1,25 @@
|
|
|
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// Copyright (c) Tailscale Inc & AUTHORS
|
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
|
|
|
|
-// Package netx contains the Network type to abstract over either a real
|
|
|
|
|
-// network or a virtual network for testing.
|
|
|
|
|
|
|
+// Package netx contains types to describe and abstract over how dialing and
|
|
|
|
|
+// listening are performed.
|
|
|
package netx
|
|
package netx
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
"context"
|
|
"context"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
"net"
|
|
"net"
|
|
|
- "net/netip"
|
|
|
|
|
- "sync"
|
|
|
|
|
-
|
|
|
|
|
- "tailscale.com/net/memnet"
|
|
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+// DialFunc is a function that dials a network address.
|
|
|
|
|
+//
|
|
|
|
|
+// It's the type implemented by net.Dialer.DialContext or required
|
|
|
|
|
+// by net/http.Transport.DialContext, etc.
|
|
|
|
|
+type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
|
|
|
|
+
|
|
|
// Network describes a network that can listen and dial. The two common
|
|
// Network describes a network that can listen and dial. The two common
|
|
|
// implementations are [RealNetwork], using the net package to use the real
|
|
// implementations are [RealNetwork], using the net package to use the real
|
|
|
-// network, or [MemNetwork], using an in-memory network (typically for testing)
|
|
|
|
|
|
|
+// network, or [memnet.Network], using an in-memory network (typically for testing)
|
|
|
type Network interface {
|
|
type Network interface {
|
|
|
NewLocalTCPListener() net.Listener
|
|
NewLocalTCPListener() net.Listener
|
|
|
Listen(network, address string) (net.Listener, error)
|
|
Listen(network, address string) (net.Listener, error)
|
|
@@ -44,77 +46,8 @@ func (realNetwork) NewLocalTCPListener() net.Listener {
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil {
|
|
if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil {
|
|
|
- panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
|
|
|
|
|
|
|
+ panic(fmt.Sprintf("failed to listen on either IPv4 or IPv6 localhost port: %v", err))
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
return ln
|
|
return ln
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
-// MemNetwork returns a Network implementation that uses an in-memory
|
|
|
|
|
-// network for testing. It is only suitable for tests that do not
|
|
|
|
|
-// require real network access.
|
|
|
|
|
-//
|
|
|
|
|
-// As of 2025-04-08, it only supports TCP.
|
|
|
|
|
-func MemNetwork() Network { return &memNetwork{} }
|
|
|
|
|
-
|
|
|
|
|
-// memNetwork implements [Network] using an in-memory network.
|
|
|
|
|
-type memNetwork struct {
|
|
|
|
|
- mu sync.Mutex
|
|
|
|
|
- lns map[string]*memnet.Listener // address -> listener
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-func (m *memNetwork) Listen(network, address string) (net.Listener, error) {
|
|
|
|
|
- if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
|
|
|
|
- return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network)
|
|
|
|
|
- }
|
|
|
|
|
- ap, err := netip.ParseAddrPort(address)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- m.mu.Lock()
|
|
|
|
|
- defer m.mu.Unlock()
|
|
|
|
|
-
|
|
|
|
|
- if m.lns == nil {
|
|
|
|
|
- m.lns = make(map[string]*memnet.Listener)
|
|
|
|
|
- }
|
|
|
|
|
- port := ap.Port()
|
|
|
|
|
- for {
|
|
|
|
|
- if port == 0 {
|
|
|
|
|
- port = 33000
|
|
|
|
|
- }
|
|
|
|
|
- key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port))
|
|
|
|
|
- _, ok := m.lns[key]
|
|
|
|
|
- if ok {
|
|
|
|
|
- if ap.Port() != 0 {
|
|
|
|
|
- return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address)
|
|
|
|
|
- }
|
|
|
|
|
- port++
|
|
|
|
|
- continue
|
|
|
|
|
- }
|
|
|
|
|
- ln := memnet.Listen(key)
|
|
|
|
|
- m.lns[key] = ln
|
|
|
|
|
- return ln, nil
|
|
|
|
|
- }
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-func (m *memNetwork) NewLocalTCPListener() net.Listener {
|
|
|
|
|
- ln, err := m.Listen("tcp", "127.0.0.1:0")
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err))
|
|
|
|
|
- }
|
|
|
|
|
- return ln
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-func (m *memNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
|
|
|
|
- if network != "tcp" && network != "tcp4" && network != "tcp6" {
|
|
|
|
|
- return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network)
|
|
|
|
|
- }
|
|
|
|
|
- m.mu.Lock()
|
|
|
|
|
- ln, ok := m.lns[address]
|
|
|
|
|
- m.mu.Unlock()
|
|
|
|
|
- if !ok {
|
|
|
|
|
- return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address)
|
|
|
|
|
- }
|
|
|
|
|
- return ln.Dial(ctx, network, address)
|
|
|
|
|
-}
|
|
|