Просмотр исходного кода

net/packet: split off checksum munging into different pkg

The current structure meant that we were embedding netstack in
the tailscale CLI and in the GUIs. This removes that by isolating
the checksum munging to a different pkg which is only called from
`net/tstun`.

Fixes #9756

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 2 лет назад
Родитель
Сommit
9d96e05267

+ 0 - 17
cmd/derper/depaware.txt

@@ -17,7 +17,6 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
         github.com/fxamacker/cbor/v2                                 from tailscale.com/tka
         github.com/golang/groupcache/lru                             from tailscale.com/net/dnscache
         github.com/golang/protobuf/proto                             from github.com/matttproud/golang_protobuf_extensions/pbutil
-        github.com/google/btree                                      from gvisor.dev/gvisor/pkg/tcpip/header
    L    github.com/google/nftables                                   from tailscale.com/util/linuxfw
    L 💣 github.com/google/nftables/alignedbuff                       from github.com/google/nftables/xt
    L 💣 github.com/google/nftables/binaryutil                        from github.com/google/nftables+
@@ -79,22 +78,6 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
         google.golang.org/protobuf/runtime/protoimpl                 from github.com/golang/protobuf/proto+
         google.golang.org/protobuf/types/descriptorpb                from google.golang.org/protobuf/reflect/protodesc
         google.golang.org/protobuf/types/known/timestamppb           from github.com/prometheus/client_golang/prometheus+
-        gvisor.dev/gvisor/pkg/atomicbitops                           from gvisor.dev/gvisor/pkg/buffer+
-        gvisor.dev/gvisor/pkg/bits                                   from gvisor.dev/gvisor/pkg/buffer
-     💣 gvisor.dev/gvisor/pkg/buffer                                 from gvisor.dev/gvisor/pkg/tcpip+
-        gvisor.dev/gvisor/pkg/context                                from gvisor.dev/gvisor/pkg/refs
-     💣 gvisor.dev/gvisor/pkg/gohacks                                from gvisor.dev/gvisor/pkg/state/wire+
-        gvisor.dev/gvisor/pkg/linewriter                             from gvisor.dev/gvisor/pkg/log
-        gvisor.dev/gvisor/pkg/log                                    from gvisor.dev/gvisor/pkg/context+
-        gvisor.dev/gvisor/pkg/refs                                   from gvisor.dev/gvisor/pkg/buffer
-     💣 gvisor.dev/gvisor/pkg/state                                  from gvisor.dev/gvisor/pkg/atomicbitops+
-        gvisor.dev/gvisor/pkg/state/wire                             from gvisor.dev/gvisor/pkg/state
-     💣 gvisor.dev/gvisor/pkg/sync                                   from gvisor.dev/gvisor/pkg/atomicbitops+
-        gvisor.dev/gvisor/pkg/tcpip                                  from gvisor.dev/gvisor/pkg/tcpip/header+
-        gvisor.dev/gvisor/pkg/tcpip/checksum                         from gvisor.dev/gvisor/pkg/buffer+
-        gvisor.dev/gvisor/pkg/tcpip/header                           from tailscale.com/net/packet
-        gvisor.dev/gvisor/pkg/tcpip/seqnum                           from gvisor.dev/gvisor/pkg/tcpip/header
-        gvisor.dev/gvisor/pkg/waiter                                 from gvisor.dev/gvisor/pkg/context+
         nhooyr.io/websocket                                          from tailscale.com/cmd/derper+
         nhooyr.io/websocket/internal/errd                            from nhooyr.io/websocket
         nhooyr.io/websocket/internal/xsync                           from nhooyr.io/websocket

+ 12 - 0
cmd/derper/derper_test.go

@@ -12,6 +12,7 @@ import (
 	"testing"
 
 	"tailscale.com/net/stun"
+	"tailscale.com/tstest/deptest"
 )
 
 func TestProdAutocertHostPolicy(t *testing.T) {
@@ -128,3 +129,14 @@ func TestNoContent(t *testing.T) {
 		})
 	}
 }
+
+func TestDeps(t *testing.T) {
+	deptest.DepChecker{
+		BadDeps: map[string]string{
+			"gvisor.dev/gvisor/pkg/buffer":       "https://github.com/tailscale/tailscale/issues/9756",
+			"gvisor.dev/gvisor/pkg/cpuid":        "https://github.com/tailscale/tailscale/issues/9756",
+			"gvisor.dev/gvisor/pkg/tcpip":        "https://github.com/tailscale/tailscale/issues/9756",
+			"gvisor.dev/gvisor/pkg/tcpip/header": "https://github.com/tailscale/tailscale/issues/9756",
+		},
+	}.Check(t)
+}

+ 0 - 17
cmd/tailscale/depaware.txt

@@ -17,7 +17,6 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         github.com/fxamacker/cbor/v2                                 from tailscale.com/tka
    L 💣 github.com/godbus/dbus/v5                                    from github.com/coreos/go-systemd/v22/dbus
         github.com/golang/groupcache/lru                             from tailscale.com/net/dnscache
-        github.com/google/btree                                      from gvisor.dev/gvisor/pkg/tcpip/header
    L    github.com/google/nftables                                   from tailscale.com/util/linuxfw
    L 💣 github.com/google/nftables/alignedbuff                       from github.com/google/nftables/xt
    L 💣 github.com/google/nftables/binaryutil                        from github.com/google/nftables+
@@ -65,22 +64,6 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         go4.org/netipx                                               from tailscale.com/wgengine/filter+
    W 💣 golang.zx2c4.com/wireguard/windows/tunnel/winipcfg           from tailscale.com/net/interfaces+
         gopkg.in/yaml.v2                                             from sigs.k8s.io/yaml
