Ver código fonte

net/dns, wgengine: implement DNS over TCP (#4598)

* net/dns, wgengine: implement DNS over TCP

Signed-off-by: Tom DNetto <[email protected]>

* wgengine/netstack: intercept only relevant port/protocols to quad-100

Signed-off-by: Tom DNetto <[email protected]>
Tom 3 anos atrás
pai
commit
d1d6ab068e
3 arquivos alterados com 273 adições e 0 exclusões
  1. 119 0
      net/dns/manager.go
  2. 136 0
      net/dns/manager_tcp_test.go
  3. 18 0
      wgengine/netstack/netstack.go

+ 119 - 0
net/dns/manager.go

@@ -7,7 +7,9 @@ package dns
 import (
 	"bufio"
 	"context"
+	"encoding/binary"
 	"errors"
+	"io"
 	"net"
 	"runtime"
 	"sync/atomic"
@@ -346,6 +348,123 @@ func (m *Manager) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([]
 	return m.resolver.Query(ctx, bs, from)
 }
 
+const (
+	// RFC 7766 6.2 recommends connection reuse & request pipelining
+	// be undertaken, and the connection be closed by the server
+	// using an idle timeout on the order of seconds.
+	idleTimeoutTCP = 45 * time.Second
+	// The RFCs don't specify the max size of a TCP-based DNS query,
+	// but we want to keep this reasonable. Given payloads are typically
+	// much larger and all known client send a single query, I've arbitrarily
+	// chosen 2k.
+	maxReqSizeTCP = 2048
+)
+
+// dnsTCPSession services DNS requests sent over TCP.
+type dnsTCPSession struct {
+	m *Manager
+
+	conn    net.Conn
+	srcAddr netaddr.IPPort
+
+	readClosing  chan struct{}
+	responses    chan []byte // DNS replies pending writing
+
+	ctx      context.Context
+	closeCtx context.CancelFunc
+}
+
+func (s *dnsTCPSession) handleWrites() {
+	defer s.conn.Close()
+	defer close(s.responses)
+	defer s.closeCtx()
+
+	for {
+		select {
+		case <-s.readClosing:
+			return // connection closed or timeout, teardown time
+
+		case resp := <-s.responses:
+			s.conn.SetWriteDeadline(time.Now().Add(idleTimeoutTCP))
+			if err := binary.Write(s.conn, binary.BigEndian, uint16(len(resp))); err != nil {
+				s.m.logf("tcp write (len): %v", err)
+				return
+			}
+			if _, err := s.conn.Write(resp); err != nil {
+				s.m.logf("tcp write (response): %v", err)
+				return
+			}
+		}
+	}
+}
+
+func (s *dnsTCPSession) handleQuery(q []byte) {
+	resp, err := s.m.Query(s.ctx, q, s.srcAddr)
+	if err != nil {
+		s.m.logf("tcp query: %v", err)
+		return
+	}
+
+	select {
+	case <-s.ctx.Done():
+	case s.responses <- resp:
+	}
+}
+
+func (s *dnsTCPSession) handleReads() {
+	defer close(s.readClosing)
+
+	for {
+		select {
+		case <-s.ctx.Done():
+			return
+
+		default:
+			s.conn.SetReadDeadline(time.Now().Add(idleTimeoutTCP))
+			var reqLen uint16
+			if err := binary.Read(s.conn, binary.BigEndian, &reqLen); err != nil {
+				if err == io.EOF || err == io.ErrClosedPipe {
+					return // connection closed nominally, we gucci
+				}
+				s.m.logf("tcp read (len): %v", err)
+				return
+			}
+			if int(reqLen) > maxReqSizeTCP {
+				s.m.logf("tcp request too large (%d > %d)", reqLen, maxReqSizeTCP)
+				return
+			}
+
+			buf := make([]byte, int(reqLen))
+			if _, err := io.ReadFull(s.conn, buf); err != nil {
+				s.m.logf("tcp read (payload): %v", err)
+				return
+			}
+
+			select {
+			case <-s.ctx.Done():
+				return
+			default:
+				go s.handleQuery(buf)
+			}
+		}
+	}
+}
+
+// HandleTCPConn implements magicDNS over TCP, taking a connection and
+// servicing DNS requests sent down it.
+func (m *Manager) HandleTCPConn(conn net.Conn, srcAddr netaddr.IPPort) {
+	s := dnsTCPSession{
+		m:            m,
+		conn:         conn,
+		srcAddr:      srcAddr,
+		responses:    make(chan []byte),
+		readClosing:  make(chan struct{}),
+	}
+	s.ctx, s.closeCtx = context.WithCancel(context.Background())
+	go s.handleReads()
+	s.handleWrites()
+}
+
 func (m *Manager) Down() error {
 	m.ctxCancel()
 	if err := m.os.Close(); err != nil {

+ 136 - 0
net/dns/manager_tcp_test.go

@@ -0,0 +1,136 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+	"encoding/binary"
+	"io"
+	"net"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+	dns "golang.org/x/net/dns/dnsmessage"
+	"inet.af/netaddr"
+	"tailscale.com/net/tsdial"
+	"tailscale.com/util/dnsname"
+)
+
+func mkDNSRequest(domain dnsname.FQDN, tp dns.Type) []byte {
+	var dnsHeader dns.Header
+	question := dns.Question{
+		Name:  dns.MustNewName(domain.WithTrailingDot()),
+		Type:  tp,
+		Class: dns.ClassINET,
+	}
+
+	builder := dns.NewBuilder(nil, dnsHeader)
+	if err := builder.StartQuestions(); err != nil {
+		panic(err)
+	}
+	if err := builder.Question(question); err != nil {
+		panic(err)
+	}
+
+	if err := builder.StartAdditionals(); err != nil {
+		panic(err)
+	}
+
+	ednsHeader := dns.ResourceHeader{
+		Name:  dns.MustNewName("."),
+		Type:  dns.TypeOPT,
+		Class: dns.Class(4095),
+	}
+
+	if err := builder.OPTResource(ednsHeader, dns.OPTResource{}); err != nil {
+		panic(err)
+	}
+
+	payload, _ := builder.Finish()
+
+	return payload
+}
+
+func TestDNSOverTCP(t *testing.T) {
+	f := fakeOSConfigurator{
+		SplitDNS: true,
+		BaseConfig: OSConfig{
+			Nameservers:   mustIPs("8.8.8.8"),
+			SearchDomains: fqdns("coffee.shop"),
+		},
+	}
+	m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil)
+	m.resolver.TestOnlySetHook(f.SetResolver)
+	m.Set(Config{
+		Hosts: hosts(
+			"dave.ts.com.", "1.2.3.4",
+			"bradfitz.ts.com.", "2.3.4.5"),
+		Routes:        upstreams("ts.com", ""),
+		SearchDomains: fqdns("tailscale.com", "universe.tf"),
+	})
+	defer m.Down()
+
+	c, s := net.Pipe()
+	defer s.Close()
+	go m.HandleTCPConn(s, netaddr.IPPort{})
+	defer c.Close()
+
+	wantResults := map[dnsname.FQDN]string{
+		"dave.ts.com.":     "1.2.3.4",
+		"bradfitz.ts.com.": "2.3.4.5",
+	}
+
+	for domain, _ := range wantResults {
+		b := mkDNSRequest(domain, dns.TypeA)
+		binary.Write(c, binary.BigEndian, uint16(len(b)))
+		c.Write(b)
+	}
+
+	results := map[dnsname.FQDN]string{}
+	for i := 0; i < len(wantResults); i++ {
+		var respLength uint16
+		if err := binary.Read(c, binary.BigEndian, &respLength); err != nil {
+			t.Fatalf("reading len: %v", err)
+		}
+		resp := make([]byte, int(respLength))
+		if _, err := io.ReadFull(c, resp); err != nil {
+			t.Fatalf("reading data: %v", err)
+		}
+
+		var parser dns.Parser
+		if _, err := parser.Start(resp); err != nil {
+			t.Errorf("parser.Start() failed: %v", err)
+			continue
+		}
+		q, err := parser.Question()
+		if err != nil {
+			t.Errorf("parser.Question(): %v", err)
+			continue
+		}
+		if err := parser.SkipAllQuestions(); err != nil {
+			t.Errorf("parser.SkipAllQuestions(): %v", err)
+			continue
+		}
+		ah, err := parser.AnswerHeader()
+		if err != nil {
+			t.Errorf("parser.AnswerHeader(): %v", err)
+			continue
+		}
+		if ah.Type != dns.TypeA {
+			t.Errorf("unexpected answer type: got %v, want %v", ah.Type, dns.TypeA)
+			continue
+		}
+		res, err := parser.AResource()
+		if err != nil {
+			t.Errorf("parser.AResource(): %v", err)
+			continue
+		}
+		results[dnsname.FQDN(q.Name.String())] = net.IP(res.A[:]).String()
+	}
+	c.Close()
+
+	if diff := cmp.Diff(wantResults, results); diff != "" {
+		t.Errorf("wrong results (-got+want)\n%s", diff)
+	}
+}

+ 18 - 0
wgengine/netstack/netstack.go

@@ -373,6 +373,19 @@ func (ns *Impl) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) filter.Re
 	if dst := p.Dst.IP(); dst != magicDNSIP && dst != magicDNSIPv6 {
 		return filter.Accept
 	}
+	// Of traffic to the service IP, we only care about UDP 53, and TCP
+	// on port 80 & 53.
+	switch p.IPProto {
+	case ipproto.TCP:
+		if port := p.Dst.Port(); port != 53 && port != 80 {
+			return filter.Accept
+		}
+	case ipproto.UDP:
+		if port := p.Dst.Port(); port != 53 {
+			return filter.Accept
+		}
+	}
+
 
 	var pn tcpip.NetworkProtocolNumber
 	switch p.IPVersion {
@@ -758,6 +771,11 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
 	// block until the TCP handshake is complete.
 	c := gonet.NewTCPConn(&wq, ep)
 
+	if reqDetails.LocalPort == 53 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) {
+		go ns.dns.HandleTCPConn(c, netaddr.IPPortFrom(clientRemoteIP, reqDetails.RemotePort))
+		return
+	}
+
 	if ns.lb != nil {
 		if reqDetails.LocalPort == 22 && ns.processSSH() && ns.isLocalIP(dialIP) {
 			if err := ns.lb.HandleSSHConn(c); err != nil {