Browse Source

cmd/derper: add support for unpublished bootstrap DNS entries (#5529)

Signed-off-by: Andrew Dunham <[email protected]>
Andrew Dunham 3 years ago
parent
commit
a0bae4dac8
3 changed files with 206 additions and 21 deletions
  1. 81 16
      cmd/derper/bootstrap_dns.go
  2. 120 1
      cmd/derper/bootstrap_dns_test.go
  3. 5 4
      cmd/derper/derper.go

+ 81 - 16
cmd/derper/bootstrap_dns.go

@@ -17,16 +17,31 @@ import (
 	"tailscale.com/syncs"
 )
 
-var dnsCache syncs.AtomicValue[[]byte]
+const refreshTimeout = time.Minute
 
-var bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests")
+type dnsEntryMap map[string][]net.IP
+
+var (
+	dnsCache            syncs.AtomicValue[dnsEntryMap]
+	dnsCacheBytes       syncs.AtomicValue[[]byte] // of JSON
+	unpublishedDNSCache syncs.AtomicValue[dnsEntryMap]
+)
+
+var (
+	bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests")
+	publishedDNSHits     = expvar.NewInt("counter_bootstrap_dns_published_hits")
+	publishedDNSMisses   = expvar.NewInt("counter_bootstrap_dns_published_misses")
+	unpublishedDNSHits   = expvar.NewInt("counter_bootstrap_dns_unpublished_hits")
+	unpublishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_misses")
+)
 
 func refreshBootstrapDNSLoop() {
-	if *bootstrapDNS == "" {
+	if *bootstrapDNS == "" && *unpublishedDNS == "" {
 		return
 	}
 	for {
 		refreshBootstrapDNS()
+		refreshUnpublishedDNS()
 		time.Sleep(10 * time.Minute)
 	}
 }
@@ -35,10 +50,34 @@ func refreshBootstrapDNS() {
 	if *bootstrapDNS == "" {
 		return
 	}
-	dnsEntries := make(map[string][]net.IP)
-	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+	ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
+	defer cancel()
+	dnsEntries := resolveList(ctx, strings.Split(*bootstrapDNS, ","))
+	j, err := json.MarshalIndent(dnsEntries, "", "\t")
+	if err != nil {
+		// leave the old values in place
+		return
+	}
+
+	dnsCache.Store(dnsEntries)
+	dnsCacheBytes.Store(j)
+}
+
+func refreshUnpublishedDNS() {
+	if *unpublishedDNS == "" {
+		return
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
 	defer cancel()
-	names := strings.Split(*bootstrapDNS, ",")
+
+	dnsEntries := resolveList(ctx, strings.Split(*unpublishedDNS, ","))
+	unpublishedDNSCache.Store(dnsEntries)
+}
+
+func resolveList(ctx context.Context, names []string) dnsEntryMap {
+	dnsEntries := make(dnsEntryMap)
+
 	var r net.Resolver
 	for _, name := range names {
 		addrs, err := r.LookupIP(ctx, "ip", name)
@@ -48,21 +87,47 @@ func refreshBootstrapDNS() {
 		}
 		dnsEntries[name] = addrs
 	}
-	j, err := json.MarshalIndent(dnsEntries, "", "\t")
-	if err != nil {
-		// leave the old values in place
-		return
-	}
-	dnsCache.Store(j)
+	return dnsEntries
 }
 
 func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
 	bootstrapDNSRequests.Add(1)
+
 	w.Header().Set("Content-Type", "application/json")
-	j := dnsCache.Load()
-	// Bootstrap DNS requests occur cross-regions,
-	// and are randomized per request,
-	// so keeping a connection open is pointlessly expensive.
+	// Bootstrap DNS requests occur cross-regions, and are randomized per
+	// request, so keeping a connection open is pointlessly expensive.
 	w.Header().Set("Connection", "close")
+
+	// Try answering a query from our hidden map first
+	if q := r.URL.Query().Get("q"); q != "" {
+		if ips, ok := unpublishedDNSCache.Load()[q]; ok && len(ips) > 0 {
+			unpublishedDNSHits.Add(1)
+
+			// Only return the specific query, not everything.
+			m := dnsEntryMap{q: ips}
+			j, err := json.MarshalIndent(m, "", "\t")
+			if err == nil {
+				w.Write(j)
+				return
+			}
+		}
+
+		// If we have a "q" query for a name in the published cache
+		// list, then track whether that's a hit/miss.
+		if m, ok := dnsCache.Load()[q]; ok {
+			if len(m) > 0 {
+				publishedDNSHits.Add(1)
+			} else {
+				publishedDNSMisses.Add(1)
+			}
+		} else {
+			// If it wasn't in either cache, treat this as a query
+			// for the unpublished cache, and thus a cache miss.
+			unpublishedDNSMisses.Add(1)
+		}
+	}
+
+	// Fall back to returning the public set of cached DNS names
+	j := dnsCacheBytes.Load()
 	w.Write(j)
 }

+ 120 - 1
cmd/derper/bootstrap_dns_test.go

@@ -5,7 +5,12 @@
 package main
 
 import (
+	"encoding/json"
+	"net"
 	"net/http"
+	"net/http/httptest"
+	"net/url"
+	"reflect"
 	"testing"
 )
 
@@ -17,11 +22,12 @@ func BenchmarkHandleBootstrapDNS(b *testing.B) {
 	}()
 	refreshBootstrapDNS()
 	w := new(bitbucketResponseWriter)
+	req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape("log.tailscale.io"), nil)
 	b.ReportAllocs()
 	b.ResetTimer()
 	b.RunParallel(func(b *testing.PB) {
 		for b.Next() {
-			handleBootstrapDNS(w, nil)
+			handleBootstrapDNS(w, req)
 		}
 	})
 }
