shardedmap.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package syncs
  4. import (
  5. "sync"
  6. "golang.org/x/sys/cpu"
  7. )
  8. // ShardedMap is a synchronized map[K]V, internally sharded by a user-defined
  9. // K-sharding function.
  10. //
  11. // The zero value is not safe for use; use NewShardedMap.
  12. type ShardedMap[K comparable, V any] struct {
  13. shardFunc func(K) int
  14. shards []mapShard[K, V]
  15. }
  16. type mapShard[K comparable, V any] struct {
  17. mu sync.Mutex
  18. m map[K]V
  19. _ cpu.CacheLinePad // avoid false sharing of neighboring shards' mutexes
  20. }
  21. // NewShardedMap returns a new ShardedMap with the given number of shards and
  22. // sharding function.
  23. //
  24. // The shard func must return a integer in the range [0, shards) purely
  25. // deterministically based on the provided K.
  26. func NewShardedMap[K comparable, V any](shards int, shard func(K) int) *ShardedMap[K, V] {
  27. m := &ShardedMap[K, V]{
  28. shardFunc: shard,
  29. shards: make([]mapShard[K, V], shards),
  30. }
  31. for i := range m.shards {
  32. m.shards[i].m = make(map[K]V)
  33. }
  34. return m
  35. }
  36. func (m *ShardedMap[K, V]) shard(key K) *mapShard[K, V] {
  37. return &m.shards[m.shardFunc(key)]
  38. }
  39. // GetOk returns m[key] and whether it was present.
  40. func (m *ShardedMap[K, V]) GetOk(key K) (value V, ok bool) {
  41. shard := m.shard(key)
  42. shard.mu.Lock()
  43. defer shard.mu.Unlock()
  44. value, ok = shard.m[key]
  45. return
  46. }
  47. // Get returns m[key] or the zero value of V if key is not present.
  48. func (m *ShardedMap[K, V]) Get(key K) (value V) {
  49. value, _ = m.GetOk(key)
  50. return
  51. }
  52. // Mutate atomically mutates m[k] by calling mutator.
  53. //
  54. // The mutator function is called with the old value (or its zero value) and
  55. // whether it existed in the map and it returns the new value and whether it
  56. // should be set in the map (true) or deleted from the map (false).
  57. //
  58. // It returns the change in size of the map as a result of the mutation, one of
  59. // -1 (delete), 0 (change), or 1 (addition).
  60. func (m *ShardedMap[K, V]) Mutate(key K, mutator func(oldValue V, oldValueExisted bool) (newValue V, keep bool)) (sizeDelta int) {
  61. shard := m.shard(key)
  62. shard.mu.Lock()
  63. defer shard.mu.Unlock()
  64. oldV, oldOK := shard.m[key]
  65. newV, newOK := mutator(oldV, oldOK)
  66. if newOK {
  67. shard.m[key] = newV
  68. if oldOK {
  69. return 0
  70. }
  71. return 1
  72. }
  73. delete(shard.m, key)
  74. if oldOK {
  75. return -1
  76. }
  77. return 0
  78. }
  79. // Set sets m[key] = value.
  80. //
  81. // present in m).
  82. func (m *ShardedMap[K, V]) Set(key K, value V) (grew bool) {
  83. shard := m.shard(key)
  84. shard.mu.Lock()
  85. defer shard.mu.Unlock()
  86. s0 := len(shard.m)
  87. shard.m[key] = value
  88. return len(shard.m) > s0
  89. }
  90. // Delete removes key from m.
  91. //
  92. // It reports whether the map size shrunk (that is, whether key was present in
  93. // the map).
  94. func (m *ShardedMap[K, V]) Delete(key K) (shrunk bool) {
  95. shard := m.shard(key)
  96. shard.mu.Lock()
  97. defer shard.mu.Unlock()
  98. s0 := len(shard.m)
  99. delete(shard.m, key)
  100. return len(shard.m) < s0
  101. }
  102. // Contains reports whether m contains key.
  103. func (m *ShardedMap[K, V]) Contains(key K) bool {
  104. shard := m.shard(key)
  105. shard.mu.Lock()
  106. defer shard.mu.Unlock()
  107. _, ok := shard.m[key]
  108. return ok
  109. }
  110. // Len returns the number of elements in m.
  111. //
  112. // It does so by locking shards one at a time, so it's not particularly cheap,
  113. // nor does it give a consistent snapshot of the map. It's mostly intended for
  114. // metrics or testing.
  115. func (m *ShardedMap[K, V]) Len() int {
  116. n := 0
  117. for i := range m.shards {
  118. shard := &m.shards[i]
  119. shard.mu.Lock()
  120. n += len(shard.m)
  121. shard.mu.Unlock()
  122. }
  123. return n
  124. }