Browse Source

wgengine: use key.NodePublic instead of tailcfg.NodeKey.

Updates #3206

Signed-off-by: David Anderson <[email protected]>
David Anderson 4 years ago
parent
commit
c3d7115e63
2 changed files with 38 additions and 41 deletions
  1. 27 29
      wgengine/userspace.go
  2. 11 12
      wgengine/userspace_test.go

+ 27 - 29
wgengine/userspace.go

@@ -114,8 +114,8 @@ type userspaceEngine struct {
 	lastEngineSigFull   deephash.Sum // of full wireguard config
 	lastEngineSigTrim   deephash.Sum // of trimmed wireguard config
 	lastDNSConfig       *dns.Config
-	recvActivityAt      map[tailcfg.NodeKey]mono.Time
-	trimmedNodes        map[tailcfg.NodeKey]bool  // set of node keys of peers currently excluded from wireguard config
+	recvActivityAt      map[key.NodePublic]mono.Time
+	trimmedNodes        map[key.NodePublic]bool   // set of node keys of peers currently excluded from wireguard config
 	sentActivityAt      map[netaddr.IP]*mono.Time // value is accessed atomically
 	destIPActivityFuncs map[netaddr.IP]func()
 	statusBufioReader   *bufio.Reader // reusable for UAPI
@@ -127,7 +127,7 @@ type userspaceEngine struct {
 	netMap              *netmap.NetworkMap // or nil
 	closing             bool               // Close was called (even if we're still closing)
 	statusCallback      StatusCallback
-	peerSequence        []tailcfg.NodeKey
+	peerSequence        []key.NodePublic
 	endpoints           []tailcfg.Endpoint
 	pendOpen            map[flowtrack.Tuple]*pendingOpenFlow // see pendopen.go
 	networkMapCallbacks map[*someHandle]NetworkMapCallback
@@ -554,8 +554,7 @@ func isTrimmablePeer(p *wgcfg.Peer, numPeers int) bool {
 // noteRecvActivity is called by magicsock when a packet has been
 // received for the peer with node key nk. Magicsock calls this no
 // more than every 10 seconds for a given peer.
-func (e *userspaceEngine) noteRecvActivity(k key.NodePublic) {
-	nk := k.AsNodeKey()
+func (e *userspaceEngine) noteRecvActivity(nk key.NodePublic) {
 	e.wgLock.Lock()
 	defer e.wgLock.Unlock()
 
@@ -597,7 +596,7 @@ func (e *userspaceEngine) noteRecvActivity(k key.NodePublic) {
 // has had a packet sent to or received from it since t.
 //
 // e.wgLock must be held.
-func (e *userspaceEngine) isActiveSinceLocked(nk tailcfg.NodeKey, ip netaddr.IP, t mono.Time) bool {
+func (e *userspaceEngine) isActiveSinceLocked(nk key.NodePublic, ip netaddr.IP, t mono.Time) bool {
 	if e.recvActivityAt[nk].After(t) {
 		return true
 	}
@@ -614,7 +613,7 @@ func (e *userspaceEngine) isActiveSinceLocked(nk tailcfg.NodeKey, ip netaddr.IP,
 // If discoChanged is nil or empty, this extra removal step isn't done.
 //
 // e.wgLock must be held.
-func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[tailcfg.NodeKey]bool) error {
+func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.NodePublic]bool) error {
 	if hook := e.testMaybeReconfigHook; hook != nil {
 		hook()
 		return nil
@@ -640,36 +639,35 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[tailcfg.
 	// their NodeKey and Tailscale IPs.  These are the ones we'll need
 	// to install tracking hooks for to watch their send/receive
 	// activity.
-	trackNodes := make([]tailcfg.NodeKey, 0, len(full.Peers))
+	trackNodes := make([]key.NodePublic, 0, len(full.Peers))
 	trackIPs := make([]netaddr.IP, 0, len(full.Peers))
 
-	trimmedNodes := map[tailcfg.NodeKey]bool{} // TODO: don't re-alloc this map each time
+	trimmedNodes := map[key.NodePublic]bool{} // TODO: don't re-alloc this map each time
 
 	needRemoveStep := false
 	for i := range full.Peers {
 		p := &full.Peers[i]
 		nk := p.PublicKey
-		tnk := nk.AsNodeKey()
 		if !isTrimmablePeer(p, len(full.Peers)) {
 			min.Peers = append(min.Peers, *p)
-			if discoChanged[tnk] {
+			if discoChanged[nk] {
 				needRemoveStep = true
 			}
 			continue
 		}
-		trackNodes = append(trackNodes, tnk)
+		trackNodes = append(trackNodes, nk)
 		recentlyActive := false
 		for _, cidr := range p.AllowedIPs {
 			trackIPs = append(trackIPs, cidr.IP())
-			recentlyActive = recentlyActive || e.isActiveSinceLocked(tnk, cidr.IP(), activeCutoff)
+			recentlyActive = recentlyActive || e.isActiveSinceLocked(nk, cidr.IP(), activeCutoff)
 		}
 		if recentlyActive {
 			min.Peers = append(min.Peers, *p)
-			if discoChanged[tnk] {
+			if discoChanged[nk] {
 				needRemoveStep = true
 			}
 		} else {
-			trimmedNodes[tnk] = true
+			trimmedNodes[nk] = true
 		}
 	}
 	e.lastNMinPeers = len(min.Peers)
@@ -688,7 +686,7 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[tailcfg.
 		minner.Peers = nil
 		numRemove := 0
 		for _, p := range min.Peers {
-			if discoChanged[p.PublicKey.AsNodeKey()] {
+			if discoChanged[p.PublicKey] {
 				numRemove++
 				continue
 			}
@@ -716,10 +714,10 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[tailcfg.
 // as given to wireguard-go.
 //
 // e.wgLock must be held.
-func (e *userspaceEngine) updateActivityMapsLocked(trackNodes []tailcfg.NodeKey, trackIPs []netaddr.IP) {
+func (e *userspaceEngine) updateActivityMapsLocked(trackNodes []key.NodePublic, trackIPs []netaddr.IP) {
 	// Generate the new map of which nodekeys we want to track
 	// receive times for.
-	mr := map[tailcfg.NodeKey]mono.Time{} // TODO: only recreate this if set of keys changed
+	mr := map[key.NodePublic]mono.Time{} // TODO: only recreate this if set of keys changed
 	for _, nk := range trackNodes {
 		// Preserve old times in the new map, but also
 		// populate map entries for new trackNodes values with
@@ -808,7 +806,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config,
 	e.mu.Lock()
 	e.peerSequence = e.peerSequence[:0]
 	for _, p := range cfg.Peers {
-		e.peerSequence = append(e.peerSequence, p.PublicKey.AsNodeKey())
+		e.peerSequence = append(e.peerSequence, p.PublicKey)
 		peerSet[p.PublicKey] = struct{}{}
 	}
 	e.mu.Unlock()
@@ -841,12 +839,12 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config,
 	// If so, we need to update the wireguard-go/device.Device in two phases:
 	// once without the node which has restarted, to clear its wireguard session key,
 	// and a second time with it.
-	discoChanged := make(map[tailcfg.NodeKey]bool)
+	discoChanged := make(map[key.NodePublic]bool)
 	{
-		prevEP := make(map[tailcfg.NodeKey]key.DiscoPublic)
+		prevEP := make(map[key.NodePublic]key.DiscoPublic)
 		for i := range e.lastCfgFull.Peers {
 			if p := &e.lastCfgFull.Peers[i]; !p.DiscoKey.IsZero() {
-				prevEP[p.PublicKey.AsNodeKey()] = p.DiscoKey
+				prevEP[p.PublicKey] = p.DiscoKey
 			}
 		}
 		for i := range cfg.Peers {
@@ -854,7 +852,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config,
 			if p.DiscoKey.IsZero() {
 				continue
 			}
-			pub := p.PublicKey.AsNodeKey()
+			pub := p.PublicKey
 			if old, ok := prevEP[pub]; ok && old != p.DiscoKey {
 				discoChanged[pub] = true
 				e.logf("wgengine: Reconfig: %s changed from %q to %q", pub.ShortString(), old, p.DiscoKey)
@@ -979,7 +977,7 @@ func (e *userspaceEngine) getStatus() (*Status, error) {
 		errc <- err
 	}()
 
-	pp := make(map[tailcfg.NodeKey]ipnstate.PeerStatusLite)
+	pp := make(map[key.NodePublic]ipnstate.PeerStatusLite)
 	var p ipnstate.PeerStatusLite
 
 	var hst1, hst2, n int64
@@ -1013,7 +1011,7 @@ func (e *userspaceEngine) getStatus() (*Status, error) {
 				return nil, fmt.Errorf("IpcGetOperation: invalid key in line %q", line)
 			}
 			if !p.NodeKey.IsZero() {
-				pp[p.NodeKey] = p
+				pp[p.NodeKey.AsNodePublic()] = p
 			}
 			p = ipnstate.PeerStatusLite{NodeKey: pk.AsNodeKey()}
 		case "rx_bytes":
@@ -1044,7 +1042,7 @@ func (e *userspaceEngine) getStatus() (*Status, error) {
 		}
 	}
 	if !p.NodeKey.IsZero() {
-		pp[p.NodeKey] = p
+		pp[p.NodeKey.AsNodePublic()] = p
 	}
 	if err := <-errc; err != nil {
 		return nil, fmt.Errorf("IpcGetOperation: %v", err)
@@ -1457,7 +1455,7 @@ func (e *userspaceEngine) peerForIP(ip netaddr.IP) (n *tailcfg.Node, isSelf bool
 
 	// TODO(bradfitz): this is O(n peers). Add ART to netaddr?
 	var best netaddr.IPPrefix
-	var bestKey tailcfg.NodeKey
+	var bestKey key.NodePublic
 	for _, p := range e.lastCfgFull.Peers {
 		for _, cidr := range p.AllowedIPs {
 			if !cidr.Contains(ip) {
@@ -1465,7 +1463,7 @@ func (e *userspaceEngine) peerForIP(ip netaddr.IP) (n *tailcfg.Node, isSelf bool
 			}
 			if best.IsZero() || cidr.Bits() > best.Bits() {
 				best = cidr
-				bestKey = p.PublicKey.AsNodeKey()
+				bestKey = p.PublicKey
 			}
 		}
 	}
@@ -1473,7 +1471,7 @@ func (e *userspaceEngine) peerForIP(ip netaddr.IP) (n *tailcfg.Node, isSelf bool
 	// call. But TODO(bradfitz): add a lookup map to netmap.NetworkMap.
 	if !bestKey.IsZero() {
 		for _, p := range nm.Peers {
-			if p.Key == bestKey {
+			if p.Key.AsNodePublic() == bestKey {
 				return p, false, nil
 			}
 		}

+ 11 - 12
wgengine/userspace_test.go

@@ -37,16 +37,15 @@ func TestNoteReceiveActivity(t *testing.T) {
 	}
 	e := &userspaceEngine{
 		timeNow:               func() mono.Time { return now },
-		recvActivityAt:        map[tailcfg.NodeKey]mono.Time{},
+		recvActivityAt:        map[key.NodePublic]mono.Time{},
 		logf:                  logBuf.Logf,
 		tundev:                new(tstun.Wrapper),
 		testMaybeReconfigHook: func() { confc <- true },
-		trimmedNodes:          map[tailcfg.NodeKey]bool{},
+		trimmedNodes:          map[key.NodePublic]bool{},
 	}
 	ra := e.recvActivityAt
 
 	nk := key.NewNode().Public()
-	tnk := nk.AsNodeKey()
 
 	// Activity on an untracked key should do nothing.
 	e.noteRecvActivity(nk)
@@ -58,12 +57,12 @@ func TestNoteReceiveActivity(t *testing.T) {
 	}
 
 	// Now track it, but don't mark it trimmed, so shouldn't update.
-	ra[tnk] = 0
+	ra[nk] = 0
 	e.noteRecvActivity(nk)
 	if len(ra) != 1 {
 		t.Fatalf("unexpected growth in map: now has %d keys; want 1", len(ra))
 	}
-	if got := ra[tnk]; got != now {
+	if got := ra[nk]; got != now {
 		t.Fatalf("time in map = %v; want %v", got, now)
 	}
 	if gotConf() {
@@ -71,12 +70,12 @@ func TestNoteReceiveActivity(t *testing.T) {
 	}
 
 	// Now mark it trimmed and expect an update.
-	e.trimmedNodes[tnk] = true
+	e.trimmedNodes[nk] = true
 	e.noteRecvActivity(nk)
 	if len(ra) != 1 {
 		t.Fatalf("unexpected growth in map: now has %d keys; want 1", len(ra))
 	}
-	if got := ra[tnk]; got != now {
+	if got := ra[nk]; got != now {
 		t.Fatalf("time in map = %v; want %v", got, now)
 	}
 	if !gotConf() {
@@ -101,7 +100,7 @@ func TestUserspaceEngineReconfig(t *testing.T) {
 		nm := &netmap.NetworkMap{
 			Peers: []*tailcfg.Node{
 				&tailcfg.Node{
-					Key: nkFromHex(nodeHex),
+					Key: nkFromHex(nodeHex).AsNodeKey(),
 				},
 			},
 		}
@@ -126,14 +125,14 @@ func TestUserspaceEngineReconfig(t *testing.T) {
 			t.Fatal(err)
 		}
 
-		wantRecvAt := map[tailcfg.NodeKey]mono.Time{
+		wantRecvAt := map[key.NodePublic]mono.Time{
 			nkFromHex(nodeHex): 0,
 		}
 		if got := ue.recvActivityAt; !reflect.DeepEqual(got, wantRecvAt) {
 			t.Errorf("wrong recvActivityAt\n got: %v\nwant: %v\n", got, wantRecvAt)
 		}
 
-		wantTrimmedNodes := map[tailcfg.NodeKey]bool{
+		wantTrimmedNodes := map[key.NodePublic]bool{
 			nkFromHex(nodeHex): true,
 		}
 		if got := ue.trimmedNodes; !reflect.DeepEqual(got, wantTrimmedNodes) {
@@ -210,7 +209,7 @@ func TestUserspaceEnginePortReconfig(t *testing.T) {
 	}
 }
 
-func nkFromHex(hex string) tailcfg.NodeKey {
+func nkFromHex(hex string) key.NodePublic {
 	if len(hex) != 64 {
 		panic(fmt.Sprintf("%q is len %d; want 64", hex, len(hex)))
 	}
@@ -218,7 +217,7 @@ func nkFromHex(hex string) tailcfg.NodeKey {
 	if err != nil {
 		panic(fmt.Sprintf("%q is not hex: %v", hex, err))
 	}
-	return k.AsNodeKey()
+	return k
 }
 
 // an experiment to see if genLocalAddrFunc was worth it. As of Go