maps_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. package csync
  2. import (
  3. "encoding/json"
  4. "maps"
  5. "sync"
  6. "testing"
  7. "github.com/stretchr/testify/require"
  8. )
  9. func TestNewMap(t *testing.T) {
  10. t.Parallel()
  11. m := NewMap[string, int]()
  12. require.NotNil(t, m)
  13. require.NotNil(t, m.inner)
  14. require.Equal(t, 0, m.Len())
  15. }
  16. func TestNewMapFrom(t *testing.T) {
  17. t.Parallel()
  18. original := map[string]int{
  19. "key1": 1,
  20. "key2": 2,
  21. }
  22. m := NewMapFrom(original)
  23. require.NotNil(t, m)
  24. require.Equal(t, original, m.inner)
  25. require.Equal(t, 2, m.Len())
  26. value, ok := m.Get("key1")
  27. require.True(t, ok)
  28. require.Equal(t, 1, value)
  29. }
  30. func TestMap_Set(t *testing.T) {
  31. t.Parallel()
  32. m := NewMap[string, int]()
  33. m.Set("key1", 42)
  34. value, ok := m.Get("key1")
  35. require.True(t, ok)
  36. require.Equal(t, 42, value)
  37. require.Equal(t, 1, m.Len())
  38. m.Set("key1", 100)
  39. value, ok = m.Get("key1")
  40. require.True(t, ok)
  41. require.Equal(t, 100, value)
  42. require.Equal(t, 1, m.Len())
  43. }
  44. func TestMap_GetOrSet(t *testing.T) {
  45. t.Parallel()
  46. m := NewMap[string, int]()
  47. require.Equal(t, 42, m.GetOrSet("key1", func() int { return 42 }))
  48. require.Equal(t, 42, m.GetOrSet("key1", func() int { return 99999 }))
  49. require.Equal(t, 1, m.Len())
  50. }
  51. func TestMap_Get(t *testing.T) {
  52. t.Parallel()
  53. m := NewMap[string, int]()
  54. value, ok := m.Get("nonexistent")
  55. require.False(t, ok)
  56. require.Equal(t, 0, value)
  57. m.Set("key1", 42)
  58. value, ok = m.Get("key1")
  59. require.True(t, ok)
  60. require.Equal(t, 42, value)
  61. }
  62. func TestMap_Del(t *testing.T) {
  63. t.Parallel()
  64. m := NewMap[string, int]()
  65. m.Set("key1", 42)
  66. m.Set("key2", 100)
  67. require.Equal(t, 2, m.Len())
  68. m.Del("key1")
  69. _, ok := m.Get("key1")
  70. require.False(t, ok)
  71. require.Equal(t, 1, m.Len())
  72. value, ok := m.Get("key2")
  73. require.True(t, ok)
  74. require.Equal(t, 100, value)
  75. m.Del("nonexistent")
  76. require.Equal(t, 1, m.Len())
  77. }
  78. func TestMap_Len(t *testing.T) {
  79. t.Parallel()
  80. m := NewMap[string, int]()
  81. require.Equal(t, 0, m.Len())
  82. m.Set("key1", 1)
  83. require.Equal(t, 1, m.Len())
  84. m.Set("key2", 2)
  85. require.Equal(t, 2, m.Len())
  86. m.Del("key1")
  87. require.Equal(t, 1, m.Len())
  88. m.Del("key2")
  89. require.Equal(t, 0, m.Len())
  90. }
  91. func TestMap_Take(t *testing.T) {
  92. t.Parallel()
  93. m := NewMap[string, int]()
  94. m.Set("key1", 42)
  95. m.Set("key2", 100)
  96. require.Equal(t, 2, m.Len())
  97. value, ok := m.Take("key1")
  98. require.True(t, ok)
  99. require.Equal(t, 42, value)
  100. require.Equal(t, 1, m.Len())
  101. _, exists := m.Get("key1")
  102. require.False(t, exists)
  103. value, ok = m.Get("key2")
  104. require.True(t, ok)
  105. require.Equal(t, 100, value)
  106. }
  107. func TestMap_Take_NonexistentKey(t *testing.T) {
  108. t.Parallel()
  109. m := NewMap[string, int]()
  110. m.Set("key1", 42)
  111. value, ok := m.Take("nonexistent")
  112. require.False(t, ok)
  113. require.Equal(t, 0, value)
  114. require.Equal(t, 1, m.Len())
  115. value, ok = m.Get("key1")
  116. require.True(t, ok)
  117. require.Equal(t, 42, value)
  118. }
  119. func TestMap_Take_EmptyMap(t *testing.T) {
  120. t.Parallel()
  121. m := NewMap[string, int]()
  122. value, ok := m.Take("key1")
  123. require.False(t, ok)
  124. require.Equal(t, 0, value)
  125. require.Equal(t, 0, m.Len())
  126. }
  127. func TestMap_Take_SameKeyTwice(t *testing.T) {
  128. t.Parallel()
  129. m := NewMap[string, int]()
  130. m.Set("key1", 42)
  131. value, ok := m.Take("key1")
  132. require.True(t, ok)
  133. require.Equal(t, 42, value)
  134. require.Equal(t, 0, m.Len())
  135. value, ok = m.Take("key1")
  136. require.False(t, ok)
  137. require.Equal(t, 0, value)
  138. require.Equal(t, 0, m.Len())
  139. }
  140. func TestMap_Seq2(t *testing.T) {
  141. t.Parallel()
  142. m := NewMap[string, int]()
  143. m.Set("key1", 1)
  144. m.Set("key2", 2)
  145. m.Set("key3", 3)
  146. collected := maps.Collect(m.Seq2())
  147. require.Equal(t, 3, len(collected))
  148. require.Equal(t, 1, collected["key1"])
  149. require.Equal(t, 2, collected["key2"])
  150. require.Equal(t, 3, collected["key3"])
  151. }
  152. func TestMap_Seq2_EarlyReturn(t *testing.T) {
  153. t.Parallel()
  154. m := NewMap[string, int]()
  155. m.Set("key1", 1)
  156. m.Set("key2", 2)
  157. m.Set("key3", 3)
  158. count := 0
  159. for range m.Seq2() {
  160. count++
  161. if count == 2 {
  162. break
  163. }
  164. }
  165. require.Equal(t, 2, count)
  166. }
  167. func TestMap_Seq2_EmptyMap(t *testing.T) {
  168. t.Parallel()
  169. m := NewMap[string, int]()
  170. count := 0
  171. for range m.Seq2() {
  172. count++
  173. }
  174. require.Equal(t, 0, count)
  175. }
  176. func TestMap_Seq(t *testing.T) {
  177. t.Parallel()
  178. m := NewMap[string, int]()
  179. m.Set("key1", 1)
  180. m.Set("key2", 2)
  181. m.Set("key3", 3)
  182. collected := make([]int, 0)
  183. for v := range m.Seq() {
  184. collected = append(collected, v)
  185. }
  186. require.Equal(t, 3, len(collected))
  187. require.Contains(t, collected, 1)
  188. require.Contains(t, collected, 2)
  189. require.Contains(t, collected, 3)
  190. }
  191. func TestMap_Seq_EarlyReturn(t *testing.T) {
  192. t.Parallel()
  193. m := NewMap[string, int]()
  194. m.Set("key1", 1)
  195. m.Set("key2", 2)
  196. m.Set("key3", 3)
  197. count := 0
  198. for range m.Seq() {
  199. count++
  200. if count == 2 {
  201. break
  202. }
  203. }
  204. require.Equal(t, 2, count)
  205. }
  206. func TestMap_Seq_EmptyMap(t *testing.T) {
  207. t.Parallel()
  208. m := NewMap[string, int]()
  209. count := 0
  210. for range m.Seq() {
  211. count++
  212. }
  213. require.Equal(t, 0, count)
  214. }
  215. func TestMap_MarshalJSON(t *testing.T) {
  216. t.Parallel()
  217. m := NewMap[string, int]()
  218. m.Set("key1", 1)
  219. m.Set("key2", 2)
  220. data, err := json.Marshal(m)
  221. require.NoError(t, err)
  222. result := &Map[string, int]{}
  223. err = json.Unmarshal(data, result)
  224. require.NoError(t, err)
  225. require.Equal(t, 2, result.Len())
  226. v1, _ := result.Get("key1")
  227. v2, _ := result.Get("key2")
  228. require.Equal(t, 1, v1)
  229. require.Equal(t, 2, v2)
  230. }
  231. func TestMap_MarshalJSON_EmptyMap(t *testing.T) {
  232. t.Parallel()
  233. m := NewMap[string, int]()
  234. data, err := json.Marshal(m)
  235. require.NoError(t, err)
  236. require.Equal(t, "{}", string(data))
  237. }
  238. func TestMap_UnmarshalJSON(t *testing.T) {
  239. t.Parallel()
  240. jsonData := `{"key1": 1, "key2": 2}`
  241. m := NewMap[string, int]()
  242. err := json.Unmarshal([]byte(jsonData), m)
  243. require.NoError(t, err)
  244. require.Equal(t, 2, m.Len())
  245. value, ok := m.Get("key1")
  246. require.True(t, ok)
  247. require.Equal(t, 1, value)
  248. value, ok = m.Get("key2")
  249. require.True(t, ok)
  250. require.Equal(t, 2, value)
  251. }
  252. func TestMap_UnmarshalJSON_EmptyJSON(t *testing.T) {
  253. t.Parallel()
  254. jsonData := `{}`
  255. m := NewMap[string, int]()
  256. err := json.Unmarshal([]byte(jsonData), m)
  257. require.NoError(t, err)
  258. require.Equal(t, 0, m.Len())
  259. }
  260. func TestMap_UnmarshalJSON_InvalidJSON(t *testing.T) {
  261. t.Parallel()
  262. jsonData := `{"key1": 1, "key2":}`
  263. m := NewMap[string, int]()
  264. err := json.Unmarshal([]byte(jsonData), m)
  265. require.Error(t, err)
  266. }
  267. func TestMap_UnmarshalJSON_OverwritesExistingData(t *testing.T) {
  268. t.Parallel()
  269. m := NewMap[string, int]()
  270. m.Set("existing", 999)
  271. jsonData := `{"key1": 1, "key2": 2}`
  272. err := json.Unmarshal([]byte(jsonData), m)
  273. require.NoError(t, err)
  274. require.Equal(t, 2, m.Len())
  275. _, ok := m.Get("existing")
  276. require.False(t, ok)
  277. value, ok := m.Get("key1")
  278. require.True(t, ok)
  279. require.Equal(t, 1, value)
  280. }
  281. func TestMap_JSONRoundTrip(t *testing.T) {
  282. t.Parallel()
  283. original := NewMap[string, int]()
  284. original.Set("key1", 1)
  285. original.Set("key2", 2)
  286. original.Set("key3", 3)
  287. data, err := json.Marshal(original)
  288. require.NoError(t, err)
  289. restored := NewMap[string, int]()
  290. err = json.Unmarshal(data, restored)
  291. require.NoError(t, err)
  292. require.Equal(t, original.Len(), restored.Len())
  293. for k, v := range original.Seq2() {
  294. restoredValue, ok := restored.Get(k)
  295. require.True(t, ok)
  296. require.Equal(t, v, restoredValue)
  297. }
  298. }
  299. func TestMap_ConcurrentAccess(t *testing.T) {
  300. t.Parallel()
  301. m := NewMap[int, int]()
  302. const numGoroutines = 100
  303. const numOperations = 100
  304. var wg sync.WaitGroup
  305. wg.Add(numGoroutines)
  306. for i := range numGoroutines {
  307. go func(id int) {
  308. defer wg.Done()
  309. for j := range numOperations {
  310. key := id*numOperations + j
  311. m.Set(key, key*2)
  312. value, ok := m.Get(key)
  313. require.True(t, ok)
  314. require.Equal(t, key*2, value)
  315. }
  316. }(i)
  317. }
  318. wg.Wait()
  319. require.Equal(t, numGoroutines*numOperations, m.Len())
  320. }
  321. func TestMap_ConcurrentReadWrite(t *testing.T) {
  322. t.Parallel()
  323. m := NewMap[int, int]()
  324. const numReaders = 50
  325. const numWriters = 50
  326. const numOperations = 100
  327. for i := range 1000 {
  328. m.Set(i, i)
  329. }
  330. var wg sync.WaitGroup
  331. wg.Add(numReaders + numWriters)
  332. for range numReaders {
  333. go func() {
  334. defer wg.Done()
  335. for j := range numOperations {
  336. key := j % 1000
  337. value, ok := m.Get(key)
  338. if ok {
  339. require.Equal(t, key, value)
  340. }
  341. _ = m.Len()
  342. }
  343. }()
  344. }
  345. for i := range numWriters {
  346. go func(id int) {
  347. defer wg.Done()
  348. for j := range numOperations {
  349. key := 1000 + id*numOperations + j
  350. m.Set(key, key)
  351. if j%10 == 0 {
  352. m.Del(key)
  353. }
  354. }
  355. }(i)
  356. }
  357. wg.Wait()
  358. }
  359. func TestMap_ConcurrentSeq2(t *testing.T) {
  360. t.Parallel()
  361. m := NewMap[int, int]()
  362. for i := range 100 {
  363. m.Set(i, i*2)
  364. }
  365. var wg sync.WaitGroup
  366. const numIterators = 10
  367. wg.Add(numIterators)
  368. for range numIterators {
  369. go func() {
  370. defer wg.Done()
  371. count := 0
  372. for k, v := range m.Seq2() {
  373. require.Equal(t, k*2, v)
  374. count++
  375. }
  376. require.Equal(t, 100, count)
  377. }()
  378. }
  379. wg.Wait()
  380. }
  381. func TestMap_ConcurrentSeq(t *testing.T) {
  382. t.Parallel()
  383. m := NewMap[int, int]()
  384. for i := range 100 {
  385. m.Set(i, i*2)
  386. }
  387. var wg sync.WaitGroup
  388. const numIterators = 10
  389. wg.Add(numIterators)
  390. for range numIterators {
  391. go func() {
  392. defer wg.Done()
  393. count := 0
  394. values := make(map[int]bool)
  395. for v := range m.Seq() {
  396. values[v] = true
  397. count++
  398. }
  399. require.Equal(t, 100, count)
  400. for i := range 100 {
  401. require.True(t, values[i*2])
  402. }
  403. }()
  404. }
  405. wg.Wait()
  406. }
  407. func TestMap_ConcurrentTake(t *testing.T) {
  408. t.Parallel()
  409. m := NewMap[int, int]()
  410. const numItems = 1000
  411. for i := range numItems {
  412. m.Set(i, i*2)
  413. }
  414. var wg sync.WaitGroup
  415. const numWorkers = 10
  416. taken := make([][]int, numWorkers)
  417. wg.Add(numWorkers)
  418. for i := range numWorkers {
  419. go func(workerID int) {
  420. defer wg.Done()
  421. taken[workerID] = make([]int, 0)
  422. for j := workerID; j < numItems; j += numWorkers {
  423. if value, ok := m.Take(j); ok {
  424. taken[workerID] = append(taken[workerID], value)
  425. }
  426. }
  427. }(i)
  428. }
  429. wg.Wait()
  430. require.Equal(t, 0, m.Len())
  431. allTaken := make(map[int]bool)
  432. for _, workerTaken := range taken {
  433. for _, value := range workerTaken {
  434. require.False(t, allTaken[value], "Value %d was taken multiple times", value)
  435. allTaken[value] = true
  436. }
  437. }
  438. require.Equal(t, numItems, len(allTaken))
  439. for i := range numItems {
  440. require.True(t, allTaken[i*2], "Expected value %d to be taken", i*2)
  441. }
  442. }
  443. func TestMap_TypeSafety(t *testing.T) {
  444. t.Parallel()
  445. stringIntMap := NewMap[string, int]()
  446. stringIntMap.Set("key", 42)
  447. value, ok := stringIntMap.Get("key")
  448. require.True(t, ok)
  449. require.Equal(t, 42, value)
  450. intStringMap := NewMap[int, string]()
  451. intStringMap.Set(42, "value")
  452. strValue, ok := intStringMap.Get(42)
  453. require.True(t, ok)
  454. require.Equal(t, "value", strValue)
  455. structMap := NewMap[string, struct{ Name string }]()
  456. structMap.Set("key", struct{ Name string }{Name: "test"})
  457. structValue, ok := structMap.Get("key")
  458. require.True(t, ok)
  459. require.Equal(t, "test", structValue.Name)
  460. }
  461. func TestMap_InterfaceCompliance(t *testing.T) {
  462. t.Parallel()
  463. var _ json.Marshaler = &Map[string, any]{}
  464. var _ json.Unmarshaler = &Map[string, any]{}
  465. }
  466. func BenchmarkMap_Set(b *testing.B) {
  467. m := NewMap[int, int]()
  468. for i := 0; b.Loop(); i++ {
  469. m.Set(i, i*2)
  470. }
  471. }
  472. func BenchmarkMap_Get(b *testing.B) {
  473. m := NewMap[int, int]()
  474. for i := range 1000 {
  475. m.Set(i, i*2)
  476. }
  477. for i := 0; b.Loop(); i++ {
  478. m.Get(i % 1000)
  479. }
  480. }
  481. func BenchmarkMap_Seq2(b *testing.B) {
  482. m := NewMap[int, int]()
  483. for i := range 1000 {
  484. m.Set(i, i*2)
  485. }
  486. for b.Loop() {
  487. for range m.Seq2() {
  488. }
  489. }
  490. }
  491. func BenchmarkMap_Seq(b *testing.B) {
  492. m := NewMap[int, int]()
  493. for i := range 1000 {
  494. m.Set(i, i*2)
  495. }
  496. for b.Loop() {
  497. for range m.Seq() {
  498. }
  499. }
  500. }
  501. func BenchmarkMap_Take(b *testing.B) {
  502. m := NewMap[int, int]()
  503. for i := range 1000 {
  504. m.Set(i, i*2)
  505. }
  506. b.ResetTimer()
  507. for i := 0; b.Loop(); i++ {
  508. key := i % 1000
  509. m.Take(key)
  510. if i%1000 == 999 {
  511. b.StopTimer()
  512. for j := range 1000 {
  513. m.Set(j, j*2)
  514. }
  515. b.StartTimer()
  516. }
  517. }
  518. }
  519. func BenchmarkMap_ConcurrentReadWrite(b *testing.B) {
  520. m := NewMap[int, int]()
  521. for i := range 1000 {
  522. m.Set(i, i*2)
  523. }
  524. b.ResetTimer()
  525. b.RunParallel(func(pb *testing.PB) {
  526. i := 0
  527. for pb.Next() {
  528. if i%2 == 0 {
  529. m.Get(i % 1000)
  530. } else {
  531. m.Set(i+1000, i*2)
  532. }
  533. i++
  534. }
  535. })
  536. }