Browse Source

util/linuxfw: decoupling IPTables logic from linux router

This change is introducing new netfilterRunner interface and moving iptables manipulation to a lower leveled iptables runner.

For #391

Signed-off-by: KevinLiang10 <[email protected]>
KevinLiang10 2 years ago
parent
commit
243ce6ccc1

+ 27 - 1
cmd/derper/depaware.txt

@@ -12,9 +12,16 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
    W 💣 github.com/alexbrainman/sspi/negotiate                       from tailscale.com/net/tshttpproxy
         github.com/beorn7/perks/quantile                             from github.com/prometheus/client_golang/prometheus
      💣 github.com/cespare/xxhash/v2                                 from github.com/prometheus/client_golang/prometheus
+   L    github.com/coreos/go-iptables/iptables                       from tailscale.com/util/linuxfw
         github.com/fxamacker/cbor/v2                                 from tailscale.com/tka
         github.com/golang/groupcache/lru                             from tailscale.com/net/dnscache
         github.com/golang/protobuf/proto                             from github.com/matttproud/golang_protobuf_extensions/pbutil+
+   L    github.com/google/nftables                                   from tailscale.com/util/linuxfw
+   L 💣 github.com/google/nftables/alignedbuff                       from github.com/google/nftables/xt
+   L 💣 github.com/google/nftables/binaryutil                        from github.com/google/nftables+
+   L    github.com/google/nftables/expr                              from github.com/google/nftables+
+   L    github.com/google/nftables/internal/parseexprfunc            from github.com/google/nftables+
+   L    github.com/google/nftables/xt                                from github.com/google/nftables/expr+
         github.com/hdevalence/ed25519consensus                       from tailscale.com/tka
    L    github.com/josharian/native                                  from github.com/mdlayher/netlink+
    L 💣 github.com/jsimonetti/rtnetlink                              from tailscale.com/net/interfaces+
@@ -23,6 +30,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
         github.com/matttproud/golang_protobuf_extensions/pbutil      from github.com/prometheus/common/expfmt
    L 💣 github.com/mdlayher/netlink                                  from github.com/jsimonetti/rtnetlink+
    L 💣 github.com/mdlayher/netlink/nlenc                            from github.com/jsimonetti/rtnetlink+
+   L    github.com/mdlayher/netlink/nltest                           from github.com/google/nftables
    L 💣 github.com/mdlayher/socket                                   from github.com/mdlayher/netlink
      💣 github.com/mitchellh/go-ps                                   from tailscale.com/safesocket
      💣 github.com/prometheus/client_golang/prometheus               from tailscale.com/tsweb/promvarz
@@ -34,6 +42,9 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
   LD    github.com/prometheus/procfs                                 from github.com/prometheus/client_golang/prometheus
   LD    github.com/prometheus/procfs/internal/fs                     from github.com/prometheus/procfs
   LD    github.com/prometheus/procfs/internal/util                   from github.com/prometheus/procfs
+   L 💣 github.com/tailscale/netlink                                 from tailscale.com/util/linuxfw
+   L 💣 github.com/vishvananda/netlink/nl                            from github.com/tailscale/netlink
+   L    github.com/vishvananda/netns                                 from github.com/tailscale/netlink+
         github.com/x448/float16                                      from github.com/fxamacker/cbor/v2
      💣 go4.org/mem                                                  from tailscale.com/client/tailscale+
         go4.org/netipx                                               from tailscale.com/wgengine/filter
@@ -66,6 +77,20 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
         google.golang.org/protobuf/runtime/protoimpl                 from github.com/golang/protobuf/proto+
         google.golang.org/protobuf/types/descriptorpb                from google.golang.org/protobuf/reflect/protodesc
         google.golang.org/protobuf/types/known/timestamppb           from github.com/prometheus/client_golang/prometheus+
+   L    gvisor.dev/gvisor/pkg/abi                                    from gvisor.dev/gvisor/pkg/abi/linux
+   L 💣 gvisor.dev/gvisor/pkg/abi/linux                              from tailscale.com/util/linuxfw
+   L    gvisor.dev/gvisor/pkg/bits                                   from gvisor.dev/gvisor/pkg/abi/linux
+   L    gvisor.dev/gvisor/pkg/context                                from gvisor.dev/gvisor/pkg/abi/linux
+   L 💣 gvisor.dev/gvisor/pkg/gohacks                                from gvisor.dev/gvisor/pkg/abi/linux+
+   L 💣 gvisor.dev/gvisor/pkg/hostarch                               from gvisor.dev/gvisor/pkg/abi/linux+
+   L    gvisor.dev/gvisor/pkg/linewriter                             from gvisor.dev/gvisor/pkg/log
+   L    gvisor.dev/gvisor/pkg/log                                    from gvisor.dev/gvisor/pkg/context
+   L    gvisor.dev/gvisor/pkg/marshal                                from gvisor.dev/gvisor/pkg/abi/linux+
+   L 💣 gvisor.dev/gvisor/pkg/marshal/primitive                      from gvisor.dev/gvisor/pkg/abi/linux
+   L 💣 gvisor.dev/gvisor/pkg/state                                  from gvisor.dev/gvisor/pkg/abi/linux+
+   L    gvisor.dev/gvisor/pkg/state/wire                             from gvisor.dev/gvisor/pkg/state
+   L 💣 gvisor.dev/gvisor/pkg/sync                                   from gvisor.dev/gvisor/pkg/linewriter+
+   L    gvisor.dev/gvisor/pkg/waiter                                 from gvisor.dev/gvisor/pkg/context
         nhooyr.io/websocket                                          from tailscale.com/cmd/derper+
         nhooyr.io/websocket/internal/errd                            from nhooyr.io/websocket
         nhooyr.io/websocket/internal/xsync                           from nhooyr.io/websocket
@@ -130,8 +155,9 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
         tailscale.com/util/dnsname                                   from tailscale.com/hostinfo+
         tailscale.com/util/httpm                                     from tailscale.com/client/tailscale
         tailscale.com/util/lineread                                  from tailscale.com/hostinfo+
+   L 💣 tailscale.com/util/linuxfw                                   from tailscale.com/net/netns
         tailscale.com/util/mak                                       from tailscale.com/syncs+
-        tailscale.com/util/multierr                                  from tailscale.com/health
+        tailscale.com/util/multierr                                  from tailscale.com/health+
         tailscale.com/util/set                                       from tailscale.com/health+
         tailscale.com/util/singleflight                              from tailscale.com/net/dnscache
         tailscale.com/util/slicesx                                   from tailscale.com/cmd/derper+

+ 26 - 0
cmd/tailscale/depaware.txt

@@ -10,8 +10,15 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
    W 💣 github.com/alexbrainman/sspi                                 from github.com/alexbrainman/sspi/negotiate+
    W    github.com/alexbrainman/sspi/internal/common                 from github.com/alexbrainman/sspi/negotiate
    W 💣 github.com/alexbrainman/sspi/negotiate                       from tailscale.com/net/tshttpproxy
+   L    github.com/coreos/go-iptables/iptables                       from tailscale.com/util/linuxfw
         github.com/fxamacker/cbor/v2                                 from tailscale.com/tka
         github.com/golang/groupcache/lru                             from tailscale.com/net/dnscache
+   L    github.com/google/nftables                                   from tailscale.com/util/linuxfw
+   L 💣 github.com/google/nftables/alignedbuff                       from github.com/google/nftables/xt
+   L 💣 github.com/google/nftables/binaryutil                        from github.com/google/nftables+
+   L    github.com/google/nftables/expr                              from github.com/google/nftables+
+   L    github.com/google/nftables/internal/parseexprfunc            from github.com/google/nftables+
+   L    github.com/google/nftables/xt                                from github.com/google/nftables/expr+
         github.com/google/uuid                                       from tailscale.com/util/quarantine+
         github.com/hdevalence/ed25519consensus                       from tailscale.com/tka
    L    github.com/josharian/native                                  from github.com/mdlayher/netlink+
@@ -23,6 +30,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
      💣 github.com/mattn/go-isatty                                   from github.com/mattn/go-colorable+
    L 💣 github.com/mdlayher/netlink                                  from github.com/jsimonetti/rtnetlink+
    L 💣 github.com/mdlayher/netlink/nlenc                            from github.com/jsimonetti/rtnetlink+
+   L    github.com/mdlayher/netlink/nltest                           from github.com/google/nftables
    L 💣 github.com/mdlayher/socket                                   from github.com/mdlayher/netlink
      💣 github.com/mitchellh/go-ps                                   from tailscale.com/cmd/tailscale/cli+
         github.com/peterbourgon/ff/v3                                from github.com/peterbourgon/ff/v3/ffcli
@@ -36,13 +44,30 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         github.com/tailscale/goupnp/scpd                             from github.com/tailscale/goupnp
         github.com/tailscale/goupnp/soap                             from github.com/tailscale/goupnp+
         github.com/tailscale/goupnp/ssdp                             from github.com/tailscale/goupnp
+   L 💣 github.com/tailscale/netlink                                 from tailscale.com/util/linuxfw
         github.com/tcnksm/go-httpstat                                from tailscale.com/net/netcheck
         github.com/toqueteos/webbrowser                              from tailscale.com/cmd/tailscale/cli
+   L 💣 github.com/vishvananda/netlink/nl                            from github.com/tailscale/netlink
+   L    github.com/vishvananda/netns                                 from github.com/tailscale/netlink+
         github.com/x448/float16                                      from github.com/fxamacker/cbor/v2
      💣 go4.org/mem                                                  from tailscale.com/derp+
         go4.org/netipx                                               from tailscale.com/wgengine/filter
    W 💣 golang.zx2c4.com/wireguard/windows/tunnel/winipcfg           from tailscale.com/net/interfaces+
         gopkg.in/yaml.v2                                             from sigs.k8s.io/yaml
+   L    gvisor.dev/gvisor/pkg/abi                                    from gvisor.dev/gvisor/pkg/abi/linux
+   L 💣 gvisor.dev/gvisor/pkg/abi/linux                              from tailscale.com/util/linuxfw
+   L    gvisor.dev/gvisor/pkg/bits                                   from gvisor.dev/gvisor/pkg/abi/linux
+   L    gvisor.dev/gvisor/pkg/context                                from gvisor.dev/gvisor/pkg/abi/linux
+   L 💣 gvisor.dev/gvisor/pkg/gohacks                                from gvisor.dev/gvisor/pkg/abi/linux+
+   L 💣 gvisor.dev/gvisor/pkg/hostarch                               from gvisor.dev/gvisor/pkg/abi/linux+
+   L    gvisor.dev/gvisor/pkg/linewriter                             from gvisor.dev/gvisor/pkg/log
+   L    gvisor.dev/gvisor/pkg/log                                    from gvisor.dev/gvisor/pkg/context
+   L    gvisor.dev/gvisor/pkg/marshal                                from gvisor.dev/gvisor/pkg/abi/linux+
+   L 💣 gvisor.dev/gvisor/pkg/marshal/primitive                      from gvisor.dev/gvisor/pkg/abi/linux
+   L 💣 gvisor.dev/gvisor/pkg/state                                  from gvisor.dev/gvisor/pkg/abi/linux+
+   L    gvisor.dev/gvisor/pkg/state/wire                             from gvisor.dev/gvisor/pkg/state
+   L 💣 gvisor.dev/gvisor/pkg/sync                                   from gvisor.dev/gvisor/pkg/linewriter+
+   L    gvisor.dev/gvisor/pkg/waiter                                 from gvisor.dev/gvisor/pkg/context
         k8s.io/client-go/util/homedir                                from tailscale.com/cmd/tailscale/cli
         nhooyr.io/websocket                                          from tailscale.com/derp/derphttp+
         nhooyr.io/websocket/internal/errd                            from nhooyr.io/websocket
@@ -120,6 +145,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         tailscale.com/util/groupmember                               from tailscale.com/cmd/tailscale/cli
         tailscale.com/util/httpm                                     from tailscale.com/client/tailscale
         tailscale.com/util/lineread                                  from tailscale.com/net/interfaces+
+   L 💣 tailscale.com/util/linuxfw                                   from tailscale.com/net/netns
         tailscale.com/util/mak                                       from tailscale.com/net/netcheck+
         tailscale.com/util/multierr                                  from tailscale.com/control/controlhttp+
         tailscale.com/util/must                                      from tailscale.com/cmd/tailscale/cli

+ 16 - 3
cmd/tailscaled/depaware.txt

@@ -75,7 +75,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
    L    github.com/aws/smithy-go/transport/http                      from github.com/aws/aws-sdk-go-v2/aws/middleware+
    L    github.com/aws/smithy-go/transport/http/internal/io          from github.com/aws/smithy-go/transport/http
    L    github.com/aws/smithy-go/waiter                              from github.com/aws/aws-sdk-go-v2/service/ssm
-   L    github.com/coreos/go-iptables/iptables                       from tailscale.com/wgengine/router
+   L    github.com/coreos/go-iptables/iptables                       from tailscale.com/util/linuxfw
   LD 💣 github.com/creack/pty                                        from tailscale.com/ssh/tailssh
    W 💣 github.com/dblohm7/wingoes                                   from github.com/dblohm7/wingoes/com
    W 💣 github.com/dblohm7/wingoes/com                               from tailscale.com/cmd/tailscaled
@@ -86,6 +86,12 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
    L 💣 github.com/godbus/dbus/v5                                    from tailscale.com/net/dns+
         github.com/golang/groupcache/lru                             from tailscale.com/net/dnscache
         github.com/google/btree                                      from gvisor.dev/gvisor/pkg/tcpip/header+
