Browse Source

util/linuxfw,wgengine/router: add new netfilter rules for HA ingresses (#15896)

Add new rules to update DNAT rules for Kubernetes operator's
HA ingress where it's expected that rules will be added/removed
frequently (so we don't want to keep old rules around or rewrite
existing rules unnecessarily):
- allow deleting DNAT rules using metadata lookup
- allow inserting DNAT rules if they don't already
exist (using metadata lookup)

Updates tailscale/tailscale#15895

Signed-off-by: Irbe Krumina <[email protected]>
Co-authored-by: chaosinthecrd <[email protected]>
Irbe Krumina 9 months ago
parent
commit
2c16fcaa06

+ 95 - 0
util/linuxfw/fake_netfilter.go

@@ -0,0 +1,95 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+//go:build linux
+
+package linuxfw
+
+import (
+	"net/netip"
+
+	"tailscale.com/types/logger"
+)
+
+// FakeNetfilterRunner is a fake netfilter runner for tests.
+type FakeNetfilterRunner struct {
+	// services is a map that tracks the firewall rules added/deleted via
+	// EnsureDNATRuleForSvc/DeleteDNATRuleForSvc.
+	services map[string]struct {
+		VIPServiceIP netip.Addr
+		ClusterIP    netip.Addr
+	}
+}
+
+// NewFakeNetfilterRunner creates a new FakeNetfilterRunner.
+func NewFakeNetfilterRunner() *FakeNetfilterRunner {
+	return &FakeNetfilterRunner{
+		services: make(map[string]struct {
+			VIPServiceIP netip.Addr
+			ClusterIP    netip.Addr
+		}),
+	}
+}
+
+func (f *FakeNetfilterRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
+	f.services[svcName] = struct {
+		VIPServiceIP netip.Addr
+		ClusterIP    netip.Addr
+	}{origDst, dst}
+	return nil
+}
+
+func (f *FakeNetfilterRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
+	delete(f.services, svcName)
+	return nil
+}
+
+func (f *FakeNetfilterRunner) GetServiceState() map[string]struct {
+	VIPServiceIP netip.Addr
+	ClusterIP    netip.Addr
+} {
+	return f.services
+}
+
+func (f *FakeNetfilterRunner) HasIPV6() bool {
+	return true
+}
+
+func (f *FakeNetfilterRunner) HasIPV6Filter() bool {
+	return true
+}
+
+func (f *FakeNetfilterRunner) HasIPV6NAT() bool {
+	return true
+}
+
+func (f *FakeNetfilterRunner) AddBase(tunname string) error              { return nil }
+func (f *FakeNetfilterRunner) DelBase() error                            { return nil }
+func (f *FakeNetfilterRunner) AddChains() error                          { return nil }
+func (f *FakeNetfilterRunner) DelChains() error                          { return nil }
+func (f *FakeNetfilterRunner) AddHooks() error                           { return nil }
+func (f *FakeNetfilterRunner) DelHooks(logf logger.Logf) error           { return nil }
+func (f *FakeNetfilterRunner) AddSNATRule() error                        { return nil }
+func (f *FakeNetfilterRunner) DelSNATRule() error                        { return nil }
+func (f *FakeNetfilterRunner) AddStatefulRule(tunname string) error      { return nil }
+func (f *FakeNetfilterRunner) DelStatefulRule(tunname string) error      { return nil }
+func (f *FakeNetfilterRunner) AddLoopbackRule(addr netip.Addr) error     { return nil }
+func (f *FakeNetfilterRunner) DelLoopbackRule(addr netip.Addr) error     { return nil }
+func (f *FakeNetfilterRunner) AddDNATRule(origDst, dst netip.Addr) error { return nil }
+func (f *FakeNetfilterRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error {
+	return nil
+}
+func (f *FakeNetfilterRunner) EnsureSNATForDst(src, dst netip.Addr) error               { return nil }
+func (f *FakeNetfilterRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error { return nil }
+func (f *FakeNetfilterRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error         { return nil }
+func (f *FakeNetfilterRunner) AddMagicsockPortRule(port uint16, network string) error   { return nil }
+func (f *FakeNetfilterRunner) DelMagicsockPortRule(port uint16, network string) error   { return nil }
+func (f *FakeNetfilterRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
+	return nil
+}
+func (f *FakeNetfilterRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pms []PortMap) error {
+	return nil
+}
+func (f *FakeNetfilterRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
+	return nil
+}

+ 53 - 5
util/linuxfw/iptables_for_svcs.go

@@ -13,6 +13,7 @@ import (
 // This file contains functionality to insert portmapping rules for a 'service'.
 // These are currently only used by the Kubernetes operator proxies.
 // An iptables rule for such a service contains a comment with the service name.
+// A 'service' corresponds to a VIPService as used by the Kubernetes operator.
 
 // EnsurePortMapRuleForSvc adds a prerouting rule that forwards traffic received
 // on match port and NOT on the provided interface to target IP and target port.
@@ -24,10 +25,10 @@ func (i *iptablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip
 	if err != nil {
 		return fmt.Errorf("error checking if rule exists: %w", err)
 	}
-	if !exists {
-		return table.Append("nat", "PREROUTING", args...)
+	if exists {
+		return nil
 	}
-	return nil
+	return table.Append("nat", "PREROUTING", args...)
 }
 
 // DeleteMapRuleForSvc constructs a prerouting rule as would be created by
@@ -40,10 +41,41 @@ func (i *iptablesRunner) DeletePortMapRuleForSvc(svc, excludeI string, targetIP
 	if err != nil {
 		return fmt.Errorf("error checking if rule exists: %w", err)
 	}
+	if !exists {
+		return nil
+	}
+	return table.Delete("nat", "PREROUTING", args...)
+}
+
+// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the
+// VIPService IP address to a local address. This is used by the Kubernetes
+// operator's network layer proxies to forward tailnet traffic for VIPServices
+// to Kubernetes Services.
+func (i *iptablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
+	table := i.getIPTByAddr(dst)
+	args := argsForIngressRule(svcName, origDst, dst)
+	exists, err := table.Exists("nat", "PREROUTING", args...)
+	if err != nil {
+		return fmt.Errorf("error checking if rule exists: %w", err)
+	}
 	if exists {
-		return table.Delete("nat", "PREROUTING", args...)
+		return nil
 	}
-	return nil
+	return table.Append("nat", "PREROUTING", args...)
+}
+
+// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc.
+func (i *iptablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
+	table := i.getIPTByAddr(dst)
+	args := argsForIngressRule(svcName, origDst, dst)
+	exists, err := table.Exists("nat", "PREROUTING", args...)
+	if err != nil {
+		return fmt.Errorf("error checking if rule exists: %w", err)
+	}
+	if !exists {
+		return nil
+	}
+	return table.Delete("nat", "PREROUTING", args...)
 }
 
 // DeleteSvc constructs all possible rules that would have been created by
@@ -72,8 +104,24 @@ func argsForPortMapRule(svc, excludeI string, targetIP netip.Addr, pm PortMap) [
 	}
 }
 
+func argsForIngressRule(svcName string, origDst, targetIP netip.Addr) []string {
+	c := commentForIngressSvc(svcName, origDst, targetIP)
+	return []string{
+		"--destination", origDst.String(),
+		"-m", "comment", "--comment", c,
+		"-j", "DNAT",
+		"--to-destination", targetIP.String(),
+	}
+}
+
 // commentForSvc generates a comment to be added to an iptables DNAT rule for a
 // service. This is for iptables debugging/readability purposes only.
 func commentForSvc(svc string, pm PortMap) string {
 	return fmt.Sprintf("%s:%s:%d -> %s:%d", svc, pm.Protocol, pm.MatchPort, pm.Protocol, pm.TargetPort)
 }
+
+// commentForIngressSvc generates a comment to be added to an iptables DNAT rule for a
+// service. This is for iptables debugging/readability purposes only.
+func commentForIngressSvc(svc string, vip, clusterIP netip.Addr) string {
+	return fmt.Sprintf("svc: %s, %s -> %s", svc, vip.String(), clusterIP.String())
+}

+ 129 - 0
util/linuxfw/iptables_for_svcs_test.go

@@ -153,6 +153,135 @@ func Test_iptablesRunner_DeleteSvc(t *testing.T) {
 	svcMustExist(t, "svc2", map[string][]string{v4Addr.String(): s2R1, v6Addr.String(): s2R2}, iptr)
 }
 
+func Test_iptablesRunner_EnsureDNATRuleForSvc(t *testing.T) {
+	v4OrigDst := netip.MustParseAddr("10.0.0.1")
+	v4Target := netip.MustParseAddr("10.0.0.2")
+	v6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1")
+	v6Target := netip.MustParseAddr("fd7a:115c:a1e0::2")
+	v4Rule := argsForIngressRule("svc:test", v4OrigDst, v4Target)
+
+	tests := []struct {
+		name              string
+		svcName           string
+		origDst           netip.Addr
+		targetIP          netip.Addr
+		precreateSvcRules [][]string
+	}{
+		{
+			name:     "dnat_for_ipv4",
+			svcName:  "svc:test",
+			origDst:  v4OrigDst,
+			targetIP: v4Target,
+		},
+		{
+			name:     "dnat_for_ipv6",
+			svcName:  "svc:test-2",
+			origDst:  v6OrigDst,
+			targetIP: v6Target,
+		},
+		{
+			name:              "add_existing_rule",
+			svcName:           "svc:test",
+			origDst:           v4OrigDst,
+			targetIP:          v4Target,
+			precreateSvcRules: [][]string{v4Rule},
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			iptr := NewFakeIPTablesRunner()
+			table := iptr.getIPTByAddr(tt.targetIP)
+			for _, ruleset := range tt.precreateSvcRules {
+				mustPrecreateDNATRule(t, ruleset, table)
+			}
+			if err := iptr.EnsureDNATRuleForSvc(tt.svcName, tt.origDst, tt.targetIP); err != nil {
+				t.Errorf("[unexpected error] iptablesRunner.EnsureDNATRuleForSvc() = %v", err)
+			}
+			args := argsForIngressRule(tt.svcName, tt.origDst, tt.targetIP)
+			exists, err := table.Exists("nat", "PREROUTING", args...)
+			if err != nil {
+				t.Fatalf("error checking if rule exists: %v", err)
+			}
+			if !exists {
+				t.Errorf("expected rule was not created")
+			}
+		})
+	}
+}
+
+func Test_iptablesRunner_DeleteDNATRuleForSvc(t *testing.T) {
+	v4OrigDst := netip.MustParseAddr("10.0.0.1")
+	v4Target := netip.MustParseAddr("10.0.0.2")
+	v6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1")
+	v6Target := netip.MustParseAddr("fd7a:115c:a1e0::2")
+	v4Rule := argsForIngressRule("svc:test", v4OrigDst, v4Target)
+	v6Rule := argsForIngressRule("svc:test", v6OrigDst, v6Target)
+
+	tests := []struct {
+		name              string
+		svcName           string
+		origDst           netip.Addr
+		targetIP          netip.Addr
+		precreateSvcRules [][]string
+	}{
+		{
+			name:              "multiple_rules_ipv4_deleted",
+			svcName:           "svc:test",
+			origDst:           v4OrigDst,
+			targetIP:          v4Target,
+			precreateSvcRules: [][]string{v4Rule, v6Rule},
+		},
+		{
+			name:              "multiple_rules_ipv6_deleted",
+			svcName:           "svc:test",
+			origDst:           v6OrigDst,
+			targetIP:          v6Target,
+			precreateSvcRules: [][]string{v4Rule, v6Rule},
+		},
+		{
+			name:              "non-existent_rule_deleted",
+			svcName:           "svc:test",
+			origDst:           v4OrigDst,
+			targetIP:          v4Target,
+			precreateSvcRules: [][]string{v6Rule},
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			iptr := NewFakeIPTablesRunner()
+			table := iptr.getIPTByAddr(tt.targetIP)
+			for _, ruleset := range tt.precreateSvcRules {
+				mustPrecreateDNATRule(t, ruleset, table)
+			}
+			if err := iptr.DeleteDNATRuleForSvc(tt.svcName, tt.origDst, tt.targetIP); err != nil {
+				t.Errorf("iptablesRunner.DeleteDNATRuleForSvc() errored: %v ", err)
+			}
+			deletedRule := argsForIngressRule(tt.svcName, tt.origDst, tt.targetIP)
+			exists, err := table.Exists("nat", "PREROUTING", deletedRule...)
+			if err != nil {
+				t.Fatalf("error verifying that rule does not exist after deletion: %v", err)
+			}
+			if exists {
+				t.Errorf("DNAT rule exists after deletion")
+			}
+		})
+	}
+}
+
+func mustPrecreateDNATRule(t *testing.T, rules []string, table iptablesInterface) {
+	t.Helper()
+	exists, err := table.Exists("nat", "PREROUTING", rules...)
+	if err != nil {
+		t.Fatalf("error ensuring that nat PREROUTING table exists: %v", err)
+	}
+	if exists {
+		return
+	}
+	if err := table.Append("nat", "PREROUTING", rules...); err != nil {
+		t.Fatalf("error precreating DNAT rule: %v", err)
+	}
+}
+
 func svcMustExist(t *testing.T, svcName string, rules map[string][]string, iptr *iptablesRunner) {
 	t.Helper()
 	for dst, ruleset := range rules {

+ 64 - 0
util/linuxfw/nftables_for_svcs.go

@@ -119,6 +119,63 @@ func (n *nftablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm [
 	return n.conn.Flush()
 }
 
+// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the
+// VIPService IP address to a local address. This is used by the Kubernetes
+// operator's network layer proxies to forward tailnet traffic for VIPServices
+// to Kubernetes Services.
+func (n *nftablesRunner) EnsureDNATRuleForSvc(svc string, origDst, dst netip.Addr) error {
+	t, ch, err := n.ensurePreroutingChain(origDst)
+	if err != nil {
+		return fmt.Errorf("error ensuring chain for %s: %w", svc, err)
+	}
+	meta := svcRuleMeta(svc, origDst, dst)
+	rule, err := n.findRuleByMetadata(t, ch, meta)
+	if err != nil {
+		return fmt.Errorf("error looking up rule: %w", err)
+	}
+	if rule != nil {
+		return nil
+	}
+	rule = dnatRuleForChain(t, ch, origDst, dst, meta)
+	n.conn.InsertRule(rule)
+	return n.conn.Flush()
+}
+
+// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc.
+// We use the metadata attached to the rule to look it up.
+func (n *nftablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
+	table, err := n.getNFTByAddr(origDst)
+	if err != nil {
+		return fmt.Errorf("error setting up nftables for IP family of %s: %w", origDst, err)
+	}
+	t, err := getTableIfExists(n.conn, table.Proto, "nat")
+	if err != nil {
+		return fmt.Errorf("error checking if nat table exists: %w", err)
+	}
+	if t == nil {
+		return nil
+	}
+	ch, err := getChainFromTable(n.conn, t, "PREROUTING")
+	if errors.Is(err, errorChainNotFound{tableName: "nat", chainName: "PREROUTING"}) {
+		return nil
+	}
+	if err != nil {
+		return fmt.Errorf("error checking if chain PREROUTING exists: %w", err)
+	}
+	meta := svcRuleMeta(svcName, origDst, dst)
+	rule, err := n.findRuleByMetadata(t, ch, meta)
+	if err != nil {
+		return fmt.Errorf("error checking if rule exists: %w", err)
+	}
+	if rule == nil {
+		return nil
+	}
+	if err := n.conn.DelRule(rule); err != nil {
+		return fmt.Errorf("error deleting rule: %w", err)
+	}
+	return n.conn.Flush()
+}
+
 func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule {
 	var fam uint32
 	if targetIP.Is4() {
@@ -243,3 +300,10 @@ func protoFromString(s string) (uint8, error) {
 		return 0, fmt.Errorf("unrecognized protocol: %q", s)
 	}
 }
+
+// svcRuleMeta generates metadata for a rule.
+// This metadata can then be used to find the rule.
+// https://github.com/google/nftables/issues/48
+func svcRuleMeta(svcName string, origDst, dst netip.Addr) []byte {
+	return []byte(fmt.Sprintf("svc:%s,VIP:%s,ClusterIP:%s", svcName, origDst.String(), dst.String()))
+}

+ 194 - 29
util/linuxfw/nftables_for_svcs_test.go

@@ -14,8 +14,9 @@ import (
 
 // This test creates a temporary network namespace for the nftables rules being
 // set up, so it needs to run in a privileged mode. Locally it needs to be run
-// by root, else it will be silently skipped. In CI it runs in a privileged
-// container.
+// by root, else it will be silently skipped.
+// sudo  go test -v -run Test_nftablesRunner_EnsurePortMapRuleForSvc ./util/linuxfw/...
+// In CI it runs in a privileged container.
 func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) {
 	conn := newSysConn(t)
 	runner := newFakeNftablesRunnerWithConn(t, conn, true)
@@ -23,51 +24,215 @@ func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) {
 	pmTCP := PortMap{MatchPort: 4003, TargetPort: 80, Protocol: "TCP"}
 	pmTCP1 := PortMap{MatchPort: 4004, TargetPort: 443, Protocol: "TCP"}
 
-	// Create a rule for service 'foo' to forward TCP traffic to IPv4 endpoint
-	runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP)
+	// Create a rule for service 'svc:foo' to forward TCP traffic to IPv4 endpoint
+	runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP)
 	svcChains(t, 1, conn)
-	chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4)
-	checkPortMapRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
+	chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv4)
+	checkPortMapRule(t, "svc:foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
 
-	// Create another rule for service 'foo' to forward TCP traffic to the
+	// Create another rule for service 'svc:foo' to forward TCP traffic to the
 	// same IPv4 endpoint, but to a different port.
-	runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP1)
+	runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP1)
 	svcChains(t, 1, conn)
-	chainRuleCount(t, "foo", 2, conn, nftables.TableFamilyIPv4)
-	checkPortMapRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
+	chainRuleCount(t, "svc:foo", 2, conn, nftables.TableFamilyIPv4)
+	checkPortMapRule(t, "svc:foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
 
-	// Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint
-	runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv6, pmTCP)
+	// Create a rule for service 'svc:foo' to forward TCP traffic to an IPv6 endpoint
+	runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv6, pmTCP)
 	svcChains(t, 2, conn)
-	chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv6)
-	checkPortMapRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
+	chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv6)
+	checkPortMapRule(t, "svc:foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
 
-	// Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint
-	runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv4, pmTCP)
+	// Create a rule for service 'svc:bar' to forward TCP traffic to IPv4 endpoint
+	runner.EnsurePortMapRuleForSvc("svc:bar", "tailscale0", ipv4, pmTCP)
 	svcChains(t, 3, conn)
