Browse Source

ipn,types/persist: add DisallowedTKAStateIDs, refactor as view type

Supercedes https://github.com/tailscale/tailscale/pull/6557, precursor to trying https://github.com/tailscale/tailscale/pull/6546 again

Signed-off-by: Tom DNetto <[email protected]>
Tom DNetto 3 years ago
parent
commit
c4980f33f7

+ 1 - 4
ipn/ipn_clone.go

@@ -24,10 +24,7 @@ func (src *Prefs) Clone() *Prefs {
 	*dst = *src
 	dst.AdvertiseTags = append(src.AdvertiseTags[:0:0], src.AdvertiseTags...)
 	dst.AdvertiseRoutes = append(src.AdvertiseRoutes[:0:0], src.AdvertiseRoutes...)
-	if dst.Persist != nil {
-		dst.Persist = new(persist.Persist)
-		*dst.Persist = *src.Persist
-	}
+	dst.Persist = src.Persist.Clone()
 	return dst
 }
 

+ 1 - 7
ipn/ipn_view.go

@@ -87,13 +87,7 @@ func (v PrefsView) NoSNAT() bool                          { return v.ж.NoSNAT }
 func (v PrefsView) NetfilterMode() preftype.NetfilterMode { return v.ж.NetfilterMode }
 func (v PrefsView) OperatorUser() string                  { return v.ж.OperatorUser }
 func (v PrefsView) ProfileName() string                   { return v.ж.ProfileName }
