Browse Source

cmd/{derper,stund},net/stunserver: add standalone stun server

Add a standalone server for STUN that can be hosted independently of the
derper, and factor that back into the derper.

Fixes #8434
Closes #8435
Closes #10745

Signed-off-by: James Tucker <[email protected]>
James Tucker 2 years ago
parent
commit
953fa80c6f

+ 4 - 2
Makefile

@@ -18,7 +18,8 @@ updatedeps: ## Update depaware deps
 	PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --update \
 		tailscale.com/cmd/tailscaled \
 		tailscale.com/cmd/tailscale \
-		tailscale.com/cmd/derper
+		tailscale.com/cmd/derper \
+		tailscale.com/cmd/stund
 
 depaware: ## Run depaware checks
 	# depaware (via x/tools/go/packages) shells back to "go", so make sure the "go"
@@ -26,7 +27,8 @@ depaware: ## Run depaware checks
 	PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --check \
 		tailscale.com/cmd/tailscaled \
 		tailscale.com/cmd/tailscale \
-		tailscale.com/cmd/derper
+		tailscale.com/cmd/derper \
+		tailscale.com/cmd/stund
 
 buildwindows: ## Build tailscale CLI for windows/amd64
 	GOOS=windows GOARCH=amd64 ./tool/go install tailscale.com/cmd/tailscale tailscale.com/cmd/tailscaled

+ 3 - 1
cmd/derper/depaware.txt

@@ -105,7 +105,8 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
         tailscale.com/net/netutil                                    from tailscale.com/client/tailscale
         tailscale.com/net/packet                                     from tailscale.com/wgengine/filter
         tailscale.com/net/sockstats                                  from tailscale.com/derp/derphttp
-        tailscale.com/net/stun                                       from tailscale.com/cmd/derper
+        tailscale.com/net/stun                                       from tailscale.com/net/stunserver
+        tailscale.com/net/stunserver                                 from tailscale.com/cmd/derper
    L    tailscale.com/net/tcpinfo                                    from tailscale.com/derp
         tailscale.com/net/tlsdial                                    from tailscale.com/derp/derphttp
         tailscale.com/net/tsaddr                                     from tailscale.com/ipn+
@@ -263,6 +264,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
         net/url                                                      from crypto/x509+
         os                                                           from crypto/rand+
         os/exec                                                      from golang.zx2c4.com/wireguard/windows/tunnel/winipcfg+
+        os/signal                                                    from tailscale.com/cmd/derper
    W    os/user                                                      from tailscale.com/util/winutil
         path                                                         from golang.org/x/crypto/acme/autocert+
         path/filepath                                                from crypto/x509+

+ 15 - 73
cmd/derper/derper.go

