2
0
Эх сурвалжийг харах

appc,ipn/ipnlocal: add app connector routes if any part of a CNAME chain is routed

If any domain along a CNAME chain matches any of the routed domains, add
routes for the discovered domains.

Fixes tailscale/corp#16928

Signed-off-by: James Tucker <[email protected]>
James Tucker 2 жил өмнө
parent
commit
e1a4b89dbe

+ 115 - 41
appc/appconnector.go

@@ -22,6 +22,7 @@ import (
 	"tailscale.com/types/views"
 	"tailscale.com/util/dnsname"
 	"tailscale.com/util/execqueue"
+	"tailscale.com/util/mak"
 )
 
 // RouteAdvertiser is an interface that allows the AppConnector to advertise
@@ -206,7 +207,16 @@ func (e *AppConnector) ObserveDNSResponse(res []byte) {
 		return
 	}
 
-nextAnswer:
+	// cnameChain tracks a chain of CNAMEs for a given query in order to reverse
+	// a CNAME chain back to the original query for flattening. The keys are
+	// CNAME record targets, and the value is the name the record answers, so
+	// for www.example.com CNAME example.com, the map would contain
+	// ["example.com"] = "www.example.com".
+	var cnameChain map[string]string
+
+	// addressRecords is a list of address records found in the response.
+	var addressRecords map[string][]netip.Addr
+
 	for {
 		h, err := p.AnswerHeader()
 		if err == dnsmessage.ErrSectionDone {
@@ -222,83 +232,147 @@ nextAnswer:
 			}
 			continue
 		}
