|
|
@@ -21,6 +21,7 @@ import (
|
|
|
"time"
|
|
|
|
|
|
dns "golang.org/x/net/dns/dnsmessage"
|
|
|
+ "tailscale.com/control/controlknobs"
|
|
|
"tailscale.com/envknob"
|
|
|
"tailscale.com/net/dns/publicdns"
|
|
|
"tailscale.com/net/dnscache"
|
|
|
@@ -68,6 +69,10 @@ const (
|
|
|
// DNS queries to the "fallback" DNS server IP for a known provider
|
|
|
// (e.g. how long to wait to query Google's 8.8.4.4 after 8.8.8.8).
|
|
|
wellKnownHostBackupDelay = 200 * time.Millisecond
|
|
|
+
|
|
|
+ // tcpQueryTimeout is the timeout for a DNS query performed over TCP.
|
|
|
+ // It matches the default 5sec timeout of the 'dig' utility.
|
|
|
+ tcpQueryTimeout = 5 * time.Second
|
|
|
)
|
|
|
|
|
|
// txid identifies a DNS transaction.
|
|
|
@@ -180,6 +185,8 @@ type forwarder struct {
|
|
|
linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it
|
|
|
dialer *tsdial.Dialer
|
|
|
|
|
|
+ controlKnobs *controlknobs.Knobs // or nil
|
|
|
+
|
|
|
ctx context.Context // good until Close
|
|
|
ctxCancel context.CancelFunc // closes ctx
|
|
|
|
|
|
@@ -206,12 +213,13 @@ func init() {
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
}
|
|
|
|
|
|
-func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer) *forwarder {
|
|
|
+func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, knobs *controlknobs.Knobs) *forwarder {
|
|
|
f := &forwarder{
|
|
|
- logf: logger.WithPrefix(logf, "forward: "),
|
|
|
- netMon: netMon,
|
|
|
- linkSel: linkSel,
|
|
|
- dialer: dialer,
|
|
|
+ logf: logger.WithPrefix(logf, "forward: "),
|
|
|
+ netMon: netMon,
|
|
|
+ linkSel: linkSel,
|
|
|
+ dialer: dialer,
|
|
|
+ controlKnobs: knobs,
|
|
|
}
|
|
|
f.ctx, f.ctxCancel = context.WithCancel(context.Background())
|
|
|
return f
|
|
|
@@ -443,7 +451,10 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client,
|
|
|
return res, err
|
|
|
}
|
|
|
|
|
|
-var verboseDNSForward = envknob.RegisterBool("TS_DEBUG_DNS_FORWARD_SEND")
|
|
|
+var (
|
|
|
+ verboseDNSForward = envknob.RegisterBool("TS_DEBUG_DNS_FORWARD_SEND")
|
|
|
+ skipTCPRetry = envknob.RegisterBool("TS_DNS_FORWARD_SKIP_TCP_RETRY")
|
|
|
+)
|
|
|
|
|
|
// send sends packet to dst. It is best effort.
|
|
|
//
|
|
|
@@ -477,10 +488,49 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe
|
|
|
return nil, fmt.Errorf("tls:// resolvers not supported yet")
|
|
|
}
|
|
|
|
|
|
- return f.sendUDP(ctx, fq, rr)
|
|
|
+ ret, err = f.sendUDP(ctx, fq, rr)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if !truncatedFlagSet(ret) {
|
|
|
+ // Successful, non-truncated response; return it.
|
|
|
+ return ret, nil
|
|
|
+ }
|
|
|
+ if fq.family == "udp" {
|
|
|
+ // If this is a UDP query, return it regardless of whether the
|
|
|
+ // response is truncated or not; the client can retry
|
|
|
+ // communicating with tailscaled over TCP. There's no point
|
|
|
+ // falling back to TCP for a truncated query if we can't return
|
|
|
+ // the results to the client.
|
|
|
+ return ret, nil
|
|
|
+ }
|
|
|
+ if skipTCPRetry() || (f.controlKnobs != nil && f.controlKnobs.DisableDNSForwarderTCPRetries.Load()) {
|
|
|
+ // Envknob or control knob disabled the TCP retry behaviour;
|
|
|
+ // just return what we have.
|
|
|
+ return ret, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // Don't retry if our context is done.
|
|
|
+ if err := ctx.Err(); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // Retry over TCP, best-effort; return the truncated UDP response if we
|
|
|
+ // cannot query via TCP.
|
|
|
+ if ret2, err2 := f.sendTCP(ctx, fq, rr); err2 == nil {
|
|
|
+ if verboseDNSForward() {
|
|
|
+ f.logf("forwarder.send(%q): successfully retried via TCP", rr.name.Addr)
|
|
|
+ }
|
|
|
+ return ret2, nil
|
|
|
+ } else if verboseDNSForward() {
|
|
|
+ f.logf("forwarder.send(%q): could not retry via TCP: %v", rr.name.Addr, err2)
|
|
|
+ }
|
|
|
+ return ret, nil
|
|
|
}
|
|
|
|
|
|
var errServerFailure = errors.New("response code indicates server issue")
|
|
|
+var errTxIDMismatch = errors.New("txid doesn't match")
|
|
|
|
|
|
func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
|
|
|
ipp, ok := rr.name.IPPort()
|
|
|
@@ -545,7 +595,7 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
|
|
|
txid := getTxID(out)
|
|
|
if txid != fq.txid {
|
|
|
metricDNSFwdUDPErrorTxID.Add(1)
|
|
|
- return nil, errors.New("txid doesn't match")
|
|
|
+ return nil, errTxIDMismatch
|
|
|
}
|
|
|
rcode := getRCode(out)
|
|
|
// don't forward transient errors back to the client when the server fails
|
|
|
@@ -577,6 +627,92 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
|
|
|
return out, nil
|
|
|
}
|
|
|
|
|
|
+func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
|
|
|
+ ipp, ok := rr.name.IPPort()
|
|
|
+ if !ok {
|
|
|
+ metricDNSFwdErrorType.Add(1)
|
|
|
+ return nil, fmt.Errorf("unrecognized resolver type %q", rr.name.Addr)
|
|
|
+ }
|
|
|
+ metricDNSFwdTCP.Add(1)
|
|
|
+ ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderTCP, f.logf)
|
|
|
+
|
|
|
+ // Specify the exact family to work around https://github.com/golang/go/issues/52264
|
|
|
+ tcpFam := "tcp4"
|
|
|
+ if ipp.Addr().Is6() {
|
|
|
+ tcpFam = "tcp6"
|
|
|
+ }
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(ctx, tcpQueryTimeout)
|
|
|
+ defer cancel()
|
|
|
+
|
|
|
+ conn, err := f.dialer.SystemDial(ctx, tcpFam, ipp.String())
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ defer conn.Close()
|
|
|
+
|
|
|
+ fq.closeOnCtxDone.Add(conn)
|
|
|
+ defer fq.closeOnCtxDone.Remove(conn)
|
|
|
+
|
|
|
+ ctxOrErr := func(err2 error) ([]byte, error) {
|
|
|
+ if err := ctx.Err(); err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ return nil, err2
|
|
|
+ }
|
|
|
+
|
|
|
+ // Write the query to the server.
|
|
|
+ query := make([]byte, len(fq.packet)+2)
|
|
|
+ binary.BigEndian.PutUint16(query, uint16(len(fq.packet)))
|
|
|
+ copy(query[2:], fq.packet)
|
|
|
+ if _, err := conn.Write(query); err != nil {
|
|
|
+ metricDNSFwdTCPErrorWrite.Add(1)
|
|
|
+ return ctxOrErr(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ metricDNSFwdTCPWrote.Add(1)
|
|
|
+
|
|
|
+ // Read the header length back from the server
|
|
|
+ var length uint16
|
|
|
+ if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
|
|
|
+ metricDNSFwdTCPErrorRead.Add(1)
|
|
|
+ return ctxOrErr(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Now read the response
|
|
|
+ out := make([]byte, length)
|
|
|
+ n, err := io.ReadFull(conn, out)
|
|
|
+ if err != nil {
|
|
|
+ metricDNSFwdTCPErrorRead.Add(1)
|
|
|
+ return ctxOrErr(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if n < int(length) {
|
|
|
+ f.logf("sendTCP: packet too small (%d bytes)", n)
|
|
|
+ return nil, io.ErrUnexpectedEOF
|
|
|
+ }
|
|
|
+ out = out[:n]
|
|
|
+ txid := getTxID(out)
|
|
|
+ if txid != fq.txid {
|
|
|
+ metricDNSFwdTCPErrorTxID.Add(1)
|
|
|
+ return nil, errTxIDMismatch
|
|
|
+ }
|
|
|
+
|
|
|
+ rcode := getRCode(out)
|
|
|
+
|
|
|
+ // don't forward transient errors back to the client when the server fails
|
|
|
+ if rcode == dns.RCodeServerFailure {
|
|
|
+ f.logf("sendTCP: response code indicating server failure: %d", rcode)
|
|
|
+ metricDNSFwdTCPErrorServer.Add(1)
|
|
|
+ return nil, errServerFailure
|
|
|
+ }
|
|
|
+
|
|
|
+ // TODO(andrew): do we need to do this?
|
|
|
+ //clampEDNSSize(out, maxResponseBytes)
|
|
|
+ metricDNSFwdTCPSuccess.Add(1)
|
|
|
+ return out, nil
|
|
|
+}
|
|
|
+
|
|
|
// resolvers returns the resolvers to use for domain.
|
|
|
func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay {
|
|
|
f.mu.Lock()
|
|
|
@@ -601,6 +737,7 @@ func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay {
|
|
|
type forwardQuery struct {
|
|
|
txid txid
|
|
|
packet []byte
|
|
|
+ family string // "tcp" or "udp"
|
|
|
|
|
|
// closeOnCtxDone lets send register values to Close if the
|
|
|
// caller's ctx expires. This avoids send from allocating its
|
|
|
@@ -686,6 +823,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
|
|
|
fq := &forwardQuery{
|
|
|
txid: getTxID(query.bs),
|
|
|
packet: query.bs,
|
|
|
+ family: query.family,
|
|
|
closeOnCtxDone: new(closePool),
|
|
|
}
|
|
|
defer fq.closeOnCtxDone.Close()
|
|
|
@@ -727,7 +865,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
|
|
|
case <-ctx.Done():
|
|
|
metricDNSFwdErrorContext.Add(1)
|
|
|
return ctx.Err()
|
|
|
- case responseChan <- packet{v, query.addr}:
|
|
|
+ case responseChan <- packet{v, query.family, query.addr}:
|
|
|
metricDNSFwdSuccess.Add(1)
|
|
|
return nil
|
|
|
}
|