-        gvisor.dev/gvisor/pkg/atomicbitops                           from gvisor.dev/gvisor/pkg/buffer+
-        gvisor.dev/gvisor/pkg/bits                                   from gvisor.dev/gvisor/pkg/buffer
-     💣 gvisor.dev/gvisor/pkg/buffer                                 from gvisor.dev/gvisor/pkg/tcpip+
-        gvisor.dev/gvisor/pkg/context                                from gvisor.dev/gvisor/pkg/refs
-     💣 gvisor.dev/gvisor/pkg/gohacks                                from gvisor.dev/gvisor/pkg/state/wire+
-        gvisor.dev/gvisor/pkg/linewriter                             from gvisor.dev/gvisor/pkg/log
-        gvisor.dev/gvisor/pkg/log                                    from gvisor.dev/gvisor/pkg/context+
-        gvisor.dev/gvisor/pkg/refs                                   from gvisor.dev/gvisor/pkg/buffer
-     💣 gvisor.dev/gvisor/pkg/state                                  from gvisor.dev/gvisor/pkg/atomicbitops+
-        gvisor.dev/gvisor/pkg/state/wire                             from gvisor.dev/gvisor/pkg/state
-     💣 gvisor.dev/gvisor/pkg/sync                                   from gvisor.dev/gvisor/pkg/atomicbitops+
-        gvisor.dev/gvisor/pkg/tcpip                                  from gvisor.dev/gvisor/pkg/tcpip/header+
-        gvisor.dev/gvisor/pkg/tcpip/checksum                         from gvisor.dev/gvisor/pkg/buffer+
-        gvisor.dev/gvisor/pkg/tcpip/header                           from tailscale.com/net/packet
-        gvisor.dev/gvisor/pkg/tcpip/seqnum                           from gvisor.dev/gvisor/pkg/tcpip/header
-        gvisor.dev/gvisor/pkg/waiter                                 from gvisor.dev/gvisor/pkg/context+
         k8s.io/client-go/util/homedir                                from tailscale.com/cmd/tailscale/cli
         nhooyr.io/websocket                                          from tailscale.com/derp/derphttp+
         nhooyr.io/websocket/internal/errd                            from nhooyr.io/websocket

+ 21 - 0
cmd/tailscale/tailscale_test.go

@@ -0,0 +1,21 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package main
+
+import (
+	"testing"
+
+	"tailscale.com/tstest/deptest"
+)
+
+func TestDeps(t *testing.T) {
+	deptest.DepChecker{
+		BadDeps: map[string]string{
+			"gvisor.dev/gvisor/pkg/buffer":       "https://github.com/tailscale/tailscale/issues/9756",
+			"gvisor.dev/gvisor/pkg/cpuid":        "https://github.com/tailscale/tailscale/issues/9756",
+			"gvisor.dev/gvisor/pkg/tcpip":        "https://github.com/tailscale/tailscale/issues/9756",
+			"gvisor.dev/gvisor/pkg/tcpip/header": "https://github.com/tailscale/tailscale/issues/9756",
+		},
+	}.Check(t)
+}

+ 1 - 0
cmd/tailscaled/depaware.txt

@@ -276,6 +276,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
      💣 tailscale.com/net/netstat                                    from tailscale.com/ipn/ipnauth+
         tailscale.com/net/netutil                                    from tailscale.com/ipn/ipnlocal+
         tailscale.com/net/packet                                     from tailscale.com/net/tstun+
+        tailscale.com/net/packet/checksum                            from tailscale.com/net/tstun
         tailscale.com/net/ping                                       from tailscale.com/net/netcheck+
         tailscale.com/net/portmapper                                 from tailscale.com/net/netcheck+
         tailscale.com/net/proxymux                                   from tailscale.com/cmd/tailscaled

+ 21 - 0
ipn/ipn_test.go

@@ -0,0 +1,21 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package ipn
+
+import (
+	"testing"
+
+	"tailscale.com/tstest/deptest"
+)
+
+func TestDeps(t *testing.T) {
+	deptest.DepChecker{
+		BadDeps: map[string]string{
+			"gvisor.dev/gvisor/pkg/buffer":       "https://github.com/tailscale/tailscale/issues/9756",
+			"gvisor.dev/gvisor/pkg/cpuid":        "https://github.com/tailscale/tailscale/issues/9756",
+			"gvisor.dev/gvisor/pkg/tcpip":        "https://github.com/tailscale/tailscale/issues/9756",
+			"gvisor.dev/gvisor/pkg/tcpip/header": "https://github.com/tailscale/tailscale/issues/9756",
+		},
+	}.Check(t)
+}

+ 197 - 0
net/packet/checksum/checksum.go

