Sfoglia il codice sorgente

ipn/ipnlocal: add wildcard TLS certificate support for subdomains (#18356)

When the NodeAttrDNSSubdomainResolve capability is present, enable
wildcard certificate issuance to cover all single-level subdomains
of a node's CertDomain.

Without the capability, only exact CertDomain matches are allowed,
so node.ts.net yields a cert for node.ts.net. With the capability,
we now generate wildcard certificates. Wildcard certs include both
the wildcard and base domain in their SANs, and ACME authorization
requests both identifiers. The cert filenames are kept still based
on the base domain with the wildcard prefix stripped, so we aren't
creating separate files. DNS challenges still used the base domain

The checkCertDomain function is replaced by resolveCertDomain that
both validates and returns the appropriate cert domain to request.
Name validation is now moved earlier into GetCertPEMWithValidity()

Fixes #1196

Signed-off-by: Fernando Serboncini <[email protected]>
Fernando Serboncini 1 mese fa
parent
commit
5edfa6f9a8
2 ha cambiato i file con 312 aggiunte e 40 eliminazioni
  1. 101 34
      ipn/ipnlocal/cert.go
  2. 211 6
      ipn/ipnlocal/cert_test.go

+ 101 - 34
ipn/ipnlocal/cert.go

@@ -37,7 +37,6 @@ import (
 	"tailscale.com/feature/buildfeatures"
 	"tailscale.com/hostinfo"
 	"tailscale.com/ipn"
-	"tailscale.com/ipn/ipnstate"
 	"tailscale.com/ipn/store"
 	"tailscale.com/ipn/store/mem"
 	"tailscale.com/net/bakedroots"
@@ -106,6 +105,13 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK
 //
 // If a cert is expired, or expires sooner than minValidity, it will be renewed
 // synchronously. Otherwise it will be renewed asynchronously.
+//
+// The domain must be one of:
+//
+//   - An exact CertDomain (e.g., "node.ts.net")
+//   - A wildcard domain (e.g., "*.node.ts.net")
+//
+// The wildcard format requires the NodeAttrDNSSubdomainResolve capability.
 func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string, minValidity time.Duration) (*TLSCertKeyPair, error) {
 	b.mu.Lock()
 	getCertForTest := b.getCertForTest
@@ -119,6 +125,13 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
 	if !validLookingCertDomain(domain) {
 		return nil, errors.New("invalid domain")
 	}
+
+	certDomain, err := b.resolveCertDomain(domain)
+	if err != nil {
+		return nil, err
+	}
+	storageKey := strings.TrimPrefix(certDomain, "*.")
+
 	logf := logger.WithPrefix(b.logf, fmt.Sprintf("cert(%q): ", domain))
 	now := b.clock.Now()
 	traceACME := func(v any) {
@@ -134,13 +147,13 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
 		return nil, err
 	}
 
-	if pair, err := getCertPEMCached(cs, domain, now); err == nil {
+	if pair, err := getCertPEMCached(cs, storageKey, now); err == nil {
 		if envknob.IsCertShareReadOnlyMode() {
 			return pair, nil
 		}
 		// If we got here, we have a valid unexpired cert.
 		// Check whether we should start an async renewal.
-		shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair, minValidity)
+		shouldRenew, err := b.shouldStartDomainRenewal(cs, storageKey, now, pair, minValidity)
 		if err != nil {
 			logf("error checking for certificate renewal: %v", err)
 			// Renewal check failed, but the current cert is valid and not
@@ -154,7 +167,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
 			logf("starting async renewal")
 			// Start renewal in the background, return current valid cert.
 			b.goTracker.Go(func() {
-				if _, err := getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity); err != nil {
+				if _, err := getCertPEM(context.Background(), b, cs, logf, traceACME, certDomain, now, minValidity); err != nil {
 					logf("async renewal failed: getCertPem: %v", err)
 				}
 			})
@@ -169,7 +182,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string
 		return nil, fmt.Errorf("retrieving cached TLS certificate failed and cert store is configured in read-only mode, not attempting to issue a new certificate: %w", err)
 	}
 
-	pair, err := getCertPEM(ctx, b, cs, logf, traceACME, domain, now, minValidity)
+	pair, err := getCertPEM(ctx, b, cs, logf, traceACME, certDomain, now, minValidity)
 	if err != nil {
 		logf("getCertPEM: %v", err)
 		return nil, err
@@ -506,19 +519,24 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey
 }
 
 // getCertPem checks if a cert needs to be renewed and if so, renews it.
+// domain is the resolved cert domain (e.g., "*.node.ts.net" for wildcards).
 // It can be overridden in tests.
 var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
 	acmeMu.Lock()
 	defer acmeMu.Unlock()
 
+	// storageKey is used for file storage and renewal tracking.
+	// For wildcards, "*.node.ts.net" -> "node.ts.net"
+	storageKey, isWildcard := strings.CutPrefix(domain, "*.")
+
 	// In case this method was triggered multiple times in parallel (when
 	// serving incoming requests), check whether one of the other goroutines
 	// already renewed the cert before us.
-	previous, err := getCertPEMCached(cs, domain, now)
+	previous, err := getCertPEMCached(cs, storageKey, now)
 	if err == nil {
 		// shouldStartDomainRenewal caches its result so it's OK to call this
 		// frequently.
-		shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, previous, minValidity)
+		shouldRenew, err := b.shouldStartDomainRenewal(cs, storageKey, now, previous, minValidity)
 		if err != nil {
 			logf("error checking for certificate renewal: %v", err)
 		} else if !shouldRenew {
@@ -561,12 +579,6 @@ var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf l
 		return nil, fmt.Errorf("unexpected ACME account status %q", a.Status)
 	}
 
-	// Before hitting LetsEncrypt, see if this is a domain that Tailscale will do DNS challenges for.
-	st := b.StatusWithoutPeers()
-	if err := checkCertDomain(st, domain); err != nil {
-		return nil, err
-	}
-
 	// If we have a previous cert, include it in the order. Assuming we're
 	// within the ARI renewal window this should exclude us from LE rate
 	// limits.
@@ -580,7 +592,18 @@ var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf l
 			opts = append(opts, acme.WithOrderReplacesCert(prevCrt))
 		}
 	}
-	order, err := ac.AuthorizeOrder(ctx, []acme.AuthzID{{Type: "dns", Value: domain}}, opts...)
+
+	// For wildcards, we need to authorize both the wildcard and base domain.
+	var authzIDs []acme.AuthzID
+	if isWildcard {
+		authzIDs = []acme.AuthzID{
+			{Type: "dns", Value: domain},
+			{Type: "dns", Value: storageKey},
+		}
+	} else {
+		authzIDs = []acme.AuthzID{{Type: "dns", Value: domain}}
+	}
+	order, err := ac.AuthorizeOrder(ctx, authzIDs, opts...)
 	if err != nil {
 		return nil, err
 	}
@@ -598,7 +621,9 @@ var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf l
 				if err != nil {
 					return nil, err
 				}
-				key := "_acme-challenge." + domain
+				// For wildcards, the challenge is on the base domain.
+				// e.g., "*.node.ts.net" -> "_acme-challenge.node.ts.net"
+				key := "_acme-challenge." + strings.TrimPrefix(az.Identifier.Value, "*.")
 
 				// Do a best-effort lookup to see if we've already created this DNS name
 				// in a previous attempt. Don't burn too much time on it, though. Worst
@@ -608,14 +633,14 @@ var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf l
 				txts, _ := resolver.LookupTXT(lookupCtx, key)
 				lookupCancel()
 				if slices.Contains(txts, rec) {
-					logf("TXT record already existed")
+					logf("TXT record already existed for %s", key)
 				} else {
-					logf("starting SetDNS call...")
+					logf("starting SetDNS call for %s...", key)
 					err = b.SetDNS(ctx, key, rec)
 					if err != nil {
 						return nil, fmt.Errorf("SetDNS %q => %q: %w", key, rec, err)
 					}
-					logf("did SetDNS")
+					logf("did SetDNS for %s", key)
 				}
 
 				chal, err := ac.Accept(ctx, ch)
@@ -672,19 +697,27 @@ var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf l
 			return nil, err
 		}
 	}
