Przeglądaj źródła

wgengine/magicsock, types/nettype, etc: finish ReadFromUDPAddrPort netip migration

So we're staying within the netip.Addr/AddrPort consistently and
avoiding allocs/conversions to the legacy net addr types.

Updates #5162

Change-Id: I59feba60d3de39f773e68292d759766bac98c917
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 2 lat temu
rodzic
commit
10f1c90f4d

+ 1 - 1
net/dns/resolver/forwarder.go

@@ -521,7 +521,7 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
 
 	// The 1 extra byte is to detect packet truncation.
 	out := make([]byte, maxResponseBytes+1)
-	n, _, err := conn.ReadFrom(out)
+	n, _, err := conn.ReadFromUDPAddrPort(out)
 	if err != nil {
 		if err := ctx.Err(); err != nil {
 			return nil, err

+ 4 - 9
net/netcheck/netcheck.go

@@ -208,7 +208,7 @@ type Client struct {
 // reusing an existing UDP connection.
 type STUNConn interface {
 	WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error)
-	ReadFrom([]byte) (int, net.Addr, error)
+	ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error)
 }
 
 func (c *Client) enoughRegions() int {
@@ -518,7 +518,7 @@ func nodeMight4(n *tailcfg.DERPNode) bool {
 }
 
 type packetReaderFromCloser interface {
-	ReadFrom([]byte) (int, net.Addr, error)
+	ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error)
 	io.Closer
 }
 
@@ -538,7 +538,7 @@ func (c *Client) readPackets(ctx context.Context, pc packetReaderFromCloser) {
 
 	var buf [64 << 10]byte
 	for {
-		n, addr, err := pc.ReadFrom(buf[:])
+		n, addr, err := pc.ReadFromUDPAddrPort(buf[:])
 		if err != nil {
 			if ctx.Err() != nil {
 				return
@@ -546,16 +546,11 @@ func (c *Client) readPackets(ctx context.Context, pc packetReaderFromCloser) {
 			c.logf("ReadFrom: %v", err)
 			return
 		}
-		ua, ok := addr.(*net.UDPAddr)
-		if !ok {
-			c.logf("ReadFrom: unexpected addr %T", addr)
-			continue
-		}
 		pkt := buf[:n]
 		if !stun.Is(pkt) {
 			continue
 		}
-		if ap := netaddr.Unmap(ua.AddrPort()); ap.IsValid() {
+		if ap := netaddr.Unmap(addr); ap.IsValid() {
 			c.ReceiveSTUNPacket(pkt, ap)
 		}
 	}

+ 5 - 10
net/portmapper/portmapper.go

@@ -531,7 +531,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
 
 	res := make([]byte, 1500)
 	for {
-		n, srci, err := uc.ReadFrom(res)
+		n, src, err := uc.ReadFromUDPAddrPort(res)
 		if err != nil {
 			if ctx.Err() == context.Canceled {
 				return netip.AddrPort{}, err
@@ -542,8 +542,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor
 			}
 			return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices}
 		}
-		srcu := srci.(*net.UDPAddr)
-		src := netaddr.Unmap(srcu.AddrPort())
+		src = netaddr.Unmap(src)
 		if !src.IsValid() {
 			continue
 		}
@@ -793,18 +792,14 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) {
 			// Nothing more to discover.
 			return res, nil
 		}
-		n, addr, err := uc.ReadFrom(buf)
+		n, src, err := uc.ReadFromUDPAddrPort(buf)
 		if err != nil {
 			if ctx.Err() == context.DeadlineExceeded {
 				err = nil
 			}
 			return res, err
 		}
-		ip, ok := netip.AddrFromSlice(addr.(*net.UDPAddr).IP)
-		if !ok {
-			continue
-		}
-		ip = ip.Unmap()
+		ip := src.Addr().Unmap()
 
 		handleUPnPResponse := func() {
 			metricUPnPResponse.Add(1)
@@ -832,7 +827,7 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) {
 			c.mu.Unlock()
 		}
 
-		port := uint16(addr.(*net.UDPAddr).Port)
+		port := src.Port()
 		switch port {
 		case c.upnpPort():
 			if mem.Contains(mem.B(buf[:n]), mem.S(":InternetGatewayDevice:")) {

+ 10 - 11
net/stun/stuntest/stuntest.go

@@ -6,14 +6,15 @@ package stuntest
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"net"
 	"net/netip"
 	"strconv"
-	"strings"
 	"sync"
 	"testing"
 
+	"tailscale.com/net/netaddr"
 	"tailscale.com/net/stun"
 	"tailscale.com/tailcfg"
 	"tailscale.com/types/nettype"
@@ -44,28 +45,27 @@ func ServeWithPacketListener(t testing.TB, ln nettype.PacketListener) (addr *net
 		addr.IP = net.ParseIP("127.0.0.1")
 	}
 	doneCh := make(chan struct{})
-	go runSTUN(t, pc, &stats, doneCh)
+	go runSTUN(t, pc.(nettype.PacketConn), &stats, doneCh)
 	return addr, func() {
 		pc.Close()
 		<-doneCh
 	}
 }
 
-func runSTUN(t testing.TB, pc net.PacketConn, stats *stunStats, done chan<- struct{}) {
+func runSTUN(t testing.TB, pc nettype.PacketConn, stats *stunStats, done chan<- struct{}) {
 	defer close(done)
 
 	var buf [64 << 10]byte
 	for {
-		n, addr, err := pc.ReadFrom(buf[:])
+		n, src, err := pc.ReadFromUDPAddrPort(buf[:])
 		if err != nil {
-			// TODO: when we switch to Go 1.16, replace this with errors.Is(err, net.ErrClosed)
-			if strings.Contains(err.Error(), "closed network connection") {
+			if errors.Is(err, net.ErrClosed) {
 				t.Logf("STUN server shutdown")
 				return
 			}
 			continue
 		}
-		ua := addr.(*net.UDPAddr)
+		src = netaddr.Unmap(src)
 		pkt := buf[:n]
 		if !stun.Is(pkt) {
 			continue
@@ -76,16 +76,15 @@ func runSTUN(t testing.TB, pc net.PacketConn, stats *stunStats, done chan<- stru
 		}
 
 		stats.mu.Lock()
-		if ua.IP.To4() != nil {
+		if src.Addr().Is4() {
 			stats.readIPv4++
 		} else {
 			stats.readIPv6++
 		}
 		stats.mu.Unlock()
 
-		nia, _ := netip.AddrFromSlice(ua.IP)
-		res := stun.Response(txid, netip.AddrPortFrom(nia, uint16(ua.Port)))
-		if _, err := pc.WriteTo(res, addr); err != nil {
+		res := stun.Response(txid, src)
+		if _, err := pc.WriteToUDPAddrPort(res, src); err != nil {
 			t.Logf("STUN server write failed: %v", err)
 		}
 	}

+ 11 - 8
tstest/natlab/natlab.go

@@ -824,13 +824,21 @@ func (c *conn) Write(buf []byte) (int, error) {
 }
 
 func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+	n, ap, err := c.ReadFromUDPAddrPort(p)
+	if err != nil {
+		return 0, nil, err
+	}
+	return n, net.UDPAddrFromAddrPort(ap), nil
+}
+
+func (c *conn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 
 	ar := &activeRead{cancel: cancel}
 
 	if err := c.canRead(); err != nil {
-		return 0, nil, err
+		return 0, netip.AddrPort{}, err
 	}
 
 	c.registerActiveRead(ar, true)
@@ -840,14 +848,9 @@ func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
 	case pkt := <-c.in:
 		n = copy(p, pkt.Payload)
 		pkt.Trace("PacketConn.ReadFrom")
-		ua := &net.UDPAddr{
-			IP:   pkt.Src.Addr().AsSlice(),
-			Port: int(pkt.Src.Port()),
-			Zone: pkt.Src.Addr().Zone(),
-		}
-		return n, ua, nil
+		return n, pkt.Src, nil
 	case <-ctx.Done():
-		return 0, nil, context.DeadlineExceeded
+		return 0, netip.AddrPort{}, context.DeadlineExceeded
 	}
 }
 

+ 3 - 3
types/nettype/nettype.go

@@ -30,11 +30,11 @@ func (Std) ListenPacket(ctx context.Context, network, address string) (net.Packe
 	return conf.ListenPacket(ctx, network, address)
 }
 
-// PacketConn is a net.PacketConn that's about halfway (as of 2023-04-15)
-// converted to use netip.AddrPort.
+// PacketConn is like a net.PacketConn but uses the newer netip.AddrPort
+// write/read methods.
 type PacketConn interface {
 	WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error)
-	ReadFrom(p []byte) (int, net.Addr, error)
+	ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error)
 	io.Closer
 	LocalAddr() net.Addr
 	SetDeadline(time.Time) error

+ 13 - 11
wgengine/magicsock/magicsock.go

@@ -44,6 +44,7 @@ import (
 	"tailscale.com/net/connstats"
 	"tailscale.com/net/dnscache"
 	"tailscale.com/net/interfaces"
+	"tailscale.com/net/netaddr"
 	"tailscale.com/net/netcheck"
 	"tailscale.com/net/neterror"
 	"tailscale.com/net/netns"
@@ -3420,7 +3421,7 @@ type batchingUDPConn struct {
 	sendBatchPool         sync.Pool
 }
 
-func (c *batchingUDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+func (c *batchingUDPConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
 	if c.rxOffload {
 		// UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may
 		// receive a "monster datagram" from any read call. The ReadFrom() API
@@ -3428,9 +3429,9 @@ func (c *batchingUDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
 		// case. Other platforms may vary in behavior, but we go with the most
 		// conservative approach to prevent this from becoming a footgun in the
 		// future.
-		return 0, nil, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable")
+		return 0, netip.AddrPort{}, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable")
 	}
-	return c.pc.ReadFrom(p)
+	return c.pc.ReadFromUDPAddrPort(p)
 }
 
 func (c *batchingUDPConn) SetDeadline(t time.Time) error {
@@ -3753,9 +3754,9 @@ func (c *RebindingUDPConn) currentConn() nettype.PacketConn {
 	return c.pconn
 }
 
-func (c *RebindingUDPConn) readFromWithInitPconn(pconn nettype.PacketConn, b []byte) (int, net.Addr, error) {
+func (c *RebindingUDPConn) readFromWithInitPconn(pconn nettype.PacketConn, b []byte) (int, netip.AddrPort, error) {
 	for {
-		n, addr, err := pconn.ReadFrom(b)
+		n, addr, err := pconn.ReadFromUDPAddrPort(b)
 		if err != nil && pconn != c.currentConn() {
 			pconn = *c.pconnAtomic.Load()
 			continue
@@ -3764,9 +3765,9 @@ func (c *RebindingUDPConn) readFromWithInitPconn(pconn nettype.PacketConn, b []b
 	}
 }
 
-// ReadFrom reads a packet from c into b.
+// ReadFromUDPAddrPort reads a packet from c into b.
 // It returns the number of bytes copied and the source address.
-func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
+func (c *RebindingUDPConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, error) {
 	return c.readFromWithInitPconn(*c.pconnAtomic.Load(), b)
 }
 
@@ -3803,9 +3804,10 @@ func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error
 		pconn := *c.pconnAtomic.Load()
 		b, ok := pconn.(*batchingUDPConn)
 		if !ok {
-			var err error
-			msgs[0].N, msgs[0].Addr, err = c.readFromWithInitPconn(pconn, msgs[0].Buffers[0])
+			n, ap, err := c.readFromWithInitPconn(pconn, msgs[0].Buffers[0])
 			if err == nil {
+				msgs[0].N = n
+				msgs[0].Addr = net.UDPAddrFromAddrPort(netaddr.Unmap(ap))
 				return 1, nil
 			}
 			return 0, err
@@ -3880,13 +3882,13 @@ type blockForeverConn struct {
 	closed bool
 }
 
-func (c *blockForeverConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+func (c *blockForeverConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) {
 	c.mu.Lock()
 	for !c.closed {
 		c.cond.Wait()
 	}
 	c.mu.Unlock()
-	return 0, nil, net.ErrClosed
+	return 0, netip.AddrPort{}, net.ErrClosed
 }
 
 func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (int, error) {

+ 1 - 1
wgengine/magicsock/magicsock_test.go

@@ -1801,7 +1801,7 @@ func TestBlockForeverConnUnblocks(t *testing.T) {
 	done := make(chan error, 1)
 	go func() {
 		defer close(done)
-		_, _, err := c.ReadFrom(make([]byte, 1))
+		_, _, err := c.ReadFromUDPAddrPort(make([]byte, 1))
 		done <- err
 	}()
 	time.Sleep(50 * time.Millisecond) // give ReadFrom time to get blocked