@@ -0,0 +1,197 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package checksum provides functions for updating checksums in parsed packets.
+package checksum
+
+import (
+	"encoding/binary"
+	"net/netip"
+
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/tcpip/header"
+	"tailscale.com/net/packet"
+	"tailscale.com/types/ipproto"
+)
+
+// UpdateSrcAddr updates the source address in the packet buffer (e.g. during
+// SNAT). It also updates the checksum. Currently (2023-09-22) only TCP/UDP/ICMP
+// is supported. It panics if provided with an address in a different
+// family to the parsed packet.
+func UpdateSrcAddr(q *packet.Parsed, src netip.Addr) {
+	if src.Is6() && q.IPVersion != 6 {
+		panic("UpdateSrcAddr: cannot write IPv6 address to v4 packet")
+	} else if src.Is4() && q.IPVersion != 4 {
+		panic("UpdateSrcAddr: cannot write IPv4 address to v6 packet")
+	}
+	q.CaptureMeta.DidSNAT = true
+	q.CaptureMeta.OriginalSrc = q.Src
+
+	old := q.Src.Addr()
+	q.Src = netip.AddrPortFrom(src, q.Src.Port())
+
+	b := q.Buffer()
+	if src.Is6() {
+		v6 := src.As16()
+		copy(b[8:24], v6[:])
+		updateV6PacketChecksums(q, old, src)
+	} else {
+		v4 := src.As4()
+		copy(b[12:16], v4[:])
+		updateV4PacketChecksums(q, old, src)
+	}
+}
+
+// UpdateDstAddr updates the destination address in the packet buffer (e.g. during
+// DNAT). It also updates the checksum. Currently (2022-12-10) only TCP/UDP/ICMP
+// is supported. It panics if provided with an address in a different
+// family to the parsed packet.
+func UpdateDstAddr(q *packet.Parsed, dst netip.Addr) {
+	if dst.Is6() && q.IPVersion != 6 {
+		panic("UpdateDstAddr: cannot write IPv6 address to v4 packet")
+	} else if dst.Is4() && q.IPVersion != 4 {
+		panic("UpdateDstAddr: cannot write IPv4 address to v6 packet")
+	}
+	q.CaptureMeta.DidDNAT = true
+	q.CaptureMeta.OriginalDst = q.Dst
+
+	old := q.Dst.Addr()
+	q.Dst = netip.AddrPortFrom(dst, q.Dst.Port())
+
+	b := q.Buffer()
+	if dst.Is6() {
+		v6 := dst.As16()
+		copy(b[24:36], v6[:])
+		updateV6PacketChecksums(q, old, dst)
+	} else {
+		v4 := dst.As4()
+		copy(b[16:20], v4[:])
+		updateV4PacketChecksums(q, old, dst)
+	}
+}
+
+// updateV4PacketChecksums updates the checksums in the packet buffer.
+// Currently (2023-03-01) only TCP/UDP/ICMP over IPv4 is supported.
+// p is modified in place.
+// If p.IPProto is unknown, only the IP header checksum is updated.
+func updateV4PacketChecksums(p *packet.Parsed, old, new netip.Addr) {
+	if len(p.Buffer()) < 12 {
+		// Not enough space for an IPv4 header.
+		return
+	}
+	o4, n4 := old.As4(), new.As4()
+
+	// First update the checksum in the IP header.
+	updateV4Checksum(p.Buffer()[10:12], o4[:], n4[:])
+
+	// Now update the transport layer checksums, where applicable.
+	tr := p.Transport()
+	switch p.IPProto {
+	case ipproto.UDP, ipproto.DCCP:
+		if len(tr) < header.UDPMinimumSize {
+			// Not enough space for a UDP header.
+			return
+		}
+		updateV4Checksum(tr[6:8], o4[:], n4[:])
+	case ipproto.TCP:
+		if len(tr) < header.TCPMinimumSize {
+			// Not enough space for a TCP header.
+			return
+		}
+		updateV4Checksum(tr[16:18], o4[:], n4[:])
+	case ipproto.GRE:
+		if len(tr) < 6 {
+			// Not enough space for a GRE header.
+			return
+		}
+		if tr[0] == 1 { // checksum present
+			updateV4Checksum(tr[4:6], o4[:], n4[:])
+		}
+	case ipproto.SCTP, ipproto.ICMPv4:
+		// No transport layer update required.
+	}
+}
+
+// updateV6PacketChecksums updates the checksums in the packet buffer.
+// p is modified in place.
+// If p.IPProto is unknown, no checksums are updated.
+func updateV6PacketChecksums(p *packet.Parsed, old, new netip.Addr) {
+	if len(p.Buffer()) < 40 {
+		// Not enough space for an IPv6 header.
+		return
+	}
+	o6, n6 := tcpip.AddrFrom16Slice(old.AsSlice()), tcpip.AddrFrom16Slice(new.AsSlice())
+
+	// Now update the transport layer checksums, where applicable.
+	tr := p.Transport()
+	switch p.IPProto {
+	case ipproto.ICMPv6:
+		if len(tr) < header.ICMPv6MinimumSize {
+			return
+		}
+		header.ICMPv6(tr).UpdateChecksumPseudoHeaderAddress(o6, n6)
+	case ipproto.UDP, ipproto.DCCP:
+		if len(tr) < header.UDPMinimumSize {
+			return
+		}
+		header.UDP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true)
+	case ipproto.TCP:
+		if len(tr) < header.TCPMinimumSize {
+			return
+		}
+		header.TCP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true)
+	case ipproto.SCTP:
+		// No transport layer update required.
+	}
+}
+
+// updateV4Checksum calculates and updates the checksum in the packet buffer for
+// a change between old and new. The oldSum must point to the 16-bit checksum
+// field in the packet buffer that holds the old checksum value, it will be
+// updated in place.
+//
+// The old and new must be the same length, and must be an even number of bytes.
+func updateV4Checksum(oldSum, old, new []byte) {
+	if len(old) != len(new) {
+		panic("old and new must be the same length")
+	}
+	if len(old)%2 != 0 {
+		panic("old and new must be of even length")
+	}
+	/*
+		RFC 1624
+		Given the following notation:
+
+		    HC  - old checksum in header
+		    C   - one's complement sum of old header
+		    HC' - new checksum in header
+		    C'  - one's complement sum of new header
+		    m   - old value of a 16-bit field
+		    m'  - new value of a 16-bit field
+
+		    HC' = ~(C + (-m) + m')  --    [Eqn. 3]
+		    HC' = ~(~HC + ~m + m')
+
+		This can be simplified to:
+		    HC' = ~(C + ~m + m')    --    [Eqn. 3]
+		    HC' = ~C'
+		    C'  = C + ~m + m'
+	*/
+
+	c := uint32(^binary.BigEndian.Uint16(oldSum))
+
+	cPrime := c
+	for len(new) > 0 {
+		mNot := uint32(^binary.BigEndian.Uint16(old[:2]))
+		mPrime := uint32(binary.BigEndian.Uint16(new[:2]))
+		cPrime += mPrime + mNot
+		new, old = new[2:], old[2:]
+	}
+
+	// Account for overflows by adding the carry bits back into the sum.
+	for (cPrime >> 16) > 0 {
+		cPrime = cPrime&0xFFFF + cPrime>>16
+	}
+	hcPrime := ^uint16(cPrime)
+	binary.BigEndian.PutUint16(oldSum, hcPrime)
+}

