Browse Source

tstest/integration: add start of integration tests for incremental map updates

This adds a new integration test with two nodes where the first gets a
incremental MapResponse (with only PeersRemoved set) saying that the
second node disappeared.

This extends the testcontrol package to support sending raw
MapResponses to nodes.

Updates #1909

Change-Id: Iea0c25c19cf0d72b52dba5a46d01b5cc87b9b39d
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 2 years ago
parent
commit
39ade4d0d4
2 changed files with 196 additions and 32 deletions
  1. 86 0
      tstest/integration/integration_test.go
  2. 110 32
      tstest/integration/testcontrol/testcontrol.go

+ 86 - 0
tstest/integration/integration_test.go

@@ -328,6 +328,92 @@ func TestTwoNodes(t *testing.T) {
 	d2.MustCleanShutdown(t)
 }
 
+// tests two nodes where the first gets a incremental MapResponse (with only
+// PeersRemoved set) saying that the second node disappeared.
+func TestIncrementalMapUpdatePeersRemoved(t *testing.T) {
+	flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/3598")
+	t.Parallel()
+	env := newTestEnv(t)
+
+	// Create one node:
+	n1 := newTestNode(t, env)
+	d1 := n1.StartDaemon()
+	n1.AwaitListening()
+	n1.MustUp()
+	n1.AwaitRunning()
+
+	all := env.Control.AllNodes()
+	if len(all) != 1 {
+		t.Fatalf("expected 1 node, got %d nodes", len(all))
+	}
+	tnode1 := all[0]
+
+	n2 := newTestNode(t, env)
+	d2 := n2.StartDaemon()
+	n2.AwaitListening()
+	n2.MustUp()
+	n2.AwaitRunning()
+
+	all = env.Control.AllNodes()
+	if len(all) != 2 {
+		t.Fatalf("expected 2 node, got %d nodes", len(all))
+	}
+	var tnode2 *tailcfg.Node
+	for _, n := range all {
+		if n.ID != tnode1.ID {
+			tnode2 = n
+			break
+		}
+	}
+	if tnode2 == nil {
+		t.Fatalf("failed to find second node ID (two dups?)")
+	}
+
+	t.Logf("node1=%v, node2=%v", tnode1.ID, tnode2.ID)
+
+	if err := tstest.WaitFor(2*time.Second, func() error {
+		st := n1.MustStatus()
+		if len(st.Peer) == 0 {
+			return errors.New("no peers")
+		}
+		if len(st.Peer) > 1 {
+			return fmt.Errorf("got %d peers; want 1", len(st.Peer))
+		}
+		peer := st.Peer[st.Peers()[0]]
+		if peer.ID == st.Self.ID {
+			return errors.New("peer is self")
+		}
+		return nil
+	}); err != nil {
+		t.Fatal(err)
+	}
+
+	t.Logf("node1 saw node2")
+
+	// Now tell node1 that node2 is removed.
+	if !env.Control.AddRawMapResponse(tnode1.Key, &tailcfg.MapResponse{
+		PeersRemoved: []tailcfg.NodeID{tnode2.ID},
+	}) {
+		t.Fatalf("failed to add map response")
+	}
+
+	// And see that node1 saw that.
+	if err := tstest.WaitFor(2*time.Second, func() error {
+		st := n1.MustStatus()
+		if len(st.Peer) == 0 {
+			return nil
+		}
+		return fmt.Errorf("got %d peers; want 0", len(st.Peer))
+	}); err != nil {
+		t.Fatal(err)
+	}
+
+	t.Logf("node1 saw node2 disappear")
+
+	d1.MustCleanShutdown(t)
+	d2.MustCleanShutdown(t)
+}
+
 func TestNodeAddressIPFields(t *testing.T) {
 	flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7008")
 	t.Parallel()

+ 110 - 32
tstest/integration/testcontrol/testcontrol.go

@@ -34,6 +34,7 @@ import (
 	"tailscale.com/types/logger"
 	"tailscale.com/types/ptr"
 	"tailscale.com/util/rands"
+	"tailscale.com/util/set"
 )
 
 const msgLimit = 1 << 20 // encrypted message length limit
@@ -67,6 +68,10 @@ type Server struct {
 	// masquerade address to use for that peer.
 	masquerades map[key.NodePublic]map[key.NodePublic]netip.Addr // node => peer => SelfNodeV4MasqAddrForThisPeer IP
 
+	// suppressAutoMapResponses is the set of nodes that should not be sent
+	// automatic map responses from serveMap. (They should only get manually sent ones)
+	suppressAutoMapResponses set.Set[key.NodePublic]
+
 	noisePubKey  key.MachinePublic
 	noisePrivKey key.ControlPrivate // not strictly needed vs. MachinePrivate, but handy to test type interactions.
 
@@ -76,8 +81,8 @@ type Server struct {
 	updates       map[tailcfg.NodeID]chan updateType
 	authPath      map[string]*AuthPath
 	nodeKeyAuthed map[key.NodePublic]bool // key => true once authenticated
-	pingReqsToAdd map[key.NodePublic]*tailcfg.PingRequest
-	allExpired    bool // All nodes will be told their node key is expired.
+	msgToSend     map[key.NodePublic]any  // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse
+	allExpired    bool                    // All nodes will be told their node key is expired.
 }
 
 // BaseURL returns the server's base URL, without trailing slash.
@@ -146,13 +151,32 @@ func (s *Server) AwaitNodeInMapRequest(ctx context.Context, k key.NodePublic) er
 	}
 }
 
