Bläddra i källkod

cmd/sniproxy: implement support for control configuration, multiple addresses

 * Implement missing tests for sniproxy
 * Wire sniproxy to new appc package
 * Add support to tsnet for routing subnet router traffic into netstack, so it can be handled

Updates: https://github.com/tailscale/corp/issues/15038
Signed-off-by: Tom DNetto <[email protected]>
Tom DNetto 2 år sedan
förälder
incheckning
a7c80c332a
4 ändrade filer med 354 tillägg och 271 borttagningar
  1. 4 4
      appc/appc.go
  2. 152 267
      cmd/sniproxy/sniproxy.go
  3. 186 0
      cmd/sniproxy/sniproxy_test.go
  4. 12 0
      tstest/integration/testcontrol/testcontrol.go

+ 4 - 4
appc/appc.go

@@ -67,6 +67,7 @@ func (s *Server) Configure(cfg *appctype.AppConnectorConfig) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	s.connectors = makeConnectorsFromConfig(cfg)
+	log.Printf("installed app connector config: %+v", s.connectors)
 }
 
 // HandleTCPFlow implements tsnet.FallbackTCPHandler.
@@ -193,8 +194,7 @@ func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (re
 }
 
 func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) {
-	buf := make([]byte, 1500)
-	resp := dnsmessage.NewBuilder(buf,
+	resp := dnsmessage.NewBuilder(response,
 		dnsmessage.Header{
 			ID:            req.Header.ID,
 			Response:      true,
@@ -203,8 +203,8 @@ func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (respon
 	resp.EnableCompression()
 
 	if len(req.Questions) == 0 {
-		buf, _ = resp.Finish()
-		return buf, nil
+		response, _ = resp.Finish()
+		return response, nil
 	}
 	q := req.Questions[0]
 	err = resp.StartQuestions()

+ 152 - 267
cmd/sniproxy/sniproxy.go

@@ -10,30 +10,34 @@ package main
 import (
 	"context"
 	"errors"
-	"expvar"
 	"flag"
 	"fmt"
 	"log"
 	"net"
 	"net/http"
+	"net/netip"
 	"os"
+	"sort"
 	"strconv"
 	"strings"
-	"time"
 
 	"github.com/peterbourgon/ff/v3"
 	"golang.org/x/net/dns/dnsmessage"
-	"inet.af/tcpproxy"
+	"tailscale.com/appc"
 	"tailscale.com/client/tailscale"
 	"tailscale.com/hostinfo"
-	"tailscale.com/metrics"
-	"tailscale.com/net/netutil"
+	"tailscale.com/ipn"
+	"tailscale.com/tailcfg"
 	"tailscale.com/tsnet"
 	"tailscale.com/tsweb"
+	"tailscale.com/types/appctype"
+	"tailscale.com/types/ipproto"
 	"tailscale.com/types/nettype"
-	"tailscale.com/util/clientmetric"
+	"tailscale.com/util/mak"
 )
 
+const configCapKey = "tailscale.com/sniproxy"
+
 var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
 
 // portForward is the state for a single port forwarding entry, as passed to the --forward flag.
@@ -68,6 +72,7 @@ func parseForward(value string) (*portForward, error) {
 }
 
 func main() {
+	// Parse flags
 	fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError)
 	var (
 		ports        = fs.String("ports", "443", "comma-separated list of ports to proxy")
@@ -77,334 +82,214 @@ func main() {
 		debugPort    = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint")
 		hostname     = fs.String("hostname", "", "Hostname to register the service under")
 	)
-
 	err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC"))
 	if err != nil {
 		log.Fatal("ff.Parse")
 	}
-	if *ports == "" {
-		log.Fatal("no ports")
-	}
 
-	hostinfo.SetApp("sniproxy")
+	var ts tsnet.Server
+	defer ts.Close()
 
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards)
+}
+
+// run actually runs the sniproxy. Its separate from main() to assist in testing.
+func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) {
+	// Wire up Tailscale node + app connector server
+	hostinfo.SetApp("sniproxy")
 	var s server
-	s.ts.Port = uint16(*wgPort)
-	s.ts.Hostname = *hostname
-	defer s.ts.Close()
+	s.ts = ts
+
+	s.ts.Port = uint16(wgPort)
+	s.ts.Hostname = hostname
 
 	lc, err := s.ts.LocalClient()
 	if err != nil {
-		log.Fatal(err)
+		log.Fatalf("LocalClient() failed: %v", err)
 	}
 	s.lc = lc
-	s.initMetrics()
-
-	for _, portStr := range strings.Split(*ports, ",") {
-		ln, err := s.ts.Listen("tcp", ":"+portStr)
-		if err != nil {
-			log.Fatal(err)
-		}
-		log.Printf("Serving on port %v ...", portStr)
-		go s.serve(ln)
-	}
-
-	for _, forwStr := range strings.Split(*forwards, ",") {
-		if forwStr == "" {
-			continue
-		}
-		forw, err := parseForward(forwStr)
-		if err != nil {
-			log.Fatal(err)
-		}
-
-		ln, err := s.ts.Listen("tcp", ":"+strconv.Itoa(forw.Port))
-		if err != nil {
-			log.Fatal(err)
-		}
-		log.Printf("Serving on port %d to %s...", forw.Port, forw.Destination)
-
-		// Add an entry to the expvar LabelMap for Prometheus metrics,
-		// and create a clientmetric to report that same value.
-		service := portNumberToName(forw)
-		s.numTCPsessions.SetInt64(service, 0)
-		metric := fmt.Sprintf("sniproxy_tcp_sessions_%s", service)
-		clientmetric.NewCounterFunc(metric, func() int64 {
-			return s.numTCPsessions.Get(service).Value()
-		})
-
-		go s.forward(ln, forw)
-	}
+	s.ts.RegisterFallbackTCPHandler(s.appc.HandleTCPFlow)
 
+	// Start special-purpose listeners: dns, http promotion, debug server
 	ln, err := s.ts.Listen("udp", ":53")
 	if err != nil {
-		log.Fatal(err)
+		log.Fatalf("failed listening on port 53: %v", err)
 	}
+	defer ln.Close()
 	go s.serveDNS(ln)
-
-	if *promoteHTTPS {
+	if promoteHTTPS {
 		ln, err := s.ts.Listen("tcp", ":80")
 		if err != nil {
-			log.Fatal(err)
+			log.Fatalf("failed listening on port 80: %v", err)
 		}
+		defer ln.Close()
 		log.Printf("Promoting HTTP to HTTPS ...")
 		go s.promoteHTTPS(ln)
 	}
-
-	if *debugPort != 0 {
+	if debugPort != 0 {
 		mux := http.NewServeMux()
 		tsweb.Debugger(mux)
-		dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", *debugPort))
+		dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort))
 		if err != nil {
-			log.Fatal(err)
+			log.Fatalf("failed listening on debug port: %v", err)
 		}
+		defer dln.Close()
 		go func() {
-			log.Fatal(http.Serve(dln, mux))
+			log.Fatalf("debug serve: %v", http.Serve(dln, mux))
 		}()
 	}
 
-	select {}
-}
-
-type server struct {
-	ts tsnet.Server
-	lc *tailscale.LocalClient
-
-	numTLSsessions expvar.Int
-	numTCPsessions *metrics.LabelMap
-	numBadAddrPort expvar.Int
-	dnsResponses   expvar.Int
-	dnsFailures    expvar.Int
-	httpPromoted   expvar.Int
-}
-
-func (s *server) serve(ln net.Listener) {
-	for {
-		c, err := ln.Accept()
-		if err != nil {
-			log.Fatal(err)
-		}
-		go s.serveConn(c)
+	// Finally, start mainloop to configure app connector based on information
+	// in the netmap.
+	// We set the NotifyInitialNetMap flag so we will always get woken with the
+	// current netmap, before only being woken on changes.
+	bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys)
+	if err != nil {
+		log.Fatalf("watching IPN bus: %v", err)
 	}
-}
-
-func (s *server) forward(ln net.Listener, forw *portForward) {
+	defer bus.Close()
 	for {
-		c, err := ln.Accept()
+		msg, err := bus.Next()
 		if err != nil {
-			log.Fatal(err)
+			if errors.Is(err, context.Canceled) {
+				return
+			}
+			log.Fatalf("reading IPN bus: %v", err)
 		}
-		go s.forwardConn(c, forw)
-	}
-}
 
-func (s *server) serveDNS(ln net.Listener) {
-	for {
-		c, err := ln.Accept()
-		if err != nil {
-			log.Fatal(err)
+		// NetMap contains app-connector configuration
+		if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() {
+			sn := nm.SelfNode.AsStruct()
+
+			var c appctype.AppConnectorConfig
+			nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey)
+			if err != nil {
+				log.Printf("failed to read app connector configuration from coordination server: %v", err)
+			} else if len(nmConf) > 0 {
+				c = nmConf[0]
+			}
+
+			if c.AdvertiseRoutes {
+				if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil {
+					log.Printf("failed to advertise routes: %v", err)
+				}
+			}
+
+			// Backwards compatibility: combine any configuration from control with flags specified
+			// on the command line. This is intentionally done after we advertise any routes
+			// because its never correct to advertise the nodes native IP addresses.
+			s.mergeConfigFromFlags(&c, ports, forwards)
+			s.appc.Configure(&c)
 		}
-		go s.serveDNSConn(c.(nettype.ConnPacketConn))
 	}
 }
 
-func (s *server) serveDNSConn(c nettype.ConnPacketConn) {
-	defer c.Close()
-	c.SetReadDeadline(time.Now().Add(5 * time.Second))
-	buf := make([]byte, 1500)
-	n, err := c.Read(buf)
-	if err != nil {
-		log.Printf("c.Read failed: %v\n ", err)
-		s.dnsFailures.Add(1)
-		return
-	}
-
-	var msg dnsmessage.Message
-	err = msg.Unpack(buf[:n])
-	if err != nil {
-		log.Printf("dnsmessage unpack failed: %v\n ", err)
-		s.dnsFailures.Add(1)
-		return
-	}
-
-	buf, err = s.dnsResponse(&msg)
-	if err != nil {
-		log.Printf("s.dnsResponse failed: %v\n", err)
-		s.dnsFailures.Add(1)
-		return
-	}
-
-	_, err = c.Write(buf)
-	if err != nil {
-		log.Printf("c.Write failed: %v\n", err)
-		s.dnsFailures.Add(1)
-		return
-	}
-
-	s.dnsResponses.Add(1)
+type server struct {
+	appc appc.Server
+	ts   *tsnet.Server
+	lc   *tailscale.LocalClient
 }
 
-func (s *server) serveConn(c net.Conn) {
-	addrPortStr := c.LocalAddr().String()
-	_, port, err := net.SplitHostPort(addrPortStr)
-	if err != nil {
-		log.Printf("bogus addrPort %q", addrPortStr)
-		s.numBadAddrPort.Add(1)
-		c.Close()
-		return
+func (s *server) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error {
+	// Collect the set of addresses to advertise, using a map
+	// to avoid duplicate entries.
+	addrs := map[netip.Addr]struct{}{}
+	for _, c := range c.SNIProxy {
+		for _, ip := range c.Addrs {
+			addrs[ip] = struct{}{}
+		}
+	}
+	for _, c := range c.DNAT {
+		for _, ip := range c.Addrs {
+			addrs[ip] = struct{}{}
+		}
 	}
 
-	var dialer net.Dialer
-	dialer.Timeout = 5 * time.Second
-
-	var p tcpproxy.Proxy
-	p.ListenFunc = func(net, laddr string) (net.Listener, error) {
-		return netutil.NewOneConnListener(c, nil), nil
+	var routes []netip.Prefix
+	for a := range addrs {
+		routes = append(routes, netip.PrefixFrom(a, a.BitLen()))
 	}
-	p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) {
-		s.numTLSsessions.Add(1)
-		return &tcpproxy.DialProxy{
-			Addr:        net.JoinHostPort(sniName, port),
-			DialContext: dialer.DialContext,
-		}, true
+	sort.SliceStable(routes, func(i, j int) bool {
+		return routes[i].Addr().Less(routes[j].Addr()) // determinism r us
 	})
-	p.Start()
-}
 
-// portNumberToName returns a human-readable name for several port numbers commonly forwarded,
-// and "tcp###" for everything else. It is used for metric label names.
-func portNumberToName(forw *portForward) string {
-	switch forw.Port {
-	case 22:
-		return "ssh"
-	case 1433:
-		return "sqlserver"
-	case 3306:
-		return "mysql"
-	case 3389:
-		return "rdp"
-	case 5432:
-		return "postgres"
-	default:
-		return fmt.Sprintf("%s%d", forw.Proto, forw.Port)
-	}
+	_, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{
+		Prefs: ipn.Prefs{
+			AdvertiseRoutes: routes,
+		},
+		AdvertiseRoutesSet: true,
+	})
+	return err
 }
 
-// forwardConn sets up a forwarder for a TCP connection. It does not inspect of the data
-// like the SNI forwarding does, it merely forwards all data to the destination specified
-// in the --forward=tcp/22/github.com argument.
-func (s *server) forwardConn(c net.Conn, forw *portForward) {
-	addrPortStr := c.LocalAddr().String()
-
-	var dialer net.Dialer
-	dialer.Timeout = 30 * time.Second
+func (s *server) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) {
+	ip4, ip6 := s.ts.TailscaleIPs()
 
-	var p tcpproxy.Proxy
-	p.ListenFunc = func(net, laddr string) (net.Listener, error) {
-		return netutil.NewOneConnListener(c, nil), nil
+	sniConfigFromFlags := appctype.SNIProxyConfig{
+		Addrs: []netip.Addr{ip4, ip6},
 	}
-
-	dial := &tcpproxy.DialProxy{
-		Addr:        fmt.Sprintf("%s:%d", forw.Destination, forw.Port),
-		DialContext: dialer.DialContext,
+	if ports != "" {
+		for _, portStr := range strings.Split(ports, ",") {
+			port, err := strconv.ParseUint(portStr, 10, 16)
+			if err != nil {
+				log.Fatalf("invalid port: %s", portStr)
+			}
+			sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{
+				Proto: int(ipproto.TCP),
+				Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)},
+			})
+		}
 	}
 
-	p.AddRoute(addrPortStr, dial)
-	s.numTCPsessions.Add(portNumberToName(forw), 1)
-	p.Start()
-}
+	var forwardConfigFromFlags []appctype.DNATConfig
+	for _, forwStr := range strings.Split(forwards, ",") {
+		if forwStr == "" {
+			continue
+		}
+		forw, err := parseForward(forwStr)
+		if err != nil {
+			log.Printf("invalid forwarding spec: %v", err)
+			continue
+		}
 
-func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) {
-	resp := dnsmessage.NewBuilder(buf,
-		dnsmessage.Header{
-			ID:            req.Header.ID,
-			Response:      true,
-			Authoritative: true,
+		forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{
+			Addrs: []netip.Addr{ip4, ip6},
+			To:    []string{forw.Destination},
+			IP: []tailcfg.ProtoPortRange{
+				{
+					Proto: int(ipproto.TCP),
+					Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)},
+				},
+			},
 		})
-	resp.EnableCompression()
-
-	if len(req.Questions) == 0 {
-		buf, _ = resp.Finish()
-		return
 	}
 
-	q := req.Questions[0]
-	err = resp.StartQuestions()
-	if err != nil {
-		return
+	if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 {
+		return // no config specified on the command line
 	}
-	resp.Question(q)
 
-	ip4, ip6 := s.ts.TailscaleIPs()
-	err = resp.StartAnswers()
-	if err != nil {
-		return
-	}
-
-	switch q.Type {
-	case dnsmessage.TypeAAAA:
-		err = resp.AAAAResource(
-			dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
-			dnsmessage.AAAAResource{AAAA: ip6.As16()},
-		)
-
-	case dnsmessage.TypeA:
-		err = resp.AResource(
-			dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
-			dnsmessage.AResource{A: ip4.As4()},
-		)
-	case dnsmessage.TypeSOA:
-		err = resp.SOAResource(
-			dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
-			dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
-				Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
-		)
-	case dnsmessage.TypeNS:
-		err = resp.NSResource(
-			dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
-			dnsmessage.NSResource{NS: tsMBox},
-		)
+	mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags)
+	for i, forward := range forwardConfigFromFlags {
+		mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward)
 	}
+}
 
-	if err != nil {
-		return
+func (s *server) serveDNS(ln net.Listener) {
+	for {
+		c, err := ln.Accept()
+		if err != nil {
+			log.Printf("serveDNS accept: %v", err)
+			return
+		}
+		go s.appc.HandleDNS(c.(nettype.ConnPacketConn))
 	}
-
-	return resp.Finish()
 }
 
 func (s *server) promoteHTTPS(ln net.Listener) {
 	err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		s.httpPromoted.Add(1)
 		http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound)
 	}))
 	log.Fatalf("promoteHTTPS http.Serve: %v", err)
 }
-
-// initMetrics sets up local prometheus metrics, and creates clientmetrics to report those
-// same counters.
-func (s *server) initMetrics() {
-	stats := new(metrics.Set)
-
-	stats.Set("tls_sessions", &s.numTLSsessions)
-	clientmetric.NewCounterFunc("sniproxy_tls_sessions", s.numTLSsessions.Value)
-
-	s.numTCPsessions = &metrics.LabelMap{Label: "proto"}
-	stats.Set("tcp_sessions", s.numTCPsessions)
-	// clientmetric doesn't have a good way to implement a Map type.
-	// We create clientmetrics dynamically when parsing the --forwards argument
-
-	stats.Set("bad_addrport", &s.numBadAddrPort)
-	clientmetric.NewCounterFunc("sniproxy_bad_addrport", s.numBadAddrPort.Value)
-
-	stats.Set("dns_responses", &s.dnsResponses)
-	clientmetric.NewCounterFunc("sniproxy_dns_responses", s.dnsResponses.Value)
-
-	stats.Set("dns_failed", &s.dnsFailures)
-	clientmetric.NewCounterFunc("sniproxy_dns_failed", s.dnsFailures.Value)
-
-	stats.Set("http_promoted", &s.httpPromoted)
-	clientmetric.NewCounterFunc("sniproxy_http_promoted", s.httpPromoted.Value)
-
-	expvar.Publish("sniproxy", stats)
-}

+ 186 - 0
cmd/sniproxy/sniproxy_test.go

@@ -4,10 +4,30 @@
 package main
 
 import (
+	"context"
+	"encoding/json"
+	"flag"
+	"fmt"
+	"net"
+	"net/http/httptest"
+	"net/netip"
+	"os"
+	"path/filepath"
 	"strings"
 	"testing"
+	"time"
 
 	"github.com/google/go-cmp/cmp"
+	"tailscale.com/ipn/store/mem"
+	"tailscale.com/net/netns"
+	"tailscale.com/tailcfg"
+	"tailscale.com/tsnet"
+	"tailscale.com/tstest/integration"
+	"tailscale.com/tstest/integration/testcontrol"
+	"tailscale.com/types/appctype"
+	"tailscale.com/types/ipproto"
+	"tailscale.com/types/key"
+	"tailscale.com/types/logger"
 )
 
 func TestPortForwardingArguments(t *testing.T) {
@@ -35,3 +55,169 @@ func TestPortForwardingArguments(t *testing.T) {
 		}
 	}
 }
