Преглед изворни кода

net/dnscache, net/tsdial: add DNS caching to tsdial UserDial

This is enough to handle the DNS queries as generated by Go's
net package (which our HTTP/SOCKS client uses), and the responses
generated by the ExitDNS DoH server.

This isn't yet suitable for putting on 100.100.100.100 where a number
of different DNS clients would hit it, as this doesn't yet do
EDNS0. It might work, but it's untested and likely incomplete.

Likewise, this doesn't handle anything about truncation, as the
exchanges are entirely in memory between Go or DoH. That would also
need to be handled later, if/when it's hooked up to 100.100.100.100.

Updates #3507

Change-Id: I1736b0ad31eea85ea853b310c52c5e6bf65c6e2a
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick пре 4 година
родитељ
комит
39ffa16853

+ 2 - 1
cmd/tailscale/depaware.txt

@@ -3,6 +3,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
    W 💣 github.com/alexbrainman/sspi                                 from github.com/alexbrainman/sspi/negotiate+
    W 💣 github.com/alexbrainman/sspi                                 from github.com/alexbrainman/sspi/negotiate+
    W    github.com/alexbrainman/sspi/internal/common                 from github.com/alexbrainman/sspi/negotiate
    W    github.com/alexbrainman/sspi/internal/common                 from github.com/alexbrainman/sspi/negotiate
    W 💣 github.com/alexbrainman/sspi/negotiate                       from tailscale.com/net/tshttpproxy
    W 💣 github.com/alexbrainman/sspi/negotiate                       from tailscale.com/net/tshttpproxy
+        github.com/golang/groupcache/lru                             from tailscale.com/net/dnscache
         github.com/kballard/go-shellquote                            from tailscale.com/cmd/tailscale/cli
         github.com/kballard/go-shellquote                            from tailscale.com/cmd/tailscale/cli
    L    github.com/klauspost/compress/flate                          from nhooyr.io/websocket
    L    github.com/klauspost/compress/flate                          from nhooyr.io/websocket
      💣 github.com/mitchellh/go-ps                                   from tailscale.com/cmd/tailscale/cli+
      💣 github.com/mitchellh/go-ps                                   from tailscale.com/cmd/tailscale/cli+
@@ -91,7 +92,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
         golang.org/x/crypto/nacl/secretbox                           from golang.org/x/crypto/nacl/box
         golang.org/x/crypto/nacl/secretbox                           from golang.org/x/crypto/nacl/box
         golang.org/x/crypto/poly1305                                 from golang.org/x/crypto/chacha20poly1305
         golang.org/x/crypto/poly1305                                 from golang.org/x/crypto/chacha20poly1305
         golang.org/x/crypto/salsa20/salsa                            from golang.org/x/crypto/nacl/box+
         golang.org/x/crypto/salsa20/salsa                            from golang.org/x/crypto/nacl/box+
-        golang.org/x/net/dns/dnsmessage                              from net
+        golang.org/x/net/dns/dnsmessage                              from net+
         golang.org/x/net/http/httpguts                               from net/http+
         golang.org/x/net/http/httpguts                               from net/http+
         golang.org/x/net/http/httpproxy                              from net/http
         golang.org/x/net/http/httpproxy                              from net/http
         golang.org/x/net/http2/hpack                                 from net/http
         golang.org/x/net/http2/hpack                                 from net/http

+ 1 - 0
cmd/tailscaled/depaware.txt

@@ -63,6 +63,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
    W 💣 github.com/go-ole/go-ole                                     from github.com/go-ole/go-ole/oleutil+
    W 💣 github.com/go-ole/go-ole                                     from github.com/go-ole/go-ole/oleutil+
    W 💣 github.com/go-ole/go-ole/oleutil                             from tailscale.com/wgengine/winnet
    W 💣 github.com/go-ole/go-ole/oleutil                             from tailscale.com/wgengine/winnet
    L 💣 github.com/godbus/dbus/v5                                    from tailscale.com/net/dns
    L 💣 github.com/godbus/dbus/v5                                    from tailscale.com/net/dns