@@ -17,11 +17,12 @@ import (
 	"math"
 	"net"
 	"net/http"
-	"net/netip"
 	"os"
+	"os/signal"
 	"path/filepath"
 	"regexp"
 	"strings"
+	"syscall"
 	"time"
 
 	"go4.org/mem"
@@ -30,7 +31,7 @@ import (
 	"tailscale.com/derp"
 	"tailscale.com/derp/derphttp"
 	"tailscale.com/metrics"
-	"tailscale.com/net/stun"
+	"tailscale.com/net/stunserver"
 	"tailscale.com/tsweb"
 	"tailscale.com/types/key"
 	"tailscale.com/util/cmpx"
@@ -59,25 +60,11 @@ var (
 )
 
 var (
-	stats             = new(metrics.Set)
-	stunDisposition   = &metrics.LabelMap{Label: "disposition"}
-	stunAddrFamily    = &metrics.LabelMap{Label: "family"}
 	tlsRequestVersion = &metrics.LabelMap{Label: "version"}
 	tlsActiveVersion  = &metrics.LabelMap{Label: "version"}
-
-	stunReadError  = stunDisposition.Get("read_error")
-	stunNotSTUN    = stunDisposition.Get("not_stun")
-	stunWriteError = stunDisposition.Get("write_error")
-	stunSuccess    = stunDisposition.Get("success")
-
-	stunIPv4 = stunAddrFamily.Get("ipv4")
-	stunIPv6 = stunAddrFamily.Get("ipv6")
 )
 
 func init() {
-	stats.Set("counter_requests", stunDisposition)
-	stats.Set("counter_addrfamily", stunAddrFamily)
-	expvar.Publish("stun", stats)
 	expvar.Publish("derper_tls_request_version", tlsRequestVersion)
 	expvar.Publish("gauge_derper_tls_active_version", tlsActiveVersion)
 }
@@ -135,6 +122,9 @@ func writeNewConfig() config {
 func main() {
 	flag.Parse()
 
+	ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+	defer cancel()
+
 	if *dev {
 		*addr = ":3340" // above the keys DERP
 		log.Printf("Running in dev mode.")
@@ -146,6 +136,11 @@ func main() {
 		log.Fatalf("invalid server address: %v", err)
 	}
 
+	if *runSTUN {
+		ss := stunserver.New(ctx)
+		go ss.ListenAndServe(net.JoinHostPort(listenHost, fmt.Sprint(*stunPort)))
+	}
+
 	cfg := loadConfig()
 
 	serveTLS := tsweb.IsProd443(*addr) || *certMode == "manual"
@@ -221,10 +216,6 @@ func main() {
 	}))
 	debug.Handle("traffic", "Traffic check", http.HandlerFunc(s.ServeDebugTraffic))
 
-	if *runSTUN {
-		go serveSTUN(listenHost, *stunPort)
-	}
-
 	quietLogger := log.New(logFilter{}, "", 0)
 	httpsrv := &http.Server{
 		Addr:     *addr,
@@ -241,6 +232,10 @@ func main() {
 		ReadTimeout:  30 * time.Second,
 		WriteTimeout: 30 * time.Second,
 	}
+	go func() {
+		<-ctx.Done()
+		httpsrv.Shutdown(ctx)
+	}()
 
 	if serveTLS {
 		log.Printf("derper: serving on %s with TLS", *addr)
@@ -351,59 +346,6 @@ func probeHandler(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
-func serveSTUN(host string, port int) {
-	pc, err := net.ListenPacket("udp", net.JoinHostPort(host, fmt.Sprint(port)))
-	if err != nil {
-		log.Fatalf("failed to open STUN listener: %v", err)
-	}
-	log.Printf("running STUN server on %v", pc.LocalAddr())
-	serverSTUNListener(context.Background(), pc.(*net.UDPConn))
-}
-
-func serverSTUNListener(ctx context.Context, pc *net.UDPConn) {
-	var buf [64 << 10]byte
-	var (
-		n   int
-		ua  *net.UDPAddr
-		err error
-	)
-	for {
-		n, ua, err = pc.ReadFromUDP(buf[:])
-		if err != nil {
-			if ctx.Err() != nil {
-				return
-			}
-			log.Printf("STUN ReadFrom: %v", err)
-			time.Sleep(time.Second)
-			stunReadError.Add(1)
-			continue
-		}
-		pkt := buf[:n]
-		if !stun.Is(pkt) {
-			stunNotSTUN.Add(1)
-			continue
-		}
-		txid, err := stun.ParseBindingRequest(pkt)
-		if err != nil {
-			stunNotSTUN.Add(1)
-			continue
-		}
-		if ua.IP.To4() != nil {
-			stunIPv4.Add(1)
-		} else {
-			stunIPv6.Add(1)
-		}
-		addr, _ := netip.AddrFromSlice(ua.IP)
-		res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(ua.Port)))
-		_, err = pc.WriteTo(res, ua)
-		if err != nil {
-			stunWriteError.Add(1)
-		} else {
-			stunSuccess.Add(1)
-		}
-	}
-}
-
 var validProdHostname = regexp.MustCompile(`^derp([^.]*)\.tailscale\.com\.?$`)
 
 func prodAutocertHostPolicy(_ context.Context, host string) error {

+ 0 - 34
cmd/derper/derper_test.go

@@ -5,13 +5,11 @@ package main
 
 import (
 	"context"
-	"net"
 	"net/http"
 	"net/http/httptest"
 	"strings"
 	"testing"
 
-	"tailscale.com/net/stun"
 	"tailscale.com/tstest/deptest"
 )
 
@@ -39,38 +37,6 @@ func TestProdAutocertHostPolicy(t *testing.T) {
 	}
 }
 
-func BenchmarkServerSTUN(b *testing.B) {
-	b.ReportAllocs()
-	pc, err := net.ListenPacket("udp", "127.0.0.1:0")
-	if err != nil {
-		b.Fatal(err)
-	}
-	defer pc.Close()
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-	go serverSTUNListener(ctx, pc.(*net.UDPConn))
-	addr := pc.LocalAddr().(*net.UDPAddr)
-
-	var resBuf [1500]byte
-	cc, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")})
-	if err != nil {
-		b.Fatal(err)
-	}
-
-	tx := stun.NewTxID()
-	req := stun.Request(tx)
-	for i := 0; i < b.N; i++ {
-		if _, err := cc.WriteToUDP(req, addr); err != nil {
-			b.Fatal(err)
-		}
-		_, _, err := cc.ReadFromUDP(resBuf[:])
-		if err != nil {
-			b.Fatal(err)
-		}
-	}
-
-}
-
 func TestNoContent(t *testing.T) {
 	testCases := []struct {
 		name  string

+ 190 - 0
cmd/stund/depaware.txt

@@ -0,0 +1,190 @@
+tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depaware)
+
+        github.com/beorn7/perks/quantile                             from github.com/prometheus/client_golang/prometheus
+     💣 github.com/cespare/xxhash/v2                                 from github.com/prometheus/client_golang/prometheus
+        github.com/golang/protobuf/proto                             from github.com/matttproud/golang_protobuf_extensions/pbutil
+        github.com/google/uuid                                       from tailscale.com/tsweb
+        github.com/matttproud/golang_protobuf_extensions/pbutil      from github.com/prometheus/common/expfmt
+     💣 github.com/prometheus/client_golang/prometheus               from tailscale.com/tsweb/promvarz
+        github.com/prometheus/client_golang/prometheus/internal      from github.com/prometheus/client_golang/prometheus
+        github.com/prometheus/client_model/go                        from github.com/prometheus/client_golang/prometheus+
+        github.com/prometheus/common/expfmt                          from github.com/prometheus/client_golang/prometheus+
+        github.com/prometheus/common/internal/bitbucket.org/ww/goautoneg from github.com/prometheus/common/expfmt
+        github.com/prometheus/common/model                           from github.com/prometheus/client_golang/prometheus+
+  LD    github.com/prometheus/procfs                                 from github.com/prometheus/client_golang/prometheus
+  LD    github.com/prometheus/procfs/internal/fs                     from github.com/prometheus/procfs
+  LD    github.com/prometheus/procfs/internal/util                   from github.com/prometheus/procfs
+     💣 go4.org/mem                                                  from tailscale.com/metrics+
+        go4.org/netipx                                               from tailscale.com/net/tsaddr
+        google.golang.org/protobuf/encoding/prototext                from github.com/golang/protobuf/proto+
+        google.golang.org/protobuf/encoding/protowire                from github.com/golang/protobuf/proto+
+        google.golang.org/protobuf/internal/descfmt                  from google.golang.org/protobuf/internal/filedesc
+        google.golang.org/protobuf/internal/descopts                 from google.golang.org/protobuf/internal/filedesc+
+        google.golang.org/protobuf/internal/detrand                  from google.golang.org/protobuf/internal/descfmt+
+        google.golang.org/protobuf/internal/encoding/defval          from google.golang.org/protobuf/internal/encoding/tag+
+        google.golang.org/protobuf/internal/encoding/messageset      from google.golang.org/protobuf/encoding/prototext+
+        google.golang.org/protobuf/internal/encoding/tag             from google.golang.org/protobuf/internal/impl
+        google.golang.org/protobuf/internal/encoding/text            from google.golang.org/protobuf/encoding/prototext+
+        google.golang.org/protobuf/internal/errors                   from google.golang.org/protobuf/encoding/prototext+
+        google.golang.org/protobuf/internal/filedesc                 from google.golang.org/protobuf/internal/encoding/tag+
+        google.golang.org/protobuf/internal/filetype                 from google.golang.org/protobuf/runtime/protoimpl
+        google.golang.org/protobuf/internal/flags                    from google.golang.org/protobuf/encoding/prototext+
+        google.golang.org/protobuf/internal/genid                    from google.golang.org/protobuf/encoding/prototext+
+     💣 google.golang.org/protobuf/internal/impl                     from google.golang.org/protobuf/internal/filetype+
+        google.golang.org/protobuf/internal/order                    from google.golang.org/protobuf/encoding/prototext+
+        google.golang.org/protobuf/internal/pragma                   from google.golang.org/protobuf/encoding/prototext+
+        google.golang.org/protobuf/internal/set                      from google.golang.org/protobuf/encoding/prototext
+     💣 google.golang.org/protobuf/internal/strs                     from google.golang.org/protobuf/encoding/prototext+
+        google.golang.org/protobuf/internal/version                  from google.golang.org/protobuf/runtime/protoimpl
+        google.golang.org/protobuf/proto                             from github.com/golang/protobuf/proto+
+        google.golang.org/protobuf/reflect/protodesc                 from github.com/golang/protobuf/proto
+     💣 google.golang.org/protobuf/reflect/protoreflect              from github.com/golang/protobuf/proto+
+        google.golang.org/protobuf/reflect/protoregistry             from github.com/golang/protobuf/proto+
+        google.golang.org/protobuf/runtime/protoiface                from github.com/golang/protobuf/proto+
+        google.golang.org/protobuf/runtime/protoimpl                 from github.com/golang/protobuf/proto+
+        google.golang.org/protobuf/types/descriptorpb                from google.golang.org/protobuf/reflect/protodesc
+        google.golang.org/protobuf/types/known/timestamppb           from github.com/prometheus/client_golang/prometheus+
+        tailscale.com                                                from tailscale.com/version
+        tailscale.com/envknob                                        from tailscale.com/tsweb+
+        tailscale.com/metrics                                        from tailscale.com/net/stunserver+
+        tailscale.com/net/netaddr                                    from tailscale.com/net/tsaddr
+        tailscale.com/net/stun                                       from tailscale.com/net/stunserver
+        tailscale.com/net/stunserver                                 from tailscale.com/cmd/stund
+        tailscale.com/net/tsaddr                                     from tailscale.com/tsweb
+        tailscale.com/tailcfg                                        from tailscale.com/version
+        tailscale.com/tsweb                                          from tailscale.com/cmd/stund
+        tailscale.com/tsweb/promvarz                                 from tailscale.com/tsweb
+        tailscale.com/tsweb/varz                                     from tailscale.com/tsweb+
+        tailscale.com/types/dnstype                                  from tailscale.com/tailcfg
+        tailscale.com/types/ipproto                                  from tailscale.com/tailcfg
+        tailscale.com/types/key                                      from tailscale.com/tailcfg
+        tailscale.com/types/lazy                                     from tailscale.com/version+
+        tailscale.com/types/logger                                   from tailscale.com/tsweb
+        tailscale.com/types/opt                                      from tailscale.com/envknob+
+        tailscale.com/types/ptr                                      from tailscale.com/tailcfg
+        tailscale.com/types/structs                                  from tailscale.com/tailcfg+
+        tailscale.com/types/tkatype                                  from tailscale.com/tailcfg+
+        tailscale.com/types/views                                    from tailscale.com/net/tsaddr+
+        tailscale.com/util/cmpx                                      from tailscale.com/tailcfg+
+   L 💣 tailscale.com/util/dirwalk                                   from tailscale.com/metrics
+        tailscale.com/util/dnsname                                   from tailscale.com/tailcfg
+        tailscale.com/util/lineread                                  from tailscale.com/version/distro
+        tailscale.com/util/nocasemaps                                from tailscale.com/types/ipproto
+        tailscale.com/util/slicesx                                   from tailscale.com/tailcfg
+        tailscale.com/util/vizerror                                  from tailscale.com/tailcfg+
+        tailscale.com/version                                        from tailscale.com/envknob+
+        tailscale.com/version/distro                                 from tailscale.com/envknob
+        golang.org/x/crypto/blake2b                                  from golang.org/x/crypto/nacl/box
+        golang.org/x/crypto/chacha20                                 from golang.org/x/crypto/chacha20poly1305
+        golang.org/x/crypto/chacha20poly1305                         from crypto/tls
+        golang.org/x/crypto/cryptobyte                               from crypto/ecdsa+
+        golang.org/x/crypto/cryptobyte/asn1                          from crypto/ecdsa+
+        golang.org/x/crypto/curve25519                               from golang.org/x/crypto/nacl/box+
+        golang.org/x/crypto/hkdf                                     from crypto/tls
+        golang.org/x/crypto/nacl/box                                 from tailscale.com/types/key
+        golang.org/x/crypto/nacl/secretbox                           from golang.org/x/crypto/nacl/box
+        golang.org/x/crypto/salsa20/salsa                            from golang.org/x/crypto/nacl/box+
+        golang.org/x/net/dns/dnsmessage                              from net
+        golang.org/x/net/http/httpguts                               from net/http
+        golang.org/x/net/http/httpproxy                              from net/http
+        golang.org/x/net/http2/hpack                                 from net/http
+        golang.org/x/net/idna                                        from golang.org/x/net/http/httpguts+
+   D    golang.org/x/net/route                                       from net
+        golang.org/x/sys/cpu                                         from golang.org/x/crypto/blake2b+
+  LD    golang.org/x/sys/unix                                        from github.com/prometheus/procfs+
+   W    golang.org/x/sys/windows                                     from github.com/prometheus/client_golang/prometheus
+        golang.org/x/text/secure/bidirule                            from golang.org/x/net/idna
+        golang.org/x/text/transform                                  from golang.org/x/text/secure/bidirule+
+        golang.org/x/text/unicode/bidi                               from golang.org/x/net/idna+
+        golang.org/x/text/unicode/norm                               from golang.org/x/net/idna
+        bufio                                                        from compress/flate+
+        bytes                                                        from bufio+
+        cmp                                                          from slices
+        compress/flate                                               from compress/gzip
+        compress/gzip                                                from github.com/golang/protobuf/proto+
+        container/list                                               from crypto/tls+
+        context                                                      from crypto/tls+
+        crypto                                                       from crypto/ecdh+
+        crypto/aes                                                   from crypto/ecdsa+
+        crypto/cipher                                                from crypto/aes+
+        crypto/des                                                   from crypto/tls+
+        crypto/dsa                                                   from crypto/x509
+        crypto/ecdh                                                  from crypto/ecdsa+
+        crypto/ecdsa                                                 from crypto/tls+
+        crypto/ed25519                                               from crypto/tls+
+        crypto/elliptic                                              from crypto/ecdsa+
+        crypto/hmac                                                  from crypto/tls+
+        crypto/md5                                                   from crypto/tls+
+        crypto/rand                                                  from crypto/ed25519+
+        crypto/rc4                                                   from crypto/tls
+        crypto/rsa                                                   from crypto/tls+
+        crypto/sha1                                                  from crypto/tls+
+        crypto/sha256                                                from crypto/tls+
+        crypto/sha512                                                from crypto/ecdsa+
+        crypto/subtle                                                from crypto/aes+
+        crypto/tls                                                   from net/http+
+        crypto/x509                                                  from crypto/tls
+        crypto/x509/pkix                                             from crypto/x509
+        database/sql/driver                                          from github.com/google/uuid
+        embed                                                        from crypto/internal/nistec+
+        encoding                                                     from encoding/json+
+        encoding/asn1                                                from crypto/x509+
+        encoding/base64                                              from encoding/json+
+        encoding/binary                                              from compress/gzip+
+        encoding/hex                                                 from crypto/x509+
+        encoding/json                                                from expvar+
+        encoding/pem                                                 from crypto/tls+
+        errors                                                       from bufio+
+        expvar                                                       from github.com/prometheus/client_golang/prometheus+
+        flag                                                         from tailscale.com/cmd/stund
+        fmt                                                          from compress/flate+
+        go/token                                                     from google.golang.org/protobuf/internal/strs
+        hash                                                         from crypto+
+        hash/crc32                                                   from compress/gzip+
+        hash/fnv                                                     from google.golang.org/protobuf/internal/detrand
+        hash/maphash                                                 from go4.org/mem
+        html                                                         from net/http/pprof+
+        io                                                           from bufio+
+        io/fs                                                        from crypto/x509+
+        io/ioutil                                                    from github.com/golang/protobuf/proto+
+        log                                                          from expvar+
+        log/internal                                                 from log
+        maps                                                         from tailscale.com/tailcfg+
+        math                                                         from compress/flate+
+        math/big                                                     from crypto/dsa+
+        math/bits                                                    from compress/flate+
+        math/rand                                                    from math/big+
+        mime                                                         from github.com/prometheus/common/expfmt+
+        mime/multipart                                               from net/http
+        mime/quotedprintable                                         from mime/multipart
+        net                                                          from crypto/tls+
+        net/http                                                     from expvar+
+        net/http/httptrace                                           from net/http
+        net/http/internal                                            from net/http
+        net/http/pprof                                               from tailscale.com/tsweb+
+        net/netip                                                    from go4.org/netipx+
+        net/textproto                                                from golang.org/x/net/http/httpguts+
+        net/url                                                      from crypto/x509+
+        os                                                           from crypto/rand+
+        os/signal                                                    from tailscale.com/cmd/stund
+        path                                                         from github.com/prometheus/client_golang/prometheus/internal+
+        path/filepath                                                from crypto/x509+
+        reflect                                                      from crypto/x509+
+        regexp                                                       from github.com/prometheus/client_golang/prometheus/internal+
+        regexp/syntax                                                from regexp
+        runtime/debug                                                from github.com/prometheus/client_golang/prometheus+
+        runtime/metrics                                              from github.com/prometheus/client_golang/prometheus+
+        runtime/pprof                                                from net/http/pprof
+        runtime/trace                                                from net/http/pprof
+        slices                                                       from tailscale.com/metrics+
+        sort                                                         from compress/flate+
+        strconv                                                      from compress/flate+
+        strings                                                      from bufio+
+        sync                                                         from compress/flate+
+        sync/atomic                                                  from context+
+        syscall                                                      from crypto/rand+
+        text/tabwriter                                               from runtime/pprof
+        time                                                         from compress/gzip+
+        unicode                                                      from bytes+
+        unicode/utf16                                                from crypto/x509+
+        unicode/utf8                                                 from bufio+

+ 48 - 0
cmd/stund/stund.go

@@ -0,0 +1,48 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// The stund binary is a standalone STUN server.
+package main
+
+import (
+	"context"
+	"flag"
+	"io"
+	"log"
+	"net/http"
+	"os/signal"
+	"syscall"
+
+	"tailscale.com/net/stunserver"
+	"tailscale.com/tsweb"
+)
+
+var (
+	stunAddr = flag.String("stun", ":3478", "UDP address on which to start the STUN server")
+	httpAddr = flag.String("http", ":3479", "address on which to start the debug http server")
+)
+
+func main() {
+	flag.Parse()
+
+	ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+	defer cancel()
+
+	log.Printf("HTTP server listening on %s", *httpAddr)
+	go http.ListenAndServe(*httpAddr, mux())
+
+	s := stunserver.New(ctx)
+	if err := s.ListenAndServe(*stunAddr); err != nil {
+		log.Fatal(err)
+	}
+}
+
+func mux() *http.ServeMux {
+	mux := http.NewServeMux()
+	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+		io.WriteString(w, "<h1>stund</h1><a href=/debug>/debug</a>")
+	})
+	debug := tsweb.Debugger(mux)
+	debug.KV("stun_addr", *stunAddr)
+	return mux
+}

+ 126 - 0
net/stunserver/stunserver.go

@@ -0,0 +1,126 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package stunserver implements a STUN server. The package publishes a number of stats
+// to expvar under the top level label "stun". Logs are sent to the standard log package.
+package stunserver
+
+import (
+	"context"
+	"errors"
+	"expvar"
+	"io"
+	"log"
+	"net"
+	"net/netip"
+	"time"
+
+	"tailscale.com/metrics"
+	"tailscale.com/net/stun"
+)
+
+var (
+	stats           = new(metrics.Set)
+	stunDisposition = &metrics.LabelMap{Label: "disposition"}
+	stunAddrFamily  = &metrics.LabelMap{Label: "family"}
+	stunReadError   = stunDisposition.Get("read_error")
+	stunNotSTUN     = stunDisposition.Get("not_stun")
+	stunWriteError  = stunDisposition.Get("write_error")
+	stunSuccess     = stunDisposition.Get("success")
+
+	stunIPv4 = stunAddrFamily.Get("ipv4")
+	stunIPv6 = stunAddrFamily.Get("ipv6")
+)
+
+func init() {
+	stats.Set("counter_requests", stunDisposition)
+	stats.Set("counter_addrfamily", stunAddrFamily)
+	expvar.Publish("stun", stats)
+}
+
+type STUNServer struct {
+	ctx context.Context // ctx signals service shutdown
+	pc  *net.UDPConn    // pc is the UDP listener
+}
+
+// New creates a new STUN server. The server is shutdown when ctx is done.
+func New(ctx context.Context) *STUNServer {
+	return &STUNServer{ctx: ctx}
+}
+
+// Listen binds the listen socket for the server at listenAddr.
+func (s *STUNServer) Listen(listenAddr string) error {
+	uaddr, err := net.ResolveUDPAddr("udp", listenAddr)
+	if err != nil {
+		return err
+	}
+	s.pc, err = net.ListenUDP("udp", uaddr)
+	if err != nil {
+		return err
+	}
+	log.Printf("STUN server listening on %v", s.LocalAddr())
+	// close the listener on shutdown in order to break out of the read loop
+	go func() {
+		<-s.ctx.Done()
+		s.pc.Close()
+	}()
+	return nil
+}
+
+// Serve starts serving responses to STUN requests. Listen must be called before Serve.
+func (s *STUNServer) Serve() error {
+	var buf [64 << 10]byte
+	var (
+		n   int
+		ua  *net.UDPAddr
+		err error
+	)
+	for {
+		n, ua, err = s.pc.ReadFromUDP(buf[:])
+		if err != nil {
+			if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
+				return nil
+			}
+			log.Printf("STUN ReadFrom: %v", err)
+			time.Sleep(time.Second)
+			stunReadError.Add(1)
+			continue
+		}
+		pkt := buf[:n]
+		if !stun.Is(pkt) {
+			stunNotSTUN.Add(1)
+			continue
+		}
+		txid, err := stun.ParseBindingRequest(pkt)
+		if err != nil {
+			stunNotSTUN.Add(1)
+			continue
+		}
+		if ua.IP.To4() != nil {
+			stunIPv4.Add(1)
+		} else {
+			stunIPv6.Add(1)
+		}
+		addr, _ := netip.AddrFromSlice(ua.IP)
+		res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(ua.Port)))
+		_, err = s.pc.WriteTo(res, ua)
+		if err != nil {
+			stunWriteError.Add(1)
+		} else {
+			stunSuccess.Add(1)
+		}
+	}
+}
+
+// ListenAndServe starts the STUN server on listenAddr.
+func (s *STUNServer) ListenAndServe(listenAddr string) error {
+	if err := s.Listen(listenAddr); err != nil {
+		return err
+	}
+	return s.Serve()
+}
+
+// LocalAddr returns the local address of the STUN server. It must not be called before ListenAndServe.
+func (s *STUNServer) LocalAddr() net.Addr {
+	return s.pc.LocalAddr()
+}