-func (v PrefsView) Persist() *persist.Persist {
-	if v.ж.Persist == nil {
-		return nil
-	}
-	x := *v.ж.Persist
-	return &x
-}
+func (v PrefsView) Persist() persist.PersistView          { return v.ж.Persist.View() }
 
 // A compilation failure here means this code must be regenerated, with the command at the top of this file.
 var _PrefsViewNeedsRegeneration = Prefs(struct {

+ 10 - 10
ipn/ipnlocal/local.go

@@ -517,7 +517,7 @@ func (b *LocalBackend) Shutdown() {
 }
 
 func stripKeysFromPrefs(p ipn.PrefsView) ipn.PrefsView {
-	if !p.Valid() || p.Persist() == nil {
+	if !p.Valid() || !p.Persist().Valid() {
 		return p
 	}
 
@@ -816,7 +816,7 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) {
 	b.mu.Lock()
 
 	if st.LogoutFinished != nil {
-		if p := b.pm.CurrentPrefs(); p.Persist() == nil || p.Persist().LoginName == "" {
+		if p := b.pm.CurrentPrefs(); !p.Persist().Valid() || p.Persist().LoginName() == "" {
 			b.mu.Unlock()
 			return
 		}
@@ -1203,7 +1203,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
 	if opts.UpdatePrefs != nil {
 		oldPrefs := b.pm.CurrentPrefs()
 		newPrefs := opts.UpdatePrefs.Clone()
-		newPrefs.Persist = oldPrefs.Persist()
+		newPrefs.Persist = oldPrefs.Persist().AsStruct()
 		pv := newPrefs.View()
 		if err := b.pm.SetPrefs(pv); err != nil {
 			b.logf("failed to save UpdatePrefs state: %v", err)
@@ -1228,7 +1228,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
 	b.applyPrefsToHostinfoLocked(hostinfo, prefs)
 
 	b.setNetMapLocked(nil)
-	persistv := prefs.Persist()
+	persistv := prefs.Persist().AsStruct()
 	if persistv == nil {
 		persistv = new(persist.Persist)
 	}
@@ -1947,8 +1947,8 @@ func (b *LocalBackend) initMachineKeyLocked() (err error) {
 	}
 
 	var legacyMachineKey key.MachinePrivate
-	if p := b.pm.CurrentPrefs().Persist(); p != nil {
-		legacyMachineKey = p.LegacyFrontendPrivateMachineKey
+	if p := b.pm.CurrentPrefs().Persist(); p.Valid() {
+		legacyMachineKey = p.LegacyFrontendPrivateMachineKey()
 	}
 
 	keyText, err := b.store.ReadState(ipn.MachineKeyStateKey)
@@ -2481,7 +2481,7 @@ func (b *LocalBackend) setPrefsLockedOnEntry(caller string, newp *ipn.Prefs) ipn
 
 	oldp := b.pm.CurrentPrefs()
 	if oldp.Valid() {
-		newp.Persist = oldp.Persist().Clone() // caller isn't allowed to override this
+		newp.Persist = oldp.Persist().AsStruct() // caller isn't allowed to override this
 	}
 	// findExitNodeIDLocked returns whether it updated b.prefs, but
 	// everything in this function treats b.prefs as completely new
@@ -3338,7 +3338,7 @@ func (b *LocalBackend) hasNodeKey() bool {
 	b.mu.Lock()
 	defer b.mu.Unlock()
 	p := b.pm.CurrentPrefs()
-	return p.Valid() && p.Persist() != nil && !p.Persist().PrivateNodeKey.IsZero()
+	return p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero()
 }
 
 // nextState returns the state the backend seems to be in, based on
@@ -3926,8 +3926,8 @@ func (b *LocalBackend) SetDNS(ctx context.Context, name, value string) error {
 
 	b.mu.Lock()
 	cc := b.ccAuto
-	if prefs := b.pm.CurrentPrefs(); prefs.Valid() {
-		req.NodeKey = prefs.Persist().PrivateNodeKey.Public()
+	if prefs := b.pm.CurrentPrefs(); prefs.Valid() && prefs.Persist().Valid() {
+		req.NodeKey = prefs.Persist().PrivateNodeKey().Public()
 	}
 	b.mu.Unlock()
 	if cc == nil {

+ 10 - 10
ipn/ipnlocal/network-lock.go

@@ -345,10 +345,10 @@ func (b *LocalBackend) NetworkLockStatus() *ipnstate.NetworkLockStatus {
 		nodeKey *key.NodePublic
 		nlPriv  key.NLPrivate
 	)
-	if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist() != nil && !p.Persist().PrivateNodeKey.IsZero() {
+	if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() {
 		nkp := p.Persist().PublicNodeKey()
 		nodeKey = &nkp
-		nlPriv = p.Persist().NetworkLockKey
+		nlPriv = p.Persist().NetworkLockKey()
 	}
 
 	if nlPriv.IsZero() {
@@ -411,9 +411,9 @@ func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byt
 	var ourNodeKey key.NodePublic
 	var nlPriv key.NLPrivate
 	b.mu.Lock()
-	if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist() != nil && !p.Persist().PrivateNodeKey.IsZero() {
+	if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() {
 		ourNodeKey = p.Persist().PublicNodeKey()
-		nlPriv = p.Persist().NetworkLockKey
+		nlPriv = p.Persist().NetworkLockKey()
 	}
 	b.mu.Unlock()
 	if ourNodeKey.IsZero() || nlPriv.IsZero() {
@@ -503,8 +503,8 @@ func (b *LocalBackend) NetworkLockSign(nodeKey key.NodePublic, rotationPublic []
 		defer b.mu.Unlock()
 
 		var nlPriv key.NLPrivate
-		if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist() != nil {
-			nlPriv = p.Persist().NetworkLockKey
+		if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() {
+			nlPriv = p.Persist().NetworkLockKey()
 		}
 		if nlPriv.IsZero() {
 			return key.NodePublic{}, tka.NodeKeySignature{}, errMissingNetmap
@@ -557,7 +557,7 @@ func (b *LocalBackend) NetworkLockModify(addKeys, removeKeys []tka.Key) (err err
 	defer b.mu.Unlock()
 
 	var ourNodeKey key.NodePublic
-	if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist() != nil && !p.Persist().PrivateNodeKey.IsZero() {
+	if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() {
 		ourNodeKey = p.Persist().PublicNodeKey()
 	}
 	if ourNodeKey.IsZero() {
@@ -568,8 +568,8 @@ func (b *LocalBackend) NetworkLockModify(addKeys, removeKeys []tka.Key) (err err
 		return err
 	}
 	var nlPriv key.NLPrivate
-	if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist() != nil {
-		nlPriv = p.Persist().NetworkLockKey
+	if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() {
+		nlPriv = p.Persist().NetworkLockKey()
 	}
 	if nlPriv.IsZero() {
 		return errMissingNetmap
@@ -634,7 +634,7 @@ func (b *LocalBackend) NetworkLockDisable(secret []byte) error {
 	)
 
 	b.mu.Lock()
-	if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist() != nil && !p.Persist().PrivateNodeKey.IsZero() {
+	if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() {
 		ourNodeKey = p.Persist().PublicNodeKey()
 	}
 	if b.tka == nil {

+ 1 - 1
ipn/ipnlocal/profiles.go

@@ -179,7 +179,7 @@ func init() {
 // provided prefs, which may be accessed via CurrentPrefs.
 func (pm *profileManager) SetPrefs(prefsIn ipn.PrefsView) error {
 	prefs := prefsIn.AsStruct().View()
-	newPersist := prefs.Persist()
+	newPersist := prefs.Persist().AsStruct()
 	if newPersist == nil || newPersist.LoginName == "" {
 		return pm.setPrefsLocked(prefs)
 	}

+ 12 - 12
ipn/ipnlocal/state_test.go

@@ -489,7 +489,7 @@ func TestStateMachine(t *testing.T) {
 		c.Assert(nn[0].LoginFinished, qt.IsNotNil)
 		c.Assert(nn[1].Prefs, qt.IsNotNil)
 		c.Assert(nn[2].State, qt.IsNotNil)
-		c.Assert(nn[1].Prefs.Persist().LoginName, qt.Equals, "user1")
+		c.Assert(nn[1].Prefs.Persist().LoginName(), qt.Equals, "user1")
 		c.Assert(ipn.NeedsMachineAuth, qt.Equals, *nn[2].State)
 		c.Assert(ipn.NeedsMachineAuth, qt.Equals, b.State())
 	}
@@ -711,7 +711,7 @@ func TestStateMachine(t *testing.T) {
 		c.Assert(nn[1].Prefs.Persist(), qt.IsNotNil)
 		c.Assert(nn[2].State, qt.IsNotNil)
 		// Prefs after finishing the login, so LoginName updated.
-		c.Assert(nn[1].Prefs.Persist().LoginName, qt.Equals, "user2")
+		c.Assert(nn[1].Prefs.Persist().LoginName(), qt.Equals, "user2")
 		c.Assert(nn[1].Prefs.LoggedOut(), qt.IsFalse)
 		c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue)
 		c.Assert(ipn.Starting, qt.Equals, *nn[2].State)
@@ -852,7 +852,7 @@ func TestStateMachine(t *testing.T) {
 		c.Assert(nn[1].Prefs, qt.IsNotNil)
 		c.Assert(nn[2].State, qt.IsNotNil)
 		// Prefs after finishing the login, so LoginName updated.
-		c.Assert(nn[1].Prefs.Persist().LoginName, qt.Equals, "user3")
+		c.Assert(nn[1].Prefs.Persist().LoginName(), qt.Equals, "user3")
 		c.Assert(nn[1].Prefs.LoggedOut(), qt.IsFalse)
 		c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue)
 		c.Assert(ipn.Starting, qt.Equals, *nn[2].State)
@@ -957,7 +957,7 @@ func TestEditPrefsHasNoKeys(t *testing.T) {
 			LegacyFrontendPrivateMachineKey: key.NewMachine(),
 		},
 	}).View())
-	if b.pm.CurrentPrefs().Persist().PrivateNodeKey.IsZero() {
+	if p := b.pm.CurrentPrefs().Persist(); !p.Valid() || p.PrivateNodeKey().IsZero() {
 		t.Fatalf("PrivateNodeKey not set")
 	}
 	p, err := b.EditPrefs(&ipn.MaskedPrefs{
@@ -973,20 +973,20 @@ func TestEditPrefsHasNoKeys(t *testing.T) {
 		t.Errorf("Hostname = %q; want foo", p.Hostname())
 	}
 
-	if !p.Persist().PrivateNodeKey.IsZero() {
-		t.Errorf("PrivateNodeKey = %v; want zero", p.Persist().PrivateNodeKey)
+	if !p.Persist().PrivateNodeKey().IsZero() {
+		t.Errorf("PrivateNodeKey = %v; want zero", p.Persist().PrivateNodeKey())
 	}
 
-	if !p.Persist().OldPrivateNodeKey.IsZero() {
-		t.Errorf("OldPrivateNodeKey = %v; want zero", p.Persist().OldPrivateNodeKey)
+	if !p.Persist().OldPrivateNodeKey().IsZero() {
+		t.Errorf("OldPrivateNodeKey = %v; want zero", p.Persist().OldPrivateNodeKey())
 	}
 
-	if !p.Persist().LegacyFrontendPrivateMachineKey.IsZero() {
-		t.Errorf("LegacyFrontendPrivateMachineKey = %v; want zero", p.Persist().LegacyFrontendPrivateMachineKey)
+	if !p.Persist().LegacyFrontendPrivateMachineKey().IsZero() {
+		t.Errorf("LegacyFrontendPrivateMachineKey = %v; want zero", p.Persist().LegacyFrontendPrivateMachineKey())
 	}
 
-	if !p.Persist().NetworkLockKey.IsZero() {
-		t.Errorf("NetworkLockKey= %v; want zero", p.Persist().NetworkLockKey)
+	if !p.Persist().NetworkLockKey().IsZero() {
+		t.Errorf("NetworkLockKey= %v; want zero", p.Persist().NetworkLockKey())
 	}
 }
 

+ 16 - 1
types/persist/persist.go

@@ -7,6 +7,7 @@ package persist
 
 import (
 	"fmt"
+	"reflect"
 
 	"tailscale.com/tailcfg"
 	"tailscale.com/types/key"
@@ -39,6 +40,12 @@ type Persist struct {
 	UserProfile       tailcfg.UserProfile
 	NetworkLockKey    key.NLPrivate
 	NodeID            tailcfg.StableNodeID
+
+	// DisallowedTKAStateIDs stores the tka.State.StateID values which
+	// this node will not operate network lock on. This is used to
+	// prevent bootstrapping TKA onto a key authority which was forcibly
+	// disabled.
+	DisallowedTKAStateIDs []string `json:",omitempty"`
 }
 
 // PublicNodeKey returns the public key for the node key.
@@ -55,6 +62,13 @@ func (p PersistView) Equals(p2 PersistView) bool {
 	return p.ж.Equals(p2.ж)
 }
 
+func nilIfEmpty[E any](s []E) []E {
+	if len(s) == 0 {
+		return nil
+	}
+	return s
+}
+
 func (p *Persist) Equals(p2 *Persist) bool {
 	if p == nil && p2 == nil {
 		return true
@@ -70,7 +84,8 @@ func (p *Persist) Equals(p2 *Persist) bool {
 		p.LoginName == p2.LoginName &&
 		p.UserProfile == p2.UserProfile &&
 		p.NetworkLockKey.Equal(p2.NetworkLockKey) &&
-		p.NodeID == p2.NodeID
+		p.NodeID == p2.NodeID &&
+		reflect.DeepEqual(nilIfEmpty(p.DisallowedTKAStateIDs), nilIfEmpty(p2.DisallowedTKAStateIDs))
 }
 
 func (p *Persist) Pretty() string {

+ 2 - 0
types/persist/persist_clone.go

@@ -20,6 +20,7 @@ func (src *Persist) Clone() *Persist {
 	}
 	dst := new(Persist)
 	*dst = *src
+	dst.DisallowedTKAStateIDs = append(src.DisallowedTKAStateIDs[:0:0], src.DisallowedTKAStateIDs...)
 	return dst
 }
 
@@ -34,4 +35,5 @@ var _PersistCloneNeedsRegeneration = Persist(struct {
 	UserProfile                     tailcfg.UserProfile
 	NetworkLockKey                  key.NLPrivate
 	NodeID                          tailcfg.StableNodeID
+	DisallowedTKAStateIDs           []string
 }{})

+ 16 - 1
types/persist/persist_test.go

@@ -22,7 +22,7 @@ func fieldsOf(t reflect.Type) (fields []string) {
 }
 
 func TestPersistEqual(t *testing.T) {
-	persistHandles := []string{"LegacyFrontendPrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "Provider", "LoginName", "UserProfile", "NetworkLockKey", "NodeID"}
+	persistHandles := []string{"LegacyFrontendPrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "Provider", "LoginName", "UserProfile", "NetworkLockKey", "NodeID", "DisallowedTKAStateIDs"}
 	if have := fieldsOf(reflect.TypeOf(Persist{})); !reflect.DeepEqual(have, persistHandles) {
 		t.Errorf("Persist.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
 			have, persistHandles)
@@ -133,6 +133,21 @@ func TestPersistEqual(t *testing.T) {
 			&Persist{NodeID: "abc"},
 			false,
 		},
+		{
+			&Persist{DisallowedTKAStateIDs: nil},
+			&Persist{DisallowedTKAStateIDs: []string{"0:0"}},
+			false,
+		},
+		{
+			&Persist{DisallowedTKAStateIDs: []string{"0:1"}},
+			&Persist{DisallowedTKAStateIDs: []string{"0:1"}},
+			true,
+		},
+		{
+			&Persist{DisallowedTKAStateIDs: []string{}},
+			&Persist{DisallowedTKAStateIDs: nil},
+			true,
+		},
 	}
 	for i, test := range tests {
 		if got := test.a.Equals(test.b); got != test.want {

+ 5 - 0
types/persist/persist_view.go

@@ -13,6 +13,7 @@ import (
 	"tailscale.com/tailcfg"
 	"tailscale.com/types/key"
 	"tailscale.com/types/structs"
+	"tailscale.com/types/views"
 )
 
 //go:generate go run tailscale.com/cmd/cloner  -clonefunc=false -type=Persist
@@ -72,6 +73,9 @@ func (v PersistView) LoginName() string                  { return v.ж.LoginName
 func (v PersistView) UserProfile() tailcfg.UserProfile   { return v.ж.UserProfile }
 func (v PersistView) NetworkLockKey() key.NLPrivate      { return v.ж.NetworkLockKey }
 func (v PersistView) NodeID() tailcfg.StableNodeID       { return v.ж.NodeID }
+func (v PersistView) DisallowedTKAStateIDs() views.Slice[string] {
+	return views.SliceOf(v.ж.DisallowedTKAStateIDs)
+}
 
 // A compilation failure here means this code must be regenerated, with the command at the top of this file.
 var _PersistViewNeedsRegeneration = Persist(struct {
@@ -84,4 +88,5 @@ var _PersistViewNeedsRegeneration = Persist(struct {
 	UserProfile                     tailcfg.UserProfile
 	NetworkLockKey                  key.NLPrivate
 	NodeID                          tailcfg.StableNodeID
+	DisallowedTKAStateIDs           []string
 }{})