distsign_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package distsign
  4. import (
  5. "bytes"
  6. "context"
  7. "crypto/ed25519"
  8. "net/http"
  9. "net/http/httptest"
  10. "net/url"
  11. "os"
  12. "path/filepath"
  13. "strings"
  14. "testing"
  15. "golang.org/x/crypto/blake2s"
  16. )
  17. func TestDownload(t *testing.T) {
  18. srv := newTestServer(t)
  19. c := srv.client(t)
  20. tests := []struct {
  21. desc string
  22. before func(*testing.T)
  23. src string
  24. want []byte
  25. wantErr bool
  26. }{
  27. {
  28. desc: "missing file",
  29. before: func(*testing.T) {},
  30. src: "hello",
  31. wantErr: true,
  32. },
  33. {
  34. desc: "success",
  35. before: func(*testing.T) {
  36. srv.addSigned("hello", []byte("world"))
  37. },
  38. src: "hello",
  39. want: []byte("world"),
  40. },
  41. {
  42. desc: "no signature",
  43. before: func(*testing.T) {
  44. srv.add("hello", []byte("world"))
  45. },
  46. src: "hello",
  47. wantErr: true,
  48. },
  49. {
  50. desc: "bad signature",
  51. before: func(*testing.T) {
  52. srv.add("hello", []byte("world"))
  53. srv.add("hello.sig", []byte("potato"))
  54. },
  55. src: "hello",
  56. wantErr: true,
  57. },
  58. {
  59. desc: "signed with untrusted key",
  60. before: func(t *testing.T) {
  61. srv.add("hello", []byte("world"))
  62. srv.add("hello.sig", newSigningKeyPair(t).sign([]byte("world")))
  63. },
  64. src: "hello",
  65. wantErr: true,
  66. },
  67. {
  68. desc: "signed with root key",
  69. before: func(t *testing.T) {
  70. srv.add("hello", []byte("world"))
  71. srv.add("hello.sig", ed25519.Sign(srv.roots[0].k, []byte("world")))
  72. },
  73. src: "hello",
  74. wantErr: true,
  75. },
  76. {
  77. desc: "bad signing key signature",
  78. before: func(t *testing.T) {
  79. srv.add("distsign.pub.sig", []byte("potato"))
  80. srv.addSigned("hello", []byte("world"))
  81. },
  82. src: "hello",
  83. wantErr: true,
  84. },
  85. }
  86. for _, tt := range tests {
  87. t.Run(tt.desc, func(t *testing.T) {
  88. srv.reset()
  89. tt.before(t)
  90. dst := filepath.Join(t.TempDir(), tt.src)
  91. t.Cleanup(func() {
  92. os.Remove(dst)
  93. })
  94. err := c.Download(context.Background(), tt.src, dst)
  95. if err != nil {
  96. if tt.wantErr {
  97. return
  98. }
  99. t.Fatalf("unexpected error from Download(%q): %v", tt.src, err)
  100. }
  101. if tt.wantErr {
  102. t.Fatalf("Download(%q) succeeded, expected an error", tt.src)
  103. }
  104. got, err := os.ReadFile(dst)
  105. if err != nil {
  106. t.Fatal(err)
  107. }
  108. if !bytes.Equal(tt.want, got) {
  109. t.Errorf("Download(%q): got %q, want %q", tt.src, got, tt.want)
  110. }
  111. })
  112. }
  113. }
  114. func TestValidateLocalBinary(t *testing.T) {
  115. srv := newTestServer(t)
  116. c := srv.client(t)
  117. tests := []struct {
  118. desc string
  119. before func(*testing.T)
  120. src string
  121. wantErr bool
  122. }{
  123. {
  124. desc: "missing file",
  125. before: func(*testing.T) {},
  126. src: "hello",
  127. wantErr: true,
  128. },
  129. {
  130. desc: "success",
  131. before: func(*testing.T) {
  132. srv.addSigned("hello", []byte("world"))
  133. },
  134. src: "hello",
  135. },
  136. {
  137. desc: "contents changed",
  138. before: func(*testing.T) {
  139. srv.addSigned("hello", []byte("new world"))
  140. },
  141. src: "hello",
  142. wantErr: true,
  143. },
  144. {
  145. desc: "no signature",
  146. before: func(*testing.T) {
  147. srv.add("hello", []byte("world"))
  148. },
  149. src: "hello",
  150. wantErr: true,
  151. },
  152. {
  153. desc: "bad signature",
  154. before: func(*testing.T) {
  155. srv.add("hello", []byte("world"))
  156. srv.add("hello.sig", []byte("potato"))
  157. },
  158. src: "hello",
  159. wantErr: true,
  160. },
  161. {
  162. desc: "signed with untrusted key",
  163. before: func(t *testing.T) {
  164. srv.add("hello", []byte("world"))
  165. srv.add("hello.sig", newSigningKeyPair(t).sign([]byte("world")))
  166. },
  167. src: "hello",
  168. wantErr: true,
  169. },
  170. {
  171. desc: "signed with root key",
  172. before: func(t *testing.T) {
  173. srv.add("hello", []byte("world"))
  174. srv.add("hello.sig", ed25519.Sign(srv.roots[0].k, []byte("world")))
  175. },
  176. src: "hello",
  177. wantErr: true,
  178. },
  179. {
  180. desc: "bad signing key signature",
  181. before: func(t *testing.T) {
  182. srv.add("distsign.pub.sig", []byte("potato"))
  183. srv.addSigned("hello", []byte("world"))
  184. },
  185. src: "hello",
  186. wantErr: true,
  187. },
  188. }
  189. for _, tt := range tests {
  190. t.Run(tt.desc, func(t *testing.T) {
  191. srv.reset()
  192. // First just do a successful Download.
  193. want := []byte("world")
  194. srv.addSigned("hello", want)
  195. dst := filepath.Join(t.TempDir(), tt.src)
  196. err := c.Download(context.Background(), tt.src, dst)
  197. if err != nil {
  198. t.Fatalf("unexpected error from Download(%q): %v", tt.src, err)
  199. }
  200. got, err := os.ReadFile(dst)
  201. if err != nil {
  202. t.Fatal(err)
  203. }
  204. if !bytes.Equal(want, got) {
  205. t.Errorf("Download(%q): got %q, want %q", tt.src, got, want)
  206. }
  207. // Now we reset srv with the test case and validate against the local dst.
  208. srv.reset()
  209. tt.before(t)
  210. err = c.ValidateLocalBinary(tt.src, dst)
  211. if err != nil {
  212. if tt.wantErr {
  213. return
  214. }
  215. t.Fatalf("unexpected error from ValidateLocalBinary(%q): %v", tt.src, err)
  216. }
  217. if tt.wantErr {
  218. t.Fatalf("ValidateLocalBinary(%q) succeeded, expected an error", tt.src)
  219. }
  220. })
  221. }
  222. }
  223. func TestRotateRoot(t *testing.T) {
  224. srv := newTestServer(t)
  225. c1 := srv.client(t)
  226. ctx := context.Background()
  227. srv.addSigned("hello", []byte("world"))
  228. if err := c1.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil {
  229. t.Fatalf("Download failed on a fresh server: %v", err)
  230. }
  231. // Remove first root and replace it with a new key.
  232. srv.roots = append(srv.roots[1:], newRootKeyPair(t))
  233. // Old client can still download files because it still trusts the old
  234. // root key.
  235. if err := c1.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil {
  236. t.Fatalf("Download failed after root rotation on old client: %v", err)
  237. }
  238. // New client should fail download because current signing key is signed by
  239. // the revoked root that new client doesn't trust.
  240. c2 := srv.client(t)
  241. if err := c2.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err == nil {
  242. t.Fatalf("Download succeeded on new client, but signing key is signed with revoked root key")
  243. }
  244. // Re-sign signing key with another valid root that client still trusts.
  245. srv.resignSigningKeys()
  246. // Both old and new clients should now be able to download.
  247. //
  248. // Note: we don't need to re-sign the "hello" file because signing key
  249. // didn't change (only signing key's signature).
  250. if err := c1.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil {
  251. t.Fatalf("Download failed after root rotation on old client with re-signed signing key: %v", err)
  252. }
  253. if err := c2.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil {
  254. t.Fatalf("Download failed after root rotation on new client with re-signed signing key: %v", err)
  255. }
  256. }
  257. func TestRotateSigning(t *testing.T) {
  258. srv := newTestServer(t)
  259. c := srv.client(t)
  260. ctx := context.Background()
  261. srv.addSigned("hello", []byte("world"))
  262. if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil {
  263. t.Fatalf("Download failed on a fresh server: %v", err)
  264. }
  265. // Replace signing key but don't publish it yet.
  266. srv.sign = append(srv.sign, newSigningKeyPair(t))
  267. if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil {
  268. t.Fatalf("Download failed after new signing key added but before publishing it: %v", err)
  269. }
  270. // Publish new signing key bundle with both keys.
  271. srv.resignSigningKeys()
  272. if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil {
  273. t.Fatalf("Download failed after new signing key was published: %v", err)
  274. }
  275. // Re-sign the "hello" file with new signing key.
  276. srv.add("hello.sig", srv.sign[1].sign([]byte("world")))
  277. if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil {
  278. t.Fatalf("Download failed after re-signing with new signing key: %v", err)
  279. }
  280. // Drop the old signing key.
  281. srv.sign = srv.sign[1:]
  282. srv.resignSigningKeys()
  283. if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil {
  284. t.Fatalf("Download failed after removing old signing key: %v", err)
  285. }
  286. // Add another key and re-sign the file with it *before* publishing.
  287. srv.sign = append(srv.sign, newSigningKeyPair(t))
  288. srv.add("hello.sig", srv.sign[1].sign([]byte("world")))
  289. if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err == nil {
  290. t.Fatalf("Download succeeded when signed with a not-yet-published signing key")
  291. }
  292. // Fix this by publishing the new key.
  293. srv.resignSigningKeys()
  294. if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil {
  295. t.Fatalf("Download failed after publishing new signing key: %v", err)
  296. }
  297. }
  298. func TestParseRootKey(t *testing.T) {
  299. tests := []struct {
  300. desc string
  301. generate func() ([]byte, []byte, error)
  302. wantErr bool
  303. }{
  304. {
  305. desc: "valid",
  306. generate: GenerateRootKey,
  307. },
  308. {
  309. desc: "signing",
  310. generate: GenerateSigningKey,
  311. wantErr: true,
  312. },
  313. {
  314. desc: "nil",
  315. generate: func() ([]byte, []byte, error) { return nil, nil, nil },
  316. wantErr: true,
  317. },
  318. {
  319. desc: "invalid PEM tag",
  320. generate: func() ([]byte, []byte, error) {
  321. priv, pub, err := GenerateRootKey()
  322. priv = bytes.Replace(priv, []byte("ROOT "), nil, -1)
  323. return priv, pub, err
  324. },
  325. wantErr: true,
  326. },
  327. {
  328. desc: "not PEM",
  329. generate: func() ([]byte, []byte, error) { return []byte("s3cr3t"), nil, nil },
  330. wantErr: true,
  331. },
  332. }
  333. for _, tt := range tests {
  334. t.Run(tt.desc, func(t *testing.T) {
  335. priv, _, err := tt.generate()
  336. if err != nil {
  337. t.Fatal(err)
  338. }
  339. r, err := ParseRootKey(priv)
  340. if err != nil {
  341. if tt.wantErr {
  342. return
  343. }
  344. t.Fatalf("unexpected error: %v", err)
  345. }
  346. if tt.wantErr {
  347. t.Fatal("expected non-nil error")
  348. }
  349. if r == nil {
  350. t.Errorf("got nil error and nil RootKey")
  351. }
  352. })
  353. }
  354. }
  355. func TestParseSigningKey(t *testing.T) {
  356. tests := []struct {
  357. desc string
  358. generate func() ([]byte, []byte, error)
  359. wantErr bool
  360. }{
  361. {
  362. desc: "valid",
  363. generate: GenerateSigningKey,
  364. },
  365. {
  366. desc: "root",
  367. generate: GenerateRootKey,
  368. wantErr: true,
  369. },
  370. {
  371. desc: "nil",
  372. generate: func() ([]byte, []byte, error) { return nil, nil, nil },
  373. wantErr: true,
  374. },
  375. {
  376. desc: "invalid PEM tag",
  377. generate: func() ([]byte, []byte, error) {
  378. priv, pub, err := GenerateSigningKey()
  379. priv = bytes.Replace(priv, []byte("SIGNING "), nil, -1)
  380. return priv, pub, err
  381. },
  382. wantErr: true,
  383. },
  384. {
  385. desc: "not PEM",
  386. generate: func() ([]byte, []byte, error) { return []byte("s3cr3t"), nil, nil },
  387. wantErr: true,
  388. },
  389. }
  390. for _, tt := range tests {
  391. t.Run(tt.desc, func(t *testing.T) {
  392. priv, _, err := tt.generate()
  393. if err != nil {
  394. t.Fatal(err)
  395. }
  396. r, err := ParseSigningKey(priv)
  397. if err != nil {
  398. if tt.wantErr {
  399. return
  400. }
  401. t.Fatalf("unexpected error: %v", err)
  402. }
  403. if tt.wantErr {
  404. t.Fatal("expected non-nil error")
  405. }
  406. if r == nil {
  407. t.Errorf("got nil error and nil SigningKey")
  408. }
  409. })
  410. }
  411. }
  412. type testServer struct {
  413. roots []rootKeyPair
  414. sign []signingKeyPair
  415. files map[string][]byte
  416. srv *httptest.Server
  417. }
  418. func newTestServer(t *testing.T) *testServer {
  419. var roots []rootKeyPair
  420. for range 3 {
  421. roots = append(roots, newRootKeyPair(t))
  422. }
  423. ts := &testServer{
  424. roots: roots,
  425. sign: []signingKeyPair{newSigningKeyPair(t)},
  426. }
  427. ts.reset()
  428. ts.srv = httptest.NewServer(ts)
  429. t.Cleanup(ts.srv.Close)
  430. return ts
  431. }
  432. func (s *testServer) client(t *testing.T) *Client {
  433. roots := make([]ed25519.PublicKey, 0, len(s.roots))
  434. for _, r := range s.roots {
  435. pub, err := parseSinglePublicKey(r.pubRaw, pemTypeRootPublic)
  436. if err != nil {
  437. t.Fatalf("parsePublicKey: %v", err)
  438. }
  439. roots = append(roots, pub)
  440. }
  441. u, err := url.Parse(s.srv.URL)
  442. if err != nil {
  443. t.Fatal(err)
  444. }
  445. return &Client{
  446. logf: t.Logf,
  447. roots: roots,
  448. pkgsAddr: u,
  449. }
  450. }
  451. func (s *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  452. path := strings.TrimPrefix(r.URL.Path, "/")
  453. data, ok := s.files[path]
  454. if !ok {
  455. http.NotFound(w, r)
  456. return
  457. }
  458. w.Write(data)
  459. }
  460. func (s *testServer) addSigned(name string, data []byte) {
  461. s.files[name] = data
  462. s.files[name+".sig"] = s.sign[0].sign(data)
  463. }
  464. func (s *testServer) add(name string, data []byte) {
  465. s.files[name] = data
  466. }
  467. func (s *testServer) reset() {
  468. s.files = make(map[string][]byte)
  469. s.resignSigningKeys()
  470. }
  471. func (s *testServer) resignSigningKeys() {
  472. var pubs [][]byte
  473. for _, k := range s.sign {
  474. pubs = append(pubs, k.pubRaw)
  475. }
  476. bundle := bytes.Join(pubs, []byte("\n"))
  477. sig := s.roots[0].sign(bundle)
  478. s.files["distsign.pub"] = bundle
  479. s.files["distsign.pub.sig"] = sig
  480. }
  481. type rootKeyPair struct {
  482. *RootKey
  483. keyPair
  484. }
  485. func newRootKeyPair(t *testing.T) rootKeyPair {
  486. privRaw, pubRaw, err := GenerateRootKey()
  487. if err != nil {
  488. t.Fatalf("GenerateRootKey: %v", err)
  489. }
  490. kp := keyPair{
  491. privRaw: privRaw,
  492. pubRaw: pubRaw,
  493. }
  494. priv, err := parsePrivateKey(kp.privRaw, pemTypeRootPrivate)
  495. if err != nil {
  496. t.Fatalf("parsePrivateKey: %v", err)
  497. }
  498. return rootKeyPair{
  499. RootKey: &RootKey{k: priv},
  500. keyPair: kp,
  501. }
  502. }
  503. func (s rootKeyPair) sign(bundle []byte) []byte {
  504. sig, err := s.SignSigningKeys(bundle)
  505. if err != nil {
  506. panic(err)
  507. }
  508. return sig
  509. }
  510. type signingKeyPair struct {
  511. *SigningKey
  512. keyPair
  513. }
  514. func newSigningKeyPair(t *testing.T) signingKeyPair {
  515. privRaw, pubRaw, err := GenerateSigningKey()
  516. if err != nil {
  517. t.Fatalf("GenerateSigningKey: %v", err)
  518. }
  519. kp := keyPair{
  520. privRaw: privRaw,
  521. pubRaw: pubRaw,
  522. }
  523. priv, err := parsePrivateKey(kp.privRaw, pemTypeSigningPrivate)
  524. if err != nil {
  525. t.Fatalf("parsePrivateKey: %v", err)
  526. }
  527. return signingKeyPair{
  528. SigningKey: &SigningKey{k: priv},
  529. keyPair: kp,
  530. }
  531. }
  532. func (s signingKeyPair) sign(blob []byte) []byte {
  533. hash := blake2s.Sum256(blob)
  534. sig, err := s.SignPackageHash(hash[:], int64(len(blob)))
  535. if err != nil {
  536. panic(err)
  537. }
  538. return sig
  539. }
  540. type keyPair struct {
  541. privRaw []byte
  542. pubRaw []byte
  543. }