-	chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv4)
-	checkPortMapRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
+	chainRuleCount(t, "svc:bar", 1, conn, nftables.TableFamilyIPv4)
+	checkPortMapRule(t, "svc:bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
 
-	// Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint
-	runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv6, pmTCP)
+	// Create a rule for service 'svc:bar' to forward TCP traffic to an IPv6 endpoint
+	runner.EnsurePortMapRuleForSvc("svc:bar", "tailscale0", ipv6, pmTCP)
 	svcChains(t, 4, conn)
-	chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv6)
-	checkPortMapRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
+	chainRuleCount(t, "svc:bar", 1, conn, nftables.TableFamilyIPv6)
+	checkPortMapRule(t, "svc:bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
 
-	// Delete service bar
-	runner.DeleteSvc("bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP})
+	// Delete service svc:bar
+	runner.DeleteSvc("svc:bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP})
 	svcChains(t, 2, conn)
 
-	// Delete a rule from service foo
-	runner.DeletePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP)
+	// Delete a rule from service svc:foo
+	runner.DeletePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP)
 	svcChains(t, 2, conn)
-	chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4)
+	chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv4)
 
-	// Delete service foo
-	runner.DeleteSvc("foo", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP, pmTCP1})
+	// Delete service svc:foo
+	runner.DeleteSvc("svc:foo", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP, pmTCP1})
 	svcChains(t, 0, conn)
 }
 
