Parcourir la source

wgengine/magicsock: actually use AF_PACKET socket for raw disco

Previously, despite what the commit said, we were using a raw IP socket
that was *not* an AF_PACKET socket, and thus was subject to the host
firewall rules. Switch to using a real AF_PACKET socket to actually get
the functionality we want.

Updates #13140

Signed-off-by: Andrew Dunham <[email protected]>
Change-Id: If657daeeda9ab8d967e75a4f049c66e2bca54b78
Andrew Dunham il y a 1 an
Parent
commit
1c972bc7cb

+ 1 - 1
cmd/k8s-operator/depaware.txt

@@ -171,7 +171,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
    L 💣 github.com/mdlayher/netlink/nlenc                            from github.com/jsimonetti/rtnetlink+
    L    github.com/mdlayher/netlink/nltest                           from github.com/google/nftables
    L    github.com/mdlayher/sdnotify                                 from tailscale.com/util/systemd
-   L 💣 github.com/mdlayher/socket                                   from github.com/mdlayher/netlink
+   L 💣 github.com/mdlayher/socket                                   from github.com/mdlayher/netlink+
         github.com/miekg/dns                                         from tailscale.com/net/dns/recursive
      💣 github.com/mitchellh/go-ps                                   from tailscale.com/safesocket
         github.com/modern-go/concurrent                              from github.com/json-iterator/go

+ 1 - 1
cmd/tailscaled/depaware.txt

@@ -139,7 +139,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
    L 💣 github.com/mdlayher/netlink/nlenc                            from github.com/jsimonetti/rtnetlink+
    L    github.com/mdlayher/netlink/nltest                           from github.com/google/nftables
    L    github.com/mdlayher/sdnotify                                 from tailscale.com/util/systemd
-   L 💣 github.com/mdlayher/socket                                   from github.com/mdlayher/netlink
+   L 💣 github.com/mdlayher/socket                                   from github.com/mdlayher/netlink+
         github.com/miekg/dns                                         from tailscale.com/net/dns/recursive
      💣 github.com/mitchellh/go-ps                                   from tailscale.com/safesocket
    L    github.com/pierrec/lz4/v4                                    from github.com/u-root/uio/uio

+ 5 - 0
net/packet/packet.go