@@ -33,3 +39,116 @@ func (b *bitbucketResponseWriter) Header() http.Header { return make(http.Header
 func (b *bitbucketResponseWriter) Write(p []byte) (int, error) { return len(p), nil }
 
 func (b *bitbucketResponseWriter) WriteHeader(statusCode int) {}
+
+func getBootstrapDNS(t *testing.T, q string) dnsEntryMap {
+	t.Helper()
+	req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape(q), nil)
+	w := httptest.NewRecorder()
+	handleBootstrapDNS(w, req)
+
+	res := w.Result()
+	if res.StatusCode != 200 {
+		t.Fatalf("got status=%d; want %d", res.StatusCode, 200)
+	}
+	var ips dnsEntryMap
+	if err := json.NewDecoder(res.Body).Decode(&ips); err != nil {
+		t.Fatalf("error decoding response body: %v", err)
+	}
+	return ips
+}
+
+func TestUnpublishedDNS(t *testing.T) {
+	const published = "login.tailscale.com"
+	const unpublished = "log.tailscale.io"
+
+	prev1, prev2 := *bootstrapDNS, *unpublishedDNS
+	*bootstrapDNS = published
+	*unpublishedDNS = unpublished
+	t.Cleanup(func() {
+		*bootstrapDNS = prev1
+		*unpublishedDNS = prev2
+	})
+
+	refreshBootstrapDNS()
+	refreshUnpublishedDNS()
+
+	hasResponse := func(q string) bool {
+		_, found := getBootstrapDNS(t, q)[q]
+		return found
+	}
+
+	if !hasResponse(published) {
+		t.Errorf("expected response for: %s", published)
+	}
+	if !hasResponse(unpublished) {
+		t.Errorf("expected response for: %s", unpublished)
+	}
+
+	// Verify that querying for a random query or a real query does not
+	// leak our unpublished domain
+	m1 := getBootstrapDNS(t, published)
+	if _, found := m1[unpublished]; found {
+		t.Errorf("found unpublished domain %s: %+v", unpublished, m1)
+	}
+	m2 := getBootstrapDNS(t, "random.example.com")
+	if _, found := m2[unpublished]; found {
+		t.Errorf("found unpublished domain %s: %+v", unpublished, m2)
+	}
+}
+
+func resetMetrics() {
+	publishedDNSHits.Set(0)
+	publishedDNSMisses.Set(0)
+	unpublishedDNSHits.Set(0)
+	unpublishedDNSMisses.Set(0)
+}
+
+// Verify that we don't count an empty list in the unpublishedDNSCache as a
+// cache hit in our metrics.
+func TestUnpublishedDNSEmptyList(t *testing.T) {
+	pub := dnsEntryMap{
+		"tailscale.com": {net.IPv4(10, 10, 10, 10)},
+	}
+	dnsCache.Store(pub)
+	dnsCacheBytes.Store([]byte(`{"tailscale.com":["10.10.10.10"]}`))
+
+	unpublishedDNSCache.Store(dnsEntryMap{
+		"log.tailscale.io":           {},
+		"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)},
+	})
+
+	t.Run("CacheMiss", func(t *testing.T) {
+		// One domain in map but empty, one not in map at all
+		for _, q := range []string{"log.tailscale.io", "login.tailscale.com"} {
+			resetMetrics()
+			ips := getBootstrapDNS(t, q)
+
+			// Expected our public map to be returned on a cache miss
+			if !reflect.DeepEqual(ips, pub) {
+				t.Errorf("got ips=%+v; want %+v", ips, pub)
+			}
+			if v := unpublishedDNSHits.Value(); v != 0 {
+				t.Errorf("got hits=%d; want 0", v)
+			}
+			if v := unpublishedDNSMisses.Value(); v != 1 {
+				t.Errorf("got misses=%d; want 1", v)
+			}
+		}
+	})
+
+	// Verify that we do get a valid response and metric.
+	t.Run("CacheHit", func(t *testing.T) {
+		resetMetrics()
+		ips := getBootstrapDNS(t, "controlplane.tailscale.com")
+		want := dnsEntryMap{"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}}
+		if !reflect.DeepEqual(ips, want) {
+			t.Errorf("got ips=%+v; want %+v", ips, want)
+		}
+		if v := unpublishedDNSHits.Value(); v != 1 {
+			t.Errorf("got hits=%d; want 1", v)
+		}
+		if v := unpublishedDNSMisses.Value(); v != 0 {
+			t.Errorf("got misses=%d; want 0", v)
+		}
+	})
+}

+ 5 - 4
cmd/derper/derper.go

@@ -47,10 +47,11 @@ var (
 	hostname   = flag.String("hostname", "derp.tailscale.com", "LetsEncrypt host name, if addr's port is :443")
 	runSTUN    = flag.Bool("stun", true, "whether to run a STUN server. It will bind to the same IP (if any) as the --addr flag value.")
 
-	meshPSKFile   = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It should contain some hex string; whitespace is trimmed.")
-	meshWith      = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list")
-	bootstrapDNS  = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns")
-	verifyClients = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.")
+	meshPSKFile    = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It should contain some hex string; whitespace is trimmed.")
+	meshWith       = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list")
+	bootstrapDNS   = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns")
+	unpublishedDNS = flag.String("unpublished-bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns and not publish in the list")
+	verifyClients  = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.")
 
 	acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection")
 	acceptConnBurst = flag.Int("accept-connection-burst", math.MaxInt, "burst limit for accepting new connection")