Просмотр исходного кода

net/dns: retry forwarder requests over TCP

We weren't correctly retrying truncated requests to an upstream DNS
server with TCP. Instead, we'd return a truncated request to the user,
even if the user was querying us over TCP and thus able to handle a
large response.

Also, add an envknob and controlknob to allow users/us to disable this
behaviour if it turns out to be buggy (✨ DNS ✨).

Updates #9264

Signed-off-by: Andrew Dunham <[email protected]>
Change-Id: Ifb04b563839a9614c0ba03e9c564e8924c1a2bfd
Andrew Dunham 2 лет назад
Родитель
Сommit
530aaa52f1

+ 23 - 16
control/controlknobs/controlknobs.go

@@ -48,6 +48,10 @@ type Knobs struct {
 
 	// PeerMTUEnable is whether the node should do peer path MTU discovery.
 	PeerMTUEnable atomic.Bool
+
+	// DisableDNSForwarderTCPRetries is whether the DNS forwarder should
+	// skip retrying truncated queries over TCP.
+	DisableDNSForwarderTCPRetries atomic.Bool
 }
 
 // UpdateFromNodeAttributes updates k (if non-nil) based on the provided self
@@ -61,14 +65,15 @@ func (k *Knobs) UpdateFromNodeAttributes(selfNodeAttrs []tailcfg.NodeCapability,
 		return ok || slices.Contains(selfNodeAttrs, attr)
 	}
 	var (
-		keepFullWG          = has(tailcfg.NodeAttrDebugDisableWGTrim)
-		disableDRPO         = has(tailcfg.NodeAttrDebugDisableDRPO)
-		disableUPnP         = has(tailcfg.NodeAttrDisableUPnP)
-		randomizeClientPort = has(tailcfg.NodeAttrRandomizeClientPort)
-		disableDeltaUpdates = has(tailcfg.NodeAttrDisableDeltaUpdates)
-		oneCGNAT            opt.Bool
-		forceBackgroundSTUN = has(tailcfg.NodeAttrDebugForceBackgroundSTUN)
-		peerMTUEnable       = has(tailcfg.NodeAttrPeerMTUEnable)
+		keepFullWG                    = has(tailcfg.NodeAttrDebugDisableWGTrim)
+		disableDRPO                   = has(tailcfg.NodeAttrDebugDisableDRPO)
+		disableUPnP                   = has(tailcfg.NodeAttrDisableUPnP)
+		randomizeClientPort           = has(tailcfg.NodeAttrRandomizeClientPort)
+		disableDeltaUpdates           = has(tailcfg.NodeAttrDisableDeltaUpdates)
+		oneCGNAT                      opt.Bool
+		forceBackgroundSTUN           = has(tailcfg.NodeAttrDebugForceBackgroundSTUN)
+		peerMTUEnable                 = has(tailcfg.NodeAttrPeerMTUEnable)
+		dnsForwarderDisableTCPRetries = has(tailcfg.NodeAttrDNSForwarderDisableTCPRetries)
 	)
 
 	if has(tailcfg.NodeAttrOneCGNATEnable) {
@@ -85,6 +90,7 @@ func (k *Knobs) UpdateFromNodeAttributes(selfNodeAttrs []tailcfg.NodeCapability,
 	k.ForceBackgroundSTUN.Store(forceBackgroundSTUN)
 	k.DisableDeltaUpdates.Store(disableDeltaUpdates)
 	k.PeerMTUEnable.Store(peerMTUEnable)
+	k.DisableDNSForwarderTCPRetries.Store(dnsForwarderDisableTCPRetries)
 }
 
 // AsDebugJSON returns k as something that can be marshalled with json.Marshal
@@ -94,13 +100,14 @@ func (k *Knobs) AsDebugJSON() map[string]any {
 		return nil
 	}
 	return map[string]any{
-		"DisableUPnP":         k.DisableUPnP.Load(),
-		"DisableDRPO":         k.DisableDRPO.Load(),
-		"KeepFullWGConfig":    k.KeepFullWGConfig.Load(),
-		"RandomizeClientPort": k.RandomizeClientPort.Load(),
-		"OneCGNAT":            k.OneCGNAT.Load(),
-		"ForceBackgroundSTUN": k.ForceBackgroundSTUN.Load(),
-		"DisableDeltaUpdates": k.DisableDeltaUpdates.Load(),
-		"PeerMTUEnable":       k.PeerMTUEnable.Load(),
+		"DisableUPnP":                   k.DisableUPnP.Load(),
+		"DisableDRPO":                   k.DisableDRPO.Load(),
+		"KeepFullWGConfig":              k.KeepFullWGConfig.Load(),
+		"RandomizeClientPort":           k.RandomizeClientPort.Load(),
+		"OneCGNAT":                      k.OneCGNAT.Load(),
+		"ForceBackgroundSTUN":           k.ForceBackgroundSTUN.Load(),
+		"DisableDeltaUpdates":           k.DisableDeltaUpdates.Load(),
+		"PeerMTUEnable":                 k.PeerMTUEnable.Load(),
+		"DisableDNSForwarderTCPRetries": k.DisableDNSForwarderTCPRetries.Load(),
 	}
 }