+ 196 - 0
net/packet/checksum/checksum_test.go

@@ -0,0 +1,196 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package checksum
+
+import (
+	"encoding/binary"
+	"net/netip"
+	"testing"
+
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/tcpip/checksum"
+	"gvisor.dev/gvisor/pkg/tcpip/header"
+	"tailscale.com/net/packet"
+)
+
+func fullHeaderChecksumV4(b []byte) uint16 {
+	s := uint32(0)
+	for i := 0; i < len(b); i += 2 {
+		if i == 10 {
+			// Skip checksum field.
+			continue
+		}
+		s += uint32(binary.BigEndian.Uint16(b[i : i+2]))
+	}
+	for s>>16 > 0 {
+		s = s&0xFFFF + s>>16
+	}
+	return ^uint16(s)
+}
+
+func TestHeaderChecksumsV4(t *testing.T) {
+	// This is not a good enough test, because it doesn't
+	// check the various packet types or the many edge cases
+	// of the checksum algorithm. But it's a start.
+
+	tests := []struct {
+		name   string
+		packet []byte
+	}{
+		{
+			name: "ICMPv4",
+			packet: []byte{
+				0x45, 0x00, 0x00, 0x54, 0xb7, 0x96, 0x40, 0x00, 0x40, 0x01, 0x7a, 0x06, 0x64, 0x7f, 0x3f, 0x4c, 0x64, 0x40, 0x01, 0x01, 0x08, 0x00, 0x47, 0x1a, 0x00, 0x11, 0x01, 0xac, 0xcc, 0xf5, 0x95, 0x63, 0x00, 0x00, 0x00, 0x00, 0x8d, 0xfc, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
+			},
+		},
+		{
+			name: "TLS",
+			packet: []byte{
+				0x45, 0x00, 0x00, 0x3c, 0x54, 0x29, 0x40, 0x00, 0x40, 0x06, 0xb1, 0xac, 0x64, 0x42, 0xd4, 0x33, 0x64, 0x61, 0x98, 0x0f, 0xb1, 0x94, 0x01, 0xbb, 0x0a, 0x51, 0xce, 0x7c, 0x00, 0x00, 0x00, 0x00, 0xa0, 0x02, 0xfb, 0xe0, 0x38, 0xf6, 0x00, 0x00, 0x02, 0x04, 0x04, 0xd8, 0x04, 0x02, 0x08, 0x0a, 0x86, 0x2b, 0xcc, 0xd5, 0x00, 0x00, 0x00, 0x00, 0x01, 0x03, 0x03, 0x07,
+			},
+		},
+		{
+			name: "DNS",
+			packet: []byte{
+				0x45, 0x00, 0x00, 0x74, 0xe2, 0x85, 0x00, 0x00, 0x40, 0x11, 0x96, 0xb5, 0x64, 0x64, 0x64, 0x64, 0x64, 0x42, 0xd4, 0x33, 0x00, 0x35, 0xec, 0x55, 0x00, 0x60, 0xd9, 0x19, 0xed, 0xfd, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x34, 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x01, 0x1e, 0x00, 0x0c, 0x07, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x01, 0x6c, 0xc0, 0x15, 0xc0, 0x31, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x1e, 0x00, 0x04, 0x8e, 0xfa, 0xbd, 0xce, 0x00, 0x00, 0x29, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+			},
+		},
+		{
+			name: "DCCP",
+			packet: []byte{
+				0x45, 0x00, 0x00, 0x28, 0x15, 0x06, 0x40, 0x00, 0x40, 0x21, 0x5f, 0x2f, 0xc0, 0xa8, 0x01, 0x1f, 0xc9, 0x0b, 0x3b, 0xad, 0x80, 0x04, 0x13, 0x89, 0x05, 0x00, 0x08, 0xdb, 0x01, 0x00, 0x00, 0x04, 0x29, 0x01, 0x6d, 0xdc, 0x00, 0x00, 0x00, 0x00,
+			},
+		},
+		{
+			name: "SCTP",
+			packet: []byte{
+				0x45, 0x00, 0x00, 0x30, 0x09, 0xd9, 0x40, 0x00, 0xff, 0x84, 0x50, 0xe2, 0x0a, 0x1c, 0x06, 0x2c, 0x0a, 0x1c, 0x06, 0x2b, 0x0b, 0x80, 0x40, 0x00, 0x21, 0x44, 0x15, 0x23, 0x2b, 0xf2, 0x02, 0x4e, 0x03, 0x00, 0x00, 0x10, 0x28, 0x02, 0x43, 0x45, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00,
+			},
+		},
+		// TODO(maisem): add test for GRE.
+	}
+	var p packet.Parsed
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			p.Decode(tt.packet)
+			t.Log(p.String())
+			UpdateSrcAddr(&p, netip.MustParseAddr("100.64.0.1"))
+
+			got := binary.BigEndian.Uint16(tt.packet[10:12])
+			want := fullHeaderChecksumV4(tt.packet[:20])
+			if got != want {
+				t.Fatalf("got %x want %x", got, want)
+			}
+
+			UpdateDstAddr(&p, netip.MustParseAddr("100.64.0.2"))
+			got = binary.BigEndian.Uint16(tt.packet[10:12])
+			want = fullHeaderChecksumV4(tt.packet[:20])
+			if got != want {
+				t.Fatalf("got %x want %x", got, want)
+			}
+		})
+	}
+}
+
+func TestNatChecksumsV6UDP(t *testing.T) {
+	a1, a2 := netip.MustParseAddr("a::1"), netip.MustParseAddr("b::1")
+
+	// Make a fake UDP packet with 32 bytes of zeros as the datagram payload.
+	b := header.IPv6(make([]byte, header.IPv6MinimumSize+header.UDPMinimumSize+32))
+	b.Encode(&header.IPv6Fields{
+		PayloadLength:     header.UDPMinimumSize + 32,
+		TransportProtocol: header.UDPProtocolNumber,
+		HopLimit:          16,
+		SrcAddr:           tcpip.AddrFrom16Slice(a1.AsSlice()),
+		DstAddr:           tcpip.AddrFrom16Slice(a2.AsSlice()),
+	})
+	udp := header.UDP(b[header.IPv6MinimumSize:])
+	udp.Encode(&header.UDPFields{
+		SrcPort: 42,
+		DstPort: 43,
+		Length:  header.UDPMinimumSize + 32,
+	})
+	xsum := header.PseudoHeaderChecksum(
+		header.UDPProtocolNumber,
+		tcpip.AddrFrom16Slice(a1.AsSlice()),
+		tcpip.AddrFrom16Slice(a2.AsSlice()),
+		uint16(header.UDPMinimumSize+32),
+	)
+	xsum = checksum.Checksum(b.Payload()[header.UDPMinimumSize:], xsum)
+	udp.SetChecksum(^udp.CalculateChecksum(xsum))
+	if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a1.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) {
+		t.Fatal("test broken; initial packet has incorrect checksum")
+	}
+
+	// Parse the packet.
+	var p packet.Parsed
+	p.Decode(b)
+	t.Log(p.String())
+
+	// Update the source address of the packet to be the same as the dest.
+	UpdateSrcAddr(&p, a2)
+	if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) {
+		t.Fatal("incorrect checksum after updating source address")
+	}
+
+	// Update the dest address of the packet to be the original source address.
+	UpdateDstAddr(&p, a1)
+	if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a1.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) {
+		t.Fatal("incorrect checksum after updating destination address")
+	}
+}
+
+func TestNatChecksumsV6TCP(t *testing.T) {
+	a1, a2 := netip.MustParseAddr("a::1"), netip.MustParseAddr("b::1")
+
+	// Make a fake TCP packet with no payload.
+	b := header.IPv6(make([]byte, header.IPv6MinimumSize+header.TCPMinimumSize))
+	b.Encode(&header.IPv6Fields{
+		PayloadLength:     header.TCPMinimumSize,
+		TransportProtocol: header.TCPProtocolNumber,
+		HopLimit:          16,
+		SrcAddr:           tcpip.AddrFrom16Slice(a1.AsSlice()),
+		DstAddr:           tcpip.AddrFrom16Slice(a2.AsSlice()),
+	})
+	tcp := header.TCP(b[header.IPv6MinimumSize:])
+	tcp.Encode(&header.TCPFields{
+		SrcPort:       42,
+		DstPort:       43,
+		SeqNum:        1,
+		AckNum:        2,
+		DataOffset:    header.TCPMinimumSize,
+		Flags:         3,
+		WindowSize:    4,
+		Checksum:      0,
+		UrgentPointer: 5,
+	})
+	xsum := header.PseudoHeaderChecksum(
+		header.TCPProtocolNumber,
+		tcpip.AddrFrom16Slice(a1.AsSlice()),
+		tcpip.AddrFrom16Slice(a2.AsSlice()),
+		uint16(header.TCPMinimumSize),
+	)
+	tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
+
+	if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a1.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), 0, 0) {
+		t.Fatal("test broken; initial packet has incorrect checksum")
+	}
+
+	// Parse the packet.
+	var p packet.Parsed
+	p.Decode(b)
+	t.Log(p.String())
+
+	// Update the source address of the packet to be the same as the dest.
+	UpdateSrcAddr(&p, a2)
+	if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), 0, 0) {
+		t.Fatal("incorrect checksum after updating source address")
+	}
+
+	// Update the dest address of the packet to be the original source address.
+	UpdateDstAddr(&p, a1)
+	if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a1.AsSlice()), 0, 0) {
+		t.Fatal("incorrect checksum after updating destination address")
+	}
+}

