Explorar o código

net/{batching,packet},wgengine/magicsock: export batchingConn (#16848)

For eventual use by net/udprelay.Server.

Updates tailscale/corp#31164

Signed-off-by: Jordan Whited <[email protected]>
Jordan Whited hai 7 meses
pai
achega
16bc0a5558

+ 1 - 0
cmd/k8s-operator/depaware.txt

@@ -838,6 +838,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
         tailscale.com/logtail/filch                                  from tailscale.com/log/sockstatlog+
         tailscale.com/metrics                                        from tailscale.com/derp+
         tailscale.com/net/bakedroots                                 from tailscale.com/net/tlsdial+
+     💣 tailscale.com/net/batching                                   from tailscale.com/wgengine/magicsock
         tailscale.com/net/captivedetection                           from tailscale.com/ipn/ipnlocal+
         tailscale.com/net/connstats                                  from tailscale.com/net/tstun+
         tailscale.com/net/dns                                        from tailscale.com/ipn/ipnlocal+

+ 1 - 0
cmd/tailscaled/depaware.txt

@@ -311,6 +311,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
         tailscale.com/logtail/filch                                  from tailscale.com/log/sockstatlog+
         tailscale.com/metrics                                        from tailscale.com/derp+
         tailscale.com/net/bakedroots                                 from tailscale.com/net/tlsdial+
+     💣 tailscale.com/net/batching                                   from tailscale.com/wgengine/magicsock
         tailscale.com/net/captivedetection                           from tailscale.com/ipn/ipnlocal+
         tailscale.com/net/connstats                                  from tailscale.com/net/tstun+
         tailscale.com/net/dns                                        from tailscale.com/cmd/tailscaled+

+ 1 - 0
cmd/tsidp/depaware.txt

@@ -268,6 +268,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar
         tailscale.com/logtail/filch                                  from tailscale.com/log/sockstatlog+
         tailscale.com/metrics                                        from tailscale.com/derp+
         tailscale.com/net/bakedroots                                 from tailscale.com/ipn/ipnlocal+
+     💣 tailscale.com/net/batching                                   from tailscale.com/wgengine/magicsock
         tailscale.com/net/captivedetection                           from tailscale.com/ipn/ipnlocal+
         tailscale.com/net/connstats                                  from tailscale.com/net/tstun+
         tailscale.com/net/dns                                        from tailscale.com/ipn/ipnlocal+

+ 48 - 0
net/batching/conn.go

@@ -0,0 +1,48 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package batching implements a socket optimized for increased throughput.
+package batching
+
+import (
+	"net/netip"
+
+	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
+	"tailscale.com/net/packet"
+	"tailscale.com/types/nettype"
+)
+
+var (
+	// This acts as a compile-time check for our usage of ipv6.Message in
+	// [Conn] for both IPv6 and IPv4 operations.
+	_ ipv6.Message = ipv4.Message{}
+)
+
+// Conn is a nettype.PacketConn that provides batched i/o using
+// platform-specific optimizations, e.g. {recv,send}mmsg & UDP GSO/GRO.
+//
+// Conn originated from (and is still used by) magicsock where its API was
+// strongly influenced by [wireguard-go/conn.Bind] constraints, namely
+// wireguard-go's ownership of packet memory.
+type Conn interface {
+	nettype.PacketConn
+	// ReadBatch reads messages from [Conn] into msgs. It returns the number of
+	// messages the caller should evaluate for nonzero len, as a zero len
+	// message may fall on either side of a nonzero.
+	//
+	// Each [ipv6.Message.OOB] must be sized to at least MinControlMessageSize().
+	// len(msgs) must be at least MinReadBatchMsgsLen().
+	ReadBatch(msgs []ipv6.Message, flags int) (n int, err error)
+	// WriteBatchTo writes buffs to addr.
+	//
+	// If geneve.VNI.IsSet(), then geneve is encoded into the space preceding
+	// offset, and offset must equal [packet.GeneveFixedHeaderLength]. If
+	// !geneve.VNI.IsSet() then the space preceding offset is ignored.
+	//
+	// len(buffs) must be <= batchSize supplied in TryUpgradeToConn().
+	//
+	// WriteBatchTo may return a [neterror.ErrUDPGSODisabled] error if UDP GSO
+	// was disabled as a result of a send error.
+	WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error
+}

+ 21 - 0
net/batching/conn_default.go

@@ -0,0 +1,21 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build !linux
+
+package batching
+
+import (
+	"tailscale.com/types/nettype"
+)
+
+// TryUpgradeToConn is no-op on all platforms except linux.
+func TryUpgradeToConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn {
+	return pconn
+}
+
+var controlMessageSize = 0
+
+func MinControlMessageSize() int {
+	return controlMessageSize
+}

+ 44 - 33
wgengine/magicsock/batching_conn_linux.go → net/batching/conn_linux.go

@@ -1,7 +1,7 @@
 // Copyright (c) Tailscale Inc & AUTHORS
 // SPDX-License-Identifier: BSD-3-Clause
 
-package magicsock
+package batching
 
 import (
 	"encoding/binary"
@@ -43,10 +43,15 @@ type xnetBatchWriter interface {
 	WriteBatch([]ipv6.Message, int) (int, error)
 }
 
+var (
+	// [linuxBatchingConn] implements [Conn].
+	_ Conn = &linuxBatchingConn{}
+)
+
 // linuxBatchingConn is a UDP socket that provides batched i/o. It implements
-// batchingConn.
+// [Conn].
 type linuxBatchingConn struct {
-	pc                    nettype.PacketConn
+	pc                    *net.UDPConn
 	xpc                   xnetBatchReaderWriter
 	rxOffload             bool                                  // supports UDP GRO or similar
 	txOffload             atomic.Bool                           // supports UDP GSO or similar
@@ -98,9 +103,8 @@ const (
 //
 // All msgs have their Addr field set to addr.
 //
-// All msgs[i].Buffers[0] are preceded by a Geneve header with vni.get() if
-// vni.isSet().
-func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, vni virtualNetworkID, buffs [][]byte, msgs []ipv6.Message, offset int) int {
+// All msgs[i].Buffers[0] are preceded by a Geneve header (geneve) if geneve.VNI.IsSet().
+func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, geneve packet.GeneveHeader, buffs [][]byte, msgs []ipv6.Message, offset int) int {
 	var (
 		base     = -1 // index of msg we are currently coalescing into
 		gsoSize  int  // segmentation size of msgs[base]
@@ -111,15 +115,10 @@ func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, vni virtualNetwo
 	if addr.IP.To4() == nil {
 		maxPayloadLen = maxIPv6PayloadLen
 	}
-	vniIsSet := vni.isSet()
-	var gh packet.GeneveHeader
-	if vniIsSet {
-		gh.Protocol = packet.GeneveProtocolWireGuard
-		gh.VNI = vni.get()
-	}
+	vniIsSet := geneve.VNI.IsSet()
 	for i, buff := range buffs {
 		if vniIsSet {
-			gh.Encode(buffs[i])
+			geneve.Encode(buff)
 		} else {
 			buff = buff[offset:]
 		}
@@ -179,37 +178,34 @@ func (c *linuxBatchingConn) putSendBatch(batch *sendBatch) {
 	c.sendBatchPool.Put(batch)
 }
 
-func (c *linuxBatchingConn) WriteBatchTo(buffs [][]byte, addr epAddr, offset int) error {
+func (c *linuxBatchingConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error {
 	batch := c.getSendBatch()
 	defer c.putSendBatch(batch)
-	if addr.ap.Addr().Is6() {
-		as16 := addr.ap.Addr().As16()
+	if addr.Addr().Is6() {
+		as16 := addr.Addr().As16()
 		copy(batch.ua.IP, as16[:])
 		batch.ua.IP = batch.ua.IP[:16]
 	} else {
-		as4 := addr.ap.Addr().As4()
+		as4 := addr.Addr().As4()
 		copy(batch.ua.IP, as4[:])
 		batch.ua.IP = batch.ua.IP[:4]
 	}
-	batch.ua.Port = int(addr.ap.Port())
+	batch.ua.Port = int(addr.Port())
 	var (
 		n       int
 		retried bool
 	)
 retry:
 	if c.txOffload.Load() {
-		n = c.coalesceMessages(batch.ua, addr.vni, buffs, batch.msgs, offset)
+		n = c.coalesceMessages(batch.ua, geneve, buffs, batch.msgs, offset)
 	} else {
-		vniIsSet := addr.vni.isSet()
-		var gh packet.GeneveHeader
+		vniIsSet := geneve.VNI.IsSet()
 		if vniIsSet {
-			gh.Protocol = packet.GeneveProtocolWireGuard
-			gh.VNI = addr.vni.get()
 			offset -= packet.GeneveFixedHeaderLength
 		}
 		for i := range buffs {
 			if vniIsSet {
-				gh.Encode(buffs[i])
+				geneve.Encode(buffs[i])
 			}
 			batch.msgs[i].Buffers[0] = buffs[i][offset:]
 			batch.msgs[i].Addr = batch.ua
@@ -231,11 +227,7 @@ retry:
 }
 
 func (c *linuxBatchingConn) SyscallConn() (syscall.RawConn, error) {
-	sc, ok := c.pc.(syscall.Conn)
-	if !ok {
-		return nil, errUnsupportedConnType
-	}
-	return sc.SyscallConn()
+	return c.pc.SyscallConn()
 }
 
 func (c *linuxBatchingConn) writeBatch(msgs []ipv6.Message) error {
@@ -391,9 +383,10 @@ func setGSOSizeInControl(control *[]byte, gsoSize uint16) {
 	*control = (*control)[:unix.CmsgSpace(2)]
 }
 
-// tryUpgradeToBatchingConn probes the capabilities of the OS and pconn, and
-// upgrades pconn to a *linuxBatchingConn if appropriate.
-func tryUpgradeToBatchingConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
+// TryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades
+// pconn to a [Conn] if appropriate. A batch size of MinReadBatchMsgsLen() is
+// suggested for the best performance.
+func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
 	if runtime.GOOS != "linux" {
 		// Exclude Android.
 		return pconn
@@ -415,7 +408,7 @@ func tryUpgradeToBatchingConn(pconn nettype.PacketConn, network string, batchSiz
 		return pconn
 	}
 	b := &linuxBatchingConn{
-		pc:                    pconn,
+		pc:                    uc,
 		getGSOSizeFromControl: getGSOSizeFromControl,
 		setGSOSizeInControl:   setGSOSizeInControl,
 		sendBatchPool: sync.Pool{
@@ -449,3 +442,21 @@ func tryUpgradeToBatchingConn(pconn nettype.PacketConn, network string, batchSiz
 	b.txOffload.Store(txOffload)
 	return b
 }
+
+var controlMessageSize = -1 // bomb if used for allocation before init
+
+func init() {
+	// controlMessageSize is set to hold a UDP_GRO or UDP_SEGMENT control
+	// message. These contain a single uint16 of data.
+	controlMessageSize = unix.CmsgSpace(2)
+}
+
+// MinControlMessageSize returns the minimum control message size required to
+// support read batching via [Conn.ReadBatch].
+func MinControlMessageSize() int {
+	return controlMessageSize
+}
+
+func MinReadBatchMsgsLen() int {
+	return 128
+}

+ 22 - 10
wgengine/magicsock/batching_conn_linux_test.go → net/batching/conn_linux_test.go

@@ -1,13 +1,14 @@
 // Copyright (c) Tailscale Inc & AUTHORS
 // SPDX-License-Identifier: BSD-3-Clause
 
-package magicsock
+package batching
 
 import (
 	"encoding/binary"
 	"net"
 	"testing"
 
+	"github.com/tailscale/wireguard-go/conn"
 	"golang.org/x/net/ipv6"
 	"tailscale.com/net/packet"
 )
@@ -159,13 +160,15 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
 		return make([]byte, len+packet.GeneveFixedHeaderLength, cap+packet.GeneveFixedHeaderLength)
 	}
 
-	vni1 := virtualNetworkID{}
-	vni1.set(1)
+	geneve := packet.GeneveHeader{
+		Protocol: packet.GeneveProtocolWireGuard,
+	}
+	geneve.VNI.Set(1)
 
 	cases := []struct {
 		name     string
 		buffs    [][]byte
-		vni      virtualNetworkID
+		geneve   packet.GeneveHeader
 		wantLens []int
 		wantGSO  []int
 	}{
@@ -182,7 +185,7 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
 			buffs: [][]byte{
 				withGeneveSpace(1, 1),
 			},
-			vni:      vni1,
+			geneve:   geneve,
 			wantLens: []int{1 + packet.GeneveFixedHeaderLength},
 			wantGSO:  []int{0},
 		},
@@ -201,7 +204,7 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
 				withGeneveSpace(1, 2+packet.GeneveFixedHeaderLength),
 				withGeneveSpace(1, 1),
 			},
-			vni:      vni1,
+			geneve:   geneve,
 			wantLens: []int{2 + (2 * packet.GeneveFixedHeaderLength)},
 			wantGSO:  []int{1 + packet.GeneveFixedHeaderLength},
 		},
@@ -220,7 +223,7 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
 				withGeneveSpace(2, 3+packet.GeneveFixedHeaderLength),
 				withGeneveSpace(1, 1),
 			},
-			vni:      vni1,
+			geneve:   geneve,
 			wantLens: []int{3 + (2 * packet.GeneveFixedHeaderLength)},
 			wantGSO:  []int{2 + packet.GeneveFixedHeaderLength},
 		},
@@ -241,7 +244,7 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
 				withGeneveSpace(1, 1),
 				withGeneveSpace(2, 2),
 			},
-			vni:      vni1,
+			geneve:   geneve,
 			wantLens: []int{3 + (2 * packet.GeneveFixedHeaderLength), 2 + packet.GeneveFixedHeaderLength},
 			wantGSO:  []int{2 + packet.GeneveFixedHeaderLength, 0},
 		},
@@ -262,7 +265,7 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
 				withGeneveSpace(2, 2),
 				withGeneveSpace(2, 2),
 			},
-			vni:      vni1,
+			geneve:   geneve,
 			wantLens: []int{4 + (2 * packet.GeneveFixedHeaderLength), 2 + packet.GeneveFixedHeaderLength},
 			wantGSO:  []int{2 + packet.GeneveFixedHeaderLength, 0},
 		},
@@ -279,7 +282,7 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
 				msgs[i].Buffers = make([][]byte, 1)
 				msgs[i].OOB = make([]byte, 0, 2)
 			}
-			got := c.coalesceMessages(addr, tt.vni, tt.buffs, msgs, packet.GeneveFixedHeaderLength)
+			got := c.coalesceMessages(addr, tt.geneve, tt.buffs, msgs, packet.GeneveFixedHeaderLength)
 			if got != len(tt.wantLens) {
 				t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
 			}
@@ -302,3 +305,12 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
 		})
 	}
 }
+
+func TestMinReadBatchMsgsLen(t *testing.T) {
+	// So long as magicsock uses [Conn], and [wireguard-go/conn.Bind] API is
+	// shaped for wireguard-go to control packet memory, these values should be
+	// aligned.
+	if MinReadBatchMsgsLen() != conn.IdealBatchSize {
+		t.Fatalf("MinReadBatchMsgsLen():%d != conn.IdealBatchSize(): %d", MinReadBatchMsgsLen(), conn.IdealBatchSize)
+	}
+}

+ 40 - 10
net/packet/geneve.go

@@ -24,6 +24,33 @@ const (
 	GeneveProtocolWireGuard uint16 = 0x7A12
 )
 
+// VirtualNetworkID is a Geneve header (RFC8926) 3-byte virtual network
+// identifier. Its methods are NOT thread-safe.
+type VirtualNetworkID struct {
+	_vni uint32
+}
+
+const (
+	vniSetMask uint32 = 0xFF000000
+	vniGetMask uint32 = ^vniSetMask
+)
+
+// IsSet returns true if Set() had been called previously, otherwise false.
+func (v *VirtualNetworkID) IsSet() bool {
+	return v._vni&vniSetMask != 0
+}
+
+// Set sets the provided VNI. If VNI exceeds the 3-byte storage it will be
+// clamped.
+func (v *VirtualNetworkID) Set(vni uint32) {
+	v._vni = vni | vniSetMask
+}
+
+// Get returns the VNI value.
+func (v *VirtualNetworkID) Get() uint32 {
+	return v._vni & vniGetMask
+}
+
 // GeneveHeader represents the fixed size Geneve header from RFC8926.
 // TLVs/options are not implemented/supported.
 //
@@ -51,7 +78,7 @@ type GeneveHeader struct {
 	// decisions or MAY be used as a mechanism to distinguish between
 	// overlapping address spaces contained in the encapsulated packet when load
 	// balancing across CPUs.
-	VNI uint32
+	VNI VirtualNetworkID
 
 	// O (1 bit): Control packet. This packet contains a control message.
 	// Control messages are sent between tunnel endpoints. Tunnel endpoints MUST
@@ -65,12 +92,18 @@ type GeneveHeader struct {
 	Control bool
 }
 
-// Encode encodes GeneveHeader into b. If len(b) < GeneveFixedHeaderLength an
-// io.ErrShortBuffer error is returned.
+var ErrGeneveVNIUnset = errors.New("VNI is unset")
+
+// Encode encodes GeneveHeader into b. If len(b) < [GeneveFixedHeaderLength] an
+// [io.ErrShortBuffer] error is returned. If !h.VNI.IsSet() then an
+// [ErrGeneveVNIUnset] error is returned.
 func (h *GeneveHeader) Encode(b []byte) error {
 	if len(b) < GeneveFixedHeaderLength {
 		return io.ErrShortBuffer
 	}
+	if !h.VNI.IsSet() {
+		return ErrGeneveVNIUnset
+	}
 	if h.Version > 3 {
 		return errors.New("version must be <= 3")
 	}
@@ -81,15 +114,12 @@ func (h *GeneveHeader) Encode(b []byte) error {
 		b[1] |= 0x80
 	}
 	binary.BigEndian.PutUint16(b[2:], h.Protocol)
-	if h.VNI > 1<<24-1 {
-		return errors.New("VNI must be <= 2^24-1")
-	}
-	binary.BigEndian.PutUint32(b[4:], h.VNI<<8)
+	binary.BigEndian.PutUint32(b[4:], h.VNI.Get()<<8)
 	return nil
 }
 
-// Decode decodes GeneveHeader from b. If len(b) < GeneveFixedHeaderLength an
-// io.ErrShortBuffer error is returned.
+// Decode decodes GeneveHeader from b. If len(b) < [GeneveFixedHeaderLength] an
+// [io.ErrShortBuffer] error is returned.
 func (h *GeneveHeader) Decode(b []byte) error {
 	if len(b) < GeneveFixedHeaderLength {
 		return io.ErrShortBuffer
@@ -99,6 +129,6 @@ func (h *GeneveHeader) Decode(b []byte) error {
 		h.Control = true
 	}
 	h.Protocol = binary.BigEndian.Uint16(b[2:])
-	h.VNI = binary.BigEndian.Uint32(b[4:]) >> 8
+	h.VNI.Set(binary.BigEndian.Uint32(b[4:]) >> 8)
 	return nil
 }

+ 54 - 2
net/packet/geneve_test.go

@@ -4,18 +4,21 @@
 package packet
 
 import (
+	"math"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
+	"github.com/google/go-cmp/cmp/cmpopts"
+	"tailscale.com/types/ptr"
 )
 
 func TestGeneveHeader(t *testing.T) {
 	in := GeneveHeader{
 		Version:  3,
 		Protocol: GeneveProtocolDisco,
-		VNI:      1<<24 - 1,
 		Control:  true,
 	}
+	in.VNI.Set(1<<24 - 1)
 	b := make([]byte, GeneveFixedHeaderLength)
 	err := in.Encode(b)
 	if err != nil {
@@ -26,7 +29,56 @@ func TestGeneveHeader(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	if diff := cmp.Diff(out, in); diff != "" {
+	if diff := cmp.Diff(out, in, cmpopts.EquateComparable(VirtualNetworkID{})); diff != "" {
 		t.Fatalf("wrong results (-got +want)\n%s", diff)
 	}
 }
+
+func TestVirtualNetworkID(t *testing.T) {
+	tests := []struct {
+		name string
+		set  *uint32
+		want uint32
+	}{
+		{
+			"don't Set",
+			nil,
+			0,
+		},
+		{
+			"Set 0",
+			ptr.To(uint32(0)),
+			0,
+		},
+		{
+			"Set 1",
+			ptr.To(uint32(1)),
+			1,
+		},
+		{
+			"Set math.MaxUint32",
+			ptr.To(uint32(math.MaxUint32)),
+			1<<24 - 1,
+		},
+		{
+			"Set max 3-byte value",
+			ptr.To(uint32(1<<24 - 1)),
+			1<<24 - 1,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			v := VirtualNetworkID{}
+			if tt.set != nil {
+				v.Set(*tt.set)
+			}
+			if v.IsSet() != (tt.set != nil) {
+				t.Fatalf("IsSet: %v != wantIsSet: %v", v.IsSet(), tt.set != nil)
+			}
+			if v.Get() != tt.want {
+				t.Fatalf("Get(): %v != want: %v", v.Get(), tt.want)
+			}
+		})
+	}
+}

+ 3 - 2
net/udprelay/server.go

@@ -140,7 +140,8 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
 		rand.Read(e.challenge[senderIndex][:])
 		copy(m.Challenge[:], e.challenge[senderIndex][:])
 		reply := make([]byte, packet.GeneveFixedHeaderLength, 512)
-		gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco}
+		gh := packet.GeneveHeader{Control: true, Protocol: packet.GeneveProtocolDisco}
+		gh.VNI.Set(e.vni)
 		err = gh.Encode(reply)
 		if err != nil {
 			return
@@ -543,7 +544,7 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte, rxSocket, otherAFSo
 	// it simple (and slow) for now.
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	e, ok := s.byVNI[gh.VNI]
+	e, ok := s.byVNI[gh.VNI.Get()]
 	if !ok {
 		// unknown VNI
 		return

+ 6 - 4
net/udprelay/server_test.go

@@ -62,7 +62,8 @@ func (c *testClient) read(t *testing.T) []byte {
 
 func (c *testClient) writeDataPkt(t *testing.T, b []byte) {
 	pkt := make([]byte, packet.GeneveFixedHeaderLength, packet.GeneveFixedHeaderLength+len(b))
-	gh := packet.GeneveHeader{Control: false, VNI: c.vni, Protocol: packet.GeneveProtocolWireGuard}
+	gh := packet.GeneveHeader{Control: false, Protocol: packet.GeneveProtocolWireGuard}
+	gh.VNI.Set(c.vni)
 	err := gh.Encode(pkt)
 	if err != nil {
 		t.Fatal(err)
@@ -84,7 +85,7 @@ func (c *testClient) readDataPkt(t *testing.T) []byte {
 	if gh.Control {
 		t.Fatal("unexpected control")
 	}
-	if gh.VNI != c.vni {
+	if gh.VNI.Get() != c.vni {
 		t.Fatal("unexpected vni")
 	}
 	return b[packet.GeneveFixedHeaderLength:]
@@ -92,7 +93,8 @@ func (c *testClient) readDataPkt(t *testing.T) []byte {
 
 func (c *testClient) writeControlDiscoMsg(t *testing.T, msg disco.Message) {
 	pkt := make([]byte, packet.GeneveFixedHeaderLength, 512)
-	gh := packet.GeneveHeader{Control: true, VNI: c.vni, Protocol: packet.GeneveProtocolDisco}
+	gh := packet.GeneveHeader{Control: true, Protocol: packet.GeneveProtocolDisco}
+	gh.VNI.Set(c.vni)
 	err := gh.Encode(pkt)
 	if err != nil {
 		t.Fatal(err)
@@ -117,7 +119,7 @@ func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message {
 	if !gh.Control {
 		t.Fatal("unexpected non-control")
 	}
-	if gh.VNI != c.vni {
+	if gh.VNI.Get() != c.vni {
 		t.Fatal("unexpected vni")
 	}
 	b = b[packet.GeneveFixedHeaderLength:]

+ 1 - 0
tsnet/depaware.txt

@@ -264,6 +264,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware)
         tailscale.com/logtail/filch                                  from tailscale.com/log/sockstatlog+
         tailscale.com/metrics                                        from tailscale.com/derp+
         tailscale.com/net/bakedroots                                 from tailscale.com/ipn/ipnlocal+
+     💣 tailscale.com/net/batching                                   from tailscale.com/wgengine/magicsock
         tailscale.com/net/captivedetection                           from tailscale.com/ipn/ipnlocal+
         tailscale.com/net/connstats                                  from tailscale.com/net/tstun+
         tailscale.com/net/dns                                        from tailscale.com/ipn/ipnlocal+

+ 0 - 23
wgengine/magicsock/batching_conn.go

@@ -1,23 +0,0 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-package magicsock
-
-import (
-	"golang.org/x/net/ipv4"
-	"golang.org/x/net/ipv6"
-	"tailscale.com/types/nettype"
-)
-
-var (
-	// This acts as a compile-time check for our usage of ipv6.Message in
-	// batchingConn for both IPv6 and IPv4 operations.
-	_ ipv6.Message = ipv4.Message{}
-)
-
-// batchingConn is a nettype.PacketConn that provides batched i/o.
-type batchingConn interface {
-	nettype.PacketConn
-	ReadBatch(msgs []ipv6.Message, flags int) (n int, err error)
-	WriteBatchTo(buffs [][]byte, addr epAddr, offset int) error
-}

+ 0 - 14
wgengine/magicsock/batching_conn_default.go

@@ -1,14 +0,0 @@
-// Copyright (c) Tailscale Inc & AUTHORS
-// SPDX-License-Identifier: BSD-3-Clause
-
-//go:build !linux
-
-package magicsock
-
-import (
-	"tailscale.com/types/nettype"
-)
-
-func tryUpgradeToBatchingConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn {
-	return pconn
-}

+ 2 - 2
wgengine/magicsock/debughttp.go

@@ -152,7 +152,7 @@ func printEndpointHTML(w io.Writer, ep *endpoint) {
 	io.WriteString(w, "<p>Endpoints:</p><ul>")
 	for _, ipp := range eps {
 		s := ep.endpointState[ipp]
-		if ipp == ep.bestAddr.ap && !ep.bestAddr.vni.isSet() {
+		if ipp == ep.bestAddr.ap && !ep.bestAddr.vni.IsSet() {
 			fmt.Fprintf(w, "<li><b>%s</b>: (best)<ul>", ipp)
 		} else {
 			fmt.Fprintf(w, "<li>%s: ...<ul>", ipp)
@@ -208,7 +208,7 @@ func epAddrLess(a, b epAddr) bool {
 		return v < 0
 	}
 	if a.ap.Port() == b.ap.Port() {
-		return a.vni.get() < b.vni.get()
+		return a.vni.Get() < b.vni.Get()
 	}
 	return a.ap.Port() < b.ap.Port()
 }

+ 18 - 18
wgengine/magicsock/endpoint.go

@@ -108,7 +108,7 @@ func (de *endpoint) udpRelayEndpointReady(maybeBest addrQuality) {
 	defer de.mu.Unlock()
 	now := mono.Now()
 	curBestAddrTrusted := now.Before(de.trustBestAddrUntil)
-	sameRelayServer := de.bestAddr.vni.isSet() && maybeBest.relayServerDisco.Compare(de.bestAddr.relayServerDisco) == 0
+	sameRelayServer := de.bestAddr.vni.IsSet() && maybeBest.relayServerDisco.Compare(de.bestAddr.relayServerDisco) == 0
 
 	if !curBestAddrTrusted ||
 		sameRelayServer ||
@@ -1070,7 +1070,7 @@ func (de *endpoint) send(buffs [][]byte, offset int) error {
 
 		switch {
 		case udpAddr.ap.Addr().Is4():
-			if udpAddr.vni.isSet() {
+			if udpAddr.vni.IsSet() {
 				de.c.metrics.outboundPacketsPeerRelayIPv4Total.Add(int64(len(buffs)))
 				de.c.metrics.outboundBytesPeerRelayIPv4Total.Add(int64(txBytes))
 			} else {
@@ -1078,7 +1078,7 @@ func (de *endpoint) send(buffs [][]byte, offset int) error {
 				de.c.metrics.outboundBytesIPv4Total.Add(int64(txBytes))
 			}
 		case udpAddr.ap.Addr().Is6():
-			if udpAddr.vni.isSet() {
+			if udpAddr.vni.IsSet() {
 				de.c.metrics.outboundPacketsPeerRelayIPv6Total.Add(int64(len(buffs)))
 				de.c.metrics.outboundBytesPeerRelayIPv6Total.Add(int64(txBytes))
 			} else {
@@ -1160,7 +1160,7 @@ func (de *endpoint) discoPingTimeout(txid stun.TxID) {
 		return
 	}
 	bestUntrusted := mono.Now().After(de.trustBestAddrUntil)
-	if sp.to == de.bestAddr.epAddr && sp.to.vni.isSet() && bestUntrusted {
+	if sp.to == de.bestAddr.epAddr && sp.to.vni.IsSet() && bestUntrusted {
 		// TODO(jwhited): consider applying this to direct UDP paths as well
 		de.clearBestAddrLocked()
 	}
@@ -1274,7 +1274,7 @@ func (de *endpoint) startDiscoPingLocked(ep epAddr, now mono.Time, purpose disco
 		return
 	}
 	if purpose != pingCLI &&
-		!ep.vni.isSet() { // de.endpointState is only relevant for direct/non-vni epAddr's
+		!ep.vni.IsSet() { // de.endpointState is only relevant for direct/non-vni epAddr's
 		st, ok := de.endpointState[ep.ap]
 		if !ok {
 			// Shouldn't happen. But don't ping an endpoint that's
@@ -1610,7 +1610,7 @@ func (de *endpoint) noteBadEndpoint(udpAddr epAddr) {
 
 	de.clearBestAddrLocked()
 
-	if !udpAddr.vni.isSet() {
+	if !udpAddr.vni.IsSet() {
 		if st, ok := de.endpointState[udpAddr.ap]; ok {
 			st.clear()
 		}
@@ -1644,7 +1644,7 @@ func pingSizeToPktLen(size int, udpAddr epAddr) tstun.WireMTU {
 		headerLen = ipv6.HeaderLen
 	}
 	headerLen += 8 // UDP header length
-	if udpAddr.vni.isSet() {
+	if udpAddr.vni.IsSet() {
 		headerLen += packet.GeneveFixedHeaderLength
 	}
 	return tstun.WireMTU(size + headerLen)
@@ -1699,7 +1699,7 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src epAdd
 	now := mono.Now()
 	latency := now.Sub(sp.at)
 
-	if !isDerp && !src.vni.isSet() {
+	if !isDerp && !src.vni.IsSet() {
 		// Note: we check vni.isSet() as relay [epAddr]'s are not stored in
 		// endpointState, they are either de.bestAddr or not.
 		st, ok := de.endpointState[sp.to.ap]
@@ -1748,7 +1748,7 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src epAdd
 		//  we don't clear direct UDP paths on disco ping timeout (see
 		//  discoPingTimeout).
 		if betterAddr(thisPong, de.bestAddr) {
-			if src.vni.isSet() {
+			if src.vni.IsSet() {
 				// This would be unexpected. Switching to a Geneve-encapsulated
 				// path should only happen in de.relayEndpointReady().
 				de.c.logf("[unexpected] switching to Geneve-encapsulated path %v from %v", thisPong, de.bestAddr)
@@ -1778,23 +1778,23 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src epAdd
 }
 
 // epAddr is a [netip.AddrPort] with an optional Geneve header (RFC8926)
-// [virtualNetworkID].
+// [packet.VirtualNetworkID].
 type epAddr struct {
-	ap  netip.AddrPort   // if ap == tailcfg.DerpMagicIPAddr then vni is never set
-	vni virtualNetworkID // vni.isSet() indicates if this [epAddr] involves a Geneve header
+	ap  netip.AddrPort          // if ap == tailcfg.DerpMagicIPAddr then vni is never set
+	vni packet.VirtualNetworkID // vni.IsSet() indicates if this [epAddr] involves a Geneve header
 }
 
 // isDirect returns true if e.ap is valid and not tailcfg.DerpMagicIPAddr,
 // and a VNI is not set.
 func (e epAddr) isDirect() bool {
-	return e.ap.IsValid() && e.ap.Addr() != tailcfg.DerpMagicIPAddr && !e.vni.isSet()
+	return e.ap.IsValid() && e.ap.Addr() != tailcfg.DerpMagicIPAddr && !e.vni.IsSet()
 }
 
 func (e epAddr) String() string {
-	if !e.vni.isSet() {
+	if !e.vni.IsSet() {
 		return e.ap.String()
 	}
-	return fmt.Sprintf("%v:vni:%d", e.ap.String(), e.vni.get())
+	return fmt.Sprintf("%v:vni:%d", e.ap.String(), e.vni.Get())
 }
 
 // addrQuality is an [epAddr], an optional [key.DiscoPublic] if a relay server
@@ -1833,10 +1833,10 @@ func betterAddr(a, b addrQuality) bool {
 
 	// Geneve-encapsulated paths (UDP relay servers) are lower preference in
 	// relation to non.
-	if !a.vni.isSet() && b.vni.isSet() {
+	if !a.vni.IsSet() && b.vni.IsSet() {
 		return true
 	}
-	if a.vni.isSet() && !b.vni.isSet() {
+	if a.vni.IsSet() && !b.vni.IsSet() {
 		return false
 	}
 
@@ -1982,7 +1982,7 @@ func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) {
 	ps.Active = now.Sub(de.lastSendExt) < sessionActiveTimeout
 
 	if udpAddr, derpAddr, _ := de.addrForSendLocked(now); udpAddr.ap.IsValid() && !derpAddr.IsValid() {
-		if udpAddr.vni.isSet() {
+		if udpAddr.vni.IsSet() {
 			ps.PeerRelay = udpAddr.String()
 		} else {
 			ps.CurAddr = udpAddr.String()

+ 7 - 6
wgengine/magicsock/endpoint_test.go

@@ -8,6 +8,7 @@ import (
 	"testing"
 	"time"
 
+	"tailscale.com/net/packet"
 	"tailscale.com/tailcfg"
 	"tailscale.com/tstime/mono"
 	"tailscale.com/types/key"
@@ -327,24 +328,24 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) {
 }
 
 func Test_epAddr_isDirectUDP(t *testing.T) {
-	vni := virtualNetworkID{}
-	vni.set(7)
+	vni := packet.VirtualNetworkID{}
+	vni.Set(7)
 	tests := []struct {
 		name string
 		ap   netip.AddrPort
-		vni  virtualNetworkID
+		vni  packet.VirtualNetworkID
 		want bool
 	}{
 		{
 			name: "true",
 			ap:   netip.MustParseAddrPort("192.0.2.1:7"),
-			vni:  virtualNetworkID{},
+			vni:  packet.VirtualNetworkID{},
 			want: true,
 		},
 		{
 			name: "false derp magic addr",
 			ap:   netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 0),
-			vni:  virtualNetworkID{},
+			vni:  packet.VirtualNetworkID{},
 			want: false,
 		},
 		{
@@ -370,7 +371,7 @@ func Test_epAddr_isDirectUDP(t *testing.T) {
 func Test_endpoint_udpRelayEndpointReady(t *testing.T) {
 	directAddrQuality := addrQuality{epAddr: epAddr{ap: netip.MustParseAddrPort("192.0.2.1:7")}}
 	peerRelayAddrQuality := addrQuality{epAddr: epAddr{ap: netip.MustParseAddrPort("192.0.2.2:77")}, latency: time.Second}
-	peerRelayAddrQuality.vni.set(1)
+	peerRelayAddrQuality.vni.Set(1)
 	peerRelayAddrQualityHigherLatencySameServer := addrQuality{
 		epAddr:  epAddr{ap: netip.MustParseAddrPort("192.0.2.3:77"), vni: peerRelayAddrQuality.vni},
 		latency: peerRelayAddrQuality.latency * 10,

+ 15 - 41
wgengine/magicsock/magicsock.go

@@ -36,6 +36,7 @@ import (
 	"tailscale.com/health"
 	"tailscale.com/hostinfo"
 	"tailscale.com/ipn/ipnstate"
+	"tailscale.com/net/batching"
 	"tailscale.com/net/connstats"
 	"tailscale.com/net/netcheck"
 	"tailscale.com/net/neterror"
@@ -626,7 +627,7 @@ func newConn(logf logger.Logf) *Conn {
 		msgs := make([]ipv6.Message, c.bind.BatchSize())
 		for i := range msgs {
 			msgs[i].Buffers = make([][]byte, 1)
-			msgs[i].OOB = make([]byte, controlMessageSize)
+			msgs[i].OOB = make([]byte, batching.MinControlMessageSize())
 		}
 		batch := &receiveBatch{
 			msgs: msgs,
@@ -1206,7 +1207,7 @@ func (c *Conn) Ping(peer tailcfg.NodeView, res *ipnstate.PingResult, size int, c
 func (c *Conn) populateCLIPingResponseLocked(res *ipnstate.PingResult, latency time.Duration, ep epAddr) {
 	res.LatencySeconds = latency.Seconds()
 	if ep.ap.Addr() != tailcfg.DerpMagicIPAddr {
-		if ep.vni.isSet() {
+		if ep.vni.IsSet() {
 			res.PeerRelay = ep.String()
 		} else {
 			res.Endpoint = ep.String()
@@ -1473,9 +1474,9 @@ func (c *Conn) Send(buffs [][]byte, ep conn.Endpoint, offset int) (err error) {
 		// deemed "under handshake load" and ends up transmitting a cookie reply
 		// using the received [conn.Endpoint] in [device.SendHandshakeCookie].
 		if ep.src.ap.Addr().Is6() {
-			return c.pconn6.WriteBatchTo(buffs, ep.src, offset)
+			return c.pconn6.WriteWireGuardBatchTo(buffs, ep.src, offset)
 		}
-		return c.pconn4.WriteBatchTo(buffs, ep.src, offset)
+		return c.pconn4.WriteWireGuardBatchTo(buffs, ep.src, offset)
 	}
 	return nil
 }
@@ -1498,9 +1499,9 @@ func (c *Conn) sendUDPBatch(addr epAddr, buffs [][]byte, offset int) (sent bool,
 		panic("bogus sendUDPBatch addr type")
 	}
 	if isIPv6 {
-		err = c.pconn6.WriteBatchTo(buffs, addr, offset)
+		err = c.pconn6.WriteWireGuardBatchTo(buffs, addr, offset)
 	} else {
-		err = c.pconn4.WriteBatchTo(buffs, addr, offset)
+		err = c.pconn4.WriteWireGuardBatchTo(buffs, addr, offset)
 	}
 	if err != nil {
 		var errGSO neterror.ErrUDPGSODisabled
@@ -1793,7 +1794,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *epAddrEndpointCach
 			c.logf("[unexpected] geneve header decoding error: %v", err)
 			return nil, 0, false, false
 		}
-		src.vni.set(geneve.VNI)
+		src.vni = geneve.VNI
 	}
 	switch pt {
 	case packetLooksLikeDisco:
@@ -1825,7 +1826,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *epAddrEndpointCach
 	// geneveInclusivePacketLen holds the packet length prior to any potential
 	// Geneve header stripping.
 	geneveInclusivePacketLen := len(b)
-	if src.vni.isSet() {
+	if src.vni.IsSet() {
 		// Strip away the Geneve header before returning the packet to
 		// wireguard-go.
 		//
@@ -1858,7 +1859,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *epAddrEndpointCach
 	if stats := c.stats.Load(); stats != nil {
 		stats.UpdateRxPhysical(ep.nodeAddr, ipp, 1, geneveInclusivePacketLen)
 	}
-	if src.vni.isSet() && (connNoted || looksLikeInitiationMsg(b)) {
+	if src.vni.IsSet() && (connNoted || looksLikeInitiationMsg(b)) {
 		// connNoted is periodic, but we also want to verify if the peer is who
 		// we believe for all initiation messages, otherwise we could get
 		// unlucky and fail to JIT configure the "correct" peer.
@@ -1887,33 +1888,6 @@ const (
 // speeds.
 var debugIPv4DiscoPingPenalty = envknob.RegisterDuration("TS_DISCO_PONG_IPV4_DELAY")
 
-// virtualNetworkID is a Geneve header (RFC8926) 3-byte virtual network
-// identifier. Its field must only ever be accessed via its methods.
-type virtualNetworkID struct {
-	_vni uint32
-}
-
-const (
-	vniSetMask uint32 = 0xFF000000
-	vniGetMask uint32 = ^vniSetMask
-)
-
-// isSet returns true if set() had been called previously, otherwise false.
-func (v *virtualNetworkID) isSet() bool {
-	return v._vni&vniSetMask != 0
-}
-
-// set sets the provided VNI. If VNI exceeds the 3-byte storage it will be
-// clamped.
-func (v *virtualNetworkID) set(vni uint32) {
-	v._vni = vni | vniSetMask
-}
-
-// get returns the VNI value.
-func (v *virtualNetworkID) get() uint32 {
-	return v._vni & vniGetMask
-}
-
 // sendDiscoAllocateUDPRelayEndpointRequest is primarily an alias for
 // sendDiscoMessage, but it will alternatively send m over the eventbus if dst
 // is a DERP IP:port, and dstKey is self. This saves a round-trip through DERP
@@ -1981,11 +1955,11 @@ func (c *Conn) sendDiscoMessage(dst epAddr, dstKey key.NodePublic, dstDisco key.
 	c.mu.Unlock()
 
 	pkt := make([]byte, 0, 512) // TODO: size it correctly? pool? if it matters.
-	if dst.vni.isSet() {
+	if dst.vni.IsSet() {
 		gh := packet.GeneveHeader{
 			Version:  0,
 			Protocol: packet.GeneveProtocolDisco,
-			VNI:      dst.vni.get(),
+			VNI:      dst.vni,
 			Control:  isRelayHandshakeMsg,
 		}
 		pkt = append(pkt, make([]byte, packet.GeneveFixedHeaderLength)...)
@@ -2006,7 +1980,7 @@ func (c *Conn) sendDiscoMessage(dst epAddr, dstKey key.NodePublic, dstDisco key.
 	box := di.sharedKey.Seal(m.AppendMarshal(nil))
 	pkt = append(pkt, box...)
 	const isDisco = true
-	sent, err = c.sendAddr(dst.ap, dstKey, pkt, isDisco, dst.vni.isSet())
+	sent, err = c.sendAddr(dst.ap, dstKey, pkt, isDisco, dst.vni.IsSet())
 	if sent {
 		if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco()) {
 			node := "?"
@@ -2294,7 +2268,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
 			}
 			return true
 		})
-		if !knownTxID && src.vni.isSet() {
+		if !knownTxID && src.vni.IsSet() {
 			// If it's an unknown TxID, and it's Geneve-encapsulated, then
 			// make [relayManager] aware. It might be in the middle of probing
 			// src.
@@ -2512,7 +2486,7 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src epAddr, di *discoInfo, derpN
 	di.lastPingTime = time.Now()
 	isDerp := src.ap.Addr() == tailcfg.DerpMagicIPAddr
 
-	if src.vni.isSet() {
+	if src.vni.IsSet() {
 		if isDerp {
 			c.logf("[unexpected] got Geneve-encapsulated disco ping from %v/%v over DERP", src, derpNodeSrc)
 			return

+ 0 - 4
wgengine/magicsock/magicsock_default.go

@@ -21,7 +21,3 @@ func (c *Conn) listenRawDisco(family string) (io.Closer, error) {
 func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) {
 	portableTrySetSocketBuffer(pconn, logf)
 }
-
-const (
-	controlMessageSize = 0
-)

+ 0 - 8
wgengine/magicsock/magicsock_linux.go

@@ -516,11 +516,3 @@ func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) {
 		}
 	}
 }
-
-var controlMessageSize = -1 // bomb if used for allocation before init
-
-func init() {
-	// controlMessageSize is set to hold a UDP_GRO or UDP_SEGMENT control
-	// message. These contain a single uint16 of data.
-	controlMessageSize = unix.CmsgSpace(2)
-}

+ 10 - 61
wgengine/magicsock/magicsock_test.go

@@ -13,7 +13,6 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"math"
 	"math/rand"
 	"net"
 	"net/http"
@@ -1787,7 +1786,7 @@ func TestBetterAddr(t *testing.T) {
 	}
 	avl := func(ipps string, vni uint32, d time.Duration) addrQuality {
 		q := al(ipps, d)
-		q.vni.set(vni)
+		q.vni.Set(vni)
 		return q
 	}
 	zero := addrQuality{}
@@ -3178,9 +3177,9 @@ func Test_packetLooksLike(t *testing.T) {
 	gh := packet.GeneveHeader{
 		Version:  0,
 		Protocol: packet.GeneveProtocolDisco,
-		VNI:      1,
 		Control:  true,
 	}
+	gh.VNI.Set(1)
 	err := gh.Encode(geneveEncapDisco)
 	if err != nil {
 		t.Fatal(err)
@@ -3200,9 +3199,9 @@ func Test_packetLooksLike(t *testing.T) {
 	gh = packet.GeneveHeader{
 		Version:  0,
 		Protocol: packet.GeneveProtocolWireGuard,
-		VNI:      1,
 		Control:  true,
 	}
+	gh.VNI.Set(1)
 	err = gh.Encode(geneveEncapWireGuard)
 	if err != nil {
 		t.Fatal(err)
@@ -3213,9 +3212,9 @@ func Test_packetLooksLike(t *testing.T) {
 	gh = packet.GeneveHeader{
 		Version:  1,
 		Protocol: packet.GeneveProtocolDisco,
-		VNI:      1,
 		Control:  true,
 	}
+	gh.VNI.Set(1)
 	err = gh.Encode(geneveEncapDiscoNonZeroGeneveVersion)
 	if err != nil {
 		t.Fatal(err)
@@ -3226,9 +3225,9 @@ func Test_packetLooksLike(t *testing.T) {
 	gh = packet.GeneveHeader{
 		Version:  0,
 		Protocol: packet.GeneveProtocolDisco,
-		VNI:      1,
 		Control:  true,
 	}
+	gh.VNI.Set(1)
 	err = gh.Encode(geneveEncapDiscoNonZeroGeneveReservedBits)
 	if err != nil {
 		t.Fatal(err)
@@ -3240,9 +3239,9 @@ func Test_packetLooksLike(t *testing.T) {
 	gh = packet.GeneveHeader{
 		Version:  0,
 		Protocol: packet.GeneveProtocolDisco,
-		VNI:      1,
 		Control:  true,
 	}
+	gh.VNI.Set(1)
 	err = gh.Encode(geneveEncapDiscoNonZeroGeneveVNILSB)
 	if err != nil {
 		t.Fatal(err)
@@ -3342,55 +3341,6 @@ func Test_packetLooksLike(t *testing.T) {
 	}
 }
 
-func Test_virtualNetworkID(t *testing.T) {
-	tests := []struct {
-		name string
-		set  *uint32
-		want uint32
-	}{
-		{
-			"don't set",
-			nil,
-			0,
-		},
-		{
-			"set 0",
-			ptr.To(uint32(0)),
-			0,
-		},
-		{
-			"set 1",
-			ptr.To(uint32(1)),
-			1,
-		},
-		{
-			"set math.MaxUint32",
-			ptr.To(uint32(math.MaxUint32)),
-			1<<24 - 1,
-		},
-		{
-			"set max 3-byte value",
-			ptr.To(uint32(1<<24 - 1)),
-			1<<24 - 1,
-		},
-	}
-
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			v := virtualNetworkID{}
-			if tt.set != nil {
-				v.set(*tt.set)
-			}
-			if v.isSet() != (tt.set != nil) {
-				t.Fatalf("isSet: %v != wantIsSet: %v", v.isSet(), tt.set != nil)
-			}
-			if v.get() != tt.want {
-				t.Fatalf("get(): %v != want: %v", v.get(), tt.want)
-			}
-		})
-	}
-}
-
 func Test_looksLikeInitiationMsg(t *testing.T) {
 	// initMsg was captured as the first packet from a WireGuard "session"
 	initMsg, err := hex.DecodeString("01000000d9205f67915a500e377b409e0c3d97ca91e68654b95952de965e75df491000cce00632678cd9e8c8525556aa8daf24e6cfc44c48812bb560ff3c1c5dee061b3f833dfaa48acf13b64bd1e0027aa4d977a3721b82fd6072338702fc3193651404980ad46dae2869ba6416cc0eb38621a4140b5b918eb6402b697202adb3002a6d00000000000000000000000000000000")
@@ -3772,6 +3722,7 @@ func TestConn_receiveIP(t *testing.T) {
 	gh := packet.GeneveHeader{
 		Protocol: packet.GeneveProtocolDisco,
 	}
+	gh.VNI.Set(1)
 	err := gh.Encode(looksLikeGeneveDisco)
 	if err != nil {
 		t.Fatal(err)
@@ -3796,10 +3747,8 @@ func TestConn_receiveIP(t *testing.T) {
 	looksLikeGeneveWireGuardInit := make([]byte, packet.GeneveFixedHeaderLength+device.MessageInitiationSize)
 	gh = packet.GeneveHeader{
 		Protocol: packet.GeneveProtocolWireGuard,
-		VNI:      1,
 	}
-	vni := virtualNetworkID{}
-	vni.set(gh.VNI)
+	gh.VNI.Set(1)
 	err = gh.Encode(looksLikeGeneveWireGuardInit)
 	if err != nil {
 		t.Fatal(err)
@@ -3922,7 +3871,7 @@ func TestConn_receiveIP(t *testing.T) {
 			ipp:                             netip.MustParseAddrPort("127.0.0.1:7777"),
 			cache:                           &epAddrEndpointCache{},
 			insertWantEndpointTypeInPeerMap: true,
-			peerMapEpAddr:                   epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777"), vni: vni},
+			peerMapEpAddr:                   epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777"), vni: gh.VNI},
 			wantEndpointType: &lazyEndpoint{
 				maybeEP: newPeerMapInsertableEndpoint(0),
 			},
@@ -3938,7 +3887,7 @@ func TestConn_receiveIP(t *testing.T) {
 			ipp:                             netip.MustParseAddrPort("127.0.0.1:7777"),
 			cache:                           &epAddrEndpointCache{},
 			insertWantEndpointTypeInPeerMap: true,
-			peerMapEpAddr:                   epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777"), vni: vni},
+			peerMapEpAddr:                   epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777"), vni: gh.VNI},
 			wantEndpointType: &lazyEndpoint{
 				maybeEP: newPeerMapInsertableEndpoint(mono.Now().Add(time.Hour * 24)),
 			},

+ 2 - 2
wgengine/magicsock/peermap.go

@@ -184,12 +184,12 @@ func (m *peerMap) setNodeKeyForEpAddr(addr epAddr, nk key.NodePublic) {
 	if pi := m.byEpAddr[addr]; pi != nil {
 		delete(pi.epAddrs, addr)
 		delete(m.byEpAddr, addr)
-		if addr.vni.isSet() {
+		if addr.vni.IsSet() {
 			delete(m.relayEpAddrByNodeKey, pi.ep.publicKey)
 		}
 	}
 	if pi, ok := m.byNodeKey[nk]; ok {
-		if addr.vni.isSet() {
+		if addr.vni.IsSet() {
 			relay, ok := m.relayEpAddrByNodeKey[nk]
 			if ok {
 				delete(pi.epAddrs, relay)

+ 3 - 2
wgengine/magicsock/peermap_test.go

@@ -7,6 +7,7 @@ import (
 	"net/netip"
 	"testing"
 
+	"tailscale.com/net/packet"
 	"tailscale.com/types/key"
 )
 
@@ -20,8 +21,8 @@ func Test_peerMap_oneRelayEpAddrPerNK(t *testing.T) {
 	ed := &endpointDisco{key: key.NewDisco().Public()}
 	ep.disco.Store(ed)
 	pm.upsertEndpoint(ep, key.DiscoPublic{})
-	vni := virtualNetworkID{}
-	vni.set(1)
+	vni := packet.VirtualNetworkID{}
+	vni.Set(1)
 	relayEpAddrA := epAddr{ap: netip.MustParseAddrPort("127.0.0.1:1"), vni: vni}
 	relayEpAddrB := epAddr{ap: netip.MustParseAddrPort("127.0.0.1:2"), vni: vni}
 	pm.setNodeKeyForEpAddr(relayEpAddrA, nk)

+ 21 - 19
wgengine/magicsock/rebinding_conn.go

@@ -13,6 +13,7 @@ import (
 	"syscall"
 
 	"golang.org/x/net/ipv6"
+	"tailscale.com/net/batching"
 	"tailscale.com/net/netaddr"
 	"tailscale.com/net/packet"
 	"tailscale.com/types/nettype"
@@ -42,7 +43,7 @@ type RebindingUDPConn struct {
 // disrupting surrounding code that assumes nettype.PacketConn is a
 // *net.UDPConn.
 func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) {
-	upc := tryUpgradeToBatchingConn(p, network, batchSize)
+	upc := batching.TryUpgradeToConn(p, network, batchSize)
 	c.pconn = upc
 	c.pconnAtomic.Store(&upc)
 	c.port = uint16(c.localAddrLocked().Port)
@@ -72,25 +73,27 @@ func (c *RebindingUDPConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, e
 	return c.readFromWithInitPconn(*c.pconnAtomic.Load(), b)
 }
 
-// WriteBatchTo writes buffs to addr.
-func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr epAddr, offset int) error {
+// WriteWireGuardBatchTo writes buffs to addr. It serves primarily as an alias
+// for [batching.Conn.WriteBatchTo], with fallback to single packet operations
+// if c.pconn is not a [batching.Conn].
+//
+// WriteWireGuardBatchTo assumes buffs are WireGuard packets, which is notable
+// for Geneve encapsulation: Geneve protocol is set to [packet.GeneveProtocolWireGuard],
+// and the control bit is left unset.
+func (c *RebindingUDPConn) WriteWireGuardBatchTo(buffs [][]byte, addr epAddr, offset int) error {
 	if offset != packet.GeneveFixedHeaderLength {
-		return fmt.Errorf("RebindingUDPConn.WriteBatchTo: [unexpected] offset (%d) != Geneve header length (%d)", offset, packet.GeneveFixedHeaderLength)
+		return fmt.Errorf("RebindingUDPConn.WriteWireGuardBatchTo: [unexpected] offset (%d) != Geneve header length (%d)", offset, packet.GeneveFixedHeaderLength)
+	}
+	gh := packet.GeneveHeader{
+		Protocol: packet.GeneveProtocolWireGuard,
+		VNI:      addr.vni,
 	}
 	for {
 		pconn := *c.pconnAtomic.Load()
-		b, ok := pconn.(batchingConn)
+		b, ok := pconn.(batching.Conn)
 		if !ok {
-			vniIsSet := addr.vni.isSet()
-			var gh packet.GeneveHeader
-			if vniIsSet {
-				gh = packet.GeneveHeader{
-					Protocol: packet.GeneveProtocolWireGuard,
-					VNI:      addr.vni.get(),
-				}
-			}
 			for _, buf := range buffs {
-				if vniIsSet {
+				if gh.VNI.IsSet() {
 					gh.Encode(buf)
 				} else {
 					buf = buf[offset:]
@@ -102,7 +105,7 @@ func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr epAddr, offset int)
 			}
 			return nil
 		}
-		err := b.WriteBatchTo(buffs, addr, offset)
+		err := b.WriteBatchTo(buffs, addr.ap, gh, offset)
 		if err != nil {
 			if pconn != c.currentConn() {
 				continue
@@ -113,13 +116,12 @@ func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr epAddr, offset int)
 	}
 }
 
-// ReadBatch reads messages from c into msgs. It returns the number of messages
-// the caller should evaluate for nonzero len, as a zero len message may fall
-// on either side of a nonzero.
+// ReadBatch is an alias for [batching.Conn.ReadBatch] with fallback to single
+// packet operations if c.pconn is not a [batching.Conn].
 func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) {
 	for {
 		pconn := *c.pconnAtomic.Load()
-		b, ok := pconn.(batchingConn)
+		b, ok := pconn.(batching.Conn)
 		if !ok {
 			n, ap, err := c.readFromWithInitPconn(pconn, msgs[0].Buffers[0])
 			if err == nil {

+ 8 - 7
wgengine/magicsock/relaymanager.go

@@ -11,6 +11,7 @@ import (
 	"time"
 
 	"tailscale.com/disco"
+	"tailscale.com/net/packet"
 	"tailscale.com/net/stun"
 	udprelay "tailscale.com/net/udprelay/endpoint"
 	"tailscale.com/tailcfg"
@@ -384,7 +385,7 @@ func (r *relayManager) handleRxDiscoMsg(conn *Conn, dm disco.Message, relayServe
 		relayServerNodeKey: relayServerNodeKey,
 		disco:              discoKey,
 		from:               src.ap,
-		vni:                src.vni.get(),
+		vni:                src.vni.Get(),
 		at:                 time.Now(),
 	})
 }
@@ -535,8 +536,8 @@ func (r *relayManager) handleRxDiscoMsgRunLoop(event relayDiscoMsgEvent) {
 		// socket on Linux. We make no such efforts here as the raw socket BPF
 		// program does not support Geneve-encapsulated disco, and is also
 		// disabled by default.
-		vni := virtualNetworkID{}
-		vni.set(event.vni)
+		vni := packet.VirtualNetworkID{}
+		vni.Set(event.vni)
 		go event.conn.sendDiscoMessage(epAddr{ap: event.from, vni: vni}, key.NodePublic{}, event.disco, &disco.Pong{
 			TxID: msg.TxID,
 			Src:  event.from,
@@ -622,8 +623,8 @@ func (r *relayManager) handleHandshakeWorkDoneRunLoop(done relayEndpointHandshak
 		return
 	}
 	// This relay endpoint is functional.
-	vni := virtualNetworkID{}
-	vni.set(done.work.se.VNI)
+	vni := packet.VirtualNetworkID{}
+	vni.Set(done.work.se.VNI)
 	addr := epAddr{ap: done.pongReceivedFrom, vni: vni}
 	// ep.udpRelayEndpointReady() must be called in a new goroutine to prevent
 	// deadlocks as it acquires [endpoint] & [Conn] mutexes. See [relayManager]
@@ -784,8 +785,8 @@ func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork, generat
 	bind := &disco.BindUDPRelayEndpoint{
 		BindUDPRelayEndpointCommon: common,
 	}
-	vni := virtualNetworkID{}
-	vni.set(work.se.VNI)
+	vni := packet.VirtualNetworkID{}
+	vni.Set(work.se.VNI)
 	for _, addrPort := range work.se.AddrPorts {
 		if addrPort.IsValid() {
 			sentBindAny = true