maps_test.go 13 KB


  1. package csync
  2. import (
  3. "encoding/json"
  4. "maps"
  5. "sync"
  6. "testing"
  7. "testing/synctest"
  8. "time"
  9. "github.com/stretchr/testify/require"
  10. )
  11. func TestNewMap(t *testing.T) {
  12. t.Parallel()
  13. m := NewMap[string, int]()
  14. require.NotNil(t, m)
  15. require.NotNil(t, m.inner)
  16. require.Equal(t, 0, m.Len())
  17. }
  18. func TestNewMapFrom(t *testing.T) {
  19. t.Parallel()
  20. original := map[string]int{
  21. "key1": 1,
  22. "key2": 2,
  23. }
  24. m := NewMapFrom(original)
  25. require.NotNil(t, m)
  26. require.Equal(t, original, m.inner)
  27. require.Equal(t, 2, m.Len())
  28. value, ok := m.Get("key1")
  29. require.True(t, ok)
  30. require.Equal(t, 1, value)
  31. }
  32. func TestNewLazyMap(t *testing.T) {
  33. t.Parallel()
  34. synctest.Test(t, func(t *testing.T) {
  35. t.Helper()
  36. waiter := sync.Mutex{}
  37. waiter.Lock()
  38. loadCalled := false
  39. loadFunc := func() map[string]int {
  40. waiter.Lock()
  41. defer waiter.Unlock()
  42. loadCalled = true
  43. return map[string]int{
  44. "key1": 1,
  45. "key2": 2,
  46. }
  47. }
  48. m := NewLazyMap(loadFunc)
  49. require.NotNil(t, m)
  50. waiter.Unlock() // Allow the load function to proceed
  51. time.Sleep(100 * time.Millisecond)
  52. require.True(t, loadCalled)
  53. require.Equal(t, 2, m.Len())
  54. value, ok := m.Get("key1")
  55. require.True(t, ok)
  56. require.Equal(t, 1, value)
  57. })
  58. }
  59. func TestMap_Reset(t *testing.T) {
  60. t.Parallel()
  61. m := NewMapFrom(map[string]int{
  62. "a": 10,
  63. })
  64. m.Reset(map[string]int{
  65. "b": 20,
  66. })
  67. value, ok := m.Get("b")
  68. require.True(t, ok)
  69. require.Equal(t, 20, value)
  70. require.Equal(t, 1, m.Len())
  71. }
  72. func TestMap_Set(t *testing.T) {
  73. t.Parallel()
  74. m := NewMap[string, int]()
  75. m.Set("key1", 42)
  76. value, ok := m.Get("key1")
  77. require.True(t, ok)
  78. require.Equal(t, 42, value)
  79. require.Equal(t, 1, m.Len())
  80. m.Set("key1", 100)
  81. value, ok = m.Get("key1")
  82. require.True(t, ok)
  83. require.Equal(t, 100, value)
  84. require.Equal(t, 1, m.Len())
  85. }
  86. func TestMap_GetOrSet(t *testing.T) {
  87. t.Parallel()
  88. m := NewMap[string, int]()
  89. require.Equal(t, 42, m.GetOrSet("key1", func() int { return 42 }))
  90. require.Equal(t, 42, m.GetOrSet("key1", func() int { return 99999 }))
  91. require.Equal(t, 1, m.Len())
  92. }
  93. func TestMap_Get(t *testing.T) {
  94. t.Parallel()
  95. m := NewMap[string, int]()
  96. value, ok := m.Get("nonexistent")
  97. require.False(t, ok)
  98. require.Equal(t, 0, value)
  99. m.Set("key1", 42)
  100. value, ok = m.Get("key1")
  101. require.True(t, ok)
  102. require.Equal(t, 42, value)
  103. }
  104. func TestMap_Del(t *testing.T) {
  105. t.Parallel()
  106. m := NewMap[string, int]()
  107. m.Set("key1", 42)
  108. m.Set("key2", 100)
  109. require.Equal(t, 2, m.Len())
  110. m.Del("key1")
  111. _, ok := m.Get("key1")
  112. require.False(t, ok)
  113. require.Equal(t, 1, m.Len())
  114. value, ok := m.Get("key2")
  115. require.True(t, ok)
  116. require.Equal(t, 100, value)
  117. m.Del("nonexistent")
  118. require.Equal(t, 1, m.Len())
  119. }
  120. func TestMap_Len(t *testing.T) {
  121. t.Parallel()
  122. m := NewMap[string, int]()
  123. require.Equal(t, 0, m.Len())
  124. m.Set("key1", 1)
  125. require.Equal(t, 1, m.Len())
  126. m.Set("key2", 2)
  127. require.Equal(t, 2, m.Len())
  128. m.Del("key1")
  129. require.Equal(t, 1, m.Len())
  130. m.Del("key2")
  131. require.Equal(t, 0, m.Len())
  132. }
  133. func TestMap_Take(t *testing.T) {
  134. t.Parallel()
  135. m := NewMap[string, int]()
  136. m.Set("key1", 42)
  137. m.Set("key2", 100)
  138. require.Equal(t, 2, m.Len())
  139. value, ok := m.Take("key1")
  140. require.True(t, ok)
  141. require.Equal(t, 42, value)
  142. require.Equal(t, 1, m.Len())
  143. _, exists := m.Get("key1")
  144. require.False(t, exists)
  145. value, ok = m.Get("key2")
  146. require.True(t, ok)
  147. require.Equal(t, 100, value)
  148. }
  149. func TestMap_Take_NonexistentKey(t *testing.T) {
  150. t.Parallel()
  151. m := NewMap[string, int]()
  152. m.Set("key1", 42)
  153. value, ok := m.Take("nonexistent")
  154. require.False(t, ok)
  155. require.Equal(t, 0, value)
  156. require.Equal(t, 1, m.Len())
  157. value, ok = m.Get("key1")
  158. require.True(t, ok)
  159. require.Equal(t, 42, value)
  160. }
  161. func TestMap_Take_EmptyMap(t *testing.T) {
  162. t.Parallel()
  163. m := NewMap[string, int]()
  164. value, ok := m.Take("key1")
  165. require.False(t, ok)
  166. require.Equal(t, 0, value)
  167. require.Equal(t, 0, m.Len())
  168. }
  169. func TestMap_Take_SameKeyTwice(t *testing.T) {
  170. t.Parallel()
  171. m := NewMap[string, int]()
  172. m.Set("key1", 42)
  173. value, ok := m.Take("key1")
  174. require.True(t, ok)
  175. require.Equal(t, 42, value)
  176. require.Equal(t, 0, m.Len())
  177. value, ok = m.Take("key1")
  178. require.False(t, ok)
  179. require.Equal(t, 0, value)
  180. require.Equal(t, 0, m.Len())
  181. }
  182. func TestMap_Seq2(t *testing.T) {
  183. t.Parallel()
  184. m := NewMap[string, int]()
  185. m.Set("key1", 1)
  186. m.Set("key2", 2)
  187. m.Set("key3", 3)
  188. collected := maps.Collect(m.Seq2())
  189. require.Equal(t, 3, len(collected))
  190. require.Equal(t, 1, collected["key1"])
  191. require.Equal(t, 2, collected["key2"])
  192. require.Equal(t, 3, collected["key3"])
  193. }
  194. func TestMap_Seq2_EarlyReturn(t *testing.T) {
  195. t.Parallel()
  196. m := NewMap[string, int]()
  197. m.Set("key1", 1)
  198. m.Set("key2", 2)
  199. m.Set("key3", 3)
  200. count := 0
  201. for range m.Seq2() {
  202. count++
  203. if count == 2 {
  204. break
  205. }
  206. }
  207. require.Equal(t, 2, count)
  208. }
  209. func TestMap_Seq2_EmptyMap(t *testing.T) {
  210. t.Parallel()
  211. m := NewMap[string, int]()
  212. count := 0
  213. for range m.Seq2() {
  214. count++
  215. }
  216. require.Equal(t, 0, count)
  217. }
  218. func TestMap_Seq(t *testing.T) {
  219. t.Parallel()
  220. m := NewMap[string, int]()
  221. m.Set("key1", 1)
  222. m.Set("key2", 2)
  223. m.Set("key3", 3)
  224. collected := make([]int, 0)
  225. for v := range m.Seq() {
  226. collected = append(collected, v)
  227. }
  228. require.Equal(t, 3, len(collected))
  229. require.Contains(t, collected, 1)
  230. require.Contains(t, collected, 2)
  231. require.Contains(t, collected, 3)
  232. }
  233. func TestMap_Seq_EarlyReturn(t *testing.T) {
  234. t.Parallel()
  235. m := NewMap[string, int]()
  236. m.Set("key1", 1)
  237. m.Set("key2", 2)
  238. m.Set("key3", 3)
  239. count := 0
  240. for range m.Seq() {
  241. count++
  242. if count == 2 {
  243. break
  244. }
  245. }
  246. require.Equal(t, 2, count)
  247. }
  248. func TestMap_Seq_EmptyMap(t *testing.T) {
  249. t.Parallel()
  250. m := NewMap[string, int]()
  251. count := 0
  252. for range m.Seq() {
  253. count++
  254. }
  255. require.Equal(t, 0, count)
  256. }
  257. func TestMap_MarshalJSON(t *testing.T) {
  258. t.Parallel()
  259. m := NewMap[string, int]()
  260. m.Set("key1", 1)
  261. m.Set("key2", 2)
  262. data, err := json.Marshal(m)
  263. require.NoError(t, err)
  264. result := &Map[string, int]{}
  265. err = json.Unmarshal(data, result)
  266. require.NoError(t, err)
  267. require.Equal(t, 2, result.Len())
  268. v1, _ := result.Get("key1")
  269. v2, _ := result.Get("key2")
  270. require.Equal(t, 1, v1)
  271. require.Equal(t, 2, v2)
  272. }
  273. func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
  274. t.Parallel()
  275. m := NewMap[string, int]()
  276. data, err := json.Marshal(m)
  277. require.NoError(t, err)
  278. require.Equal(t, "{}", string(data))
  279. }
  280. func TestMap_UnmarshalJSON(t *testing.T) {
  281. t.Parallel()
  282. jsonData := `{"key1": 1, "key2": 2}`
  283. m := NewMap[string, int]()
  284. err := json.Unmarshal([]byte(jsonData), m)
  285. require.NoError(t, err)
  286. require.Equal(t, 2, m.Len())
  287. value, ok := m.Get("key1")
  288. require.True(t, ok)
  289. require.Equal(t, 1, value)
  290. value, ok = m.Get("key2")
  291. require.True(t, ok)
  292. require.Equal(t, 2, value)
  293. }
  294. func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
  295. t.Parallel()
  296. jsonData := `{}`
  297. m := NewMap[string, int]()
  298. err := json.Unmarshal([]byte(jsonData), m)
  299. require.NoError(t, err)
  300. require.Equal(t, 0, m.Len())
  301. }
  302. func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
  303. t.Parallel()
  304. jsonData := `{"key1": 1, "key2":}`
  305. m := NewMap[string, int]()
  306. err := json.Unmarshal([]byte(jsonData), m)
  307. require.Error(t, err)
  308. }
  309. func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
  310. t.Parallel()
  311. m := NewMap[string, int]()
  312. m.Set("existing", 999)
  313. jsonData := `{"key1": 1, "key2": 2}`
  314. err := json.Unmarshal([]byte(jsonData), m)
  315. require.NoError(t, err)
  316. require.Equal(t, 2, m.Len())
  317. _, ok := m.Get("existing")
  318. require.False(t, ok)
  319. value, ok := m.Get("key1")
  320. require.True(t, ok)
  321. require.Equal(t, 1, value)
  322. }
  323. func TestMap_JSONRoundTrip(t *testing.T) {
  324. t.Parallel()
  325. original := NewMap[string, int]()
  326. original.Set("key1", 1)
  327. original.Set("key2", 2)
  328. original.Set("key3", 3)
  329. data, err := json.Marshal(original)
  330. require.NoError(t, err)
  331. restored := NewMap[string, int]()
  332. err = json.Unmarshal(data, restored)
  333. require.NoError(t, err)
  334. require.Equal(t, original.Len(), restored.Len())
  335. for k, v := range original.Seq2() {
  336. restoredValue, ok := restored.Get(k)
  337. require.True(t, ok)
  338. require.Equal(t, v, restoredValue)
  339. }
  340. }
  341. func TestMap_ConcurrentAccess(t *testing.T) {
  342. t.Parallel()
  343. m := NewMap[int, int]()
  344. const numGoroutines = 100
  345. const numOperations = 100
  346. var wg sync.WaitGroup
  347. wg.Add(numGoroutines)
  348. for i := range numGoroutines {
  349. go func(id int) {
  350. defer wg.Done()
  351. for j := range numOperations {
  352. key := id*numOperations + j
  353. m.Set(key, key*2)
  354. value, ok := m.Get(key)
  355. require.True(t, ok)
  356. require.Equal(t, key*2, value)
  357. }
  358. }(i)
  359. }
  360. wg.Wait()
  361. require.Equal(t, numGoroutines*numOperations, m.Len())
  362. }
  363. func TestMap_ConcurrentReadWrite(t *testing.T) {
  364. t.Parallel()
  365. m := NewMap[int, int]()
  366. const numReaders = 50
  367. const numWriters = 50
  368. const numOperations = 100
  369. for i := range 1000 {
  370. m.Set(i, i)
  371. }
  372. var wg sync.WaitGroup
  373. wg.Add(numReaders + numWriters)
  374. for range numReaders {
  375. go func() {
  376. defer wg.Done()
  377. for j := range numOperations {
  378. key := j % 1000
  379. value, ok := m.Get(key)
  380. if ok {
  381. require.Equal(t, key, value)
  382. }
  383. _ = m.Len()
  384. }
  385. }()
  386. }
  387. for i := range numWriters {
  388. go func(id int) {
  389. defer wg.Done()
  390. for j := range numOperations {
  391. key := 1000 + id*numOperations + j
  392. m.Set(key, key)
  393. if j%10 == 0 {
  394. m.Del(key)
  395. }
  396. }
  397. }(i)
  398. }
  399. wg.Wait()
  400. }
  401. func TestMap_ConcurrentSeq2(t *testing.T) {
  402. t.Parallel()
  403. m := NewMap[int, int]()
  404. for i := range 100 {
  405. m.Set(i, i*2)
  406. }
  407. var wg sync.WaitGroup
  408. const numIterators = 10
  409. wg.Add(numIterators)
  410. for range numIterators {
  411. go func() {
  412. defer wg.Done()
  413. count := 0
  414. for k, v := range m.Seq2() {
  415. require.Equal(t, k*2, v)
  416. count++
  417. }
  418. require.Equal(t, 100, count)
  419. }()
  420. }
  421. wg.Wait()
  422. }
  423. func TestMap_ConcurrentSeq(t *testing.T) {
  424. t.Parallel()
  425. m := NewMap[int, int]()
  426. for i := range 100 {
  427. m.Set(i, i*2)
  428. }
  429. var wg sync.WaitGroup
  430. const numIterators = 10
  431. wg.Add(numIterators)
  432. for range numIterators {
  433. go func() {
  434. defer wg.Done()
  435. count := 0
  436. values := make(map[int]bool)
  437. for v := range m.Seq() {
  438. values[v] = true
  439. count++
  440. }
  441. require.Equal(t, 100, count)
  442. for i := range 100 {
  443. require.True(t, values[i*2])
  444. }
  445. }()
  446. }
  447. wg.Wait()
  448. }
  449. func TestMap_ConcurrentTake(t *testing.T) {
  450. t.Parallel()
  451. m := NewMap[int, int]()
  452. const numItems = 1000
  453. for i := range numItems {
  454. m.Set(i, i*2)
  455. }
  456. var wg sync.WaitGroup
  457. const numWorkers = 10
  458. taken := make([][]int, numWorkers)
  459. wg.Add(numWorkers)
  460. for i := range numWorkers {
  461. go func(workerID int) {
  462. defer wg.Done()
  463. taken[workerID] = make([]int, 0)
  464. for j := workerID; j < numItems; j += numWorkers {
  465. if value, ok := m.Take(j); ok {
  466. taken[workerID] = append(taken[workerID], value)
  467. }
  468. }
  469. }(i)
  470. }
  471. wg.Wait()
  472. require.Equal(t, 0, m.Len())
  473. allTaken := make(map[int]bool)
  474. for _, workerTaken := range taken {
  475. for _, value := range workerTaken {
  476. require.False(t, allTaken[value], "Value %d was taken multiple times", value)
  477. allTaken[value] = true
  478. }
  479. }
  480. require.Equal(t, numItems, len(allTaken))
  481. for i := range numItems {
  482. require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
  483. }
  484. }
  485. func TestMap_TypeSafety(t *testing.T) {
  486. t.Parallel()
  487. stringIntMap := NewMap[string, int]()
  488. stringIntMap.Set("key", 42)
  489. value, ok := stringIntMap.Get("key")
  490. require.True(t, ok)
  491. require.Equal(t, 42, value)
  492. intStringMap := NewMap[int, string]()
  493. intStringMap.Set(42, "value")
  494. strValue, ok := intStringMap.Get(42)
  495. require.True(t, ok)
  496. require.Equal(t, "value", strValue)
  497. structMap := NewMap[string, struct{ Name string }]()
  498. structMap.Set("key", struct{ Name string }{Name: "test"})
  499. structValue, ok := structMap.Get("key")
  500. require.True(t, ok)
  501. require.Equal(t, "test", structValue.Name)
  502. }
  503. func TestMap_InterfaceCompliance(t *testing.T) {
  504. t.Parallel()
  505. var _ json.Marshaler = &Map[string, any]{}
  506. var _ json.Unmarshaler = &Map[string, any]{}
  507. }
  508. func BenchmarkMap_Set(b *testing.B) {
  509. m := NewMap[int, int]()
  510. for i := 0; b.Loop(); i++ {
  511. m.Set(i, i*2)
  512. }
  513. }
  514. func BenchmarkMap_Get(b *testing.B) {
  515. m := NewMap[int, int]()
  516. for i := range 1000 {
  517. m.Set(i, i*2)
  518. }
  519. for i := 0; b.Loop(); i++ {
  520. m.Get(i % 1000)
  521. }
  522. }
  523. func BenchmarkMap_Seq2(b *testing.B) {
  524. m := NewMap[int, int]()
  525. for i := range 1000 {
  526. m.Set(i, i*2)
  527. }
  528. for b.Loop() {
  529. for range m.Seq2() {
  530. }
  531. }
  532. }
  533. func BenchmarkMap_Seq(b *testing.B) {
  534. m := NewMap[int, int]()
  535. for i := range 1000 {
  536. m.Set(i, i*2)
  537. }
  538. for b.Loop() {
  539. for range m.Seq() {
  540. }
  541. }
  542. }
  543. func BenchmarkMap_Take(b *testing.B) {
  544. m := NewMap[int, int]()
  545. for i := range 1000 {
  546. m.Set(i, i*2)
  547. }
  548. b.ResetTimer()
  549. for i := 0; b.Loop(); i++ {
  550. key := i % 1000
  551. m.Take(key)
  552. if i%1000 == 999 {
  553. b.StopTimer()
  554. for j := range 1000 {
  555. m.Set(j, j*2)
  556. }
  557. b.StartTimer()
  558. }
  559. }
  560. }
  561. func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
  562. m := NewMap[int, int]()
  563. for i := range 1000 {
  564. m.Set(i, i*2)
  565. }
  566. b.ResetTimer()
  567. b.RunParallel(func(pb *testing.PB) {
  568. i := 0
  569. for pb.Next() {
  570. if i%2 == 0 {
  571. m.Get(i % 1000)
  572. } else {
  573. m.Set(i+1000, i*2)
  574. }
  575. i++
  576. }
  577. })
  578. }