+
+var verboseDERP = flag.Bool("verbose-derp", false, "if set, print DERP and STUN logs")
+var verboseNodes = flag.Bool("verbose-nodes", false, "if set, print tsnet.Server logs")
+
+func startControl(t *testing.T) (control *testcontrol.Server, controlURL string) {
+	// Corp#4520: don't use netns for tests.
+	netns.SetEnabled(false)
+	t.Cleanup(func() {
+		netns.SetEnabled(true)
+	})
+
+	derpLogf := logger.Discard
+	if *verboseDERP {
+		derpLogf = t.Logf
+	}
+	derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1")
+	control = &testcontrol.Server{
+		DERPMap: derpMap,
+		DNSConfig: &tailcfg.DNSConfig{
+			Proxied: true,
+		},
+		MagicDNSDomain: "tail-scale.ts.net",
+	}
+	control.HTTPTestServer = httptest.NewUnstartedServer(control)
+	control.HTTPTestServer.Start()
+	t.Cleanup(control.HTTPTestServer.Close)
+	controlURL = control.HTTPTestServer.URL
+	t.Logf("testcontrol listening on %s", controlURL)
+	return control, controlURL
+}
+
+func startNode(t *testing.T, ctx context.Context, controlURL, hostname string) (*tsnet.Server, key.NodePublic, netip.Addr) {
+	t.Helper()
+
+	tmp := filepath.Join(t.TempDir(), hostname)
+	os.MkdirAll(tmp, 0755)
+	s := &tsnet.Server{
+		Dir:        tmp,
+		ControlURL: controlURL,
+		Hostname:   hostname,
+		Store:      new(mem.Store),
+		Ephemeral:  true,
+	}
+	if !*verboseNodes {
+		s.Logf = logger.Discard
+	}
+	t.Cleanup(func() { s.Close() })
+
+	status, err := s.Up(ctx)
+	if err != nil {
+		t.Fatal(err)
+	}
+	return s, status.Self.PublicKey, status.TailscaleIPs[0]
+}
+
+func TestSNIProxyWithNetmapConfig(t *testing.T) {
+	c, controlURL := startControl(t)
+	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+	defer cancel()
+
+	// Create a listener to proxy connections to.
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer ln.Close()
+
+	// Start sniproxy
+	sni, nodeKey, ip := startNode(t, ctx, controlURL, "snitest")
+	go run(ctx, sni, 0, sni.Hostname, false, 0, "", "")
+
+	// Configure the mock coordination server to send down app connector config.
+	config := &appctype.AppConnectorConfig{
+		DNAT: map[appctype.ConfigID]appctype.DNATConfig{
+			"nic_test": {
+				Addrs: []netip.Addr{ip},
+				To:    []string{"127.0.0.1"},
+				IP: []tailcfg.ProtoPortRange{
+					{
+						Proto: int(ipproto.TCP),
+						Ports: tailcfg.PortRange{First: uint16(ln.Addr().(*net.TCPAddr).Port), Last: uint16(ln.Addr().(*net.TCPAddr).Port)},
+					},
+				},
+			},
+		},
+	}
+	b, err := json.Marshal(config)
+	if err != nil {
+		t.Fatal(err)
+	}
+	c.SetNodeCapMap(nodeKey, tailcfg.NodeCapMap{
+		configCapKey: []tailcfg.RawMessage{tailcfg.RawMessage(b)},
+	})
+
+	// Lets spin up a second node (to represent the client).
+	client, _, _ := startNode(t, ctx, controlURL, "client")
+
+	// Make sure that the sni node has received its config.
+	l, err := sni.LocalClient()
+	if err != nil {
+		t.Fatal(err)
+	}
+	gotConfigured := false
+	for i := 0; i < 100; i++ {
+		s, err := l.StatusWithoutPeers(ctx)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if len(s.Self.CapMap) > 0 {
+			gotConfigured = true
+			break // we got it
+		}
+		time.Sleep(10 * time.Millisecond)
+	}
+	if !gotConfigured {
+		t.Error("sni node never received its configuration from the coordination server!")
+	}
+
+	// Lets make the client open a connection to the sniproxy node, and
+	// make sure it results in a connection to our test listener.
+	w, err := client.Dial(ctx, "tcp", fmt.Sprintf("%s:%d", ip, ln.Addr().(*net.TCPAddr).Port))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer w.Close()
+
+	r, err := ln.Accept()
+	if err != nil {
+		t.Fatal(err)
+	}
+	r.Close()
+}
+
+func TestSNIProxyWithFlagConfig(t *testing.T) {
+	_, controlURL := startControl(t)
+	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+	defer cancel()
+
+	// Create a listener to proxy connections to.
+	ln, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer ln.Close()
+
+	// Start sniproxy
+	sni, _, ip := startNode(t, ctx, controlURL, "snitest")
+	go run(ctx, sni, 0, sni.Hostname, false, 0, "", fmt.Sprintf("tcp/%d/localhost", ln.Addr().(*net.TCPAddr).Port))
+
+	// Lets spin up a second node (to represent the client).
+	client, _, _ := startNode(t, ctx, controlURL, "client")
+
+	// Lets make the client open a connection to the sniproxy node, and
+	// make sure it results in a connection to our test listener.
+	w, err := client.Dial(ctx, "tcp", fmt.Sprintf("%s:%d", ip, ln.Addr().(*net.TCPAddr).Port))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer w.Close()
+
+	r, err := ln.Accept()
+	if err != nil {
+		t.Fatal(err)
+	}
+	r.Close()
+}

