Ver código fonte

derp: include src IPs in mesh watch messages

Updates tailscale/corp#13945

Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 2 anos atrás
pai
commit
6c791f7d60
6 arquivos alterados com 73 adições e 38 exclusões
  1. 2 1
      cmd/derper/mesh.go
  2. 1 1
      derp/derp.go
  3. 15 3
      derp/derp_client.go
  4. 45 25
      derp/derp_server.go
  5. 4 3
      derp/derp_test.go
  6. 6 5
      derp/derphttp/mesh_client.go

+ 2 - 1
cmd/derper/mesh.go

@@ -9,6 +9,7 @@ import (
 	"fmt"
 	"log"
 	"net"
+	"net/netip"
 	"strings"
 	"time"
 
@@ -67,7 +68,7 @@ func startMeshWithHost(s *derp.Server, host string) error {
 		return d.DialContext(ctx, network, addr)
 	})
 
-	add := func(k key.NodePublic) { s.AddPacketForwarder(k, c) }
+	add := func(k key.NodePublic, _ netip.AddrPort) { s.AddPacketForwarder(k, c) }
 	remove := func(k key.NodePublic) { s.RemovePacketForwarder(k, c) }
 	go c.RunWatchConnectionLoop(context.Background(), s.PublicKey(), logf, add, remove)
 	return nil

+ 1 - 1
derp/derp.go

@@ -85,7 +85,7 @@ const (
 
 	// framePeerPresent is like framePeerGone, but for other
 	// members of the DERP region when they're meshed up together.
-	framePeerPresent = frameType(0x09) // 32B pub key of peer that's connected
+	framePeerPresent = frameType(0x09) // 32B pub key of peer that's connected + optional 18B ip:port (16 byte IP + 2 byte BE uint16 port)
 
 	// frameWatchConns is how one DERP node in a regional mesh
 	// subscribes to the others in the region.

+ 15 - 3
derp/derp_client.go

@@ -363,7 +363,12 @@ func (PeerGoneMessage) msg() {}
 
 // PeerPresentMessage is a ReceivedMessage that indicates that the client
 // is connected to the server. (Only used by trusted mesh clients)
-type PeerPresentMessage key.NodePublic
+type PeerPresentMessage struct {
+	// Key is the public key of the client.
+	Key key.NodePublic
+	// IPPort is the remote IP and port of the client.
+	IPPort netip.AddrPort
+}
 
 func (PeerPresentMessage) msg() {}
 
@@ -546,8 +551,15 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro
 				c.logf("[unexpected] dropping short peerPresent frame from DERP server")
 				continue
 			}
-			pg := PeerPresentMessage(key.NodePublicFromRaw32(mem.B(b[:keyLen])))
-			return pg, nil
+			var msg PeerPresentMessage
+			msg.Key = key.NodePublicFromRaw32(mem.B(b[:keyLen]))
+			if n >= keyLen+16+2 {
+				msg.IPPort = netip.AddrPortFrom(
+					netip.AddrFrom16([16]byte(b[keyLen:keyLen+16])).Unmap(),
+					binary.BigEndian.Uint16(b[keyLen+16:keyLen+16+2]),
+				)
+			}
+			return msg, nil
 
 		case frameRecvPacket:
 			var rp ReceivedPacket

+ 45 - 25
derp/derp_server.go

