Просмотр исходного кода

control/controlclient: make Status.Persist a PersistView

Signed-off-by: Maisem Ali <[email protected]>
Maisem Ali 3 лет назад
Родитель
Сommit
f00a49667d

+ 2 - 2
control/controlclient/auto.go

@@ -580,7 +580,7 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM
 
 	c.logf("[v1] sendStatus: %s: %v", who, state)
 
-	var p *persist.Persist
+	var p *persist.PersistView
 	var loginFin, logoutFin *empty.Message
 	if state == StateAuthenticated {
 		loginFin = new(empty.Message)
@@ -590,7 +590,7 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM
 	}
 	if nm != nil && loggedIn && synced {
 		pp := c.direct.GetPersist()
-		p = pp.AsStruct()
+		p = &pp
 	} else {
 		// don't send netmap status, as it's misleading when we're
 		// not logged in.

+ 8 - 8
control/controlclient/direct.go

@@ -87,7 +87,7 @@ type Direct struct {
 	sfGroup     singleflight.Group[struct{}, *NoiseClient] // protects noiseClient creation.
 	noiseClient *NoiseClient
 
-	persist       persist.Persist
+	persist       persist.PersistView
 	authKey       string
 	tryingNewKey  key.NodePrivate
 	expiry        *time.Time
@@ -238,7 +238,7 @@ func NewDirect(opts Options) (*Direct, error) {
 		logf:                   opts.Logf,
 		newDecompressor:        opts.NewDecompressor,
 		keepAlive:              opts.KeepAlive,
-		persist:                opts.Persist,
+		persist:                opts.Persist.View(),
 		authKey:                opts.AuthKey,
 		discoPubKey:            opts.DiscoPublicKey,
 		debugFlags:             opts.DebugFlags,
@@ -336,7 +336,7 @@ func (c *Direct) SetTKAHead(tkaHead string) bool {
 func (c *Direct) GetPersist() persist.PersistView {
 	c.mu.Lock()
 	defer c.mu.Unlock()
-	return c.persist.View()
+	return c.persist
 }
 
 func (c *Direct) TryLogout(ctx context.Context) error {
@@ -346,7 +346,7 @@ func (c *Direct) TryLogout(ctx context.Context) error {
 	c.logf("[v1] TryLogout control response: mustRegen=%v, newURL=%v, err=%v", mustRegen, newURL, err)
 
 	c.mu.Lock()
-	c.persist = persist.Persist{}
+	c.persist = new(persist.Persist).View()
 	c.mu.Unlock()
 
 	return err
@@ -421,7 +421,7 @@ func (c *Direct) hostInfoLocked() *tailcfg.Hostinfo {
 
 func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, newURL string, nks tkatype.MarshaledSignature, err error) {
 	c.mu.Lock()
-	persist := c.persist
+	persist := c.persist.AsStruct()
 	tryingNewKey := c.tryingNewKey
 	serverKey := c.serverKey
 	serverNoiseKey := c.serverNoiseKey
@@ -660,7 +660,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
 		// save it for the retry-with-URL
 		c.tryingNewKey = tryingNewKey
 	}
-	c.persist = persist
+	c.persist = persist.View()
 	c.mu.Unlock()
 
 	if err != nil {
@@ -823,7 +823,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
 		return errors.New("getMachinePrivKey returned zero key")
 	}
 
-	if persist.PrivateNodeKey.IsZero() {
+	if persist.PrivateNodeKey().IsZero() {
 		return errors.New("privateNodeKey is zero")
 	}
 	if backendLogID == "" {
@@ -967,7 +967,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, readOnly bool
 		}
 	}()
 
-	sess := newMapSession(persist.PrivateNodeKey)
+	sess := newMapSession(persist.PrivateNodeKey())
 	sess.logf = c.logf
 	sess.vlogf = vlogf
 	sess.machinePubKey = machinePubKey

+ 1 - 1
control/controlclient/status.go

@@ -75,7 +75,7 @@ type Status struct {
 	// use them. Please don't use these fields.
 	// TODO(apenwarr): Unexport or remove these.
 	State   State
-	Persist *persist.Persist // locally persisted configuration
+	Persist *persist.PersistView // locally persisted configuration
 }
 
 // Equal reports whether s and s2 are equal.

+ 3 - 3
ipn/ipnlocal/local.go

@@ -818,10 +818,10 @@ func (b *LocalBackend) setClientStatus(st controlclient.Status) {
 		prefs.ControlURL = prefs.ControlURLOrDefault()
 		prefsChanged = true
 	}
-	if st.Persist != nil {
-		if !prefs.Persist.Equals(st.Persist) {
+	if st.Persist != nil && st.Persist.Valid() {
+		if !prefs.Persist.View().Equals(*st.Persist) {
 			prefsChanged = true
-			prefs.Persist = st.Persist.Clone()
+			prefs.Persist = st.Persist.AsStruct()
 		}
 	}
 	if st.URL != "" {

+ 2 - 1
ipn/ipnlocal/state_test.go

@@ -139,10 +139,11 @@ func (cc *mockControl) populateKeys() (newKeys bool) {
 // (In our tests here, upstream is the ipnlocal.Local instance.)
 func (cc *mockControl) send(err error, url string, loginFinished bool, nm *netmap.NetworkMap) {
 	if cc.statusFunc != nil {
+		pv := cc.persist.View()
 		s := controlclient.Status{
 			URL:     url,
 			NetMap:  nm,
-			Persist: cc.persist,
+			Persist: &pv,
 			Err:     err,
 		}
 		if loginFinished {

+ 9 - 0
types/persist/persist.go

@@ -43,6 +43,15 @@ func (p *Persist) PublicNodeKey() key.NodePublic {
 	return p.PrivateNodeKey.Public()
 }
 
+// PublicNodeKey returns the public key for the node key.
+func (p PersistView) PublicNodeKey() key.NodePublic {
+	return p.ж.PublicNodeKey()
+}
+
+func (p PersistView) Equals(p2 PersistView) bool {
+	return p.ж.Equals(p2.ж)
+}
+
 func (p *Persist) Equals(p2 *Persist) bool {
 	if p == nil && p2 == nil {
 		return true