+ 12 - 0
tstest/integration/testcontrol/testcontrol.go

@@ -75,6 +75,9 @@ type Server struct {
 	// masquerade address to use for that peer.
 	masquerades map[key.NodePublic]map[key.NodePublic]netip.Addr // node => peer => SelfNodeV{4,6}MasqAddrForThisPeer IP
 
+	// nodeCapMaps overrides the capability map sent down to a client.
+	nodeCapMaps map[key.NodePublic]tailcfg.NodeCapMap
+
 	// suppressAutoMapResponses is the set of nodes that should not be sent
 	// automatic map responses from serveMap. (They should only get manually sent ones)
 	suppressAutoMapResponses set.Set[key.NodePublic]
@@ -369,6 +372,14 @@ func (s *Server) SetMasqueradeAddresses(pairs []MasqueradePair) {
 	s.updateLocked("SetMasqueradeAddresses", s.nodeIDsLocked(0))
 }
 
+// SetNodeCapMap overrides the capability map the specified client receives.
+func (s *Server) SetNodeCapMap(nodeKey key.NodePublic, capMap tailcfg.NodeCapMap) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	mak.Set(&s.nodeCapMaps, nodeKey, capMap)
+	s.updateLocked("SetNodeCapMap", s.nodeIDsLocked(0))
+}
+
 // nodeIDsLocked returns the node IDs of all nodes in the server, except
 // for the node with the given ID.
 func (s *Server) nodeIDsLocked(except tailcfg.NodeID) []tailcfg.NodeID {
@@ -881,6 +892,7 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse,
 		// node key rotated away (once test server supports that)
 		return nil, nil
 	}
+	node.CapMap = s.nodeCapMaps[nk]
 	node.Capabilities = append(node.Capabilities, tailcfg.NodeAttrDisableUPnP)
 
 	user, _ := s.getUser(nk)