+ 0 - 184
net/packet/packet.go

@@ -10,8 +10,6 @@ import (
 	"net/netip"
 	"strings"
 
-	"gvisor.dev/gvisor/pkg/tcpip"
-	"gvisor.dev/gvisor/pkg/tcpip/header"
 	"tailscale.com/net/netaddr"
 	"tailscale.com/types/ipproto"
 )
@@ -454,62 +452,6 @@ func (q *Parsed) IsEchoResponse() bool {
 	}
 }
 
-// UpdateSrcAddr updates the source address in the packet buffer (e.g. during
-// SNAT). It also updates the checksum. Currently (2023-09-22) only TCP/UDP/ICMP
-// is supported. It panics if provided with an address in a different
-// family to the parsed packet.
-func (q *Parsed) UpdateSrcAddr(src netip.Addr) {
-	if src.Is6() && q.IPVersion != 6 {
-		panic("UpdateSrcAddr: cannot write IPv6 address to v4 packet")
-	} else if src.Is4() && q.IPVersion != 4 {
-		panic("UpdateSrcAddr: cannot write IPv4 address to v6 packet")
-	}
-	q.CaptureMeta.DidSNAT = true
-	q.CaptureMeta.OriginalSrc = q.Src
-
-	old := q.Src.Addr()
-	q.Src = netip.AddrPortFrom(src, q.Src.Port())
-
-	b := q.Buffer()
-	if src.Is6() {
-		v6 := src.As16()
-		copy(b[8:24], v6[:])
-		updateV6PacketChecksums(q, old, src)
-	} else {
-		v4 := src.As4()
-		copy(b[12:16], v4[:])
-		updateV4PacketChecksums(q, old, src)
-	}
-}
-
-// UpdateDstAddr updates the destination address in the packet buffer (e.g. during
-// DNAT). It also updates the checksum. Currently (2022-12-10) only TCP/UDP/ICMP
-// is supported. It panics if provided with an address in a different
-// family to the parsed packet.
-func (q *Parsed) UpdateDstAddr(dst netip.Addr) {
-	if dst.Is6() && q.IPVersion != 6 {
-		panic("UpdateDstAddr: cannot write IPv6 address to v4 packet")
-	} else if dst.Is4() && q.IPVersion != 4 {
-		panic("UpdateDstAddr: cannot write IPv4 address to v6 packet")
-	}
-	q.CaptureMeta.DidDNAT = true
-	q.CaptureMeta.OriginalDst = q.Dst
-
-	old := q.Dst.Addr()
-	q.Dst = netip.AddrPortFrom(dst, q.Dst.Port())
-
-	b := q.Buffer()
-	if dst.Is6() {
-		v6 := dst.As16()
-		copy(b[24:36], v6[:])
-		updateV6PacketChecksums(q, old, dst)
-	} else {
-		v4 := dst.As4()
-		copy(b[16:20], v4[:])
-		updateV4PacketChecksums(q, old, dst)
-	}
-}
-
 // EchoIDSeq extracts the identifier/sequence bytes from an ICMP Echo response,
 // and returns them as a uint32, used to lookup internally routed ICMP echo
 // responses. This function is intentionally lightweight as it is called on
