Selaa lähdekoodia

types/lazy: helpers for lazily computed values

Co-authored-by: Maisem Ali <[email protected]>
Co-authored-by: Brad Fitzpatrick <[email protected]>
Signed-off-by: David Anderson <[email protected]>
David Anderson 3 vuotta sitten
vanhempi
sitoutus
9e6b4d7ad8
4 muutettua tiedostoa jossa 477 lisäystä ja 0 poistoa
  1. 88 0
      types/lazy/lazy.go
  2. 150 0
      types/lazy/sync_test.go
  3. 99 0
      types/lazy/unsync.go
  4. 140 0
      types/lazy/unsync_test.go

+ 88 - 0
types/lazy/lazy.go

@@ -0,0 +1,88 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// Package lazy provides types for lazily initialized values.
+package lazy
+
+import "sync"
+
+// SyncValue is a lazily computed value.
+//
+// Use either Get or GetErr, depending on whether your fill function returns an
+// error.
+//
+// Recursive use of a SyncValue from its own fill function will deadlock.
+//
+// SyncValue is safe for concurrent use.
+type SyncValue[T any] struct {
+	once sync.Once
+	v    T
+	err  error
+}
+
+// Set attempts to set z's value to val, and reports whether it succeeded.
+// Set only succeeds if none of Get/GetErr/Set have been called before.
+func (z *SyncValue[T]) Set(val T) bool {
+	var wasSet bool
+	z.once.Do(func() {
+		z.v = val
+		wasSet = true
+	})
+	return wasSet
+}
+
+// MustSet sets z's value to val, or panics if z already has a value.
+func (z *SyncValue[T]) MustSet(val T) {
+	if !z.Set(val) {
+		panic("Set after already filled")
+	}
+}
+
+// Get returns z's value, calling fill to compute it if necessary.
+// f is called at most once.
+func (z *SyncValue[T]) Get(fill func() T) T {
+	z.once.Do(func() { z.v = fill() })
+	return z.v
+}
+
+// GetErr returns z's value, calling fill to compute it if necessary.
+// f is called at most once, and z remembers both of fill's outputs.
+func (z *SyncValue[T]) GetErr(fill func() (T, error)) (T, error) {
+	z.once.Do(func() { z.v, z.err = fill() })
+	return z.v, z.err
+}
+
+// SyncFunc wraps a function to make it lazy.
+//
+// The returned function calls fill the first time it's called, and returns
+// fill's result on every subsequent call.
+//
+// The returned function is safe for concurrent use.
+func SyncFunc[T any](fill func() T) func() T {
+	var (
+		once sync.Once
+		v    T
+	)
+	return func() T {
+		once.Do(func() { v = fill() })
+		return v
+	}
+}
+
+// SyncFuncErr wraps a function to make it lazy.
+//
+// The returned function calls fill the first time it's called, and returns
+// fill's results on every subsequent call.
+//
+// The returned function is safe for concurrent use.
+func SyncFuncErr[T any](fill func() (T, error)) func() (T, error) {
+	var (
+		once sync.Once
+		v    T
+		err  error
+	)
+	return func() (T, error) {
+		once.Do(func() { v, err = fill() })
+		return v, err
+	}
+}

+ 150 - 0
types/lazy/sync_test.go

