shardvalue_test.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package syncs
  4. import (
  5. "math"
  6. "runtime"
  7. "sync"
  8. "sync/atomic"
  9. "testing"
  10. "golang.org/x/sys/cpu"
  11. )
  12. func TestShardValue(t *testing.T) {
  13. type intVal struct {
  14. atomic.Int64
  15. _ cpu.CacheLinePad
  16. }
  17. t.Run("One", func(t *testing.T) {
  18. sv := NewShardValue[intVal]()
  19. sv.One(func(v *intVal) {
  20. v.Store(10)
  21. })
  22. var v int64
  23. for i := range sv.shards {
  24. v += sv.shards[i].Load()
  25. }
  26. if v != 10 {
  27. t.Errorf("got %v, want 10", v)
  28. }
  29. })
  30. t.Run("All", func(t *testing.T) {
  31. sv := NewShardValue[intVal]()
  32. for i := range sv.shards {
  33. sv.shards[i].Store(int64(i))
  34. }
  35. var total int64
  36. sv.All(func(v *intVal) bool {
  37. total += v.Load()
  38. return true
  39. })
  40. // triangle coefficient lower one order due to 0 index
  41. want := int64(len(sv.shards) * (len(sv.shards) - 1) / 2)
  42. if total != want {
  43. t.Errorf("got %v, want %v", total, want)
  44. }
  45. })
  46. t.Run("Len", func(t *testing.T) {
  47. sv := NewShardValue[intVal]()
  48. if got, want := sv.Len(), runtime.NumCPU(); got != want {
  49. t.Errorf("got %v, want %v", got, want)
  50. }
  51. })
  52. t.Run("distribution", func(t *testing.T) {
  53. sv := NewShardValue[intVal]()
  54. goroutines := 1000
  55. iterations := 10000
  56. var wg sync.WaitGroup
  57. wg.Add(goroutines)
  58. for i := 0; i < goroutines; i++ {
  59. go func() {
  60. defer wg.Done()
  61. for i := 0; i < iterations; i++ {
  62. sv.One(func(v *intVal) {
  63. v.Add(1)
  64. })
  65. }
  66. }()
  67. }
  68. wg.Wait()
  69. var (
  70. total int64
  71. distribution []int64
  72. )
  73. t.Logf("distribution:")
  74. sv.All(func(v *intVal) bool {
  75. total += v.Load()
  76. distribution = append(distribution, v.Load())
  77. t.Logf("%d", v.Load())
  78. return true
  79. })
  80. if got, want := total, int64(goroutines*iterations); got != want {
  81. t.Errorf("got %v, want %v", got, want)
  82. }
  83. if got, want := len(distribution), runtime.NumCPU(); got != want {
  84. t.Errorf("got %v, want %v", got, want)
  85. }
  86. mean := total / int64(len(distribution))
  87. for _, v := range distribution {
  88. if v < mean/10 || v > mean*10 {
  89. t.Logf("distribution is very unbalanced: %v", distribution)
  90. }
  91. }
  92. t.Logf("mean: %d", mean)
  93. var standardDev int64
  94. for _, v := range distribution {
  95. standardDev += ((v - mean) * (v - mean))
  96. }
  97. standardDev = int64(math.Sqrt(float64(standardDev / int64(len(distribution)))))
  98. t.Logf("stdev: %d", standardDev)
  99. if standardDev > mean/3 {
  100. t.Logf("standard deviation is too high: %v", standardDev)
  101. }
  102. })
  103. }