+ 10 - 6
net/dns/manager.go

@@ -17,6 +17,7 @@ import (
 	"sync/atomic"
 	"time"
 
+	"tailscale.com/control/controlknobs"
 	"tailscale.com/health"
 	"tailscale.com/net/dns/resolver"
 	"tailscale.com/net/netmon"
@@ -66,14 +67,14 @@ type Manager struct {
 
 // NewManagers created a new manager from the given config.
 // The netMon parameter is optional; if non-nil it's used to do faster interface lookups.
-func NewManager(logf logger.Logf, oscfg OSConfigurator, netMon *netmon.Monitor, dialer *tsdial.Dialer, linkSel resolver.ForwardLinkSelector) *Manager {
+func NewManager(logf logger.Logf, oscfg OSConfigurator, netMon *netmon.Monitor, dialer *tsdial.Dialer, linkSel resolver.ForwardLinkSelector, knobs *controlknobs.Knobs) *Manager {
 	if dialer == nil {
 		panic("nil Dialer")
 	}
 	logf = logger.WithPrefix(logf, "dns: ")
 	m := &Manager{
 		logf:     logf,
-		resolver: resolver.New(logf, netMon, linkSel, dialer),
+		resolver: resolver.New(logf, netMon, linkSel, dialer, knobs),
 		os:       oscfg,
 	}
 	m.ctx, m.ctxCancel = context.WithCancel(context.Background())
@@ -295,7 +296,10 @@ func toIPsOnly(resolvers []*dnstype.Resolver) (ret []netip.Addr) {
 // Query executes a DNS query received from the given address. The query is
 // provided in bs as a wire-encoded DNS query without any transport header.
 // This method is called for requests arriving over UDP and TCP.
-func (m *Manager) Query(ctx context.Context, bs []byte, from netip.AddrPort) ([]byte, error) {
+//
+// The "family" parameter should indicate what type of DNS query this is:
+// either "tcp" or "udp".
+func (m *Manager) Query(ctx context.Context, bs []byte, family string, from netip.AddrPort) ([]byte, error) {
 	select {
 	case <-m.ctx.Done():
 		return nil, net.ErrClosed
@@ -309,7 +313,7 @@ func (m *Manager) Query(ctx context.Context, bs []byte, from netip.AddrPort) ([]
 		return nil, errFullQueue
 	}
 	defer atomic.AddInt32(&m.activeQueriesAtomic, -1)
-	return m.resolver.Query(ctx, bs, from)
+	return m.resolver.Query(ctx, bs, family, from)
 }
 
 const (
@@ -371,7 +375,7 @@ func (s *dnsTCPSession) handleWrites() {
 }
 
 func (s *dnsTCPSession) handleQuery(q []byte) {
-	resp, err := s.m.Query(s.ctx, q, s.srcAddr)
+	resp, err := s.m.Query(s.ctx, q, "tcp", s.srcAddr)
 	if err != nil {
 		s.m.logf("tcp query: %v", err)
 		return
@@ -466,7 +470,7 @@ func Cleanup(logf logger.Logf, interfaceName string) {
 		logf("creating dns cleanup: %v", err)
 		return
 	}
-	dns := NewManager(logf, oscfg, nil, &tsdial.Dialer{Logf: logf}, nil)
+	dns := NewManager(logf, oscfg, nil, &tsdial.Dialer{Logf: logf}, nil, nil)
 	if err := dns.Down(); err != nil {
 		logf("dns down: %v", err)
 	}

+ 2 - 2
net/dns/manager_tcp_test.go

@@ -87,7 +87,7 @@ func TestDNSOverTCP(t *testing.T) {
 			SearchDomains: fqdns("coffee.shop"),
 		},
 	}
-	m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil)
+	m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil, nil)
 	m.resolver.TestOnlySetHook(f.SetResolver)
 	m.Set(Config{
 		Hosts: hosts(
@@ -172,7 +172,7 @@ func TestDNSOverTCP_TooLarge(t *testing.T) {
 			SearchDomains: fqdns("coffee.shop"),
 		},
 	}
-	m := NewManager(log, &f, nil, new(tsdial.Dialer), nil)
+	m := NewManager(log, &f, nil, new(tsdial.Dialer), nil, nil)
 	m.resolver.TestOnlySetHook(f.SetResolver)
 	m.Set(Config{
 		Hosts:         hosts("andrew.ts.com.", "1.2.3.4"),

+ 1 - 1
net/dns/manager_test.go

@@ -613,7 +613,7 @@ func TestManager(t *testing.T) {
 				SplitDNS:   test.split,
 				BaseConfig: test.bs,
 			}
-			m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil)
+			m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil, nil)
 			m.resolver.TestOnlySetHook(f.SetResolver)
 
 			if err := m.Set(test.in); err != nil {

+ 147 - 9
net/dns/resolver/forwarder.go

@@ -21,6 +21,7 @@ import (
 	"time"
 
 	dns "golang.org/x/net/dns/dnsmessage"
+	"tailscale.com/control/controlknobs"
 	"tailscale.com/envknob"
 	"tailscale.com/net/dns/publicdns"
 	"tailscale.com/net/dnscache"
@@ -68,6 +69,10 @@ const (
 	// DNS queries to the "fallback" DNS server IP for a known provider
 	// (e.g. how long to wait to query Google's 8.8.4.4 after 8.8.8.8).
 	wellKnownHostBackupDelay = 200 * time.Millisecond
+
+	// tcpQueryTimeout is the timeout for a DNS query performed over TCP.
+	// It matches the default 5sec timeout of the 'dig' utility.
+	tcpQueryTimeout = 5 * time.Second
 )
 
 // txid identifies a DNS transaction.
@@ -180,6 +185,8 @@ type forwarder struct {
 	linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it
 	dialer  *tsdial.Dialer
 
+	controlKnobs *controlknobs.Knobs // or nil
+
 	ctx       context.Context    // good until Close
 	ctxCancel context.CancelFunc // closes ctx
 
@@ -206,12 +213,13 @@ func init() {
 	rand.Seed(time.Now().UnixNano())
 }
 
-func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer) *forwarder {
+func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, knobs *controlknobs.Knobs) *forwarder {
 	f := &forwarder{
-		logf:    logger.WithPrefix(logf, "forward: "),
-		netMon:  netMon,
-		linkSel: linkSel,
-		dialer:  dialer,
+		logf:         logger.WithPrefix(logf, "forward: "),
+		netMon:       netMon,
+		linkSel:      linkSel,
+		dialer:       dialer,
+		controlKnobs: knobs,
 	}
 	f.ctx, f.ctxCancel = context.WithCancel(context.Background())
 	return f
@@ -443,7 +451,10 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client,
 	return res, err
 }
 
-var verboseDNSForward = envknob.RegisterBool("TS_DEBUG_DNS_FORWARD_SEND")
+var (
+	verboseDNSForward = envknob.RegisterBool("TS_DEBUG_DNS_FORWARD_SEND")
+	skipTCPRetry      = envknob.RegisterBool("TS_DNS_FORWARD_SKIP_TCP_RETRY")
+)
 
 // send sends packet to dst. It is best effort.
 //
@@ -477,10 +488,49 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe
 		return nil, fmt.Errorf("tls:// resolvers not supported yet")
 	}
 
-	return f.sendUDP(ctx, fq, rr)
+	ret, err = f.sendUDP(ctx, fq, rr)
+	if err != nil {
+		return nil, err
+	}
+
+	if !truncatedFlagSet(ret) {
+		// Successful, non-truncated response; return it.
+		return ret, nil
+	}
+	if fq.family == "udp" {
+		// If this is a UDP query, return it regardless of whether the
+		// response is truncated or not; the client can retry
+		// communicating with tailscaled over TCP. There's no point
+		// falling back to TCP for a truncated query if we can't return
+		// the results to the client.
+		return ret, nil
+	}
+	if skipTCPRetry() || (f.controlKnobs != nil && f.controlKnobs.DisableDNSForwarderTCPRetries.Load()) {
+		// Envknob or control knob disabled the TCP retry behaviour;
+		// just return what we have.
+		return ret, nil
+	}
+
+	// Don't retry if our context is done.
+	if err := ctx.Err(); err != nil {
+		return nil, err
+	}
+
+	// Retry over TCP, best-effort; return the truncated UDP response if we
+	// cannot query via TCP.
+	if ret2, err2 := f.sendTCP(ctx, fq, rr); err2 == nil {
+		if verboseDNSForward() {
+			f.logf("forwarder.send(%q): successfully retried via TCP", rr.name.Addr)
+		}
+		return ret2, nil
+	} else if verboseDNSForward() {
+		f.logf("forwarder.send(%q): could not retry via TCP: %v", rr.name.Addr, err2)
+	}
+	return ret, nil
 }
 
 var errServerFailure = errors.New("response code indicates server issue")
+var errTxIDMismatch = errors.New("txid doesn't match")
 
 func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
 	ipp, ok := rr.name.IPPort()
@@ -545,7 +595,7 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
 	txid := getTxID(out)
 	if txid != fq.txid {
 		metricDNSFwdUDPErrorTxID.Add(1)
-		return nil, errors.New("txid doesn't match")
+		return nil, errTxIDMismatch
 	}
 	rcode := getRCode(out)
 	// don't forward transient errors back to the client when the server fails
@@ -577,6 +627,92 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
 	return out, nil
 }
 
+func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
+	ipp, ok := rr.name.IPPort()
+	if !ok {
+		metricDNSFwdErrorType.Add(1)
+		return nil, fmt.Errorf("unrecognized resolver type %q", rr.name.Addr)
+	}
+	metricDNSFwdTCP.Add(1)
+	ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderTCP, f.logf)
+
+	// Specify the exact family to work around https://github.com/golang/go/issues/52264
+	tcpFam := "tcp4"
+	if ipp.Addr().Is6() {
+		tcpFam = "tcp6"
+	}
+
+	ctx, cancel := context.WithTimeout(ctx, tcpQueryTimeout)
+	defer cancel()
+
+	conn, err := f.dialer.SystemDial(ctx, tcpFam, ipp.String())
+	if err != nil {
+		return nil, err
+	}
+	defer conn.Close()
+
+	fq.closeOnCtxDone.Add(conn)
+	defer fq.closeOnCtxDone.Remove(conn)
+
+	ctxOrErr := func(err2 error) ([]byte, error) {
+		if err := ctx.Err(); err != nil {
+			return nil, err
+		}
+		return nil, err2
+	}
+
+	// Write the query to the server.
+	query := make([]byte, len(fq.packet)+2)
+	binary.BigEndian.PutUint16(query, uint16(len(fq.packet)))
+	copy(query[2:], fq.packet)
+	if _, err := conn.Write(query); err != nil {
+		metricDNSFwdTCPErrorWrite.Add(1)
+		return ctxOrErr(err)
+	}
+
+	metricDNSFwdTCPWrote.Add(1)
+
+	// Read the header length back from the server
+	var length uint16
+	if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
+		metricDNSFwdTCPErrorRead.Add(1)
+		return ctxOrErr(err)
+	}
+
+	// Now read the response
+	out := make([]byte, length)
+	n, err := io.ReadFull(conn, out)
+	if err != nil {
+		metricDNSFwdTCPErrorRead.Add(1)
+		return ctxOrErr(err)
+	}
+
+	if n < int(length) {
+		f.logf("sendTCP: packet too small (%d bytes)", n)
+		return nil, io.ErrUnexpectedEOF
+	}
+	out = out[:n]
+	txid := getTxID(out)
+	if txid != fq.txid {
+		metricDNSFwdTCPErrorTxID.Add(1)
+		return nil, errTxIDMismatch
+	}
+
+	rcode := getRCode(out)
+
+	// don't forward transient errors back to the client when the server fails
+	if rcode == dns.RCodeServerFailure {
+		f.logf("sendTCP: response code indicating server failure: %d", rcode)
+		metricDNSFwdTCPErrorServer.Add(1)
+		return nil, errServerFailure
+	}
+
+	// TODO(andrew): do we need to do this?
+	//clampEDNSSize(out, maxResponseBytes)
+	metricDNSFwdTCPSuccess.Add(1)
+	return out, nil
+}
+
 // resolvers returns the resolvers to use for domain.
 func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay {
 	f.mu.Lock()
@@ -601,6 +737,7 @@ func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay {
 type forwardQuery struct {
 	txid   txid
 	packet []byte
+	family string // "tcp" or "udp"
 
 	// closeOnCtxDone lets send register values to Close if the
 	// caller's ctx expires. This avoids send from allocating its
@@ -686,6 +823,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
 	fq := &forwardQuery{
 		txid:           getTxID(query.bs),
 		packet:         query.bs,
+		family:         query.family,
 		closeOnCtxDone: new(closePool),
 	}
 	defer fq.closeOnCtxDone.Close()
@@ -727,7 +865,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
 			case <-ctx.Done():
 				metricDNSFwdErrorContext.Add(1)
 				return ctx.Err()
-			case responseChan <- packet{v, query.addr}:
+			case responseChan <- packet{v, query.family, query.addr}:
 				metricDNSFwdSuccess.Add(1)
 				return nil
 			}

+ 233 - 0
net/dns/resolver/forwarder_test.go

@@ -4,14 +4,26 @@
 package resolver
 
 import (
+	"bytes"
+	"context"
+	"encoding/binary"
 	"flag"
 	"fmt"
+	"io"
+	"net"
+	"net/netip"
+	"os"
 	"reflect"
 	"strings"
+	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 
 	dns "golang.org/x/net/dns/dnsmessage"
+	"tailscale.com/envknob"
+	"tailscale.com/net/netmon"
+	"tailscale.com/net/tsdial"
 	"tailscale.com/types/dnstype"
 )
 
@@ -240,3 +252,224 @@ func FuzzClampEDNSSize(f *testing.F) {
 		clampEDNSSize(data, maxResponseBytes)
 	})
 }
+
+func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte)) (port uint16) {
+	tcpResponse := make([]byte, len(response)+2)
+	binary.BigEndian.PutUint16(tcpResponse, uint16(len(response)))
+	copy(tcpResponse[2:], response)
+
+	// Repeatedly listen until we can get the same port.
+	const tries = 25
+	var (
+		tcpLn *net.TCPListener
+		udpLn *net.UDPConn
+		err   error
+	)
+	for try := 0; try < tries; try++ {
+		if tcpLn != nil {
+			tcpLn.Close()
+			tcpLn = nil
+		}
+
+		tcpLn, err = net.ListenTCP("tcp4", &net.TCPAddr{
+			IP:   net.IPv4(127, 0, 0, 1),
+			Port: 0, // Choose one
+		})
+		if err != nil {
+			tb.Fatal(err)
+		}
+		udpLn, err = net.ListenUDP("udp4", &net.UDPAddr{
+			IP:   net.IPv4(127, 0, 0, 1),
+			Port: tcpLn.Addr().(*net.TCPAddr).Port,
+		})
+		if err == nil {
+			break
+		}
+	}
+	if tcpLn == nil || udpLn == nil {
+		if tcpLn != nil {
+			tcpLn.Close()
+		}
+		if udpLn != nil {
+			udpLn.Close()
+		}
+
+		// Skip instead of being fatal to avoid flaking on extremely
+		// heavily-loaded CI systems.
+		tb.Skipf("failed to listen on same port for TCP/UDP after %d tries", tries)
+	}
+
+	port = uint16(tcpLn.Addr().(*net.TCPAddr).Port)
+
+	handleConn := func(conn net.Conn) {
+		defer conn.Close()
+
+		// Read the length header, then the buffer
+		var length uint16
+		if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
+			tb.Logf("error reading length header: %v", err)
+			return
+		}
+		req := make([]byte, length)
+		n, err := io.ReadFull(conn, req)
+		if err != nil {
+			tb.Logf("error reading query: %v", err)
+			return
+		}
+		req = req[:n]
+		onRequest(true, req)
+
+		// Write response
+		if _, err := conn.Write(tcpResponse); err != nil {
+			tb.Logf("error writing response: %v", err)
+			return
+		}
+	}
+
+	var wg sync.WaitGroup
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		for {
+			conn, err := tcpLn.Accept()
+			if err != nil {
+				return
+			}
+			go handleConn(conn)
+		}
+	}()
+
+	handleUDP := func(addr netip.AddrPort, req []byte) {
+		onRequest(false, req)
+		if _, err := udpLn.WriteToUDPAddrPort(response, addr); err != nil {
+			tb.Logf("error writing response: %v", err)
+		}
+	}
+
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		for {
+			buf := make([]byte, 65535)
+			n, addr, err := udpLn.ReadFromUDPAddrPort(buf)
+			if err != nil {
+				return
+			}
+			buf = buf[:n]
+			go handleUDP(addr, buf)
+		}
+	}()
+
+	tb.Cleanup(func() {
+		tcpLn.Close()
+		udpLn.Close()
+		tb.Logf("waiting for listeners to finish...")
+		wg.Wait()
+	})
+	return
+}
+
+func TestForwarderTCPFallback(t *testing.T) {
+	const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND"
+	oldVal := os.Getenv(debugKnob)
+	envknob.Setenv(debugKnob, "true")
+	t.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) })
+
+	const domain = "large-dns-response.tailscale.com."
+
+	// Make a response that's very large, containing a bunch of localhost addresses.
+	largeResponse := func() []byte {
+		name := dns.MustNewName(domain)
+
+		builder := dns.NewBuilder(nil, dns.Header{})
+		builder.StartQuestions()
+		builder.Question(dns.Question{
+			Name:  name,
+			Type:  dns.TypeA,
+			Class: dns.ClassINET,
+		})
+		builder.StartAnswers()
+		for i := 0; i < 120; i++ {
+			builder.AResource(dns.ResourceHeader{
+				Name:  name,
+				Class: dns.ClassINET,
+				TTL:   300,
+			}, dns.AResource{
+				A: [4]byte{127, 0, 0, byte(i)},
+			})
+		}
+
+		msg, err := builder.Finish()
+		if err != nil {
+			t.Fatal(err)
+		}
+		return msg
+	}()
+	if len(largeResponse) <= maxResponseBytes {
+		t.Fatalf("got len(largeResponse)=%d, want > %d", len(largeResponse), maxResponseBytes)
+	}
+
+	// Our request is a single A query for the domain in the answer, above.
+	request := func() []byte {
+		builder := dns.NewBuilder(nil, dns.Header{})
+		builder.StartQuestions()
+		builder.Question(dns.Question{
+			Name:  dns.MustNewName(domain),
+			Type:  dns.TypeA,
+			Class: dns.ClassINET,
+		})
+		msg, err := builder.Finish()
+		if err != nil {
+			t.Fatal(err)
+		}
+		return msg
+	}()
+
+	var sawUDPRequest, sawTCPRequest atomic.Bool
+	port := runDNSServer(t, largeResponse, func(isTCP bool, gotRequest []byte) {
+		if isTCP {
+			sawTCPRequest.Store(true)
+		} else {
+			sawUDPRequest.Store(true)
+		}
+
+		if !bytes.Equal(request, gotRequest) {
+			t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
+		}
+	})
+
+	netMon, err := netmon.New(t.Logf)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	var dialer tsdial.Dialer
+	dialer.SetNetMon(netMon)
+
+	fwd := newForwarder(t.Logf, netMon, nil, &dialer, nil)
+
+	fq := &forwardQuery{
+		txid:           getTxID(request),
+		packet:         request,
+		closeOnCtxDone: new(closePool),
+	}
+	defer fq.closeOnCtxDone.Close()
+
+	rr := resolverAndDelay{
+		name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)},
+	}
+
+	resp, err := fwd.send(context.Background(), fq, rr)
+	if err != nil {
+		t.Fatalf("error making request: %v", err)
+	}
+	if !bytes.Equal(resp, largeResponse) {
+		t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
+	}
+	if !sawTCPRequest.Load() {
+		t.Errorf("DNS server never saw TCP request")
+	}
+	if !sawUDPRequest.Load() {
+		t.Errorf("DNS server never saw UDP request")
+	}
+}