+   L    github.com/google/nftables                                   from tailscale.com/util/linuxfw
+   L 💣 github.com/google/nftables/alignedbuff                       from github.com/google/nftables/xt
+   L 💣 github.com/google/nftables/binaryutil                        from github.com/google/nftables+
+   L    github.com/google/nftables/expr                              from github.com/google/nftables+
+   L    github.com/google/nftables/internal/parseexprfunc            from github.com/google/nftables+
+   L    github.com/google/nftables/xt                                from github.com/google/nftables/expr+
         github.com/hdevalence/ed25519consensus                       from tailscale.com/tka
    L 💣 github.com/illarion/gonotify                                 from tailscale.com/net/dns
    L    github.com/insomniacslk/dhcp/dhcpv4                          from tailscale.com/net/tstun
@@ -109,6 +115,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
    L    github.com/mdlayher/genetlink                                from tailscale.com/net/tstun
    L 💣 github.com/mdlayher/netlink                                  from github.com/jsimonetti/rtnetlink+
    L 💣 github.com/mdlayher/netlink/nlenc                            from github.com/jsimonetti/rtnetlink+
+   L    github.com/mdlayher/netlink/nltest                           from github.com/google/nftables
    L    github.com/mdlayher/sdnotify                                 from tailscale.com/util/systemd
    L 💣 github.com/mdlayher/socket                                   from github.com/mdlayher/netlink
      💣 github.com/mitchellh/go-ps                                   from tailscale.com/safesocket
@@ -153,13 +160,18 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
         go4.org/netipx                                               from tailscale.com/ipn/ipnlocal+
    W 💣 golang.zx2c4.com/wintun                                      from github.com/tailscale/wireguard-go/tun+
    W 💣 golang.zx2c4.com/wireguard/windows/tunnel/winipcfg           from tailscale.com/net/dns+
+   L    gvisor.dev/gvisor/pkg/abi                                    from gvisor.dev/gvisor/pkg/abi/linux
+   L 💣 gvisor.dev/gvisor/pkg/abi/linux                              from tailscale.com/util/linuxfw
         gvisor.dev/gvisor/pkg/atomicbitops                           from gvisor.dev/gvisor/pkg/tcpip+
-        gvisor.dev/gvisor/pkg/bits                                   from gvisor.dev/gvisor/pkg/bufferv2
+        gvisor.dev/gvisor/pkg/bits                                   from gvisor.dev/gvisor/pkg/bufferv2+
      💣 gvisor.dev/gvisor/pkg/bufferv2                               from gvisor.dev/gvisor/pkg/tcpip+
-        gvisor.dev/gvisor/pkg/context                                from gvisor.dev/gvisor/pkg/refs
+        gvisor.dev/gvisor/pkg/context                                from gvisor.dev/gvisor/pkg/refs+
      💣 gvisor.dev/gvisor/pkg/gohacks                                from gvisor.dev/gvisor/pkg/state/wire+
+   L 💣 gvisor.dev/gvisor/pkg/hostarch                               from gvisor.dev/gvisor/pkg/abi/linux+
         gvisor.dev/gvisor/pkg/linewriter                             from gvisor.dev/gvisor/pkg/log
         gvisor.dev/gvisor/pkg/log                                    from gvisor.dev/gvisor/pkg/context+
+   L    gvisor.dev/gvisor/pkg/marshal                                from gvisor.dev/gvisor/pkg/abi/linux+
+   L 💣 gvisor.dev/gvisor/pkg/marshal/primitive                      from gvisor.dev/gvisor/pkg/abi/linux
         gvisor.dev/gvisor/pkg/rand                                   from gvisor.dev/gvisor/pkg/tcpip/network/hash+
         gvisor.dev/gvisor/pkg/refs                                   from gvisor.dev/gvisor/pkg/bufferv2+
      💣 gvisor.dev/gvisor/pkg/sleep                                  from gvisor.dev/gvisor/pkg/tcpip/transport/tcp
@@ -317,6 +329,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
      💣 tailscale.com/util/hashx                                     from tailscale.com/util/deephash
         tailscale.com/util/httpm                                     from tailscale.com/client/tailscale+
         tailscale.com/util/lineread                                  from tailscale.com/hostinfo+
+   L 💣 tailscale.com/util/linuxfw                                   from tailscale.com/net/netns+
         tailscale.com/util/mak                                       from tailscale.com/control/controlclient+
         tailscale.com/util/multierr                                  from tailscale.com/control/controlclient+
         tailscale.com/util/must                                      from tailscale.com/logpolicy

+ 2 - 9
net/netns/netns_linux.go

@@ -17,16 +17,9 @@ import (
 	"tailscale.com/net/interfaces"
 	"tailscale.com/net/netmon"
 	"tailscale.com/types/logger"
+	"tailscale.com/util/linuxfw"
 )
 