@@ -572,129 +514,3 @@ func withIP(ap netip.AddrPort, ip netip.Addr) netip.AddrPort {
 func withPort(ap netip.AddrPort, port uint16) netip.AddrPort {
 	return netip.AddrPortFrom(ap.Addr(), port)
 }
-
-// updateV4PacketChecksums updates the checksums in the packet buffer.
-// Currently (2023-03-01) only TCP/UDP/ICMP over IPv4 is supported.
-// p is modified in place.
-// If p.IPProto is unknown, only the IP header checksum is updated.
-func updateV4PacketChecksums(p *Parsed, old, new netip.Addr) {
-	if len(p.Buffer()) < 12 {
-		// Not enough space for an IPv4 header.
-		return
-	}
-	o4, n4 := old.As4(), new.As4()
-
-	// First update the checksum in the IP header.
-	updateV4Checksum(p.Buffer()[10:12], o4[:], n4[:])
-
-	// Now update the transport layer checksums, where applicable.
-	tr := p.Transport()
-	switch p.IPProto {
-	case ipproto.UDP, ipproto.DCCP:
-		if len(tr) < header.UDPMinimumSize {
-			// Not enough space for a UDP header.
-			return
-		}
-		updateV4Checksum(tr[6:8], o4[:], n4[:])
-	case ipproto.TCP:
-		if len(tr) < header.TCPMinimumSize {
-			// Not enough space for a TCP header.
-			return
-		}
-		updateV4Checksum(tr[16:18], o4[:], n4[:])
-	case ipproto.GRE:
-		if len(tr) < 6 {
-			// Not enough space for a GRE header.
-			return
-		}
-		if tr[0] == 1 { // checksum present
-			updateV4Checksum(tr[4:6], o4[:], n4[:])
-		}
-	case ipproto.SCTP, ipproto.ICMPv4:
-		// No transport layer update required.
-	}
-}
-
-// updateV6PacketChecksums updates the checksums in the packet buffer.
-// p is modified in place.
-// If p.IPProto is unknown, no checksums are updated.
-func updateV6PacketChecksums(p *Parsed, old, new netip.Addr) {
-	if len(p.Buffer()) < 40 {
-		// Not enough space for an IPv6 header.
-		return
-	}
-	o6, n6 := tcpip.AddrFrom16Slice(old.AsSlice()), tcpip.AddrFrom16Slice(new.AsSlice())
-
-	// Now update the transport layer checksums, where applicable.
-	tr := p.Transport()
-	switch p.IPProto {
-	case ipproto.ICMPv6:
-		if len(tr) < header.ICMPv6MinimumSize {
-			return
-		}
-		header.ICMPv6(tr).UpdateChecksumPseudoHeaderAddress(o6, n6)
-	case ipproto.UDP, ipproto.DCCP:
-		if len(tr) < header.UDPMinimumSize {
-			return
-		}
-		header.UDP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true)
-	case ipproto.TCP:
-		if len(tr) < header.TCPMinimumSize {
-			return
-		}
-		header.TCP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true)
-	case ipproto.SCTP:
-		// No transport layer update required.
-	}
-}
-
-// updateV4Checksum calculates and updates the checksum in the packet buffer for
-// a change between old and new. The oldSum must point to the 16-bit checksum
-// field in the packet buffer that holds the old checksum value, it will be
-// updated in place.
-//
-// The old and new must be the same length, and must be an even number of bytes.
-func updateV4Checksum(oldSum, old, new []byte) {
-	if len(old) != len(new) {
-		panic("old and new must be the same length")
-	}
-	if len(old)%2 != 0 {
-		panic("old and new must be of even length")
-	}
-	/*
-		RFC 1624
-		Given the following notation:
-
-		    HC  - old checksum in header
-		    C   - one's complement sum of old header
-		    HC' - new checksum in header
-		    C'  - one's complement sum of new header
-		    m   - old value of a 16-bit field
-		    m'  - new value of a 16-bit field
-
-		    HC' = ~(C + (-m) + m')  --    [Eqn. 3]
-		    HC' = ~(~HC + ~m + m')
-
-		This can be simplified to:
-		    HC' = ~(C + ~m + m')    --    [Eqn. 3]
-		    HC' = ~C'
-		    C'  = C + ~m + m'
-	*/
-
-	c := uint32(^binary.BigEndian.Uint16(oldSum))
-
-	cPrime := c
-	for len(new) > 0 {
-		mNot := uint32(^binary.BigEndian.Uint16(old[:2]))
-		mPrime := uint32(binary.BigEndian.Uint16(new[:2]))
-		cPrime += mPrime + mNot
-		new, old = new[2:], old[2:]
-	}
-
-	// Account for overflows by adding the carry bits back into the sum.
-	for (cPrime >> 16) > 0 {
-		cPrime = cPrime&0xFFFF + cPrime>>16
-	}
-	hcPrime := ^uint16(cPrime)
-	binary.BigEndian.PutUint16(oldSum, hcPrime)
-}