+        github.com/golang/groupcache/lru                             from tailscale.com/net/dnscache
         github.com/google/btree                                      from inet.af/netstack/tcpip/header+
         github.com/google/btree                                      from inet.af/netstack/tcpip/header+
    L    github.com/insomniacslk/dhcp/dhcpv4                          from tailscale.com/net/tstun
    L    github.com/insomniacslk/dhcp/dhcpv4                          from tailscale.com/net/tstun
    L    github.com/insomniacslk/dhcp/iana                            from github.com/insomniacslk/dhcp/dhcpv4
    L    github.com/insomniacslk/dhcp/iana                            from github.com/insomniacslk/dhcp/dhcpv4

+ 1 - 0
go.mod

@@ -19,6 +19,7 @@ require (
 	github.com/gliderlabs/ssh v0.3.3
 	github.com/gliderlabs/ssh v0.3.3
 	github.com/go-ole/go-ole v1.2.6
 	github.com/go-ole/go-ole v1.2.6
 	github.com/godbus/dbus/v5 v5.0.6
 	github.com/godbus/dbus/v5 v5.0.6
+	github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da
 	github.com/google/go-cmp v0.5.6
 	github.com/google/go-cmp v0.5.6
 	github.com/google/uuid v1.3.0
 	github.com/google/uuid v1.3.0
 	github.com/goreleaser/nfpm v1.10.3
 	github.com/goreleaser/nfpm v1.10.3

+ 2 - 0
go.sum

@@ -387,6 +387,8 @@ github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4er
 github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
 github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
 github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
 github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
 github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
 github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
+github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
+github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
 github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
 github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
 github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
 github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
 github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
 github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=

+ 314 - 0
net/dnscache/messagecache.go

@@ -0,0 +1,314 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dnscache
+
+import (
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"io"
+	"sync"
+	"time"
+
+	"github.com/golang/groupcache/lru"
+	"golang.org/x/net/dns/dnsmessage"
+)
+
+// MessageCache is a cache that works at the DNS message layer,
+// with its cache keyed on a DNS wire-level question, and capable
+// of replying to DNS messages.
+//
+// Its zero value is ready for use with a default cache size.
+// Use SetMaxCacheSize to specify the cache size.
+//
+// It's safe for concurrent use.
+type MessageCache struct {
+	// Clock is a clock, for testing.
+	// If nil, time.Now is used.
+	Clock func() time.Time
+
+	mu           sync.Mutex
+	cacheSizeSet int       // 0 means default
+	cache        lru.Cache // msgQ => *msgCacheValue
+}
+
+func (c *MessageCache) now() time.Time {
+	if c.Clock != nil {
+		return c.Clock()
+	}
+	return time.Now()
+}
+
+// SetMaxCacheSize sets the maximum number of DNS cache entries that
+// can be stored.
+func (c *MessageCache) SetMaxCacheSize(n int) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	c.cacheSizeSet = n
+	c.pruneLocked()
+}
+
+// Flush clears the cache.
+func (c *MessageCache) Flush() {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	c.cache.Clear()
+}
+
+// pruneLocked prunes down the cache size to the configured (or
+// default) max size.
+func (c *MessageCache) pruneLocked() {
+	max := c.cacheSizeSet
+	if max == 0 {
+		max = 500
+	}
+	for c.cache.Len() > max {
+		c.cache.RemoveOldest()
+	}
+}
+
+// msgQ is the MessageCache cache key.
+//
+// It's basically a golang.org/x/net/dns/dnsmessage#Question but the
+// Class is omitted (we only cache ClassINET) and we store a Go string
+// instead of a 256 byte dnsmessage.Name array.
+type msgQ struct {
+	Name string
+	Type dnsmessage.Type // A, AAAA, MX, etc
+}
+
+// A *msgCacheValue is the cached value for a msgQ (question) key.
+//
+// Despite using pointers for storage and methods, the value is
+// immutable once placed in the cache.
+type msgCacheValue struct {
+	Expires time.Time
+
+	// Answers are the minimum data to reconstruct a DNS response
+	// message. TTLs are added later when converting to a
+	// dnsmessage.Resource.
+	Answers []msgResource
+}
+
+type msgResource struct {
+	Name string
+	Type dnsmessage.Type // dnsmessage.UnknownResource.Type
+	Data []byte          // dnsmessage.UnknownResource.Data
+}
+
+// ErrCacheMiss is a sentinel error returned by MessageCache.ReplyFromCache
+// when the request can not be satisified from cache.
+var ErrCacheMiss = errors.New("cache miss")
+
+var parserPool = &sync.Pool{
+	New: func() interface{} { return new(dnsmessage.Parser) },
+}
+
+// ReplyFromCache writes a DNS reply to w for the provided DNS query message,
+// which must begin with the two ID bytes of a DNS message.
+//
+// If there's a cache miss, the message is invalid or unexpected,
+// ErrCacheMiss is returned. On cache hit, either nil or an error from
+// a w.Write call is returned.
+func (c *MessageCache) ReplyFromCache(w io.Writer, dnsQueryMessage []byte) error {
+	cacheKey, txID, ok := getDNSQueryCacheKey(dnsQueryMessage)
+	if !ok {
+		return ErrCacheMiss
+	}
+	now := c.now()
+
+	c.mu.Lock()
+	cacheEntI, _ := c.cache.Get(cacheKey)
+	v, ok := cacheEntI.(*msgCacheValue)
+	if ok && now.After(v.Expires) {
+		c.cache.Remove(cacheKey)
+		ok = false
+	}
+	c.mu.Unlock()
+
+	if !ok {
+		return ErrCacheMiss
+	}
+
+	ttl := uint32(v.Expires.Sub(now).Seconds())
+
+	packedRes, err := packDNSResponse(cacheKey, txID, ttl, v.Answers)
+	if err != nil {
+		return ErrCacheMiss
+	}
+	_, err = w.Write(packedRes)
+	return err
+}
+
+var (
+	errNotCacheable = errors.New("question not cacheable")
+)
+
+// AddCacheEntry adds a cache entry to the cache.
+// It returns an error if the entry could not be cached.
+func (c *MessageCache) AddCacheEntry(qPacket, res []byte) error {
+	cacheKey, qID, ok := getDNSQueryCacheKey(qPacket)
+	if !ok {
+		return errNotCacheable
+	}
+	now := c.now()
+	v := &msgCacheValue{}
+
+	p := parserPool.Get().(*dnsmessage.Parser)
+	defer parserPool.Put(p)
+
+	resh, err := p.Start(res)
+	if err != nil {
+		return fmt.Errorf("reading header in response: %w", err)
+	}
+	if resh.ID != qID {
+		return fmt.Errorf("response ID doesn't match query ID")
+	}
+	q, err := p.Question()
+	if err != nil {
+		return fmt.Errorf("reading 1st question in response: %w", err)
+	}
+	if _, err := p.Question(); err != dnsmessage.ErrSectionDone {
+		if err == nil {
+			return errors.New("unexpected 2nd question in response")
+		}
+		return fmt.Errorf("after reading 1st question in response: %w", err)
+	}
+	if resName := asciiLowerName(q.Name).String(); resName != cacheKey.Name {
+		return fmt.Errorf("response question name %q != question name %q", resName, cacheKey.Name)
+	}
+	for {
+		rh, err := p.AnswerHeader()
+		if err == dnsmessage.ErrSectionDone {
+			break
+		}
+		if err != nil {
+			return fmt.Errorf("reading answer: %w", err)
+		}
+		res, err := p.UnknownResource()
+		if err != nil {
+			return fmt.Errorf("reading resource: %w", err)
+		}
+		if rh.Class != dnsmessage.ClassINET {
+			continue
+		}
+
+		// Set the cache entry's expiration to the soonest
+		// we've seen. (They should all be the same, though)
+		expires := now.Add(time.Duration(rh.TTL) * time.Second)
+		if v.Expires.IsZero() || expires.Before(v.Expires) {
+			v.Expires = expires
+		}
+		v.Answers = append(v.Answers, msgResource{
+			Name: rh.Name.String(),
+			Type: rh.Type,
+			Data: res.Data, // doesn't alias; a copy from dnsmessage.unpackUnknownResource
+		})
+	}
+	c.addCacheValue(cacheKey, v)
+	return nil
+}
+
+func (c *MessageCache) addCacheValue(cacheKey msgQ, v *msgCacheValue) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	c.cache.Add(cacheKey, v)
+	c.pruneLocked()
+}
+
+func getDNSQueryCacheKey(msg []byte) (cacheKey msgQ, txID uint16, ok bool) {
+	p := parserPool.Get().(*dnsmessage.Parser)
+	defer parserPool.Put(p)
+	h, err := p.Start(msg)
+	const dnsHeaderSize = 12
+	if err != nil || h.OpCode != 0 || h.Response || h.Truncated ||
+		len(msg) < dnsHeaderSize { // p.Start checks this anyway, but to be explicit for slicing below
+		return cacheKey, 0, false
+	}
+	var (
+		numQ    = binary.BigEndian.Uint16(msg[4:6])
+		numAns  = binary.BigEndian.Uint16(msg[6:8])
+		numAuth = binary.BigEndian.Uint16(msg[8:10])
+		numAddn = binary.BigEndian.Uint16(msg[10:12])
+	)
+	_ = numAddn // ignore this for now; do client OSes send EDNS additional? assume so, ignore.
+	if !(numQ == 1 && numAns == 0 && numAuth == 0) {
+		// Something weird. We don't want to deal with it.
+		return cacheKey, 0, false
+	}
+	q, err := p.Question()
+	if err != nil {
+		// Already verified numQ == 1 so shouldn't happen, but:
+		return cacheKey, 0, false
+	}
+	if q.Class != dnsmessage.ClassINET {
+		// We only cache the Internet class.
+		return cacheKey, 0, false
+	}
+	return msgQ{Name: asciiLowerName(q.Name).String(), Type: q.Type}, h.ID, true
+}
+
+func asciiLowerName(n dnsmessage.Name) dnsmessage.Name {
+	nb := n.Data[:]
+	if int(n.Length) < len(n.Data) {
+		nb = nb[:n.Length]
+	}
+	for i, b := range nb {
+		if 'A' <= b && b <= 'Z' {
+			n.Data[i] += 0x20
+		}
+	}
+	return n
+}
+
+// packDNSResponse builds a DNS response for the given question and
+// transaction ID. The response resource records will have have the
+// same provided TTL.
+func packDNSResponse(q msgQ, txID uint16, ttl uint32, answers []msgResource) ([]byte, error) {
+	var baseMem []byte // TODO: guess a max size based on looping over answers?
+	b := dnsmessage.NewBuilder(baseMem, dnsmessage.Header{
+		ID:            txID,
+		Response:      true,
+		OpCode:        0,
+		Authoritative: false,
+		Truncated:     false,
+		RCode:         dnsmessage.RCodeSuccess,
+	})
+	name, err := dnsmessage.NewName(q.Name)
+	if err != nil {
+		return nil, err
+	}
+	if err := b.StartQuestions(); err != nil {
+		return nil, err
+	}
+	if err := b.Question(dnsmessage.Question{
+		Name:  name,
+		Type:  q.Type,
+		Class: dnsmessage.ClassINET,
+	}); err != nil {
+		return nil, err
+	}
+	if err := b.StartAnswers(); err != nil {
+		return nil, err
+	}
+	for _, r := range answers {
+		name, err := dnsmessage.NewName(r.Name)
+		if err != nil {
+			return nil, err
+		}
+		if err := b.UnknownResource(dnsmessage.ResourceHeader{
+			Name:  name,
+			Type:  r.Type,
+			Class: dnsmessage.ClassINET,
+			TTL:   ttl,
+		}, dnsmessage.UnknownResource{
+			Type: r.Type,
+			Data: r.Data,
+		}); err != nil {
+			return nil, err
+		}
+	}
+	return b.Finish()
+}

