Browse Source

util/linuxfw: move detection logic

Just a refactor to consolidate the firewall detection logic in a single
package so that it can be reused in a later commit by containerboot.

Updates #9310

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 2 years ago
parent
commit
05a1f5bf71

+ 110 - 0
util/linuxfw/detector.go

@@ -0,0 +1,110 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux
+
+package linuxfw
+
+import (
+	"tailscale.com/envknob"
+	"tailscale.com/hostinfo"
+	"tailscale.com/types/logger"
+	"tailscale.com/version/distro"
+)
+
+func detectFirewallMode(logf logger.Logf) FirewallMode {
+	if distro.Get() == distro.Gokrazy {
+		// Reduce startup logging on gokrazy. There's no way to do iptables on
+		// gokrazy anyway.
+		logf("GoKrazy should use nftables.")
+		hostinfo.SetFirewallMode("nft-gokrazy")
+		return FirewallModeNfTables
+	}
+
+	envMode := envknob.String("TS_DEBUG_FIREWALL_MODE")
+	// We now use iptables as default and have "auto" and "nftables" as
+	// options for people to test further.
+	switch envMode {
+	case "auto":
+		return pickFirewallModeFromInstalledRules(logf, linuxFWDetector{})
+	case "nftables":
+		logf("envknob TS_DEBUG_FIREWALL_MODE=nftables set")
+		hostinfo.SetFirewallMode("nft-forced")
+		return FirewallModeNfTables
+	case "iptables":
+		logf("envknob TS_DEBUG_FIREWALL_MODE=iptables set")
+		hostinfo.SetFirewallMode("ipt-forced")
+	default:
+		logf("default choosing iptables")
+		hostinfo.SetFirewallMode("ipt-default")
+	}
+	return FirewallModeIPTables
+}
+
+// tableDetector abstracts helpers to detect the firewall mode.
+// It is implemented for testing purposes.
+type tableDetector interface {
+	iptDetect() (int, error)
+	nftDetect() (int, error)
+}
+
+type linuxFWDetector struct{}
+
+// iptDetect returns the number of iptables rules in the current namespace.
+func (l linuxFWDetector) iptDetect() (int, error) {
+	return detectIptables()
+}
+
+// nftDetect returns the number of nftables rules in the current namespace.
+func (l linuxFWDetector) nftDetect() (int, error) {
+	return detectNetfilter()
+}
+
+// pickFirewallModeFromInstalledRules returns the firewall mode to use based on
+// the environment and the system's capabilities.
+func pickFirewallModeFromInstalledRules(logf logger.Logf, det tableDetector) FirewallMode {
+	if distro.Get() == distro.Gokrazy {
+		// Reduce startup logging on gokrazy. There's no way to do iptables on
+		// gokrazy anyway.
+		return FirewallModeNfTables
+	}
+	iptAva, nftAva := true, true
+	iptRuleCount, err := det.iptDetect()
+	if err != nil {
+		logf("detect iptables rule: %v", err)
+		iptAva = false
+	}
+	nftRuleCount, err := det.nftDetect()
+	if err != nil {
+		logf("detect nftables rule: %v", err)
+		nftAva = false
+	}
+	logf("nftables rule count: %d, iptables rule count: %d", nftRuleCount, iptRuleCount)
+	switch {
+	case nftRuleCount > 0 && iptRuleCount == 0:
+		logf("nftables is currently in use")
+		hostinfo.SetFirewallMode("nft-inuse")
+		return FirewallModeNfTables
+	case iptRuleCount > 0 && nftRuleCount == 0:
+		logf("iptables is currently in use")
+		hostinfo.SetFirewallMode("ipt-inuse")
+		return FirewallModeIPTables
+	case nftAva:
+		// if both iptables and nftables are available but
+		// neither/both are currently used, use nftables.
+		logf("nftables is available")
+		hostinfo.SetFirewallMode("nft")
+		return FirewallModeNfTables
+	case iptAva:
+		logf("iptables is available")
+		hostinfo.SetFirewallMode("ipt")
+		return FirewallModeIPTables
+	default:
+		// if neither iptables nor nftables are available, use iptablesRunner as a dummy
+		// runner which exists but won't do anything. Creating iptablesRunner errors only
+		// if the iptables command is missing or doesn’t support "--version", as long as it
+		// can determine a version then it’ll carry on.
+		hostinfo.SetFirewallMode("ipt-fb")
+		return FirewallModeIPTables
+	}
+}

+ 2 - 2
util/linuxfw/iptables.go

