Procházet zdrojové kódy

net/dns/recursive: add initial implementation of recursive DNS resolver

We've talked in the past about reworking how bootstrap DNS works to
instead do recursive DNS resolution from the root; this would better
support on-prem customers and Headscale users where the DERP servers
don't currently resolve their DNS server. This package is an initial
implementation of recursive resolution for A and AAAA records.

Updates #5853

Change-Id: Ibe974d78709b4b03674b47c4ef61f9a00addf8b4
Signed-off-by: Andrew Dunham <[email protected]>
Andrew Dunham před 2 roky
rodič
revize
f077b672e4
2 změnil soubory, kde provedl 1381 přidání a 0 odebrání
  1. 640 0
      net/dns/recursive/recursive.go
  2. 741 0
      net/dns/recursive/recursive_test.go

+ 640 - 0
net/dns/recursive/recursive.go

@@ -0,0 +1,640 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package recursive implements a simple recursive DNS resolver.
+package recursive
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"net"
+	"net/netip"
+	"strings"
+	"time"
+
+	"github.com/miekg/dns"
+	"golang.org/x/exp/constraints"
+	"golang.org/x/exp/slices"
+	"tailscale.com/envknob"
+	"tailscale.com/net/netns"
+	"tailscale.com/types/logger"
+	"tailscale.com/util/dnsname"
+	"tailscale.com/util/mak"
+	"tailscale.com/util/multierr"
+	"tailscale.com/util/slicesx"
+)
+
+const (
+	// maxDepth is how deep from the root nameservers we'll recurse when
+	// resolving; passing this limit will instead return an error.
+	//
+	// maxDepth must be at least 20 to resolve "console.aws.amazon.com",
+	// which is a domain with a moderately complicated DNS setup. The
+	// current value of 30 was chosen semi-arbitrarily to ensure that we
+	// have about 50% headroom.
+	maxDepth = 30
+	// numStartingServers is the number of root nameservers that we use as
+	// initial candidates for our recursion.
+	numStartingServers = 3
+	// udpQueryTimeout is the amount of time we wait for a UDP response
+	// from a nameserver before falling back to a TCP connection.
+	udpQueryTimeout = 5 * time.Second
+
+	// These constants aren't typed in the DNS package, so we create typed
+	// versions here to avoid having to do repeated type casts.
+	qtypeA    dns.Type = dns.Type(dns.TypeA)
+	qtypeAAAA dns.Type = dns.Type(dns.TypeAAAA)
+)
+
+var (
+	// ErrMaxDepth is returned when recursive resolving exceeds the maximum
+	// depth limit for this package.
+	ErrMaxDepth = fmt.Errorf("exceeded max depth %d when resolving", maxDepth)
+
+	// ErrAuthoritativeNoResponses is the error returned when an
+	// authoritative nameserver indicates that there are no responses to
+	// the given query.
+	ErrAuthoritativeNoResponses = errors.New("authoritative server returned no responses")
+
+	// ErrNoResponses is returned when our resolution process completes
+	// with no valid responses from any nameserver, but no authoritative
+	// server explicitly returned NXDOMAIN.
+	ErrNoResponses = errors.New("no responses to query")
+)
+
+var rootServersV4 = []netip.Addr{
+	netip.MustParseAddr("198.41.0.4"),     // a.root-servers.net
+	netip.MustParseAddr("199.9.14.201"),   // b.root-servers.net
+	netip.MustParseAddr("192.33.4.12"),    // c.root-servers.net
+	netip.MustParseAddr("199.7.91.13"),    // d.root-servers.net
+	netip.MustParseAddr("192.203.230.10"), // e.root-servers.net
+	netip.MustParseAddr("192.5.5.241"),    // f.root-servers.net
+	netip.MustParseAddr("192.112.36.4"),   // g.root-servers.net
+	netip.MustParseAddr("198.97.190.53"),  // h.root-servers.net
+	netip.MustParseAddr("192.36.148.17"),  // i.root-servers.net
+	netip.MustParseAddr("192.58.128.30"),  // j.root-servers.net
+	netip.MustParseAddr("193.0.14.129"),   // k.root-servers.net
+	netip.MustParseAddr("199.7.83.42"),    // l.root-servers.net
+	netip.MustParseAddr("202.12.27.33"),   // m.root-servers.net
+}
+
+var rootServersV6 = []netip.Addr{
+	netip.MustParseAddr("2001:503:ba3e::2:30"), // a.root-servers.net
+	netip.MustParseAddr("2001:500:200::b"),     // b.root-servers.net
+	netip.MustParseAddr("2001:500:2::c"),       // c.root-servers.net
+	netip.MustParseAddr("2001:500:2d::d"),      // d.root-servers.net
+	netip.MustParseAddr("2001:500:a8::e"),      // e.root-servers.net
+	netip.MustParseAddr("2001:500:2f::f"),      // f.root-servers.net
+	netip.MustParseAddr("2001:500:12::d0d"),    // g.root-servers.net
+	netip.MustParseAddr("2001:500:1::53"),      // h.root-servers.net
+	netip.MustParseAddr("2001:7fe::53"),        // i.root-servers.net
+	netip.MustParseAddr("2001:503:c27::2:30"),  // j.root-servers.net
+	netip.MustParseAddr("2001:7fd::1"),         // k.root-servers.net
+	netip.MustParseAddr("2001:500:9f::42"),     // l.root-servers.net
+	netip.MustParseAddr("2001:dc3::35"),        // m.root-servers.net
+}
+
+var debug = envknob.RegisterBool("TS_DEBUG_RECURSIVE_DNS")
+
+// Resolver is a recursive DNS resolver that is designed for looking up A and AAAA records.
+type Resolver struct {
+	// Dialer is used to create outbound connections. If nil, a zero
+	// net.Dialer will be used instead.
+	Dialer netns.Dialer
+
+	// Logf is the logging function to use; if none is specified, then logs
+	// will be dropped.
+	Logf logger.Logf
+
+	// NoIPv6, if set, will prevent this package from querying for AAAA
+	// records and will avoid contacting nameservers over IPv6.
+	NoIPv6 bool
+
+	// Test mocks
+	testQueryHook    func(name dnsname.FQDN, nameserver netip.Addr, protocol string, qtype dns.Type) (*dns.Msg, error)
+	testExchangeHook func(nameserver netip.Addr, network string, msg *dns.Msg) (*dns.Msg, error)
+	rootServers      []netip.Addr
+	timeNow          func() time.Time
+
+	// Caching
+	// NOTE(andrew): if we make resolution parallel, this needs a mutex
+	queryCache map[dnsQuery]dnsMsgWithExpiry
+
+	// Possible future additions:
+	//    - Additional nameservers? From the system maybe?
+	//    - NoIPv4 for IPv4
+	//    - DNS-over-HTTPS or DNS-over-TLS support
+}
+
+// queryState stores all state during the course of a single query
+type queryState struct {
+	// rootServers are the root nameservers to start from
+	rootServers []netip.Addr
+
+	// TODO: metrics?
+}
+
+type dnsQuery struct {
+	nameserver netip.Addr
+	name       dnsname.FQDN
+	qtype      dns.Type
+}
+
+func (q dnsQuery) String() string {
+	return fmt.Sprintf("dnsQuery{nameserver:%q,name:%q,qtype:%v}", q.nameserver.String(), q.name, q.qtype)
+}
+
+type dnsMsgWithExpiry struct {
+	*dns.Msg
+	expiresAt time.Time
+}
+
+func (r *Resolver) now() time.Time {
+	if r.timeNow != nil {
+		return r.timeNow()
+	}
+	return time.Now()
+}
+
+func (r *Resolver) logf(format string, args ...any) {
+	if r.Logf == nil {
+		return
+	}
+	r.Logf(format, args...)
+}
+
+func (r *Resolver) dlogf(format string, args ...any) {
+	if r.Logf == nil || !debug() {
+		return
+	}
+	r.Logf(format, args...)
+}
+
+func (r *Resolver) depthlogf(depth int, format string, args ...any) {
+	if r.Logf == nil || !debug() {
+		return
+	}
+	prefix := fmt.Sprintf("[%d] %s", depth, strings.Repeat("  ", depth))
+	r.Logf(prefix+format, args...)
+}
+
+var defaultDialer net.Dialer
+
+func (r *Resolver) dialer() netns.Dialer {
+	if r.Dialer != nil {
+		return r.Dialer
+	}
+
+	return &defaultDialer
+}
+
+func (r *Resolver) newState() *queryState {
+	var rootServers []netip.Addr
+	if len(r.rootServers) > 0 {
+		rootServers = r.rootServers
+	} else {
+		// Select a random subset of root nameservers to start from, since if
+		// we don't get responses from those, something else has probably gone
+		// horribly wrong.
+		roots4 := slices.Clone(rootServersV4)
+		slicesx.Shuffle(roots4)
+		roots4 = roots4[:numStartingServers]
+
+		var roots6 []netip.Addr
+		if !r.NoIPv6 {
+			roots6 = slices.Clone(rootServersV6)
+			slicesx.Shuffle(roots6)
+			roots6 = roots6[:numStartingServers]
+		}
+
+		// Interleave the root servers so that we try to contact them over
+		// IPv4, then IPv6, IPv4, IPv6, etc.
+		rootServers = slicesx.Interleave(roots4, roots6)
+	}
+
+	return &queryState{
+		rootServers: rootServers,
+	}
+}
+
+// Resolve will perform a recursive DNS resolution for the provided name,
+// starting at a randomly-chosen root DNS server, and return the A and AAAA
+// responses as a slice of netip.Addrs along with the minimum TTL for the
+// returned records.
+func (r *Resolver) Resolve(ctx context.Context, name string) (addrs []netip.Addr, minTTL time.Duration, err error) {
+	dnsName, err := dnsname.ToFQDN(name)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	qstate := r.newState()
+
+	r.logf("querying IPv4 addresses for: %q", name)
+	addrs4, minTTL4, err4 := r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeA)
+
+	var (
+		addrs6  []netip.Addr
+		minTTL6 time.Duration
+		err6    error
+	)
+	if !r.NoIPv6 {
+		r.logf("querying IPv6 addresses for: %q", name)
+		addrs6, minTTL6, err6 = r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeAAAA)
+	}
+
+	if err4 != nil && err6 != nil {
+		if err4 == err6 {
+			return nil, 0, err4
+		}
+
+		return nil, 0, multierr.New(err4, err6)
+	}
+	if err4 != nil {
+		return addrs6, minTTL6, nil
+	} else if err6 != nil {
+		return addrs4, minTTL4, nil
+	}
+
+	minTTL = minTTL4
+	if minTTL6 < minTTL {
+		minTTL = minTTL6
+	}
+
+	addrs = append(addrs4, addrs6...)
+	if len(addrs) == 0 {
+		return nil, 0, ErrNoResponses
+	}
+
+	slicesx.Shuffle(addrs)
+	return addrs, minTTL, nil
+}
+
+func (r *Resolver) resolveRecursiveFromRoot(
+	ctx context.Context,
+	qstate *queryState,
+	depth int,
+	name dnsname.FQDN, // what we're querying
+	qtype dns.Type,
+) ([]netip.Addr, time.Duration, error) {
+	r.depthlogf(depth, "resolving %q from root (type: %v)", name, qtype)
+
+	var depthError bool
+	for _, server := range qstate.rootServers {
+		addrs, minTTL, err := r.resolveRecursive(ctx, qstate, depth, name, server, qtype)
+		if err == nil {
+			return addrs, minTTL, err
+		} else if errors.Is(err, ErrAuthoritativeNoResponses) {
+			return nil, 0, ErrAuthoritativeNoResponses
+		} else if errors.Is(err, ErrMaxDepth) {
+			depthError = true
+		}
+	}
+
+	if depthError {
+		return nil, 0, ErrMaxDepth
+	}
+	return nil, 0, ErrNoResponses
+}
+
+func (r *Resolver) resolveRecursive(
+	ctx context.Context,
+	qstate *queryState,
+	depth int,
+	name dnsname.FQDN, // what we're querying
+	nameserver netip.Addr,
+	qtype dns.Type,
+) ([]netip.Addr, time.Duration, error) {
+	if depth == maxDepth {
+		r.depthlogf(depth, "not recursing past maximum depth")
+		return nil, 0, ErrMaxDepth
+	}
+
+	// Ask this nameserver for an answer.
+	resp, err := r.queryNameserver(ctx, depth, name, nameserver, qtype)
+	if err != nil {
+		return nil, 0, err
+	}
+
+	// If we get an actual answer from the nameserver, then return it.
+	var (
+		answers []netip.Addr
+		cnames  []dnsname.FQDN
+		minTTL  = 24 * 60 * 60 // 24 hours in seconds
+	)
+	for _, answer := range resp.Answer {
+		if crec, ok := answer.(*dns.CNAME); ok {
+			cnameFQDN, err := dnsname.ToFQDN(crec.Target)
+			if err != nil {
+				r.logf("bad CNAME %q returned: %v", crec.Target, err)
+				continue
+			}
+
+			cnames = append(cnames, cnameFQDN)
+			continue
+		}
+
+		addr := addrFromRecord(answer)
+		if !addr.IsValid() {
+			r.logf("[unexpected] invalid record in %T answer", answer)
+		} else if addr.Is4() && qtype != qtypeA {
+			r.logf("[unexpected] got IPv4 answer but qtype=%v", qtype)
+		} else if addr.Is6() && qtype != qtypeAAAA {
+			r.logf("[unexpected] got IPv6 answer but qtype=%v", qtype)
+		} else {
+			answers = append(answers, addr)
+			minTTL = min(minTTL, int(answer.Header().Ttl))
+		}
+	}
+
+	if len(answers) > 0 {
+		r.depthlogf(depth, "got answers for %q: %v", name, answers)
+		return answers, time.Duration(minTTL) * time.Second, nil
+	}
+
+	r.depthlogf(depth, "no answers for %q", name)
+
+	// If we have a non-zero number of CNAMEs, then try resolving those
+	// (from the root again) and return the first one that succeeds.
+	//
+	// TODO: return the union of all responses?
+	// TODO: parallelism?
+	if len(cnames) > 0 {
+		r.depthlogf(depth, "got CNAME responses for %q: %v", name, cnames)
+	}
+	var cnameDepthError bool
+	for _, cname := range cnames {
+		answers, minTTL, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, cname, qtype)
+		if err == nil {
+			return answers, minTTL, nil
+		} else if errors.Is(err, ErrAuthoritativeNoResponses) {
+			return nil, 0, ErrAuthoritativeNoResponses
+		} else if errors.Is(err, ErrMaxDepth) {
+			cnameDepthError = true
+		}
+	}
+
+	// If this is an authoritative response, then we know that continuing
+	// to look further is not going to result in any answers and we should
+	// bail out.
+	if resp.MsgHdr.Authoritative {
+		// If we failed to recurse into a CNAME due to a depth limit,
+		// propagate that here.
+		if cnameDepthError {
+			return nil, 0, ErrMaxDepth
+		}
+
+		r.depthlogf(depth, "got authoritative response with no answers; stopping")
+		return nil, 0, ErrAuthoritativeNoResponses
+	}
+
+	r.depthlogf(depth, "got %d NS responses and %d ADDITIONAL responses for %q", len(resp.Ns), len(resp.Extra), name)
+
+	// No CNAMEs and no answers; see if we got any AUTHORITY responses,
+	// which indicate which nameservers to query next.
+	var authorities []dnsname.FQDN
+	for _, rr := range resp.Ns {
+		ns, ok := rr.(*dns.NS)
+		if !ok {
+			continue
+		}
+
+		nsName, err := dnsname.ToFQDN(ns.Ns)
+		if err != nil {
+			r.logf("unexpected bad NS name %q: %v", ns.Ns, err)
+			continue
+		}
+
+		authorities = append(authorities, nsName)
+	}
+
+	// Also check for "glue" records, which are IP addresses provided by
+	// the DNS server for authority responses; these are required when the
+	// authority server is a subdomain of what's being resolved.
+	glueRecords := make(map[dnsname.FQDN][]netip.Addr)
+	for _, rr := range resp.Extra {
+		name, err := dnsname.ToFQDN(rr.Header().Name)
+		if err != nil {
+			r.logf("unexpected bad Name %q in Extra addr: %v", rr.Header().Name, err)
+			continue
+		}
+
+		if addr := addrFromRecord(rr); addr.IsValid() {
+			glueRecords[name] = append(glueRecords[name], addr)
+		} else {
+			r.logf("unexpected bad Extra %T addr", rr)
+		}
+	}
+
+	// Try authorities with glue records first, to minimize the number of
+	// additional DNS queries that we need to make.
+	authoritiesGlue, authoritiesNoGlue := slicesx.Partition(authorities, func(aa dnsname.FQDN) bool {
+		return len(glueRecords[aa]) > 0
+	})
+
+	authorityDepthError := false
+
+	r.depthlogf(depth, "authorities with glue records for recursion: %v", authoritiesGlue)
+	for _, authority := range authoritiesGlue {
+		for _, nameserver := range glueRecords[authority] {
+			answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype)
+			if err == nil {
+				return answers, minTTL, nil
+			} else if errors.Is(err, ErrAuthoritativeNoResponses) {
+				return nil, 0, ErrAuthoritativeNoResponses
+			} else if errors.Is(err, ErrMaxDepth) {
+				authorityDepthError = true
+			}
+		}
+	}
+
+	r.depthlogf(depth, "authorities with no glue records for recursion: %v", authoritiesNoGlue)
+	for _, authority := range authoritiesNoGlue {
+		// First, resolve the IP for the authority server from the
+		// root, querying for both IPv4 and IPv6 addresses regardless
+		// of what the current question type is.
+		//
+		// TODO: check for infinite recursion; it'll get caught by our
+		// recursion depth, but we want to bail early.
+		for _, authorityQtype := range []dns.Type{qtypeAAAA, qtypeA} {
+			answers, _, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, authority, authorityQtype)
+			if err != nil {
+				r.depthlogf(depth, "error querying authority %q: %v", authority, err)
+				continue
+			}
+			r.depthlogf(depth, "resolved authority %q (type %v) to: %v", authority, authorityQtype, answers)
+
+			// Now, query this authority for the final address.
+			for _, nameserver := range answers {
+				answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype)
+				if err == nil {
+					return answers, minTTL, nil
+				} else if errors.Is(err, ErrAuthoritativeNoResponses) {
+					return nil, 0, ErrAuthoritativeNoResponses
+				} else if errors.Is(err, ErrMaxDepth) {
+					authorityDepthError = true
+				}
+			}
+		}
+	}
+
+	if authorityDepthError {
+		return nil, 0, ErrMaxDepth
+	}
+	return nil, 0, ErrNoResponses
+}
+
+func min[T constraints.Ordered](a, b T) T {
+	if a < b {
+		return a
+	}
+	return b
+}
+
+// queryNameserver sends a query for "name" to the nameserver "nameserver" for
+// records of type "qtype", trying both UDP and TCP connections as
+// appropriate.
+func (r *Resolver) queryNameserver(
+	ctx context.Context,
+	depth int,
+	name dnsname.FQDN, // what we're querying
+	nameserver netip.Addr, // destination of query
+	qtype dns.Type,
+) (*dns.Msg, error) {
+	// TODO(andrew): we should QNAME minimisation here to avoid sending the
+	// full name to intermediate/root nameservers. See:
+	//    https://www.rfc-editor.org/rfc/rfc7816
+
+	// Handle the case where UDP is blocked by adding an explicit timeout
+	// for the UDP portion of this query.
+	udpCtx, udpCtxCancel := context.WithTimeout(ctx, udpQueryTimeout)
+	defer udpCtxCancel()
+
+	msg, err := r.queryNameserverProto(udpCtx, depth, name, nameserver, "udp", qtype)
+	if err == nil {
+		return msg, nil
+	}
+
+	msg, err2 := r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype)
+	if err2 == nil {
+		return msg, nil
+	}
+
+	return nil, multierr.New(err, err2)
+}
+
+// queryNameserverProto sends a query for "name" to the nameserver "nameserver"
+// for records of type "qtype" over the provided protocol (either "udp"
+// or "tcp"), and returns the DNS response or an error.
+func (r *Resolver) queryNameserverProto(
+	ctx context.Context,
+	depth int,
+	name dnsname.FQDN, // what we're querying
+	nameserver netip.Addr, // destination of query
+	protocol string,
+	qtype dns.Type,
+) (resp *dns.Msg, err error) {
+	if r.testQueryHook != nil {
+		return r.testQueryHook(name, nameserver, protocol, qtype)
+	}
+
+	now := r.now()
+	nameserverStr := nameserver.String()
+
+	cacheKey := dnsQuery{
+		nameserver: nameserver,
+		name:       name,
+		qtype:      qtype,
+	}
+	cacheEntry, ok := r.queryCache[cacheKey]
+	if ok && cacheEntry.expiresAt.Before(now) {
+		r.depthlogf(depth, "using cached response from %s about %q (type: %v)", nameserverStr, name, qtype)
+		return cacheEntry.Msg, nil
+	}
+
+	var network string
+	if nameserver.Is4() {
+		network = protocol + "4"
+	} else {
+		network = protocol + "6"
+	}
+
+	// Prepare a message asking for an appropriately-typed record
+	// for the name we're querying.
+	m := new(dns.Msg)
+	m.SetQuestion(name.WithTrailingDot(), uint16(qtype))
+
+	// Allow mocking out the network components with our exchange hook.
+	if r.testExchangeHook != nil {
+		resp, err = r.testExchangeHook(nameserver, network, m)
+	} else {
+		// Dial the current nameserver using our dialer.
+		var nconn net.Conn
+		nconn, err = r.dialer().DialContext(ctx, network, net.JoinHostPort(nameserverStr, "53"))
+		if err != nil {
+			return nil, err
+		}
+
+		var c dns.Client // TODO: share?
+		conn := &dns.Conn{
+			Conn:    nconn,
+			UDPSize: c.UDPSize,
+		}
+
+		// Send the DNS request to the current nameserver.
+		//
+		// TODO(andrew): use ExchangeWithConnContext after this upstream PR is
+		// merged:
+		//    https://github.com/miekg/dns/pull/1459
+		r.depthlogf(depth, "asking %s over %s about %q (type: %v)", nameserverStr, protocol, name, qtype)
+		resp, _, err = c.ExchangeWithConn(m, conn)
+	}
+	if err != nil {
+		return nil, err
+	}
+
+	// If the message was truncated and we're using UDP, re-run with TCP.
+	if resp.MsgHdr.Truncated && protocol == "udp" {
+		r.depthlogf(depth, "response message truncated; re-running query with TCP")
+		resp, err = r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	// Find minimum expiry for all records in this message.
+	var minTTL int
+	for _, rr := range resp.Answer {
+		minTTL = min(minTTL, int(rr.Header().Ttl))
+	}
+	for _, rr := range resp.Ns {
+		minTTL = min(minTTL, int(rr.Header().Ttl))
+	}
+	for _, rr := range resp.Extra {
+		minTTL = min(minTTL, int(rr.Header().Ttl))
+	}
+
+	mak.Set(&r.queryCache, cacheKey, dnsMsgWithExpiry{
+		Msg:       resp,
+		expiresAt: now.Add(time.Duration(minTTL) * time.Second),
+	})
+	return resp, nil
+}
+
+func addrFromRecord(rr dns.RR) netip.Addr {
+	switch v := rr.(type) {
+	case *dns.A:
+		ip, ok := netip.AddrFromSlice(v.A)
+		if !ok || !ip.Is4() {
+			return netip.Addr{}
+		}
+		return ip
+	case *dns.AAAA:
+		ip, ok := netip.AddrFromSlice(v.AAAA)
+		if !ok || !ip.Is6() {
+			return netip.Addr{}
+		}
+		return ip
+	}
+	return netip.Addr{}
+}

+ 741 - 0
net/dns/recursive/recursive_test.go

@@ -0,0 +1,741 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package recursive
+
+import (
+	"context"
+	"errors"
+	"flag"
+	"fmt"
+	"net"
+	"net/netip"
+	"reflect"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/miekg/dns"
+	"golang.org/x/exp/slices"
+	"tailscale.com/envknob"
+	"tailscale.com/tstest"
+)
+
+const testDomain = "tailscale.com"
+
+// Recursively resolving the AWS console requires being able to handle CNAMEs,
+// glue records, falling back from UDP to TCP for oversize queries, and more;
+// it's a great integration test for DNS resolution and they can handle the
+// traffic :)
+const complicatedTestDomain = "console.aws.amazon.com"
+
+var flagNetworkAccess = flag.Bool("enable-network-access", false, "run tests that need external network access")
+
+func init() {
+	envknob.Setenv("TS_DEBUG_RECURSIVE_DNS", "true")
+}
+
+func newResolver(tb testing.TB) *Resolver {
+	clock := &tstest.Clock{
+		Step: 50 * time.Millisecond,
+	}
+	return &Resolver{
+		Logf:    tb.Logf,
+		timeNow: clock.Now,
+	}
+}
+
+func TestResolve(t *testing.T) {
+	if !*flagNetworkAccess {
+		t.SkipNow()
+	}
+
+	ctx := context.Background()
+	r := newResolver(t)
+	addrs, minTTL, err := r.Resolve(ctx, testDomain)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	t.Logf("addrs: %+v", addrs)
+	t.Logf("minTTL: %v", minTTL)
+	if len(addrs) < 1 {
+		t.Fatalf("expected at least one address")
+	}
+
+	if minTTL <= 10*time.Second || minTTL >= 24*time.Hour {
+		t.Errorf("invalid minimum TTL: %v", minTTL)
+	}
+
+	var has4, has6 bool
+	for _, addr := range addrs {
+		has4 = has4 || addr.Is4()
+		has6 = has6 || addr.Is6()
+	}
+
+	if !has4 {
+		t.Errorf("expected at least one IPv4 address")
+	}
+	if !has6 {
+		t.Errorf("expected at least one IPv6 address")
+	}
+}
+
+func TestResolveComplicated(t *testing.T) {
+	if !*flagNetworkAccess {
+		t.SkipNow()
+	}
+
+	ctx := context.Background()
+	r := newResolver(t)
+	addrs, minTTL, err := r.Resolve(ctx, complicatedTestDomain)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	t.Logf("addrs: %+v", addrs)
+	t.Logf("minTTL: %v", minTTL)
+	if len(addrs) < 1 {
+		t.Fatalf("expected at least one address")
+	}
+
+	if minTTL <= 10*time.Second || minTTL >= 24*time.Hour {
+		t.Errorf("invalid minimum TTL: %v", minTTL)
+	}
+}
+
+func TestResolveNoIPv6(t *testing.T) {
+	if !*flagNetworkAccess {
+		t.SkipNow()
+	}
+
+	r := newResolver(t)
+	r.NoIPv6 = true
+
+	addrs, _, err := r.Resolve(context.Background(), testDomain)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	t.Logf("addrs: %+v", addrs)
+	if len(addrs) < 1 {
+		t.Fatalf("expected at least one address")
+	}
+
+	for _, addr := range addrs {
+		if addr.Is6() {
+			t.Errorf("got unexpected IPv6 address: %v", addr)
+		}
+	}
+}
+
+func TestResolveFallbackToTCP(t *testing.T) {
+	var udpCalls, tcpCalls int
+	hook := func(nameserver netip.Addr, network string, req *dns.Msg) (*dns.Msg, error) {
+		if strings.HasPrefix(network, "udp") {
+			t.Logf("got %q query; returning truncated result", network)
+			udpCalls++
+			resp := &dns.Msg{}
+			resp.SetReply(req)
+			resp.Truncated = true
+			return resp, nil
+		}
+
+		t.Logf("got %q query; returning real result", network)
+		tcpCalls++
+		resp := &dns.Msg{}
+		resp.SetReply(req)
+		resp.Answer = append(resp.Answer, &dns.A{
+			Hdr: dns.RR_Header{
+				Name:   req.Question[0].Name,
+				Rrtype: req.Question[0].Qtype,
+				Class:  dns.ClassINET,
+				Ttl:    300,
+			},
+			A: net.IPv4(1, 2, 3, 4),
+		})
+		return resp, nil
+	}
+
+	r := newResolver(t)
+	r.testExchangeHook = hook
+
+	ctx := context.Background()
+	resp, err := r.queryNameserverProto(ctx, 0, "tailscale.com", netip.MustParseAddr("9.9.9.9"), "udp", dns.Type(dns.TypeA))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if len(resp.Answer) < 1 {
+		t.Fatalf("no answers in response: %v", resp)
+	}
+	rrA, ok := resp.Answer[0].(*dns.A)
+	if !ok {
+		t.Fatalf("invalid RR type: %T", resp.Answer[0])
+	}
+	if !rrA.A.Equal(net.IPv4(1, 2, 3, 4)) {
+		t.Errorf("wanted A response 1.2.3.4, got: %v", rrA.A)
+	}
+	if tcpCalls != 1 {
+		t.Errorf("got %d, want 1 TCP calls", tcpCalls)
+	}
+	if udpCalls != 1 {
+		t.Errorf("got %d, want 1 UDP calls", udpCalls)
+	}
+
+	// Verify that we're cached and re-run to fetch from the cache.
+	if len(r.queryCache) < 1 {
+		t.Errorf("wanted entries in the query cache")
+	}
+
+	resp2, err := r.queryNameserverProto(ctx, 0, "tailscale.com", netip.MustParseAddr("9.9.9.9"), "udp", dns.Type(dns.TypeA))
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !reflect.DeepEqual(resp, resp2) {
+		t.Errorf("expected equal responses; old=%+v new=%+v", resp, resp2)
+	}
+
+	// We didn't make any more network requests since we loaded from the cache.
+	if tcpCalls != 1 {
+		t.Errorf("got %d, want 1 TCP calls", tcpCalls)
+	}
+	if udpCalls != 1 {
+		t.Errorf("got %d, want 1 UDP calls", udpCalls)
+	}
+}
+
+func dnsIPRR(name string, addr netip.Addr) dns.RR {
+	if addr.Is4() {
+		return &dns.A{
+			Hdr: dns.RR_Header{
+				Name:   name,
+				Rrtype: dns.TypeA,
+				Class:  dns.ClassINET,
+				Ttl:    300,
+			},
+			A: net.IP(addr.AsSlice()),
+		}
+	}
+
+	return &dns.AAAA{
+		Hdr: dns.RR_Header{
+			Name:   name,
+			Rrtype: dns.TypeAAAA,
+			Class:  dns.ClassINET,
+			Ttl:    300,
+		},
+		AAAA: net.IP(addr.AsSlice()),
+	}
+}
+
+func cnameRR(name, target string) dns.RR {
+	return &dns.CNAME{
+		Hdr: dns.RR_Header{
+			Name:   name,
+			Rrtype: dns.TypeCNAME,
+			Class:  dns.ClassINET,
+			Ttl:    300,
+		},
+		Target: target,
+	}
+}
+
+func nsRR(name, target string) dns.RR {
+	return &dns.NS{
+		Hdr: dns.RR_Header{
+			Name:   name,
+			Rrtype: dns.TypeNS,
+			Class:  dns.ClassINET,
+			Ttl:    300,
+		},
+		Ns: target,
+	}
+}
+
+type mockReply struct {
+	name  string
+	qtype dns.Type
+	resp  *dns.Msg
+}
+
+type replyMock struct {
+	tb      testing.TB
+	replies map[netip.Addr][]mockReply
+}
+
+func (r *replyMock) exchangeHook(nameserver netip.Addr, network string, req *dns.Msg) (*dns.Msg, error) {
+	if len(req.Question) != 1 {
+		r.tb.Fatalf("unsupported multiple or empty question: %v", req.Question)
+	}
+	question := req.Question[0]
+
+	replies := r.replies[nameserver]
+	if len(replies) == 0 {
+		r.tb.Fatalf("no configured replies for nameserver: %v", nameserver)
+	}
+
+	for _, reply := range replies {
+		if reply.name == question.Name && reply.qtype == dns.Type(question.Qtype) {
+			return reply.resp.Copy(), nil
+		}
+	}
+
+	r.tb.Fatalf("no replies found for query %q of type %v to %v", question.Name, question.Qtype, nameserver)
+	panic("unreachable")
+}
+
+// responses for mocking, shared between the following tests
+var (
+	rootServerAddr = netip.MustParseAddr("198.41.0.4") // a.root-servers.net.
+	comNSAddr      = netip.MustParseAddr("192.5.6.30") // a.gtld-servers.net.
+
+	// DNS response from the root nameservers for a .com nameserver
+	comRecord = &dns.Msg{
+		Ns:    []dns.RR{nsRR("com.", "a.gtld-servers.net.")},
+		Extra: []dns.RR{dnsIPRR("a.gtld-servers.net.", comNSAddr)},
+	}
+
+	// Random Amazon nameservers that we use in glue records
+	amazonNS   = netip.MustParseAddr("205.251.192.197")
+	amazonNSv6 = netip.MustParseAddr("2600:9000:5306:1600::1")
+
+	// Nameservers for the tailscale.com domain
+	tailscaleNameservers = &dns.Msg{
+		Ns: []dns.RR{
+			nsRR("tailscale.com.", "ns-197.awsdns-24.com."),
+			nsRR("tailscale.com.", "ns-557.awsdns-05.net."),
+			nsRR("tailscale.com.", "ns-1558.awsdns-02.co.uk."),
+			nsRR("tailscale.com.", "ns-1359.awsdns-41.org."),
+		},
+		Extra: []dns.RR{
+			dnsIPRR("ns-197.awsdns-24.com.", amazonNS),
+		},
+	}
+)
+
+func TestBasicRecursion(t *testing.T) {
+	mock := &replyMock{
+		tb: t,
+		replies: map[netip.Addr][]mockReply{
+			// Query to the root server returns the .com server + a glue record
+			rootServerAddr: {
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord},
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord},
+			},
+
+			// Query to the ".com" server return the nameservers for tailscale.com
+			comNSAddr: {
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers},
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers},
+			},
+
+			// Query to the actual nameserver works.
+			amazonNS: {
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{
+						dnsIPRR("tailscale.com.", netip.MustParseAddr("13.248.141.131")),
+						dnsIPRR("tailscale.com.", netip.MustParseAddr("76.223.15.28")),
+					},
+				}},
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{
+						dnsIPRR("tailscale.com.", netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b")),
+						dnsIPRR("tailscale.com.", netip.MustParseAddr("2600:9000:a51d:27c1:1530:b9ef:2a6:b9e5")),
+					},
+				}},
+			},
+		},
+	}
+
+	r := newResolver(t)
+	r.testExchangeHook = mock.exchangeHook
+	r.rootServers = []netip.Addr{rootServerAddr}
+
+	// Query for tailscale.com, verify we get the right responses
+	ctx := context.Background()
+	addrs, minTTL, err := r.Resolve(ctx, "tailscale.com")
+	if err != nil {
+		t.Fatal(err)
+	}
+	wantAddrs := []netip.Addr{
+		netip.MustParseAddr("13.248.141.131"),
+		netip.MustParseAddr("76.223.15.28"),
+		netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"),
+		netip.MustParseAddr("2600:9000:a51d:27c1:1530:b9ef:2a6:b9e5"),
+	}
+	slices.SortFunc(addrs, func(x, y netip.Addr) bool { return x.String() < y.String() })
+	slices.SortFunc(wantAddrs, func(x, y netip.Addr) bool { return x.String() < y.String() })
+
+	if !reflect.DeepEqual(addrs, wantAddrs) {
+		t.Errorf("got addrs=%+v; want %+v", addrs, wantAddrs)
+	}
+
+	const wantMinTTL = 5 * time.Minute
+	if minTTL != wantMinTTL {
+		t.Errorf("got minTTL=%+v; want %+v", minTTL, wantMinTTL)
+	}
+}
+
+func TestNoAnswers(t *testing.T) {
+	mock := &replyMock{
+		tb: t,
+		replies: map[netip.Addr][]mockReply{
+			// Query to the root server returns the .com server + a glue record
+			rootServerAddr: {
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord},
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord},
+			},
+
+			// Query to the ".com" server return the nameservers for tailscale.com
+			comNSAddr: {
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers},
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers},
+			},
+
+			// Query to the actual nameserver returns no responses, authoritatively.
+			amazonNS: {
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{},
+				}},
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{},
+				}},
+			},
+		},
+	}
+
+	r := &Resolver{
+		Logf:             t.Logf,
+		testExchangeHook: mock.exchangeHook,
+		rootServers:      []netip.Addr{rootServerAddr},
+	}
+
+	// Query for tailscale.com, verify we get the right responses
+	_, _, err := r.Resolve(context.Background(), "tailscale.com")
+	if err == nil {
+		t.Fatalf("got no error, want error")
+	}
+	if !errors.Is(err, ErrAuthoritativeNoResponses) {
+		t.Fatalf("got err=%v, want %v", err, ErrAuthoritativeNoResponses)
+	}
+}
+
+func TestRecursionCNAME(t *testing.T) {
+	mock := &replyMock{
+		tb: t,
+		replies: map[netip.Addr][]mockReply{
+			// Query to the root server returns the .com server + a glue record
+			rootServerAddr: {
+				{name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeA), resp: comRecord},
+				{name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord},
+
+				{name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord},
+				{name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord},
+			},
+
+			// Query to the ".com" server return the nameservers for tailscale.com
+			comNSAddr: {
+				{name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers},
+				{name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers},
+
+				{name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers},
+				{name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers},
+			},
+
+			// Query to the actual nameserver works.
+			amazonNS: {
+				{name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{cnameRR("subdomain.otherdomain.com.", "subdomain.tailscale.com.")},
+				}},
+				{name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{cnameRR("subdomain.otherdomain.com.", "subdomain.tailscale.com.")},
+				}},
+
+				{name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{dnsIPRR("tailscale.com.", netip.MustParseAddr("13.248.141.131"))},
+				}},
+				{name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{dnsIPRR("tailscale.com.", netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"))},
+				}},
+			},
+		},
+	}
+
+	r := &Resolver{
+		Logf:             t.Logf,
+		testExchangeHook: mock.exchangeHook,
+		rootServers:      []netip.Addr{rootServerAddr},
+	}
+
+	// Query for tailscale.com, verify we get the right responses
+	addrs, minTTL, err := r.Resolve(context.Background(), "subdomain.otherdomain.com")
+	if err != nil {
+		t.Fatal(err)
+	}
+	wantAddrs := []netip.Addr{
+		netip.MustParseAddr("13.248.141.131"),
+		netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"),
+	}
+	slices.SortFunc(addrs, func(x, y netip.Addr) bool { return x.String() < y.String() })
+	slices.SortFunc(wantAddrs, func(x, y netip.Addr) bool { return x.String() < y.String() })
+
+	if !reflect.DeepEqual(addrs, wantAddrs) {
+		t.Errorf("got addrs=%+v; want %+v", addrs, wantAddrs)
+	}
+
+	const wantMinTTL = 5 * time.Minute
+	if minTTL != wantMinTTL {
+		t.Errorf("got minTTL=%+v; want %+v", minTTL, wantMinTTL)
+	}
+}
+
+func TestRecursionNoGlue(t *testing.T) {
+	coukNS := netip.MustParseAddr("213.248.216.1")
+	coukRecord := &dns.Msg{
+		Ns:    []dns.RR{nsRR("com.", "dns1.nic.uk.")},
+		Extra: []dns.RR{dnsIPRR("dns1.nic.uk.", coukNS)},
+	}
+
+	intermediateNS := netip.MustParseAddr("205.251.193.66") // g-ns-322.awsdns-02.co.uk.
+	intermediateRecord := &dns.Msg{
+		Ns:    []dns.RR{nsRR("awsdns-02.co.uk.", "g-ns-322.awsdns-02.co.uk.")},
+		Extra: []dns.RR{dnsIPRR("g-ns-322.awsdns-02.co.uk.", intermediateNS)},
+	}
+
+	const amazonNameserver = "ns-1558.awsdns-02.co.uk."
+	tailscaleNameservers := &dns.Msg{
+		Ns: []dns.RR{
+			nsRR("tailscale.com.", amazonNameserver),
+		},
+	}
+
+	tailscaleResponses := []mockReply{
+		{name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{
+			MsgHdr: dns.MsgHdr{Authoritative: true},
+			Answer: []dns.RR{dnsIPRR("tailscale.com.", netip.MustParseAddr("13.248.141.131"))},
+		}},
+		{name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{
+			MsgHdr: dns.MsgHdr{Authoritative: true},
+			Answer: []dns.RR{dnsIPRR("tailscale.com.", netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"))},
+		}},
+	}
+
+	mock := &replyMock{
+		tb: t,
+		replies: map[netip.Addr][]mockReply{
+			rootServerAddr: {
+				// Query to the root server returns the .com server + a glue record
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord},
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord},
+
+				// Querying the .co.uk nameserver returns the .co.uk nameserver + a glue record.
+				{name: amazonNameserver, qtype: dns.Type(dns.TypeA), resp: coukRecord},
+				{name: amazonNameserver, qtype: dns.Type(dns.TypeAAAA), resp: coukRecord},
+			},
+
+			// Queries to the ".com" server return the nameservers
+			// for tailscale.com, which don't contain a glue
+			// record.
+			comNSAddr: {
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers},
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers},
+			},
+
+			// Queries to the ".co.uk" nameserver returns the
+			// address of the intermediate Amazon nameserver.
+			coukNS: {
+				{name: amazonNameserver, qtype: dns.Type(dns.TypeA), resp: intermediateRecord},
+				{name: amazonNameserver, qtype: dns.Type(dns.TypeAAAA), resp: intermediateRecord},
+			},
+
+			// Queries to the intermediate nameserver returns an
+			// answer for the final Amazon nameserver.
+			intermediateNS: {
+				{name: amazonNameserver, qtype: dns.Type(dns.TypeA), resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{dnsIPRR(amazonNameserver, amazonNS)},
+				}},
+				{name: amazonNameserver, qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{dnsIPRR(amazonNameserver, amazonNSv6)},
+				}},
+			},
+
+			// Queries to the actual nameserver work and return
+			// responses to the query.
+			amazonNS:   tailscaleResponses,
+			amazonNSv6: tailscaleResponses,
+		},
+	}
+
+	r := newResolver(t)
+	r.testExchangeHook = mock.exchangeHook
+	r.rootServers = []netip.Addr{rootServerAddr}
+
+	// Query for tailscale.com, verify we get the right responses
+	addrs, minTTL, err := r.Resolve(context.Background(), "tailscale.com")
+	if err != nil {
+		t.Fatal(err)
+	}
+	wantAddrs := []netip.Addr{
+		netip.MustParseAddr("13.248.141.131"),
+		netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"),
+	}
+	slices.SortFunc(addrs, func(x, y netip.Addr) bool { return x.String() < y.String() })
+	slices.SortFunc(wantAddrs, func(x, y netip.Addr) bool { return x.String() < y.String() })
+
+	if !reflect.DeepEqual(addrs, wantAddrs) {
+		t.Errorf("got addrs=%+v; want %+v", addrs, wantAddrs)
+	}
+
+	const wantMinTTL = 5 * time.Minute
+	if minTTL != wantMinTTL {
+		t.Errorf("got minTTL=%+v; want %+v", minTTL, wantMinTTL)
+	}
+}
+
+func TestRecursionLimit(t *testing.T) {
+	mock := &replyMock{
+		tb:      t,
+		replies: map[netip.Addr][]mockReply{},
+	}
+
+	// Fill out a CNAME chain equal to our recursion limit; we won't get
+	// this far since each CNAME is more than 1 level "deep", but this
+	// ensures that we have more than the limit.
+	for i := 0; i < maxDepth+1; i++ {
+		curr := fmt.Sprintf("%d-tailscale.com.", i)
+
+		tailscaleNameservers := &dns.Msg{
+			Ns:    []dns.RR{nsRR(curr, "ns-197.awsdns-24.com.")},
+			Extra: []dns.RR{dnsIPRR("ns-197.awsdns-24.com.", amazonNS)},
+		}
+
+		// Query to the root server returns the .com server + a glue record
+		mock.replies[rootServerAddr] = append(mock.replies[rootServerAddr],
+			mockReply{name: curr, qtype: dns.Type(dns.TypeA), resp: comRecord},
+			mockReply{name: curr, qtype: dns.Type(dns.TypeAAAA), resp: comRecord},
+		)
+
+		// Query to the ".com" server return the nameservers for NN-tailscale.com
+		mock.replies[comNSAddr] = append(mock.replies[comNSAddr],
+			mockReply{name: curr, qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers},
+			mockReply{name: curr, qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers},
+		)
+
+		// Queries to the nameserver return a CNAME for the n+1th server.
+		next := fmt.Sprintf("%d-tailscale.com.", i+1)
+		mock.replies[amazonNS] = append(mock.replies[amazonNS],
+			mockReply{
+				name:  curr,
+				qtype: dns.Type(dns.TypeA),
+				resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{cnameRR(curr, next)},
+				},
+			},
+			mockReply{
+				name:  curr,
+				qtype: dns.Type(dns.TypeAAAA),
+				resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{cnameRR(curr, next)},
+				},
+			},
+		)
+	}
+
+	r := newResolver(t)
+	r.testExchangeHook = mock.exchangeHook
+	r.rootServers = []netip.Addr{rootServerAddr}
+
+	// Query for the first node in the chain, 0-tailscale.com, and verify
+	// we get a max-depth error.
+	ctx := context.Background()
+	_, _, err := r.Resolve(ctx, "0-tailscale.com")
+	if err == nil {
+		t.Fatal("expected error, got nil")
+	} else if !errors.Is(err, ErrMaxDepth) {
+		t.Fatalf("got err=%v, want ErrMaxDepth", err)
+	}
+}
+
+func TestInvalidResponses(t *testing.T) {
+	mock := &replyMock{
+		tb: t,
+		replies: map[netip.Addr][]mockReply{
+			// Query to the root server returns the .com server + a glue record
+			rootServerAddr: {
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord},
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord},
+			},
+
+			// Query to the ".com" server return the nameservers for tailscale.com
+			comNSAddr: {
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers},
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers},
+			},
+
+			// Query to the actual nameserver returns an invalid IP address
+			amazonNS: {
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					Answer: []dns.RR{&dns.A{
+						Hdr: dns.RR_Header{
+							Name:   "tailscale.com.",
+							Rrtype: dns.TypeA,
+							Class:  dns.ClassINET,
+							Ttl:    300,
+						},
+						// Note: this is an IPv6 addr in an IPv4 response
+						A: net.IP(netip.MustParseAddr("2600:9000:a51d:27c1:1530:b9ef:2a6:b9e5").AsSlice()),
+					}},
+				}},
+				{name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{
+					MsgHdr: dns.MsgHdr{Authoritative: true},
+					// This an IPv4 response to an IPv6 query
+					Answer: []dns.RR{&dns.A{
+						Hdr: dns.RR_Header{
+							Name:   "tailscale.com.",
+							Rrtype: dns.TypeA,
+							Class:  dns.ClassINET,
+							Ttl:    300,
+						},
+						A: net.IP(netip.MustParseAddr("13.248.141.131").AsSlice()),
+					}},
+				}},
+			},
+		},
+	}
+
+	r := &Resolver{
+		Logf:             t.Logf,
+		testExchangeHook: mock.exchangeHook,
+		rootServers:      []netip.Addr{rootServerAddr},
+	}
+
+	// Query for tailscale.com, verify we get no responses since the
+	// addresses are invalid.
+	_, _, err := r.Resolve(context.Background(), "tailscale.com")
+	if err == nil {
+		t.Fatalf("got no error, want error")
+	}
+	if !errors.Is(err, ErrAuthoritativeNoResponses) {
+		t.Fatalf("got err=%v, want %v", err, ErrAuthoritativeNoResponses)
+	}
+}
+
+// TODO(andrew): test for more edge cases that aren't currently covered:
+//	* Nameservers that cross between IPv4 and IPv6
+//	* Authoritative no replies after following CNAME
+//	* Authoritative no replies after following non-glue NS record
+//	* Error querying non-glue NS record followed by success