Browse Source

syncs: add Map.LoadFunc (#9869)

The LoadFunc loads a value and calls a user-provided function.
The utility of this method is to ensure that the map lock is held
while executing user-provided logic.
This allows us to solve TOCTOU bugs that would be nearly imposible
to the solve without this API.

Updates tailscale/corp#14772

Signed-off-by: Joe Tsai <[email protected]>
Joe Tsai 2 years ago
parent
commit
674beabc73
2 changed files with 30 additions and 3 deletions
  1. 20 3
      syncs/syncs.go
  2. 10 0
      syncs/syncs_test.go

+ 20 - 3
syncs/syncs.go

@@ -164,19 +164,33 @@ type Map[K comparable, V any] struct {
 	m  map[K]V
 }
 
-func (m *Map[K, V]) Load(key K) (value V, ok bool) {
+// Load loads the value for the provided key and whether it was found.
+func (m *Map[K, V]) Load(key K) (value V, loaded bool) {
 	m.mu.RLock()
 	defer m.mu.RUnlock()
-	value, ok = m.m[key]
-	return value, ok
+	value, loaded = m.m[key]
+	return value, loaded
+}
+
+// LoadFunc calls f with the value for the provided key
+// regardless of whether the entry exists or not.
+// The lock is held for the duration of the call to f.
+func (m *Map[K, V]) LoadFunc(key K, f func(value V, loaded bool)) {
+	m.mu.RLock()
+	defer m.mu.RUnlock()
+	value, loaded := m.m[key]
+	f(value, loaded)
 }
 
+// Store stores the value for the provided key.
 func (m *Map[K, V]) Store(key K, value V) {
 	m.mu.Lock()
 	defer m.mu.Unlock()
 	mak.Set(&m.m, key, value)
 }
 
+// LoadOrStore returns the value for the given key if it exists
+// otherwise it stores value.
 func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
 	if actual, loaded = m.Load(key); loaded {
 		return actual, loaded
@@ -212,6 +226,8 @@ func (m *Map[K, V]) LoadOrInit(key K, f func() V) (actual V, loaded bool) {
 	return actual, loaded
 }
 
+// LoadAndDelete returns the value for the given key if it exists.
+// It ensures that the map is cleared of any entry for the key.
 func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
 	m.mu.Lock()
 	defer m.mu.Unlock()
@@ -222,6 +238,7 @@ func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
 	return value, loaded
 }
 
+// Delete deletes the entry identified by key.
 func (m *Map[K, V]) Delete(key K) {
 	m.mu.Lock()
 	defer m.mu.Unlock()

+ 10 - 0
syncs/syncs_test.go

@@ -81,6 +81,11 @@ func TestMap(t *testing.T) {
 	if v, ok := m.Load("noexist"); v != 0 || ok {
 		t.Errorf(`Load("noexist") = (%v, %v), want (0, false)`, v, ok)
 	}
+	m.LoadFunc("noexist", func(v int, ok bool) {
+		if v != 0 || ok {
+			t.Errorf(`LoadFunc("noexist") = (%v, %v), want (0, false)`, v, ok)
+		}
+	})
 	m.Store("one", 1)
 	if v, ok := m.LoadOrStore("one", -1); v != 1 || !ok {
 		t.Errorf(`LoadOrStore("one", 1) = (%v, %v), want (1, true)`, v, ok)
@@ -88,6 +93,11 @@ func TestMap(t *testing.T) {
 	if v, ok := m.Load("one"); v != 1 || !ok {
 		t.Errorf(`Load("one") = (%v, %v), want (1, true)`, v, ok)
 	}
+	m.LoadFunc("one", func(v int, ok bool) {
+		if v != 1 || !ok {
+			t.Errorf(`LoadFunc("one") = (%v, %v), want (1, true)`, v, ok)
+		}
+	})
 	if v, ok := m.LoadOrStore("two", 2); v != 2 || ok {
 		t.Errorf(`LoadOrStore("two", 2) = (%v, %v), want (2, false)`, v, ok)
 	}