Răsfoiți Sursa

net/dns/resolver: add Windows ExitDNS service support, using net package

Updates #1713
Updates #835

Change-Id: Ia71e96d0632c2d617b401695ad68301b07c1c2ec
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 4 ani în urmă
părinte
comite
cced414c7d

+ 1 - 2
ipn/ipnlocal/local.go

@@ -1797,10 +1797,9 @@ func (b *LocalBackend) peerAPIServicesLocked() (ret []tailcfg.Service) {
 		})
 	}
 	switch runtime.GOOS {
-	case "linux", "freebsd", "openbsd", "illumos", "darwin":
+	case "linux", "freebsd", "openbsd", "illumos", "darwin", "windows":
 		// These are the platforms currently supported by
 		// net/dns/resolver/tsdns.go:Resolver.HandleExitNodeDNSQuery.
-		// TODO(bradfitz): add windows once it's done there.
 		ret = append(ret, tailcfg.Service{
 			Proto: tailcfg.PeerAPIDNS,
 			Port:  1, // version

+ 262 - 7
net/dns/resolver/tsdns.go

@@ -360,7 +360,8 @@ func (r *Resolver) HandleExitNodeDNSQuery(ctx context.Context, q []byte, from ne
 	case "windows":
 		// TODO: use DnsQueryEx and write to ch.
 		// See https://docs.microsoft.com/en-us/windows/win32/api/windns/nf-windns-dnsqueryex.
-		return nil, errors.New("TODO: windows exit node suport")
+		// For now just use the net package:
+		return handleExitNodeDNSQueryWithNetPkg(ctx, nil, resp)
 	case "darwin":
 		// /etc/resolv.conf is a lie and only says one upstream DNS
 		// but for now that's probably good enough. Later we'll
@@ -404,6 +405,106 @@ func (r *Resolver) HandleExitNodeDNSQuery(ctx context.Context, q []byte, from ne
 	}
 }
 
+// handleExitNodeDNSQueryWithNetPkg takes a DNS query message in q and
+// return a reply (for the ExitDNS DoH service) using the net package's
+// native APIs. This is only used on Windows for now.
+//
+// If resolver is nil, the net.Resolver zero value is used.
+//
+// response contains the pre-serialized response, which notably
+// includes the original question and its header.
+func handleExitNodeDNSQueryWithNetPkg(ctx context.Context, resolver *net.Resolver, resp *response) (res []byte, err error) {
+	if resp.Question.Class != dns.ClassINET {
+		return nil, errors.New("unsupported class")
+	}
+
+	r := resolver
+	if r == nil {
+		r = new(net.Resolver)
+	}
+	name := resp.Question.Name.String()
+
+	handleError := func(err error) (res []byte, _ error) {
+		if isGoNoSuchHostError(err) {
+			resp.Header.RCode = dns.RCodeNameError
+			return marshalResponse(resp)
+		}
+		// TODO: map other errors to RCodeServerFailure?
+		// Or I guess our caller should do that?
+		return nil, err
+	}
+
+	resp.Header.RCode = dns.RCodeSuccess // unless changed below
+
+	switch resp.Question.Type {
+	case dns.TypeA, dns.TypeAAAA:
+		network := "ip4"
+		if resp.Question.Type == dns.TypeAAAA {
+			network = "ip6"
+		}
+		ips, err := r.LookupIP(ctx, network, name)
+		if err != nil {
+			return handleError(err)
+		}
+		for _, stdIP := range ips {
+			if ip, ok := netaddr.FromStdIP(stdIP); ok {
+				resp.IPs = append(resp.IPs, ip)
+			}
+		}
+	case dns.TypeTXT:
+		strs, err := r.LookupTXT(ctx, name)
+		if err != nil {
+			return handleError(err)
+		}
+		resp.TXT = strs
+	case dns.TypePTR:
+		ipStr, ok := unARPA(name)
+		if !ok {
+			// TODO: is this RCodeFormatError?
+			return nil, errors.New("bogus PTR name")
+		}
+		addrs, err := r.LookupAddr(ctx, ipStr)
+		if err != nil {
+			return handleError(err)
+		}
+		if len(addrs) > 0 {
+			resp.Name, _ = dnsname.ToFQDN(addrs[0])
+		}
+	case dns.TypeCNAME:
+		cname, err := r.LookupCNAME(ctx, name)
+		if err != nil {
+			return handleError(err)
+		}
+		resp.CNAME = cname
+	case dns.TypeSRV:
+		// Thanks, Go: "To accommodate services publishing SRV
+		// records under non-standard names, if both service
+		// and proto are empty strings, LookupSRV looks up
+		// name directly."
+		_, srvs, err := r.LookupSRV(ctx, "", "", name)
+		if err != nil {
+			return handleError(err)
+		}
+		resp.SRVs = srvs
+	case dns.TypeNS:
+		nss, err := r.LookupNS(ctx, name)
+		if err != nil {
+			return handleError(err)
+		}
+		resp.NSs = nss
+	default:
+		return nil, fmt.Errorf("unsupported record type %v", resp.Question.Type)
+	}
+	return marshalResponse(resp)
+}
+
+func isGoNoSuchHostError(err error) bool {
+	if de, ok := err.(*net.DNSError); ok {
+		return de.IsNotFound
+	}
+	return false
+}
+
 type resolvConfCache struct {
 	mod  time.Time
 	size int64
@@ -604,10 +705,27 @@ func (r *Resolver) handleQuery(pkt packet) {
 type response struct {
 	Header   dns.Header
 	Question dns.Question
+
 	// Name is the response to a PTR query.
 	Name dnsname.FQDN
-	// IP is the response to an A, AAAA, or ALL query.
-	IP netaddr.IP
+
+	// IP and IPs are the responses to an A, AAAA, or ALL query.
+	// Either/both/neither can be populated.
+	IP  netaddr.IP
+	IPs []netaddr.IP
+
+	// TXT is the response to a TXT query.
+	// Each one is its own RR with one string.
+	TXT []string
+
+	// CNAME is the response to a CNAME query.
+	CNAME string
+
+	// SRVs are the responses to a SRV query.
+	SRVs []*net.SRV
+
+	// NSs are the responses to an NS query.
+	NSs []*net.NS
 }
 
 var dnsParserPool = &sync.Pool{
@@ -683,6 +801,16 @@ func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error
 	return builder.AAAAResource(answerHeader, answer)
 }
 
+func marshalIP(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
+	if ip.Is4() {
+		return marshalARecord(name, ip, builder)
+	}
+	if ip.Is6() {
+		return marshalAAAARecord(name, ip, builder)
+	}
+	return nil
+}
+
 // marshalPTRRecord serializes a PTR record into an active builder.
 // The caller may continue using the builder following the call.
 func marshalPTRRecord(queryName dns.Name, name dnsname.FQDN, builder *dns.Builder) error {
@@ -702,6 +830,83 @@ func marshalPTRRecord(queryName dns.Name, name dnsname.FQDN, builder *dns.Builde
 	return builder.PTRResource(answerHeader, answer)
 }
 
+func marshalTXT(queryName dns.Name, txts []string, builder *dns.Builder) error {
+	for _, txt := range txts {
+		if err := builder.TXTResource(dns.ResourceHeader{
+			Name:  queryName,
+			Type:  dns.TypeTXT,
+			Class: dns.ClassINET,
+			TTL:   uint32(defaultTTL / time.Second),
+		}, dns.TXTResource{
+			TXT: []string{txt},
+		}); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func marshalCNAME(queryName dns.Name, cname string, builder *dns.Builder) error {
+	if cname == "" {
+		return nil
+	}
+	name, err := dns.NewName(cname)
+	if err != nil {
+		return err
+	}
+	return builder.CNAMEResource(dns.ResourceHeader{
+		Name:  queryName,
+		Type:  dns.TypeCNAME,
+		Class: dns.ClassINET,
+		TTL:   uint32(defaultTTL / time.Second),
+	}, dns.CNAMEResource{
+		CNAME: name,
+	})
+}
+
+func marshalNS(queryName dns.Name, nss []*net.NS, builder *dns.Builder) error {
+	for _, ns := range nss {
+		name, err := dns.NewName(ns.Host)
+		if err != nil {
+			return err
+		}
+		err = builder.NSResource(dns.ResourceHeader{
+			Name:  queryName,
+			Type:  dns.TypeNS,
+			Class: dns.ClassINET,
+			TTL:   uint32(defaultTTL / time.Second),
+		}, dns.NSResource{NS: name})
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func marshalSRV(queryName dns.Name, srvs []*net.SRV, builder *dns.Builder) error {
+	for _, s := range srvs {
+		srvName, err := dns.NewName(s.Target)
+		if err != nil {
+			return err
+		}
+		err = builder.SRVResource(dns.ResourceHeader{
+			Name:  queryName,
+			Type:  dns.TypeSRV,
+			Class: dns.ClassINET,
+			TTL:   uint32(defaultTTL / time.Second),
+		}, dns.SRVResource{
+			Target:   srvName,
+			Priority: s.Priority,
+			Port:     s.Port,
+			Weight:   s.Weight,
+		})
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 // marshalResponse serializes the DNS response into a new buffer.
 func marshalResponse(resp *response) ([]byte, error) {
 	resp.Header.Response = true
@@ -712,6 +917,14 @@ func marshalResponse(resp *response) ([]byte, error) {
 
 	builder := dns.NewBuilder(nil, resp.Header)
 
+	// TODO(bradfitz): I'm not sure why this wasn't enabled
+	// before, but for now (2021-12-09) enable it at least when
+	// there's more than 1 record (which was never the case
+	// before), where it really helps.
+	if len(resp.IPs) > 1 {
+		builder.EnableCompression()
+	}
+
 	isSuccess := resp.Header.RCode == dns.RCodeSuccess
 
 	if resp.Question.Type != 0 || isSuccess {
@@ -738,13 +951,24 @@ func marshalResponse(resp *response) ([]byte, error) {
 
 	switch resp.Question.Type {
 	case dns.TypeA, dns.TypeAAAA, dns.TypeALL:
-		if resp.IP.Is4() {
-			err = marshalARecord(resp.Question.Name, resp.IP, &builder)
-		} else if resp.IP.Is6() {
-			err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder)
+		if err := marshalIP(resp.Question.Name, resp.IP, &builder); err != nil {
+			return nil, err
+		}
+		for _, ip := range resp.IPs {
+			if err := marshalIP(resp.Question.Name, ip, &builder); err != nil {
+				return nil, err
+			}
 		}
 	case dns.TypePTR:
 		err = marshalPTRRecord(resp.Question.Name, resp.Name, &builder)
+	case dns.TypeTXT:
+		err = marshalTXT(resp.Question.Name, resp.TXT, &builder)
+	case dns.TypeCNAME:
+		err = marshalCNAME(resp.Question.Name, resp.CNAME, &builder)
+	case dns.TypeSRV:
+		err = marshalSRV(resp.Question.Name, resp.SRVs, &builder)
+	case dns.TypeNS:
+		err = marshalNS(resp.Question.Name, resp.NSs, &builder)
 	}
 	if err != nil {
 		return nil, err
@@ -926,6 +1150,37 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
 	return marshalResponse(resp)
 }
 
+// unARPA maps from "4.4.8.8.in-addr.arpa." to "8.8.4.4", etc.
+func unARPA(a string) (ipStr string, ok bool) {
+	const suf4 = ".in-addr.arpa."
+	if strings.HasSuffix(a, suf4) {
+		s := strings.TrimSuffix(a, suf4)
+		// Parse and reverse octets.
+		ip, err := netaddr.ParseIP(s)
+		if err != nil || !ip.Is4() {
+			return "", false
+		}
+		a4 := ip.As4()
+		return netaddr.IPv4(a4[3], a4[2], a4[1], a4[0]).String(), true
+	}
+	const suf6 = ".ip6.arpa."
+	if len(a) == len("e.0.0.2.0.0.0.0.0.0.0.0.0.0.0.0.b.0.8.0.a.0.0.4.0.b.8.f.7.0.6.2.ip6.arpa.") &&
+		strings.HasSuffix(a, suf6) {
+		var hx [32]byte
+		var a16 [16]byte
+		for i := range hx {
+			hx[31-i] = a[i*2]
+			if a[i*2+1] != '.' {
+				return "", false
+			}
+		}
+		hex.Decode(a16[:], hx[:])
+		return netaddr.IPFrom16(a16).String(), true
+	}
+	return "", false
+
+}
+
 var (
 	metricDNSQueryLocal       = clientmetric.NewCounter("dns_query_local")
 	metricDNSQueryErrorClosed = clientmetric.NewCounter("dns_query_local_error_closed")

+ 124 - 0
net/dns/resolver/tsdns_server_test.go

@@ -6,6 +6,7 @@ package resolver
 
 import (
 	"fmt"
+	"net"
 	"strings"
 	"testing"
 
@@ -179,6 +180,129 @@ var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg)
 	w.WriteMsg(m)
 })
 
+// weirdoGoCNAMEHandler returns a DNS handler that satisfies
+// Go's weird Resolver.LookupCNAME (read its godoc carefully!).
+//
+// This doesn't even return a CNAME record, because that's not
+// what Go looks for.
+func weirdoGoCNAMEHandler(target string) dns.HandlerFunc {
+	return func(w dns.ResponseWriter, req *dns.Msg) {
+		m := new(dns.Msg)
+		m.SetReply(req)
+		question := req.Question[0]
+
+		switch question.Qtype {
+		case dns.TypeA:
+			m.Answer = append(m.Answer, &dns.CNAME{
+				Hdr: dns.RR_Header{
+					Name:   target,
+					Rrtype: dns.TypeCNAME,
+					Class:  dns.ClassINET,
+					Ttl:    600,
+				},
+				Target: target,
+			})
+		case dns.TypeAAAA:
+			m.Answer = append(m.Answer, &dns.AAAA{
+				Hdr: dns.RR_Header{
+					Name:   target,
+					Rrtype: dns.TypeAAAA,
+					Class:  dns.ClassINET,
+					Ttl:    600,
+				},
+				AAAA: net.ParseIP("1::2"),
+			})
+		}
+		w.WriteMsg(m)
+	}
+}
+
+// dnsHandler returns a handler that replies with the answers/options
+// provided.
+//
+// Types supported: netaddr.IP.
+func dnsHandler(answers ...interface{}) dns.HandlerFunc {
+	return func(w dns.ResponseWriter, req *dns.Msg) {
+		m := new(dns.Msg)
+		m.SetReply(req)
+		if len(req.Question) != 1 {
+			panic("not a single-question request")
+		}
+		m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies
+
+		question := req.Question[0]
+		for _, a := range answers {
+			switch a := a.(type) {
+			default:
+				panic(fmt.Sprintf("unsupported dnsHandler arg %T", a))
+			case netaddr.IP:
+				ip := a
+				if ip.Is4() {
+					m.Answer = append(m.Answer, &dns.A{
+						Hdr: dns.RR_Header{
+							Name:   question.Name,
+							Rrtype: dns.TypeA,
+							Class:  dns.ClassINET,
+						},
+						A: ip.IPAddr().IP,
+					})
+				} else if ip.Is6() {
+					m.Answer = append(m.Answer, &dns.AAAA{
+						Hdr: dns.RR_Header{
+							Name:   question.Name,
+							Rrtype: dns.TypeAAAA,
+							Class:  dns.ClassINET,
+						},
+						AAAA: ip.IPAddr().IP,
+					})
+				}
+			case dns.PTR:
+				ptr := a
+				ptr.Hdr = dns.RR_Header{
+					Name:   question.Name,
+					Rrtype: dns.TypePTR,
+					Class:  dns.ClassINET,
+				}
+				m.Answer = append(m.Answer, &ptr)
+			case dns.CNAME:
+				c := a
+				c.Hdr = dns.RR_Header{
+					Name:   question.Name,
+					Rrtype: dns.TypeCNAME,
+					Class:  dns.ClassINET,
+					Ttl:    600,
+				}
+				m.Answer = append(m.Answer, &c)
+			case dns.TXT:
+				txt := a
+				txt.Hdr = dns.RR_Header{
+					Name:   question.Name,
+					Rrtype: dns.TypeTXT,
+					Class:  dns.ClassINET,
+				}
+				m.Answer = append(m.Answer, &txt)
+			case dns.SRV:
+				srv := a
+				srv.Hdr = dns.RR_Header{
+					Name:   question.Name,
+					Rrtype: dns.TypeSRV,
+					Class:  dns.ClassINET,
+				}
+				m.Answer = append(m.Answer, &srv)
+			case dns.NS:
+				rr := a
+				rr.Hdr = dns.RR_Header{
+					Name:   question.Name,
+					Rrtype: dns.TypeNS,
+					Class:  dns.ClassINET,
+				}
+				m.Answer = append(m.Answer, &rr)
+			}
+		}
+		w.WriteMsg(m)
+	}
+}
+
 func serveDNS(tb testing.TB, addr string, records ...interface{}) *dns.Server {
 	if len(records)%2 != 0 {
 		panic("must have an even number of record values")

+ 388 - 0
net/dns/resolver/tsdns_test.go

@@ -6,16 +6,22 @@ package resolver
 
 import (
 	"bytes"
+	"context"
 	"encoding/hex"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"math/rand"
 	"net"
+	"reflect"
 	"runtime"
 	"strconv"
 	"strings"
 	"testing"
+	"time"
 
+	miekdns "github.com/miekg/dns"
+	"golang.org/x/net/dns/dnsmessage"
 	dns "golang.org/x/net/dns/dnsmessage"
 	"inet.af/netaddr"
 	"tailscale.com/net/tsdial"
@@ -43,6 +49,8 @@ var dnsCfg = Config{
 
 const noEdns = 0
 
+const dnsHeaderLen = 12
+
 func dnspacket(domain dnsname.FQDN, tp dns.Type, ednsSize uint16) []byte {
 	var dnsHeader dns.Header
 	question := dns.Question{
@@ -1093,3 +1101,383 @@ func TestForwardLinkSelection(t *testing.T) {
 type linkSelFunc func(ip netaddr.IP) string
 
 func (f linkSelFunc) PickLink(ip netaddr.IP) string { return f(ip) }
+
+func TestHandleExitNodeDNSQueryWithNetPkg(t *testing.T) {
+	if runtime.GOOS == "windows" {
+		t.Skip("skipping test on Windows; waiting for golang.org/issue/33097")
+	}
+
+	records := []interface{}{
+		"no-records.test.",
+		dnsHandler(),
+
+		"one-a.test.",
+		dnsHandler(netaddr.MustParseIP("1.2.3.4")),
+
+		"two-a.test.",
+		dnsHandler(netaddr.MustParseIP("1.2.3.4"), netaddr.MustParseIP("5.6.7.8")),
+
+		"one-aaaa.test.",
+		dnsHandler(netaddr.MustParseIP("1::2")),
+
+		"two-aaaa.test.",
+		dnsHandler(netaddr.MustParseIP("1::2"), netaddr.MustParseIP("3::4")),
+
+		"nx-domain.test.",
+		resolveToNXDOMAIN,
+
+		"4.3.2.1.in-addr.arpa.",
+		dnsHandler(miekdns.PTR{Ptr: "foo.com."}),
+
+		"cname.test.",
+		weirdoGoCNAMEHandler("the-target.foo."),
+
+		"txt.test.",
+		dnsHandler(
+			miekdns.TXT{Txt: []string{"txt1=one"}},
+			miekdns.TXT{Txt: []string{"txt2=two"}},
+			miekdns.TXT{Txt: []string{"txt3=three"}},
+		),
+
+		"srv.test.",
+		dnsHandler(
+			miekdns.SRV{
+				Priority: 1,
+				Weight:   2,
+				Port:     3,
+				Target:   "foo.com.",
+			},
+			miekdns.SRV{
+				Priority: 4,
+				Weight:   5,
+				Port:     6,
+				Target:   "bar.com.",
+			},
+		),
+
+		"ns.test.",
+		dnsHandler(miekdns.NS{Ns: "ns1.foo."}, miekdns.NS{Ns: "ns2.bar."}),
+	}
+	v4server := serveDNS(t, "127.0.0.1:0", records...)
+	defer v4server.Shutdown()
+
+	// backendResolver is the resolver between
+	// handleExitNodeDNSQueryWithNetPkg and its upstream resolver,
+	// which in this test's case is the miekg/dns test DNS server
+	// (v4server).
+	backResolver := &net.Resolver{
+		PreferGo: true,
+		Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
+			var d net.Dialer
+			return d.DialContext(ctx, "udp", v4server.PacketConn.LocalAddr().String())
+		},
+	}
+
+	t.Run("no_such_host", func(t *testing.T) {
+		res, err := handleExitNodeDNSQueryWithNetPkg(context.Background(), backResolver, &response{
+			Header: dnsmessage.Header{
+				ID:       123,
+				Response: true,
+				OpCode:   0, // query
+			},
+			Question: dnsmessage.Question{
+				Name:  dnsmessage.MustNewName("nx-domain.test."),
+				Type:  dnsmessage.TypeA,
+				Class: dnsmessage.ClassINET,
+			},
+		})
+		if err != nil {
+			t.Fatal(err)
+		}
+		if len(res) < dnsHeaderLen {
+			t.Fatal("short reply")
+		}
+		rcode := dns.RCode(res[3] & 0x0f)
+		if rcode != dns.RCodeNameError {
+			t.Errorf("RCode = %v; want dns.RCodeNameError", rcode)
+			t.Logf("Response was: %q", res)
+		}
+	})
+
+	matchPacked := func(want string) func(t testing.TB, got []byte) {
+		return func(t testing.TB, got []byte) {
+			if string(got) == want {
+				return
+			}
+			t.Errorf("unexpected reply.\n got: %q\nwant: %q\n", got, want)
+			t.Errorf("\nin hex:\n got: % 2x\nwant: % 2x\n", got, want)
+		}
+	}
+
+	tests := []struct {
+		Type  dnsmessage.Type
+		Name  string
+		Check func(t testing.TB, got []byte)
+	}{
+		{
+			Type:  dnsmessage.TypeA,
+			Name:  "one-a.test.",
+			Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\x05one-a\x04test\x00\x00\x01\x00\x01\x05one-a\x04test\x00\x00\x01\x00\x01\x00\x00\x02X\x00\x04\x01\x02\x03\x04"),
+		},
+		{
+			Type:  dnsmessage.TypeA,
+			Name:  "two-a.test.",
+			Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x02\x00\x00\x00\x00\x05two-a\x04test\x00\x00\x01\x00\x01\xc0\f\x00\x01\x00\x01\x00\x00\x02X\x00\x04\x01\x02\x03\x04\xc0\f\x00\x01\x00\x01\x00\x00\x02X\x00\x04\x05\x06\a\b"),
+		},
+		{
+			Type:  dnsmessage.TypeAAAA,
+			Name:  "one-aaaa.test.",
+			Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\bone-aaaa\x04test\x00\x00\x1c\x00\x01\bone-aaaa\x04test\x00\x00\x1c\x00\x01\x00\x00\x02X\x00\x10\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"),
+		},
+		{
+			Type:  dnsmessage.TypeAAAA,
+			Name:  "two-aaaa.test.",
+			Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x02\x00\x00\x00\x00\btwo-aaaa\x04test\x00\x00\x1c\x00\x01\xc0\f\x00\x1c\x00\x01\x00\x00\x02X\x00\x10\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc0\f\x00\x1c\x00\x01\x00\x00\x02X\x00\x10\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04"),
+		},
+		{
+			Type:  dnsmessage.TypePTR,
+			Name:  "4.3.2.1.in-addr.arpa.",
+			Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\x014\x013\x012\x011\ain-addr\x04arpa\x00\x00\f\x00\x01\x014\x013\x012\x011\ain-addr\x04arpa\x00\x00\f\x00\x01\x00\x00\x02X\x00\t\x03foo\x03com\x00"),
+		},
+		{
+			Type:  dnsmessage.TypeCNAME,
+			Name:  "cname.test.",
+			Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\x05cname\x04test\x00\x00\x05\x00\x01\x05cname\x04test\x00\x00\x05\x00\x01\x00\x00\x02X\x00\x10\nthe-target\x03foo\x00"),
+		},
+
+		// No records of various types
+		{
+			Type:  dnsmessage.TypeA,
+			Name:  "no-records.test.",
+			Check: matchPacked("\x00{\x84\x03\x00\x01\x00\x00\x00\x00\x00\x00\nno-records\x04test\x00\x00\x01\x00\x01"),
+		},
+		{
+			Type:  dnsmessage.TypeAAAA,
+			Name:  "no-records.test.",
+			Check: matchPacked("\x00{\x84\x03\x00\x01\x00\x00\x00\x00\x00\x00\nno-records\x04test\x00\x00\x1c\x00\x01"),
+		},
+		{
+			Type:  dnsmessage.TypeCNAME,
+			Name:  "no-records.test.",
+			Check: matchPacked("\x00{\x84\x03\x00\x01\x00\x00\x00\x00\x00\x00\nno-records\x04test\x00\x00\x05\x00\x01"),
+		},
+		{
+			Type:  dnsmessage.TypeSRV,
+			Name:  "no-records.test.",
+			Check: matchPacked("\x00{\x84\x03\x00\x01\x00\x00\x00\x00\x00\x00\nno-records\x04test\x00\x00!\x00\x01"),
+		},
+		{
+			Type:  dnsmessage.TypeTXT,
+			Name:  "txt.test.",
+			Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x03\x00\x00\x00\x00\x03txt\x04test\x00\x00\x10\x00\x01\x03txt\x04test\x00\x00\x10\x00\x01\x00\x00\x02X\x00\t\btxt1=one\x03txt\x04test\x00\x00\x10\x00\x01\x00\x00\x02X\x00\t\btxt2=two\x03txt\x04test\x00\x00\x10\x00\x01\x00\x00\x02X\x00\v\ntxt3=three"),
+		},
+		{
+			Type:  dnsmessage.TypeSRV,
+			Name:  "srv.test.",
+			Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x02\x00\x00\x00\x00\x03srv\x04test\x00\x00!\x00\x01\x03srv\x04test\x00\x00!\x00\x01\x00\x00\x02X\x00\x0f\x00\x01\x00\x02\x00\x03\x03foo\x03com\x00\x03srv\x04test\x00\x00!\x00\x01\x00\x00\x02X\x00\x0f\x00\x04\x00\x05\x00\x06\x03bar\x03com\x00"),
+		},
+		{
+			Type:  dnsmessage.TypeNS,
+			Name:  "ns.test.",
+			Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x02\x00\x00\x00\x00\x02ns\x04test\x00\x00\x02\x00\x01\x02ns\x04test\x00\x00\x02\x00\x01\x00\x00\x02X\x00\t\x03ns1\x03foo\x00\x02ns\x04test\x00\x00\x02\x00\x01\x00\x00\x02X\x00\t\x03ns2\x03bar\x00"),
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(fmt.Sprintf("%v_%v", tt.Type, strings.Trim(tt.Name, ".")), func(t *testing.T) {
+			got, err := handleExitNodeDNSQueryWithNetPkg(context.Background(), backResolver, &response{
+				Header: dnsmessage.Header{
+					ID:       123,
+					Response: true,
+					OpCode:   0, // query
+				},
+				Question: dnsmessage.Question{
+					Name:  dnsmessage.MustNewName(tt.Name),
+					Type:  tt.Type,
+					Class: dnsmessage.ClassINET,
+				},
+			})
+			if err != nil {
+				t.Fatal(err)
+			}
+			if len(got) < dnsHeaderLen {
+				t.Errorf("short record")
+			}
+			if tt.Check != nil {
+				tt.Check(t, got)
+				if t.Failed() {
+					t.Errorf("Got: %q\nIn hex: % 02x", got, got)
+				}
+			}
+		})
+	}
+
+	wrapRes := newWrapResolver(backResolver)
+	ctx := context.Background()
+
+	t.Run("wrap_ip_a", func(t *testing.T) {
+		ips, err := wrapRes.LookupIP(ctx, "ip", "two-a.test.")
+		if err != nil {
+			t.Fatal(err)
+		}
+		if got, want := ips, []net.IP{
+			net.ParseIP("1.2.3.4").To4(),
+			net.ParseIP("5.6.7.8").To4(),
+		}; !reflect.DeepEqual(got, want) {
+			t.Errorf("LookupIP = %v; want %v", got, want)
+		}
+	})
+
+	t.Run("wrap_ip_aaaa", func(t *testing.T) {
+		ips, err := wrapRes.LookupIP(ctx, "ip", "two-aaaa.test.")
+		if err != nil {
+			t.Fatal(err)
+		}
+		if got, want := ips, []net.IP{
+			net.ParseIP("1::2"),
+			net.ParseIP("3::4"),
+		}; !reflect.DeepEqual(got, want) {
+			t.Errorf("LookupIP(v6) = %v; want %v", got, want)
+		}
+	})
+
+	t.Run("wrap_ip_nx", func(t *testing.T) {
+		ips, err := wrapRes.LookupIP(ctx, "ip", "nx-domain.test.")
+		if !isGoNoSuchHostError(err) {
+			t.Errorf("no NX domain = (%v, %v); want no host error", ips, err)
+		}
+	})
+
+	t.Run("wrap_srv", func(t *testing.T) {
+		_, srvs, err := wrapRes.LookupSRV(ctx, "", "", "srv.test.")
+		if err != nil {
+			t.Fatal(err)
+		}
+		if got, want := srvs, []*net.SRV{
+			{
+				Target:   "foo.com.",
+				Priority: 1,
+				Weight:   2,
+				Port:     3,
+			},
+			{
+				Target:   "bar.com.",
+				Priority: 4,
+				Weight:   5,
+				Port:     6,
+			},
+		}; !reflect.DeepEqual(got, want) {
+			jgot, _ := json.Marshal(got)
+			jwant, _ := json.Marshal(want)
+			t.Errorf("SRV = %s; want %s", jgot, jwant)
+		}
+	})
+
+	t.Run("wrap_txt", func(t *testing.T) {
+		txts, err := wrapRes.LookupTXT(ctx, "txt.test.")
+		if err != nil {
+			t.Fatal(err)
+		}
+		if got, want := txts, []string{"txt1=one", "txt2=two", "txt3=three"}; !reflect.DeepEqual(got, want) {
+			t.Errorf("TXT = %q; want %q", got, want)
+		}
+	})
+
+	t.Run("wrap_ns", func(t *testing.T) {
+		nss, err := wrapRes.LookupNS(ctx, "ns.test.")
+		if err != nil {
+			t.Fatal(err)
+		}
+		if got, want := nss, []*net.NS{
+			{Host: "ns1.foo."},
+			{Host: "ns2.bar."},
+		}; !reflect.DeepEqual(got, want) {
+			jgot, _ := json.Marshal(got)
+			jwant, _ := json.Marshal(want)
+			t.Errorf("NS = %s; want %s", jgot, jwant)
+		}
+	})
+}
+
+// newWrapResolver returns a resolver that uses r (via handleExitNodeDNSQueryWithNetPkg)
+// to make DNS requests.
+func newWrapResolver(r *net.Resolver) *net.Resolver {
+	if runtime.GOOS == "windows" {
+		panic("doesn't work on Windows") // golang.org/issue/33097
+	}
+	return &net.Resolver{
+		PreferGo: true,
+		Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
+			return &wrapResolverConn{ctx: ctx, r: r}, nil
+		},
+	}
+}
+
+type wrapResolverConn struct {
+	ctx context.Context
+	r   *net.Resolver
+	buf bytes.Buffer
+}
+
+var _ net.PacketConn = (*wrapResolverConn)(nil)
+
+func (*wrapResolverConn) Close() error                       { return nil }
+func (*wrapResolverConn) LocalAddr() net.Addr                { return fakeAddr{} }
+func (*wrapResolverConn) RemoteAddr() net.Addr               { return fakeAddr{} }
+func (*wrapResolverConn) SetDeadline(t time.Time) error      { return nil }
+func (*wrapResolverConn) SetReadDeadline(t time.Time) error  { return nil }
+func (*wrapResolverConn) SetWriteDeadline(t time.Time) error { return nil }
+
+func (a *wrapResolverConn) Read(p []byte) (n int, err error) {
+	n, _, err = a.ReadFrom(p)
+	return
+}
+
+func (a *wrapResolverConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+	n, err = a.buf.Read(p)
+	return n, fakeAddr{}, err
+}
+
+func (a *wrapResolverConn) Write(packet []byte) (n int, err error) {
+	return a.WriteTo(packet, fakeAddr{})
+}
+
+func (a *wrapResolverConn) WriteTo(q []byte, _ net.Addr) (n int, err error) {
+	resp := parseExitNodeQuery(q)
+	if resp == nil {
+		return 0, errors.New("bad query")
+	}
+	res, err := handleExitNodeDNSQueryWithNetPkg(context.Background(), a.r, resp)
+	if err != nil {
+		return 0, err
+	}
+	a.buf.Write(res)
+	return len(q), nil
+}
+
+type fakeAddr struct{}
+
+func (fakeAddr) Network() string { return "unused" }
+func (fakeAddr) String() string  { return "unused-todoAddr" }
+
+func TestUnARPA(t *testing.T) {
+	tests := []struct {
+		in, want string
+	}{
+		{"", ""},
+		{"bad", ""},
+		{"4.4.8.8.in-addr.arpa.", "8.8.4.4"},
+		{".in-addr.arpa.", ""},
+		{"e.0.0.2.0.0.0.0.0.0.0.0.0.0.0.0.b.0.8.0.a.0.0.4.0.b.8.f.7.0.6.2.ip6.arpa.", "2607:f8b0:400a:80b::200e"},
+		{".ip6.arpa.", ""},
+	}
+	for _, tt := range tests {
+		got, ok := unARPA(tt.in)
+		if ok != (got != "") {
+			t.Errorf("inconsistent results for %q: (%q, %v)", tt.in, got, ok)
+		}
+		if got != tt.want {
+			t.Errorf("unARPA(%q) = %q; want %q", tt.in, got, tt.want)
+		}
+	}
+}