+ 17 - 7
net/dns/resolver/tsdns.go

@@ -23,6 +23,7 @@ import (
 	"time"
 
 	dns "golang.org/x/net/dns/dnsmessage"
+	"tailscale.com/control/controlknobs"
 	"tailscale.com/envknob"
 	"tailscale.com/net/dns/resolvconffile"
 	"tailscale.com/net/netaddr"
@@ -53,8 +54,9 @@ var (
 )
 
 type packet struct {
-	bs   []byte
-	addr netip.AddrPort // src for a request, dst for a response
+	bs     []byte
+	family string         // either "tcp" or "udp"
+	addr   netip.AddrPort // src for a request, dst for a response
 }
 
 // Config is a resolver configuration.
@@ -206,7 +208,7 @@ type ForwardLinkSelector interface {
 
 // New returns a new resolver.
 // netMon optionally specifies a network monitor to use for socket rebinding.
-func New(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer) *Resolver {
+func New(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, knobs *controlknobs.Knobs) *Resolver {
 	if dialer == nil {
 		panic("nil Dialer")
 	}
@@ -218,7 +220,7 @@ func New(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector,
 		ipToHost: map[netip.Addr]dnsname.FQDN{},
 		dialer:   dialer,
 	}
-	r.forwarder = newForwarder(r.logf, netMon, linkSel, dialer)
+	r.forwarder = newForwarder(r.logf, netMon, linkSel, dialer, knobs)
 	return r
 }
 
@@ -266,7 +268,7 @@ func (r *Resolver) Close() {
 // bound on per-query resource usage.
 const dnsQueryTimeout = 10 * time.Second
 
-func (r *Resolver) Query(ctx context.Context, bs []byte, from netip.AddrPort) ([]byte, error) {
+func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from netip.AddrPort) ([]byte, error) {
 	metricDNSQueryLocal.Add(1)
 	select {
 	case <-r.closed:
@@ -281,7 +283,7 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, from netip.AddrPort) ([
 		ctx, cancel := context.WithTimeout(ctx, dnsQueryTimeout)
 		defer close(responses)
 		defer cancel()
-		err = r.forwarder.forwardWithDestChan(ctx, packet{bs, from}, responses)
+		err = r.forwarder.forwardWithDestChan(ctx, packet{bs, family, from}, responses)
 		if err != nil {
 			select {
 			// Best effort: use any error response sent by forwardWithDestChan.
@@ -369,7 +371,7 @@ func (r *Resolver) HandleExitNodeDNSQuery(ctx context.Context, q []byte, from ne
 			}}
 		}
 
-		err = r.forwarder.forwardWithDestChan(ctx, packet{q, from}, ch, resolvers...)
+		err = r.forwarder.forwardWithDestChan(ctx, packet{q, "tcp", from}, ch, resolvers...)
 		if err != nil {
 			metricDNSExitProxyErrorForward.Add(1)
 			return nil, err
@@ -1306,6 +1308,14 @@ var (
 	metricDNSFwdUDPErrorRead   = clientmetric.NewCounter("dns_query_fwd_udp_error_read")
 	metricDNSFwdUDPSuccess     = clientmetric.NewCounter("dns_query_fwd_udp_success")
 
+	metricDNSFwdTCP            = clientmetric.NewCounter("dns_query_fwd_tcp")       // on entry
+	metricDNSFwdTCPWrote       = clientmetric.NewCounter("dns_query_fwd_tcp_wrote") // sent TCP packet
+	metricDNSFwdTCPErrorWrite  = clientmetric.NewCounter("dns_query_fwd_tcp_error_write")
+	metricDNSFwdTCPErrorServer = clientmetric.NewCounter("dns_query_fwd_tcp_error_server")
+	metricDNSFwdTCPErrorTxID   = clientmetric.NewCounter("dns_query_fwd_tcp_error_txid")
+	metricDNSFwdTCPErrorRead   = clientmetric.NewCounter("dns_query_fwd_tcp_error_read")
+	metricDNSFwdTCPSuccess     = clientmetric.NewCounter("dns_query_fwd_tcp_success")
+
 	metricDNSFwdDoH               = clientmetric.NewCounter("dns_query_fwd_doh")
 	metricDNSFwdDoHErrorStatus    = clientmetric.NewCounter("dns_query_fwd_doh_error_status")
 	metricDNSFwdDoHErrorCT        = clientmetric.NewCounter("dns_query_fwd_doh_error_content_type")

+ 3 - 3
net/dns/resolver/tsdns_test.go

@@ -233,7 +233,7 @@ func unpackResponse(payload []byte) (dnsResponse, error) {
 }
 
 func syncRespond(r *Resolver, query []byte) ([]byte, error) {
-	return r.Query(context.Background(), query, netip.AddrPort{})
+	return r.Query(context.Background(), query, "udp", netip.AddrPort{})
 }
 
 func mustIP(str string) netip.Addr {
@@ -315,7 +315,7 @@ func TestRDNSNameToIPv6(t *testing.T) {
 }
 
 func newResolver(t testing.TB) *Resolver {
-	return New(t.Logf, nil /* no network monitor */, nil /* no link selector */, new(tsdial.Dialer))
+	return New(t.Logf, nil /* no network monitor */, nil /* no link selector */, new(tsdial.Dialer), nil /* no control knobs */)
 }
 
 func TestResolveLocal(t *testing.T) {
@@ -1016,7 +1016,7 @@ func TestForwardLinkSelection(t *testing.T) {
 			return "special"
 		}
 		return ""
-	}), new(tsdial.Dialer))
+	}), new(tsdial.Dialer), nil /* no control knobs */)
 
 	// Test non-special IP.
 	if got, err := fwd.packetListener(netip.Addr{}); err != nil {

+ 3 - 2
net/sockstats/label_string.go

@@ -20,11 +20,12 @@ func _() {
 	_ = x[LabelMagicsockConnUDP6-9]
 	_ = x[LabelNetlogLogger-10]
 	_ = x[LabelSockstatlogLogger-11]
+	_ = x[LabelDNSForwarderTCP-12]
 }
 
-const _Label_name = "ControlClientAutoControlClientDialerDERPHTTPClientLogtailLoggerDNSForwarderDoHDNSForwarderUDPNetcheckClientPortmapperClientMagicsockConnUDP4MagicsockConnUDP6NetlogLoggerSockstatlogLogger"
+const _Label_name = "ControlClientAutoControlClientDialerDERPHTTPClientLogtailLoggerDNSForwarderDoHDNSForwarderUDPNetcheckClientPortmapperClientMagicsockConnUDP4MagicsockConnUDP6NetlogLoggerSockstatlogLoggerDNSForwarderTCP"
 
-var _Label_index = [...]uint8{0, 17, 36, 50, 63, 78, 93, 107, 123, 140, 157, 169, 186}
+var _Label_index = [...]uint8{0, 17, 36, 50, 63, 78, 93, 107, 123, 140, 157, 169, 186, 201}
 
 func (i Label) String() string {
 	if i >= Label(len(_Label_index)-1) {

+ 1 - 0
net/sockstats/sockstats.go

@@ -51,6 +51,7 @@ const (
 	LabelMagicsockConnUDP6   Label = 9  // wgengine/magicsock/magicsock.go
 	LabelNetlogLogger        Label = 10 // wgengine/netlog/logger.go
 	LabelSockstatlogLogger   Label = 11 // log/sockstatlog/logger.go
+	LabelDNSForwarderTCP     Label = 12 // net/dns/resolver/forwarder.go
 )
 
 // WithSockStats instruments a context so that sockets created with it will

+ 6 - 1
tailcfg/tailcfg.go

@@ -114,7 +114,8 @@ type CapabilityVersion int
 //   - 72: 2023-08-23: TS-2023-006 UPnP issue fixed; UPnP can now be used again
 //   - 73: 2023-09-01: Non-Windows clients expect to receive ClientVersion
 //   - 74: 2023-09-18: Client understands NodeCapMap
-const CurrentCapabilityVersion CapabilityVersion = 74
+//   - 75: 2023-09-12: Client understands NodeAttrDNSForwarderDisableTCPRetries
+const CurrentCapabilityVersion CapabilityVersion = 75
 
 type StableID string
 
@@ -2137,6 +2138,10 @@ const (
 	// NodeAttrPeerMTUEnable makes the client do path MTU discovery to its
 	// peers. If it isn't set, it defaults to the client default.
 	NodeAttrPeerMTUEnable NodeCapability = "peer-mtu-enable"
+
+	// NodeAttrDNSForwarderDisableTCPRetries disables retrying truncated
+	// DNS queries over TCP if the response is truncated.
+	NodeAttrDNSForwarderDisableTCPRetries NodeCapability = "dns-forwarder-disable-tcp-retries"
 )
 
 // SetDNSRequest is a request to add a DNS record.

+ 1 - 1
wgengine/netstack/netstack.go

@@ -1086,7 +1086,7 @@ func (ns *Impl) handleMagicDNSUDP(srcAddr netip.AddrPort, c *gonet.UDPConn) {
 			}
 			return
 		}
-		resp, err := ns.dns.Query(context.Background(), q[:n], srcAddr)
+		resp, err := ns.dns.Query(context.Background(), q[:n], "udp", srcAddr)
 		if err != nil {
 			ns.logf("dns udp query: %v", err)
 			return

+ 1 - 1
wgengine/userspace.go

@@ -304,7 +304,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
 	tunName, _ := conf.Tun.Name()
 	conf.Dialer.SetTUNName(tunName)
 	conf.Dialer.SetNetMon(e.netMon)
-	e.dns = dns.NewManager(logf, conf.DNS, e.netMon, conf.Dialer, fwdDNSLinkSelector{e, tunName})
+	e.dns = dns.NewManager(logf, conf.DNS, e.netMon, conf.Dialer, fwdDNSLinkSelector{e, tunName}, conf.ControlKnobs)
 
 	// TODO: there's probably a better place for this
 	sockstats.SetNetMon(e.netMon)