Browse Source

cmd/natc: add a flag to use specific DNS servers

If natc is running on a host with tailscale using `--accept-dns=true`
then a DNS loop can occur. Provide a flag for some specific DNS
upstreams for natc to use instead, to overcome such situations.

Updates #14667

Signed-off-by: James Tucker <[email protected]>
James Tucker 8 months ago
parent
commit
86985228bc
2 changed files with 225 additions and 2 deletions
  1. 29 2
      cmd/natc/natc.go
  2. 196 0
      cmd/natc/natc_test.go

+ 29 - 2
cmd/natc/natc.go

@@ -54,6 +54,7 @@ func main() {
 		hostname        = fs.String("hostname", "", "Hostname to register the service under")
 		siteID          = fs.Uint("site-id", 1, "an integer site ID to use for the ULA prefix which allows for multiple proxies to act in a HA configuration")
 		v4PfxStr        = fs.String("v4-pfx", "100.64.1.0/24", "comma-separated list of IPv4 prefixes to advertise")
+		dnsServers      = fs.String("dns-servers", "", "comma separated list of upstream DNS to use, including host and port (use system if empty)")
 		verboseTSNet    = fs.Bool("verbose-tsnet", false, "enable verbose logging in tsnet")
 		printULA        = fs.Bool("print-ula", false, "print the ULA prefix and exit")
 		ignoreDstPfxStr = fs.String("ignore-destinations", "", "comma-separated list of prefixes to ignore")
@@ -78,7 +79,7 @@ func main() {
 	}
 
 	var ignoreDstTable *bart.Table[bool]
-	for _, s := range strings.Split(*ignoreDstPfxStr, ",") {
+	for s := range strings.SplitSeq(*ignoreDstPfxStr, ",") {
 		s := strings.TrimSpace(s)
 		if s == "" {
 			continue
@@ -185,11 +186,37 @@ func main() {
 		ipPool:     ipp,
 		routes:     routes,
 		dnsAddr:    dnsAddr,
-		resolver:   net.DefaultResolver,
+		resolver:   getResolver(*dnsServers),
 	}
 	c.run(ctx, lc)
 }
 
+// getResolver parses serverFlag and returns either the default resolver, or a
+// resolver that uses the provided comma-separated DNS server AddrPort's, or
+// panics.
+func getResolver(serverFlag string) lookupNetIPer {
+	if serverFlag == "" {
+		return net.DefaultResolver
+	}
+	var addrs []string
+	for s := range strings.SplitSeq(serverFlag, ",") {
+		s = strings.TrimSpace(s)
+		addr, err := netip.ParseAddrPort(s)
+		if err != nil {
+			log.Fatalf("dns server provided: %q does not parse: %v", s, err)
+		}
+		addrs = append(addrs, addr.String())
+	}
+	return &net.Resolver{
+		PreferGo: true,
+		Dial: func(ctx context.Context, network string, address string) (net.Conn, error) {
+			var dialer net.Dialer
+			// TODO(raggi): perhaps something other than random?
+			return dialer.DialContext(ctx, network, addrs[rand.N(len(addrs))])
+		},
+	}
+}
+
 func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) {
 	var ipsb netipx.IPSetBuilder
 	for _, p := range prefixes {

+ 196 - 0
cmd/natc/natc_test.go

@@ -9,6 +9,7 @@ import (
 	"io"
 	"net"
 	"net/netip"
+	"sync"
 	"testing"
 	"time"
 
@@ -480,3 +481,198 @@ func TestV6V4(t *testing.T) {
 		}
 	}
 }
+
+// echoServer is a simple server that just echos back data set to it.
+type echoServer struct {
+	listener net.Listener
+	addr     string
+	wg       sync.WaitGroup
+	done     chan struct{}
+}
+
+// newEchoServer creates a new test DNS server on the specified network and address
+func newEchoServer(t *testing.T, network, addr string) *echoServer {
+	listener, err := net.Listen(network, addr)
+	if err != nil {
+		t.Fatalf("Failed to create test DNS server: %v", err)
+	}
+
+	server := &echoServer{
+		listener: listener,
+		addr:     listener.Addr().String(),
+		done:     make(chan struct{}),
+	}
+
+	server.wg.Add(1)
+	go server.serve()
+
+	return server
+}
+
+func (s *echoServer) serve() {
+	defer s.wg.Done()
+
+	for {
+		select {
+		case <-s.done:
+			return
+		default:
+			conn, err := s.listener.Accept()
+			if err != nil {
+				select {
+				case <-s.done:
+					return
+				default:
+					continue
+				}
+			}
+			go s.handleConnection(conn)
+		}
+	}
+}
+
+func (s *echoServer) handleConnection(conn net.Conn) {
+	defer conn.Close()
+	// Simple response - just echo back some data to confirm connectivity
+	buf := make([]byte, 1024)
+	n, err := conn.Read(buf)
+	if err != nil {
+		return
+	}
+	conn.Write(buf[:n])
+}
+
+func (s *echoServer) close() {
+	close(s.done)
+	s.listener.Close()
+	s.wg.Wait()
+}
+
+func TestGetResolver(t *testing.T) {
+	tests := []struct {
+		name    string
+		network string
+		addr    string
+	}{
+		{
+			name:    "ipv4_loopback",
+			network: "tcp4",
+			addr:    "127.0.0.1:0",
+		},
+		{
+			name:    "ipv6_loopback",
+			network: "tcp6",
+			addr:    "[::1]:0",
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			server := newEchoServer(t, tc.network, tc.addr)
+			defer server.close()
+			serverAddr := server.addr
+			resolver := getResolver(serverAddr)
+			if resolver == nil {
+				t.Fatal("getResolver returned nil")
+			}
+
+			netResolver, ok := resolver.(*net.Resolver)
+			if !ok {
+				t.Fatal("getResolver did not return a *net.Resolver")
+			}
+			if netResolver.Dial == nil {
+				t.Fatal("resolver.Dial is nil")
+			}
+
+			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+			defer cancel()
+			conn, err := netResolver.Dial(ctx, "tcp", "dummy.address:53")
+			if err != nil {
+				t.Fatalf("Failed to dial test DNS server: %v", err)
+			}
+			defer conn.Close()
+
+			testData := []byte("test")
+			_, err = conn.Write(testData)
+			if err != nil {
+				t.Fatalf("Failed to write to connection: %v", err)
+			}
+
+			response := make([]byte, len(testData))
+			_, err = conn.Read(response)
+			if err != nil {
+				t.Fatalf("Failed to read from connection: %v", err)
+			}
+
+			if string(response) != string(testData) {
+				t.Fatalf("Expected echo response %q, got %q", testData, response)
+			}
+		})
+	}
+}
+
+func TestGetResolverMultipleServers(t *testing.T) {
+	server1 := newEchoServer(t, "tcp4", "127.0.0.1:0")
+	defer server1.close()
+	server2 := newEchoServer(t, "tcp4", "127.0.0.1:0")
+	defer server2.close()
+	serverFlag := server1.addr + ", " + server2.addr
+
+	resolver := getResolver(serverFlag)
+	netResolver, ok := resolver.(*net.Resolver)
+	if !ok {
+		t.Fatal("getResolver did not return a *net.Resolver")
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+	defer cancel()
+
+	servers := map[string]bool{
+		server1.addr: false,
+		server2.addr: false,
+	}
+
+	// Try up to 1000 times to hit all servers, this should be very quick, and
+	// if this fails randomness has regressed beyond reason.
+	for range 1000 {
+		conn, err := netResolver.Dial(ctx, "tcp", "dummy.address:53")
+		if err != nil {
+			t.Fatalf("Failed to dial test DNS server: %v", err)
+		}
+
+		remoteAddr := conn.RemoteAddr().String()
+
+		conn.Close()
+
+		servers[remoteAddr] = true
+
+		var allDone = true
+		for _, done := range servers {
+			if !done {
+				allDone = false
+				break
+			}
+		}
+		if allDone {
+			break
+		}
+	}
+
+	var allDone = true
+	for _, done := range servers {
+		if !done {
+			allDone = false
+			break
+		}
+	}
+	if !allDone {
+		t.Errorf("after 1000 queries, not all servers were hit, significant lack of randomness: %#v", servers)
+	}
+}
+
+func TestGetResolverEmpty(t *testing.T) {
+	resolver := getResolver("")
+	if resolver != net.DefaultResolver {
+		t.Fatal(`getResolver("") should return net.DefaultResolver`)
+	}
+}