+ 292 - 0
net/dnscache/messagecache_test.go

@@ -0,0 +1,292 @@
+// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dnscache
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"fmt"
+	"net"
+	"runtime"
+	"testing"
+	"time"
+
+	"golang.org/x/net/dns/dnsmessage"
+	"tailscale.com/tstest"
+)
+
+func TestMessageCache(t *testing.T) {
+	clock := &tstest.Clock{
+		Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC),
+	}
+	mc := &MessageCache{Clock: clock.Now}
+	mc.SetMaxCacheSize(2)
+	clock.Advance(time.Second)
+
+	var out bytes.Buffer
+	if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss {
+		t.Fatalf("unexpected error: %v", err)
+	}
+
+	if err := mc.AddCacheEntry(
+		makeQ(2, "foo.com."),
+		makeRes(2, "FOO.COM.", ttlOpt(10),
+			&dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}},
+			&dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil {
+		t.Fatal(err)
+	}
+
+	// Expect cache hit, with 10 seconds remaining.
+	out.Reset()
+	if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil {
+		t.Fatalf("expected cache hit; got: %v", err)
+	}
+	if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 {
+		t.Errorf("TxID = %v; want %v", p.TxID, 3)
+	} else if p.TTL != 10 {
+		t.Errorf("TTL = %v; want 10", p.TTL)
+	}
+
+	// One second elapses, expect a cache hit, with 9 seconds
+	// remaining.
+	clock.Advance(time.Second)
+	out.Reset()
+	if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil {
+		t.Fatalf("expected cache hit; got: %v", err)
+	}
+	if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 {
+		t.Errorf("TxID = %v; want %v", p.TxID, 4)
+	} else if p.TTL != 9 {
+		t.Errorf("TTL = %v; want 9", p.TTL)
+	}
+
+	// Expect cache miss on MX record.
+	if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss {
+		t.Fatalf("expected cache miss on MX; got: %v", err)
+	}
+	// Expect cache miss on CHAOS class.
+	if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss {
+		t.Fatalf("expected cache miss on CHAOS; got: %v", err)
+	}
+
+	// Ten seconds elapses; expect a cache miss.
+	clock.Advance(10 * time.Second)
+	if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss {
+		t.Fatalf("expected cache miss, got: %v", err)
+	}
+}
+
+type parsedMeta struct {
+	TxID uint16
+	TTL  uint32
+}
+
+func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) {
+	t.Helper()
+	var p dnsmessage.Parser
+	h, err := p.Start(r)
+	if err != nil {
+		t.Fatal(err)
+	}
+	ret.TxID = h.ID
+	qq, err := p.AllQuestions()
+	if err != nil {
+		t.Fatalf("AllQuestions: %v", err)
+	}
+	if len(qq) != 1 {
+		t.Fatalf("num questions = %v; want 1", len(qq))
+	}
+	aa, err := p.AllAnswers()
+	if err != nil {
+		t.Fatalf("AllAnswers: %v", err)
+	}
+	for _, r := range aa {
+		if ret.TTL == 0 {
+			ret.TTL = r.Header.TTL
+		}
+		if ret.TTL != r.Header.TTL {
+			t.Fatal("mixed TTLs")
+		}
+	}
+	return ret
+}
+
+type responseOpt bool
+
+type ttlOpt uint32
+
+func makeQ(txID uint16, name string, opt ...interface{}) []byte {
+	opt = append(opt, responseOpt(false))
+	return makeDNSPkt(txID, name, opt...)
+}
+
+func makeRes(txID uint16, name string, opt ...interface{}) []byte {
+	opt = append(opt, responseOpt(true))
+	return makeDNSPkt(txID, name, opt...)
+}
+
+func makeDNSPkt(txID uint16, name string, opt ...interface{}) []byte {
+	typ := dnsmessage.TypeA
+	class := dnsmessage.ClassINET
+	var response bool
+	var answers []dnsmessage.ResourceBody
+	var ttl uint32 = 1 // one second by default
+	for _, o := range opt {
+		switch o := o.(type) {
+		case dnsmessage.Type:
+			typ = o
+		case dnsmessage.Class:
+			class = o
+		case responseOpt:
+			response = bool(o)
+		case dnsmessage.ResourceBody:
+			answers = append(answers, o)
+		case ttlOpt:
+			ttl = uint32(o)
+		default:
+			panic(fmt.Sprintf("unknown opt type %T", o))
+		}
+	}
+	qname := dnsmessage.MustNewName(name)
+	msg := dnsmessage.Message{
+		Header: dnsmessage.Header{ID: txID, Response: response},
+		Questions: []dnsmessage.Question{
+			{
+				Name:  qname,
+				Type:  typ,
+				Class: class,
+			},
+		},
+	}
+	for _, rb := range answers {
+		msg.Answers = append(msg.Answers, dnsmessage.Resource{
+			Header: dnsmessage.ResourceHeader{
+				Name:  qname,
+				Type:  typ,
+				Class: class,
+				TTL:   ttl,
+			},
+			Body: rb,
+		})
+	}
+	buf, err := msg.Pack()
+	if err != nil {
+		panic(err)
+	}
+	return buf
+}
+
+func TestASCIILowerName(t *testing.T) {
+	n := asciiLowerName(dnsmessage.MustNewName("Foo.COM."))
+	if got, want := n.String(), "foo.com."; got != want {
+		t.Errorf("got = %q; want %q", got, want)
+	}
+}
+
+func TestGetDNSQueryCacheKey(t *testing.T) {
+	tests := []struct {
+		name  string
+		pkt   []byte
+		want  msgQ
+		txID  uint16
+		anyTX bool
+	}{
+		{
+			name: "empty",
+		},
+		{
+			name: "a",
+			pkt:  makeQ(123, "foo.com."),
+			want: msgQ{"foo.com.", dnsmessage.TypeA},
+			txID: 123,
+		},
+		{
+			name: "aaaa",
+			pkt:  makeQ(6, "foo.com.", dnsmessage.TypeAAAA),
+			want: msgQ{"foo.com.", dnsmessage.TypeAAAA},
+			txID: 6,
+		},
+		{
+			name: "normalize_case",
+			pkt:  makeQ(123, "FoO.CoM."),
+			want: msgQ{"foo.com.", dnsmessage.TypeA},
+			txID: 123,
+		},
+		{
+			name: "ignore_response",
+			pkt:  makeRes(123, "foo.com."),
+		},
+		{
+			name: "ignore_question_with_answers",
+			pkt:  makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}),
+		},
+		{
+			name:  "whatever_go_generates", // in case Go's net package grows functionality we don't handle
+			pkt:   getGoNetPacketDNSQuery("from-go.foo."),
+			want:  msgQ{"from-go.foo.", dnsmessage.TypeA},
+			anyTX: true,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, gotTX, ok := getDNSQueryCacheKey(tt.pkt)
+			if !ok {
+				if tt.txID == 0 && got == (msgQ{}) {
+					return
+				}
+				t.Fatal("failed")
+			}
+			if got != tt.want {
+				t.Errorf("got %+v, want %+v", got, tt.want)
+			}
+			if gotTX != tt.txID && !tt.anyTX {
+				t.Errorf("got tx %v, want %v", gotTX, tt.txID)
+			}
+		})
+	}
+}
+
+func getGoNetPacketDNSQuery(name string) []byte {
+	if runtime.GOOS == "windows" {
+		// On Windows, Go's net.Resolver doesn't use the DNS client.
+		// See https://github.com/golang/go/issues/33097 which
+		// was approved but not yet implemented.
+		// For now just pretend it's implemented to make this test
+		// pass on Windows with complicated the caller.
+		return makeQ(123, name)
+	}
+	res := make(chan []byte, 1)
+	r := &net.Resolver{
+		PreferGo: true,
+		Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
+			return goResolverConn(res), nil
+		},
+	}
+	r.LookupIP(context.Background(), "ip4", name)
+	return <-res
+}
+
+type goResolverConn chan<- []byte
+
+func (goResolverConn) Close() error                       { return nil }
+func (goResolverConn) LocalAddr() net.Addr                { return todoAddr{} }
+func (goResolverConn) RemoteAddr() net.Addr               { return todoAddr{} }
+func (goResolverConn) SetDeadline(t time.Time) error      { return nil }
+func (goResolverConn) SetReadDeadline(t time.Time) error  { return nil }
+func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil }
+func (goResolverConn) Read([]byte) (int, error)           { return 0, errors.New("boom") }
+func (c goResolverConn) Write(p []byte) (int, error) {
+	select {
+	case c <- p[2:]: // skip 2 byte length for TCP mode DNS query
+	default:
+	}
+	return 0, errors.New("boom")
+}
+
+type todoAddr struct{}
+
+func (todoAddr) Network() string { return "unused" }
+func (todoAddr) String() string  { return "unused-todoAddr" }

