Browse Source

tka: truncate long rotation signature chains

When a rotation signature chain reaches a certain size, remove the
oldest rotation signature from the chain before wrapping it in a new
rotation signature.

Since all previous rotation signatures are signed by the same wrapping
pubkey (node's own tailnet lock key), the node can re-construct the
chain, re-signing previous rotation signatures. This will satisfy the
existing certificate validation logic.

Updates #13185

Signed-off-by: Anton Tolchanov <[email protected]>
Anton Tolchanov 1 year ago
parent
commit
fd6686d81a
4 changed files with 221 additions and 11 deletions
  1. 11 10
      ipn/ipnlocal/network-lock.go
  2. 25 0
      ipn/ipnlocal/network-lock_test.go
  3. 51 1
      tka/sig.go
  4. 134 0
      tka/sig_test.go

+ 11 - 10
ipn/ipnlocal/network-lock.go

@@ -175,23 +175,24 @@ func (r *rotationTracker) addRotationDetails(np key.NodePublic, d *tka.RotationD
 // obsoleteKeys returns the set of node keys that are obsolete due to key rotation.
 func (r *rotationTracker) obsoleteKeys() set.Set[key.NodePublic] {
 	for _, v := range r.byWrappingKey {
+		// Do not consider signatures for keys that have been marked as obsolete
+		// by another signature.
+		v = slices.DeleteFunc(v, func(rd sigRotationDetails) bool {
+			return r.obsolete.Contains(rd.np)
+		})
+		if len(v) == 0 {
+			continue
+		}
+
 		// If there are multiple rotation signatures with the same wrapping
 		// pubkey, we need to decide which one is the "latest", and keep it.
 		// The signature with the largest number of previous keys is likely to
-		// be the latest, unless it has been marked as obsolete (rotated out) by
-		// another signature (which might happen in the future if we start
-		// compacting long rotated signature chains).
+		// be the latest.
 		slices.SortStableFunc(v, func(a, b sigRotationDetails) int {
-			// Group all obsolete keys after non-obsolete keys.
-			if ao, bo := r.obsolete.Contains(a.np), r.obsolete.Contains(b.np); ao != bo {
-				if ao {
-					return 1
-				}
-				return -1
-			}
 			// Sort by decreasing number of previous keys.
 			return b.numPrevKeys - a.numPrevKeys
 		})
+
 		// If there are several signatures with the same number of previous
 		// keys, we cannot determine which one is the latest, so all of them are
 		// rejected for safety.

+ 25 - 0
ipn/ipnlocal/network-lock_test.go

@@ -667,6 +667,31 @@ func TestTKAFilterNetmap(t *testing.T) {
 	if diff := cmp.Diff(want, nm.Peers, nodePubComparer); diff != "" {
 		t.Errorf("filtered netmap differs (-want, +got):\n%s", diff)
 	}
+
+	// Confirm that repeated rotation works correctly.
+	for range 100 {
+		n5Rotated, n5RotatedSig = resign(n5nl, n5RotatedSig)
+	}
+
+	n51, n51Sig := resign(n5nl, n5RotatedSig)
+
+	nm = &netmap.NetworkMap{
+		Peers: nodeViews([]*tailcfg.Node{
+			{ID: 1, Key: n1.Public(), KeySignature: n1GoodSig.Serialize()},
+			{ID: 5, Key: n5Rotated.Public(), KeySignature: n5RotatedSig}, // rotated
+			{ID: 51, Key: n51.Public(), KeySignature: n51Sig},
+		}),
+	}
+
+	b.tkaFilterNetmapLocked(nm)
+
+	want = nodeViews([]*tailcfg.Node{
+		{ID: 1, Key: n1.Public(), KeySignature: n1GoodSig.Serialize()},
+		{ID: 51, Key: n51.Public(), KeySignature: n51Sig},
+	})
+	if diff := cmp.Diff(want, nm.Peers, nodePubComparer); diff != "" {
+		t.Errorf("filtered netmap differs (-want, +got):\n%s", diff)
+	}
 }
 
 func TestTKADisable(t *testing.T) {

+ 51 - 1
tka/sig.go

@@ -372,10 +372,15 @@ func ResignNKS(priv key.NLPrivate, nodeKey key.NodePublic, oldNKS tkatype.Marsha
 		return oldNKS, nil
 	}
 
+	nested, err := maybeTrimRotationSignatureChain(oldSig, priv)
+	if err != nil {
+		return nil, fmt.Errorf("trimming rotation signature chain: %w", err)
+	}
+
 	newSig := NodeKeySignature{
 		SigKind: SigRotation,
 		Pubkey:  nk,
-		Nested:  &oldSig,
+		Nested:  &nested,
 	}
 	if newSig.Signature, err = priv.SignNKS(newSig.SigHash()); err != nil {
 		return nil, fmt.Errorf("signing NKS: %w", err)
@@ -384,6 +389,51 @@ func ResignNKS(priv key.NLPrivate, nodeKey key.NodePublic, oldNKS tkatype.Marsha
 	return newSig.Serialize(), nil
 }
 
+// maybeTrimRotationSignatureChain truncates rotation signature chain to ensure
+// it contains no more than 15 node keys.
+func maybeTrimRotationSignatureChain(sig NodeKeySignature, priv key.NLPrivate) (NodeKeySignature, error) {
+	if sig.SigKind != SigRotation {
+		return sig, nil
+	}
+
+	// Collect all the previous node keys, ordered from newest to oldest.
+	prevPubkeys := [][]byte{sig.Pubkey}
+	nested := sig.Nested
+	for nested != nil {
+		if len(nested.Pubkey) > 0 {
+			prevPubkeys = append(prevPubkeys, nested.Pubkey)
+		}
+		if nested.SigKind != SigRotation {
+			break
+		}
+		nested = nested.Nested
+	}
+
+	// Existing rotation signature with 15 keys is the maximum we can wrap in a
+	// new signature without hitting the CBOR nesting limit of 16 (see
+	// MaxNestedLevels in tka.go).
+	const maxPrevKeys = 15
+	if len(prevPubkeys) <= maxPrevKeys {
+		return sig, nil
+	}
+
+	// Create a new rotation signature chain, starting with the original
+	// direct signature.
+	var err error
+	result := nested // original direct signature
+	for i := maxPrevKeys - 2; i >= 0; i-- {
+		result = &NodeKeySignature{
+			SigKind: SigRotation,
+			Pubkey:  prevPubkeys[i],
+			Nested:  result,
+		}
+		if result.Signature, err = priv.SignNKS(result.SigHash()); err != nil {
+			return sig, fmt.Errorf("signing NKS: %w", err)
+		}
+	}
+	return *result, nil
+}
+
 // SignByCredential signs a node public key by a private key which has its
 // signing authority delegated by a SigCredential signature. This is used by
 // wrapped auth keys.

+ 134 - 0
tka/sig_test.go

@@ -9,7 +9,9 @@ import (
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
+	"github.com/google/go-cmp/cmp/cmpopts"
 	"tailscale.com/types/key"
+	"tailscale.com/types/tkatype"
 )
 
 func TestSigDirect(t *testing.T) {
@@ -74,6 +76,9 @@ func TestSigNested(t *testing.T) {
 	if err := nestedSig.verifySignature(oldNode.Public(), k); err != nil {
 		t.Fatalf("verifySignature(oldNode) failed: %v", err)
 	}
+	if l := sigChainLength(nestedSig); l != 1 {
+		t.Errorf("nestedSig chain length = %v, want 1", l)
+	}
 
 	// The signature authorizing the rotation, signed by the
 	// rotation key & embedding the original signature.
@@ -88,6 +93,9 @@ func TestSigNested(t *testing.T) {
 	if err := sig.verifySignature(node.Public(), k); err != nil {
 		t.Fatalf("verifySignature(node) failed: %v", err)
 	}
+	if l := sigChainLength(sig); l != 2 {
+		t.Errorf("sig chain length = %v, want 2", l)
+	}
 
 	// Test verification fails if the wrong verification key is provided
 	kBad := Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}, Votes: 2}
@@ -497,3 +505,129 @@ func TestDecodeWrappedAuthkey(t *testing.T) {
 	}
 
 }
+
+func TestResignNKS(t *testing.T) {
+	// Tailnet lock keypair of a signing node.
+	authPub, authPriv := testingKey25519(t, 1)
+	authKey := Key{Kind: Key25519, Public: authPub, Votes: 2}
+
+	// Node's own tailnet lock key used to sign rotation signatures.
+	tlPriv := key.NewNLPrivate()
+
+	// The original (oldest) node key, signed by a signing node.
+	origNode := key.NewNode()
+	origPub, _ := origNode.Public().MarshalBinary()
+
+	// The original signature for the old node key, signed by
+	// the network-lock key.
+	directSig := NodeKeySignature{
+		SigKind:        SigDirect,
+		KeyID:          authKey.MustID(),
+		Pubkey:         origPub,
+		WrappingPubkey: tlPriv.Public().Verifier(),
+	}
+	sigHash := directSig.SigHash()
+	directSig.Signature = ed25519.Sign(authPriv, sigHash[:])
+	if err := directSig.verifySignature(origNode.Public(), authKey); err != nil {
+		t.Fatalf("verifySignature(origNode) failed: %v", err)
+	}
+
+	// Generate a bunch of node keys to be used by tests.
+	var nodeKeys []key.NodePublic
+	for range 20 {
+		n := key.NewNode()
+		nodeKeys = append(nodeKeys, n.Public())
+	}
+
+	// mkSig creates a signature chain starting with a direct signature
+	// with rotation signatures matching provided keys (from the nodeKeys slice).
+	mkSig := func(prevKeyIDs ...int) tkatype.MarshaledSignature {
+		sig := &directSig
+		for _, i := range prevKeyIDs {
+			pk, _ := nodeKeys[i].MarshalBinary()
+			sig = &NodeKeySignature{
+				SigKind: SigRotation,
+				Pubkey:  pk,
+				Nested:  sig,
+			}
+			var err error
+			sig.Signature, err = tlPriv.SignNKS(sig.SigHash())
+			if err != nil {
+				t.Error(err)
+			}
+		}
+		return sig.Serialize()
+	}
+
+	tests := []struct {
+		name             string
+		oldSig           tkatype.MarshaledSignature
+		wantPrevNodeKeys []key.NodePublic
+	}{
+		{
+			name:             "first-rotation",
+			oldSig:           directSig.Serialize(),
+			wantPrevNodeKeys: []key.NodePublic{origNode.Public()},
+		},
+		{
+			name:             "second-rotation",
+			oldSig:           mkSig(0),
+			wantPrevNodeKeys: []key.NodePublic{nodeKeys[0], origNode.Public()},
+		},
+		{
+			name:   "truncate-chain",
+			oldSig: mkSig(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14),
+			wantPrevNodeKeys: []key.NodePublic{
+				nodeKeys[14],
+				nodeKeys[13],
+				nodeKeys[12],
+				nodeKeys[11],
+				nodeKeys[10],
+				nodeKeys[9],
+				nodeKeys[8],
+				nodeKeys[7],
+				nodeKeys[6],
+				nodeKeys[5],
+				nodeKeys[4],
+				nodeKeys[3],
+				nodeKeys[2],
+				nodeKeys[1],
+				origNode.Public(),
+			},
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			newNode := key.NewNode()
+			got, err := ResignNKS(tlPriv, newNode.Public(), tt.oldSig)
+			if err != nil {
+				t.Fatalf("ResignNKS() error = %v", err)
+			}
+			var gotSig NodeKeySignature
+			if err := gotSig.Unserialize(got); err != nil {
+				t.Fatalf("Unserialize() failed: %v", err)
+			}
+			if err := gotSig.verifySignature(newNode.Public(), authKey); err != nil {
+				t.Errorf("verifySignature(newNode) error: %v", err)
+			}
+
+			rd, err := gotSig.rotationDetails()
+			if err != nil {
+				t.Fatalf("rotationDetails() error = %v", err)
+			}
+			if sigChainLength(gotSig) != len(tt.wantPrevNodeKeys)+1 {
+				t.Errorf("sigChainLength() = %v, want %v", sigChainLength(gotSig), len(tt.wantPrevNodeKeys)+1)
+			}
+			if diff := cmp.Diff(tt.wantPrevNodeKeys, rd.PrevNodeKeys, cmpopts.EquateComparable(key.NodePublic{})); diff != "" {
+				t.Errorf("PrevNodeKeys mismatch (-want +got):\n%s", diff)
+			}
+		})
+	}
+}
+
+func sigChainLength(s NodeKeySignature) int {
+	if s.Nested != nil {
+		return 1 + sigChainLength(*s.Nested)
+	}
+	return 1
+}