@@ -393,6 +393,11 @@ func (q *Parsed) Buffer() []byte {
 // Payload returns the payload of the IP subprotocol section.
 // This is a read-only view; that is, q retains the ownership of the buffer.
 func (q *Parsed) Payload() []byte {
+	// If the packet is truncated, return nothing instead of crashing.
+	if q.length > len(q.b) || q.dataofs > len(q.b) {
+		return nil
+	}
+
 	return q.b[q.dataofs:q.length]
 }
 

+ 269 - 89
wgengine/magicsock/magicsock_linux.go

@@ -5,28 +5,37 @@ package magicsock
 
 import (
 	"bytes"
+	"context"
 	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
 	"net"
 	"net/netip"
+	"strings"
 	"syscall"
 	"time"
-	"unsafe"
 
+	"github.com/mdlayher/socket"
 	"golang.org/x/net/bpf"
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+	"golang.org/x/sys/cpu"
 	"golang.org/x/sys/unix"
+	"tailscale.com/disco"
 	"tailscale.com/envknob"
 	"tailscale.com/net/netns"
+	"tailscale.com/types/ipproto"
 	"tailscale.com/types/key"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/nettype"
 )
 
 const (
-	udpHeaderSize          = 8
-	ipv6FragmentHeaderSize = 8
+	udpHeaderSize = 8
+
+	// discoMinHeaderSize is the minimum size of the disco header in bytes.
+	discoMinHeaderSize = len(disco.Magic) + 32 /* key length */ + disco.NonceLen
 )
 
 // Enable/disable using raw sockets to receive disco traffic.
@@ -38,8 +47,17 @@ var debugRawDiscoReads = envknob.RegisterBool("TS_DEBUG_RAW_DISCO")
 // These are our BPF filters that we use for testing packets.
 var (
 	magicsockFilterV4 = []bpf.Instruction{
-		// For raw UDPv4 sockets, BPF receives the entire IP packet to
-		// inspect.
+		// For raw sockets (with ETH_P_IP set), the BPF program
+		// receives the entire IPv4 packet, but not the Ethernet
+		// header.
+
+		// Double-check that this is a UDP packet; we shouldn't be
+		// seeing anything else given how we create our AF_PACKET
+		// socket, but an extra check here is cheap, and matches the
+		// check that we do in the IPv6 path.
+		bpf.LoadAbsolute{Off: 9, Size: 1},
+		bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(ipproto.UDP), SkipTrue: 1, SkipFalse: 0},
+		bpf.RetConstant{Val: 0x0},
 
 		// Disco packets are so small they should never get
 		// fragmented, and we don't want to handle reassembly.
@@ -53,6 +71,25 @@ var (
 		// Load IP header length into X register.
 		bpf.LoadMemShift{Off: 0},
 
+		// Verify that we have a packet that's big enough to (possibly)
+		// contain a disco packet.
+		//
+		// The length of an IPv4 disco packet is composed of:
+		// - 8 bytes for the UDP header
+		// - N bytes for the disco packet header
+		//
+		// bpf will implicitly return 0 ("skip") if attempting an
+		// out-of-bounds load, so we can check the length of the packet
+		// loading a byte from that offset here. We subtract 1 byte
+		// from the offset to ensure that we accept a packet that's
+		// exactly the minimum size.
+		//
+		// We use LoadIndirect; since we loaded the start of the packet's
+		// payload into the X register, above, we don't need to add
+		// ipv4.HeaderLen to the offset (and this properly handles IPv4
+		// extensions).
+		bpf.LoadIndirect{Off: uint32(udpHeaderSize + discoMinHeaderSize - 1), Size: 1},
+
 		// Get the first 4 bytes of the UDP packet, compare with our magic number
 		bpf.LoadIndirect{Off: udpHeaderSize, Size: 4},
 		bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic1, SkipTrue: 0, SkipFalse: 3},
@@ -82,25 +119,24 @@ var (
 	// and thus we'd rather be conservative here and possibly not receive
 	// disco packets rather than slow down the system.
 	magicsockFilterV6 = []bpf.Instruction{
-		// For raw UDPv6 sockets, BPF receives _only_ the UDP header onwards, not an entire IP packet.
-		//
-		//    https://stackoverflow.com/questions/24514333/using-bpf-with-sock-dgram-on-linux-machine
-		//    https://blog.cloudflare.com/epbf_sockets_hop_distance/
-		//
-		// This is especially confusing because this *isn't* true for
-		// IPv4; see the following code from the 'ping' utility that
-		// corroborates this:
-		//
-		//    https://github.com/iputils/iputils/blob/1ab5fa/ping/ping.c#L1667-L1676
-		//    https://github.com/iputils/iputils/blob/1ab5fa/ping/ping6_common.c#L933-L941
+		// Do a bounds check to ensure we have enough space for a disco
+		// packet; see the comment in the IPv4 BPF program for more
+		// details.
+		bpf.LoadAbsolute{Off: uint32(ipv6.HeaderLen + udpHeaderSize + discoMinHeaderSize - 1), Size: 1},
+
+		// Verify that the 'next header' value of the IPv6 packet is
+		// UDP, which is what we're expecting; if it's anything else
+		// (including extension headers), we skip the packet.
+		bpf.LoadAbsolute{Off: 6, Size: 1},
+		bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(ipproto.UDP), SkipTrue: 0, SkipFalse: 5},
 
 		// Compare with our magic number. Start by loading and
 		// comparing the first 4 bytes of the UDP payload.
-		bpf.LoadAbsolute{Off: udpHeaderSize, Size: 4},
+		bpf.LoadAbsolute{Off: ipv6.HeaderLen + udpHeaderSize, Size: 4},
 		bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic1, SkipTrue: 0, SkipFalse: 3},
 
 		// Compare the next 2 bytes
-		bpf.LoadAbsolute{Off: udpHeaderSize + 4, Size: 2},
+		bpf.LoadAbsolute{Off: ipv6.HeaderLen + udpHeaderSize + 4, Size: 2},
 		bpf.JumpIf{Cond: bpf.JumpEqual, Val: discoMagic2, SkipTrue: 0, SkipFalse: 1},
 
 		// Accept the whole packet
@@ -140,21 +176,24 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
 	}
 
 	var (
-		network  string
+		udpnet   string
 		addr     string
-		testAddr string
+		proto    int
+		testAddr netip.AddrPort
 		prog     []bpf.Instruction
 	)
 	switch family {
 	case "ip4":
-		network = "ip4:17"
+		udpnet = "udp4"
 		addr = "0.0.0.0"
-		testAddr = "127.0.0.1:1"
+		proto = ethernetProtoIPv4()
+		testAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 1)
 		prog = magicsockFilterV4
 	case "ip6":
-		network = "ip6:17"
+		udpnet = "udp6"
 		addr = "::"
-		testAddr = "[::1]:1"
+		proto = ethernetProtoIPv6()
+		testAddr = netip.AddrPortFrom(netip.IPv6Loopback(), 1)
 		prog = magicsockFilterV6
 	default:
 		return nil, fmt.Errorf("unsupported address family %q", family)
@@ -165,72 +204,214 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
 		return nil, fmt.Errorf("assembling filter: %w", err)
 	}
 