@@ -0,0 +1,150 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package lazy
+
+import (
+	"errors"
+	"sync"
+	"testing"
+)
+
+func TestSyncValue(t *testing.T) {
+	var lt SyncValue[int]
+	n := int(testing.AllocsPerRun(1000, func() {
+		got := lt.Get(fortyTwo)
+		if got != 42 {
+			t.Fatalf("got %v; want 42", got)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+}
+
+func TestSyncValueErr(t *testing.T) {
+	var lt SyncValue[int]
+	n := int(testing.AllocsPerRun(1000, func() {
+		got, err := lt.GetErr(func() (int, error) {
+			return 42, nil
+		})
+		if got != 42 || err != nil {
+			t.Fatalf("got %v, %v; want 42, nil", got, err)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+
+	var lterr SyncValue[int]
+	wantErr := errors.New("test error")
+	n = int(testing.AllocsPerRun(1000, func() {
+		got, err := lterr.GetErr(func() (int, error) {
+			return 0, wantErr
+		})
+		if got != 0 || err != wantErr {
+			t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+}
+
+func TestSyncValueSet(t *testing.T) {
+	var lt SyncValue[int]
+	if !lt.Set(42) {
+		t.Fatalf("Set failed")
+	}
+	if lt.Set(43) {
+		t.Fatalf("Set succeeded after first Set")
+	}
+	n := int(testing.AllocsPerRun(1000, func() {
+		got := lt.Get(fortyTwo)
+		if got != 42 {
+			t.Fatalf("got %v; want 42", got)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+}
+
+func TestSyncValueMustSet(t *testing.T) {
+	var lt SyncValue[int]
+	lt.MustSet(42)
+	defer func() {
+		if e := recover(); e == nil {
+			t.Errorf("unexpected success; want panic")
+		}
+	}()
+	lt.MustSet(43)
+}
+
+func TestSyncValueConcurrent(t *testing.T) {
+	var (
+		lt       SyncValue[int]
+		wg       sync.WaitGroup
+		start    = make(chan struct{})
+		routines = 10000
+	)
+	wg.Add(routines)
+	for i := 0; i < routines; i++ {
+		go func() {
+			defer wg.Done()
+			// Every goroutine waits for the go signal, so that more of them
+			// have a chance to race on the initial Get than with sequential
+			// goroutine starts.
+			<-start
+			got := lt.Get(fortyTwo)
+			if got != 42 {
+				t.Errorf("got %v; want 42", got)
+			}
+		}()
+	}
+	close(start)
+	wg.Wait()
+}
+
+func TestSyncFunc(t *testing.T) {
+	f := SyncFunc(fortyTwo)
+
+	n := int(testing.AllocsPerRun(1000, func() {
+		got := f()
+		if got != 42 {
+			t.Fatalf("got %v; want 42", got)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+}
+
+func TestSyncFuncErr(t *testing.T) {
+	f := SyncFuncErr(func() (int, error) {
+		return 42, nil
+	})
+	n := int(testing.AllocsPerRun(1000, func() {
+		got, err := f()
+		if got != 42 || err != nil {
+			t.Fatalf("got %v, %v; want 42, nil", got, err)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+
+	wantErr := errors.New("test error")
+	f = SyncFuncErr(func() (int, error) {
+		return 0, wantErr
+	})
+	n = int(testing.AllocsPerRun(1000, func() {
+		got, err := f()
+		if got != 0 || err != wantErr {
+			t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+}

+ 99 - 0
types/lazy/unsync.go

@@ -0,0 +1,99 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package lazy
+
+// GValue is a lazily computed value.
+//
+// Use either Get or GetErr, depending on whether your fill function returns an
+// error.
+//
+// Recursive use of a GValue from its own fill function will panic.
+//
+// GValue is not safe for concurrent use. (Mnemonic: G is for one Goroutine,
+// which isn't strictly true if you provide your own synchronization between
+// goroutines, but in practice most of our callers have been using it within
+// a single goroutine.)
+type GValue[T any] struct {
+	done    bool
+	calling bool
+	V       T
+	err     error
+}
+
+// Set attempts to set z's value to val, and reports whether it succeeded.
+// Set only succeeds if none of Get/GetErr/Set have been called before.
+func (z *GValue[T]) Set(v T) bool {
+	if z.done {
+		return false
+	}
+	if z.calling {
+		panic("Set while Get fill is running")
+	}
+	z.V = v
+	z.done = true
+	return true
+}
+
+// MustSet sets z's value to val, or panics if z already has a value.
+func (z *GValue[T]) MustSet(val T) {
+	if !z.Set(val) {
+		panic("Set after already filled")
+	}
+}
+
+// Get returns z's value, calling fill to compute it if necessary.
+// f is called at most once.
+func (z *GValue[T]) Get(fill func() T) T {
+	if !z.done {
+		if z.calling {
+			panic("recursive lazy fill")
+		}
+		z.calling = true
+		z.V = fill()
+		z.done = true
+		z.calling = false
+	}
+	return z.V
+}
+
+// GetErr returns z's value, calling fill to compute it if necessary.
+// f is called at most once, and z remembers both of fill's outputs.
+func (z *GValue[T]) GetErr(fill func() (T, error)) (T, error) {
+	if !z.done {
+		if z.calling {
+			panic("recursive lazy fill")
+		}
+		z.calling = true
+		z.V, z.err = fill()
+		z.done = true
+		z.calling = false
+	}
+	return z.V, z.err
+}
+
+// GFunc wraps a function to make it lazy.
+//
+// The returned function calls fill the first time it's called, and returns
+// fill's result on every subsequent call.
+//
+// The returned function is not safe for concurrent use.
+func GFunc[T any](fill func() T) func() T {
+	var v GValue[T]
+	return func() T {
+		return v.Get(fill)
+	}
+}
+
+// SyncFuncErr wraps a function to make it lazy.
+//
+// The returned function calls fill the first time it's called, and returns
+// fill's results on every subsequent call.
+//
+// The returned function is not safe for concurrent use.
+func GFuncErr[T any](fill func() (T, error)) func() (T, error) {
+	var v GValue[T]
+	return func() (T, error) {
+		return v.GetErr(fill)
+	}
+}

+ 140 - 0
types/lazy/unsync_test.go

@@ -0,0 +1,140 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+package lazy
+
+import (
+	"errors"
+	"testing"
+)
+
+func fortyTwo() int { return 42 }
+
+func TestGValue(t *testing.T) {
+	var lt GValue[int]
+	n := int(testing.AllocsPerRun(1000, func() {
+		got := lt.Get(fortyTwo)
+		if got != 42 {
+			t.Fatalf("got %v; want 42", got)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+}
+
+func TestGValueErr(t *testing.T) {
+	var lt GValue[int]
+	n := int(testing.AllocsPerRun(1000, func() {
+		got, err := lt.GetErr(func() (int, error) {
+			return 42, nil
+		})
+		if got != 42 || err != nil {
+			t.Fatalf("got %v, %v; want 42, nil", got, err)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+
+	var lterr GValue[int]
+	wantErr := errors.New("test error")
+	n = int(testing.AllocsPerRun(1000, func() {
+		got, err := lterr.GetErr(func() (int, error) {
+			return 0, wantErr
+		})
+		if got != 0 || err != wantErr {
+			t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+}
+
+func TestGValueSet(t *testing.T) {
+	var lt GValue[int]
+	if !lt.Set(42) {
+		t.Fatalf("Set failed")
+	}
+	if lt.Set(43) {
+		t.Fatalf("Set succeeded after first Set")
+	}
+	n := int(testing.AllocsPerRun(1000, func() {
+		got := lt.Get(fortyTwo)
+		if got != 42 {
+			t.Fatalf("got %v; want 42", got)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+}
+
+func TestGValueMustSet(t *testing.T) {
+	var lt GValue[int]
+	lt.MustSet(42)
+	defer func() {
+		if e := recover(); e == nil {
+			t.Errorf("unexpected success; want panic")
+		}
+	}()
+	lt.MustSet(43)
+}
+
+func TestGValueRecursivePanic(t *testing.T) {
+	defer func() {
+		if e := recover(); e != nil {
+			t.Logf("got panic, as expected")
+		} else {
+			t.Errorf("unexpected success; want panic")
+		}
+	}()
+	v := GValue[int]{}
+	v.Get(func() int {
+		return v.Get(func() int { return 42 })
+	})
+}
+
+func TestGFunc(t *testing.T) {
+	f := GFunc(fortyTwo)
+
+	n := int(testing.AllocsPerRun(1000, func() {
+		got := f()
+		if got != 42 {
+			t.Fatalf("got %v; want 42", got)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+}
+
+func TestGFuncErr(t *testing.T) {
+	f := GFuncErr(func() (int, error) {
+		return 42, nil
+	})
+	n := int(testing.AllocsPerRun(1000, func() {
+		got, err := f()
+		if got != 42 || err != nil {
+			t.Fatalf("got %v, %v; want 42, nil", got, err)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+
+	wantErr := errors.New("test error")
+	f = GFuncErr(func() (int, error) {
+		return 0, wantErr
+	})
+	n = int(testing.AllocsPerRun(1000, func() {
+		got, err := f()
+		if got != 0 || err != wantErr {
+			t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr)
+		}
+	}))
+	if n != 0 {
+		t.Errorf("allocs = %v; want 0", n)
+	}
+}