Explorar o código

wgengine/magicsock: fix missing Conn.hasPeerRelayServers.Store() call (#16792)

This commit also extends the updateRelayServersSet unit tests to cover
onNodeViewsUpdate.

Fixes tailscale/corp#31080

Signed-off-by: Jordan Whited <[email protected]>
Jordan Whited hai 7 meses
pai
achega
4666d4ca2a
Modificáronse 2 ficheiros con 78 adicións e 20 borrados
  1. 4 5
      wgengine/magicsock/magicsock.go
  2. 74 15
      wgengine/magicsock/magicsock_test.go

+ 4 - 5
wgengine/magicsock/magicsock.go

@@ -274,11 +274,9 @@ type Conn struct {
 	captureHook syncs.AtomicValue[packet.CaptureCallback]
 
 	// hasPeerRelayServers is whether [relayManager] is configured with at least
-	// one peer relay server via [relayManager.handleRelayServersSet]. It is
-	// only accessed by [Conn.updateRelayServersSet], [endpoint.setDERPHome],
-	// and [endpoint.discoverUDPRelayPathsLocked]. It exists to suppress
-	// calls into [relayManager] leading to wasted work involving channel
-	// operations and goroutine creation.
+	// one peer relay server via [relayManager.handleRelayServersSet]. It exists
+	// to suppress calls into [relayManager] leading to wasted work involving
+	// channel operations and goroutine creation.
 	hasPeerRelayServers atomic.Bool
 
 	// discoPrivate is the private naclbox key used for active
@@ -2998,6 +2996,7 @@ func (c *Conn) onNodeViewsUpdate(update NodeViewsUpdate) {
 	if peersChanged || relayClientChanged {
 		if !relayClientEnabled {
 			c.relayManager.handleRelayServersSet(nil)
+			c.hasPeerRelayServers.Store(false)
 		} else {
 			c.updateRelayServersSet(filt, self, peers)
 		}

+ 74 - 15
wgengine/magicsock/magicsock_test.go

@@ -65,7 +65,6 @@ import (
 	"tailscale.com/types/netmap"
 	"tailscale.com/types/nettype"
 	"tailscale.com/types/ptr"
-	"tailscale.com/types/views"
 	"tailscale.com/util/cibuild"
 	"tailscale.com/util/clientmetric"
 	"tailscale.com/util/eventbus"
@@ -3584,7 +3583,7 @@ func Test_nodeHasCap(t *testing.T) {
 	}
 }
 
-func TestConn_updateRelayServersSet(t *testing.T) {
+func TestConn_onNodeViewsUpdate_updateRelayServersSet(t *testing.T) {
 	peerNodeCandidateRelay := &tailcfg.Node{
 		Cap: 121,
 		ID:  1,
@@ -3618,12 +3617,21 @@ func TestConn_updateRelayServersSet(t *testing.T) {
 		DiscoKey: key.NewDisco().Public(),
 	}
 
+	selfNodeNodeAttrDisableRelayClient := selfNode.Clone()
+	selfNodeNodeAttrDisableRelayClient.CapMap = make(tailcfg.NodeCapMap)
+	selfNodeNodeAttrDisableRelayClient.CapMap[tailcfg.NodeAttrDisableRelayClient] = nil
+
+	selfNodeNodeAttrOnlyTCP443 := selfNode.Clone()
+	selfNodeNodeAttrOnlyTCP443.CapMap = make(tailcfg.NodeCapMap)
+	selfNodeNodeAttrOnlyTCP443.CapMap[tailcfg.NodeAttrOnlyTCP443] = nil
+
 	tests := []struct {
-		name             string
-		filt             *filter.Filter
-		self             tailcfg.NodeView
-		peers            views.Slice[tailcfg.NodeView]
-		wantRelayServers set.Set[candidatePeerRelay]
+		name                   string
+		filt                   *filter.Filter
+		self                   tailcfg.NodeView
+		peers                  []tailcfg.NodeView
+		wantRelayServers       set.Set[candidatePeerRelay]
+		wantRelayClientEnabled bool
 	}{
 		{
 			name: "candidate relay server",
@@ -3639,7 +3647,7 @@ func TestConn_updateRelayServersSet(t *testing.T) {
 				},
 			}, nil, nil, nil, nil, nil),
 			self:  selfNode.View(),
-			peers: views.SliceOf([]tailcfg.NodeView{peerNodeCandidateRelay.View()}),
+			peers: []tailcfg.NodeView{peerNodeCandidateRelay.View()},
 			wantRelayServers: set.SetOf([]candidatePeerRelay{
 				{
 					nodeKey:          peerNodeCandidateRelay.Key,
@@ -3647,6 +3655,43 @@ func TestConn_updateRelayServersSet(t *testing.T) {
 					derpHomeRegionID: 1,
 				},
 			}),
+			wantRelayClientEnabled: true,
+		},
+		{
+			name: "no candidate relay server because self has tailcfg.NodeAttrDisableRelayClient",
+			filt: filter.New([]filtertype.Match{
+				{
+					Srcs: peerNodeCandidateRelay.Addresses,
+					Caps: []filtertype.CapMatch{
+						{
+							Dst: selfNodeNodeAttrDisableRelayClient.Addresses[0],
+							Cap: tailcfg.PeerCapabilityRelayTarget,
+						},
+					},
+				},
+			}, nil, nil, nil, nil, nil),
+			self:                   selfNodeNodeAttrDisableRelayClient.View(),
+			peers:                  []tailcfg.NodeView{peerNodeCandidateRelay.View()},
+			wantRelayServers:       make(set.Set[candidatePeerRelay]),
+			wantRelayClientEnabled: false,
+		},
+		{
+			name: "no candidate relay server because self has tailcfg.NodeAttrOnlyTCP443",
+			filt: filter.New([]filtertype.Match{
+				{
+					Srcs: peerNodeCandidateRelay.Addresses,
+					Caps: []filtertype.CapMatch{
+						{
+							Dst: selfNodeNodeAttrOnlyTCP443.Addresses[0],
+							Cap: tailcfg.PeerCapabilityRelayTarget,
+						},
+					},
+				},
+			}, nil, nil, nil, nil, nil),
+			self:                   selfNodeNodeAttrOnlyTCP443.View(),
+			peers:                  []tailcfg.NodeView{peerNodeCandidateRelay.View()},
+			wantRelayServers:       make(set.Set[candidatePeerRelay]),
+			wantRelayClientEnabled: false,
 		},
 		{
 			name: "self candidate relay server",
@@ -3662,7 +3707,7 @@ func TestConn_updateRelayServersSet(t *testing.T) {
 				},
 			}, nil, nil, nil, nil, nil),
 			self:  selfNode.View(),
-			peers: views.SliceOf([]tailcfg.NodeView{selfNode.View()}),
+			peers: []tailcfg.NodeView{selfNode.View()},
 			wantRelayServers: set.SetOf([]candidatePeerRelay{
 				{
 					nodeKey:          selfNode.Key,
@@ -3670,6 +3715,7 @@ func TestConn_updateRelayServersSet(t *testing.T) {
 					derpHomeRegionID: 2,
 				},
 			}),
+			wantRelayClientEnabled: true,
 		},
 		{
 			name: "no candidate relay server",
@@ -3684,21 +3730,34 @@ func TestConn_updateRelayServersSet(t *testing.T) {
 					},
 				},
 			}, nil, nil, nil, nil, nil),
-			self:             selfNode.View(),
-			peers:            views.SliceOf([]tailcfg.NodeView{peerNodeNotCandidateRelayCapVer.View()}),
-			wantRelayServers: make(set.Set[candidatePeerRelay]),
+			self:                   selfNode.View(),
+			peers:                  []tailcfg.NodeView{peerNodeNotCandidateRelayCapVer.View()},
+			wantRelayServers:       make(set.Set[candidatePeerRelay]),
+			wantRelayClientEnabled: true,
 		},
 	}
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			c := &Conn{}
-			c.updateRelayServersSet(tt.filt, tt.self, tt.peers)
+			c := newConn(t.Logf)
+			c.filt = tt.filt
+			if len(tt.wantRelayServers) == 0 {
+				// So we can verify it gets flipped back.
+				c.hasPeerRelayServers.Store(true)
+			}
+
+			c.onNodeViewsUpdate(NodeViewsUpdate{
+				SelfNode: tt.self,
+				Peers:    tt.peers,
+			})
 			got := c.relayManager.getServers()
 			if !got.Equal(tt.wantRelayServers) {
 				t.Fatalf("got: %v != want: %v", got, tt.wantRelayServers)
 			}
 			if len(tt.wantRelayServers) > 0 != c.hasPeerRelayServers.Load() {
-				t.Fatalf("c.hasPeerRelayServers: %v != wantRelayServers: %v", c.hasPeerRelayServers.Load(), tt.wantRelayServers)
+				t.Fatalf("c.hasPeerRelayServers: %v != len(tt.wantRelayServers) > 0: %v", c.hasPeerRelayServers.Load(), len(tt.wantRelayServers) > 0)
+			}
+			if c.relayClientEnabled != tt.wantRelayClientEnabled {
+				t.Fatalf("c.relayClientEnabled: %v != wantRelayClientEnabled: %v", c.relayClientEnabled, tt.wantRelayClientEnabled)
 			}
 		})
 	}