Przeglądaj źródła

net/portmapper: actually test something in TestProbeIntegration

And use dynamic port numbers in tests, as Linux on GitHub Actions and
Windows in general have things running on these ports.

Co-Author: Julian Knodt <[email protected]>
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 4 lat temu
rodzic
commit
bdb93c5942

+ 15 - 2
net/portmapper/igd_test.go

@@ -49,10 +49,11 @@ func NewTestIGD() (*TestIGD, error) {
 		doUPnP: true,
 	}
 	var err error
-	if d.upnpConn, err = net.ListenPacket("udp", "127.0.0.1:1900"); err != nil {
+	if d.upnpConn, err = testListenUDP(); err != nil {
 		return nil, err
 	}
-	if d.pxpConn, err = net.ListenPacket("udp", "127.0.0.1:5351"); err != nil {
+	if d.pxpConn, err = testListenUDP(); err != nil {
+		d.upnpConn.Close()
 		return nil, err
 	}
 	d.ts = httptest.NewServer(http.HandlerFunc(d.serveUPnPHTTP))
@@ -61,6 +62,18 @@ func NewTestIGD() (*TestIGD, error) {
 	return d, nil
 }
 
+func testListenUDP() (net.PacketConn, error) {
+	return net.ListenPacket("udp4", "127.0.0.1:0")
+}
+
+func (d *TestIGD) TestPxPPort() uint16 {
+	return uint16(d.pxpConn.LocalAddr().(*net.UDPAddr).Port)
+}
+
+func (d *TestIGD) TestUPnPPort() uint16 {
+	return uint16(d.upnpConn.LocalAddr().(*net.UDPAddr).Port)
+}
+
 func (d *TestIGD) Close() error {
 	d.ts.Close()
 	d.upnpConn.Close()

+ 8 - 6
net/portmapper/pcp.go

@@ -12,7 +12,6 @@ import (
 	"time"
 
 	"inet.af/netaddr"
-	"tailscale.com/net/netns"
 )
 
 // References:
@@ -22,8 +21,8 @@ import (
 
 // PCP constants
 const (
-	pcpVersion = 2
-	pcpPort    = 5351
+	pcpVersion     = 2
+	pcpDefaultPort = 5351
 
 	pcpMapLifetimeSec = 7200 // TODO does the RFC recommend anything? This is taken from PMP.
 
@@ -39,7 +38,8 @@ const (
 )
 
 type pcpMapping struct {
-	gw       netaddr.IP
+	c        *Client
+	gw       netaddr.IPPort
 	internal netaddr.IPPort
 	external netaddr.IPPort
 
@@ -54,13 +54,13 @@ func (p *pcpMapping) GoodUntil() time.Time     { return p.goodUntil }
 func (p *pcpMapping) RenewAfter() time.Time    { return p.renewAfter }
 func (p *pcpMapping) External() netaddr.IPPort { return p.external }
 func (p *pcpMapping) Release(ctx context.Context) {
-	uc, err := netns.Listener().ListenPacket(ctx, "udp4", ":0")
+	uc, err := p.c.listenPacket(ctx, "udp4", ":0")
 	if err != nil {
 		return
 	}
 	defer uc.Close()
 	pkt := buildPCPRequestMappingPacket(p.internal.IP(), p.internal.Port(), p.external.Port(), 0, p.external.IP())
-	uc.WriteTo(pkt, netaddr.IPPortFrom(p.gw, pcpPort).UDPAddr())
+	uc.WriteTo(pkt, p.gw.UDPAddr())
 }
 
 // buildPCPRequestMappingPacket generates a PCP packet with a MAP opcode.
@@ -95,6 +95,8 @@ func buildPCPRequestMappingPacket(
 	return pkt
 }
 
+// parsePCPMapResponse parses resp into a partially populated pcpMapping.
+// In particular, its Client is not populated.
 func parsePCPMapResponse(resp []byte) (*pcpMapping, error) {
 	if len(resp) < 60 {
 		return nil, fmt.Errorf("Does not appear to be PCP MAP response")

+ 49 - 18
net/portmapper/portmapper.go

@@ -14,6 +14,7 @@ import (
 	"io"
 	"net"
 	"net/http"
+	"os"
 	"sync"
 	"time"
 
@@ -55,6 +56,8 @@ type Client struct {
 	logf         logger.Logf
 	ipAndGateway func() (gw, ip netaddr.IP, ok bool)
 	onChange     func() // or nil
+	testPxPPort  uint16 // if non-zero, pxpPort to use for tests
+	testUPnPPort uint16 // if non-zero, uPnPPort to use for tests
 
 	mu sync.Mutex // guards following, and all fields thereof
 
@@ -113,7 +116,8 @@ func (c *Client) HaveMapping() bool {
 //
 // All fields are immutable once created.
 type pmpMapping struct {
-	gw         netaddr.IP
+	c          *Client
+	gw         netaddr.IPPort
 	external   netaddr.IPPort
 	internal   netaddr.IPPort
 	renewAfter time.Time // the time at which we want to renew the mapping
@@ -132,13 +136,13 @@ func (p *pmpMapping) External() netaddr.IPPort { return p.external }
 
 // Release does a best effort fire-and-forget release of the PMP mapping m.
 func (m *pmpMapping) Release(ctx context.Context) {
-	uc, err := netns.Listener().ListenPacket(ctx, "udp4", ":0")
+	uc, err := m.c.listenPacket(ctx, "udp4", ":0")
 	if err != nil {
 		return
 	}
 	defer uc.Close()
 	pkt := buildPMPRequestMappingPacket(m.internal.Port(), m.external.Port(), pmpMapLifetimeDelete)
-	uc.WriteTo(pkt, netaddr.IPPortFrom(m.gw, pmpPort).UDPAddr())
+	uc.WriteTo(pkt, m.gw.UDPAddr())
 }
 
 // NewClient returns a new portmapping client.
@@ -213,6 +217,32 @@ func (c *Client) gatewayAndSelfIP() (gw, myIP netaddr.IP, ok bool) {
 	return
 }
 
+// pxpPort returns the NAT-PMP and PCP port number.
+// It returns 5351, except for in tests where it varies by run.
+func (c *Client) pxpPort() uint16 {
+	if c.testPxPPort != 0 {
+		return c.testPxPPort
+	}
+	return pmpDefaultPort
+}
+
+// upnpPort returns the UPnP discovery port number.
+// It returns 1900, except for in tests where it varies by run.
+func (c *Client) upnpPort() uint16 {
+	if c.testUPnPPort != 0 {
+		return c.testUPnPPort
+	}
+	return upnpDefaultPort
+}
+
+func (c *Client) listenPacket(ctx context.Context, network, addr string) (net.PacketConn, error) {
+	if (c.testPxPPort != 0 || c.testUPnPPort != 0) && os.Getenv("GITHUB_ACTIONS") == "true" {
+		var lc net.ListenConfig
+		return lc.ListenPacket(ctx, network, addr)
+	}
+	return netns.Listener().ListenPacket(ctx, network, addr)
+}
+
 func (c *Client) invalidateMappingsLocked(releaseOld bool) {
 	if c.mapping != nil {
 		if releaseOld {
@@ -399,7 +429,8 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netaddr.IPPor
 	// PCP returns all the information necessary for a mapping in a single packet, so we can
 	// construct it upon receiving that packet.
 	m := &pmpMapping{
-		gw:       gw,
+		c:        c,
+		gw:       netaddr.IPPortFrom(gw, c.pxpPort()),
 		internal: internalAddr,
 	}
 	if haveRecentPMP {
@@ -415,7 +446,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netaddr.IPPor
 	}
 	c.mu.Unlock()
 
-	uc, err := netns.Listener().ListenPacket(ctx, "udp4", ":0")
+	uc, err := c.listenPacket(ctx, "udp4", ":0")
 	if err != nil {
 		return netaddr.IPPort{}, err
 	}
@@ -424,7 +455,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netaddr.IPPor
 	uc.SetReadDeadline(time.Now().Add(portMapServiceTimeout))
 	defer closeCloserOnContextDone(ctx, uc)()
 
-	pxpAddr := netaddr.IPPortFrom(gw, pmpPort)
+	pxpAddr := netaddr.IPPortFrom(gw, c.pxpPort())
 	pxpAddru := pxpAddr.UDPAddr()
 
 	preferPCP := !DisablePCP && (DisablePMP || (!haveRecentPMP && haveRecentPCP))
@@ -499,8 +530,9 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netaddr.IPPor
 					// PCP should only have a single packet response
 					return netaddr.IPPort{}, NoMappingError{ErrNoPortMappingServices}
 				}
+				pcpMapping.c = c
 				pcpMapping.internal = m.internal
-				pcpMapping.gw = gw
+				pcpMapping.gw = netaddr.IPPortFrom(gw, c.pxpPort())
 				c.mu.Lock()
 				defer c.mu.Unlock()
 				c.mapping = pcpMapping
@@ -524,7 +556,7 @@ type pmpResultCode uint16
 
 // NAT-PMP constants.
 const (
-	pmpPort              = 5351
+	pmpDefaultPort       = 5351
 	pmpMapLifetimeSec    = 7200 // RFC recommended 2 hour map duration
 	pmpMapLifetimeDelete = 0    // 0 second lifetime deletes
 
@@ -622,7 +654,7 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) {
 		}
 	}()
 
-	uc, err := netns.Listener().ListenPacket(context.Background(), "udp4", ":0")
+	uc, err := c.listenPacket(context.Background(), "udp4", ":0")
 	if err != nil {
 		c.logf("ProbePCP: %v", err)
 		return res, err
@@ -632,9 +664,8 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) {
 	defer cancel()
 	defer closeCloserOnContextDone(ctx, uc)()
 
-	pcpAddr := netaddr.IPPortFrom(gw, pcpPort).UDPAddr()
-	pmpAddr := netaddr.IPPortFrom(gw, pmpPort).UDPAddr()
-	upnpAddr := netaddr.IPPortFrom(gw, upnpPort).UDPAddr()
+	pxpAddr := netaddr.IPPortFrom(gw, c.pxpPort()).UDPAddr()
+	upnpAddr := netaddr.IPPortFrom(gw, c.upnpPort()).UDPAddr()
 
 	// Don't send probes to services that we recently learned (for
 	// the same gw/myIP) are available. See
@@ -642,12 +673,12 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) {
 	if c.sawPMPRecently() {
 		res.PMP = true
 	} else if !DisablePMP {
-		uc.WriteTo(pmpReqExternalAddrPacket, pmpAddr)
+		uc.WriteTo(pmpReqExternalAddrPacket, pxpAddr)
 	}
 	if c.sawPCPRecently() {
 		res.PCP = true
 	} else if !DisablePCP {
-		uc.WriteTo(pcpAnnounceRequest(myIP), pcpAddr)
+		uc.WriteTo(pcpAnnounceRequest(myIP), pxpAddr)
 	}
 	if c.sawUPnPRecently() {
 		res.UPnP = true
@@ -669,9 +700,9 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) {
 			}
 			return res, err
 		}
-		port := addr.(*net.UDPAddr).Port
+		port := uint16(addr.(*net.UDPAddr).Port)
 		switch port {
-		case upnpPort:
+		case c.upnpPort():
 			if mem.Contains(mem.B(buf[:n]), mem.S(":InternetGatewayDevice:")) {
 				meta, err := parseUPnPDiscoResponse(buf[:n])
 				if err != nil {
@@ -686,7 +717,7 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) {
 				c.uPnPMeta = meta
 				c.mu.Unlock()
 			}
-		case pcpPort: // same as pmpPort
+		case c.pxpPort(): // same value for PMP and PCP
 			if pres, ok := parsePCPResponse(buf[:n]); ok {
 				if pres.OpCode == pcpOpReply|pcpOpAnnounce {
 					pcpHeard = true
@@ -729,7 +760,7 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) {
 var pmpReqExternalAddrPacket = []byte{pmpVersion, pmpOpMapPublicAddr} // 0, 0
 
 const (
-	upnpPort = 1900 // for UDP discovery only; TCP port discovered later
+	upnpDefaultPort = 1900 // for UDP discovery only; TCP port discovered later
 )
 
 // uPnPPacket is the UPnP UDP discovery packet's request body.

+ 19 - 2
net/portmapper/portmapper_test.go

@@ -7,6 +7,7 @@ package portmapper
 import (
 	"context"
 	"os"
+	"reflect"
 	"strconv"
 	"testing"
 	"time"
@@ -72,7 +73,9 @@ func TestProbeIntegration(t *testing.T) {
 		logf("portmapping changed.")
 		logf("have mapping: %v", c.HaveMapping())
 	})
-
+	c.testPxPPort = igd.TestPxPPort()
+	c.testUPnPPort = igd.TestUPnPPort()
+	t.Logf("Listening on pxp=%v, upnp=%v", c.testPxPPort, c.testUPnPPort)
 	c.SetGatewayLookupFunc(func() (gw, self netaddr.IP, ok bool) {
 		return netaddr.IPv4(127, 0, 0, 1), netaddr.IPv4(1, 2, 3, 4), true
 	})
@@ -81,7 +84,21 @@ func TestProbeIntegration(t *testing.T) {
 	if err != nil {
 		t.Fatalf("Probe: %v", err)
 	}
+	if !res.UPnP {
+		t.Errorf("didn't detect UPnP")
+	}
+	st := igd.stats()
+	want := igdCounters{
+		numUPnPDiscoRecv:     1,
+		numPMPRecv:           1,
+		numPCPRecv:           1,
+		numPMPPublicAddrRecv: 1,
+	}
+	if !reflect.DeepEqual(st, want) {
+		t.Errorf("unexpected stats:\n got: %+v\nwant: %+v", st, want)
+	}
+
 	t.Logf("Probe: %+v", res)
-	t.Logf("IGD stats: %+v", igd.stats())
+	t.Logf("IGD stats: %+v", st)
 	// TODO(bradfitz): finish
 }