Procházet zdrojové kódy

control,tailcfg,wgengine/magicsock: add nodeAttr to enable/disable peer MTU

Add a nodeAttr to enable/disable peer path MTU discovery.

Updates #311

Signed-off-by: Val <[email protected]>
Val před 2 roky
rodič
revize
65dc711c76

+ 6 - 0
control/controlknobs/controlknobs.go

@@ -45,6 +45,9 @@ type Knobs struct {
 	// incremental (delta) netmap updates and should treat all netmap
 	// changes as "full" ones as tailscaled did in 1.48.x and earlier.
 	DisableDeltaUpdates atomic.Bool
+
+	// PeerMTUEnable is whether the node should do peer path MTU discovery.
+	PeerMTUEnable atomic.Bool
 }
 
 // UpdateFromNodeAttributes updates k (if non-nil) based on the provided self
@@ -65,6 +68,7 @@ func (k *Knobs) UpdateFromNodeAttributes(selfNodeAttrs []tailcfg.NodeCapability,
 		disableDeltaUpdates = has(tailcfg.NodeAttrDisableDeltaUpdates)
 		oneCGNAT            opt.Bool
 		forceBackgroundSTUN = has(tailcfg.NodeAttrDebugForceBackgroundSTUN)
+		peerMTUEnable       = has(tailcfg.NodeAttrPeerMTUEnable)
 	)
 
 	if has(tailcfg.NodeAttrOneCGNATEnable) {
@@ -80,6 +84,7 @@ func (k *Knobs) UpdateFromNodeAttributes(selfNodeAttrs []tailcfg.NodeCapability,
 	k.OneCGNAT.Store(oneCGNAT)
 	k.ForceBackgroundSTUN.Store(forceBackgroundSTUN)
 	k.DisableDeltaUpdates.Store(disableDeltaUpdates)
+	k.PeerMTUEnable.Store(peerMTUEnable)
 }
 
 // AsDebugJSON returns k as something that can be marshalled with json.Marshal
@@ -96,5 +101,6 @@ func (k *Knobs) AsDebugJSON() map[string]any {
 		"OneCGNAT":            k.OneCGNAT.Load(),
 		"ForceBackgroundSTUN": k.ForceBackgroundSTUN.Load(),
 		"DisableDeltaUpdates": k.DisableDeltaUpdates.Load(),
+		"PeerMTUEnable":       k.PeerMTUEnable.Load(),
 	}
 }

+ 4 - 0
tailcfg/tailcfg.go

@@ -2133,6 +2133,10 @@ const (
 	// rather than one big /10 CGNAT route. At most one of this or
 	// NodeAttrOneCGNATEnable may be set; if neither are, it's automatic.
 	NodeAttrOneCGNATDisable NodeCapability = "one-cgnat?v=false"
+
+	// NodeAttrPeerMTUEnable makes the client do path MTU discovery to its
+	// peers. If it isn't set, it defaults to the client default.
+	NodeAttrPeerMTUEnable NodeCapability = "peer-mtu-enable"
 )
 
 // SetDNSRequest is a request to add a DNS record.

+ 8 - 0
wgengine/magicsock/peermtu.go

@@ -34,6 +34,14 @@ func (c *Conn) ShouldPMTUD() bool {
 		}
 		return v
 	}
+	if c.controlKnobs != nil {
+		if v := c.controlKnobs.PeerMTUEnable.Load(); v {
+			if debugPMTUD() {
+				c.logf("magicsock: peermtu: peer path MTU discovery enabled by control")
+			}
+			return v
+		}
+	}
 	if debugPMTUD() {
 		c.logf("magicsock: peermtu: peer path MTU discovery set by default to false")
 	}

+ 84 - 0
wgengine/userspace_test.go

