Browse Source

feature/relayserver,net/{netcheck,udprelay}: implement addr discovery (#16253)

The relay server now fetches IPs from local interfaces and external
perspective IP:port's via netcheck (STUN).

Updates tailscale/corp#27502

Signed-off-by: Jordan Whited <[email protected]>
Jordan Whited 8 months ago
parent
commit
3ed76ceed3
4 changed files with 170 additions and 43 deletions
  1. 1 2
      feature/relayserver/relayserver.go
  2. 16 3
      net/netcheck/netcheck.go
  3. 152 37
      net/udprelay/server.go
  4. 1 1
      net/udprelay/server_test.go

+ 1 - 2
feature/relayserver/relayserver.go

@@ -10,7 +10,6 @@ import (
 	"errors"
 	"io"
 	"net/http"
-	"net/netip"
 	"sync"
 
 	"tailscale.com/envknob"
@@ -136,7 +135,7 @@ func (e *extension) relayServerOrInit() (relayServer, error) {
 		return nil, errors.New("TAILSCALE_USE_WIP_CODE envvar is not set")
 	}
 	var err error
-	e.server, _, err = udprelay.NewServer(*e.port, []netip.Addr{netip.MustParseAddr("127.0.0.1")})
+	e.server, _, err = udprelay.NewServer(e.logf, *e.port, nil)
 	if err != nil {
 		return nil, err
 	}

+ 16 - 3
net/netcheck/netcheck.go

@@ -753,6 +753,7 @@ func newReport() *Report {
 
 // GetReportOpts contains options that can be passed to GetReport. Unless
 // specified, all fields are optional and can be left as their zero value.
+// At most one of OnlyTCP443 or OnlySTUN may be set.
 type GetReportOpts struct {
 	// GetLastDERPActivity is a callback that, if provided, should return
 	// the absolute time that the calling code last communicated with a
@@ -765,6 +766,8 @@ type GetReportOpts struct {
 	// OnlyTCP443 constrains netcheck reporting to measurements over TCP port
 	// 443.
 	OnlyTCP443 bool
+	// OnlySTUN constrains netcheck reporting to STUN measurements over UDP.
+	OnlySTUN bool
 }
 
 // getLastDERPActivity calls o.GetLastDERPActivity if both o and
@@ -790,6 +793,13 @@ func (c *Client) SetForcePreferredDERP(region int) {
 //
 // It may not be called concurrently with itself.
 func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetReportOpts) (_ *Report, reterr error) {
+	onlySTUN := false
+	if opts != nil && opts.OnlySTUN {
+		if opts.OnlyTCP443 {
+			return nil, errors.New("netcheck: only one of OnlySTUN or OnlyTCP443 may be set in opts")
+		}
+		onlySTUN = true
+	}
 	defer func() {
 		if reterr != nil {
 			metricNumGetReportError.Add(1)
@@ -865,6 +875,9 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe
 	}()
 
 	if runtime.GOOS == "js" || runtime.GOOS == "tamago" || (runtime.GOOS == "plan9" && hostinfo.IsInVM86()) {
+		if onlySTUN {
+			return nil, errors.New("platform is restricted to HTTP, but OnlySTUN is set in opts")
+		}
 		if err := c.runHTTPOnlyChecks(ctx, last, rs, dm); err != nil {
 			return nil, err
 		}
@@ -896,7 +909,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe
 	// it's unnecessary.
 	captivePortalDone := syncs.ClosedChan()
 	captivePortalStop := func() {}
-	if !rs.incremental {
+	if !rs.incremental && !onlySTUN {
 		// NOTE(andrew): we can't simply add this goroutine to the
 		// `NewWaitGroupChan` below, since we don't wait for that
 		// waitgroup to finish when exiting this function and thus get
@@ -970,9 +983,9 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe
 	rs.stopTimers()
 
 	// Try HTTPS and ICMP latency check if all STUN probes failed due to
-	// UDP presumably being blocked.
+	// UDP presumably being blocked, and we are not constrained to only STUN.
 	// TODO: this should be moved into the probePlan, using probeProto probeHTTPS.
-	if !rs.anyUDP() && ctx.Err() == nil {
+	if !rs.anyUDP() && ctx.Err() == nil && !onlySTUN {
 		var wg sync.WaitGroup
 		var need []*tailcfg.DERPRegion
 		for rid, reg := range dm.Regions {

+ 152 - 37
net/udprelay/server.go

@@ -8,6 +8,7 @@ package udprelay
 
 import (
 	"bytes"
+	"context"
 	"crypto/rand"
 	"errors"
 	"fmt"
@@ -19,11 +20,18 @@ import (
 	"time"
 
 	"go4.org/mem"
+	"tailscale.com/client/local"
 	"tailscale.com/disco"
+	"tailscale.com/net/netcheck"
+	"tailscale.com/net/netmon"
 	"tailscale.com/net/packet"
+	"tailscale.com/net/stun"
 	"tailscale.com/net/udprelay/endpoint"
 	"tailscale.com/tstime"
 	"tailscale.com/types/key"
+	"tailscale.com/types/logger"
+	"tailscale.com/util/eventbus"
+	"tailscale.com/util/set"
 )
 
 const (
@@ -42,25 +50,22 @@ const (
 
 // Server implements an experimental UDP relay server.
 type Server struct {
-	// disco keypair used as part of 3-way bind handshake
-	disco       key.DiscoPrivate
-	discoPublic key.DiscoPublic
-
+	// The following fields are initialized once and never mutated.
+	logf                logger.Logf
+	disco               key.DiscoPrivate
+	discoPublic         key.DiscoPublic
 	bindLifetime        time.Duration
 	steadyStateLifetime time.Duration
-
-	// addrPorts contains the ip:port pairs returned as candidate server
-	// endpoints in response to an allocation request.
-	addrPorts []netip.AddrPort
-
-	uc *net.UDPConn
-
-	closeOnce sync.Once
-	wg        sync.WaitGroup
-	closeCh   chan struct{}
+	bus                 *eventbus.Bus
+	uc                  *net.UDPConn
+	closeOnce           sync.Once
+	wg                  sync.WaitGroup
+	closeCh             chan struct{}
+	netChecker          *netcheck.Client
+
+	mu        sync.Mutex       // guards the following fields
+	addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints
 	closed    bool
-
-	mu        sync.Mutex // guards the following fields
 	lamportID uint64
 	vniPool   []uint32 // the pool of available VNIs
 	byVNI     map[uint32]*serverEndpoint
@@ -270,14 +275,13 @@ func (e *serverEndpoint) isBound() bool {
 
 // NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet
 // supported. Port may be 0, and what ultimately gets bound is returned as
-// 'boundPort'. Supplied 'addrs' are joined with 'boundPort' and returned as
-// [endpoint.ServerEndpoint.AddrPorts] in response to Server.AllocateEndpoint()
-// requests.
+// 'boundPort'. If len(overrideAddrs) > 0 these will be used in place of dynamic
+// discovery, which is useful to override in tests.
 //
 // TODO: IPv6 support
-// TODO: dynamic addrs:port discovery
-func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err error) {
+func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, boundPort uint16, err error) {
 	s = &Server{
+		logf:                logger.WithPrefix(logf, "relayserver"),
 		disco:               key.NewDisco(),
 		bindLifetime:        defaultBindLifetime,
 		steadyStateLifetime: defaultSteadyStateLifetime,
@@ -292,26 +296,120 @@ func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err erro
 	for i := 1; i < 1<<24; i++ {
 		s.vniPool = append(s.vniPool, uint32(i))
 	}
-	boundPort, err = s.listenOn(port)
+
+	bus := eventbus.New()
+	s.bus = bus
+	netMon, err := netmon.New(s.bus, logf)
 	if err != nil {
 		return nil, 0, err
 	}
-	addrPorts := make([]netip.AddrPort, 0, len(addrs))
-	for _, addr := range addrs {
-		addrPort, err := netip.ParseAddrPort(net.JoinHostPort(addr.String(), strconv.Itoa(boundPort)))
-		if err != nil {
-			return nil, 0, err
-		}
-		addrPorts = append(addrPorts, addrPort)
+	s.netChecker = &netcheck.Client{
+		NetMon: netMon,
+		Logf:   logger.WithPrefix(logf, "relayserver: netcheck:"),
+		SendPacket: func(b []byte, addrPort netip.AddrPort) (int, error) {
+			return s.uc.WriteToUDPAddrPort(b, addrPort)
+		},
+	}
+
+	boundPort, err = s.listenOn(port)
+	if err != nil {
+		return nil, 0, err
 	}
-	s.addrPorts = addrPorts
-	s.wg.Add(2)
+
+	s.wg.Add(1)
 	go s.packetReadLoop()
+	s.wg.Add(1)
 	go s.endpointGCLoop()
+	if len(overrideAddrs) > 0 {
+		var addrPorts set.Set[netip.AddrPort]
+		addrPorts.Make()
+		for _, addr := range overrideAddrs {
+			if addr.IsValid() {
+				addrPorts.Add(netip.AddrPortFrom(addr, boundPort))
+			}
+		}
+		s.addrPorts = addrPorts.Slice()
+	} else {
+		s.wg.Add(1)
+		go s.addrDiscoveryLoop()
+	}
 	return s, boundPort, nil
 }
 
-func (s *Server) listenOn(port int) (int, error) {
+func (s *Server) addrDiscoveryLoop() {
+	defer s.wg.Done()
+
+	timer := time.NewTimer(0) // fire immediately
+	defer timer.Stop()
+
+	getAddrPorts := func() ([]netip.AddrPort, error) {
+		var addrPorts set.Set[netip.AddrPort]
+		addrPorts.Make()
+
+		// get local addresses
+		localPort := s.uc.LocalAddr().(*net.UDPAddr).Port
+		ips, _, err := netmon.LocalAddresses()
+		if err != nil {
+			return nil, err
+		}
+		for _, ip := range ips {
+			if ip.IsValid() {
+				addrPorts.Add(netip.AddrPortFrom(ip, uint16(localPort)))
+			}
+		}
+
+		// fetch DERPMap to feed to netcheck
+		derpMapCtx, derpMapCancel := context.WithTimeout(context.Background(), time.Second)
+		defer derpMapCancel()
+		localClient := &local.Client{}
+		// TODO(jwhited): We are in-process so use eventbus or similar.
+		//  local.Client gets us going.
+		dm, err := localClient.CurrentDERPMap(derpMapCtx)
+		if err != nil {
+			return nil, err
+		}
+
+		// get addrPorts as visible from DERP
+		netCheckerCtx, netCheckerCancel := context.WithTimeout(context.Background(), netcheck.ReportTimeout)
+		defer netCheckerCancel()
+		rep, err := s.netChecker.GetReport(netCheckerCtx, dm, &netcheck.GetReportOpts{
+			OnlySTUN: true,
+		})
+		if err != nil {
+			return nil, err
+		}
+		if rep.GlobalV4.IsValid() {
+			addrPorts.Add(rep.GlobalV4)
+		}
+		if rep.GlobalV6.IsValid() {
+			addrPorts.Add(rep.GlobalV6)
+		}
+		// TODO(jwhited): consider logging if rep.MappingVariesByDestIP as
+		//  that's a hint we are not well-positioned to operate as a UDP relay.
+		return addrPorts.Slice(), nil
+	}
+
+	for {
+		select {
+		case <-timer.C:
+			// Mirror magicsock behavior for duration between STUN. We consider
+			// 30s a min bound for NAT timeout.
+			timer.Reset(tstime.RandomDurationBetween(20*time.Second, 26*time.Second))
+			addrPorts, err := getAddrPorts()
+			if err != nil {
+				s.logf("error discovering IP:port candidates: %v", err)
+			}
+			s.mu.Lock()
+			s.addrPorts = addrPorts
+			s.mu.Unlock()
+		case <-s.closeCh:
+			return
+		}
+	}
+
+}
+
+func (s *Server) listenOn(port int) (uint16, error) {
 	uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port})
 	if err != nil {
 		return 0, err
@@ -322,13 +420,13 @@ func (s *Server) listenOn(port int) (int, error) {
 		s.uc.Close()
 		return 0, err
 	}
-	boundPort, err := strconv.Atoi(boundPortStr)
+	boundPort, err := strconv.ParseUint(boundPortStr, 10, 16)
 	if err != nil {
 		s.uc.Close()
 		return 0, err
 	}
 	s.uc = uc
-	return boundPort, nil
+	return uint16(boundPort), nil
 }
 
 // Close closes the server.
@@ -343,6 +441,7 @@ func (s *Server) Close() error {
 		clear(s.byDisco)
 		s.vniPool = nil
 		s.closed = true
+		s.bus.Close()
 	})
 	return nil
 }
@@ -378,6 +477,13 @@ func (s *Server) endpointGCLoop() {
 }
 
 func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) {
+	if stun.Is(b) && b[1] == 0x01 {
+		// A b[1] value of 0x01 (STUN method binding) is sufficiently
+		// non-overlapping with the Geneve header where the LSB is always 0
+		// (part of 6 "reserved" bits).
+		s.netChecker.ReceiveSTUNPacket(b, from)
+		return
+	}
 	gh := packet.GeneveHeader{}
 	err := gh.Decode(b)
 	if err != nil {
@@ -426,6 +532,10 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
 		return endpoint.ServerEndpoint{}, ErrServerClosed
 	}
 
+	if len(s.addrPorts) == 0 {
+		return endpoint.ServerEndpoint{}, errors.New("server addrPorts are not yet known")
+	}
+
 	if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 {
 		return endpoint.ServerEndpoint{}, fmt.Errorf("client disco equals server disco: %s", s.discoPublic.ShortString())
 	}
@@ -439,8 +549,13 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
 		// TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt
 		// to give the client a more accurate picture of the bind window.
 		return endpoint.ServerEndpoint{
-			ServerDisco:         s.discoPublic,
-			AddrPorts:           s.addrPorts,
+			ServerDisco: s.discoPublic,
+			// Returning the "latest" addrPorts for an existing allocation is
+			// the simple choice. It may not be the best depending on client
+			// behaviors and endpoint state (bound or not). We might want to
+			// consider storing them (maybe interning) in the [*serverEndpoint]
+			// at allocation time.
+			AddrPorts:           slices.Clone(s.addrPorts),
 			VNI:                 e.vni,
 			LamportID:           e.lamportID,
 			BindLifetime:        tstime.GoDuration{Duration: s.bindLifetime},
@@ -469,7 +584,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
 
 	return endpoint.ServerEndpoint{
 		ServerDisco:         s.discoPublic,
-		AddrPorts:           s.addrPorts,
+		AddrPorts:           slices.Clone(s.addrPorts),
 		VNI:                 e.vni,
 		LamportID:           e.lamportID,
 		BindLifetime:        tstime.GoDuration{Duration: s.bindLifetime},

+ 1 - 1
net/udprelay/server_test.go

@@ -156,7 +156,7 @@ func TestServer(t *testing.T) {
 
 	ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1")
 
-	server, _, err := NewServer(0, []netip.Addr{ipv4LoopbackAddr})
+	server, _, err := NewServer(t.Logf, 0, []netip.Addr{ipv4LoopbackAddr})
 	if err != nil {
 		t.Fatal(err)
 	}