+ 88 - 0
net/stunserver/stunserver_test.go

@@ -0,0 +1,88 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package stunserver
+
+import (
+	"context"
+	"net"
+	"sync"
+	"testing"
+	"time"
+
+	"tailscale.com/net/stun"
+	"tailscale.com/util/must"
+)
+
+func TestSTUNServer(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	s := New(ctx)
+	must.Do(s.Listen("localhost:0"))
+	var w sync.WaitGroup
+	w.Add(1)
+	var serveErr error
+	go func() {
+		defer w.Done()
+		serveErr = s.Serve()
+	}()
+
+	c := must.Get(net.DialUDP("udp", nil, s.LocalAddr().(*net.UDPAddr)))
+	defer c.Close()
+	c.SetDeadline(time.Now().Add(5 * time.Second))
+	txid := stun.NewTxID()
+	_, err := c.Write(stun.Request(txid))
+	if err != nil {
+		t.Fatalf("failed to write STUN request: %v", err)
+	}
+	var buf [64 << 10]byte
+	n, err := c.Read(buf[:])
+	if err != nil {
+		t.Fatalf("failed to read STUN response: %v", err)
+	}
+	if !stun.Is(buf[:n]) {
+		t.Fatalf("response is not STUN")
+	}
+	tid, _, err := stun.ParseResponse(buf[:n])
+	if err != nil {
+		t.Fatalf("failed to parse STUN response: %v", err)
+	}
+	if tid != txid {
+		t.Fatalf("STUN response has wrong transaction ID; got %d, want %d", tid, txid)
+	}
+
+	cancel()
+	w.Wait()
+	if serveErr != nil {
+		t.Fatalf("failed to listen and serve: %v", serveErr)
+	}
+}
+
+func BenchmarkServerSTUN(b *testing.B) {
+	b.ReportAllocs()
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	s := New(ctx)
+	s.Listen("localhost:0")
+	go s.Serve()
+	addr := s.LocalAddr().(*net.UDPAddr)
+
+	var resBuf [1500]byte
+	cc, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")})
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	tx := stun.NewTxID()
+	req := stun.Request(tx)
+	for i := 0; i < b.N; i++ {
+		if _, err := cc.WriteToUDP(req, addr); err != nil {
+			b.Fatal(err)
+		}
+		_, _, err := cc.ReadFromUDP(resBuf[:])
+		if err != nil {
+			b.Fatal(err)
+		}
+	}
+}