Browse Source

net/dnsfallback: add singleflight to recursive resolver

This prevents running more than one recursive resolution for the same
hostname in parallel, which can use excessive amounts of CPU when called
in a tight loop. Additionally, add tests that hit the network (when
run with a flag) to test the lookup behaviour.

Updates tailscale/corp#15261

Signed-off-by: Andrew Dunham <[email protected]>
Change-Id: I39351e1d2a8782dd4c52cb04b3bd982eb651c81e
Andrew Dunham 2 years ago
parent
commit
e33bc64cff
3 changed files with 177 additions and 56 deletions
  1. 1 1
      cmd/tailscale/depaware.txt
  2. 145 55
      net/dnsfallback/dnsfallback.go
  3. 31 0
      net/dnsfallback/dnsfallback_test.go

+ 1 - 1
cmd/tailscale/depaware.txt

@@ -155,7 +155,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         tailscale.com/util/nocasemaps                                from tailscale.com/types/ipproto
         tailscale.com/util/quarantine                                from tailscale.com/cmd/tailscale/cli
         tailscale.com/util/set                                       from tailscale.com/health+
-        tailscale.com/util/singleflight                              from tailscale.com/net/dnscache
+        tailscale.com/util/singleflight                              from tailscale.com/net/dnscache+
         tailscale.com/util/slicesx                                   from tailscale.com/net/dnscache+
         tailscale.com/util/testenv                                   from tailscale.com/cmd/tailscale/cli
         tailscale.com/util/truncate                                  from tailscale.com/cmd/tailscale/cli

+ 145 - 55
net/dnsfallback/dnsfallback.go

@@ -36,6 +36,7 @@ import (
 	"tailscale.com/tailcfg"
 	"tailscale.com/types/logger"
 	"tailscale.com/util/clientmetric"
+	"tailscale.com/util/singleflight"
 	"tailscale.com/util/slicesx"
 )
 
@@ -44,76 +45,165 @@ var (
 	disableRecursiveResolver = envknob.RegisterBool("TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER") // legacy pre-1.52 env knob name
 )
 
