Browse Source

wgengine/router/osrouter: fix data race in magicsock port update callback

As found by @cmol in #17423.

Updates #17423

Change-Id: I1492501f74ca7b57a8c5278ea6cb87a56a4086b9
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 5 months ago
parent
commit
141eb64d3f
1 changed files with 19 additions and 17 deletions
  1. 19 17
      wgengine/router/osrouter/router_linux.go

+ 19 - 17
wgengine/router/osrouter/router_linux.go

@@ -86,8 +86,8 @@ type linuxRouter struct {
 	cmd commandRunner
 	nfr linuxfw.NetfilterRunner
 
-	magicsockPortV4 uint16
-	magicsockPortV6 uint16
+	magicsockPortV4 atomic.Uint32 // actually a uint16
+	magicsockPortV6 atomic.Uint32 // actually a uint16
 }
 
 func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Monitor, health *health.Tracker, bus *eventbus.Bus) (router.Router, error) {
@@ -546,7 +546,7 @@ func (r *linuxRouter) UpdateMagicsockPort(port uint16, network string) error {
 		}
 	}
 
-	var magicsockPort *uint16
+	var magicsockPort *atomic.Uint32
 	switch network {
 	case "udp4":
 		magicsockPort = &r.magicsockPortV4
@@ -566,27 +566,29 @@ func (r *linuxRouter) UpdateMagicsockPort(port uint16, network string) error {
 
 	// set the port, we'll make the firewall rule when netfilter turns back on
 	if r.netfilterMode == netfilterOff {
-		*magicsockPort = port
+		magicsockPort.Store(uint32(port))
 		return nil
 	}
 
-	if *magicsockPort == port {
+	cur := magicsockPort.Load()
+
+	if cur == uint32(port) {
 		return nil
 	}
 
-	if *magicsockPort != 0 {
-		if err := r.nfr.DelMagicsockPortRule(*magicsockPort, network); err != nil {
+	if cur != 0 {
+		if err := r.nfr.DelMagicsockPortRule(uint16(cur), network); err != nil {
 			return fmt.Errorf("del magicsock port rule: %w", err)
 		}
 	}
 
 	if port != 0 {
-		if err := r.nfr.AddMagicsockPortRule(*magicsockPort, network); err != nil {
+		if err := r.nfr.AddMagicsockPortRule(uint16(port), network); err != nil {
 			return fmt.Errorf("add magicsock port rule: %w", err)
 		}
 	}
 
-	*magicsockPort = port
+	magicsockPort.Store(uint32(port))
 	return nil
 }
 
@@ -658,13 +660,13 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error {
 			if err := r.nfr.AddBase(r.tunname); err != nil {
 				return err
 			}
-			if r.magicsockPortV4 != 0 {
-				if err := r.nfr.AddMagicsockPortRule(r.magicsockPortV4, "udp4"); err != nil {
+			if mport := uint16(r.magicsockPortV4.Load()); mport != 0 {
+				if err := r.nfr.AddMagicsockPortRule(mport, "udp4"); err != nil {
 					return fmt.Errorf("could not add magicsock port rule v4: %w", err)
 				}
 			}
-			if r.magicsockPortV6 != 0 && r.getV6FilteringAvailable() {
-				if err := r.nfr.AddMagicsockPortRule(r.magicsockPortV6, "udp6"); err != nil {
+			if mport := uint16(r.magicsockPortV6.Load()); mport != 0 && r.getV6FilteringAvailable() {
+				if err := r.nfr.AddMagicsockPortRule(mport, "udp6"); err != nil {
 					return fmt.Errorf("could not add magicsock port rule v6: %w", err)
 				}
 			}
@@ -698,13 +700,13 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error {
 			if err := r.nfr.AddBase(r.tunname); err != nil {
 				return err
 			}
-			if r.magicsockPortV4 != 0 {
-				if err := r.nfr.AddMagicsockPortRule(r.magicsockPortV4, "udp4"); err != nil {
+			if mport := uint16(r.magicsockPortV4.Load()); mport != 0 {
+				if err := r.nfr.AddMagicsockPortRule(mport, "udp4"); err != nil {
 					return fmt.Errorf("could not add magicsock port rule v4: %w", err)
 				}
 			}
-			if r.magicsockPortV6 != 0 && r.getV6FilteringAvailable() {
-				if err := r.nfr.AddMagicsockPortRule(r.magicsockPortV6, "udp6"); err != nil {
+			if mport := uint16(r.magicsockPortV6.Load()); mport != 0 && r.getV6FilteringAvailable() {
+				if err := r.nfr.AddMagicsockPortRule(mport, "udp6"); err != nil {
 					return fmt.Errorf("could not add magicsock port rule v6: %w", err)
 				}
 			}