@@ -12,6 +12,7 @@ import (
 	crand "crypto/rand"
 	"crypto/x509"
 	"crypto/x509/pkix"
+	"encoding/binary"
 	"encoding/json"
 	"errors"
 	"expvar"
@@ -43,6 +44,7 @@ import (
 	"tailscale.com/tstime/rate"
 	"tailscale.com/types/key"
 	"tailscale.com/types/logger"
+	"tailscale.com/util/set"
 	"tailscale.com/version"
 )
 
@@ -150,7 +152,7 @@ type Server struct {
 	closed   bool
 	netConns map[Conn]chan struct{} // chan is closed when conn closes
 	clients  map[key.NodePublic]clientSet
-	watchers map[*sclient]bool // mesh peer -> true
+	watchers set.Set[*sclient] // mesh peers
 	// clientsMesh tracks all clients in the cluster, both locally
 	// and to mesh peers.  If the value is nil, that means the
 	// peer is only local (and thus in the clients Map, but not
@@ -219,8 +221,7 @@ func (s singleClient) ForeachClient(f func(*sclient)) { f(s.c) }
 // All fields are guarded by Server.mu.
 type dupClientSet struct {
 	// set is the set of connected clients for sclient.key.
-	// The values are all true.
-	set map[*sclient]bool
+	set set.Set[*sclient]
 
 	// last is the most recent addition to set, or nil if the most
 	// recent one has since disconnected and nobody else has send
@@ -261,7 +262,7 @@ func (s *dupClientSet) removeClient(c *sclient) bool {
 
 	trim := s.sendHistory[:0]
 	for _, v := range s.sendHistory {
-		if s.set[v] && (len(trim) == 0 || trim[len(trim)-1] != v) {
+		if s.set.Contains(v) && (len(trim) == 0 || trim[len(trim)-1] != v) {
 			trim = append(trim, v)
 		}
 	}
@@ -316,7 +317,7 @@ func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server {
 		clientsMesh:          map[key.NodePublic]PacketForwarder{},
 		netConns:             map[Conn]chan struct{}{},
 		memSys0:              ms.Sys,
-		watchers:             map[*sclient]bool{},
+		watchers:             set.Set[*sclient]{},
 		sentTo:               map[key.NodePublic]map[key.NodePublic]int64{},
 		avgQueueDuration:     new(uint64),
 		tcpRtt:               metrics.LabelMap{Label: "le"},
@@ -498,8 +499,8 @@ func (s *Server) registerClient(c *sclient) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
-	set := s.clients[c.key]
-	switch set := set.(type) {
+	curSet := s.clients[c.key]
+	switch curSet := curSet.(type) {
 	case nil:
 		s.clients[c.key] = singleClient{c}
 		c.debugLogf("register single client")
@@ -507,14 +508,14 @@ func (s *Server) registerClient(c *sclient) {
 		s.dupClientKeys.Add(1)
 		s.dupClientConns.Add(2) // both old and new count
 		s.dupClientConnTotal.Add(1)
-		old := set.ActiveClient()
+		old := curSet.ActiveClient()
 		old.isDup.Store(true)
 		c.isDup.Store(true)
 		s.clients[c.key] = &dupClientSet{
 			last: c,
-			set: map[*sclient]bool{
-				old: true,
-				c:   true,
+			set: set.Set[*sclient]{
+				old: struct{}{},
+				c:   struct{}{},
 			},
 			sendHistory: []*sclient{old},
 		}
@@ -523,9 +524,9 @@ func (s *Server) registerClient(c *sclient) {
 		s.dupClientConns.Add(1)     // the gauge
 		s.dupClientConnTotal.Add(1) // the counter
 		c.isDup.Store(true)
-		set.set[c] = true
-		set.last = c
-		set.sendHistory = append(set.sendHistory, c)
+		curSet.set.Add(c)
+		curSet.last = c
+		curSet.sendHistory = append(curSet.sendHistory, c)
 		c.debugLogf("register another duplicate client")
 	}
 
@@ -534,7 +535,7 @@ func (s *Server) registerClient(c *sclient) {
 	}
 	s.keyOfAddr[c.remoteIPPort] = c.key
 	s.curClients.Add(1)
-	s.broadcastPeerStateChangeLocked(c.key, true)
+	s.broadcastPeerStateChangeLocked(c.key, c.remoteIPPort, true)
 }
 
 // broadcastPeerStateChangeLocked enqueues a message to all watchers
@@ -542,9 +543,13 @@ func (s *Server) registerClient(c *sclient) {
 // presence changed.
 //
 // s.mu must be held.
-func (s *Server) broadcastPeerStateChangeLocked(peer key.NodePublic, present bool) {
+func (s *Server) broadcastPeerStateChangeLocked(peer key.NodePublic, ipPort netip.AddrPort, present bool) {
 	for w := range s.watchers {
-		w.peerStateChange = append(w.peerStateChange, peerConnState{peer: peer, present: present})
+		w.peerStateChange = append(w.peerStateChange, peerConnState{
+			peer:    peer,
+			present: present,
+			ipPort:  ipPort,
+		})
 		go w.requestMeshUpdate()
 	}
 }
@@ -565,7 +570,7 @@ func (s *Server) unregisterClient(c *sclient) {
 			delete(s.clientsMesh, c.key)
 			s.notePeerGoneFromRegionLocked(c.key)
 		}
-		s.broadcastPeerStateChangeLocked(c.key, false)
+		s.broadcastPeerStateChangeLocked(c.key, netip.AddrPort{}, false)
 	case *dupClientSet:
 		c.debugLogf("removed duplicate client")
 		if set.removeClient(c) {
@@ -655,13 +660,21 @@ func (s *Server) addWatcher(c *sclient) {
 	defer s.mu.Unlock()
 
 	// Queue messages for each already-connected client.
-	for peer := range s.clients {
-		c.peerStateChange = append(c.peerStateChange, peerConnState{peer: peer, present: true})
+	for peer, clientSet := range s.clients {
+		ac := clientSet.ActiveClient()
+		if ac == nil {
+			continue
+		}
+		c.peerStateChange = append(c.peerStateChange, peerConnState{
+			peer:    peer,
+			present: true,
+			ipPort:  ac.remoteIPPort,
+		})
 	}
 
 	// And enroll the watcher in future updates (of both
 	// connections & disconnections).
-	s.watchers[c] = true
+	s.watchers.Add(c)
 
 	go c.requestMeshUpdate()
 }
@@ -1349,6 +1362,7 @@ type sclient struct {
 type peerConnState struct {
 	peer    key.NodePublic
 	present bool
+	ipPort  netip.AddrPort // if present, the peer's IP:port
 }
 
 // pkt is a request to write a data frame to an sclient.
@@ -1542,12 +1556,18 @@ func (c *sclient) sendPeerGone(peer key.NodePublic, reason PeerGoneReasonType) e
 }
 
 // sendPeerPresent sends a peerPresent frame, without flushing.
-func (c *sclient) sendPeerPresent(peer key.NodePublic) error {
+func (c *sclient) sendPeerPresent(peer key.NodePublic, ipPort netip.AddrPort) error {
 	c.setWriteDeadline()
-	if err := writeFrameHeader(c.bw.bw(), framePeerPresent, keyLen); err != nil {
+	const frameLen = keyLen + 16 + 2
+	if err := writeFrameHeader(c.bw.bw(), framePeerPresent, frameLen); err != nil {
 		return err
 	}
-	_, err := c.bw.Write(peer.AppendTo(nil))
+	payload := make([]byte, frameLen)
+	_ = peer.AppendTo(payload[:0])
+	a16 := ipPort.Addr().As16()
+	copy(payload[keyLen:], a16[:])
+	binary.BigEndian.PutUint16(payload[keyLen+16:], ipPort.Port())
+	_, err := c.bw.Write(payload)
 	return err
 }
 
@@ -1566,7 +1586,7 @@ func (c *sclient) sendMeshUpdates() error {
 		}
 		var err error
 		if pcs.present {
-			err = c.sendPeerPresent(pcs.peer)
+			err = c.sendPeerPresent(pcs.peer, pcs.ipPort)
 		} else {
 			err = c.sendPeerGone(pcs.peer, PeerGoneReasonDisconnected)
 		}

+ 4 - 3
derp/derp_test.go

@@ -92,7 +92,7 @@ func TestSendRecv(t *testing.T) {
 		defer cancel()
 
 		brwServer := bufio.NewReadWriter(bufio.NewReader(cin), bufio.NewWriter(cin))
-		go s.Accept(ctx, cin, brwServer, fmt.Sprintf("test-client-%d", i))
+		go s.Accept(ctx, cin, brwServer, fmt.Sprintf("[abc::def]:%v", i))
 
 		key := clientPrivateKeys[i]
 		brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout))
@@ -528,7 +528,7 @@ func newTestServer(t *testing.T, ctx context.Context) *testServer {
 			// TODO: register c in ts so Close also closes it?
 			go func(i int) {
 				brwServer := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
-				go s.Accept(ctx, c, brwServer, fmt.Sprintf("test-client-%d", i))
+				go s.Accept(ctx, c, brwServer, c.RemoteAddr().String())
 			}(i)
 		}
 	}()
@@ -615,7 +615,7 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.NodePublic) {
 		}
 		switch m := m.(type) {
 		case PeerPresentMessage:
-			got := key.NodePublic(m)
+			got := m.Key
 			if !want[got] {
 				t.Fatalf("got peer present for %v; want present for %v", tc.ts.keyName(got), logger.ArgWriter(func(bw *bufio.Writer) {
 					for _, pub := range peers {
@@ -623,6 +623,7 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.NodePublic) {
 					}
 				}))
 			}
+			t.Logf("got present with IP %v", m.IPPort)
 			delete(want, got)
 			if len(want) == 0 {
 				return

+ 6 - 5
derp/derphttp/mesh_client.go

@@ -5,6 +5,7 @@ package derphttp
 
 import (
 	"context"
+	"net/netip"
 	"sync"
 	"time"
 
@@ -26,7 +27,7 @@ import (
 //
 // To force RunWatchConnectionLoop to return quickly, its ctx needs to
 // be closed, and c itself needs to be closed.
-func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf, add, remove func(key.NodePublic)) {
+func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf, add func(key.NodePublic, netip.AddrPort), remove func(key.NodePublic)) {
 	if infoLogf == nil {
 		infoLogf = logger.Discard
 	}
@@ -68,9 +69,9 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key
 	})
 	defer timer.Stop()
 
-	updatePeer := func(k key.NodePublic, isPresent bool) {
+	updatePeer := func(k key.NodePublic, ipPort netip.AddrPort, isPresent bool) {
 		if isPresent {
-			add(k)
+			add(k, ipPort)
 		} else {
 			remove(k)
 		}
@@ -126,7 +127,7 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key
 			}
 			switch m := m.(type) {
 			case derp.PeerPresentMessage:
-				updatePeer(key.NodePublic(m), true)
+				updatePeer(m.Key, m.IPPort, true)
 			case derp.PeerGoneMessage:
 				switch m.Reason {
 				case derp.PeerGoneReasonDisconnected:
@@ -138,7 +139,7 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key
 					logf("Recv: peer %s not at server %s for unknown reason %v",
 						key.NodePublic(m.Peer).ShortString(), c.ServerPublicKey().ShortString(), m.Reason)
 				}
-				updatePeer(key.NodePublic(m.Peer), false)
+				updatePeer(key.NodePublic(m.Peer), netip.AddrPort{}, false)
 			default:
 				continue
 			}