|
|
@@ -8,6 +8,7 @@ package udprelay
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
+ "context"
|
|
|
"crypto/rand"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
@@ -19,11 +20,18 @@ import (
|
|
|
"time"
|
|
|
|
|
|
"go4.org/mem"
|
|
|
+ "tailscale.com/client/local"
|
|
|
"tailscale.com/disco"
|
|
|
+ "tailscale.com/net/netcheck"
|
|
|
+ "tailscale.com/net/netmon"
|
|
|
"tailscale.com/net/packet"
|
|
|
+ "tailscale.com/net/stun"
|
|
|
"tailscale.com/net/udprelay/endpoint"
|
|
|
"tailscale.com/tstime"
|
|
|
"tailscale.com/types/key"
|
|
|
+ "tailscale.com/types/logger"
|
|
|
+ "tailscale.com/util/eventbus"
|
|
|
+ "tailscale.com/util/set"
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
@@ -42,25 +50,22 @@ const (
|
|
|
|
|
|
// Server implements an experimental UDP relay server.
|
|
|
type Server struct {
|
|
|
- // disco keypair used as part of 3-way bind handshake
|
|
|
- disco key.DiscoPrivate
|
|
|
- discoPublic key.DiscoPublic
|
|
|
-
|
|
|
+ // The following fields are initialized once and never mutated.
|
|
|
+ logf logger.Logf
|
|
|
+ disco key.DiscoPrivate
|
|
|
+ discoPublic key.DiscoPublic
|
|
|
bindLifetime time.Duration
|
|
|
steadyStateLifetime time.Duration
|
|
|
-
|
|
|
- // addrPorts contains the ip:port pairs returned as candidate server
|
|
|
- // endpoints in response to an allocation request.
|
|
|
- addrPorts []netip.AddrPort
|
|
|
-
|
|
|
- uc *net.UDPConn
|
|
|
-
|
|
|
- closeOnce sync.Once
|
|
|
- wg sync.WaitGroup
|
|
|
- closeCh chan struct{}
|
|
|
+ bus *eventbus.Bus
|
|
|
+ uc *net.UDPConn
|
|
|
+ closeOnce sync.Once
|
|
|
+ wg sync.WaitGroup
|
|
|
+ closeCh chan struct{}
|
|
|
+ netChecker *netcheck.Client
|
|
|
+
|
|
|
+ mu sync.Mutex // guards the following fields
|
|
|
+ addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints
|
|
|
closed bool
|
|
|
-
|
|
|
- mu sync.Mutex // guards the following fields
|
|
|
lamportID uint64
|
|
|
vniPool []uint32 // the pool of available VNIs
|
|
|
byVNI map[uint32]*serverEndpoint
|
|
|
@@ -270,14 +275,13 @@ func (e *serverEndpoint) isBound() bool {
|
|
|
|
|
|
// NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet
|
|
|
// supported. Port may be 0, and what ultimately gets bound is returned as
|
|
|
-// 'boundPort'. Supplied 'addrs' are joined with 'boundPort' and returned as
|
|
|
-// [endpoint.ServerEndpoint.AddrPorts] in response to Server.AllocateEndpoint()
|
|
|
-// requests.
|
|
|
+// 'boundPort'. If len(overrideAddrs) > 0 these will be used in place of dynamic
|
|
|
+// discovery, which is useful to override in tests.
|
|
|
//
|
|
|
// TODO: IPv6 support
|
|
|
-// TODO: dynamic addrs:port discovery
|
|
|
-func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err error) {
|
|
|
+func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, boundPort uint16, err error) {
|
|
|
s = &Server{
|
|
|
+ logf: logger.WithPrefix(logf, "relayserver"),
|
|
|
disco: key.NewDisco(),
|
|
|
bindLifetime: defaultBindLifetime,
|
|
|
steadyStateLifetime: defaultSteadyStateLifetime,
|
|
|
@@ -292,26 +296,120 @@ func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err erro
|
|
|
for i := 1; i < 1<<24; i++ {
|
|
|
s.vniPool = append(s.vniPool, uint32(i))
|
|
|
}
|
|
|
- boundPort, err = s.listenOn(port)
|
|
|
+
|
|
|
+ bus := eventbus.New()
|
|
|
+ s.bus = bus
|
|
|
+ netMon, err := netmon.New(s.bus, logf)
|
|
|
if err != nil {
|
|
|
return nil, 0, err
|
|
|
}
|
|
|
- addrPorts := make([]netip.AddrPort, 0, len(addrs))
|
|
|
- for _, addr := range addrs {
|
|
|
- addrPort, err := netip.ParseAddrPort(net.JoinHostPort(addr.String(), strconv.Itoa(boundPort)))
|
|
|
- if err != nil {
|
|
|
- return nil, 0, err
|
|
|
- }
|
|
|
- addrPorts = append(addrPorts, addrPort)
|
|
|
+ s.netChecker = &netcheck.Client{
|
|
|
+ NetMon: netMon,
|
|
|
+ Logf: logger.WithPrefix(logf, "relayserver: netcheck:"),
|
|
|
+ SendPacket: func(b []byte, addrPort netip.AddrPort) (int, error) {
|
|
|
+ return s.uc.WriteToUDPAddrPort(b, addrPort)
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ boundPort, err = s.listenOn(port)
|
|
|
+ if err != nil {
|
|
|
+ return nil, 0, err
|
|
|
}
|
|
|
- s.addrPorts = addrPorts
|
|
|
- s.wg.Add(2)
|
|
|
+
|
|
|
+ s.wg.Add(1)
|
|
|
go s.packetReadLoop()
|
|
|
+ s.wg.Add(1)
|
|
|
go s.endpointGCLoop()
|
|
|
+ if len(overrideAddrs) > 0 {
|
|
|
+ var addrPorts set.Set[netip.AddrPort]
|
|
|
+ addrPorts.Make()
|
|
|
+ for _, addr := range overrideAddrs {
|
|
|
+ if addr.IsValid() {
|
|
|
+ addrPorts.Add(netip.AddrPortFrom(addr, boundPort))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ s.addrPorts = addrPorts.Slice()
|
|
|
+ } else {
|
|
|
+ s.wg.Add(1)
|
|
|
+ go s.addrDiscoveryLoop()
|
|
|
+ }
|
|
|
return s, boundPort, nil
|
|
|
}
|
|
|
|
|
|
-func (s *Server) listenOn(port int) (int, error) {
|
|
|
+func (s *Server) addrDiscoveryLoop() {
|
|
|
+ defer s.wg.Done()
|
|
|
+
|
|
|
+ timer := time.NewTimer(0) // fire immediately
|
|
|
+ defer timer.Stop()
|
|
|
+
|
|
|
+ getAddrPorts := func() ([]netip.AddrPort, error) {
|
|
|
+ var addrPorts set.Set[netip.AddrPort]
|
|
|
+ addrPorts.Make()
|
|
|
+
|
|
|
+ // get local addresses
|
|
|
+ localPort := s.uc.LocalAddr().(*net.UDPAddr).Port
|
|
|
+ ips, _, err := netmon.LocalAddresses()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ for _, ip := range ips {
|
|
|
+ if ip.IsValid() {
|
|
|
+ addrPorts.Add(netip.AddrPortFrom(ip, uint16(localPort)))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // fetch DERPMap to feed to netcheck
|
|
|
+ derpMapCtx, derpMapCancel := context.WithTimeout(context.Background(), time.Second)
|
|
|
+ defer derpMapCancel()
|
|
|
+ localClient := &local.Client{}
|
|
|
+ // TODO(jwhited): We are in-process so use eventbus or similar.
|
|
|
+ // local.Client gets us going.
|
|
|
+ dm, err := localClient.CurrentDERPMap(derpMapCtx)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // get addrPorts as visible from DERP
|
|
|
+ netCheckerCtx, netCheckerCancel := context.WithTimeout(context.Background(), netcheck.ReportTimeout)
|
|
|
+ defer netCheckerCancel()
|
|
|
+ rep, err := s.netChecker.GetReport(netCheckerCtx, dm, &netcheck.GetReportOpts{
|
|
|
+ OnlySTUN: true,
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if rep.GlobalV4.IsValid() {
|
|
|
+ addrPorts.Add(rep.GlobalV4)
|
|
|
+ }
|
|
|
+ if rep.GlobalV6.IsValid() {
|
|
|
+ addrPorts.Add(rep.GlobalV6)
|
|
|
+ }
|
|
|
+ // TODO(jwhited): consider logging if rep.MappingVariesByDestIP as
|
|
|
+ // that's a hint we are not well-positioned to operate as a UDP relay.
|
|
|
+ return addrPorts.Slice(), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-timer.C:
|
|
|
+ // Mirror magicsock behavior for duration between STUN. We consider
|
|
|
+ // 30s a min bound for NAT timeout.
|
|
|
+ timer.Reset(tstime.RandomDurationBetween(20*time.Second, 26*time.Second))
|
|
|
+ addrPorts, err := getAddrPorts()
|
|
|
+ if err != nil {
|
|
|
+ s.logf("error discovering IP:port candidates: %v", err)
|
|
|
+ }
|
|
|
+ s.mu.Lock()
|
|
|
+ s.addrPorts = addrPorts
|
|
|
+ s.mu.Unlock()
|
|
|
+ case <-s.closeCh:
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+func (s *Server) listenOn(port int) (uint16, error) {
|
|
|
uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port})
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
@@ -322,13 +420,13 @@ func (s *Server) listenOn(port int) (int, error) {
|
|
|
s.uc.Close()
|
|
|
return 0, err
|
|
|
}
|
|
|
- boundPort, err := strconv.Atoi(boundPortStr)
|
|
|
+ boundPort, err := strconv.ParseUint(boundPortStr, 10, 16)
|
|
|
if err != nil {
|
|
|
s.uc.Close()
|
|
|
return 0, err
|
|
|
}
|
|
|
s.uc = uc
|
|
|
- return boundPort, nil
|
|
|
+ return uint16(boundPort), nil
|
|
|
}
|
|
|
|
|
|
// Close closes the server.
|
|
|
@@ -343,6 +441,7 @@ func (s *Server) Close() error {
|
|
|
clear(s.byDisco)
|
|
|
s.vniPool = nil
|
|
|
s.closed = true
|
|
|
+ s.bus.Close()
|
|
|
})
|
|
|
return nil
|
|
|
}
|
|
|
@@ -378,6 +477,13 @@ func (s *Server) endpointGCLoop() {
|
|
|
}
|
|
|
|
|
|
func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) {
|
|
|
+ if stun.Is(b) && b[1] == 0x01 {
|
|
|
+ // A b[1] value of 0x01 (STUN method binding) is sufficiently
|
|
|
+ // non-overlapping with the Geneve header where the LSB is always 0
|
|
|
+ // (part of 6 "reserved" bits).
|
|
|
+ s.netChecker.ReceiveSTUNPacket(b, from)
|
|
|
+ return
|
|
|
+ }
|
|
|
gh := packet.GeneveHeader{}
|
|
|
err := gh.Decode(b)
|
|
|
if err != nil {
|
|
|
@@ -426,6 +532,10 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
|
|
|
return endpoint.ServerEndpoint{}, ErrServerClosed
|
|
|
}
|
|
|
|
|
|
+ if len(s.addrPorts) == 0 {
|
|
|
+ return endpoint.ServerEndpoint{}, errors.New("server addrPorts are not yet known")
|
|
|
+ }
|
|
|
+
|
|
|
if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 {
|
|
|
return endpoint.ServerEndpoint{}, fmt.Errorf("client disco equals server disco: %s", s.discoPublic.ShortString())
|
|
|
}
|
|
|
@@ -439,8 +549,13 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
|
|
|
// TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt
|
|
|
// to give the client a more accurate picture of the bind window.
|
|
|
return endpoint.ServerEndpoint{
|
|
|
- ServerDisco: s.discoPublic,
|
|
|
- AddrPorts: s.addrPorts,
|
|
|
+ ServerDisco: s.discoPublic,
|
|
|
+ // Returning the "latest" addrPorts for an existing allocation is
|
|
|
+ // the simple choice. It may not be the best depending on client
|
|
|
+ // behaviors and endpoint state (bound or not). We might want to
|
|
|
+ // consider storing them (maybe interning) in the [*serverEndpoint]
|
|
|
+ // at allocation time.
|
|
|
+ AddrPorts: slices.Clone(s.addrPorts),
|
|
|
VNI: e.vni,
|
|
|
LamportID: e.lamportID,
|
|
|
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
|
|
|
@@ -469,7 +584,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
|
|
|
|
|
|
return endpoint.ServerEndpoint{
|
|
|
ServerDisco: s.discoPublic,
|
|
|
- AddrPorts: s.addrPorts,
|
|
|
+ AddrPorts: slices.Clone(s.addrPorts),
|
|
|
VNI: e.vni,
|
|
|
LamportID: e.lamportID,
|
|
|
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
|