Browse Source

cmd/natc: add --ignore-destinations flag

Updates tailscale/corp#20503

Signed-off-by: Fran Bull <[email protected]>
Fran Bull 1 year ago
parent
commit
d2d459d442
1 changed files with 137 additions and 34 deletions
  1. 137 34
      cmd/natc/natc.go

+ 137 - 34
cmd/natc/natc.go

@@ -48,12 +48,13 @@ func main() {
 	// Parse flags
 	fs := flag.NewFlagSet("natc", flag.ExitOnError)
 	var (
-		debugPort    = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint")
-		hostname     = fs.String("hostname", "", "Hostname to register the service under")
-		siteID       = fs.Uint("site-id", 1, "an integer site ID to use for the ULA prefix which allows for multiple proxies to act in a HA configuration")
-		v4PfxStr     = fs.String("v4-pfx", "100.64.1.0/24", "comma-separated list of IPv4 prefixes to advertise")
-		verboseTSNet = fs.Bool("verbose-tsnet", false, "enable verbose logging in tsnet")
-		printULA     = fs.Bool("print-ula", false, "print the ULA prefix and exit")
+		debugPort       = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint")
+		hostname        = fs.String("hostname", "", "Hostname to register the service under")
+		siteID          = fs.Uint("site-id", 1, "an integer site ID to use for the ULA prefix which allows for multiple proxies to act in a HA configuration")
+		v4PfxStr        = fs.String("v4-pfx", "100.64.1.0/24", "comma-separated list of IPv4 prefixes to advertise")
+		verboseTSNet    = fs.Bool("verbose-tsnet", false, "enable verbose logging in tsnet")
+		printULA        = fs.Bool("print-ula", false, "print the ULA prefix and exit")
+		ignoreDstPfxStr = fs.String("ignore-destinations", "", "comma-separated list of prefixes to ignore")
 	)
 	ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_NATC"))
 
@@ -70,6 +71,24 @@ func main() {
 		log.Fatalf("site-id must be in the range [0, 65535]")
 	}
 
+	var ignoreDstTable *bart.Table[bool]
+	for _, s := range strings.Split(*ignoreDstPfxStr, ",") {
+		s := strings.TrimSpace(s)
+		if s == "" {
+			continue
+		}
+		if ignoreDstTable == nil {
+			ignoreDstTable = &bart.Table[bool]{}
+		}
+		pfx, err := netip.ParsePrefix(s)
+		if err != nil {
+			log.Fatalf("unable to parse prefix: %v", err)
+		}
+		if pfx.Masked() != pfx {
+			log.Fatalf("prefix %v is not normalized (bits are set outside the mask)", pfx)
+		}
+		ignoreDstTable.Insert(pfx, true)
+	}
 	var v4Prefixes []netip.Prefix
 	for _, s := range strings.Split(*v4PfxStr, ",") {
 		p := netip.MustParsePrefix(strings.TrimSpace(s))
@@ -112,11 +131,12 @@ func main() {
 	}
 
 	c := &connector{
-		ts:       ts,
-		lc:       lc,
-		dnsAddr:  dnsAddr,
-		v4Ranges: v4Prefixes,
-		v6ULA:    ula(uint16(*siteID)),
+		ts:         ts,
+		lc:         lc,
+		dnsAddr:    dnsAddr,
+		v4Ranges:   v4Prefixes,
+		v6ULA:      ula(uint16(*siteID)),
+		ignoreDsts: ignoreDstTable,
 	}
 	c.run(ctx)
 }
@@ -139,6 +159,15 @@ type connector struct {
 	v6ULA netip.Prefix
 
 	perPeerMap syncs.Map[tailcfg.NodeID, *perPeerState]
+
+	// ignoreDsts is initialized at start up with the contents of --ignore-destinations (if none it is nil)
+	// It is never mutated, only used for lookups.
+	// Users who want to natc a DNS wildcard but not every address record in that domain can supply the
+	// exceptions in --ignore-destinations. When we receive a dns request we will look up the fqdn
+	// and if any of the ip addresses in response to the lookup match any 'ignore destinations' prefix we will
+	// return a dns response that contains the ip addresses we discovered with the lookup (ie not the
+	// natc behavior, which would return a dummy ip address pointing at natc).
+	ignoreDsts *bart.Table[bool]
 }
 
 // v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses.
@@ -192,6 +221,26 @@ func (c *connector) serveDNS() {
 	}
 }
 
+func lookupDestinationIP(domain string) ([]netip.Addr, error) {
+	netIPs, err := net.LookupIP(domain)
+	if err != nil {
+		var dnsError *net.DNSError
+		if errors.As(err, &dnsError) && dnsError.IsNotFound {
+			return nil, nil
+		} else {
+			return nil, err
+		}
+	}
+	var addrs []netip.Addr
+	for _, ip := range netIPs {
+		a, ok := netip.AddrFromSlice(ip)
+		if ok {
+			addrs = append(addrs, a)
+		}
+	}
+	return addrs, nil
+}
+
 // handleDNS handles a DNS request to the app connector.
 // It generates a response based on the request and the node that sent it.
 //