-// tailscaleBypassMark is the mark indicating that packets originating
-// from a socket should bypass Tailscale-managed routes during routing
-// table lookups.
-//
-// Keep this in sync with tailscaleBypassMark in
-// wgengine/router/router_linux.go.
-const tailscaleBypassMark = 0x80000
-
 // socketMarkWorksOnce is the sync.Once & cached value for useSocketMark.
 var socketMarkWorksOnce struct {
 	sync.Once
@@ -119,7 +112,7 @@ func controlC(network, address string, c syscall.RawConn) error {
 }
 
 func setBypassMark(fd uintptr) error {
-	if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, tailscaleBypassMark); err != nil {
+	if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, linuxfw.TailscaleBypassMarkNum); err != nil {
 		return fmt.Errorf("setting SO_MARK bypass: %w", err)
 	}
 	return nil

+ 0 - 42
net/netns/netns_linux_test.go

@@ -4,51 +4,9 @@
 package netns
 
 import (
-	"fmt"
-	"go/ast"
-	"go/parser"
-	"go/token"
 	"testing"
 )
 
-// verifies tailscaleBypassMark is in sync with wgengine.
-func TestBypassMarkInSync(t *testing.T) {
-	want := fmt.Sprintf("%q", fmt.Sprintf("0x%x", tailscaleBypassMark))
-	fset := token.NewFileSet()
-	f, err := parser.ParseFile(fset, "../../wgengine/router/router_linux.go", nil, 0)
-	if err != nil {
-		t.Fatal(err)
-	}
-	for _, decl := range f.Decls {
-		gd, ok := decl.(*ast.GenDecl)
-		if !ok || gd.Tok != token.CONST {
-			continue
-		}
-		for _, spec := range gd.Specs {
-			vs, ok := spec.(*ast.ValueSpec)
-			if !ok {
-				continue
-			}
-			for i, ident := range vs.Names {
-				if ident.Name != "tailscaleBypassMark" {
-					continue
-				}
-				valExpr := vs.Values[i]
-				lit, ok := valExpr.(*ast.BasicLit)
-				if !ok {
-					t.Errorf("tailscaleBypassMark = %T, expected *ast.BasicLit", valExpr)
-				}
-				if lit.Value == want {
-					// Pass.
-					return
-				}
-				t.Fatalf("router_linux.go's tailscaleBypassMark = %s; not in sync with netns's %s", lit.Value, want)
-			}
-		}
-	}
-	t.Errorf("tailscaleBypassMark not found in router_linux.go")
-}
-
 func TestSocketMarkWorks(t *testing.T) {
 	_ = socketMarkWorks()
 	// we cannot actually assert whether the test runner has SO_MARK available

+ 475 - 0
util/linuxfw/iptables_runner.go

@@ -0,0 +1,475 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux
+
+package linuxfw
+
+import (
+	"fmt"
+	"net/netip"
+	"strings"
+
+	"github.com/coreos/go-iptables/iptables"
+	"tailscale.com/net/tsaddr"
+	"tailscale.com/types/logger"
+	"tailscale.com/util/multierr"
+)
+
+type iptablesInterface interface {
+	// Adding this interface for testing purposes so we can mock out
+	// the iptables library, in reality this is a wrapper to *iptables.IPTables.
+	Insert(table, chain string, pos int, args ...string) error
+	Append(table, chain string, args ...string) error
+	Exists(table, chain string, args ...string) (bool, error)
+	Delete(table, chain string, args ...string) error
+	ClearChain(table, chain string) error
+	NewChain(table, chain string) error
+	DeleteChain(table, chain string) error
+}
+
+type iptablesRunner struct {
+	ipt4 iptablesInterface
+	ipt6 iptablesInterface
+
+	v6Available    bool
+	v6NATAvailable bool
+}
+
+// NewIPTablesRunner constructs a NetfilterRunner that programs iptables rules.
+// If the underlying iptables library fails to initialize, that error is
+// returned. The runner probes for IPv6 support once at initialization time and
+// if not found, no IPv6 rules will be modified for the lifetime of the runner.
+func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
+	ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
+	if err != nil {
+		return nil, err
+	}
+
+	supportsV6, supportsV6NAT := false, false
+	v6err := checkIPv6(logf)
+	if v6err != nil {
+		logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
+	} else {
+		supportsV6 = true
+		supportsV6NAT = supportsV6 && checkSupportsV6NAT()
+		logf("v6nat = %v", supportsV6NAT)
+	}
+
+	var ipt6 *iptables.IPTables
+	if supportsV6 {
+		ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
+		if err != nil {
+			return nil, err
+		}
+	}
+	return &iptablesRunner{ipt4, ipt6, supportsV6, supportsV6NAT}, nil
+}
+
+// HasIPV6 returns true if the system supports IPv6.
+func (i *iptablesRunner) HasIPV6() bool {
+	return i.v6Available
+}
+
+// HasIPV6NAT returns true if the system supports IPv6 NAT.
+func (i *iptablesRunner) HasIPV6NAT() bool {
+	return i.v6NATAvailable
+}
+
+func isErrChainNotExist(err error) bool {
+	return errCode(err) == 1
+}
+
+// getIPTByAddr returns the iptablesInterface with correct IP family
+// that we will be using for the given address.
+func (i *iptablesRunner) getIPTByAddr(addr netip.Addr) iptablesInterface {
+	nf := i.ipt4
+	if addr.Is6() {
+		nf = i.ipt6
+	}
+	return nf
+}
+
+// AddLoopbackRule adds an iptables rule to permit loopback traffic to
+// a local Tailscale IP.
+func (i *iptablesRunner) AddLoopbackRule(addr netip.Addr) error {
+	if err := i.getIPTByAddr(addr).Insert("filter", "ts-input", 1, "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil {
+		return fmt.Errorf("adding loopback allow rule for %q: %w", addr, err)
+	}
+
+	return nil
+}
+
+// tsChain returns the name of the tailscale sub-chain corresponding
+// to the given "parent" chain (e.g. INPUT, FORWARD, ...).
+func tsChain(chain string) string {
+	return "ts-" + strings.ToLower(chain)
+}
+
+// DelLoopbackRule removes the iptables rule permitting loopback
+// traffic to a Tailscale IP.
+func (i *iptablesRunner) DelLoopbackRule(addr netip.Addr) error {
+	if err := i.getIPTByAddr(addr).Delete("filter", "ts-input", "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil {
+		return fmt.Errorf("deleting loopback allow rule for %q: %w", addr, err)
+	}
+
+	return nil
+}
+
+// getTables gets the available iptablesInterface in iptables runner.
+func (i *iptablesRunner) getTables() []iptablesInterface {
+	if i.HasIPV6() {
+		return []iptablesInterface{i.ipt4, i.ipt6}
+	}
+	return []iptablesInterface{i.ipt4}
+}
+
+// getNATTables gets the available iptablesInterface in iptables runner.
+// If the system does not support IPv6 NAT, only the IPv4 iptablesInterface
+// is returned.
+func (i *iptablesRunner) getNATTables() []iptablesInterface {
+	if i.HasIPV6NAT() {
+		return i.getTables()
+	}
+	return []iptablesInterface{i.ipt4}
+}
+
+// AddHooks inserts calls to tailscale's netfilter chains in
+// the relevant main netfilter chains. The tailscale chains must
+// already exist. If they do not, an error is returned.
+func (i *iptablesRunner) AddHooks() error {
+	// divert inserts a jump to the tailscale chain in the given table/chain.
+	// If the jump already exists, it is a no-op.
+	divert := func(ipt iptablesInterface, table, chain string) error {
+		tsChain := tsChain(chain)
+
+		args := []string{"-j", tsChain}
+		exists, err := ipt.Exists(table, chain, args...)
+		if err != nil {
+			return fmt.Errorf("checking for %v in %s/%s: %w", args, table, chain, err)
+		}
+		if exists {
+			return nil
+		}
+		if err := ipt.Insert(table, chain, 1, args...); err != nil {
+			return fmt.Errorf("adding %v in %s/%s: %w", args, table, chain, err)
+		}
+		return nil
+	}
+
+	for _, ipt := range i.getTables() {
+		if err := divert(ipt, "filter", "INPUT"); err != nil {
+			return err
+		}
+		if err := divert(ipt, "filter", "FORWARD"); err != nil {
+			return err
+		}
+	}
+
+	for _, ipt := range i.getNATTables() {
+		if err := divert(ipt, "nat", "POSTROUTING"); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+// AddChains creates custom Tailscale chains in netfilter via iptables
+// if the ts-chain doesn't already exist.
+func (i *iptablesRunner) AddChains() error {
+	// create creates a chain in the given table if it doesn't already exist.
+	// If the chain already exists, it is a no-op.
+	create := func(ipt iptablesInterface, table, chain string) error {
+		err := ipt.ClearChain(table, chain)
+		if isErrChainNotExist(err) {
+			// nonexistent chain. let's create it!
+			return ipt.NewChain(table, chain)
+		}
+		if err != nil {
+			return fmt.Errorf("setting up %s/%s: %w", table, chain, err)
+		}
+		return nil
+	}
+
+	for _, ipt := range i.getTables() {
+		if err := create(ipt, "filter", "ts-input"); err != nil {
+			return err
+		}
+		if err := create(ipt, "filter", "ts-forward"); err != nil {
+			return err
+		}
+	}
+
+	for _, ipt := range i.getNATTables() {
+		if err := create(ipt, "nat", "ts-postrouting"); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+// AddBase adds some basic processing rules to be supplemented by
+// later calls to other helpers.
+func (i *iptablesRunner) AddBase(tunname string) error {
+	if err := i.addBase4(tunname); err != nil {
+		return err
+	}
+	if i.HasIPV6() {
+		if err := i.addBase6(tunname); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+// addBase4 adds some basic IPv6 processing rules to be
+// supplemented by later calls to other helpers.
+func (i *iptablesRunner) addBase4(tunname string) error {
+	// Only allow CGNAT range traffic to come from tailscale0. There
+	// is an exception carved out for ranges used by ChromeOS, for
+	// which we fall out of the Tailscale chain.
+	//
+	// Note, this will definitely break nodes that end up using the
+	// CGNAT range for other purposes :(.
+	args := []string{"!", "-i", tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"}
+	if err := i.ipt4.Append("filter", "ts-input", args...); err != nil {
+		return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err)
+	}
+	args = []string{"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}
+	if err := i.ipt4.Append("filter", "ts-input", args...); err != nil {
+		return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err)
+	}
+
+	// Forward all traffic from the Tailscale interface, and drop
+	// traffic to the tailscale interface by default. We use packet
+	// marks here so both filter/FORWARD and nat/POSTROUTING can match
+	// on these packets of interest.
+	//
+	// In particular, we only want to apply SNAT rules in
+	// nat/POSTROUTING to packets that originated from the Tailscale
+	// interface, but we can't match on the inbound interface in
+	// POSTROUTING. So instead, we match on the inbound interface in
+	// filter/FORWARD, and set a packet mark that nat/POSTROUTING can
+	// use to effectively run that same test again.
+	args = []string{"-i", tunname, "-j", "MARK", "--set-mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask}
+	if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil {
+		return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
+	}
+	args = []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "ACCEPT"}
+	if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil {
+		return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
+	}
+	args = []string{"-o", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}
+	if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil {
+		return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
+	}
+	args = []string{"-o", tunname, "-j", "ACCEPT"}
+	if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil {
+		return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
+	}
+
+	return nil
+}
+
+// addBase6 adds some basic IPv4 processing rules to be
+// supplemented by later calls to other helpers.
+func (i *iptablesRunner) addBase6(tunname string) error {
+	// TODO: only allow traffic from Tailscale's ULA range to come
+	// from tailscale0.
+
+	args := []string{"-i", tunname, "-j", "MARK", "--set-mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask}
+	if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil {
+		return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err)
+	}
+	args = []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "ACCEPT"}
+	if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil {
+		return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err)
+	}
+	// TODO: drop forwarded traffic to tailscale0 from tailscale's ULA
+	// (see corresponding IPv4 CGNAT rule).
+	args = []string{"-o", tunname, "-j", "ACCEPT"}
+	if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil {
+		return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err)
+	}
+
+	return nil
+}
+
+// DelChains removes the custom Tailscale chains from netfilter via iptables.
+func (i *iptablesRunner) DelChains() error {
+	for _, ipt := range i.getTables() {
+		if err := delChain(ipt, "filter", "ts-input"); err != nil {
+			return err
+		}
+		if err := delChain(ipt, "filter", "ts-forward"); err != nil {
+			return err
+		}
+	}
+
+	for _, ipt := range i.getNATTables() {
+		if err := delChain(ipt, "nat", "ts-postrouting"); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+// DelBase empties but does not remove custom Tailscale chains from
+// netfilter via iptables.
+func (i *iptablesRunner) DelBase() error {
+	del := func(ipt iptablesInterface, table, chain string) error {
+		if err := ipt.ClearChain(table, chain); err != nil {
+			if isErrChainNotExist(err) {
+				// nonexistent chain. That's fine, since it's
+				// the desired state anyway.
+				return nil
+			}
+			return fmt.Errorf("flushing %s/%s: %w", table, chain, err)
+		}
+		return nil
+	}
+
+	for _, ipt := range i.getTables() {
+		if err := del(ipt, "filter", "ts-input"); err != nil {
+			return err
+		}
+		if err := del(ipt, "filter", "ts-forward"); err != nil {
+			return err
+		}
+	}
+	for _, ipt := range i.getNATTables() {
+		if err := del(ipt, "nat", "ts-postrouting"); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+// DelHooks deletes the calls to tailscale's netfilter chains
+// in the relevant main netfilter chains.
+func (i *iptablesRunner) DelHooks(logf logger.Logf) error {
+	for _, ipt := range i.getTables() {
+		if err := delTSHook(ipt, "filter", "INPUT", logf); err != nil {
+			return err
+		}
+		if err := delTSHook(ipt, "filter", "FORWARD", logf); err != nil {
+			return err
+		}
+	}
+	for _, ipt := range i.getNATTables() {
+		if err := delTSHook(ipt, "nat", "POSTROUTING", logf); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+// AddSNATRule adds a netfilter rule to SNAT traffic destined for
+// local subnets.
+func (i *iptablesRunner) AddSNATRule() error {
+	args := []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "MASQUERADE"}
+	for _, ipt := range i.getNATTables() {
+		if err := ipt.Append("nat", "ts-postrouting", args...); err != nil {
+			return fmt.Errorf("adding %v in nat/ts-postrouting: %w", args, err)
+		}
+	}
+	return nil
+}
+
+// DelSNATRule removes the netfilter rule to SNAT traffic destined for
+// local subnets. An error is returned if the rule does not exist.
+func (i *iptablesRunner) DelSNATRule() error {
+	args := []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "MASQUERADE"}
+	for _, ipt := range i.getNATTables() {
+		if err := ipt.Delete("nat", "ts-postrouting", args...); err != nil {
+			return fmt.Errorf("deleting %v in nat/ts-postrouting: %w", args, err)
+		}
+	}
+	return nil
+}
+
+// IPTablesCleanup removes all Tailscale added iptables rules.
+// Any errors that occur are logged to the provided logf.
+func IPTablesCleanup(logf logger.Logf) {
+	err := clearRules(iptables.ProtocolIPv4, logf)
+	if err != nil {
+		logf("linuxfw: clear iptables: %v", err)
+	}
+
+	err = clearRules(iptables.ProtocolIPv6, logf)
+	if err != nil {
+		logf("linuxfw: clear ip6tables: %v", err)
+	}
+}
+
+// delTSHook deletes hook in a chain that jumps to a ts-chain. If the hook does not
+// exist, it's a no-op since the desired state is already achieved but we log the
+// error because error code from the iptables module resists unwrapping.
+func delTSHook(ipt iptablesInterface, table, chain string, logf logger.Logf) error {
+	tsChain := tsChain(chain)
+	args := []string{"-j", tsChain}
+	if err := ipt.Delete(table, chain, args...); err != nil {
+		// TODO(apenwarr): check for errCode(1) here.
+		// Unfortunately the error code from the iptables
+		// module resists unwrapping, unlike with other
+		// calls. So we have to assume if Delete fails,
+		// it's because there is no such rule.
+		logf("deleting %v in %s/%s: %v", args, table, chain, err)
+		return nil
+	}
+	return nil
+}
+
+// delChain flushs and deletes a chain. If the chain does not exist, it's a no-op
+// since the desired state is already achieved. otherwise, it returns an error.
+func delChain(ipt iptablesInterface, table, chain string) error {
+	if err := ipt.ClearChain(table, chain); err != nil {
+		if isErrChainNotExist(err) {
+			// nonexistent chain. nothing to do.
+			return nil
+		}
+		return fmt.Errorf("flushing %s/%s: %w", table, chain, err)
+	}
+	if err := ipt.DeleteChain(table, chain); err != nil {
+		return fmt.Errorf("deleting %s/%s: %w", table, chain, err)
+	}
+	return nil
+}
+
+// clearRules clears all the iptables rules created by Tailscale
+// for the given protocol. If error occurs, it's logged but not returned.
+func clearRules(proto iptables.Protocol, logf logger.Logf) error {
+	ipt, err := iptables.NewWithProtocol(proto)
+	if err != nil {
+		return err
+	}
+
+	var errs []error
+
+	if err := delTSHook(ipt, "filter", "INPUT", logf); err != nil {
+		errs = append(errs, err)
+	}
+	if err := delTSHook(ipt, "filter", "FORWARD", logf); err != nil {
+		errs = append(errs, err)
+	}
+	if err := delTSHook(ipt, "nat", "POSTROUTING", logf); err != nil {
+		errs = append(errs, err)
+	}
+
+	if err := delChain(ipt, "filter", "ts-input"); err != nil {
+		errs = append(errs, err)
+	}
+	if err := delChain(ipt, "filter", "ts-forward"); err != nil {
+		errs = append(errs, err)
+	}
+
+	if err := delChain(ipt, "nat", "ts-postrouting"); err != nil {
+		errs = append(errs, err)
+	}
+
+	return multierr.New(errs...)
+}

+ 420 - 0
util/linuxfw/iptables_runner_test.go

@@ -0,0 +1,420 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux
+
+package linuxfw
+
+import (
+	"errors"
+	"net/netip"
+	"strings"
+	"testing"
+
+	"tailscale.com/net/tsaddr"
+)
+
+var errExec = errors.New("execution failed")
+
+type fakeIPTables struct {
+	t *testing.T
+	n map[string][]string
+}
+
+type fakeRule struct {
+	table, chain string
+	args         []string
+}
+
+func newIPTables(t *testing.T) *fakeIPTables {
+	return &fakeIPTables{
+		t: t,
+		n: map[string][]string{
+			"filter/INPUT":    nil,
+			"filter/OUTPUT":   nil,
+			"filter/FORWARD":  nil,
+			"nat/PREROUTING":  nil,
+			"nat/OUTPUT":      nil,
+			"nat/POSTROUTING": nil,
+		},
+	}
+}
+
+func (n *fakeIPTables) Insert(table, chain string, pos int, args ...string) error {
+	k := table + "/" + chain
+	if rules, ok := n.n[k]; ok {
+		if pos > len(rules)+1 {
+			n.t.Errorf("bad position %d in %s", pos, k)
+			return errExec
+		}
+		rules = append(rules, "")
+		copy(rules[pos:], rules[pos-1:])
+		rules[pos-1] = strings.Join(args, " ")
+		n.n[k] = rules
+	} else {
+		n.t.Errorf("unknown table/chain %s", k)
+		return errExec
+	}
+	return nil
+}
+
+func (n *fakeIPTables) Append(table, chain string, args ...string) error {
+	k := table + "/" + chain
+	return n.Insert(table, chain, len(n.n[k])+1, args...)
+}
+
+func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) {
+	k := table + "/" + chain
+	if rules, ok := n.n[k]; ok {
+		for _, rule := range rules {
+			if rule == strings.Join(args, " ") {
+				return true, nil
+			}
+		}
+		return false, nil
+	} else {
+		n.t.Logf("unknown table/chain %s", k)
+		return false, errExec
+	}
+}
+
+func hasChain(n *fakeIPTables, table, chain string) bool {
+	k := table + "/" + chain
+	if _, ok := n.n[k]; ok {
+		return true
+	} else {
+		return false
+	}
+}
+
+func (n *fakeIPTables) Delete(table, chain string, args ...string) error {
+	k := table + "/" + chain
+	if rules, ok := n.n[k]; ok {
+		for i, rule := range rules {
+			if rule == strings.Join(args, " ") {
+				rules = append(rules[:i], rules[i+1:]...)
+				n.n[k] = rules
+				return nil
+			}
+		}
+		n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k)
+		return errExec
+	} else {
+		n.t.Errorf("unknown table/chain %s", k)
+		return errExec
+	}
+}
+
+func (n *fakeIPTables) ClearChain(table, chain string) error {
+	k := table + "/" + chain
+	if _, ok := n.n[k]; ok {
+		n.n[k] = nil
+		return nil
+	} else {
+		n.t.Logf("note: ClearChain: unknown table/chain %s", k)
+		return errors.New("exitcode:1")
+	}
+}
+
+func (n *fakeIPTables) NewChain(table, chain string) error {
+	k := table + "/" + chain
+	if _, ok := n.n[k]; ok {
+		n.t.Errorf("table/chain %s already exists", k)
+		return errExec
+	}
+	n.n[k] = nil
+	return nil
+}
+
+func (n *fakeIPTables) DeleteChain(table, chain string) error {
+	k := table + "/" + chain
+	if rules, ok := n.n[k]; ok {
+		if len(rules) != 0 {
+			n.t.Errorf("%s is not empty", k)
+			return errExec
+		}
+		delete(n.n, k)
+		return nil
+	} else {
+		n.t.Errorf("%s does not exist", k)
+		return errExec
+	}
+}
+
+func newFakeIPTablesRunner(t *testing.T) *iptablesRunner {
+	ipt4 := newIPTables(t)
+	ipt6 := newIPTables(t)
+
+	iptr := &iptablesRunner{ipt4, ipt6, true, true}
+	return iptr
+}
+
+func TestAddAndDeleteChains(t *testing.T) {
+	iptr := newFakeIPTablesRunner(t)
+	err := iptr.AddChains()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Check that the chains were created.
+	tsChains := []struct{ table, chain string }{ // table/chain
+		{"filter", "ts-input"},
+		{"filter", "ts-forward"},
+		{"nat", "ts-postrouting"},
+	}
+
+	for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} {
+		for _, tc := range tsChains {
+			// Exists returns error if the chain doesn't exist.
+			if _, err := proto.Exists(tc.table, tc.chain); err != nil {
+				t.Errorf("chain %s/%s doesn't exist", tc.table, tc.chain)
+			}
+		}
+	}
+
+	err = iptr.DelChains()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Check that the chains were deleted.
+	for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} {
+		for _, tc := range tsChains {
+			if _, err = proto.Exists(tc.table, tc.chain); err == nil {
+				t.Errorf("chain %s/%s still exists", tc.table, tc.chain)
+			}
+		}
+	}
+
+}
+
+func TestAddAndDeleteHooks(t *testing.T) {
+	iptr := newFakeIPTablesRunner(t)
+	// don't need to test what happens if the chains don't exist, because
+	// this is handled by fake iptables, in realife iptables would return error.
+	if err := iptr.AddChains(); err != nil {
+		t.Fatal(err)
+	}
+	defer iptr.DelChains()
+
+	if err := iptr.AddHooks(); err != nil {
+		t.Fatal(err)
+	}
+
+	// Check that the rules were created.
+	tsRules := []fakeRule{ // table/chain/rule
+		{"filter", "INPUT", []string{"-j", "ts-input"}},
+		{"filter", "FORWARD", []string{"-j", "ts-forward"}},
+		{"nat", "POSTROUTING", []string{"-j", "ts-postrouting"}},
+	}
+
+	for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} {
+		for _, tr := range tsRules {
+			if exists, err := proto.Exists(tr.table, tr.chain, tr.args...); err != nil {
+				t.Fatal(err)
+			} else if !exists {
+				t.Errorf("rule %s/%s/%s doesn't exist", tr.table, tr.chain, strings.Join(tr.args, " "))
+			}
+			// check if the rule is at front of the chain
+			if proto.(*fakeIPTables).n[tr.table+"/"+tr.chain][0] != strings.Join(tr.args, " ") {
+				t.Errorf("v4 rule %s/%s/%s is not at the top", tr.table, tr.chain, strings.Join(tr.args, " "))
+			}
+		}
+	}
+
+	if err := iptr.DelHooks(t.Logf); err != nil {
+		t.Fatal(err)
+	}
+
+	// Check that the rules were deleted.
+	for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} {
+		for _, tr := range tsRules {
+			if exists, err := proto.Exists(tr.table, tr.chain, tr.args...); err != nil {
+				t.Fatal(err)
+			} else if exists {
+				t.Errorf("rule %s/%s/%s still exists", tr.table, tr.chain, strings.Join(tr.args, " "))
+			}
+		}
+	}
+
+	if err := iptr.AddHooks(); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestAddAndDeleteBase(t *testing.T) {
+	iptr := newFakeIPTablesRunner(t)
+	tunname := "tun0"
+	if err := iptr.AddChains(); err != nil {
+		t.Fatal(err)
+	}
+
+	if err := iptr.AddBase(tunname); err != nil {
+		t.Fatal(err)
+	}
+
+	// Check that the rules were created.
+	tsRulesV4 := []fakeRule{ // table/chain/rule
+		{"filter", "ts-input", []string{"!", "-i", tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"}},
+		{"filter", "ts-input", []string{"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}},
+		{"filter", "ts-forward", []string{"-o", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}},
+	}
+
+	tsRulesCommon := []fakeRule{ // table/chain/rule
+		{"filter", "ts-forward", []string{"-i", tunname, "-j", "MARK", "--set-mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask}},
+		{"filter", "ts-forward", []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "ACCEPT"}},
+		{"filter", "ts-forward", []string{"-o", tunname, "-j", "ACCEPT"}},
+	}
+
+	// check that the rules were created for ipt4
+	for _, tr := range append(tsRulesV4, tsRulesCommon...) {
+		if exists, err := iptr.ipt4.Exists(tr.table, tr.chain, tr.args...); err != nil {
+			t.Fatal(err)
+		} else if !exists {
+			t.Errorf("rule %s/%s/%s doesn't exist", tr.table, tr.chain, strings.Join(tr.args, " "))
+		}
+	}
+
+	// check that the rules were created for ipt6
+	for _, tr := range tsRulesCommon {
+		if exists, err := iptr.ipt6.Exists(tr.table, tr.chain, tr.args...); err != nil {
+			t.Fatal(err)
+		} else if !exists {
+			t.Errorf("rule %s/%s/%s doesn't exist", tr.table, tr.chain, strings.Join(tr.args, " "))
+		}
+	}
+
+	if err := iptr.DelBase(); err != nil {
+		t.Fatal(err)
+	}
+
+	// Check that the rules were deleted.
+	for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} {
+		for _, tr := range append(tsRulesV4, tsRulesCommon...) {
+			if exists, err := proto.Exists(tr.table, tr.chain, tr.args...); err != nil {
+				t.Fatal(err)
+			} else if exists {
+				t.Errorf("rule %s/%s/%s still exists", tr.table, tr.chain, strings.Join(tr.args, " "))
+			}
+		}
+	}
+
+	if err := iptr.DelChains(); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestAddAndDelLoopbackRule(t *testing.T) {
+	iptr := newFakeIPTablesRunner(t)
+	// We don't need to test for malformed addresses, AddLoopbackRule
+	// takes in a netip.Addr, which is already valid.
+	fakeAddrV4 := netip.MustParseAddr("192.168.0.2")
+	fakeAddrV6 := netip.MustParseAddr("2001:db8::2")
+
+	if err := iptr.AddChains(); err != nil {
+		t.Fatal(err)
+	}
+	if err := iptr.AddLoopbackRule(fakeAddrV4); err != nil {
+		t.Fatal(err)
+	}
+	if err := iptr.AddLoopbackRule(fakeAddrV6); err != nil {
+		t.Fatal(err)
+	}
+
+	// Check that the rules were created.
+	tsRulesV4 := fakeRule{ // table/chain/rule
+		"filter", "ts-input", []string{"-i", "lo", "-s", fakeAddrV4.String(), "-j", "ACCEPT"}}
+
+	tsRulesV6 := fakeRule{ // table/chain/rule
+		"filter", "ts-input", []string{"-i", "lo", "-s", fakeAddrV6.String(), "-j", "ACCEPT"}}
+
+	// check that the rules were created for ipt4 and ipt6
+	if exist, err := iptr.ipt4.Exists(tsRulesV4.table, tsRulesV4.chain, tsRulesV4.args...); err != nil {
+		t.Fatal(err)
+	} else if !exist {
+		t.Errorf("rule %s/%s/%s doesn't exist", tsRulesV4.table, tsRulesV4.chain, strings.Join(tsRulesV4.args, " "))
+	}
+	if exist, err := iptr.ipt6.Exists(tsRulesV6.table, tsRulesV6.chain, tsRulesV6.args...); err != nil {
+		t.Fatal(err)
+	} else if !exist {
+		t.Errorf("rule %s/%s/%s doesn't exist", tsRulesV6.table, tsRulesV6.chain, strings.Join(tsRulesV6.args, " "))
+	}
+
+	// check that the rule is at the top
+	chain := "filter/ts-input"
+	if iptr.ipt4.(*fakeIPTables).n[chain][0] != strings.Join(tsRulesV4.args, " ") {
+		t.Errorf("v4 rule %s/%s/%s is not at the top", tsRulesV4.table, tsRulesV4.chain, strings.Join(tsRulesV4.args, " "))
+	}
+	if iptr.ipt6.(*fakeIPTables).n[chain][0] != strings.Join(tsRulesV6.args, " ") {
+		t.Errorf("v6 rule %s/%s/%s is not at the top", tsRulesV6.table, tsRulesV6.chain, strings.Join(tsRulesV6.args, " "))
+	}
+
+	// delete the rules
+	if err := iptr.DelLoopbackRule(fakeAddrV4); err != nil {
+		t.Fatal(err)
+	}
+	if err := iptr.DelLoopbackRule(fakeAddrV6); err != nil {
+		t.Fatal(err)
+	}
+
+	// Check that the rules were deleted.
+	if exist, err := iptr.ipt4.Exists(tsRulesV4.table, tsRulesV4.chain, tsRulesV4.args...); err != nil {
+		t.Fatal(err)
+	} else if exist {
+		t.Errorf("rule %s/%s/%s still exists", tsRulesV4.table, tsRulesV4.chain, strings.Join(tsRulesV4.args, " "))
+	}
+
+	if exist, err := iptr.ipt6.Exists(tsRulesV6.table, tsRulesV6.chain, tsRulesV6.args...); err != nil {
+		t.Fatal(err)
+	} else if exist {
+		t.Errorf("rule %s/%s/%s still exists", tsRulesV6.table, tsRulesV6.chain, strings.Join(tsRulesV6.args, " "))
+	}
+
+	if err := iptr.DelChains(); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestAddAndDelSNATRule(t *testing.T) {
+	iptr := newFakeIPTablesRunner(t)
+
+	if err := iptr.AddChains(); err != nil {
+		t.Fatal(err)
+	}
+
+	rule := fakeRule{ // table/chain/rule
+		"nat", "ts-postrouting", []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "MASQUERADE"},
+	}
+
+	// Add SNAT rule
+	if err := iptr.AddSNATRule(); err != nil {
+		t.Fatal(err)
+	}
+
+	// Check that the rule was created for ipt4 and ipt6
+	for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} {
+		if exist, err := proto.Exists(rule.table, rule.chain, rule.args...); err != nil {
+			t.Fatal(err)
+		} else if !exist {
+			t.Errorf("rule %s/%s/%s doesn't exist", rule.table, rule.chain, strings.Join(rule.args, " "))
+		}
+	}
+
+	// Delete SNAT rule
+	if err := iptr.DelSNATRule(); err != nil {
+		t.Fatal(err)
+	}
+
+	// Check that the rule was deleted for ipt4 and ipt6
+	for _, proto := range []iptablesInterface{iptr.ipt4, iptr.ipt6} {
+		if exist, err := proto.Exists(rule.table, rule.chain, rule.args...); err != nil {
+			t.Fatal(err)
+		} else if exist {
+			t.Errorf("rule %s/%s/%s still exists", rule.table, rule.chain, strings.Join(rule.args, " "))
+		}
+	}
+
+	if err := iptr.DelChains(); err != nil {
+		t.Fatal(err)
+	}
+}

+ 173 - 4
util/linuxfw/linuxfw.go

@@ -2,10 +2,179 @@
 // SPDX-License-Identifier: BSD-3-Clause
 
 // Package linuxfw returns the kind of firewall being used by the kernel.
+
+//go:build linux
+
 package linuxfw
 
-import "errors"
+import (
+	"bytes"
+	"errors"
+	"fmt"
+	"os"
+	"os/exec"
+	"strconv"
+	"strings"
+
+	"github.com/tailscale/netlink"
+	"tailscale.com/types/logger"
+)
+
+// The following bits are added to packet marks for Tailscale use.
+//
+// We tried to pick bits sufficiently out of the way that it's
+// unlikely to collide with existing uses. We have 4 bytes of mark
+// bits to play with. We leave the lower byte alone on the assumption
+// that sysadmins would use those. Kubernetes uses a few bits in the
+// second byte, so we steer clear of that too.
+//
+// Empirically, most of the documentation on packet marks on the
+// internet gives the impression that the marks are 16 bits
+// wide. Based on this, we theorize that the upper two bytes are
+// relatively unused in the wild, and so we consume bits 16:23 (the
+// third byte).
+//
+// The constants are in the iptables/iproute2 string format for
+// matching and setting the bits, so they can be directly embedded in
+// commands.
+const (
+	// The mask for reading/writing the 'firewall mask' bits on a packet.
+	// See the comment on the const block on why we only use the third byte.
+	//
+	// We claim bits 16:23 entirely. For now we only use the lower four
+	// bits, leaving the higher 4 bits for future use.
+	TailscaleFwmarkMask    = "0xff0000"
+	TailscaleFwmarkMaskNeg = "0xff00ffff"
+	TailscaleFwmarkMaskNum = 0xff0000
+
+	// Packet is from Tailscale and to a subnet route destination, so
+	// is allowed to be routed through this machine.
+	TailscaleSubnetRouteMark    = "0x40000"
+	TailscaleSubnetRouteMarkNum = 0x40000
+	// This one is same value but padded to even number of digit, so
+	// hex decoding can work correctly.
+	TailscaleSubnetRouteMarkHexStr = "0x040000"
+
+	// Packet was originated by tailscaled itself, and must not be
+	// routed over the Tailscale network.
+	TailscaleBypassMark    = "0x80000"
+	TailscaleBypassMarkNum = 0x80000
+)
+
+// errCode extracts and returns the process exit code from err, or
+// zero if err is nil.
+func errCode(err error) int {
+	if err == nil {
+		return 0
+	}
+	var e *exec.ExitError
+	if ok := errors.As(err, &e); ok {
+		return e.ExitCode()
+	}
+	s := err.Error()
+	if strings.HasPrefix(s, "exitcode:") {
+		code, err := strconv.Atoi(s[9:])
+		if err == nil {
+			return code
+		}
+	}
+	return -42
+}
+
+// checkIPv6 checks whether the system appears to have a working IPv6
+// network stack. It returns an error explaining what looks wrong or
+// missing.  It does not check that IPv6 is currently functional or
+// that there's a global address, just that the system would support
+// IPv6 if it were on an IPv6 network.
+func checkIPv6(logf logger.Logf) error {
+	_, err := os.Stat("/proc/sys/net/ipv6")
+	if os.IsNotExist(err) {
+		return err
+	}
+	bs, err := os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_ipv6")
+	if err != nil {
+		// Be conservative if we can't find the IPv6 configuration knob.
+		return err
+	}
+	disabled, err := strconv.ParseBool(strings.TrimSpace(string(bs)))
+	if err != nil {
+		return errors.New("disable_ipv6 has invalid bool")
+	}
+	if disabled {
+		return errors.New("disable_ipv6 is set")
+	}
+
+	// Older kernels don't support IPv6 policy routing. Some kernels
+	// support policy routing but don't have this knob, so absence of
+	// the knob is not fatal.
+	bs, err = os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_policy")
+	if err == nil {
+		disabled, err = strconv.ParseBool(strings.TrimSpace(string(bs)))
+		if err != nil {
+			return errors.New("disable_policy has invalid bool")
+		}
+		if disabled {
+			return errors.New("disable_policy is set")
+		}
+	}
+
+	if err := CheckIPRuleSupportsV6(logf); err != nil {
+		return fmt.Errorf("kernel doesn't support IPv6 policy routing: %w", err)
+	}
+
+	// Some distros ship ip6tables separately from iptables.
+	if _, err := exec.LookPath("ip6tables"); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+// checkSupportsV6NAT returns whether the system has a "nat" table in the
+// IPv6 netfilter stack.
+//
+// The nat table was added after the initial release of ipv6
+// netfilter, so some older distros ship a kernel that can't NAT IPv6
+// traffic.
+func checkSupportsV6NAT() bool {
+	bs, err := os.ReadFile("/proc/net/ip6_tables_names")
+	if err != nil {
+		// Can't read the file. Assume SNAT works.
+		return true
+	}
+	if bytes.Contains(bs, []byte("nat\n")) {
+		return true
+	}
+	// In nftables mode, that proc file will be empty. Try another thing:
+	if exec.Command("modprobe", "ip6table_nat").Run() == nil {
+		return true
+	}
+	return false
+}
+
+func CheckIPRuleSupportsV6(logf logger.Logf) error {
+	// First try just a read-only operation to ideally avoid
+	// having to modify any state.
+	if rules, err := netlink.RuleList(netlink.FAMILY_V6); err != nil {
+		return fmt.Errorf("querying IPv6 policy routing rules: %w", err)
+	} else {
+		if len(rules) > 0 {
+			logf("[v1] kernel supports IPv6 policy routing (found %d rules)", len(rules))
+			return nil
+		}
+	}
 
-// ErrUnsupported is the error returned from all functions on non-Linux
-// platforms.
-var ErrUnsupported = errors.New("unsupported")
+	// Try to actually create & delete one as a test.
+	rule := netlink.NewRule()
+	rule.Priority = 1234
+	rule.Mark = TailscaleBypassMarkNum
+	rule.Table = 52
+	rule.Family = netlink.FAMILY_V6
+	// First delete the rule unconditionally, and don't check for
+	// errors. This is just cleaning up anything that might be already
+	// there.
+	netlink.RuleDel(rule)
+	// And clean up on exit.
+	defer netlink.RuleDel(rule)
+	return netlink.RuleAdd(rule)
+}

+ 6 - 0
util/linuxfw/linuxfw_unsupported.go

@@ -9,9 +9,15 @@
 package linuxfw
 
 import (
+	"errors"
+
 	"tailscale.com/types/logger"
 )
 
+// ErrUnsupported is the error returned from all functions on non-Linux
+// platforms.
+var ErrUnsupported = errors.New("linuxfw:unsupported")
+
 // DebugNetfilter is not supported on non-Linux platforms.
 func DebugNetfilter(logf logger.Logf) error {
 	return ErrUnsupported

+ 104 - 553
wgengine/router/router_linux.go

@@ -4,7 +4,6 @@
 package router
 
 import (
-	"bytes"
 	"errors"
 	"fmt"
 	"net"
@@ -17,7 +16,6 @@ import (
 	"syscall"
 	"time"
 
-	"github.com/coreos/go-iptables/iptables"
 	"github.com/tailscale/netlink"
 	"github.com/tailscale/wireguard-go/tun"
 	"go4.org/netipx"
@@ -25,9 +23,9 @@ import (
 	"golang.org/x/time/rate"
 	"tailscale.com/envknob"
 	"tailscale.com/net/netmon"
-	"tailscale.com/net/tsaddr"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/preftype"
+	"tailscale.com/util/linuxfw"
 	"tailscale.com/util/multierr"
 	"tailscale.com/version/distro"
 )
@@ -38,56 +36,34 @@ const (
 	netfilterOn       = preftype.NetfilterOn
 )
 
-// The following bits are added to packet marks for Tailscale use.
-//
-// We tried to pick bits sufficiently out of the way that it's
-// unlikely to collide with existing uses. We have 4 bytes of mark
-// bits to play with. We leave the lower byte alone on the assumption
-// that sysadmins would use those. Kubernetes uses a few bits in the
-// second byte, so we steer clear of that too.
-//
-// Empirically, most of the documentation on packet marks on the
-// internet gives the impression that the marks are 16 bits
-// wide. Based on this, we theorize that the upper two bytes are
-// relatively unused in the wild, and so we consume bits 16:23 (the
-// third byte).
-//
-// The constants are in the iptables/iproute2 string format for
-// matching and setting the bits, so they can be directly embedded in
-// commands.
-const (
-	// The mask for reading/writing the 'firewall mask' bits on a packet.
-	// See the comment on the const block on why we only use the third byte.
-	//
-	// We claim bits 16:23 entirely. For now we only use the lower four
-	// bits, leaving the higher 4 bits for future use.
-	tailscaleFwmarkMask    = "0xff0000"
-	tailscaleFwmarkMaskNum = 0xff0000
-
-	// Packet is from Tailscale and to a subnet route destination, so
-	// is allowed to be routed through this machine.
-	tailscaleSubnetRouteMark = "0x40000"
-
-	// Packet was originated by tailscaled itself, and must not be
-	// routed over the Tailscale network.
-	//
-	// Keep this in sync with tailscaleBypassMark in
-	// net/netns/netns_linux.go.
-	tailscaleBypassMark    = "0x80000"
-	tailscaleBypassMarkNum = 0x80000
-)
-
 // netfilterRunner abstracts helpers to run netfilter commands. It
 // exists purely to swap out go-iptables for a fake implementation in
 // tests.
 type netfilterRunner interface {
-	Insert(table, chain string, pos int, args ...string) error
-	Append(table, chain string, args ...string) error
-	Exists(table, chain string, args ...string) (bool, error)
-	Delete(table, chain string, args ...string) error
-	ClearChain(table, chain string) error
-	NewChain(table, chain string) error
-	DeleteChain(table, chain string) error
+	AddLoopbackRule(addr netip.Addr) error
+	DelLoopbackRule(addr netip.Addr) error
+	AddHooks() error
+	DelHooks(logf logger.Logf) error
+	AddChains() error
+	DelChains() error
+	AddBase(tunname string) error
+	DelBase() error
+	AddSNATRule() error
+	DelSNATRule() error
+
+	HasIPV6() bool
+	HasIPV6NAT() bool
+}
+
+func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) {
+	var nfr netfilterRunner
+	var err error
+	nfr, err = linuxfw.NewIPTablesRunner(logf)
+	if err != nil {
+		return nil, err
+	}
+
+	return nfr, nil
 }
 
 type linuxRouter struct {
@@ -109,16 +85,13 @@ type linuxRouter struct {
 
 	// Various feature checks for the network stack.
 	ipRuleAvailable bool // whether kernel was built with IP_MULTIPLE_TABLES
-	v6Available     bool
-	v6NATAvailable  bool
 	fwmaskWorks     bool // whether we can use 'ip rule...fwmark <mark>/<mask>'
 
 	// ipPolicyPrefBase is the base priority at which ip rules are installed.
 	ipPolicyPrefBase int
 
-	ipt4 netfilterRunner
-	ipt6 netfilterRunner
-	cmd  commandRunner
+	nfr netfilterRunner
+	cmd commandRunner
 }
 
 func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Monitor) (Router, error) {
@@ -127,51 +100,27 @@ func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Moni
 		return nil, err
 	}
 
-	ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
+	nfr, err := newNetfilterRunner(logf)
 	if err != nil {
 		return nil, err
 	}
 
-	v6err := checkIPv6(logf)
-	if v6err != nil {
-		logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err)
-	}
-	supportsV6 := v6err == nil
-	supportsV6NAT := supportsV6 && supportsV6NAT()
-	if supportsV6 {
-		logf("v6nat = %v", supportsV6NAT)
-	}
-
-	var ipt6 netfilterRunner
-	if supportsV6 {
-		// The iptables package probes for `ip6tables` and errors out
-		// if unavailable. We want that to be a non-fatal error.
-		ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
-		if err != nil {
-			return nil, err
-		}
-	}
-
 	cmd := osCommandRunner{
 		ambientCapNetAdmin: useAmbientCaps(),
 	}
 
-	return newUserspaceRouterAdvanced(logf, tunname, netMon, ipt4, ipt6, cmd, supportsV6, supportsV6NAT)
+	return newUserspaceRouterAdvanced(logf, tunname, netMon, nfr, cmd)
 }
 
-func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, netfilter4, netfilter6 netfilterRunner, cmd commandRunner, supportsV6, supportsV6NAT bool) (Router, error) {
+func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, nfr netfilterRunner, cmd commandRunner) (Router, error) {
 	r := &linuxRouter{
 		logf:          logf,
 		tunname:       tunname,
 		netfilterMode: netfilterOff,
 		netMon:        netMon,
 
-		v6Available:    supportsV6,
-		v6NATAvailable: supportsV6NAT,
-
-		ipt4: netfilter4,
-		ipt6: netfilter6,
-		cmd:  cmd,
+		nfr: nfr,
+		cmd: cmd,
 
 		ipRuleFixLimiter: rate.NewLimiter(rate.Every(5*time.Second), 10),
 		ipPolicyPrefBase: 5200,
@@ -484,23 +433,23 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error {
 	case netfilterOff:
 		switch r.netfilterMode {
 		case netfilterNoDivert:
-			if err := r.delNetfilterBase(); err != nil {
+			if err := r.nfr.DelBase(); err != nil {
 				return err
 			}
-			if err := r.delNetfilterChains(); err != nil {
+			if err := r.nfr.DelChains(); err != nil {
 				r.logf("note: %v", err)
 				// harmless, continue.
 				// This can happen if someone left a ref to
 				// this table somewhere else.
 			}
 		case netfilterOn:
-			if err := r.delNetfilterHooks(); err != nil {
+			if err := r.nfr.DelHooks(r.logf); err != nil {
 				return err
 			}
-			if err := r.delNetfilterBase(); err != nil {
+			if err := r.nfr.DelBase(); err != nil {
 				return err
 			}
-			if err := r.delNetfilterChains(); err != nil {
+			if err := r.nfr.DelChains(); err != nil {
 				r.logf("note: %v", err)
 				// harmless, continue.
 				// This can happen if someone left a ref to
@@ -512,15 +461,15 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error {
 		switch r.netfilterMode {
 		case netfilterOff:
 			reprocess = true
-			if err := r.addNetfilterChains(); err != nil {
+			if err := r.nfr.AddChains(); err != nil {
 				return err
 			}
-			if err := r.addNetfilterBase(); err != nil {
+			if err := r.nfr.AddBase(r.tunname); err != nil {
 				return err
 			}
 			r.snatSubnetRoutes = false
 		case netfilterOn:
-			if err := r.delNetfilterHooks(); err != nil {
+			if err := r.nfr.DelHooks(r.logf); err != nil {
 				return err
 			}
 		}
@@ -529,33 +478,33 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error {
 		// we can't add a "-j ts-forward" rule to FORWARD
 		// while ts-forward contains an "-m mark" rule. But
 		// we can add the row *before* populating ts-forward.
-		// So we have to delNetFilterBase, then add the hooks,
-		// then re-addNetFilterBase, just in case.
+		// So we have to delBase, then add the hooks,
+		// then re-addBase, just in case.
 		switch r.netfilterMode {
 		case netfilterOff:
 			reprocess = true
-			if err := r.addNetfilterChains(); err != nil {
+			if err := r.nfr.AddChains(); err != nil {
 				return err
 			}
-			if err := r.delNetfilterBase(); err != nil {
+			if err := r.nfr.DelBase(); err != nil {
 				return err
 			}
-			if err := r.addNetfilterHooks(); err != nil {
+			if err := r.nfr.AddHooks(); err != nil {
 				return err
 			}
-			if err := r.addNetfilterBase(); err != nil {
+			if err := r.nfr.AddBase(r.tunname); err != nil {
 				return err
 			}
 			r.snatSubnetRoutes = false
 		case netfilterNoDivert:
 			reprocess = true
-			if err := r.delNetfilterBase(); err != nil {
+			if err := r.nfr.DelBase(); err != nil {
 				return err
 			}
-			if err := r.addNetfilterHooks(); err != nil {
+			if err := r.nfr.AddHooks(); err != nil {
 				return err
 			}
-			if err := r.addNetfilterBase(); err != nil {
+			if err := r.nfr.AddBase(r.tunname); err != nil {
 				return err
 			}
 			r.snatSubnetRoutes = false
@@ -579,11 +528,19 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error {
 	return nil
 }
 
+func (r *linuxRouter) getV6Available() bool {
+	return r.nfr.HasIPV6()
+}
+
+func (r *linuxRouter) getV6NATAvailable() bool {
+	return r.nfr.HasIPV6NAT()
+}
+
 // addAddress adds an IP/mask to the tunnel interface. Fails if the
 // address is already assigned to the interface, or if the addition
 // fails.
 func (r *linuxRouter) addAddress(addr netip.Prefix) error {
-	if !r.v6Available && addr.Addr().Is6() {
+	if !r.getV6Available() && addr.Addr().Is6() {
 		return nil
 	}
 	if r.useIPCommand() {
@@ -609,7 +566,7 @@ func (r *linuxRouter) addAddress(addr netip.Prefix) error {
 // the address is not assigned to the interface, or if the removal
 // fails.
 func (r *linuxRouter) delAddress(addr netip.Prefix) error {
-	if !r.v6Available && addr.Addr().Is6() {
+	if !r.getV6Available() && addr.Addr().Is6() {
 		return nil
 	}
 	if err := r.delLoopbackRule(addr.Addr()); err != nil {
@@ -638,17 +595,8 @@ func (r *linuxRouter) addLoopbackRule(addr netip.Addr) error {
 		return nil
 	}
 
-	nf := r.ipt4
-	if addr.Is6() {
-		if !r.v6Available {
-			// IPv6 not available, ignore.
-			return nil
-		}
-		nf = r.ipt6
-	}
-
-	if err := nf.Insert("filter", "ts-input", 1, "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil {
-		return fmt.Errorf("adding loopback allow rule for %q: %w", addr, err)
+	if err := r.nfr.AddLoopbackRule(addr); err != nil {
+		return err
 	}
 	return nil
 }
@@ -660,17 +608,8 @@ func (r *linuxRouter) delLoopbackRule(addr netip.Addr) error {
 		return nil
 	}
 
-	nf := r.ipt4
-	if addr.Is6() {
-		if !r.v6Available {
-			// IPv6 not available, ignore.
-			return nil
-		}
-		nf = r.ipt6
-	}
-
-	if err := nf.Delete("filter", "ts-input", "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil {
-		return fmt.Errorf("deleting loopback allow rule for %q: %w", addr, err)
+	if err := r.nfr.DelLoopbackRule(addr); err != nil {
+		return err
 	}
 	return nil
 }
@@ -679,7 +618,7 @@ func (r *linuxRouter) delLoopbackRule(addr netip.Addr) error {
 // interface. Fails if the route already exists, or if adding the
 // route fails.
 func (r *linuxRouter) addRoute(cidr netip.Prefix) error {
-	if !r.v6Available && cidr.Addr().Is6() {
+	if !r.getV6Available() && cidr.Addr().Is6() {
 		return nil
 	}
 	if r.useIPCommand() {
@@ -704,7 +643,7 @@ func (r *linuxRouter) addThrowRoute(cidr netip.Prefix) error {
 	if !r.ipRuleAvailable {
 		return nil
 	}
-	if !r.v6Available && cidr.Addr().Is6() {
+	if !r.getV6Available() && cidr.Addr().Is6() {
 		return nil
 	}
 	if r.useIPCommand() {
@@ -712,7 +651,7 @@ func (r *linuxRouter) addThrowRoute(cidr netip.Prefix) error {
 	}
 	err := netlink.RouteReplace(&netlink.Route{
 		Dst:   netipx.PrefixIPNet(cidr.Masked()),
-		Table: tailscaleRouteTable.num,
+		Table: tailscaleRouteTable.Num,
 		Type:  unix.RTN_THROW,
 	})
 	if err != nil {
@@ -722,7 +661,7 @@ func (r *linuxRouter) addThrowRoute(cidr netip.Prefix) error {
 }
 
 func (r *linuxRouter) addRouteDef(routeDef []string, cidr netip.Prefix) error {
-	if !r.v6Available && cidr.Addr().Is6() {
+	if !r.getV6Available() && cidr.Addr().Is6() {
 		return nil
 	}
 	args := append([]string{"ip", "route", "add"}, routeDef...)
@@ -756,7 +695,7 @@ var (
 // interface. Fails if the route doesn't exist, or if removing the
 // route fails.
 func (r *linuxRouter) delRoute(cidr netip.Prefix) error {
-	if !r.v6Available && cidr.Addr().Is6() {
+	if !r.getV6Available() && cidr.Addr().Is6() {
 		return nil
 	}
 	if r.useIPCommand() {
@@ -784,7 +723,7 @@ func (r *linuxRouter) delThrowRoute(cidr netip.Prefix) error {
 	if !r.ipRuleAvailable {
 		return nil
 	}
-	if !r.v6Available && cidr.Addr().Is6() {
+	if !r.getV6Available() && cidr.Addr().Is6() {
 		return nil
 	}
 	if r.useIPCommand() {
@@ -803,7 +742,7 @@ func (r *linuxRouter) delThrowRoute(cidr netip.Prefix) error {
 }
 
 func (r *linuxRouter) delRouteDef(routeDef []string, cidr netip.Prefix) error {
-	if !r.v6Available && cidr.Addr().Is6() {
+	if !r.getV6Available() && cidr.Addr().Is6() {
 		return nil
 	}
 	args := append([]string{"ip", "route", "del"}, routeDef...)
@@ -865,7 +804,7 @@ func (r *linuxRouter) linkIndex() (int, error) {
 // routeTable returns the route table to use.
 func (r *linuxRouter) routeTable() int {
 	if r.ipRuleAvailable {
-		return tailscaleRouteTable.num
+		return tailscaleRouteTable.Num
 	}
 	return 0
 }
@@ -962,7 +901,7 @@ func (f addrFamily) netlinkInt() int {
 }
 
 func (r *linuxRouter) addrFamilies() []addrFamily {
-	if r.v6Available {
+	if r.getV6Available() {
 		return []addrFamily{v4, v6}
 	}
 	return []addrFamily{v4}
@@ -985,30 +924,34 @@ func (r *linuxRouter) addIPRules() error {
 	return r.justAddIPRules()
 }
 
-// routeTable is a Linux routing table: both its name and number.
+// RouteTable is a Linux routing table: both its name and number.
 // See /etc/iproute2/rt_tables.
-type routeTable struct {
-	name string
-	num  int
+type RouteTable struct {
+	Name string
+	Num  int
 }
 
-// ipCmdArg returns the string form of the table to pass to the "ip" command.
-func (rt routeTable) ipCmdArg() string {
-	if rt.num >= 253 {
-		return rt.name
+var routeTableByNumber = map[int]RouteTable{}
+
+// IpCmdArg returns the string form of the table to pass to the "ip" command.
+func (rt RouteTable) ipCmdArg() string {
+	if rt.Num >= 253 {
+		return rt.Name
 	}
-	return strconv.Itoa(rt.num)
+	return strconv.Itoa(rt.Num)
 }
 
-var routeTableByNumber = map[int]routeTable{}
-
-func newRouteTable(name string, num int) routeTable {
-	rt := routeTable{name, num}
+func newRouteTable(name string, num int) RouteTable {
+	rt := RouteTable{name, num}
 	routeTableByNumber[num] = rt
 	return rt
 }
 
-func mustRouteTable(num int) routeTable {
+// MustRouteTable returns the RouteTable with the given number key.
+// It panics if the number is unknown because this result is a part
+// of IP rule argument and we don't want to continue with an invalid
+// argument with table no exist.
+func mustRouteTable(num int) RouteTable {
 	rt, ok := routeTableByNumber[num]
 	if !ok {
 		panic(fmt.Sprintf("unknown route table %v", num))
@@ -1059,22 +1002,22 @@ var ipRules = []netlink.Rule{
 	// main routing table.
 	{
 		Priority: 10,
-		Mark:     tailscaleBypassMarkNum,
-		Table:    mainRouteTable.num,
+		Mark:     linuxfw.TailscaleBypassMarkNum,
+		Table:    mainRouteTable.Num,
 	},
 	// ...and then we try the 'default' table, for correctness,
 	// even though it's been empty on every Linux system I've ever seen.
 	{
 		Priority: 30,
-		Mark:     tailscaleBypassMarkNum,
-		Table:    defaultRouteTable.num,
+		Mark:     linuxfw.TailscaleBypassMarkNum,
+		Table:    defaultRouteTable.Num,
 	},
 	// If neither of those matched (no default route on this system?)
 	// then packets from us should be aborted rather than falling through
 	// to the tailscale routes, because that would create routing loops.
 	{
 		Priority: 50,
-		Mark:     tailscaleBypassMarkNum,
+		Mark:     linuxfw.TailscaleBypassMarkNum,
 		Type:     unix.RTN_UNREACHABLE,
 	},
 	// If we get to this point, capture all packets and send them
@@ -1084,7 +1027,7 @@ var ipRules = []netlink.Rule{
 	// beat non-VPN routes.
 	{
 		Priority: 70,
-		Table:    tailscaleRouteTable.num,
+		Table:    tailscaleRouteTable.Num,
 	},
 	// If that didn't match, then non-fwmark packets fall through to the
 	// usual rules (pref 32766 and 32767, ie. main and default).
@@ -1105,7 +1048,7 @@ func (r *linuxRouter) justAddIPRules() error {
 			// Note: r is a value type here; safe to mutate it.
 			ru.Family = family.netlinkInt()
 			if ru.Mark != 0 {
-				ru.Mask = tailscaleFwmarkMaskNum
+				ru.Mask = linuxfw.TailscaleFwmarkMaskNum
 			}
 			ru.Goto = -1
 			ru.SuppressIfgroup = -1
@@ -1138,7 +1081,7 @@ func (r *linuxRouter) addIPRulesWithIPCommand() error {
 			}
 			if rule.Mark != 0 {
 				if r.fwmaskWorks {
-					args = append(args, "fwmark", fmt.Sprintf("0x%x/%s", rule.Mark, tailscaleFwmarkMask))
+					args = append(args, "fwmark", fmt.Sprintf("0x%x/%s", rule.Mark, linuxfw.TailscaleFwmarkMask))
 				} else {
 					args = append(args, "fwmark", fmt.Sprintf("0x%x", rule.Mark))
 				}
@@ -1239,284 +1182,6 @@ func (r *linuxRouter) delIPRulesWithIPCommand() error {
 	return rg.ErrAcc
 }
 
-func (r *linuxRouter) netfilterFamilies() []netfilterRunner {
-	if r.v6Available {
-		return []netfilterRunner{r.ipt4, r.ipt6}
-	}
-	return []netfilterRunner{r.ipt4}
-}
-
-// addNetfilterChains creates custom Tailscale chains in netfilter.
-func (r *linuxRouter) addNetfilterChains() error {
-	create := func(ipt netfilterRunner, table, chain string) error {
-		err := ipt.ClearChain(table, chain)
-		if errCode(err) == 1 {
-			// nonexistent chain. let's create it!
-			return ipt.NewChain(table, chain)
-		}
-		if err != nil {
-			return fmt.Errorf("setting up %s/%s: %w", table, chain, err)
-		}
-		return nil
-	}
-
-	for _, ipt := range r.netfilterFamilies() {
-		if err := create(ipt, "filter", "ts-input"); err != nil {
-			return err
-		}
-		if err := create(ipt, "filter", "ts-forward"); err != nil {
-			return err
-		}
-	}
-	if err := create(r.ipt4, "nat", "ts-postrouting"); err != nil {
-		return err
-	}
-	if r.v6NATAvailable {
-		if err := create(r.ipt6, "nat", "ts-postrouting"); err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-// addNetfilterBase adds some basic processing rules to be
-// supplemented by later calls to other helpers.
-func (r *linuxRouter) addNetfilterBase() error {
-	if err := r.addNetfilterBase4(); err != nil {
-		return err
-	}
-	if r.v6Available {
-		if err := r.addNetfilterBase6(); err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-// addNetfilterBase4 adds some basic IPv4 processing rules to be
-// supplemented by later calls to other helpers.
-func (r *linuxRouter) addNetfilterBase4() error {
-	// Only allow CGNAT range traffic to come from tailscale0. There
-	// is an exception carved out for ranges used by ChromeOS, for
-	// which we fall out of the Tailscale chain.
-	//
-	// Note, this will definitely break nodes that end up using the
-	// CGNAT range for other purposes :(.
-	args := []string{"!", "-i", r.tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"}
-	if err := r.ipt4.Append("filter", "ts-input", args...); err != nil {
-		return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err)
-	}
-	args = []string{"!", "-i", r.tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}
-	if err := r.ipt4.Append("filter", "ts-input", args...); err != nil {
-		return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err)
-	}
-
-	// Forward all traffic from the Tailscale interface, and drop
-	// traffic to the tailscale interface by default. We use packet
-	// marks here so both filter/FORWARD and nat/POSTROUTING can match
-	// on these packets of interest.
-	//
-	// In particular, we only want to apply SNAT rules in
-	// nat/POSTROUTING to packets that originated from the Tailscale
-	// interface, but we can't match on the inbound interface in
-	// POSTROUTING. So instead, we match on the inbound interface in
-	// filter/FORWARD, and set a packet mark that nat/POSTROUTING can
-	// use to effectively run that same test again.
-	args = []string{"-i", r.tunname, "-j", "MARK", "--set-mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask}
-	if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil {
-		return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
-	}
-	args = []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "ACCEPT"}
-	if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil {
-		return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
-	}
-	args = []string{"-o", r.tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}
-	if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil {
-		return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
-	}
-	args = []string{"-o", r.tunname, "-j", "ACCEPT"}
-	if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil {
-		return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
-	}
-
-	return nil
-}
-
-// addNetfilterBase4 adds some basic IPv6 processing rules to be
-// supplemented by later calls to other helpers.
-func (r *linuxRouter) addNetfilterBase6() error {
-	// TODO: only allow traffic from Tailscale's ULA range to come
-	// from tailscale0.
-
-	args := []string{"-i", r.tunname, "-j", "MARK", "--set-mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask}
-	if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil {
-		return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err)
-	}
-	args = []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "ACCEPT"}
-	if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil {
-		return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err)
-	}
-	// TODO: drop forwarded traffic to tailscale0 from tailscale's ULA
-	// (see corresponding IPv4 CGNAT rule).
-	args = []string{"-o", r.tunname, "-j", "ACCEPT"}
-	if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil {
-		return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err)
-	}
-
-	return nil
-}
-
-// delNetfilterChains removes the custom Tailscale chains from netfilter.
-func (r *linuxRouter) delNetfilterChains() error {
-	del := func(ipt netfilterRunner, table, chain string) error {
-		if err := ipt.ClearChain(table, chain); err != nil {
-			if errCode(err) == 1 {
-				// nonexistent chain. That's fine, since it's
-				// the desired state anyway.
-				return nil
-			}
-			return fmt.Errorf("flushing %s/%s: %w", table, chain, err)
-		}
-		if err := ipt.DeleteChain(table, chain); err != nil {
-			// this shouldn't fail, because if the chain didn't
-			// exist, we would have returned after ClearChain.
-			return fmt.Errorf("deleting %s/%s: %v", table, chain, err)
-		}
-		return nil
-	}
-
-	for _, ipt := range r.netfilterFamilies() {
-		if err := del(ipt, "filter", "ts-input"); err != nil {
-			return err
-		}
-		if err := del(ipt, "filter", "ts-forward"); err != nil {
-			return err
-		}
-	}
-	if err := del(r.ipt4, "nat", "ts-postrouting"); err != nil {
-		return err
-	}
-	if r.v6NATAvailable {
-		if err := del(r.ipt6, "nat", "ts-postrouting"); err != nil {
-			return err
-		}
-	}
-
-	return nil
-}
-
-// delNetfilterBase empties but does not remove custom Tailscale chains from
-// netfilter.
-func (r *linuxRouter) delNetfilterBase() error {
-	del := func(ipt netfilterRunner, table, chain string) error {
-		if err := ipt.ClearChain(table, chain); err != nil {
-			if errCode(err) == 1 {
-				// nonexistent chain. That's fine, since it's
-				// the desired state anyway.
-				return nil
-			}
-			return fmt.Errorf("flushing %s/%s: %w", table, chain, err)
-		}
-		return nil
-	}
-
-	for _, ipt := range r.netfilterFamilies() {
-		if err := del(ipt, "filter", "ts-input"); err != nil {
-			return err
-		}
-		if err := del(ipt, "filter", "ts-forward"); err != nil {
-			return err
-		}
-	}
-	if err := del(r.ipt4, "nat", "ts-postrouting"); err != nil {
-		return err
-	}
-	if r.v6NATAvailable {
-		if err := del(r.ipt6, "nat", "ts-postrouting"); err != nil {
-			return err
-		}
-	}
-
-	return nil
-}
-
-// addNetfilterHooks inserts calls to tailscale's netfilter chains in
-// the relevant main netfilter chains. The tailscale chains must
-// already exist.
-func (r *linuxRouter) addNetfilterHooks() error {
-	divert := func(ipt netfilterRunner, table, chain string) error {
-		tsChain := tsChain(chain)
-
-		args := []string{"-j", tsChain}
-		exists, err := ipt.Exists(table, chain, args...)
-		if err != nil {
-			return fmt.Errorf("checking for %v in %s/%s: %w", args, table, chain, err)
-		}
-		if exists {
-			return nil
-		}
-		if err := ipt.Insert(table, chain, 1, args...); err != nil {
-			return fmt.Errorf("adding %v in %s/%s: %w", args, table, chain, err)
-		}
-		return nil
-	}
-
-	for _, ipt := range r.netfilterFamilies() {
-		if err := divert(ipt, "filter", "INPUT"); err != nil {
-			return err
-		}
-		if err := divert(ipt, "filter", "FORWARD"); err != nil {
-			return err
-		}
-	}
-	if err := divert(r.ipt4, "nat", "POSTROUTING"); err != nil {
-		return err
-	}
-	if r.v6NATAvailable {
-		if err := divert(r.ipt6, "nat", "POSTROUTING"); err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-// delNetfilterHooks deletes the calls to tailscale's netfilter chains
-// in the relevant main netfilter chains.
-func (r *linuxRouter) delNetfilterHooks() error {
-	del := func(ipt netfilterRunner, table, chain string) error {
-		tsChain := tsChain(chain)
-		args := []string{"-j", tsChain}
-		if err := ipt.Delete(table, chain, args...); err != nil {
-			// TODO(apenwarr): check for errCode(1) here.
-			// Unfortunately the error code from the iptables
-			// module resists unwrapping, unlike with other
-			// calls. So we have to assume if Delete fails,
-			// it's because there is no such rule.
-			r.logf("note: deleting %v in %s/%s: %w", args, table, chain, err)
-			return nil
-		}
-		return nil
-	}
-
-	for _, ipt := range r.netfilterFamilies() {
-		if err := del(ipt, "filter", "INPUT"); err != nil {
-			return err
-		}
-		if err := del(ipt, "filter", "FORWARD"); err != nil {
-			return err
-		}
-	}
-	if err := del(r.ipt4, "nat", "POSTROUTING"); err != nil {
-		return err
-	}
-	if r.v6NATAvailable {
-		if err := del(r.ipt6, "nat", "POSTROUTING"); err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
 // addSNATRule adds a netfilter rule to SNAT traffic destined for
 // local subnets.
 func (r *linuxRouter) addSNATRule() error {
@@ -1524,14 +1189,8 @@ func (r *linuxRouter) addSNATRule() error {
 		return nil
 	}
 
-	args := []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "MASQUERADE"}
-	if err := r.ipt4.Append("nat", "ts-postrouting", args...); err != nil {
-		return fmt.Errorf("adding %v in v4/nat/ts-postrouting: %w", args, err)
-	}
-	if r.v6NATAvailable {
-		if err := r.ipt6.Append("nat", "ts-postrouting", args...); err != nil {
-			return fmt.Errorf("adding %v in v6/nat/ts-postrouting: %w", args, err)
-		}
+	if err := r.nfr.AddSNATRule(); err != nil {
+		return err
 	}
 	return nil
 }
@@ -1543,14 +1202,8 @@ func (r *linuxRouter) delSNATRule() error {
 		return nil
 	}
 
-	args := []string{"-m", "mark", "--mark", tailscaleSubnetRouteMark + "/" + tailscaleFwmarkMask, "-j", "MASQUERADE"}
-	if err := r.ipt4.Delete("nat", "ts-postrouting", args...); err != nil {
-		return fmt.Errorf("deleting %v in v4/nat/ts-postrouting: %w", args, err)
-	}
-	if r.v6NATAvailable {
-		if err := r.ipt6.Delete("nat", "ts-postrouting", args...); err != nil {
-			return fmt.Errorf("deleting %v in v6/nat/ts-postrouting: %w", args, err)
-		}
+	if err := r.nfr.DelSNATRule(); err != nil {
+		return err
 	}
 	return nil
 }
@@ -1619,12 +1272,6 @@ func cidrDiff(kind string, old map[netip.Prefix]bool, new []netip.Prefix, add, d
 	return ret, nil
 }
 
-// tsChain returns the name of the tailscale sub-chain corresponding
-// to the given "parent" chain (e.g. INPUT, FORWARD, ...).
-func tsChain(chain string) string {
-	return "ts-" + strings.ToLower(chain)
-}
-
 // normalizeCIDR returns cidr as an ip/mask string, with the host bits
 // of the IP address zeroed out.
 func normalizeCIDR(cidr netip.Prefix) string {
@@ -1632,105 +1279,9 @@ func normalizeCIDR(cidr netip.Prefix) string {
 }
 
 func cleanup(logf logger.Logf, interfaceName string) {
-	// TODO(dmytro): clean up iptables.
-}
-
-// checkIPv6 checks whether the system appears to have a working IPv6
-// network stack. It returns an error explaining what looks wrong or
-// missing.  It does not check that IPv6 is currently functional or
-// that there's a global address, just that the system would support
-// IPv6 if it were on an IPv6 network.
-func checkIPv6(logf logger.Logf) error {
-	_, err := os.Stat("/proc/sys/net/ipv6")
-	if os.IsNotExist(err) {
-		return err
-	}
-	bs, err := os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_ipv6")
-	if err != nil {
-		// Be conservative if we can't find the IPv6 configuration knob.
-		return err
+	if interfaceName != "userspace-networking" {
+		linuxfw.IPTablesCleanup(logf)
 	}
-	disabled, err := strconv.ParseBool(strings.TrimSpace(string(bs)))
-	if err != nil {
-		return errors.New("disable_ipv6 has invalid bool")
-	}
-	if disabled {
-		return errors.New("disable_ipv6 is set")
-	}
-
-	// Older kernels don't support IPv6 policy routing. Some kernels
-	// support policy routing but don't have this knob, so absence of
-	// the knob is not fatal.
-	bs, err = os.ReadFile("/proc/sys/net/ipv6/conf/all/disable_policy")
-	if err == nil {
-		disabled, err = strconv.ParseBool(strings.TrimSpace(string(bs)))
-		if err != nil {
-			return errors.New("disable_policy has invalid bool")
-		}
-		if disabled {
-			return errors.New("disable_policy is set")
-		}
-	}
-
-	if err := checkIPRuleSupportsV6(logf); err != nil {
-		return fmt.Errorf("kernel doesn't support IPv6 policy routing: %w", err)
-	}
-
-	// Some distros ship ip6tables separately from iptables.
-	if _, err := exec.LookPath("ip6tables"); err != nil {
-		return err
-	}
-
-	return nil
-}
-
-// supportsV6NAT returns whether the system has a "nat" table in the
-// IPv6 netfilter stack.
-//
-// The nat table was added after the initial release of ipv6
-// netfilter, so some older distros ship a kernel that can't NAT IPv6
-// traffic.
-func supportsV6NAT() bool {
-	bs, err := os.ReadFile("/proc/net/ip6_tables_names")
-	if err != nil {
-		// Can't read the file. Assume SNAT works.
-		return true
-	}
-	if bytes.Contains(bs, []byte("nat\n")) {
-		return true
-	}
-	// In nftables mode, that proc file will be empty. Try another thing:
-	if exec.Command("modprobe", "ip6table_nat").Run() == nil {
-		return true
-	}
-	return false
-}
-
-func checkIPRuleSupportsV6(logf logger.Logf) error {
-	// First try just a read-only operation to ideally avoid
-	// having to modify any state.
-	if rules, err := netlink.RuleList(netlink.FAMILY_V6); err != nil {
-		return fmt.Errorf("querying IPv6 policy routing rules: %w", err)
-	} else {
-		if len(rules) > 0 {
-			logf("[v1] kernel supports IPv6 policy routing (found %d rules)", len(rules))
-			return nil
-		}
-	}
-
-	// Try to actually create & delete one as a test.
-	rule := netlink.NewRule()
-	rule.Priority = 1234
-	rule.Mark = tailscaleBypassMarkNum
-	rule.Table = tailscaleRouteTable.num
-	rule.Family = netlink.FAMILY_V6
-	// First delete the rule unconditionally, and don't check for
-	// errors. This is just cleaning up anything that might be already
-	// there.
-	netlink.RuleDel(rule)
-	// And clean up on exit.
-	defer netlink.RuleDel(rule)
-	return netlink.RuleAdd(rule)
 }
 
 // Checks if the running openWRT system is using mwan3, based on the heuristic

+ 205 - 86
wgengine/router/router_linux_test.go

@@ -22,8 +22,10 @@ import (
 	"github.com/vishvananda/netlink"
 	"golang.org/x/exp/slices"
 	"tailscale.com/net/netmon"
+	"tailscale.com/net/tsaddr"
 	"tailscale.com/tstest"
 	"tailscale.com/types/logger"
+	"tailscale.com/util/linuxfw"
 )
 
 func TestRouterStates(t *testing.T) {
@@ -328,7 +330,7 @@ ip route add throw 192.168.0.0/24 table 52` + basic,
 	defer mon.Close()
 
 	fake := NewFakeOS(t)
-	router, err := newUserspaceRouterAdvanced(t.Logf, "tailscale0", mon, fake.netfilter4, fake.netfilter6, fake, true, true)
+	router, err := newUserspaceRouterAdvanced(t.Logf, "tailscale0", mon, fake.nfr, fake)
 	if err != nil {
 		t.Fatalf("failed to create router: %v", err)
 	}
@@ -362,15 +364,17 @@ ip route add throw 192.168.0.0/24 table 52` + basic,
 	}
 }
 
-type fakeNetfilter struct {
-	t *testing.T
-	n map[string][]string
+type fakeIPTablesRunner struct {
+	t    *testing.T
+	ipt4 map[string][]string
+	ipt6 map[string][]string
+	//we always assume ipv6 and ipv6 nat are enabled when testing
 }
 
-func newNetfilter(t *testing.T) *fakeNetfilter {
-	return &fakeNetfilter{
+func newIPTablesRunner(t *testing.T) netfilterRunner {
+	return &fakeIPTablesRunner{
 		t: t,
-		n: map[string][]string{
+		ipt4: map[string][]string{
 			"filter/INPUT":    nil,
 			"filter/OUTPUT":   nil,
 			"filter/FORWARD":  nil,
@@ -378,118 +382,233 @@ func newNetfilter(t *testing.T) *fakeNetfilter {
 			"nat/OUTPUT":      nil,
 			"nat/POSTROUTING": nil,
 		},
+		ipt6: map[string][]string{
+			"filter/INPUT":    nil,
+			"filter/OUTPUT":   nil,
+			"filter/FORWARD":  nil,
+			"nat/PREROUTING":  nil,
+			"nat/OUTPUT":      nil,
+			"nat/POSTROUTING": nil,
+		},
+	}
+}
+
+func insertRule(n *fakeIPTablesRunner, curIPT map[string][]string, chain, newRule string) error {
+	// Get current rules for filter/ts-input chain with according IP version
+	curTSInputRules, ok := curIPT[chain]
+	if !ok {
+		n.t.Fatalf("no %s chain exists", chain)
+		return fmt.Errorf("no %s chain exists", chain)
+	}
+
+	// Add new rule to top of filter/ts-input
+	curTSInputRules = append(curTSInputRules, "")
+	copy(curTSInputRules[1:], curTSInputRules)
+	curTSInputRules[0] = newRule
+	curIPT[chain] = curTSInputRules
+	return nil
+}
+
+func appendRule(n *fakeIPTablesRunner, curIPT map[string][]string, chain, newRule string) error {
+	// Get current rules for filter/ts-input chain with according IP version
+	curTSInputRules, ok := curIPT[chain]
+	if !ok {
+		n.t.Fatalf("no %s chain exists", chain)
+		return fmt.Errorf("no %s chain exists", chain)
+	}
+
+	// Add new rule to end of filter/ts-input
+	curTSInputRules = append(curTSInputRules, newRule)
+	curIPT[chain] = curTSInputRules
+	return nil
+}
+
+func deleteRule(n *fakeIPTablesRunner, curIPT map[string][]string, chain, delRule string) error {
+	// Get current rules for filter/ts-input chain with according IP version
+	curTSInputRules, ok := curIPT[chain]
+	if !ok {
+		n.t.Fatalf("no %s chain exists", chain)
+		return fmt.Errorf("no %s chain exists", chain)
+	}
+
+	// Remove rule from filter/ts-input
+	for i, rule := range curTSInputRules {
+		if rule == delRule {
+			curTSInputRules = append(curTSInputRules[:i], curTSInputRules[i+1:]...)
+			break
+		}
+	}
+	curIPT[chain] = curTSInputRules
+	return nil
+}
+
+func (n *fakeIPTablesRunner) AddLoopbackRule(addr netip.Addr) error {
+	curIPT := n.ipt4
+	if addr.Is6() {
+		curIPT = n.ipt6
 	}
+	newRule := fmt.Sprintf("-i lo -s %s -j ACCEPT", addr.String())
+
+	return insertRule(n, curIPT, "filter/ts-input", newRule)
 }
 
-func (n *fakeNetfilter) Insert(table, chain string, pos int, args ...string) error {
-	k := table + "/" + chain
-	if rules, ok := n.n[k]; ok {
-		if pos > len(rules)+1 {
-			n.t.Errorf("bad position %d in %s", pos, k)
-			return errExec
+func (n *fakeIPTablesRunner) AddBase(tunname string) error {
+	if err := n.AddBase4(tunname); err != nil {
+		return err
+	}
+	if n.HasIPV6() {
+		if err := n.AddBase6(tunname); err != nil {
+			return err
 		}
-		rules = append(rules, "")
-		copy(rules[pos:], rules[pos-1:])
-		rules[pos-1] = strings.Join(args, " ")
-		n.n[k] = rules
-	} else {
-		n.t.Errorf("unknown table/chain %s", k)
-		return errExec
 	}
 	return nil
 }
 
-func (n *fakeNetfilter) Append(table, chain string, args ...string) error {
-	k := table + "/" + chain
-	return n.Insert(table, chain, len(n.n[k])+1, args...)
+func (n *fakeIPTablesRunner) AddBase4(tunname string) error {
+	curIPT := n.ipt4
+	newRules := []struct{ chain, rule string }{
+		{"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j RETURN", tunname, tsaddr.ChromeOSVMRange().String())},
+		{"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())},
+		{"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},
+		{"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},
+		{"filter/ts-forward", fmt.Sprintf("-o %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())},
+		{"filter/ts-forward", fmt.Sprintf("-o %s -j ACCEPT", tunname)},
+	}
+	for _, rule := range newRules {
+		if err := appendRule(n, curIPT, rule.chain, rule.rule); err != nil {
+			return fmt.Errorf("add rule %q to chain %q: %w", rule.rule, rule.chain, err)
+		}
+	}
+	return nil
 }
 
-func (n *fakeNetfilter) Exists(table, chain string, args ...string) (bool, error) {
-	k := table + "/" + chain
-	if rules, ok := n.n[k]; ok {
-		for _, rule := range rules {
-			if rule == strings.Join(args, " ") {
-				return true, nil
+func (n *fakeIPTablesRunner) AddBase6(tunname string) error {
+	curIPT := n.ipt6
+	newRules := []struct{ chain, rule string }{
+		{"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},
+		{"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},
+		{"filter/ts-forward", fmt.Sprintf("-o %s -j ACCEPT", tunname)},
+	}
+	for _, rule := range newRules {
+		if err := appendRule(n, curIPT, rule.chain, rule.rule); err != nil {
+			return fmt.Errorf("add rule %q to chain %q: %w", rule.rule, rule.chain, err)
+		}
+	}
+	return nil
+}
+
+func (n *fakeIPTablesRunner) DelLoopbackRule(addr netip.Addr) error {
+	curIPT := n.ipt4
+	if addr.Is6() {
+		curIPT = n.ipt6
+	}
+
+	delRule := fmt.Sprintf("-i lo -s %s -j ACCEPT", addr.String())
+
+	return deleteRule(n, curIPT, "filter/ts-input", delRule)
+}
+
+func (n *fakeIPTablesRunner) AddHooks() error {
+	newRules := []struct{ chain, rule string }{
+		{"filter/INPUT", "-j ts-input"},
+		{"filter/FORWARD", "-j ts-forward"},
+		{"nat/POSTROUTING", "-j ts-postrouting"},
+	}
+	for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
+		for _, r := range newRules {
+			if err := insertRule(n, ipt, r.chain, r.rule); err != nil {
+				return err
 			}
 		}
-		return false, nil
-	} else {
-		n.t.Errorf("unknown table/chain %s", k)
-		return false, errExec
 	}
+	return nil
 }
 
-func (n *fakeNetfilter) Delete(table, chain string, args ...string) error {
-	k := table + "/" + chain
-	if rules, ok := n.n[k]; ok {
-		for i, rule := range rules {
-			if rule == strings.Join(args, " ") {
-				rules = append(rules[:i], rules[i+1:]...)
-				n.n[k] = rules
-				return nil
+func (n *fakeIPTablesRunner) DelHooks(logf logger.Logf) error {
+	delRules := []struct{ chain, rule string }{
+		{"filter/INPUT", "-j ts-input"},
+		{"filter/FORWARD", "-j ts-forward"},
+		{"nat/POSTROUTING", "-j ts-postrouting"},
+	}
+	for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
+		for _, r := range delRules {
+			if err := deleteRule(n, ipt, r.chain, r.rule); err != nil {
+				return err
 			}
 		}
-		n.t.Errorf("delete of unknown rule %q from %s", strings.Join(args, " "), k)
-		return errExec
-	} else {
-		n.t.Errorf("unknown table/chain %s", k)
-		return errExec
 	}
+	return nil
 }
 
-func (n *fakeNetfilter) ClearChain(table, chain string) error {
-	k := table + "/" + chain
-	if _, ok := n.n[k]; ok {
-		n.n[k] = nil
-		return nil
-	} else {
-		n.t.Logf("note: ClearChain: unknown table/chain %s", k)
-		return errors.New("exitcode:1")
+func (n *fakeIPTablesRunner) AddChains() error {
+	for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
+		for _, chain := range []string{"filter/ts-input", "filter/ts-forward", "nat/ts-postrouting"} {
+			ipt[chain] = nil
+		}
 	}
+	return nil
 }
 
-func (n *fakeNetfilter) NewChain(table, chain string) error {
-	k := table + "/" + chain
-	if _, ok := n.n[k]; ok {
-		n.t.Errorf("table/chain %s already exists", k)
-		return errExec
+func (n *fakeIPTablesRunner) DelChains() error {
+	for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
+		for chain := range ipt {
+			if strings.HasPrefix(chain, "filter/ts-") || strings.HasPrefix(chain, "nat/ts-") {
+				delete(ipt, chain)
+			}
+		}
 	}
-	n.n[k] = nil
 	return nil
 }
 
-func (n *fakeNetfilter) DeleteChain(table, chain string) error {
-	k := table + "/" + chain
-	if rules, ok := n.n[k]; ok {
-		if len(rules) != 0 {
-			n.t.Errorf("%s is not empty", k)
-			return errExec
+func (n *fakeIPTablesRunner) DelBase() error {
+	for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
+		for _, chain := range []string{"filter/ts-input", "filter/ts-forward", "nat/ts-postrouting"} {
+			ipt[chain] = nil
 		}
-		delete(n.n, k)
-		return nil
-	} else {
-		n.t.Errorf("%s does not exist", k)
-		return errExec
 	}
+	return nil
 }
 
+func (n *fakeIPTablesRunner) AddSNATRule() error {
+	newRule := fmt.Sprintf("-m mark --mark %s/%s -j MASQUERADE", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)
+	for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
+		if err := appendRule(n, ipt, "nat/ts-postrouting", newRule); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (n *fakeIPTablesRunner) DelSNATRule() error {
+	delRule := fmt.Sprintf("-m mark --mark %s/%s -j MASQUERADE", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)
+	for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} {
+		if err := deleteRule(n, ipt, "nat/ts-postrouting", delRule); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (n *fakeIPTablesRunner) HasIPV6() bool    { return true }
+func (n *fakeIPTablesRunner) HasIPV6NAT() bool { return true }
+
 // fakeOS implements commandRunner and provides v4 and v6
 // netfilterRunners, but captures changes without touching the OS.
 type fakeOS struct {
-	t          *testing.T
-	up         bool
-	ips        []string
-	routes     []string
-	rules      []string
-	netfilter4 *fakeNetfilter
-	netfilter6 *fakeNetfilter
+	t      *testing.T
+	up     bool
+	ips    []string
+	routes []string
+	rules  []string
+	//This test tests on the router level, so we will not bother
+	//with using iptables or nftables, chose the simpler one.
+	nfr netfilterRunner
 }
 
 func NewFakeOS(t *testing.T) *fakeOS {
 	return &fakeOS{
-		t:          t,
-		netfilter4: newNetfilter(t),
-		netfilter6: newNetfilter(t),
+		t:   t,
+		nfr: newIPTablesRunner(t),
 	}
 }
 
@@ -516,23 +635,23 @@ func (o *fakeOS) String() string {
 	}
 
 	var chains []string
-	for chain := range o.netfilter4.n {
+	for chain := range o.nfr.(*fakeIPTablesRunner).ipt4 {
 		chains = append(chains, chain)
 	}
 	sort.Strings(chains)
 	for _, chain := range chains {
-		for _, rule := range o.netfilter4.n[chain] {
+		for _, rule := range o.nfr.(*fakeIPTablesRunner).ipt4[chain] {
 			fmt.Fprintf(&b, "v4/%s %s\n", chain, rule)
 		}
 	}
 
 	chains = nil
-	for chain := range o.netfilter6.n {
+	for chain := range o.nfr.(*fakeIPTablesRunner).ipt6 {
 		chains = append(chains, chain)
 	}
 	sort.Strings(chains)
 	for _, chain := range chains {
-		for _, rule := range o.netfilter6.n[chain] {
+		for _, rule := range o.nfr.(*fakeIPTablesRunner).ipt6[chain] {
 			fmt.Fprintf(&b, "v6/%s %s\n", chain, rule)
 		}
 	}
@@ -806,7 +925,7 @@ func TestDebugListRules(t *testing.T) {
 }
 
 func TestCheckIPRuleSupportsV6(t *testing.T) {
-	err := checkIPRuleSupportsV6(t.Logf)
+	err := linuxfw.CheckIPRuleSupportsV6(t.Logf)
 	if err != nil && os.Getuid() != 0 {
 		t.Skipf("skipping, error when not root: %v", err)
 	}