+ 18 - 3
net/tsdial/dohclient.go

@@ -13,15 +13,18 @@ import (
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 	"time"
 	"time"
+
+	"tailscale.com/net/dnscache"
 )
 )
 
 
 // dohConn is a net.PacketConn suitable for returning from
 // dohConn is a net.PacketConn suitable for returning from
 // net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes'
 // net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes'
 // ExitDNS DoH proxy service.
 // ExitDNS DoH proxy service.
 type dohConn struct {
 type dohConn struct {
-	ctx     context.Context
-	baseURL string
-	hc      *http.Client // if nil, default is used
+	ctx      context.Context
+	baseURL  string
+	hc       *http.Client // if nil, default is used
+	dnsCache *dnscache.MessageCache
 
 
 	rbuf bytes.Buffer
 	rbuf bytes.Buffer
 }
 }
@@ -52,6 +55,15 @@ func (c *dohConn) Read(p []byte) (n int, err error) {
 }
 }
 
 
 func (c *dohConn) Write(packet []byte) (n int, err error) {
 func (c *dohConn) Write(packet []byte) (n int, err error) {
+	if c.dnsCache != nil {
+		err := c.dnsCache.ReplyFromCache(&c.rbuf, packet)
+		if err == nil {
+			// Cache hit.
+			// TODO(bradfitz): add clientmetric
+			return len(packet), nil
+		}
+		c.rbuf.Reset()
+	}
 	req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet))
 	req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet))
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err
@@ -77,6 +89,9 @@ func (c *dohConn) Write(packet []byte) (n int, err error) {
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
+	if c.dnsCache != nil {
+		c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes())
+	}
 	return len(packet), nil
 	return len(packet), nil
 }
 }
 
 