-	pc, err := net.ListenPacket(network, addr)
+	sock, err := socket.Socket(
+		unix.AF_PACKET,
+		unix.SOCK_DGRAM,
+		proto,
+		"afpacket",
+		nil, // no config
+	)
 	if err != nil {
-		return nil, fmt.Errorf("creating packet conn: %w", err)
+		return nil, fmt.Errorf("creating AF_PACKET socket: %w", err)
 	}
 
-	if err := setBPF(pc, asm); err != nil {
-		pc.Close()
+	if err := sock.SetBPF(asm); err != nil {
+		sock.Close()
 		return nil, fmt.Errorf("installing BPF filter: %w", err)
 	}
 
 	// If all the above succeeds, we should be ready to receive. Just
 	// out of paranoia, check that we do receive a well-formed disco
 	// packet.
-	tc, err := net.ListenPacket("udp", net.JoinHostPort(addr, "0"))
+	tc, err := net.ListenPacket(udpnet, net.JoinHostPort(addr, "0"))
 	if err != nil {
-		pc.Close()
+		sock.Close()
 		return nil, fmt.Errorf("creating disco test socket: %w", err)
 	}
 	defer tc.Close()
-	if _, err := tc.(*net.UDPConn).WriteToUDPAddrPort(testDiscoPacket, netip.MustParseAddrPort(testAddr)); err != nil {
-		pc.Close()
+	if _, err := tc.(*net.UDPConn).WriteToUDPAddrPort(testDiscoPacket, testAddr); err != nil {
+		sock.Close()
 		return nil, fmt.Errorf("writing disco test packet: %w", err)
 	}
-	pc.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
-	var buf [1500]byte
+
+	const selfTestTimeout = 100 * time.Millisecond
+	if err := sock.SetReadDeadline(time.Now().Add(selfTestTimeout)); err != nil {
+		sock.Close()
+		return nil, fmt.Errorf("setting socket timeout: %w", err)
+	}
+
+	var (
+		ctx = context.Background()
+		buf [1500]byte
+	)
 	for {
-		n, _, err := pc.ReadFrom(buf[:])
+		n, _, err := sock.Recvfrom(ctx, buf[:], 0)
 		if err != nil {
-			pc.Close()
+			sock.Close()
 			return nil, fmt.Errorf("reading during raw disco self-test: %w", err)
 		}
-		if n < udpHeaderSize {
+
+		_ /* src */, _ /* dst */, payload := parseUDPPacket(buf[:n], family == "ip6")
+		if payload == nil {
 			continue
 		}
-		if !bytes.Equal(buf[udpHeaderSize:n], testDiscoPacket) {
+		if !bytes.Equal(payload, testDiscoPacket) {
+			c.discoLogf("listenRawDisco: self-test: received mismatched UDP packet of %d bytes", len(payload))
 			continue
 		}
+		c.logf("[v1] listenRawDisco: self-test passed for %s", family)
 		break
 	}
-	pc.SetReadDeadline(time.Time{})
+	sock.SetReadDeadline(time.Time{})
 
-	go c.receiveDisco(pc, family == "ip6")
-	return pc, nil
+	go c.receiveDisco(sock, family == "ip6")
+	return sock, nil
 }
 
-func (c *Conn) receiveDisco(pc net.PacketConn, isIPV6 bool) {
+// parseUDPPacket is a basic parser for UDP packets that returns the source and
+// destination addresses, and the payload. The returned payload is a sub-slice
+// of the input buffer.
+//
+// It expects to be called with a buffer that contains the entire UDP packet,
+// including the IP header, and one that has been filtered with the BPF
+// programs above.
+//
+// If an error occurs, it will return the zero values for all return values.
+func parseUDPPacket(buf []byte, isIPv6 bool) (src, dst netip.AddrPort, payload []byte) {
+	// First, parse the IPv4 or IPv6 header to get to the UDP header. Since
+	// we assume this was filtered with BPF, we know that there will be no
+	// IPv6 extension headers.
+	var (
+		srcIP, dstIP netip.Addr
+		udp          []byte
+	)
+	if isIPv6 {
+		// Basic length check to ensure that we don't panic
+		if len(buf) < ipv6.HeaderLen+udpHeaderSize {
+			return
+		}
+
+		// Extract the source and destination addresses from the IPv6
+		// header.
+		srcIP, _ = netip.AddrFromSlice(buf[8:24])
+		dstIP, _ = netip.AddrFromSlice(buf[24:40])
+
+		// We know that the UDP packet starts immediately after the IPv6
+		// packet.
+		udp = buf[ipv6.HeaderLen:]
+	} else {
+		// This is an IPv4 packet; read the length field from the header.
+		if len(buf) < ipv4.HeaderLen {
+			return
+		}
+		udpOffset := int((buf[0] & 0x0F) << 2)
+		if udpOffset+udpHeaderSize > len(buf) {
+			return
+		}
+
+		// Parse the source and destination IPs.
+		srcIP, _ = netip.AddrFromSlice(buf[12:16])
+		dstIP, _ = netip.AddrFromSlice(buf[16:20])
+		udp = buf[udpOffset:]
+	}
+
+	// Parse the ports
+	srcPort := binary.BigEndian.Uint16(udp[0:2])
+	dstPort := binary.BigEndian.Uint16(udp[2:4])
+
+	// The payload starts after the UDP header.
+	payload = udp[8:]
+	return netip.AddrPortFrom(srcIP, srcPort), netip.AddrPortFrom(dstIP, dstPort), payload
+}
+
+// ethernetProtoIPv4 returns the constant unix.ETH_P_IP, in network byte order.
+// packet(7) sockets require that the 'protocol' argument be in network byte
+// order; see:
+//
+//	https://man7.org/linux/man-pages/man7/packet.7.html
+//
+// Instead of using htons at runtime, we can just hardcode the value here...
+// but we also have a test that verifies that this is correct.
+func ethernetProtoIPv4() int {
+	if cpu.IsBigEndian {
+		return 0x0800
+	} else {
+		return 0x0008
+	}
+}
+
+// ethernetProtoIPv6 returns the constant unix.ETH_P_IPV6, and is otherwise the
+// same as ethernetProtoIPv4.
+func ethernetProtoIPv6() int {
+	if cpu.IsBigEndian {
+		return 0x86dd
+	} else {
+		return 0xdd86
+	}
+}
+
+func (c *Conn) discoLogf(format string, args ...any) {
+	// Enable debug logging if we're debugging raw disco reads or if the
+	// magicsock component logs are on.
+	if debugRawDiscoReads() {
+		c.logf(format, args...)
+	} else {
+		c.dlogf(format, args...)
+	}
+}
+
+func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) {
+	// Given that we're parsing raw packets, be extra careful and recover
+	// from any panics in this function.
+	//
+	// If we didn't have a recover() here and panic'd, we'd take down the
+	// entire process since this function is the top of a goroutine, and Go
+	// will kill the process if a goroutine panics and it unwinds past the
+	// top-level function.
+	defer func() {
+		if err := recover(); err != nil {
+			c.logf("[unexpected] recovered from panic in receiveDisco(isIPv6=%v): %v", isIPV6, err)
+		}
+	}()
+
+	ctx := context.Background()
+
+	// Set up our loggers
+	var family string
+	if isIPV6 {
+		family = "ip6"
+	} else {
+		family = "ip4"
+	}
+	var (
+		prefix string      = "disco raw " + family + ": "
+		logf   logger.Logf = logger.WithPrefix(c.logf, prefix)
+		dlogf  logger.Logf = logger.WithPrefix(c.discoLogf, prefix)
+	)
+
 	var buf [1500]byte
 	for {
-		n, src, err := pc.ReadFrom(buf[:])
+		n, src, err := pc.Recvfrom(ctx, buf[:], 0)
 		if debugRawDiscoReads() {
-			c.logf("raw disco read from %v = (%v, %v)", src, n, err)
+			logf("read from %s = (%v, %v)", printSockaddr(src), n, err)
 		}
-		if errors.Is(err, net.ErrClosed) {
+		if err != nil && (errors.Is(err, net.ErrClosed) || err.Error() == "use of closed file") {
+			// EOF; no need to print an error
 			return
 		} else if err != nil {
-			c.logf("disco raw reader failed: %v", err)
+			logf("reader failed: %v", err)
 			return
 		}
-		if n < udpHeaderSize {
-			// Too small to be a valid UDP datagram, drop.
+
+		srcAddr, dstAddr, payload := parseUDPPacket(buf[:n], family == "ip6")
+		if payload == nil {
+			// callee logged
 			continue
 		}
 
-		dstPort := binary.BigEndian.Uint16(buf[2:4])
+		dstPort := dstAddr.Port()
 		if dstPort == 0 {
-			c.logf("[unexpected] disco raw: received packet for port 0")
+			logf("[unexpected] received packet for port 0")
 		}
 
 		var acceptPort uint16
@@ -242,59 +423,58 @@ func (c *Conn) receiveDisco(pc net.PacketConn, isIPV6 bool) {
 		if acceptPort == 0 {
 			// This should only typically happen if the receiving address family
 			// was recently disabled.
-			c.dlogf("[v1] disco raw: dropping packet for port %d as acceptPort=0", dstPort)
+			dlogf("[v1] dropping packet for port %d as acceptPort=0", dstPort)
 			continue
 		}
 
+		// If the packet isn't destined for our local port, then we
+		// should drop it since it might be for another Tailscale
+		// process on the same machine, or NATed to a different machine
+		// if this is a router, etc.
+		//
+		// We get the local port to compare against inside the receive
+		// loop; we can't cache this beforehand because it can change
+		// if/when we rebind.
 		if dstPort != acceptPort {
-			c.dlogf("[v1] disco raw: dropping packet for port %d", dstPort)
-			continue
-		}
-
-		srcIP, ok := netip.AddrFromSlice(src.(*net.IPAddr).IP)
-		if !ok {
-			c.logf("[unexpected] PacketConn.ReadFrom returned not-an-IP %v in from", src)
+			dlogf("[v1] dropping packet for port %d that isn't our local port", dstPort)
 			continue
 		}
-		srcPort := binary.BigEndian.Uint16(buf[:2])
 
-		if srcIP.Is4() {
-			metricRecvDiscoPacketIPv4.Add(1)
-		} else {
+		if isIPV6 {
 			metricRecvDiscoPacketIPv6.Add(1)
+		} else {
+			metricRecvDiscoPacketIPv4.Add(1)
 		}
 
-		c.handleDiscoMessage(buf[udpHeaderSize:n], netip.AddrPortFrom(srcIP, srcPort), key.NodePublic{}, discoRXPathRawSocket)
+		c.handleDiscoMessage(payload, srcAddr, key.NodePublic{}, discoRXPathRawSocket)
 	}
 }
 
-// setBPF installs filter as the BPF filter on conn.
-// Ideally we would just use SetBPF as implemented in x/net/ipv4,
-// but x/net/ipv6 doesn't implement it. And once you've written
-// this code once, it turns out to be address family agnostic, so
-// we might as well use it on both and get to use a net.PacketConn
-// directly for both families instead of being stuck with
-// different types.
-func setBPF(conn net.PacketConn, filter []bpf.RawInstruction) error {
-	sc, err := conn.(*net.IPConn).SyscallConn()
-	if err != nil {
-		return err
-	}
-	prog := &unix.SockFprog{
-		Len:    uint16(len(filter)),
-		Filter: (*unix.SockFilter)(unsafe.Pointer(&filter[0])),
-	}
-	var setErr error
-	err = sc.Control(func(fd uintptr) {
-		setErr = unix.SetsockoptSockFprog(int(fd), unix.SOL_SOCKET, unix.SO_ATTACH_FILTER, prog)
-	})
-	if err != nil {
-		return err
-	}
-	if setErr != nil {
-		return err
+// printSockaddr is a helper function to pretty-print various sockaddr types.
+func printSockaddr(sa unix.Sockaddr) string {
+	switch sa := sa.(type) {
+	case *unix.SockaddrInet4:
+		addr := netip.AddrFrom4(sa.Addr)
+		return netip.AddrPortFrom(addr, uint16(sa.Port)).String()
+	case *unix.SockaddrInet6:
+		addr := netip.AddrFrom16(sa.Addr)
+		return netip.AddrPortFrom(addr, uint16(sa.Port)).String()
+	case *unix.SockaddrLinklayer:
+		hwaddr := sa.Addr[:sa.Halen]
+
+		var buf strings.Builder
+		fmt.Fprintf(&buf, "link(ty=0x%04x,if=%d):[", sa.Protocol, sa.Ifindex)
+		for i, b := range hwaddr {
+			if i > 0 {
+				buf.WriteByte(':')
+			}
+			fmt.Fprintf(&buf, "%02x", b)
+		}
+		buf.WriteByte(']')
+		return buf.String()
+	default:
+		return fmt.Sprintf("unknown(%T)", sa)
 	}
-	return nil
 }
 
 // trySetSocketBuffer attempts to set SO_SNDBUFFORCE and SO_RECVBUFFORCE which

+ 148 - 0
wgengine/magicsock/magicsock_linux_test.go

@@ -0,0 +1,148 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package magicsock
+
+import (
+	"bytes"
+	"encoding/binary"
+	"net/netip"
+	"testing"
+
+	"golang.org/x/sys/cpu"
+	"golang.org/x/sys/unix"
+	"tailscale.com/disco"
+)
+
+func TestParseUDPPacket(t *testing.T) {
+	src4 := netip.MustParseAddrPort("127.0.0.1:12345")
+	dst4 := netip.MustParseAddrPort("127.0.0.2:54321")
+
+	src6 := netip.MustParseAddrPort("[::1]:12345")
+	dst6 := netip.MustParseAddrPort("[::2]:54321")
+
+	udp4Packet := []byte{
+		// IPv4 header
+		0x45, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x00,
+		0x40, 0x11, 0x00, 0x00,
+		0x7f, 0x00, 0x00, 0x01, // source ip
+		0x7f, 0x00, 0x00, 0x02, // dest ip
+
+		// UDP header
+		0x30, 0x39, // src port
+		0xd4, 0x31, // dest port
+		0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes
+		0x00, 0x00, // checksum; unused
+
+		// Payload: disco magic plus 4 bytes
+		0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03,
+	}
+	udp6Packet := []byte{
+		// IPv6 header
+		0x60, 0x00, 0x00, 0x00,
+		0x00, 0x12, // payload length
+		0x11, // next header: UDP
+		0x00, // hop limit; unused
+
+		// Source IP
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
+		// Dest IP
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
+
+		// UDP header
+		0x30, 0x39, // src port
+		0xd4, 0x31, // dest port
+		0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes
+		0x00, 0x00, // checksum; unused
+
+		// Payload: disco magic plus 4 bytes
+		0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03,
+	}
+
+	// Verify that parsing the UDP packet works correctly.
+	t.Run("IPv4", func(t *testing.T) {
+		src, dst, payload := parseUDPPacket(udp4Packet, false)
+		if src != src4 {
+			t.Errorf("src = %v; want %v", src, src4)
+		}
+		if dst != dst4 {
+			t.Errorf("dst = %v; want %v", dst, dst4)
+		}
+		if !bytes.HasPrefix(payload, []byte(disco.Magic)) {
+			t.Errorf("payload = %x; must start with %x", payload, disco.Magic)
+		}
+	})
+	t.Run("IPv6", func(t *testing.T) {
+		src, dst, payload := parseUDPPacket(udp6Packet, true)
+		if src != src6 {
+			t.Errorf("src = %v; want %v", src, src6)
+		}
+		if dst != dst6 {
+			t.Errorf("dst = %v; want %v", dst, dst6)
+		}
+		if !bytes.HasPrefix(payload, []byte(disco.Magic)) {
+			t.Errorf("payload = %x; must start with %x", payload, disco.Magic)
+		}
+	})
+	t.Run("Truncated", func(t *testing.T) {
+		truncateBy := func(b []byte, n int) []byte {
+			if n >= len(b) {
+				return nil
+			}
+			return b[:len(b)-n]
+		}
+
+		src, dst, payload := parseUDPPacket(truncateBy(udp4Packet, 11), false)
+		if payload != nil {
+			t.Errorf("payload = %x; want nil", payload)
+		}
+		if src.IsValid() || dst.IsValid() {
+			t.Errorf("src = %v, dst = %v; want invalid", src, dst)
+		}
+
+		src, dst, payload = parseUDPPacket(truncateBy(udp6Packet, 11), true)
+		if payload != nil {
+			t.Errorf("payload = %x; want nil", payload)
+		}
+		if src.IsValid() || dst.IsValid() {
+			t.Errorf("src = %v, dst = %v; want invalid", src, dst)
+		}
+	})
+}
+
+func TestEthernetProto(t *testing.T) {
+	htons := func(x uint16) int {
+		// Network byte order is big-endian; write the value as
+		// big-endian to a byte slice and read it back in the native
+		// endian-ness. This is a no-op on a big-endian platform and a
+		// byte swap on a little-endian platform.
+		var b [2]byte
+		binary.BigEndian.PutUint16(b[:], x)
+		return int(binary.NativeEndian.Uint16(b[:]))
+	}
+
+	if v4 := ethernetProtoIPv4(); v4 != htons(unix.ETH_P_IP) {
+		t.Errorf("ethernetProtoIPv4 = 0x%04x; want 0x%04x", v4, htons(unix.ETH_P_IP))
+	}
+	if v6 := ethernetProtoIPv6(); v6 != htons(unix.ETH_P_IPV6) {
+		t.Errorf("ethernetProtoIPv6 = 0x%04x; want 0x%04x", v6, htons(unix.ETH_P_IPV6))
+	}
+
+	// As a way to verify that the htons function is working correctly,
+	// assert that the ETH_P_IP value returned from our function matches
+	// the value defined in the unix package based on whether the host is
+	// big-endian (network byte order) or little-endian.
+	if cpu.IsBigEndian {
+		if v4 := ethernetProtoIPv4(); v4 != unix.ETH_P_IP {
+			t.Errorf("ethernetProtoIPv4 = 0x%04x; want 0x%04x", v4, unix.ETH_P_IP)
+		}
+	} else {
+		if v4 := ethernetProtoIPv4(); v4 == unix.ETH_P_IP {
+			t.Errorf("ethernetProtoIPv4 = 0x%04x; want 0x%04x", v4, htons(unix.ETH_P_IP))
+		} else {
+			t.Logf("ethernetProtoIPv4 = 0x%04x, correctly different from 0x%04x", v4, unix.ETH_P_IP)
+		}
+	}
+}