@@ -23,13 +23,13 @@ func DebugIptables(logf logger.Logf) error {
 	return nil
 }
 
-// DetectIptables returns the number of iptables rules that are present in the
+// detectIptables returns the number of iptables rules that are present in the
 // system, ignoring the default "ACCEPT" rule present in the standard iptables
 // chains.
 //
 // It only returns an error when there is no iptables binary, or when iptables -S
 // fails. In all other cases, it returns the number of non-default rules.
-func DetectIptables() (int, error) {
+func detectIptables() (int, error) {
 	// run "iptables -S" to get the list of rules using iptables
 	// exec.Command returns an error if the binary is not found
 	cmd := exec.Command("iptables", "-S")

+ 4 - 4
util/linuxfw/iptables_runner.go

@@ -45,11 +45,11 @@ func checkIP6TablesExists() error {
 	return nil
 }
 
-// NewIPTablesRunner constructs a NetfilterRunner that programs iptables rules.
+// newIPTablesRunner constructs a NetfilterRunner that programs iptables rules.
 // If the underlying iptables library fails to initialize, that error is
 // returned. The runner probes for IPv6 support once at initialization time and
 // if not found, no IPv6 rules will be modified for the lifetime of the runner.
-func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
+func newIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
 	ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
 	if err != nil {
 		return nil, err
@@ -79,12 +79,12 @@ func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
 	return &iptablesRunner{ipt4, ipt6, supportsV6, supportsV6NAT}, nil
 }
 
-// HasIPV6 returns true if the system supports IPv6.
+// HasIPV6 reports true if the system supports IPv6.
 func (i *iptablesRunner) HasIPV6() bool {
 	return i.v6Available
 }
 
-// HasIPV6NAT returns true if the system supports IPv6 NAT.
+// HasIPV6NAT reports true if the system supports IPv6 NAT.
 func (i *iptablesRunner) HasIPV6NAT() bool {
 	return i.v6NATAvailable
 }

+ 3 - 3
util/linuxfw/linuxfw_unsupported.go

@@ -25,16 +25,16 @@ func DebugNetfilter(logf logger.Logf) error {
 }
 
 // DetectNetfilter is not supported on non-Linux platforms.
-func DetectNetfilter() (int, error) {
+func detectNetfilter() (int, error) {
 	return 0, ErrUnsupported
 }
 
 // DebugIptables is not supported on non-Linux platforms.
-func DebugIptables(logf logger.Logf) error {
+func debugIptables(logf logger.Logf) error {
 	return ErrUnsupported
 }
 
 // DetectIptables is not supported on non-Linux platforms.
-func DetectIptables() (int, error) {
+func detectIptables() (int, error) {
 	return 0, ErrUnsupported
 }

+ 2 - 2
util/linuxfw/nftables.go

@@ -103,8 +103,8 @@ func DebugNetfilter(logf logger.Logf) error {
 	return nil
 }
 
-// DetectNetfilter returns the number of nftables rules present in the system.
-func DetectNetfilter() (int, error) {
+// detectNetfilter returns the number of nftables rules present in the system.
+func detectNetfilter() (int, error) {
 	conn, err := nftables.New()
 	if err != nil {
 		return 0, FWModeNotSupportedError{

+ 61 - 3
util/linuxfw/nftables_runner.go

@@ -175,9 +175,67 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
 	return nil
 }
 
-// NewNfTablesRunner creates a new nftablesRunner without guaranteeing
+// NetfilterRunner abstracts helpers to run netfilter commands. It is
+// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner.
+type NetfilterRunner interface {
+	// AddLoopbackRule adds a rule to permit loopback traffic to addr. This rule
+	// is added only if it does not already exist.
+	AddLoopbackRule(addr netip.Addr) error
+
+	// DelLoopbackRule removes the rule added by AddLoopbackRule.
+	DelLoopbackRule(addr netip.Addr) error
+
+	// AddHooks adds rules to conventional chains like "FORWARD", "INPUT" and
+	// "POSTROUTING" to jump from those chains to tailscale chains.
+	AddHooks() error
+
+	// DelHooks deletes rules added by AddHooks.
+	DelHooks(logf logger.Logf) error
+
+	// AddChains creates custom Tailscale chains.
+	AddChains() error
+
+	// DelChains removes chains added by AddChains.
+	DelChains() error
+
+	// AddBase adds rules reused by different other rules.
+	AddBase(tunname string) error
+
+	// DelBase removes rules added by AddBase.
+	DelBase() error
+
+	// AddSNATRule adds the netfilter rule to SNAT incoming traffic over
+	// the Tailscale interface destined for local subnets. An error is
+	// returned if the rule already exists.
+	AddSNATRule() error
+
+	// DelSNATRule removes the rule added by AddSNATRule.
+	DelSNATRule() error
+
+	// HasIPV6 reports true if the system supports IPv6.
+	HasIPV6() bool
+
+	// HasIPV6NAT reports true if the system supports IPv6 NAT.
+	HasIPV6NAT() bool
+}
+
+// New creates a NetfilterRunner using either nftables or iptables.
+// As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set.
+func New(logf logger.Logf) (NetfilterRunner, error) {
+	mode := detectFirewallMode(logf)
+	switch mode {
+	case FirewallModeIPTables:
+		return newIPTablesRunner(logf)
+	case FirewallModeNfTables:
+		return newNfTablesRunner(logf)
+	default:
+		return nil, fmt.Errorf("unknown firewall mode %v", mode)
+	}
+}
+
+// newNfTablesRunner creates a new nftablesRunner without guaranteeing
 // the existence of the tables and chains.
-func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
+func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
 	conn, err := nftables.New()
 	if err != nil {
 		return nil, fmt.Errorf("nftables connection: %w", err)
@@ -231,7 +289,7 @@ func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, err
 	}
 }
 
-// HasIPV6 returns true if the system supports IPv6.
+// HasIPV6 reports true if the system supports IPv6.
 func (n *nftablesRunner) HasIPV6() bool {
 	return n.v6Available
 }

+ 61 - 0
util/linuxfw/nftables_runner_test.go

@@ -7,6 +7,7 @@ package linuxfw
 
 import (
 	"bytes"
+	"errors"
 	"fmt"
 	"net/netip"
 	"os"
@@ -946,3 +947,63 @@ func TestNFTAddAndDelHookRule(t *testing.T) {
 		t.Fatalf("expected 0 rule in POSTROUTING chain, got %v", len(postroutingChainRules))
 	}
 }
+
+type testFWDetector struct {
+	iptRuleCount, nftRuleCount int
+	iptErr, nftErr             error
+}
+
+func (t *testFWDetector) iptDetect() (int, error) {
+	return t.iptRuleCount, t.iptErr
+}
+
+func (t *testFWDetector) nftDetect() (int, error) {
+	return t.nftRuleCount, t.nftErr
+}
+
+func TestPickFirewallModeFromInstalledRules(t *testing.T) {
+	tests := []struct {
+		name string
+		det  *testFWDetector
+		want FirewallMode
+	}{
+		{
+			name: "using iptables legacy",
+			det:  &testFWDetector{iptRuleCount: 1},
+			want: FirewallModeIPTables,
+		},
+		{
+			name: "using nftables",
+			det:  &testFWDetector{nftRuleCount: 1},
+			want: FirewallModeNfTables,
+		},
+		{
+			name: "using both iptables and nftables",
+			det:  &testFWDetector{iptRuleCount: 2, nftRuleCount: 2},
+			want: FirewallModeNfTables,
+		},
+		{
+			name: "not using any firewall, both available",
+			det:  &testFWDetector{},
+			want: FirewallModeNfTables,
+		},
+		{
+			name: "not using any firewall, iptables available only",
+			det:  &testFWDetector{iptRuleCount: 1, nftErr: errors.New("nft error")},
+			want: FirewallModeIPTables,
+		},
+		{
+			name: "not using any firewall, nftables available only",
+			det:  &testFWDetector{iptErr: errors.New("iptables error"), nftRuleCount: 1},
+			want: FirewallModeNfTables,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got := pickFirewallModeFromInstalledRules(t.Logf, tt.det)
+			if got != tt.want {
+				t.Errorf("chooseFireWallMode() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}

+ 3 - 143
wgengine/router/router_linux.go

@@ -22,7 +22,6 @@ import (
 	"golang.org/x/sys/unix"
 	"golang.org/x/time/rate"
 	"tailscale.com/envknob"
-	"tailscale.com/hostinfo"
 	"tailscale.com/net/netmon"
 	"tailscale.com/types/logger"
 	"tailscale.com/types/preftype"
@@ -37,145 +36,6 @@ const (
 	netfilterOn       = preftype.NetfilterOn
 )
 
-// netfilterRunner abstracts helpers to run netfilter commands. It is
-// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner.
-type netfilterRunner interface {
-	AddLoopbackRule(addr netip.Addr) error
-	DelLoopbackRule(addr netip.Addr) error
-	AddHooks() error
-	DelHooks(logf logger.Logf) error
-	AddChains() error
-	DelChains() error
-	AddBase(tunname string) error
-	DelBase() error
-	AddSNATRule() error
-	DelSNATRule() error
-
-	HasIPV6() bool
-	HasIPV6NAT() bool
-}
-
-// tableDetector abstracts helpers to detect the firewall mode.
-// It is implemented for testing purposes.
-type tableDetector interface {
-	iptDetect() (int, error)
-	nftDetect() (int, error)
-}
-
-type linuxFWDetector struct{}
-
-// iptDetect returns the number of iptables rules in the current namespace.
-func (l *linuxFWDetector) iptDetect() (int, error) {
-	return linuxfw.DetectIptables()
-}
-
-// nftDetect returns the number of nftables rules in the current namespace.
-func (l *linuxFWDetector) nftDetect() (int, error) {
-	return linuxfw.DetectNetfilter()
-}
-
-// chooseFireWallMode returns the firewall mode to use based on the
-// environment and the system's capabilities.
-func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMode {
-	if distro.Get() == distro.Gokrazy {
-		// Reduce startup logging on gokrazy. There's no way to do iptables on
-		// gokrazy anyway.
-		return linuxfw.FirewallModeNfTables
-	}
-	iptAva, nftAva := true, true
-	iptRuleCount, err := det.iptDetect()
-	if err != nil {
-		logf("detect iptables rule: %v", err)
-		iptAva = false
-	}
-	nftRuleCount, err := det.nftDetect()
-	if err != nil {
-		logf("detect nftables rule: %v", err)
-		nftAva = false
-	}
-	logf("nftables rule count: %d, iptables rule count: %d", nftRuleCount, iptRuleCount)
-	switch {
-	case nftRuleCount > 0 && iptRuleCount == 0:
-		logf("nftables is currently in use")
-		hostinfo.SetFirewallMode("nft-inuse")
-		return linuxfw.FirewallModeNfTables
-	case iptRuleCount > 0 && nftRuleCount == 0:
-		logf("iptables is currently in use")
-		hostinfo.SetFirewallMode("ipt-inuse")
-		return linuxfw.FirewallModeIPTables
-	case nftAva:
-		// if both iptables and nftables are available but
-		// neither/both are currently used, use nftables.
-		logf("nftables is available")
-		hostinfo.SetFirewallMode("nft")
-		return linuxfw.FirewallModeNfTables
-	case iptAva:
-		logf("iptables is available")
-		hostinfo.SetFirewallMode("ipt")
-		return linuxfw.FirewallModeIPTables
-	default:
-		// if neither iptables nor nftables are available, use iptablesRunner as a dummy
-		// runner which exists but won't do anything. Creating iptablesRunner errors only
-		// if the iptables command is missing or doesn’t support "--version", as long as it
-		// can determine a version then it’ll carry on.
-		hostinfo.SetFirewallMode("ipt-fb")
-		return linuxfw.FirewallModeIPTables
-	}
-}
-
-// newNetfilterRunner creates a netfilterRunner using either nftables or iptables.
-// As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set.
-func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) {
-	tableDetector := &linuxFWDetector{}
-	var mode linuxfw.FirewallMode
-
-	// We now use iptables as default and have "auto" and "nftables" as
-	// options for people to test further.
-	switch {
-	case distro.Get() == distro.Gokrazy:
-		// Reduce startup logging on gokrazy. There's no way to do iptables on
-		// gokrazy anyway.
-		logf("GoKrazy should use nftables.")
-		hostinfo.SetFirewallMode("nft-gokrazy")
-		mode = linuxfw.FirewallModeNfTables
-	case envknob.String("TS_DEBUG_FIREWALL_MODE") == "nftables":
-		logf("envknob TS_DEBUG_FIREWALL_MODE=nftables set")
-		hostinfo.SetFirewallMode("nft-forced")
-		mode = linuxfw.FirewallModeNfTables
-	case envknob.String("TS_DEBUG_FIREWALL_MODE") == "auto":
-		mode = chooseFireWallMode(logf, tableDetector)
-	case envknob.String("TS_DEBUG_FIREWALL_MODE") == "iptables":
-		logf("envknob TS_DEBUG_FIREWALL_MODE=iptables set")
-		hostinfo.SetFirewallMode("ipt-forced")
-		mode = linuxfw.FirewallModeIPTables
-	default:
-		logf("default choosing iptables")
-		hostinfo.SetFirewallMode("ipt-default")
-		mode = linuxfw.FirewallModeIPTables
-	}
-
-	var nfr netfilterRunner
-	var err error
-	switch mode {
-	case linuxfw.FirewallModeIPTables:
-		logf("using iptables")
-		nfr, err = linuxfw.NewIPTablesRunner(logf)
-		if err != nil {
-			return nil, err
-		}
-	case linuxfw.FirewallModeNfTables:
-		logf("using nftables")
-		nfr, err = linuxfw.NewNfTablesRunner(logf)
-		if err != nil {
-			return nil, err
-		}
-	default:
-		return nil, fmt.Errorf("unknown firewall mode: %v", mode)
-	}
-
-	return nfr, nil
-}
-
 type linuxRouter struct {
 	closed           atomic.Bool
 	logf             func(fmt string, args ...any)
@@ -200,7 +60,7 @@ type linuxRouter struct {
 	// ipPolicyPrefBase is the base priority at which ip rules are installed.
 	ipPolicyPrefBase int
 
-	nfr netfilterRunner
+	nfr linuxfw.NetfilterRunner
 	cmd commandRunner
 }
 
@@ -210,7 +70,7 @@ func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Moni
 		return nil, err
 	}
 
-	nfr, err := newNetfilterRunner(logf)
+	nfr, err := linuxfw.New(logf)
 	if err != nil {
 		return nil, err
 	}
@@ -222,7 +82,7 @@ func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Moni
 	return newUserspaceRouterAdvanced(logf, tunname, netMon, nfr, cmd)
 }
 
-func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, nfr netfilterRunner, cmd commandRunner) (Router, error) {
+func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, nfr linuxfw.NetfilterRunner, cmd commandRunner) (Router, error) {
 	r := &linuxRouter{
 		logf:          logf,
 		tunname:       tunname,

+ 2 - 62
wgengine/router/router_linux_test.go

@@ -372,7 +372,7 @@ type fakeIPTablesRunner struct {
 	//we always assume ipv6 and ipv6 nat are enabled when testing
 }
 
-func newIPTablesRunner(t *testing.T) netfilterRunner {
+func newIPTablesRunner(t *testing.T) linuxfw.NetfilterRunner {
 	return &fakeIPTablesRunner{
 		t: t,
 		ipt4: map[string][]string{
@@ -603,7 +603,7 @@ type fakeOS struct {
 	rules  []string
 	//This test tests on the router level, so we will not bother
 	//with using iptables or nftables, chose the simpler one.
-	nfr netfilterRunner
+	nfr linuxfw.NetfilterRunner
 }
 
 func NewFakeOS(t *testing.T) *fakeOS {
@@ -1063,63 +1063,3 @@ func adjustFwmask(t *testing.T, s string) string {
 
 	return fwmaskAdjustRe.ReplaceAllString(s, "$1")
 }
-
-type testFWDetector struct {
-	iptRuleCount, nftRuleCount int
-	iptErr, nftErr             error
-}
-
-func (t *testFWDetector) iptDetect() (int, error) {
-	return t.iptRuleCount, t.iptErr
-}
-
-func (t *testFWDetector) nftDetect() (int, error) {
-	return t.nftRuleCount, t.nftErr
-}
-
-func TestChooseFireWallMode(t *testing.T) {
-	tests := []struct {
-		name string
-		det  *testFWDetector
-		want linuxfw.FirewallMode
-	}{
-		{
-			name: "using iptables legacy",
-			det:  &testFWDetector{iptRuleCount: 1},
-			want: linuxfw.FirewallModeIPTables,
-		},
-		{
-			name: "using nftables",
-			det:  &testFWDetector{nftRuleCount: 1},
-			want: linuxfw.FirewallModeNfTables,
-		},
-		{
-			name: "using both iptables and nftables",
-			det:  &testFWDetector{iptRuleCount: 2, nftRuleCount: 2},
-			want: linuxfw.FirewallModeNfTables,
-		},
-		{
-			name: "not using any firewall, both available",
-			det:  &testFWDetector{},
-			want: linuxfw.FirewallModeNfTables,
-		},
-		{
-			name: "not using any firewall, iptables available only",
-			det:  &testFWDetector{iptRuleCount: 1, nftErr: errors.New("nft error")},
-			want: linuxfw.FirewallModeIPTables,
-		},
-		{
-			name: "not using any firewall, nftables available only",
-			det:  &testFWDetector{iptErr: errors.New("iptables error"), nftRuleCount: 1},
-			want: linuxfw.FirewallModeNfTables,
-		},
-	}
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			got := chooseFireWallMode(t.Logf, tt.det)
-			if got != tt.want {
-				t.Errorf("chooseFireWallMode() = %v, want %v", got, tt.want)
-			}
-		})
-	}
-}