Browse Source

types/mapx, ipn/ipnext: add ordered map, akin to set.Slice

We had an ordered set type (set.Slice) already but we occasionally want
to do the same thing with a map, preserving the order things were added,
so add that too, as mapsx.OrderedMap[K, V], and then use in ipnext.

Updates #12614

Change-Id: I85e6f5e11035571a28316441075e952aef9a0863
Signed-off-by: Brad Fitzpatrick <[email protected]>
Brad Fitzpatrick 10 months ago
parent
commit
dbf13976d3

+ 1 - 0
cmd/k8s-operator/depaware.txt

@@ -921,6 +921,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
         tailscale.com/types/lazy                                     from tailscale.com/ipn/ipnlocal+
         tailscale.com/types/logger                                   from tailscale.com/appc+
         tailscale.com/types/logid                                    from tailscale.com/ipn/ipnlocal+
+        tailscale.com/types/mapx                                     from tailscale.com/ipn/ipnext
         tailscale.com/types/netlogtype                               from tailscale.com/net/connstats+
         tailscale.com/types/netmap                                   from tailscale.com/control/controlclient+
         tailscale.com/types/nettype                                  from tailscale.com/ipn/localapi+

+ 1 - 0
cmd/tailscaled/depaware.txt

@@ -373,6 +373,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
         tailscale.com/types/lazy                                     from tailscale.com/ipn/ipnlocal+
         tailscale.com/types/logger                                   from tailscale.com/appc+
         tailscale.com/types/logid                                    from tailscale.com/cmd/tailscaled+
+        tailscale.com/types/mapx                                     from tailscale.com/ipn/ipnext
         tailscale.com/types/netlogtype                               from tailscale.com/net/connstats+
         tailscale.com/types/netmap                                   from tailscale.com/control/controlclient+
         tailscale.com/types/nettype                                  from tailscale.com/ipn/localapi+

+ 10 - 17
ipn/ipnext/ipnext.go