+type resolveResult struct {
+	addrs  []netip.Addr
+	minTTL time.Duration
+}
+
 // MakeLookupFunc creates a function that can be used to resolve hostnames
 // (e.g. as a LookupIPFallback from dnscache.Resolver).
 // The netMon parameter is optional; if non-nil it's used to do faster interface lookups.
 func MakeLookupFunc(logf logger.Logf, netMon *netmon.Monitor) func(ctx context.Context, host string) ([]netip.Addr, error) {
-	return func(ctx context.Context, host string) ([]netip.Addr, error) {
-		// If they've explicitly disabled the recursive resolver with the legacy
-		// TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER envknob or not set the
-		// newer TS_DNSFALLBACK_RECURSIVE_RESOLVER to true, then don't use the
-		// recursive resolver. (tailscale/corp#15261) In the future, we might
-		// change the default (the opt.Bool being unset) to mean enabled.
-		if disableRecursiveResolver() || !optRecursiveResolver().EqualBool(true) {
-			return lookup(ctx, host, logf, netMon)
-		}
+	fr := &fallbackResolver{
+		logf:   logf,
+		netMon: netMon,
+	}
+	return fr.Lookup
+}
 
-		addrsCh := make(chan []netip.Addr, 1)
+// fallbackResolver contains the state and configuration for a DNS resolution
+// function.
+type fallbackResolver struct {
+	logf   logger.Logf
+	netMon *netmon.Monitor // or nil
+	sf     singleflight.Group[string, resolveResult]
 
-		// Run the recursive resolver in the background so we can
-		// compare the results.
-		go func() {
-			logf := logger.WithPrefix(logf, "recursive: ")
-
-			// Ensure that we catch panics while we're testing this
-			// code path; this should never panic, but we don't
-			// want to take down the process by having the panic
-			// propagate to the top of the goroutine's stack and
-			// then terminate.
-			defer func() {
-				if r := recover(); r != nil {
-					logf("bootstrap DNS: recovered panic: %v", r)
-					metricRecursiveErrors.Add(1)
-				}
-			}()
-
-			resolver := recursive.Resolver{
-				Dialer: netns.NewDialer(logf, netMon),
-				Logf:   logf,
-			}
-			addrs, minTTL, err := resolver.Resolve(ctx, host)
-			if err != nil {
-				logf("error using recursive resolver: %v", err)
-				metricRecursiveErrors.Add(1)
-				return
-			}
+	// for tests
+	waitForCompare bool
+}
 
-			compareAddr := func(a, b netip.Addr) int { return a.Compare(b) }
-			slices.SortFunc(addrs, compareAddr)
+func (fr *fallbackResolver) Lookup(ctx context.Context, host string) ([]netip.Addr, error) {
+	// If they've explicitly disabled the recursive resolver with the legacy
+	// TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER envknob or not set the
+	// newer TS_DNSFALLBACK_RECURSIVE_RESOLVER to true, then don't use the
+	// recursive resolver. (tailscale/corp#15261) In the future, we might
+	// change the default (the opt.Bool being unset) to mean enabled.
+	if disableRecursiveResolver() || !optRecursiveResolver().EqualBool(true) {
+		return lookup(ctx, host, fr.logf, fr.netMon)
+	}
 
-			// Wait for a response from the main function
-			oldAddrs := <-addrsCh
-			slices.SortFunc(oldAddrs, compareAddr)
+	addrsCh := make(chan []netip.Addr, 1)
 
-			matches := slices.Equal(addrs, oldAddrs)
+	// Run the recursive resolver in the background so we can
+	// compare the results. For tests, we also allow waiting for the
+	// comparison to complete; normally, we do this entirely asynchronously
+	// so as not to block the caller.
+	var done chan struct{}
+	if fr.waitForCompare {
+		done = make(chan struct{})
+		go func() {
+			defer close(done)
+			fr.compareWithRecursive(ctx, addrsCh, host)
+		}()
+	} else {
+		go fr.compareWithRecursive(ctx, addrsCh, host)
+	}
 
-			logf("bootstrap DNS comparison: matches=%v oldAddrs=%v addrs=%v minTTL=%v", matches, oldAddrs, addrs, minTTL)
+	addrs, err := lookup(ctx, host, fr.logf, fr.netMon)
+	if err != nil {
+		addrsCh <- nil
+		return nil, err
+	}
 
-			if matches {
-				metricRecursiveMatches.Add(1)
-			} else {
-				metricRecursiveMismatches.Add(1)
-			}
-		}()
+	addrsCh <- slices.Clone(addrs)
+	if fr.waitForCompare {
+		select {
+		case <-done:
+		case <-ctx.Done():
+		}
+	}
+	return addrs, nil
+}
 
-		addrs, err := lookup(ctx, host, logf, netMon)
+// compareWithRecursive is responsible for comparing the DNS resolution
+// performed via the "normal" path (bootstrap DNS requests to the DERP servers)
+// with DNS resolution performed with our in-process recursive DNS resolver.
+//
+// It will select on addrsCh to read exactly one set of addrs (returned by the
+// "normal" path) and compare against the results returned by the recursive
+// resolver. If ctx is canceled, then it will abort.
+func (fr *fallbackResolver) compareWithRecursive(
+	ctx context.Context,
+	addrsCh <-chan []netip.Addr,
+	host string,
+) {
+	logf := logger.WithPrefix(fr.logf, "recursive: ")
+
+	// Ensure that we catch panics while we're testing this
+	// code path; this should never panic, but we don't
+	// want to take down the process by having the panic
+	// propagate to the top of the goroutine's stack and
+	// then terminate.
+	defer func() {
+		if r := recover(); r != nil {
+			logf("bootstrap DNS: recovered panic: %v", r)
+			metricRecursiveErrors.Add(1)
+		}
+	}()
+
+	// Don't resolve the same host multiple times
+	// concurrently; if we end up in a tight loop, this can
+	// take up a lot of CPU.
+	var didRun bool
+	result, err, _ := fr.sf.Do(host, func() (resolveResult, error) {
+		didRun = true
+		resolver := &recursive.Resolver{
+			Dialer: netns.NewDialer(logf, fr.netMon),
+			Logf:   logf,
+		}
+		addrs, minTTL, err := resolver.Resolve(ctx, host)
 		if err != nil {
-			addrsCh <- nil
-			return nil, err
+			logf("error using recursive resolver: %v", err)
+			metricRecursiveErrors.Add(1)
+			return resolveResult{}, err
 		}
+		return resolveResult{addrs, minTTL}, nil
+	})
+
+	// The singleflight function handled errors; return if
+	// there was one. Additionally, don't bother doing the
+	// comparison if we waited on another singleflight
+	// caller; the results are likely to be the same, so
+	// rather than spam the logs we can just exit and let
+	// the singleflight call that did execute do the
+	// comparison.
+	//
+	// Returning here is safe because the addrsCh channel
+	// is buffered, so the main function won't block even
+	// if we never read from it.
+	if err != nil || !didRun {
+		return
+	}
+
+	addrs, minTTL := result.addrs, result.minTTL
+	compareAddr := func(a, b netip.Addr) int { return a.Compare(b) }
+	slices.SortFunc(addrs, compareAddr)
+
+	// Wait for a response from the main function; try this once before we
+	// check whether the context is canceled since selects are
+	// nondeterministic.
+	var oldAddrs []netip.Addr
+	select {
+	case oldAddrs = <-addrsCh:
+		// All good; continue
+	default:
+		// Now block.
+		select {
+		case oldAddrs = <-addrsCh:
+		case <-ctx.Done():
+			return
+		}
+	}
+	slices.SortFunc(oldAddrs, compareAddr)
+
+	matches := slices.Equal(addrs, oldAddrs)
+
+	logf("bootstrap DNS comparison: matches=%v oldAddrs=%v addrs=%v minTTL=%v", matches, oldAddrs, addrs, minTTL)
 
-		addrsCh <- slices.Clone(addrs)
-		return addrs, nil
+	if matches {
+		metricRecursiveMatches.Add(1)
+	} else {
+		metricRecursiveMismatches.Add(1)
 	}
 }
 

+ 31 - 0
net/dnsfallback/dnsfallback_test.go

@@ -4,13 +4,17 @@
 package dnsfallback
 
 import (
+	"context"
 	"encoding/json"
+	"flag"
 	"os"
 	"path/filepath"
 	"reflect"
 	"testing"
 
+	"tailscale.com/net/netmon"
 	"tailscale.com/tailcfg"
+	"tailscale.com/types/logger"
 )
 
 func TestGetDERPMap(t *testing.T) {
@@ -170,3 +174,30 @@ func TestCacheUnchanged(t *testing.T) {
 		t.Fatalf("didn't find non-empty regular file; mode=%v size=%d", st.Mode(), st.Size())
 	}
 }
+
+var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests")
+
+func TestLookup(t *testing.T) {
+	if !*extNetwork {
+		t.Skip("skipping test without --use-external-network")
+	}
+
+	logf, closeLogf := logger.LogfCloser(t.Logf)
+	defer closeLogf()
+
+	netMon, err := netmon.New(logf)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	resolver := &fallbackResolver{
+		logf:           logf,
+		netMon:         netMon,
+		waitForCompare: true,
+	}
+	addrs, err := resolver.Lookup(context.Background(), "controlplane.tailscale.com")
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Logf("addrs: %+v", addrs)
+}