@@ -6,12 +6,15 @@ package wgengine
 import (
 	"fmt"
 	"net/netip"
+	"os"
 	"reflect"
+	"runtime"
 	"testing"
 
 	"go4.org/mem"
 	"tailscale.com/cmd/testwrapper/flakytest"
 	"tailscale.com/control/controlknobs"
+	"tailscale.com/envknob"
 	"tailscale.com/net/dns"
 	"tailscale.com/net/netaddr"
 	"tailscale.com/net/tstun"
@@ -20,6 +23,7 @@ import (
 	"tailscale.com/tstime/mono"
 	"tailscale.com/types/key"
 	"tailscale.com/types/netmap"
+	"tailscale.com/types/opt"
 	"tailscale.com/wgengine/router"
 	"tailscale.com/wgengine/wgcfg"
 )
@@ -227,6 +231,86 @@ func TestUserspaceEnginePortReconfig(t *testing.T) {
 	}
 }
 
+// Test that enabling and disabling peer path MTU discovery works correctly.
+func TestUserspaceEnginePeerMTUReconfig(t *testing.T) {
+	if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
+		t.Skipf("skipping on %q; peer MTU not supported", runtime.GOOS)
+	}
+
+	defer os.Setenv("TS_DEBUG_ENABLE_PMTUD", os.Getenv("TS_DEBUG_ENABLE_PMTUD"))
+	envknob.Setenv("TS_DEBUG_ENABLE_PMTUD", "")
+	// Turn on debugging to help diagnose problems.
+	defer os.Setenv("TS_DEBUG_PMTUD", os.Getenv("TS_DEBUG_PMTUD"))
+	envknob.Setenv("TS_DEBUG_PMTUD", "true")
+
+	var knobs controlknobs.Knobs
+
+	e, err := NewFakeUserspaceEngine(t.Logf, 0, &knobs)
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Cleanup(e.Close)
+	ue := e.(*userspaceEngine)
+
+	if ue.magicConn.PeerMTUEnabled() != false {
+		t.Error("peer MTU enabled by default, should not be")
+	}
+	osDefaultDF, err := ue.magicConn.DontFragSetting()
+	if err != nil {
+		t.Errorf("get don't fragment bit failed: %v", err)
+	}
+	t.Logf("Info: OS default don't fragment bit(s) setting: %v", osDefaultDF)
+
+	// Build a set of configs to use as we change the peer MTU settings.
+	nodeKey, err := key.ParseNodePublicUntyped(mem.S("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	cfg := &wgcfg.Config{
+		Peers: []wgcfg.Peer{
+			{
+				PublicKey: nodeKey,
+				AllowedIPs: []netip.Prefix{
+					netip.PrefixFrom(netaddr.IPv4(100, 100, 99, 1), 32),
+				},
+			},
+		},
+	}
+	routerCfg := &router.Config{}
+
+	tests := []struct {
+		desc    string   // test description
+		wantP   bool     // desired value of PMTUD setting
+		wantDF  bool     // desired value of don't fragment bits
+		shouldP opt.Bool // if set, force peer MTU to this value
+	}{
+		{desc: "after_first_reconfig", wantP: false, wantDF: osDefaultDF, shouldP: ""},
+		{desc: "enabling_PMTUD_first_time", wantP: true, wantDF: true, shouldP: "true"},
+		{desc: "disabling_PMTUD", wantP: false, wantDF: false, shouldP: "false"},
+		{desc: "enabling_PMTUD_second_time", wantP: true, wantDF: true, shouldP: "true"},
+		{desc: "returning_to_default_PMTUD", wantP: false, wantDF: false, shouldP: ""},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.desc, func(t *testing.T) {
+			if v, ok := tt.shouldP.Get(); ok {
+				knobs.PeerMTUEnable.Store(v)
+			} else {
+				knobs.PeerMTUEnable.Store(false)
+			}
+			if err := ue.Reconfig(cfg, routerCfg, &dns.Config{}); err != nil {
+				t.Fatal(err)
+			}
+			if v := ue.magicConn.PeerMTUEnabled(); v != tt.wantP {
+				t.Errorf("peer MTU set to %v, want %v", v, tt.wantP)
+			}
+			if v, err := ue.magicConn.DontFragSetting(); v != tt.wantDF || err != nil {
+				t.Errorf("don't fragment bit set to %v, want %v, err %v", v, tt.wantP, err)
+			}
+		})
+	}
+}
+
 func nkFromHex(hex string) key.NodePublic {
 	if len(hex) != 64 {
 		panic(fmt.Sprintf("%q is len %d; want 64", hex, len(hex)))