-	if err := cs.WriteTLSCertAndKey(domain, certPEM.Bytes(), privPEM.Bytes()); err != nil {
+	if err := cs.WriteTLSCertAndKey(storageKey, certPEM.Bytes(), privPEM.Bytes()); err != nil {
 		return nil, err
 	}
-	b.domainRenewed(domain)
+	b.domainRenewed(storageKey)
 
 	return &TLSCertKeyPair{CertPEM: certPEM.Bytes(), KeyPEM: privPEM.Bytes()}, nil
 }
 
-// certRequest generates a CSR for the given common name cn and optional SANs.
-func certRequest(key crypto.Signer, name string, ext []pkix.Extension) ([]byte, error) {
+// certRequest generates a CSR for the given domain and optional SANs.
+func certRequest(key crypto.Signer, domain string, ext []pkix.Extension) ([]byte, error) {
+	dnsNames := []string{domain}
+	if base, ok := strings.CutPrefix(domain, "*."); ok {
+		// Wildcard cert must also include the base domain as a SAN.
+		// This is load-bearing: getCertPEMCached validates certs using
+		// the storage key (base domain), which only passes x509 verification
+		// if the base domain is in DNSNames.
+		dnsNames = append(dnsNames, base)
+	}
 	req := &x509.CertificateRequest{
-		Subject:         pkix.Name{CommonName: name},
-		DNSNames:        []string{name},
+		Subject:         pkix.Name{CommonName: domain},
+		DNSNames:        dnsNames,
 		ExtraExtensions: ext,
 	}
 	return x509.CreateCertificateRequest(rand.Reader, req, key)
@@ -844,7 +877,7 @@ func isDefaultDirectoryURL(u string) bool {
 // we might be able to get a cert for.
 //
 // It's a light check primarily for double checking before it's used
-// as part of a filesystem path. The actual validation happens in checkCertDomain.
+// as part of a filesystem path. The actual validation happens in resolveCertDomain.
 func validLookingCertDomain(name string) bool {
 	if name == "" ||
 		strings.Contains(name, "..") ||
@@ -852,22 +885,56 @@ func validLookingCertDomain(name string) bool {
 		!strings.Contains(name, ".") {
 		return false
 	}
+	// Only allow * as a wildcard prefix "*.domain.tld"
+	if rest, ok := strings.CutPrefix(name, "*."); ok {
+		if strings.Contains(rest, "*") || !strings.Contains(rest, ".") {
+			return false
+		}
+	} else if strings.Contains(name, "*") {
+		return false
+	}
 	return true
 }
 
-func checkCertDomain(st *ipnstate.Status, domain string) error {
+// resolveCertDomain validates a domain and returns the cert domain to use.
+//
+//   - "node.ts.net" -> "node.ts.net" (exact CertDomain match)
+//   - "*.node.ts.net" -> "*.node.ts.net" (explicit wildcard, requires NodeAttrDNSSubdomainResolve)
+//
+// Subdomain requests like "app.node.ts.net" are rejected; callers should
+// request "*.node.ts.net" explicitly for subdomain coverage.
+func (b *LocalBackend) resolveCertDomain(domain string) (string, error) {
 	if domain == "" {
-		return errors.New("missing domain name")
+		return "", errors.New("missing domain name")
 	}
-	for _, d := range st.CertDomains {
-		if d == domain {
-			return nil
+
+	// Read the netmap once to get both CertDomains and capabilities atomically.
+	nm := b.NetMap()
+	if nm == nil {
+		return "", errors.New("no netmap available")
+	}
+	certDomains := nm.DNS.CertDomains
+	if len(certDomains) == 0 {
+		return "", errors.New("your Tailscale account does not support getting TLS certs")
+	}
+
+	// Wildcard request like "*.node.ts.net".
+	if base, ok := strings.CutPrefix(domain, "*."); ok {
+		if !nm.AllCaps.Contains(tailcfg.NodeAttrDNSSubdomainResolve) {
+			return "", fmt.Errorf("wildcard certificates are not enabled for this node")
 		}
+		if !slices.Contains(certDomains, base) {
+			return "", fmt.Errorf("invalid domain %q; parent domain must be one of %q", domain, certDomains)
+		}
+		return domain, nil
 	}
-	if len(st.CertDomains) == 0 {
-		return errors.New("your Tailscale account does not support getting TLS certs")
+
+	// Exact CertDomain match.
+	if slices.Contains(certDomains, domain) {
+		return domain, nil
 	}
-	return fmt.Errorf("invalid domain %q; must be one of %q", domain, st.CertDomains)
+
+	return "", fmt.Errorf("invalid domain %q; must be one of %q", domain, certDomains)
 }
 
 // handleC2NTLSCertStatus returns info about the last TLS certificate issued for the
@@ -884,7 +951,7 @@ func handleC2NTLSCertStatus(b *LocalBackend, w http.ResponseWriter, r *http.Requ
 		return
 	}
 
-	domain := r.FormValue("domain")
+	domain := strings.TrimPrefix(r.FormValue("domain"), "*.")
 	if domain == "" {
 		http.Error(w, "no 'domain'", http.StatusBadRequest)
 		return

+ 211 - 6
ipn/ipnlocal/cert_test.go

@@ -17,17 +17,205 @@ import (
 	"math/big"
 	"os"
 	"path/filepath"
+	"slices"
 	"testing"
 	"time"
 
 	"github.com/google/go-cmp/cmp"
 	"tailscale.com/envknob"
 	"tailscale.com/ipn/store/mem"
+	"tailscale.com/tailcfg"
 	"tailscale.com/tstest"
 	"tailscale.com/types/logger"
+	"tailscale.com/types/netmap"
 	"tailscale.com/util/must"
+	"tailscale.com/util/set"
 )
 
+func TestCertRequest(t *testing.T) {
+	key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	if err != nil {
+		t.Fatalf("GenerateKey: %v", err)
+	}
+
+	tests := []struct {
+		domain   string
+		wantSANs []string
+	}{
+		{
+			domain:   "example.com",
+			wantSANs: []string{"example.com"},
+		},
+		{
+			domain:   "*.example.com",
+			wantSANs: []string{"*.example.com", "example.com"},
+		},
+		{
+			domain:   "*.foo.bar.com",
+			wantSANs: []string{"*.foo.bar.com", "foo.bar.com"},
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.domain, func(t *testing.T) {
+			csrDER, err := certRequest(key, tt.domain, nil)
+			if err != nil {
+				t.Fatalf("certRequest: %v", err)
+			}
+			csr, err := x509.ParseCertificateRequest(csrDER)
+			if err != nil {
+				t.Fatalf("ParseCertificateRequest: %v", err)
+			}
+			if csr.Subject.CommonName != tt.domain {
+				t.Errorf("CommonName = %q, want %q", csr.Subject.CommonName, tt.domain)
+			}
+			if !slices.Equal(csr.DNSNames, tt.wantSANs) {
+				t.Errorf("DNSNames = %v, want %v", csr.DNSNames, tt.wantSANs)
+			}
+		})
+	}
+}
+
+func TestResolveCertDomain(t *testing.T) {
+	tests := []struct {
+		name        string
+		domain      string
+		certDomains []string
+		hasCap      bool
+		skipNetmap  bool
+		want        string
+		wantErr     string
+	}{
+		{
+			name:        "exact_match",
+			domain:      "node.ts.net",
+			certDomains: []string{"node.ts.net"},
+			want:        "node.ts.net",
+		},
+		{
+			name:        "exact_match_with_cap",
+			domain:      "node.ts.net",
+			certDomains: []string{"node.ts.net"},
+			hasCap:      true,
+			want:        "node.ts.net",
+		},
+		{
+			name:        "wildcard_with_cap",
+			domain:      "*.node.ts.net",
+			certDomains: []string{"node.ts.net"},
+			hasCap:      true,
+			want:        "*.node.ts.net",
+		},
+		{
+			name:        "wildcard_without_cap",
+			domain:      "*.node.ts.net",
+			certDomains: []string{"node.ts.net"},
+			hasCap:      false,
+			wantErr:     "wildcard certificates are not enabled for this node",
+		},
+		{
+			name:        "subdomain_with_cap_rejected",
+			domain:      "app.node.ts.net",
+			certDomains: []string{"node.ts.net"},
+			hasCap:      true,
+			wantErr:     `invalid domain "app.node.ts.net"; must be one of ["node.ts.net"]`,
+		},
+		{
+			name:        "subdomain_without_cap_rejected",
+			domain:      "app.node.ts.net",
+			certDomains: []string{"node.ts.net"},
+			hasCap:      false,
+			wantErr:     `invalid domain "app.node.ts.net"; must be one of ["node.ts.net"]`,
+		},
+		{
+			name:        "multi_level_subdomain_rejected",
+			domain:      "a.b.node.ts.net",
+			certDomains: []string{"node.ts.net"},
+			hasCap:      true,
+			wantErr:     `invalid domain "a.b.node.ts.net"; must be one of ["node.ts.net"]`,
+		},
+		{
+			name:        "wildcard_no_matching_parent",
+			domain:      "*.unrelated.ts.net",
+			certDomains: []string{"node.ts.net"},
+			hasCap:      true,
+			wantErr:     `invalid domain "*.unrelated.ts.net"; parent domain must be one of ["node.ts.net"]`,
+		},
+		{
+			name:        "subdomain_unrelated_rejected",
+			domain:      "app.unrelated.ts.net",
+			certDomains: []string{"node.ts.net"},
+			hasCap:      true,
+			wantErr:     `invalid domain "app.unrelated.ts.net"; must be one of ["node.ts.net"]`,
+		},
+		{
+			name:        "no_cert_domains",
+			domain:      "node.ts.net",
+			certDomains: nil,
+			wantErr:     "your Tailscale account does not support getting TLS certs",
+		},
+		{
+			name:        "wildcard_no_cert_domains",
+			domain:      "*.foo.ts.net",
+			certDomains: nil,
+			hasCap:      true,
+			wantErr:     "your Tailscale account does not support getting TLS certs",
+		},
+		{
+			name:        "empty_domain",
+			domain:      "",
+			certDomains: []string{"node.ts.net"},
+			wantErr:     "missing domain name",
+		},
+		{
+			name:       "nil_netmap",
+			domain:     "node.ts.net",
+			skipNetmap: true,
+			wantErr:    "no netmap available",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			b := newTestLocalBackend(t)
+
+			if !tt.skipNetmap {
+				// Set up netmap with CertDomains and capability
+				var allCaps set.Set[tailcfg.NodeCapability]
+				if tt.hasCap {
+					allCaps = set.Of(tailcfg.NodeAttrDNSSubdomainResolve)
+				}
+				b.mu.Lock()
+				b.currentNode().SetNetMap(&netmap.NetworkMap{
+					SelfNode: (&tailcfg.Node{}).View(),
+					DNS: tailcfg.DNSConfig{
+						CertDomains: tt.certDomains,
+					},
+					AllCaps: allCaps,
+				})
+				b.mu.Unlock()
+			}
+
+			got, err := b.resolveCertDomain(tt.domain)
+			if tt.wantErr != "" {
+				if err == nil {
+					t.Errorf("resolveCertDomain(%q) = %q, want error %q", tt.domain, got, tt.wantErr)
+				} else if err.Error() != tt.wantErr {
+					t.Errorf("resolveCertDomain(%q) error = %q, want %q", tt.domain, err.Error(), tt.wantErr)
+				}
+				return
+			}
+			if err != nil {
+				t.Errorf("resolveCertDomain(%q) error = %v, want nil", tt.domain, err)
+				return
+			}
+			if got != tt.want {
+				t.Errorf("resolveCertDomain(%q) = %q, want %q", tt.domain, got, tt.want)
+			}
+		})
+	}
+}
+
 func TestValidLookingCertDomain(t *testing.T) {
 	tests := []struct {
 		in   string
@@ -40,6 +228,16 @@ func TestValidLookingCertDomain(t *testing.T) {
 		{"", false},
 		{"foo\\bar.com", false},
 		{"foo\x00bar.com", false},
+		// Wildcard tests
+		{"*.foo.com", true},
+		{"*.foo.bar.com", true},
+		{"*foo.com", false},      // must be *.
+		{"*.com", false},         // must have domain after *.
+		{"*.", false},            // must have domain after *.
+		{"*.*.foo.com", false},   // no nested wildcards
+		{"foo.*.bar.com", false}, // no wildcard mid-string
+		{"app.foo.com", true},    // regular subdomain
+		{"*", false},             // bare asterisk
 	}
 	for _, tt := range tests {
 		if got := validLookingCertDomain(tt.in); got != tt.want {
@@ -231,12 +429,19 @@ func TestDebugACMEDirectoryURL(t *testing.T) {
 
 func TestGetCertPEMWithValidity(t *testing.T) {
 	const testDomain = "example.com"
-	b := &LocalBackend{
-		store:   &mem.Store{},
-		varRoot: t.TempDir(),
-		ctx:     context.Background(),
-		logf:    t.Logf,
-	}
+	b := newTestLocalBackend(t)
+	b.varRoot = t.TempDir()
+
+	// Set up netmap with CertDomains so resolveCertDomain works
+	b.mu.Lock()
+	b.currentNode().SetNetMap(&netmap.NetworkMap{
+		SelfNode: (&tailcfg.Node{}).View(),
+		DNS: tailcfg.DNSConfig{
+			CertDomains: []string{testDomain},
+		},
+	})
+	b.mu.Unlock()
+
 	certDir, err := b.certDir()
 	if err != nil {
 		t.Fatalf("certDir error: %v", err)