+ 19 - 5
net/tsdial/tsdial.go

@@ -11,6 +11,7 @@ import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"runtime"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
@@ -18,6 +19,7 @@ import (
 	"time"
 	"time"
 
 
 	"inet.af/netaddr"
 	"inet.af/netaddr"
+	"tailscale.com/net/dnscache"
 	"tailscale.com/net/netknob"
 	"tailscale.com/net/netknob"
 	"tailscale.com/types/netmap"
 	"tailscale.com/types/netmap"
 	"tailscale.com/wgengine/monitor"
 	"tailscale.com/wgengine/monitor"
@@ -48,7 +50,8 @@ type Dialer struct {
 	dns            dnsMap
 	dns            dnsMap
 	tunName        string // tun device name
 	tunName        string // tun device name
 	linkMon        *monitor.Mon
 	linkMon        *monitor.Mon
-	exitDNSDoHBase string // non-empty if DoH-proxying exit node in use; base URL+path (without '?')
+	exitDNSDoHBase string                 // non-empty if DoH-proxying exit node in use; base URL+path (without '?')
+	dnsCache       *dnscache.MessageCache // nil until first first non-empty SetExitDNSDoH
 }
 }
 
 
 // SetTUNName sets the name of the tun device in use ("tailscale0", "utun6",
 // SetTUNName sets the name of the tun device in use ("tailscale0", "utun6",
