Преглед изворни кода

util/mak: move tailssh's mapSet into a new package for reuse elsewhere

Change-Id: Idfe95db82275fd2be6ca88f245830731a0d5aecf
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick пре 3 година
родитељ
комит
910ae68e0b

+ 1 - 0
cmd/tailscaled/depaware.txt

@@ -264,6 +264,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
   LW    tailscale.com/util/endian                                    from tailscale.com/net/dns+
         tailscale.com/util/groupmember                               from tailscale.com/ipn/ipnserver
         tailscale.com/util/lineread                                  from tailscale.com/hostinfo+
+        tailscale.com/util/mak                                       from tailscale.com/control/controlclient+
         tailscale.com/util/multierr                                  from tailscale.com/cmd/tailscaled+
         tailscale.com/util/netconv                                   from tailscale.com/wgengine/magicsock
         tailscale.com/util/osshare                                   from tailscale.com/cmd/tailscaled+

+ 2 - 4
control/controlclient/noise.go

@@ -20,6 +20,7 @@ import (
 	"tailscale.com/control/controlhttp"
 	"tailscale.com/tailcfg"
 	"tailscale.com/types/key"
+	"tailscale.com/util/mak"
 	"tailscale.com/util/multierr"
 )
 
@@ -137,9 +138,6 @@ func (nc *noiseClient) Close() error {
 func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
 	nc.mu.Lock()
 	connID := nc.nextID
-	if nc.connPool == nil {
-		nc.connPool = make(map[int]*noiseConn)
-	}
 	nc.nextID++
 	nc.mu.Unlock()
 
@@ -161,6 +159,6 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
 	nc.mu.Lock()
 	defer nc.mu.Unlock()
 	ncc := &noiseConn{Conn: conn, id: connID, pool: nc}
-	nc.connPool[ncc.id] = ncc
+	mak.Set(&nc.connPool, ncc.id, ncc)
 	return ncc, nil
 }

+ 2 - 4
ipn/store/stores.go

@@ -21,6 +21,7 @@ import (
 	"tailscale.com/ipn/store/mem"
 	"tailscale.com/paths"
 	"tailscale.com/types/logger"
+	"tailscale.com/util/mak"
 )
 
 // Provider returns a StateStore for the provided path.
@@ -82,10 +83,7 @@ func Register(prefix string, fn Provider) {
 	if _, ok := knownStores[prefix]; ok {
 		panic(fmt.Sprintf("%q already registered", prefix))
 	}
-	if knownStores == nil {
-		knownStores = make(map[string]Provider)
-	}
-	knownStores[prefix] = fn
+	mak.Set(&knownStores, prefix, fn)
 }
 
 // TryWindowsAppDataMigration attempts to copy the Windows state file

+ 4 - 11
ssh/tailssh/tailssh.go