@@ -219,11 +268,44 @@ func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDP
 		return
 	}
 
+	// If there are destination ips that we don't want to route, we
+	// have to do a dns lookup here to find the destination ip.
+	if c.ignoreDsts != nil {
+		if len(msg.Questions) > 0 {
+			q := msg.Questions[0]
+			switch q.Type {
+			case dnsmessage.TypeAAAA, dnsmessage.TypeA:
+				dstAddrs, err := lookupDestinationIP(q.Name.String())
+				if err != nil {
+					log.Printf("HandleDNS: lookup destination failed: %v\n ", err)
+					return
+				}
+				if c.ignoreDestination(dstAddrs) {
+					bs, err := dnsResponse(&msg, dstAddrs)
+					// TODO (fran): treat as SERVFAIL
+					if err != nil {
+						log.Printf("HandleDNS: generate ignore response failed: %v\n", err)
+						return
+					}
+					_, err = pc.WriteTo(bs, remoteAddr)
+					if err != nil {
+						log.Printf("HandleDNS: write failed: %v\n", err)
+					}
+					return
+				}
+			}
+		}
+	}
+	// None of the destination IP addresses match an ignore destination prefix, do
+	// the natc thing.
+
 	resp, err := c.generateDNSResponse(&msg, who.Node.ID)
+	// TODO (fran): treat as SERVFAIL
 	if err != nil {
 		log.Printf("HandleDNS: connector handling failed: %v\n", err)
 		return
 	}
+	// TODO (fran): treat as NXDOMAIN
 	if len(resp) == 0 {
 		return
 	}
@@ -244,6 +326,23 @@ var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
 // argument is the NodeID of the node that sent the request.
 func (c *connector) generateDNSResponse(req *dnsmessage.Message, from tailcfg.NodeID) ([]byte, error) {
 	pm, _ := c.perPeerMap.LoadOrStore(from, &perPeerState{c: c})
+	var addrs []netip.Addr
+	if len(req.Questions) > 0 {
+		switch req.Questions[0].Type {
+		case dnsmessage.TypeAAAA, dnsmessage.TypeA:
+			var err error
+			addrs, err = pm.ipForDomain(req.Questions[0].Name.String())
+			if err != nil {
+				return nil, err
+			}
+		}
+	}
+	return dnsResponse(req, addrs)
+}
+
+// dnsResponse makes a DNS response for the natc. If the dnsmessage is requesting TypeAAAA
+// or TypeA the provided addrs of the requested type will be used.
+func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) {
 	b := dnsmessage.NewBuilder(nil,
 		dnsmessage.Header{
 			ID:            req.Header.ID,
@@ -265,51 +364,44 @@ func (c *connector) generateDNSResponse(req *dnsmessage.Message, from tailcfg.No
 	if err := b.StartAnswers(); err != nil {
 		return nil, err
 	}
-	var err error
 	switch q.Type {
 	case dnsmessage.TypeAAAA, dnsmessage.TypeA:
-		var addrs []netip.Addr
-		addrs, err = pm.ipForDomain(q.Name.String())
-		if err != nil {
-			return nil, err
-		}
 		want6 := q.Type == dnsmessage.TypeAAAA
-		found := false
 		for _, ip := range addrs {
 			if want6 != ip.Is6() {
 				continue
 			}
-			found = true
 			if want6 {
-				err = b.AAAAResource(
+				if err := b.AAAAResource(
 					dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5},
 					dnsmessage.AAAAResource{AAAA: ip.As16()},
-				)
+				); err != nil {
+					return nil, err
+				}
 			} else {
-				err = b.AResource(
+				if err := b.AResource(
 					dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5},
 					dnsmessage.AResource{A: ip.As4()},
-				)
+				); err != nil {
+					return nil, err
+				}
 			}
-			break
-		}
-		if !found {
-			err = errors.New("no address found")
 		}
 	case dnsmessage.TypeSOA:
-		err = b.SOAResource(
+		if err := b.SOAResource(
 			dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
 			dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
 				Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
-		)
+		); err != nil {
+			return nil, err
+		}
 	case dnsmessage.TypeNS:
-		err = b.NSResource(
+		if err := b.NSResource(
 			dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
 			dnsmessage.NSResource{NS: tsMBox},
-		)
-	}
-	if err != nil {
-		return nil, err
+		); err != nil {
+			return nil, err
+		}
 	}
 	return b.Finish()
 }
@@ -345,6 +437,17 @@ func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Con
 	}, true
 }
 
+// ignoreDestination reports whether any of the provided dstAddrs match the prefixes configured
+// in --ignore-destinations
+func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool {
+	for _, a := range dstAddrs {
+		if _, ok := c.ignoreDsts.Get(a); ok {
+			return true
+		}
+	}
+	return false
+}
+
 func proxyTCPConn(c net.Conn, dest string) {
 	addrPortStr := c.LocalAddr().String()
 	_, port, err := net.SplitHostPort(addrPortStr)