-		if h.Type != dnsmessage.TypeA && h.Type != dnsmessage.TypeAAAA {
+
+		switch h.Type {
+		case dnsmessage.TypeCNAME, dnsmessage.TypeA, dnsmessage.TypeAAAA:
+		default:
 			if err := p.SkipAnswer(); err != nil {
 				return
 			}
 			continue
-		}
 
-		domain := h.Name.String()
-		if len(domain) == 0 {
-			return
 		}
-		domain = strings.TrimSuffix(domain, ".")
-		domain = strings.ToLower(domain)
-		e.logf("[v2] observed DNS response for %s", domain)
 
-		e.mu.Lock()
-		addrs, ok := e.domains[domain]
-		// match wildcard domains
-		if !ok {
-			for _, wc := range e.wildcards {
-				if dnsname.HasSuffix(domain, wc) {
-					e.domains[domain] = nil
-					ok = true
-					break
-				}
-			}
+		domain := strings.TrimSuffix(strings.ToLower(h.Name.String()), ".")
+		if len(domain) == 0 {
+			continue
 		}
-		e.mu.Unlock()
 
-		if !ok {
-			if err := p.SkipAnswer(); err != nil {
+		if h.Type == dnsmessage.TypeCNAME {
+			res, err := p.CNAMEResource()
+			if err != nil {
 				return
 			}
+			cname := strings.TrimSuffix(strings.ToLower(res.CNAME.String()), ".")
+			if len(cname) == 0 {
+				continue
+			}
+			mak.Set(&cnameChain, cname, domain)
 			continue
 		}
 
-		var addr netip.Addr
 		switch h.Type {
 		case dnsmessage.TypeA:
 			r, err := p.AResource()
 			if err != nil {
 				return
 			}
-			addr = netip.AddrFrom4(r.A)
+			addr := netip.AddrFrom4(r.A)
+			mak.Set(&addressRecords, domain, append(addressRecords[domain], addr))
 		case dnsmessage.TypeAAAA:
 			r, err := p.AAAAResource()
 			if err != nil {
 				return
 			}
-			addr = netip.AddrFrom16(r.AAAA)
+			addr := netip.AddrFrom16(r.AAAA)
+			mak.Set(&addressRecords, domain, append(addressRecords[domain], addr))
 		default:
 			if err := p.SkipAnswer(); err != nil {
 				return
 			}
 			continue
 		}
-		if slices.Contains(addrs, addr) {
+	}
+
+	e.mu.Lock()
+	defer e.mu.Unlock()
+
+	for domain, addrs := range addressRecords {
+		domain, isRouted := e.findRoutedDomainLocked(domain, cnameChain)
+
+		// domain and none of the CNAMEs in the chain are routed
+		if !isRouted {
 			continue
 		}
-		for _, route := range e.controlRoutes {
-			if route.Contains(addr) {
-				// record the new address associated with the domain for faster matching in subsequent
-				// requests and for diagnostic records.
-				e.mu.Lock()
-				e.domains[domain] = append(addrs, addr)
-				e.mu.Unlock()
-				continue nextAnswer
+
+		// advertise each address we have learned for the routed domain, that
+		// was not already known.
+		for _, addr := range addrs {
+			e.logf("[v2] observed routed DNS response for %s: %s", domain, addr)
+			if e.isAddrKnownLocked(domain, addr) {
+				continue
+			}
+
+			e.scheduleAdvertisement(domain, addr)
+		}
+	}
+}
+
+// starting from the given domain that resolved to an address, find it, or any
+// of the domains in the CNAME chain toward resolving it, that are routed
+// domains, returning the routed domain name and a bool indicating whether a
+// routed domain was found.
+// e.mu must be held.
+func (e *AppConnector) findRoutedDomainLocked(domain string, cnameChain map[string]string) (string, bool) {
+	var isRouted bool
+	for {
+		_, isRouted = e.domains[domain]
+		if isRouted {
+			break
+		}
+
+		// match wildcard domains
+		for _, wc := range e.wildcards {
+			if dnsname.HasSuffix(domain, wc) {
+				e.domains[domain] = nil
+				isRouted = true
+				break
 			}
 		}
+
+		next, ok := cnameChain[domain]
+		if !ok {
+			break
+		}
+		domain = next
+	}
+	return domain, isRouted
+}
+
+// isAddrKnownLocked returns true if the address is known to be associated with
+// the given domain. Known domain tables are updated for covered routes to speed
+// up future matches.
+// e.mu must be held.
+func (e *AppConnector) isAddrKnownLocked(domain string, addr netip.Addr) bool {
+	if slices.Contains(e.domains[domain], addr) {
+		return true
+	}
+	for _, route := range e.controlRoutes {
+		if route.Contains(addr) {
+			// record the new address associated with the domain for faster matching in subsequent
+			// requests and for diagnostic records.
+			e.domains[domain] = append(e.domains[domain], addr)
+			return true
+		}
+	}
+	return false
+
+}
+
+// scheduleAdvertisement schedules an advertisement of the given address
+// associated with the given domain.
+func (e *AppConnector) scheduleAdvertisement(domain string, addr netip.Addr) {
+	e.queue.Add(func() {
 		if err := e.routeAdvertiser.AdvertiseRoute(netip.PrefixFrom(addr, addr.BitLen())); err != nil {
 			e.logf("failed to advertise route for %s: %v: %v", domain, addr, err)
-			continue
+			return
 		}
-		e.logf("[v2] advertised route for %v: %v", domain, addr)
-
 		e.mu.Lock()
-		e.domains[domain] = append(addrs, addr)
-		e.mu.Unlock()
-	}
+		defer e.mu.Unlock()
+
+		if !slices.Contains(e.domains[domain], addr) {
+			e.logf("[v2] advertised route for %v: %v", domain, addr)
+			e.domains[domain] = append(e.domains[domain], addr)
+		}
+	})
 }

+ 82 - 0
appc/appconnector_test.go

@@ -99,6 +99,7 @@ func TestDomainRoutes(t *testing.T) {
 	a := NewAppConnector(t.Logf, rc)
 	a.updateDomains([]string{"example.com"})
 	a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8"))
+	a.Wait(context.Background())
 
 	want := map[string][]netip.Addr{
 		"example.com": {netip.MustParseAddr("192.0.0.8")},
@@ -110,6 +111,7 @@ func TestDomainRoutes(t *testing.T) {
 }
 
 func TestObserveDNSResponse(t *testing.T) {
+	ctx := context.Background()
 	rc := &appctest.RouteCollector{}
 	a := NewAppConnector(t.Logf, rc)
 
@@ -123,6 +125,26 @@ func TestObserveDNSResponse(t *testing.T) {
 
 	a.updateDomains([]string{"example.com"})
 	a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8"))
+	a.Wait(ctx)
+	if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) {
+		t.Errorf("got %v; want %v", got, want)
+	}
+
+	// a CNAME record chain should result in a route being added if the chain
+	// matches a routed domain.
+	a.updateDomains([]string{"www.example.com", "example.com"})
+	a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.9", "www.example.com.", "chain.example.com.", "example.com."))
+	a.Wait(ctx)
+	wantRoutes = append(wantRoutes, netip.MustParsePrefix("192.0.0.9/32"))
+	if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) {
+		t.Errorf("got %v; want %v", got, want)
+	}
+
+	// a CNAME record chain should result in a route being added if the chain
+	// even if only found in the middle of the chain
+	a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.10", "outside.example.org.", "www.example.com.", "example.org."))
+	a.Wait(ctx)
+	wantRoutes = append(wantRoutes, netip.MustParsePrefix("192.0.0.10/32"))
 	if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) {
 		t.Errorf("got %v; want %v", got, want)
 	}
@@ -130,12 +152,14 @@ func TestObserveDNSResponse(t *testing.T) {
 	wantRoutes = append(wantRoutes, netip.MustParsePrefix("2001:db8::1/128"))
 
 	a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1"))
+	a.Wait(ctx)
 	if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) {
 		t.Errorf("got %v; want %v", got, want)
 	}
 
 	// don't re-advertise routes that have already been advertised
 	a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1"))
+	a.Wait(ctx)
 	if !slices.Equal(rc.Routes(), wantRoutes) {
 		t.Errorf("rc.Routes(): got %v; want %v", rc.Routes(), wantRoutes)
 	}
@@ -145,6 +169,7 @@ func TestObserveDNSResponse(t *testing.T) {
 	a.updateRoutes([]netip.Prefix{pfx})
 	wantRoutes = append(wantRoutes, pfx)
 	a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.2.1"))
+	a.Wait(ctx)
 	if !slices.Equal(rc.Routes(), wantRoutes) {
 		t.Errorf("rc.Routes(): got %v; want %v", rc.Routes(), wantRoutes)
 	}
@@ -154,11 +179,13 @@ func TestObserveDNSResponse(t *testing.T) {
 }
 
 func TestWildcardDomains(t *testing.T) {
+	ctx := context.Background()
 	rc := &appctest.RouteCollector{}
 	a := NewAppConnector(t.Logf, rc)
 
 	a.updateDomains([]string{"*.example.com"})
 	a.ObserveDNSResponse(dnsResponse("foo.example.com.", "192.0.0.8"))
+	a.Wait(ctx)
 	if got, want := rc.Routes(), []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")}; !slices.Equal(got, want) {
 		t.Errorf("routes: got %v; want %v", got, want)
 	}
@@ -218,6 +245,61 @@ func dnsResponse(domain, address string) []byte {
 	return must.Get(b.Finish())
 }
 
+func dnsCNAMEResponse(address string, domains ...string) []byte {
+	addr := netip.MustParseAddr(address)
+	b := dnsmessage.NewBuilder(nil, dnsmessage.Header{})
+	b.EnableCompression()
+	b.StartAnswers()
+
+	if len(domains) >= 2 {
+		for i, domain := range domains[:len(domains)-1] {
+			b.CNAMEResource(
+				dnsmessage.ResourceHeader{
+					Name:  dnsmessage.MustNewName(domain),
+					Type:  dnsmessage.TypeCNAME,
+					Class: dnsmessage.ClassINET,
+					TTL:   0,
+				},
+				dnsmessage.CNAMEResource{
+					CNAME: dnsmessage.MustNewName(domains[i+1]),
+				},
+			)
+		}
+	}
+
+	domain := domains[len(domains)-1]
+
+	switch addr.BitLen() {
+	case 32:
+		b.AResource(
+			dnsmessage.ResourceHeader{
+				Name:  dnsmessage.MustNewName(domain),
+				Type:  dnsmessage.TypeA,
+				Class: dnsmessage.ClassINET,
+				TTL:   0,
+			},
+			dnsmessage.AResource{
+				A: addr.As4(),
+			},
+		)
+	case 128:
+		b.AAAAResource(
+			dnsmessage.ResourceHeader{
+				Name:  dnsmessage.MustNewName(domain),
+				Type:  dnsmessage.TypeAAAA,
+				Class: dnsmessage.ClassINET,
+				TTL:   0,
+			},
+			dnsmessage.AAAAResource{
+				AAAA: addr.As16(),
+			},
+		)
+	default:
+		panic("invalid address length")
+	}
+	return must.Get(b.Finish())
+}
+
 func prefixEqual(a, b netip.Prefix) bool {
 	return a == b
 }

+ 66 - 0
ipn/ipnlocal/peerapi_test.go

@@ -803,6 +803,72 @@ func TestPeerAPIReplyToDNSQueriesAreObserved(t *testing.T) {
 	}
 }
 
+func TestPeerAPIReplyToDNSQueriesAreObservedWithCNAMEFlattening(t *testing.T) {
+	ctx := context.Background()
+	var h peerAPIHandler
+	h.remoteAddr = netip.MustParseAddrPort("100.150.151.152:12345")
+
+	rc := &appctest.RouteCollector{}
+	eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0)
+	pm := must.Get(newProfileManager(new(mem.Store), t.Logf))
+	h.ps = &peerAPIServer{
+		b: &LocalBackend{
+			e:            eng,
+			pm:           pm,
+			store:        pm.Store(),
+			appConnector: appc.NewAppConnector(t.Logf, rc),
+		},
+	}
+	h.ps.b.appConnector.UpdateDomains([]string{"www.example.com"})
+	h.ps.b.appConnector.Wait(ctx)
+
+	h.ps.resolver = &fakeResolver{build: func(b *dnsmessage.Builder) {
+		b.CNAMEResource(
+			dnsmessage.ResourceHeader{
+				Name:  dnsmessage.MustNewName("www.example.com."),
+				Type:  dnsmessage.TypeCNAME,
+				Class: dnsmessage.ClassINET,
+				TTL:   0,
+			},
+			dnsmessage.CNAMEResource{
+				CNAME: dnsmessage.MustNewName("example.com."),
+			},
+		)
+		b.AResource(
+			dnsmessage.ResourceHeader{
+				Name:  dnsmessage.MustNewName("example.com."),
+				Type:  dnsmessage.TypeA,
+				Class: dnsmessage.ClassINET,
+				TTL:   0,
+			},
+			dnsmessage.AResource{
+				A: [4]byte{192, 0, 0, 8},
+			},
+		)
+	}}
+	f := filter.NewAllowAllForTest(logger.Discard)
+	h.ps.b.setFilter(f)
+
+	if !h.ps.b.OfferingAppConnector() {
+		t.Fatal("expecting to be offering app connector")
+	}
+	if !h.replyToDNSQueries() {
+		t.Errorf("unexpectedly deny; wanted to be a DNS server")
+	}
+
+	w := httptest.NewRecorder()
+	h.handleDNSQuery(w, httptest.NewRequest("GET", "/dns-query?q=www.example.com.", nil))
+	if w.Code != http.StatusOK {
+		t.Errorf("unexpected status code: %v", w.Code)
+	}
+	h.ps.b.appConnector.Wait(ctx)
+
+	wantRoutes := []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")}
+	if !slices.Equal(rc.Routes(), wantRoutes) {
+		t.Errorf("got %v; want %v", rc.Routes(), wantRoutes)
+	}
+}
+
 type fakeResolver struct {
 	build func(*dnsmessage.Builder)
 }