@@ -40,6 +40,7 @@ import (
 	"tailscale.com/tailcfg"
 	"tailscale.com/tempfork/gliderlabs/ssh"
 	"tailscale.com/types/logger"
+	"tailscale.com/util/mak"
 )
 
 var (
@@ -471,7 +472,7 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
 
 	srv.mu.Lock()
 	defer srv.mu.Unlock()
-	mapSet(&srv.fetchPublicKeysCache, url, pubKeyCacheEntry{
+	mak.Set(&srv.fetchPublicKeysCache, url, pubKeyCacheEntry{
 		at:    srv.now(),
 		lines: lines,
 		etag:  etag,
@@ -731,8 +732,8 @@ func (srv *server) startSession(ss *sshSession) {
 	if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup {
 		panic("dup sharedID")
 	}
-	mapSet(&srv.activeSessionByH, ss.idH, ss)
-	mapSet(&srv.activeSessionBySharedID, ss.sharedID, ss)
+	mak.Set(&srv.activeSessionByH, ss.idH, ss)
+	mak.Set(&srv.activeSessionBySharedID, ss.sharedID, ss)
 }
 
 // endSession unregisters s from the list of active sessions.
@@ -1248,11 +1249,3 @@ func envEq(a, b string) bool {
 	}
 	return a == b
 }
-
-// mapSet assigns m[k] = v, making m if necessary.
-func mapSet[K comparable, V any](m *map[K]V, k K, v V) {
-	if *m == nil {
-		*m = make(map[K]V)
-	}
-	(*m)[k] = v
-}

+ 53 - 0
util/mak/mak.go

@@ -0,0 +1,53 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package mak helps make maps. It contains generic helpers to make/assign
+// things, notably to maps, but also slices.
+package mak
+
+import (
+	"fmt"
+	"reflect"
+)
+
+// Set populates an entry in a map, making the map if necessary.
+//
+// That is, it assigns (*m)[k] = v, making *m if it was nil.
+func Set[K comparable, V any, T ~map[K]V](m *T, k K, v V) {
+	if *m == nil {
+		*m = make(map[K]V)
+	}
+	(*m)[k] = v
+}
+
+// NonNil takes a pointer to a Go data structure
+// (currently only a slice or a map) and makes sure it's non-nil for
+// JSON serialization. (In particular, JavaScript clients usually want
+// the field to be defined after they decode the JSON.)
+// MakeNonNil takes a pointer to a Go data structure
+// (currently only a slice or a map) and makes sure it's non-nil for
+// JSON serialization. (In particular, JavaScript clients usually want
+// the field to be defined after they decode the JSON.)
+func NonNil(ptr interface{}) {
+	if ptr == nil {
+		panic("nil interface")
+	}
+	rv := reflect.ValueOf(ptr)
+	if rv.Kind() != reflect.Ptr {
+		panic(fmt.Sprintf("kind %v, not Ptr", rv.Kind()))
+	}
+	if rv.Pointer() == 0 {
+		panic("nil pointer")
+	}
+	rv = rv.Elem()
+	if rv.Pointer() != 0 {
+		return
+	}
+	switch rv.Type().Kind() {
+	case reflect.Slice:
+		rv.Set(reflect.MakeSlice(rv.Type(), 0, 0))
+	case reflect.Map:
+		rv.Set(reflect.MakeMap(rv.Type()))
+	}
+}

+ 71 - 0
util/mak/mak_test.go

@@ -0,0 +1,71 @@
+// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package mak contains code to help make things.
+package mak
+
+import (
+	"reflect"
+	"testing"
+)
+
+type M map[string]int
+
+func TestSet(t *testing.T) {
+	t.Run("unnamed", func(t *testing.T) {
+		var m map[string]int
+		Set(&m, "foo", 42)
+		Set(&m, "bar", 1)
+		Set(&m, "bar", 2)
+		want := map[string]int{
+			"foo": 42,
+			"bar": 2,
+		}
+		if got := m; !reflect.DeepEqual(got, want) {
+			t.Errorf("got %v; want %v", got, want)
+		}
+	})
+	t.Run("named", func(t *testing.T) {
+		var m M
+		Set(&m, "foo", 1)
+		Set(&m, "bar", 1)
+		Set(&m, "bar", 2)
+		want := M{
+			"foo": 1,
+			"bar": 2,
+		}
+		if got := m; !reflect.DeepEqual(got, want) {
+			t.Errorf("got %v; want %v", got, want)
+		}
+	})
+}
+
+func TestNonNil(t *testing.T) {
+	var s []string
+	NonNil(&s)
+	if len(s) != 0 {
+		t.Errorf("slice len = %d; want 0", len(s))
+	}
+	if s == nil {
+		t.Error("slice still nil")
+	}
+
+	s = append(s, "foo")
+	NonNil(&s)
+	if len(s) != 1 {
+		t.Errorf("len = %d; want 1", len(s))
+	}
+	if s[0] != "foo" {
+		t.Errorf("value = %q; want foo", s)
+	}
+
+	var m map[string]string
+	NonNil(&m)
+	if len(m) != 0 {
+		t.Errorf("map len = %d; want 0", len(s))
+	}
+	if m == nil {
+		t.Error("map still nil")
+	}
+}

+ 5 - 11
wgengine/magicsock/magicsock.go

@@ -55,6 +55,7 @@ import (
 	"tailscale.com/types/nettype"
 	"tailscale.com/util/clientmetric"
 	"tailscale.com/util/netconv"
+	"tailscale.com/util/mak"
 	"tailscale.com/util/uniq"
 	"tailscale.com/version"
 	"tailscale.com/wgengine/monitor"
@@ -438,11 +439,7 @@ func (c *Conn) removeDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp
 func (c *Conn) addDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp.Client) {
 	c.mu.Lock()
 	defer c.mu.Unlock()
-	if c.derpRoute == nil {
-		c.derpRoute = make(map[key.NodePublic]derpRoute)
-	}
-	r := derpRoute{derpID, dc}
-	c.derpRoute[peer] = r
+	mak.Set(&c.derpRoute, peer, derpRoute{derpID, dc})
 }
 
 // DerpMagicIP is a fake WireGuard endpoint IP address that means
@@ -1050,7 +1047,7 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro
 		}, nil
 	}
 
-	already := make(map[netaddr.IPPort]tailcfg.EndpointType) // endpoint -> how it was found
+	var already map[netaddr.IPPort]tailcfg.EndpointType // endpoint -> how it was found
 	var eps []tailcfg.Endpoint                               // unique endpoints
 
 	ipp := func(s string) (ipp netaddr.IPPort) {
@@ -1062,7 +1059,7 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro
 			return
 		}
 		if _, ok := already[ipp]; !ok {
-			already[ipp] = et
+			mak.Set(&already, ipp, et)
 			eps = append(eps, tailcfg.Endpoint{Addr: ipp, Type: et})
 		}
 	}
@@ -3957,9 +3954,6 @@ func (de *endpoint) handleCallMeMaybe(m *disco.CallMeMaybe) {
 	for ep := range de.isCallMeMaybeEP {
 		de.isCallMeMaybeEP[ep] = false // mark for deletion
 	}
-	if de.isCallMeMaybeEP == nil {
-		de.isCallMeMaybeEP = map[netaddr.IPPort]bool{}
-	}
 	var newEPs []netaddr.IPPort
 	for _, ep := range m.MyNumber {
 		if ep.IP().Is6() && ep.IP().IsLinkLocalUnicast() {
@@ -3968,7 +3962,7 @@ func (de *endpoint) handleCallMeMaybe(m *disco.CallMeMaybe) {
 			// for these.
 			continue
 		}
-		de.isCallMeMaybeEP[ep] = true
+		mak.Set(&de.isCallMeMaybeEP, ep, true)
 		if es, ok := de.endpointState[ep]; ok {
 			es.callMeMaybeTime = now
 		} else {

+ 2 - 4
wgengine/pendopen.go

@@ -15,6 +15,7 @@ import (
 	"tailscale.com/net/tsaddr"
 	"tailscale.com/net/tstun"
 	"tailscale.com/types/ipproto"
+	"tailscale.com/util/mak"
 	"tailscale.com/wgengine/filter"
 )
 
@@ -115,14 +116,11 @@ func (e *userspaceEngine) trackOpenPostFilterOut(pp *packet.Parsed, t *tstun.Wra
 
 	e.mu.Lock()
 	defer e.mu.Unlock()
-	if e.pendOpen == nil {
-		e.pendOpen = make(map[flowtrack.Tuple]*pendingOpenFlow)
-	}
 	if _, dup := e.pendOpen[flow]; dup {
 		// Duplicates are expected when the OS retransmits. Ignore.
 		return
 	}
-	e.pendOpen[flow] = &pendingOpenFlow{timer: timer}
+	mak.Set(&e.pendOpen, flow, &pendingOpenFlow{timer: timer})
 
 	return filter.Accept
 }