+ 0 - 185
net/packet/packet_test.go

@@ -5,7 +5,6 @@ package packet
 
 import (
 	"bytes"
-	"encoding/binary"
 	"encoding/hex"
 	"net/netip"
 	"reflect"
@@ -13,9 +12,6 @@ import (
 	"testing"
 	"unicode"
 
-	"gvisor.dev/gvisor/pkg/tcpip"
-	"gvisor.dev/gvisor/pkg/tcpip/checksum"
-	"gvisor.dev/gvisor/pkg/tcpip/header"
 	"tailscale.com/tstest"
 	"tailscale.com/types/ipproto"
 	"tailscale.com/util/must"
@@ -33,187 +29,6 @@ const (
 	Fragment = ipproto.Fragment
 )
 
-func fullHeaderChecksumV4(b []byte) uint16 {
-	s := uint32(0)
-	for i := 0; i < len(b); i += 2 {
-		if i == 10 {
-			// Skip checksum field.
-			continue
-		}
-		s += uint32(binary.BigEndian.Uint16(b[i : i+2]))
-	}
-	for s>>16 > 0 {
-		s = s&0xFFFF + s>>16
-	}
-	return ^uint16(s)
-}
-
-func TestHeaderChecksumsV4(t *testing.T) {
-	// This is not a good enough test, because it doesn't
-	// check the various packet types or the many edge cases
-	// of the checksum algorithm. But it's a start.
-
-	tests := []struct {
-		name   string
-		packet []byte
-	}{
-		{
-			name: "ICMPv4",
-			packet: []byte{
-				0x45, 0x00, 0x00, 0x54, 0xb7, 0x96, 0x40, 0x00, 0x40, 0x01, 0x7a, 0x06, 0x64, 0x7f, 0x3f, 0x4c, 0x64, 0x40, 0x01, 0x01, 0x08, 0x00, 0x47, 0x1a, 0x00, 0x11, 0x01, 0xac, 0xcc, 0xf5, 0x95, 0x63, 0x00, 0x00, 0x00, 0x00, 0x8d, 0xfc, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
-			},
-		},
-		{
-			name: "TLS",
-			packet: []byte{
-				0x45, 0x00, 0x00, 0x3c, 0x54, 0x29, 0x40, 0x00, 0x40, 0x06, 0xb1, 0xac, 0x64, 0x42, 0xd4, 0x33, 0x64, 0x61, 0x98, 0x0f, 0xb1, 0x94, 0x01, 0xbb, 0x0a, 0x51, 0xce, 0x7c, 0x00, 0x00, 0x00, 0x00, 0xa0, 0x02, 0xfb, 0xe0, 0x38, 0xf6, 0x00, 0x00, 0x02, 0x04, 0x04, 0xd8, 0x04, 0x02, 0x08, 0x0a, 0x86, 0x2b, 0xcc, 0xd5, 0x00, 0x00, 0x00, 0x00, 0x01, 0x03, 0x03, 0x07,
-			},
-		},
-		{
-			name: "DNS",
-			packet: []byte{
-				0x45, 0x00, 0x00, 0x74, 0xe2, 0x85, 0x00, 0x00, 0x40, 0x11, 0x96, 0xb5, 0x64, 0x64, 0x64, 0x64, 0x64, 0x42, 0xd4, 0x33, 0x00, 0x35, 0xec, 0x55, 0x00, 0x60, 0xd9, 0x19, 0xed, 0xfd, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x34, 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x01, 0x1e, 0x00, 0x0c, 0x07, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x01, 0x6c, 0xc0, 0x15, 0xc0, 0x31, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x1e, 0x00, 0x04, 0x8e, 0xfa, 0xbd, 0xce, 0x00, 0x00, 0x29, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
-			},
-		},
-		{
-			name: "DCCP",
-			packet: []byte{
-				0x45, 0x00, 0x00, 0x28, 0x15, 0x06, 0x40, 0x00, 0x40, 0x21, 0x5f, 0x2f, 0xc0, 0xa8, 0x01, 0x1f, 0xc9, 0x0b, 0x3b, 0xad, 0x80, 0x04, 0x13, 0x89, 0x05, 0x00, 0x08, 0xdb, 0x01, 0x00, 0x00, 0x04, 0x29, 0x01, 0x6d, 0xdc, 0x00, 0x00, 0x00, 0x00,
-			},
-		},
-		{
-			name: "SCTP",
-			packet: []byte{
-				0x45, 0x00, 0x00, 0x30, 0x09, 0xd9, 0x40, 0x00, 0xff, 0x84, 0x50, 0xe2, 0x0a, 0x1c, 0x06, 0x2c, 0x0a, 0x1c, 0x06, 0x2b, 0x0b, 0x80, 0x40, 0x00, 0x21, 0x44, 0x15, 0x23, 0x2b, 0xf2, 0x02, 0x4e, 0x03, 0x00, 0x00, 0x10, 0x28, 0x02, 0x43, 0x45, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00,
-			},
-		},
-		// TODO(maisem): add test for GRE.
-	}
-	var p Parsed
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			p.Decode(tt.packet)
-			t.Log(p.String())
-			p.UpdateSrcAddr(netip.MustParseAddr("100.64.0.1"))
-
-			got := binary.BigEndian.Uint16(tt.packet[10:12])
-			want := fullHeaderChecksumV4(tt.packet[:20])
-			if got != want {
-				t.Fatalf("got %x want %x", got, want)
-			}
-
-			p.UpdateDstAddr(netip.MustParseAddr("100.64.0.2"))
-			got = binary.BigEndian.Uint16(tt.packet[10:12])
-			want = fullHeaderChecksumV4(tt.packet[:20])
-			if got != want {
-				t.Fatalf("got %x want %x", got, want)
-			}
-		})
-	}
-}
-
-func TestNatChecksumsV6UDP(t *testing.T) {
-	a1, a2 := netip.MustParseAddr("a::1"), netip.MustParseAddr("b::1")
-
-	// Make a fake UDP packet with 32 bytes of zeros as the datagram payload.
-	b := header.IPv6(make([]byte, header.IPv6MinimumSize+header.UDPMinimumSize+32))
-	b.Encode(&header.IPv6Fields{
-		PayloadLength:     header.UDPMinimumSize + 32,
-		TransportProtocol: header.UDPProtocolNumber,
-		HopLimit:          16,
-		SrcAddr:           tcpip.AddrFrom16Slice(a1.AsSlice()),
-		DstAddr:           tcpip.AddrFrom16Slice(a2.AsSlice()),
-	})
-	udp := header.UDP(b[header.IPv6MinimumSize:])
-	udp.Encode(&header.UDPFields{
-		SrcPort: 42,
-		DstPort: 43,
-		Length:  header.UDPMinimumSize + 32,
-	})
-	xsum := header.PseudoHeaderChecksum(
-		header.UDPProtocolNumber,
-		tcpip.AddrFrom16Slice(a1.AsSlice()),
-		tcpip.AddrFrom16Slice(a2.AsSlice()),
-		uint16(header.UDPMinimumSize+32),
-	)
-	xsum = checksum.Checksum(b.Payload()[header.UDPMinimumSize:], xsum)
-	udp.SetChecksum(^udp.CalculateChecksum(xsum))
-	if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a1.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) {
-		t.Fatal("test broken; initial packet has incorrect checksum")
-	}
-
-	// Parse the packet.
-	var p Parsed
-	p.Decode(b)
-	t.Log(p.String())
-
-	// Update the source address of the packet to be the same as the dest.
-	p.UpdateSrcAddr(a2)
-	if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) {
-		t.Fatal("incorrect checksum after updating source address")
-	}
-
-	// Update the dest address of the packet to be the original source address.
-	p.UpdateDstAddr(a1)
-	if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a1.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) {
-		t.Fatal("incorrect checksum after updating destination address")
-	}
-}
-
-func TestNatChecksumsV6TCP(t *testing.T) {
-	a1, a2 := netip.MustParseAddr("a::1"), netip.MustParseAddr("b::1")
-
-	// Make a fake TCP packet with no payload.
-	b := header.IPv6(make([]byte, header.IPv6MinimumSize+header.TCPMinimumSize))
-	b.Encode(&header.IPv6Fields{
-		PayloadLength:     header.TCPMinimumSize,
-		TransportProtocol: header.TCPProtocolNumber,
-		HopLimit:          16,
-		SrcAddr:           tcpip.AddrFrom16Slice(a1.AsSlice()),
-		DstAddr:           tcpip.AddrFrom16Slice(a2.AsSlice()),
-	})
-	tcp := header.TCP(b[header.IPv6MinimumSize:])
-	tcp.Encode(&header.TCPFields{
-		SrcPort:       42,
-		DstPort:       43,
-		SeqNum:        1,
-		AckNum:        2,
-		DataOffset:    header.TCPMinimumSize,
-		Flags:         3,
-		WindowSize:    4,
-		Checksum:      0,
-		UrgentPointer: 5,
-	})
-	xsum := header.PseudoHeaderChecksum(
-		header.TCPProtocolNumber,
-		tcpip.AddrFrom16Slice(a1.AsSlice()),
-		tcpip.AddrFrom16Slice(a2.AsSlice()),
-		uint16(header.TCPMinimumSize),
-	)
-	tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
-
-	if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a1.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), 0, 0) {
-		t.Fatal("test broken; initial packet has incorrect checksum")
-	}
-
-	// Parse the packet.
-	var p Parsed
-	p.Decode(b)
-	t.Log(p.String())
-
-	// Update the source address of the packet to be the same as the dest.
-	p.UpdateSrcAddr(a2)
-	if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), 0, 0) {
-		t.Fatal("incorrect checksum after updating source address")
-	}
-
-	// Update the dest address of the packet to be the original source address.
-	p.UpdateDstAddr(a1)
-	if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a1.AsSlice()), 0, 0) {
-		t.Fatal("incorrect checksum after updating destination address")
-	}
-}
-
 func mustIPPort(s string) netip.AddrPort {
 	ipp, err := netip.ParseAddrPort(s)
 	if err != nil {

+ 3 - 2
net/tstun/wrap.go

@@ -25,6 +25,7 @@ import (
 	"tailscale.com/disco"
 	"tailscale.com/net/connstats"
 	"tailscale.com/net/packet"
+	"tailscale.com/net/packet/checksum"
 	"tailscale.com/net/tsaddr"
 	"tailscale.com/net/tstun/table"
 	"tailscale.com/syncs"
@@ -487,7 +488,7 @@ func (t *Wrapper) snat(p *packet.Parsed) {
 	oldSrc := p.Src.Addr()
 	newSrc := nc.selectSrcIP(oldSrc, p.Dst.Addr())
 	if oldSrc != newSrc {
-		p.UpdateSrcAddr(newSrc)
+		checksum.UpdateSrcAddr(p, newSrc)
 	}
 }
 
@@ -497,7 +498,7 @@ func (t *Wrapper) dnat(p *packet.Parsed) {
 	oldDst := p.Dst.Addr()
 	newDst := nc.mapDstIP(oldDst)
 	if newDst != oldDst {
-		p.UpdateDstAddr(newDst)
+		checksum.UpdateDstAddr(p, newDst)
 	}
 }