@@ -76,7 +79,16 @@ func (d *Dialer) TUNName() string {
 func (d *Dialer) SetExitDNSDoH(doh string) {
 func (d *Dialer) SetExitDNSDoH(doh string) {
 	d.mu.Lock()
 	d.mu.Lock()
 	defer d.mu.Unlock()
 	defer d.mu.Unlock()
+	if d.exitDNSDoHBase == doh {
+		return
+	}
 	d.exitDNSDoHBase = doh
 	d.exitDNSDoHBase = doh
+	if doh != "" && d.dnsCache == nil {
+		d.dnsCache = new(dnscache.MessageCache)
+	}
+	if d.dnsCache != nil {
+		d.dnsCache.Flush()
+	}
 }
 }
 
 
 func (d *Dialer) SetLinkMonitor(mon *monitor.Mon) {
 func (d *Dialer) SetLinkMonitor(mon *monitor.Mon) {
@@ -149,12 +161,14 @@ func (d *Dialer) userDialResolve(ctx context.Context, network, addr string) (net
 	}
 	}
 
 
 	var r net.Resolver
 	var r net.Resolver
-	if exitDNSDoH != "" {
+	if exitDNSDoH != "" && runtime.GOOS != "windows" { // Windows: https://github.com/golang/go/issues/33097
+		r.PreferGo = true
 		r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
 		r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) {
 			return &dohConn{
 			return &dohConn{
-				ctx:     ctx,
-				baseURL: exitDNSDoH,
-				hc:      d.PeerAPIHTTPClient(),
+				ctx:      ctx,
+				baseURL:  exitDNSDoH,
+				hc:       d.PeerAPIHTTPClient(),
+				dnsCache: d.dnsCache,
 			}, nil
 			}, nil
 		}
 		}
 	}
 	}