| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- // Copyright (c) Tailscale Inc & AUTHORS
- // SPDX-License-Identifier: BSD-3-Clause
- package syncs
- import (
- "context"
- "io"
- "os"
- "sync"
- "testing"
- "time"
- "github.com/google/go-cmp/cmp"
- )
- func TestAtomicValue(t *testing.T) {
- {
- // Always wrapping should not allocate for simple values
- // because wrappedValue[T] has the same memory layout as T.
- var v AtomicValue[bool]
- bools := []bool{true, false}
- if n := int(testing.AllocsPerRun(1000, func() {
- for _, b := range bools {
- v.Store(b)
- }
- })); n != 0 {
- t.Errorf("AllocsPerRun = %d, want 0", n)
- }
- }
- {
- var v AtomicValue[int]
- got, gotOk := v.LoadOk()
- if got != 0 || gotOk {
- t.Fatalf("LoadOk = (%v, %v), want (0, false)", got, gotOk)
- }
- v.Store(1)
- got, gotOk = v.LoadOk()
- if got != 1 || !gotOk {
- t.Fatalf("LoadOk = (%v, %v), want (1, true)", got, gotOk)
- }
- }
- {
- var v AtomicValue[error]
- got, gotOk := v.LoadOk()
- if got != nil || gotOk {
- t.Fatalf("LoadOk = (%v, %v), want (nil, false)", got, gotOk)
- }
- v.Store(io.EOF)
- got, gotOk = v.LoadOk()
- if got != io.EOF || !gotOk {
- t.Fatalf("LoadOk = (%v, %v), want (EOF, true)", got, gotOk)
- }
- err := &os.PathError{}
- v.Store(err)
- got, gotOk = v.LoadOk()
- if got != err || !gotOk {
- t.Fatalf("LoadOk = (%v, %v), want (%v, true)", got, gotOk, err)
- }
- v.Store(nil)
- got, gotOk = v.LoadOk()
- if got != nil || !gotOk {
- t.Fatalf("LoadOk = (%v, %v), want (nil, true)", got, gotOk)
- }
- }
- {
- c1, c2, c3 := make(chan struct{}), make(chan struct{}), make(chan struct{})
- var v AtomicValue[chan struct{}]
- if v.CompareAndSwap(c1, c2) != false {
- t.Fatalf("CompareAndSwap = true, want false")
- }
- if v.CompareAndSwap(nil, c1) != true {
- t.Fatalf("CompareAndSwap = false, want true")
- }
- if v.CompareAndSwap(c2, c3) != false {
- t.Fatalf("CompareAndSwap = true, want false")
- }
- if v.CompareAndSwap(c1, c2) != true {
- t.Fatalf("CompareAndSwap = false, want true")
- }
- }
- }
- func TestMutexValue(t *testing.T) {
- var v MutexValue[time.Time]
- if n := int(testing.AllocsPerRun(1000, func() {
- v.Store(v.Load())
- v.WithLock(func(*time.Time) {})
- })); n != 0 {
- t.Errorf("AllocsPerRun = %d, want 0", n)
- }
- now := time.Now()
- v.Store(now)
- if !v.Load().Equal(now) {
- t.Errorf("Load = %v, want %v", v.Load(), now)
- }
- var group sync.WaitGroup
- var v2 MutexValue[int]
- var sum int
- for i := range 10 {
- group.Go(func() {
- old1 := v2.Load()
- old2 := v2.Swap(old1 + i)
- delta := old2 - old1
- v2.WithLock(func(p *int) { *p += delta })
- })
- sum += i
- }
- group.Wait()
- if v2.Load() != sum {
- t.Errorf("Load = %v, want %v", v2.Load(), sum)
- }
- }
- func TestWaitGroupChan(t *testing.T) {
- wg := NewWaitGroupChan()
- wantNotDone := func() {
- t.Helper()
- select {
- case <-wg.DoneChan():
- t.Fatal("done too early")
- default:
- }
- }
- wantDone := func() {
- t.Helper()
- select {
- case <-wg.DoneChan():
- default:
- t.Fatal("expected to be done")
- }
- }
- wg.Add(2)
- wantNotDone()
- wg.Decr()
- wantNotDone()
- wg.Decr()
- wantDone()
- wantDone()
- }
- func TestClosedChan(t *testing.T) {
- ch := ClosedChan()
- for range 2 {
- select {
- case <-ch:
- default:
- t.Fatal("not closed")
- }
- }
- }
- func TestSemaphore(t *testing.T) {
- s := NewSemaphore(2)
- assertLen := func(want int) {
- t.Helper()
- if got := s.Len(); got != want {
- t.Fatalf("Len = %d, want %d", got, want)
- }
- }
- assertLen(0)
- s.Acquire()
- assertLen(1)
- if !s.TryAcquire() {
- t.Fatal("want true")
- }
- assertLen(2)
- if s.TryAcquire() {
- t.Fatal("want false")
- }
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- if s.AcquireContext(ctx) {
- t.Fatal("want false")
- }
- s.Release()
- assertLen(1)
- if !s.AcquireContext(context.Background()) {
- t.Fatal("want true")
- }
- assertLen(2)
- s.Release()
- assertLen(1)
- s.Release()
- assertLen(0)
- }
- func TestMap(t *testing.T) {
- var m Map[string, int]
- if v, ok := m.Load("noexist"); v != 0 || ok {
- t.Errorf(`Load("noexist") = (%v, %v), want (0, false)`, v, ok)
- }
- m.LoadFunc("noexist", func(v int, ok bool) {
- if v != 0 || ok {
- t.Errorf(`LoadFunc("noexist") = (%v, %v), want (0, false)`, v, ok)
- }
- })
- m.Store("one", 1)
- if v, ok := m.LoadOrStore("one", -1); v != 1 || !ok {
- t.Errorf(`LoadOrStore("one", 1) = (%v, %v), want (1, true)`, v, ok)
- }
- if v, ok := m.Load("one"); v != 1 || !ok {
- t.Errorf(`Load("one") = (%v, %v), want (1, true)`, v, ok)
- }
- m.LoadFunc("one", func(v int, ok bool) {
- if v != 1 || !ok {
- t.Errorf(`LoadFunc("one") = (%v, %v), want (1, true)`, v, ok)
- }
- })
- if v, ok := m.LoadOrStore("two", 2); v != 2 || ok {
- t.Errorf(`LoadOrStore("two", 2) = (%v, %v), want (2, false)`, v, ok)
- }
- if v, ok := m.LoadOrInit("three", func() int { return 3 }); v != 3 || ok {
- t.Errorf(`LoadOrInit("three", 3) = (%v, %v), want (3, true)`, v, ok)
- }
- got := map[string]int{}
- want := map[string]int{"one": 1, "two": 2, "three": 3}
- for k, v := range m.All() {
- got[k] = v
- }
- if d := cmp.Diff(got, want); d != "" {
- t.Errorf("Range mismatch (-got +want):\n%s", d)
- }
- if v, ok := m.LoadAndDelete("two"); v != 2 || !ok {
- t.Errorf(`LoadAndDelete("two) = (%v, %v), want (2, true)`, v, ok)
- }
- if v, ok := m.LoadAndDelete("two"); v != 0 || ok {
- t.Errorf(`LoadAndDelete("two) = (%v, %v), want (0, false)`, v, ok)
- }
- m.Delete("three")
- m.Delete("one")
- m.Delete("noexist")
- got = map[string]int{}
- want = map[string]int{}
- for k, v := range m.All() {
- got[k] = v
- }
- if d := cmp.Diff(got, want); d != "" {
- t.Errorf("Range mismatch (-got +want):\n%s", d)
- }
- t.Run("LoadOrStore", func(t *testing.T) {
- var m Map[string, string]
- var wg sync.WaitGroup
- var ok1, ok2 bool
- wg.Go(func() { _, ok1 = m.LoadOrStore("", "") })
- wg.Go(func() { _, ok2 = m.LoadOrStore("", "") })
- wg.Wait()
- if ok1 == ok2 {
- t.Errorf("exactly one LoadOrStore should load")
- }
- })
- t.Run("Clear", func(t *testing.T) {
- var m Map[string, string]
- _, _ = m.LoadOrStore("a", "1")
- _, _ = m.LoadOrStore("b", "2")
- _, _ = m.LoadOrStore("c", "3")
- _, _ = m.LoadOrStore("d", "4")
- _, _ = m.LoadOrStore("e", "5")
- if m.Len() != 5 {
- t.Errorf("Len after loading want=5 got=%d", m.Len())
- }
- m.Clear()
- if m.Len() != 0 {
- t.Errorf("Len after Clear want=0 got=%d", m.Len())
- }
- })
- t.Run("Swap", func(t *testing.T) {
- var m Map[string, string]
- m.Store("hello", "world")
- if got, want := m.Swap("hello", "world2"), "world"; got != want {
- t.Errorf("got old value %q, want %q", got, want)
- }
- if got := m.Swap("empty", "foo"); got != "" {
- t.Errorf("got old value %q, want empty string", got)
- }
- })
- }
|