+func Test_nftablesRunner_EnsureDNATRuleForSvc(t *testing.T) {
+	conn := newSysConn(t)
+	runner := newFakeNftablesRunnerWithConn(t, conn, true)
+
+	// Test IPv4 DNAT rule
+	ipv4OrigDst := netip.MustParseAddr("10.0.0.1")
+	ipv4Target := netip.MustParseAddr("10.0.0.2")
+
+	// Create DNAT rule for service 'svc:foo' to forward IPv4 traffic
+	err := runner.EnsureDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target)
+	if err != nil {
+		t.Fatalf("error creating IPv4 DNAT rule: %v", err)
+	}
+	checkDNATRule(t, "svc:foo", ipv4OrigDst, ipv4Target, runner, nftables.TableFamilyIPv4)
+
+	// Test IPv6 DNAT rule
+	ipv6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1")
+	ipv6Target := netip.MustParseAddr("fd7a:115c:a1e0::2")
+
+	// Create DNAT rule for service 'svc:foo' to forward IPv6 traffic
+	err = runner.EnsureDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target)
+	if err != nil {
+		t.Fatalf("error creating IPv6 DNAT rule: %v", err)
+	}
+	checkDNATRule(t, "svc:foo", ipv6OrigDst, ipv6Target, runner, nftables.TableFamilyIPv6)
+
+	// Test creating rule for another service
+	err = runner.EnsureDNATRuleForSvc("svc:bar", ipv4OrigDst, ipv4Target)
+	if err != nil {
+		t.Fatalf("error creating DNAT rule for service 'svc:bar': %v", err)
+	}
+	checkDNATRule(t, "svc:bar", ipv4OrigDst, ipv4Target, runner, nftables.TableFamilyIPv4)
+}
+
+func Test_nftablesRunner_DeleteDNATRuleForSvc(t *testing.T) {
+	conn := newSysConn(t)
+	runner := newFakeNftablesRunnerWithConn(t, conn, true)
+
+	// Test IPv4 DNAT rule deletion
+	ipv4OrigDst := netip.MustParseAddr("10.0.0.1")
+	ipv4Target := netip.MustParseAddr("10.0.0.2")
+
+	// Create and then delete IPv4 DNAT rule
+	err := runner.EnsureDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target)
+	if err != nil {
+		t.Fatalf("error creating IPv4 DNAT rule: %v", err)
+	}
+
+	// Verify rule exists before deletion
+	table, err := runner.getNFTByAddr(ipv4OrigDst)
+	if err != nil {
+		t.Fatalf("error getting table: %v", err)
+	}
+	nftTable, err := getTableIfExists(runner.conn, table.Proto, "nat")
+	if err != nil {
+		t.Fatalf("error getting nat table: %v", err)
+	}
+	ch, err := getChainFromTable(runner.conn, nftTable, "PREROUTING")
+	if err != nil {
+		t.Fatalf("error getting PREROUTING chain: %v", err)
+	}
+	meta := svcRuleMeta("svc:foo", ipv4OrigDst, ipv4Target)
+	rule, err := runner.findRuleByMetadata(nftTable, ch, meta)
+	if err != nil {
+		t.Fatalf("error checking if rule exists: %v", err)
+	}
+	if rule == nil {
+		t.Fatal("rule does not exist before deletion")
+	}
+
+	err = runner.DeleteDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target)
+	if err != nil {
+		t.Fatalf("error deleting IPv4 DNAT rule: %v", err)
+	}
+
+	// Verify rule is deleted
+	rule, err = runner.findRuleByMetadata(nftTable, ch, meta)
+	if err != nil {
+		t.Fatalf("error checking if rule exists: %v", err)
+	}
+	if rule != nil {
+		t.Fatal("rule still exists after deletion")
+	}
+
+	// Test IPv6 DNAT rule deletion
+	ipv6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1")
+	ipv6Target := netip.MustParseAddr("fd7a:115c:a1e0::2")
+
+	// Create and then delete IPv6 DNAT rule
+	err = runner.EnsureDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target)
+	if err != nil {
+		t.Fatalf("error creating IPv6 DNAT rule: %v", err)
+	}
+
+	// Verify rule exists before deletion
+	table, err = runner.getNFTByAddr(ipv6OrigDst)
+	if err != nil {
+		t.Fatalf("error getting table: %v", err)
+	}
+	nftTable, err = getTableIfExists(runner.conn, table.Proto, "nat")
+	if err != nil {
+		t.Fatalf("error getting nat table: %v", err)
+	}
+	ch, err = getChainFromTable(runner.conn, nftTable, "PREROUTING")
+	if err != nil {
+		t.Fatalf("error getting PREROUTING chain: %v", err)
+	}
+	meta = svcRuleMeta("svc:foo", ipv6OrigDst, ipv6Target)
+	rule, err = runner.findRuleByMetadata(nftTable, ch, meta)
+	if err != nil {
+		t.Fatalf("error checking if rule exists: %v", err)
+	}
+	if rule == nil {
+		t.Fatal("rule does not exist before deletion")
+	}
+
+	err = runner.DeleteDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target)
+	if err != nil {
+		t.Fatalf("error deleting IPv6 DNAT rule: %v", err)
+	}
+
+	// Verify rule is deleted
+	rule, err = runner.findRuleByMetadata(nftTable, ch, meta)
+	if err != nil {
+		t.Fatalf("error checking if rule exists: %v", err)
+	}
+	if rule != nil {
+		t.Fatal("rule still exists after deletion")
+	}
+}
+
+// checkDNATRule verifies that a DNAT rule exists for the given service, original destination, and target IP.
+func checkDNATRule(t *testing.T, svc string, origDst, targetIP netip.Addr, runner *nftablesRunner, fam nftables.TableFamily) {
+	t.Helper()
+	table, err := runner.getNFTByAddr(origDst)
+	if err != nil {
+		t.Fatalf("error getting table: %v", err)
+	}
+	nftTable, err := getTableIfExists(runner.conn, table.Proto, "nat")
+	if err != nil {
+		t.Fatalf("error getting nat table: %v", err)
+	}
+	if nftTable == nil {
+		t.Fatal("nat table not found")
+	}
+
+	ch, err := getChainFromTable(runner.conn, nftTable, "PREROUTING")
+	if err != nil {
+		t.Fatalf("error getting PREROUTING chain: %v", err)
+	}
+	if ch == nil {
+		t.Fatal("PREROUTING chain not found")
+	}
+
+	meta := svcRuleMeta(svc, origDst, targetIP)
+	rule, err := runner.findRuleByMetadata(nftTable, ch, meta)
+	if err != nil {
+		t.Fatalf("error checking if rule exists: %v", err)
+	}
+	if rule == nil {
+		t.Fatal("DNAT rule not found")
+	}
+}
+
 // svcChains verifies that the expected number of chains exist (for either IP
 // family) and that each of them is configured as NAT prerouting chain.
 func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) {

+ 15 - 5
util/linuxfw/nftables_runner.go

@@ -107,6 +107,12 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
 	if err != nil {
 		return err
 	}
+	rule := dnatRuleForChain(nat, preroutingCh, origDst, dst, nil)
+	n.conn.InsertRule(rule)
+	return n.conn.Flush()
+}
+
+func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.Addr, meta []byte) *nftables.Rule {
 	var daddrOffset, fam, dadderLen uint32
 	if origDst.Is4() {
 		daddrOffset = 16
@@ -117,9 +123,9 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
 		dadderLen = 16
 		fam = unix.NFPROTO_IPV6
 	}
-	dnatRule := &nftables.Rule{
-		Table: nat,
-		Chain: preroutingCh,
+	rule := &nftables.Rule{
+		Table: t,
+		Chain: ch,
 		Exprs: []expr.Any{
 			&expr.Payload{
 				DestRegister: 1,
@@ -143,8 +149,10 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
 			},
 		},
 	}
-	n.conn.InsertRule(dnatRule)
-	return n.conn.Flush()
+	if len(meta) > 0 {
+		rule.UserData = meta
+	}
+	return rule
 }
 
 // DNATWithLoadBalancer currently just forwards all traffic destined for origDst
@@ -555,6 +563,8 @@ type NetfilterRunner interface {
 	EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error
 
 	DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error
+	EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error
+	DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error
 
 	DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error
 

+ 8 - 0
wgengine/router/router_linux_test.go

@@ -557,6 +557,14 @@ func (n *fakeIPTablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
 	return errors.New("not implemented")
 }
 
+func (n *fakeIPTablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
+	return errors.New("not implemented")
+}
+
+func (n *fakeIPTablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
+	return errors.New("not implemented")
+}
+
 func (n *fakeIPTablesRunner) addBase4(tunname string) error {
 	curIPT := n.ipt4
 	newRules := []struct{ chain, rule string }{