@@ -8,6 +8,7 @@ package ipnext
 import (
 	"errors"
 	"fmt"
+	"iter"
 
 	"tailscale.com/control/controlclient"
 	"tailscale.com/feature"
@@ -16,8 +17,7 @@ import (
 	"tailscale.com/tsd"
 	"tailscale.com/tstime"
 	"tailscale.com/types/logger"
-	"tailscale.com/types/views"
-	"tailscale.com/util/mak"
+	"tailscale.com/types/mapx"
 )
 
 // Extension augments LocalBackend with additional functionality.
@@ -91,13 +91,9 @@ func (d *Definition) MakeExtension(logf logger.Logf, sb SafeBackend) (Extension,
 	return ext, nil
 }
 
-// extensionsByName is a map of registered extensions,
+// extensions is a map of registered extensions,
 // where the key is the name of the extension.
-var extensionsByName map[string]*Definition
-
-// extensionsByOrder is a slice of registered extensions,
-// in the order they were registered.
-var extensionsByOrder []*Definition
+var extensions mapx.OrderedMap[string, *Definition]
 
 // RegisterExtension registers a function that instantiates an [Extension].
 // The name must be the same as returned by the extension's [Extension.Name].
@@ -111,19 +107,16 @@ func RegisterExtension(name string, newExt NewExtensionFn) {
 	if newExt == nil {
 		panic(fmt.Sprintf("ipnext: newExt is nil: %q", name))
 	}
-	if _, ok := extensionsByName[name]; ok {
+	if extensions.Contains(name) {
 		panic(fmt.Sprintf("ipnext: duplicate extensions: %q", name))
 	}
-	ext := &Definition{name, newExt}
-	mak.Set(&extensionsByName, name, ext)
-	extensionsByOrder = append(extensionsByOrder, ext)
+	extensions.Set(name, &Definition{name, newExt})
 }
 
-// Extensions returns a read-only view of the extensions
-// registered via [RegisterExtension]. It preserves the order
-// in which the extensions were registered.
-func Extensions() views.Slice[*Definition] {
-	return views.SliceOf(extensionsByOrder)
+// Extensions iterates over the extensions in the order they were registered
+// via [RegisterExtension].
+func Extensions() iter.Seq[*Definition] {
+	return extensions.Values()
 }
 
 // DefinitionForTest returns a [Definition] for the specified [Extension].

+ 3 - 6
ipn/ipnlocal/extension_host.go

@@ -162,17 +162,14 @@ func newExtensionHost(logf logger.Logf, b Backend, overrideExts ...*ipnext.Defin
 	}
 
 	// Use registered extensions.
-	exts := ipnext.Extensions().All()
-	numExts := ipnext.Extensions().Len()
+	extDef := ipnext.Extensions()
 	if overrideExts != nil {
 		// Use the provided, potentially empty, overrideExts
 		// instead of the registered ones.
-		exts = slices.All(overrideExts)
-		numExts = len(overrideExts)
+		extDef = slices.Values(overrideExts)
 	}
 
-	host.allExtensions = make([]ipnext.Extension, 0, numExts)
-	for _, d := range exts {
+	for d := range extDef {
 		ext, err := d.MakeExtension(logf, b)
 		if errors.Is(err, ipnext.SkipExtension) {
 			// The extension wants to be skipped.

+ 111 - 0
types/mapx/ordered.go

@@ -0,0 +1,111 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package mapx contains extra map types and functions.
+package mapx
+
+import (
+	"iter"
+	"slices"
+)
+
+// OrderedMap is a map that maintains the order of its keys.
+//
+// It is meant for maps that only grow or that are small;
+// is it not optimized for deleting keys.
+//
+// The zero value is ready to use.
+//
+// Locking-wise, it has the same rules as a regular Go map:
+// concurrent reads are safe, but not writes.
+type OrderedMap[K comparable, V any] struct {
+	// m is the underlying map.
+	m map[K]V
+
+	// keys is the order of keys in the map.
+	keys []K
+}
+
+func (m *OrderedMap[K, V]) init() {
+	if m.m == nil {
+		m.m = make(map[K]V)
+	}
+}
+
+// Set sets the value for the given key in the map.
+//
+// If the key already exists, it updates the value and keeps the order.
+func (m *OrderedMap[K, V]) Set(key K, value V) {
+	m.init()
+	len0 := len(m.keys)
+	m.m[key] = value
+	if len(m.m) > len0 {
+		// New key (not an update)
+		m.keys = append(m.keys, key)
+	}
+}
+
+// Get returns the value for the given key in the map.
+// If the key does not exist, it returns the zero value for V.
+func (m *OrderedMap[K, V]) Get(key K) V {
+	return m.m[key]
+}
+
+// GetOk returns the value for the given key in the map
+// and whether it was present in the map.
+func (m *OrderedMap[K, V]) GetOk(key K) (_ V, ok bool) {
+	v, ok := m.m[key]
+	return v, ok
+}
+
+// Contains reports whether the map contains the given key.
+func (m *OrderedMap[K, V]) Contains(key K) bool {
+	_, ok := m.m[key]
+	return ok
+}
+
+// Delete removes the key from the map.
+//
+// The cost is O(n) in the number of keys in the map.
+func (m *OrderedMap[K, V]) Delete(key K) {
+	len0 := len(m.m)
+	delete(m.m, key)
+	if len(m.m) == len0 {
+		// Wasn't present; no need to adjust keys.
+		return
+	}
+	was := m.keys
+	m.keys = m.keys[:0]
+	for _, k := range was {
+		if k != key {
+			m.keys = append(m.keys, k)
+		}
+	}
+}
+
+// All yields all the keys and values, in the order they were inserted.
+func (m *OrderedMap[K, V]) All() iter.Seq2[K, V] {
+	return func(yield func(K, V) bool) {
+		for _, k := range m.keys {
+			if !yield(k, m.m[k]) {
+				return
+			}
+		}
+	}
+}
+
+// Keys yields the map keys, in the order they were inserted.
+func (m *OrderedMap[K, V]) Keys() iter.Seq[K] {
+	return slices.Values(m.keys)
+}
+
+// Values yields the map values, in the order they were inserted.
+func (m *OrderedMap[K, V]) Values() iter.Seq[V] {
+	return func(yield func(V) bool) {
+		for _, k := range m.keys {
+			if !yield(m.m[k]) {
+				return
+			}
+		}
+	}
+}

+ 56 - 0
types/mapx/ordered_test.go

@@ -0,0 +1,56 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package mapx
+
+import (
+	"fmt"
+	"slices"
+	"testing"
+)
+
+func TestOrderedMap(t *testing.T) {
+	// Test the OrderedMap type and its methods.
+	var m OrderedMap[string, int]
+	m.Set("d", 4)
+	m.Set("a", 1)
+	m.Set("b", 1)
+	m.Set("b", 2)
+	m.Set("c", 3)
+	m.Delete("d")
+	m.Delete("e")
+
+	want := map[string]int{
+		"a": 1,
+		"b": 2,
+		"c": 3,
+		"d": 0,
+	}
+	for k, v := range want {
+		if m.Get(k) != v {
+			t.Errorf("Get(%q) = %d, want %d", k, m.Get(k), v)
+			continue
+		}
+		got, ok := m.GetOk(k)
+		if got != v {
+			t.Errorf("GetOk(%q) = %d, want %d", k, got, v)
+		}
+		if ok != m.Contains(k) {
+			t.Errorf("GetOk and Contains don't agree for %q", k)
+		}
+	}
+
+	if got, want := slices.Collect(m.Keys()), []string{"a", "b", "c"}; !slices.Equal(got, want) {
+		t.Errorf("Keys() = %q, want %q", got, want)
+	}
+	if got, want := slices.Collect(m.Values()), []int{1, 2, 3}; !slices.Equal(got, want) {
+		t.Errorf("Values() = %v, want %v", got, want)
+	}
+	var allGot []string
+	for k, v := range m.All() {
+		allGot = append(allGot, fmt.Sprintf("%s:%d", k, v))
+	}
+	if got, want := allGot, []string{"a:1", "b:2", "c:3"}; !slices.Equal(got, want) {
+		t.Errorf("All() = %q, want %q", got, want)
+	}
+}