|
|
@@ -7,6 +7,7 @@ import (
|
|
|
"bytes"
|
|
|
"context"
|
|
|
"encoding/json"
|
|
|
+ "fmt"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
@@ -877,3 +878,135 @@ func TestTKAForceDisable(t *testing.T) {
|
|
|
t.Fatal("tka was re-initalized")
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func TestTKAAffectedSigs(t *testing.T) {
|
|
|
+ nodePriv := key.NewNode()
|
|
|
+ // toSign := key.NewNode()
|
|
|
+ nlPriv := key.NewNLPrivate()
|
|
|
+
|
|
|
+ pm := must.Get(newProfileManager(new(mem.Store), t.Logf))
|
|
|
+ must.Do(pm.SetPrefs((&ipn.Prefs{
|
|
|
+ Persist: &persist.Persist{
|
|
|
+ PrivateNodeKey: nodePriv,
|
|
|
+ NetworkLockKey: nlPriv,
|
|
|
+ },
|
|
|
+ }).View()))
|
|
|
+
|
|
|
+ // Make a fake TKA authority, to seed local state.
|
|
|
+ disablementSecret := bytes.Repeat([]byte{0xa5}, 32)
|
|
|
+ tkaKey := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2}
|
|
|
+
|
|
|
+ temp := t.TempDir()
|
|
|
+ tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID))
|
|
|
+ os.Mkdir(tkaPath, 0755)
|
|
|
+ chonk, err := tka.ChonkDir(tkaPath)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ authority, _, err := tka.Create(chonk, tka.State{
|
|
|
+ Keys: []tka.Key{tkaKey},
|
|
|
+ DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)},
|
|
|
+ }, nlPriv)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("tka.Create() failed: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ untrustedKey := key.NewNLPrivate()
|
|
|
+ tcs := []struct {
|
|
|
+ name string
|
|
|
+ makeSig func() *tka.NodeKeySignature
|
|
|
+ wantErr string
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ "no error",
|
|
|
+ func() *tka.NodeKeySignature {
|
|
|
+ sig, _ := signNodeKey(tailcfg.TKASignInfo{NodePublic: nodePriv.Public()}, nlPriv)
|
|
|
+ return sig
|
|
|
+ },
|
|
|
+ "",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "signature for different keyID",
|
|
|
+ func() *tka.NodeKeySignature {
|
|
|
+ sig, _ := signNodeKey(tailcfg.TKASignInfo{NodePublic: nodePriv.Public()}, untrustedKey)
|
|
|
+ return sig
|
|
|
+ },
|
|
|
+ fmt.Sprintf("got signature with keyID %X from request for %X", untrustedKey.KeyID(), nlPriv.KeyID()),
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "invalid signature",
|
|
|
+ func() *tka.NodeKeySignature {
|
|
|
+ sig, _ := signNodeKey(tailcfg.TKASignInfo{NodePublic: nodePriv.Public()}, nlPriv)
|
|
|
+ copy(sig.Signature, []byte{1, 2, 3, 4, 5, 6}) // overwrite with trash to invalid signature
|
|
|
+ return sig
|
|
|
+ },
|
|
|
+ "signature 0 is not valid: invalid signature",
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tc := range tcs {
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
+ s := tc.makeSig()
|
|
|
+ ts, client := fakeNoiseServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ defer r.Body.Close()
|
|
|
+ switch r.URL.Path {
|
|
|
+ case "/machine/tka/affected-sigs":
|
|
|
+ body := new(tailcfg.TKASignaturesUsingKeyRequest)
|
|
|
+ if err := json.NewDecoder(r.Body).Decode(body); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ if body.Version != tailcfg.CurrentCapabilityVersion {
|
|
|
+ t.Errorf("sign CapVer = %v, want %v", body.Version, tailcfg.CurrentCapabilityVersion)
|
|
|
+ }
|
|
|
+ if body.NodeKey != nodePriv.Public() {
|
|
|
+ t.Errorf("nodeKey = %v, want %v", body.NodeKey, nodePriv.Public())
|
|
|
+ }
|
|
|
+
|
|
|
+ w.WriteHeader(200)
|
|
|
+ if err := json.NewEncoder(w).Encode(tailcfg.TKASignaturesUsingKeyResponse{
|
|
|
+ Signatures: []tkatype.MarshaledSignature{s.Serialize()},
|
|
|
+ }); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ default:
|
|
|
+ t.Errorf("unhandled endpoint path: %v", r.URL.Path)
|
|
|
+ w.WriteHeader(404)
|
|
|
+ }
|
|
|
+ }))
|
|
|
+ defer ts.Close()
|
|
|
+ cc := fakeControlClient(t, client)
|
|
|
+ b := LocalBackend{
|
|
|
+ varRoot: temp,
|
|
|
+ cc: cc,
|
|
|
+ ccAuto: cc,
|
|
|
+ logf: t.Logf,
|
|
|
+ tka: &tkaState{
|
|
|
+ authority: authority,
|
|
|
+ storage: chonk,
|
|
|
+ },
|
|
|
+ pm: pm,
|
|
|
+ store: pm.Store(),
|
|
|
+ }
|
|
|
+
|
|
|
+ sigs, err := b.NetworkLockAffectedSigs(nlPriv.KeyID())
|
|
|
+ switch {
|
|
|
+ case tc.wantErr == "" && err != nil:
|
|
|
+ t.Errorf("NetworkLockAffectedSigs() failed: %v", err)
|
|
|
+ case tc.wantErr != "" && err == nil:
|
|
|
+ t.Errorf("NetworkLockAffectedSigs().err = nil, want %q", tc.wantErr)
|
|
|
+ case tc.wantErr != "" && err.Error() != tc.wantErr:
|
|
|
+ t.Errorf("NetworkLockAffectedSigs().err = %q, want %q", err.Error(), tc.wantErr)
|
|
|
+ }
|
|
|
+
|
|
|
+ if tc.wantErr == "" {
|
|
|
+ if len(sigs) != 1 {
|
|
|
+ t.Fatalf("len(sigs) = %d, want 1", len(sigs))
|
|
|
+ }
|
|
|
+ if !bytes.Equal(s.Serialize(), sigs[0]) {
|
|
|
+ t.Errorf("unexpected signature: got %v, want %v", sigs[0], s.Serialize())
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|