cert_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. //go:build !ios && !android && !js
  4. package ipnlocal
  5. import (
  6. "context"
  7. "crypto/ecdsa"
  8. "crypto/elliptic"
  9. "crypto/rand"
  10. "crypto/x509"
  11. "crypto/x509/pkix"
  12. "embed"
  13. "encoding/pem"
  14. "math/big"
  15. "os"
  16. "path/filepath"
  17. "testing"
  18. "time"
  19. "github.com/google/go-cmp/cmp"
  20. "tailscale.com/envknob"
  21. "tailscale.com/ipn/store/mem"
  22. "tailscale.com/tstest"
  23. "tailscale.com/types/logger"
  24. "tailscale.com/util/must"
  25. )
  26. func TestValidLookingCertDomain(t *testing.T) {
  27. tests := []struct {
  28. in string
  29. want bool
  30. }{
  31. {"foo.com", true},
  32. {"foo..com", false},
  33. {"foo/com.com", false},
  34. {"NUL", false},
  35. {"", false},
  36. {"foo\\bar.com", false},
  37. {"foo\x00bar.com", false},
  38. }
  39. for _, tt := range tests {
  40. if got := validLookingCertDomain(tt.in); got != tt.want {
  41. t.Errorf("validLookingCertDomain(%q) = %v, want %v", tt.in, got, tt.want)
  42. }
  43. }
  44. }
  45. //go:embed testdata/*
  46. var certTestFS embed.FS
  47. func TestCertStoreRoundTrip(t *testing.T) {
  48. const testDomain = "example.com"
  49. // Use fixed verification timestamps so validity doesn't change over time.
  50. // If you update the test data below, these may also need to be updated.
  51. testNow := time.Date(2023, time.February, 10, 0, 0, 0, 0, time.UTC)
  52. testExpired := time.Date(2026, time.February, 10, 0, 0, 0, 0, time.UTC)
  53. // To re-generate a root certificate and domain certificate for testing,
  54. // use:
  55. //
  56. // go run filippo.io/mkcert@latest example.com
  57. //
  58. // The content is not important except to be structurally valid so we can be
  59. // sure the round-trip succeeds.
  60. testRoot, err := certTestFS.ReadFile("testdata/rootCA.pem")
  61. if err != nil {
  62. t.Fatal(err)
  63. }
  64. roots := x509.NewCertPool()
  65. if !roots.AppendCertsFromPEM(testRoot) {
  66. t.Fatal("Unable to add test CA to the cert pool")
  67. }
  68. testCert, err := certTestFS.ReadFile("testdata/example.com.pem")
  69. if err != nil {
  70. t.Fatal(err)
  71. }
  72. testKey, err := certTestFS.ReadFile("testdata/example.com-key.pem")
  73. if err != nil {
  74. t.Fatal(err)
  75. }
  76. tests := []struct {
  77. name string
  78. store certStore
  79. debugACMEURL bool
  80. }{
  81. {"FileStore", certFileStore{dir: t.TempDir(), testRoots: roots}, false},
  82. {"FileStore_UnknownCA", certFileStore{dir: t.TempDir()}, true},
  83. {"StateStore", certStateStore{StateStore: new(mem.Store), testRoots: roots}, false},
  84. {"StateStore_UnknownCA", certStateStore{StateStore: new(mem.Store)}, true},
  85. }
  86. for _, test := range tests {
  87. t.Run(test.name, func(t *testing.T) {
  88. if test.debugACMEURL {
  89. t.Setenv("TS_DEBUG_ACME_DIRECTORY_URL", "https://acme-staging-v02.api.letsencrypt.org/directory")
  90. }
  91. if err := test.store.WriteTLSCertAndKey(testDomain, testCert, testKey); err != nil {
  92. t.Fatalf("WriteTLSCertAndKey: unexpected error: %v", err)
  93. }
  94. kp, err := test.store.Read(testDomain, testNow)
  95. if err != nil {
  96. t.Fatalf("Read: unexpected error: %v", err)
  97. }
  98. if diff := cmp.Diff(kp.CertPEM, testCert); diff != "" {
  99. t.Errorf("Certificate (-got, +want):\n%s", diff)
  100. }
  101. if diff := cmp.Diff(kp.KeyPEM, testKey); diff != "" {
  102. t.Errorf("Key (-got, +want):\n%s", diff)
  103. }
  104. unexpected, err := test.store.Read(testDomain, testExpired)
  105. if err != errCertExpired {
  106. t.Fatalf("Read: expected expiry error: %v", string(unexpected.CertPEM))
  107. }
  108. })
  109. }
  110. }
  111. func TestShouldStartDomainRenewal(t *testing.T) {
  112. reset := func() {
  113. renewMu.Lock()
  114. defer renewMu.Unlock()
  115. clear(renewCertAt)
  116. }
  117. mustMakePair := func(template *x509.Certificate) *TLSCertKeyPair {
  118. priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  119. if err != nil {
  120. panic(err)
  121. }
  122. b, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
  123. if err != nil {
  124. panic(err)
  125. }
  126. certPEM := pem.EncodeToMemory(&pem.Block{
  127. Type: "CERTIFICATE",
  128. Bytes: b,
  129. })
  130. return &TLSCertKeyPair{
  131. Cached: false,
  132. CertPEM: certPEM,
  133. KeyPEM: []byte("unused"),
  134. }
  135. }
  136. now := time.Unix(1685714838, 0)
  137. subject := pkix.Name{
  138. Organization: []string{"Tailscale, Inc."},
  139. Country: []string{"CA"},
  140. Province: []string{"ON"},
  141. Locality: []string{"Toronto"},
  142. StreetAddress: []string{"290 Bremner Blvd"},
  143. PostalCode: []string{"M5V 3L9"},
  144. }
  145. testCases := []struct {
  146. name string
  147. notBefore time.Time
  148. lifetime time.Duration
  149. want bool
  150. wantErr string
  151. }{
  152. {
  153. name: "should renew",
  154. notBefore: now.AddDate(0, 0, -89),
  155. lifetime: 90 * 24 * time.Hour,
  156. want: true,
  157. },
  158. {
  159. name: "short-lived renewal",
  160. notBefore: now.AddDate(0, 0, -7),
  161. lifetime: 10 * 24 * time.Hour,
  162. want: true,
  163. },
  164. {
  165. name: "no renew",
  166. notBefore: now.AddDate(0, 0, -59), // 59 days ago == not 2/3rds of the way through 90 days yet
  167. lifetime: 90 * 24 * time.Hour,
  168. want: false,
  169. },
  170. }
  171. b := new(LocalBackend)
  172. for _, tt := range testCases {
  173. t.Run(tt.name, func(t *testing.T) {
  174. reset()
  175. ret, err := b.domainRenewalTimeByExpiry(mustMakePair(&x509.Certificate{
  176. SerialNumber: big.NewInt(2019),
  177. Subject: subject,
  178. NotBefore: tt.notBefore,
  179. NotAfter: tt.notBefore.Add(tt.lifetime),
  180. }))
  181. if tt.wantErr != "" {
  182. if err == nil {
  183. t.Errorf("wanted error, got nil")
  184. } else if err.Error() != tt.wantErr {
  185. t.Errorf("got err=%q, want %q", err.Error(), tt.wantErr)
  186. }
  187. } else {
  188. renew := now.After(ret)
  189. if renew != tt.want {
  190. t.Errorf("got renew=%v (ret=%v), want renew %v", renew, ret, tt.want)
  191. }
  192. }
  193. })
  194. }
  195. }
  196. func TestDebugACMEDirectoryURL(t *testing.T) {
  197. for _, tc := range []string{"", "https://acme-staging-v02.api.letsencrypt.org/directory"} {
  198. const setting = "TS_DEBUG_ACME_DIRECTORY_URL"
  199. t.Run(tc, func(t *testing.T) {
  200. t.Setenv(setting, tc)
  201. ac, err := acmeClient(certStateStore{StateStore: new(mem.Store)})
  202. if err != nil {
  203. t.Fatalf("acmeClient creation err: %v", err)
  204. }
  205. if ac.DirectoryURL != tc {
  206. t.Fatalf("acmeClient.DirectoryURL = %q, want %q", ac.DirectoryURL, tc)
  207. }
  208. })
  209. }
  210. }
  211. func TestGetCertPEMWithValidity(t *testing.T) {
  212. const testDomain = "example.com"
  213. b := &LocalBackend{
  214. store: &mem.Store{},
  215. varRoot: t.TempDir(),
  216. ctx: context.Background(),
  217. logf: t.Logf,
  218. }
  219. certDir, err := b.certDir()
  220. if err != nil {
  221. t.Fatalf("certDir error: %v", err)
  222. }
  223. if _, err := b.getCertStore(); err != nil {
  224. t.Fatalf("getCertStore error: %v", err)
  225. }
  226. testRoot, err := certTestFS.ReadFile("testdata/rootCA.pem")
  227. if err != nil {
  228. t.Fatal(err)
  229. }
  230. roots := x509.NewCertPool()
  231. if !roots.AppendCertsFromPEM(testRoot) {
  232. t.Fatal("Unable to add test CA to the cert pool")
  233. }
  234. testX509Roots = roots
  235. defer func() { testX509Roots = nil }()
  236. tests := []struct {
  237. name string
  238. now time.Time
  239. // storeCerts is true if the test cert and key should be written to store.
  240. storeCerts bool
  241. readOnlyMode bool // TS_READ_ONLY_CERTS env var
  242. wantAsyncRenewal bool // async issuance should be started
  243. wantIssuance bool // sync issuance should be started
  244. wantErr bool
  245. }{
  246. {
  247. name: "valid_no_renewal",
  248. now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC),
  249. storeCerts: true,
  250. wantAsyncRenewal: false,
  251. wantIssuance: false,
  252. wantErr: false,
  253. },
  254. {
  255. name: "issuance_needed",
  256. now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC),
  257. storeCerts: false,
  258. wantAsyncRenewal: false,
  259. wantIssuance: true,
  260. wantErr: false,
  261. },
  262. {
  263. name: "renewal_needed",
  264. now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC),
  265. storeCerts: true,
  266. wantAsyncRenewal: true,
  267. wantIssuance: false,
  268. wantErr: false,
  269. },
  270. {
  271. name: "renewal_needed_read_only_mode",
  272. now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC),
  273. storeCerts: true,
  274. readOnlyMode: true,
  275. wantAsyncRenewal: false,
  276. wantIssuance: false,
  277. wantErr: false,
  278. },
  279. {
  280. name: "no_certs_read_only_mode",
  281. now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC),
  282. storeCerts: false,
  283. readOnlyMode: true,
  284. wantAsyncRenewal: false,
  285. wantIssuance: false,
  286. wantErr: true,
  287. },
  288. }
  289. for _, tt := range tests {
  290. t.Run(tt.name, func(t *testing.T) {
  291. if tt.readOnlyMode {
  292. envknob.Setenv("TS_CERT_SHARE_MODE", "ro")
  293. }
  294. os.RemoveAll(certDir)
  295. if tt.storeCerts {
  296. os.MkdirAll(certDir, 0755)
  297. if err := os.WriteFile(filepath.Join(certDir, "example.com.crt"),
  298. must.Get(os.ReadFile("testdata/example.com.pem")), 0644); err != nil {
  299. t.Fatal(err)
  300. }
  301. if err := os.WriteFile(filepath.Join(certDir, "example.com.key"),
  302. must.Get(os.ReadFile("testdata/example.com-key.pem")), 0644); err != nil {
  303. t.Fatal(err)
  304. }
  305. }
  306. b.clock = tstest.NewClock(tstest.ClockOpts{Start: tt.now})
  307. allDone := make(chan bool, 1)
  308. defer b.goTracker.AddDoneCallback(func() {
  309. b.mu.Lock()
  310. defer b.mu.Unlock()
  311. if b.goTracker.RunningGoroutines() > 0 {
  312. return
  313. }
  314. select {
  315. case allDone <- true:
  316. default:
  317. }
  318. })()
  319. // Set to true if get getCertPEM is called. GetCertPEM can be called in a goroutine for async
  320. // renewal or in the main goroutine if issuance is required to obtain valid TLS credentials.
  321. getCertPemWasCalled := false
  322. getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) {
  323. getCertPemWasCalled = true
  324. return nil, nil
  325. }
  326. prevGoRoutines := b.goTracker.StartedGoroutines()
  327. _, err = b.GetCertPEMWithValidity(context.Background(), testDomain, 0)
  328. if (err != nil) != tt.wantErr {
  329. t.Errorf("b.GetCertPemWithValidity got err %v, wants error: '%v'", err, tt.wantErr)
  330. }
  331. // GetCertPEMWithValidity calls getCertPEM in a goroutine if async renewal is needed. That's the
  332. // only goroutine it starts, so this can be used to test if async renewal was started.
  333. gotAsyncRenewal := b.goTracker.StartedGoroutines()-prevGoRoutines != 0
  334. if gotAsyncRenewal {
  335. select {
  336. case <-time.After(5 * time.Second):
  337. t.Fatal("timed out waiting for goroutines to finish")
  338. case <-allDone:
  339. }
  340. }
  341. // Verify that async renewal was triggered if expected.
  342. if tt.wantAsyncRenewal != gotAsyncRenewal {
  343. t.Fatalf("wants getCertPem to be called async: %v, got called %v", tt.wantAsyncRenewal, gotAsyncRenewal)
  344. }
  345. // Verify that (non-async) issuance was started if expected.
  346. gotIssuance := getCertPemWasCalled && !gotAsyncRenewal
  347. if tt.wantIssuance != gotIssuance {
  348. t.Errorf("wants getCertPem to be called: %v, got called %v", tt.wantIssuance, gotIssuance)
  349. }
  350. })
  351. }
  352. }