-// AddPingRequest sends the ping pr to nodeKeyDst. It reports whether it did so. That is,
-// it reports whether nodeKeyDst was connected.
+// AddPingRequest sends the ping pr to nodeKeyDst.
+//
+// It reports whether the message was enqueued. That is, it reports whether
+// nodeKeyDst was connected.
 func (s *Server) AddPingRequest(nodeKeyDst key.NodePublic, pr *tailcfg.PingRequest) bool {
+	return s.addDebugMessage(nodeKeyDst, pr)
+}
+
+// AddRawMapResponse delivers the raw MapResponse mr to nodeKeyDst. It's meant
+// for testing incremental map updates.
+//
+// Once AddRawMapResponse has been sent to a node, all future automatic
+// MapResponses to that node will be suppressed and only explicit MapResponses
+// injected via AddRawMapResponse will be sent.
+//
+// It reports whether the message was enqueued. That is, it reports whether
+// nodeKeyDst was connected.
+func (s *Server) AddRawMapResponse(nodeKeyDst key.NodePublic, mr *tailcfg.MapResponse) bool {
+	return s.addDebugMessage(nodeKeyDst, mr)
+}
+
+func (s *Server) addDebugMessage(nodeKeyDst key.NodePublic, msg any) bool {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	if s.pingReqsToAdd == nil {
-		s.pingReqsToAdd = map[key.NodePublic]*tailcfg.PingRequest{}
+	if s.msgToSend == nil {
+		s.msgToSend = map[key.NodePublic]any{}
 	}
 	// Now send the update to the channel
 	node := s.nodeLocked(nodeKeyDst)
@@ -160,7 +184,14 @@ func (s *Server) AddPingRequest(nodeKeyDst key.NodePublic, pr *tailcfg.PingReque
 		return false
 	}
 
-	s.pingReqsToAdd[nodeKeyDst] = pr
+	if _, ok := msg.(*tailcfg.MapResponse); ok {
+		if s.suppressAutoMapResponses == nil {
+			s.suppressAutoMapResponses = set.Set[key.NodePublic]{}
+		}
+		s.suppressAutoMapResponses.Add(nodeKeyDst)
+	}
+
+	s.msgToSend[nodeKeyDst] = msg
 	nodeID := node.ID
 	oldUpdatesCh := s.updates[nodeID]
 	return sendUpdate(oldUpdatesCh, updateDebugInjection)
@@ -602,6 +633,7 @@ const (
 	updateSelfChanged
 
 	// updateDebugInjection is an update used for PingRequests
+	// or a raw MapResponse.
 	updateDebugInjection
 )
 
@@ -725,33 +757,49 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi
 
 	w.WriteHeader(200)
 	for {
-		res, err := s.MapResponse(req)
-		if err != nil {
-			// TODO: log
+		if resBytes, ok := s.takeRawMapMessage(req.NodeKey); ok {
+			if err := s.sendMapMsg(w, mkey, compress, resBytes); err != nil {
+				s.logf("sendMapMsg of raw message: %v", err)
+				return
+			}
+			if streaming {
+				continue
+			}
 			return
 		}
-		if res == nil {
-			return // done
-		}
 
-		s.mu.Lock()
-		allExpired := s.allExpired
-		s.mu.Unlock()
-		if allExpired {
-			res.Node.KeyExpiry = time.Now().Add(-1 * time.Minute)
-		}
-		// TODO: add minner if/when needed
-		resBytes, err := json.Marshal(res)
-		if err != nil {
-			s.logf("json.Marshal: %v", err)
-			return
-		}
-		if err := s.sendMapMsg(w, mkey, compress, resBytes); err != nil {
-			return
+		if s.canGenerateAutomaticMapResponseFor(req.NodeKey) {
+			res, err := s.MapResponse(req)
+			if err != nil {
+				// TODO: log
+				return
+			}
+			if res == nil {
+				return // done
+			}
+
+			s.mu.Lock()
+			allExpired := s.allExpired
+			s.mu.Unlock()
+			if allExpired {
+				res.Node.KeyExpiry = time.Now().Add(-1 * time.Minute)
+			}
+			// TODO: add minner if/when needed
+			resBytes, err := json.Marshal(res)
+			if err != nil {
+				s.logf("json.Marshal: %v", err)
+				return
+			}
+			if err := s.sendMapMsg(w, mkey, compress, resBytes); err != nil {
+				return
+			}
 		}
 		if !streaming {
 			return
 		}
+		if s.hasPendingRawMapMessage(req.NodeKey) {
+			continue
+		}
 	keepAliveLoop:
 		for {
 			var keepAliveTimer *time.Timer
@@ -874,16 +922,46 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse,
 	}
 	res.Node.AllowedIPs = res.Node.Addresses
 
-	// Consume the PingRequest while protected by mutex if it exists
+	// Consume a PingRequest while protected by mutex if it exists
 	s.mu.Lock()
-	if pr, ok := s.pingReqsToAdd[nk]; ok {
-		res.PingRequest = pr
-		delete(s.pingReqsToAdd, nk)
+	defer s.mu.Unlock()
+	switch m := s.msgToSend[nk].(type) {
+	case *tailcfg.PingRequest:
+		res.PingRequest = m
+		delete(s.msgToSend, nk)
 	}
-	s.mu.Unlock()
 	return res, nil
 }
 
+func (s *Server) canGenerateAutomaticMapResponseFor(nk key.NodePublic) bool {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	return !s.suppressAutoMapResponses.Contains(nk)
+}
+
+func (s *Server) hasPendingRawMapMessage(nk key.NodePublic) bool {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	_, ok := s.msgToSend[nk].(*tailcfg.MapResponse)
+	return ok
+}
+
+func (s *Server) takeRawMapMessage(nk key.NodePublic) (mapResJSON []byte, ok bool) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	mr, ok := s.msgToSend[nk].(*tailcfg.MapResponse)
+	if !ok {
+		return nil, false
+	}
+	delete(s.msgToSend, nk)
+	var err error
+	mapResJSON, err = json.Marshal(mr)
+	if err != nil {
+		panic(err)
+	}
+	return mapResJSON, true
+}
+
 func (s *Server) sendMapMsg(w http.ResponseWriter, mkey key.MachinePublic, compress bool, msg any) error {
 	resBytes, err := s.encode(mkey, compress, msg)
 	if err != nil {