sync_test.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package lazy
  4. import (
  5. "errors"
  6. "fmt"
  7. "sync"
  8. "testing"
  9. "tailscale.com/types/opt"
  10. )
  11. func TestSyncValue(t *testing.T) {
  12. var lt SyncValue[int]
  13. n := int(testing.AllocsPerRun(1000, func() {
  14. got := lt.Get(fortyTwo)
  15. if got != 42 {
  16. t.Fatalf("got %v; want 42", got)
  17. }
  18. if p, ok := lt.Peek(); !ok {
  19. t.Fatalf("Peek failed")
  20. } else if p != 42 {
  21. t.Fatalf("Peek got %v; want 42", p)
  22. }
  23. }))
  24. if n != 0 {
  25. t.Errorf("allocs = %v; want 0", n)
  26. }
  27. }
  28. func TestSyncValueErr(t *testing.T) {
  29. var lt SyncValue[int]
  30. n := int(testing.AllocsPerRun(1000, func() {
  31. got, err := lt.GetErr(func() (int, error) {
  32. return 42, nil
  33. })
  34. if got != 42 || err != nil {
  35. t.Fatalf("got %v, %v; want 42, nil", got, err)
  36. }
  37. }))
  38. if n != 0 {
  39. t.Errorf("allocs = %v; want 0", n)
  40. }
  41. var lterr SyncValue[int]
  42. wantErr := errors.New("test error")
  43. n = int(testing.AllocsPerRun(1000, func() {
  44. got, err := lterr.GetErr(func() (int, error) {
  45. return 0, wantErr
  46. })
  47. if got != 0 || err != wantErr {
  48. t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr)
  49. }
  50. if p, ok := lt.Peek(); !ok {
  51. t.Fatalf("Peek failed")
  52. } else if got != 0 {
  53. t.Fatalf("Peek got %v; want 0", p)
  54. }
  55. }))
  56. if n != 0 {
  57. t.Errorf("allocs = %v; want 0", n)
  58. }
  59. }
  60. func TestSyncValueSet(t *testing.T) {
  61. var lt SyncValue[int]
  62. if !lt.Set(42) {
  63. t.Fatalf("Set failed")
  64. }
  65. if lt.Set(43) {
  66. t.Fatalf("Set succeeded after first Set")
  67. }
  68. if p, ok := lt.Peek(); !ok {
  69. t.Fatalf("Peek failed")
  70. } else if p != 42 {
  71. t.Fatalf("Peek got %v; want 42", p)
  72. }
  73. n := int(testing.AllocsPerRun(1000, func() {
  74. got := lt.Get(fortyTwo)
  75. if got != 42 {
  76. t.Fatalf("got %v; want 42", got)
  77. }
  78. }))
  79. if n != 0 {
  80. t.Errorf("allocs = %v; want 0", n)
  81. }
  82. }
  83. func TestSyncValueMustSet(t *testing.T) {
  84. var lt SyncValue[int]
  85. lt.MustSet(42)
  86. defer func() {
  87. if e := recover(); e == nil {
  88. t.Errorf("unexpected success; want panic")
  89. }
  90. }()
  91. lt.MustSet(43)
  92. }
  93. func TestSyncValueErrPeek(t *testing.T) {
  94. var sv SyncValue[int]
  95. sv.GetErr(func() (int, error) {
  96. return 123, errors.New("boom")
  97. })
  98. p, ok := sv.Peek()
  99. if ok {
  100. t.Error("unexpected Peek success")
  101. }
  102. if p != 0 {
  103. t.Fatalf("Peek got %v; want 0", p)
  104. }
  105. p, err, ok := sv.PeekErr()
  106. if !ok {
  107. t.Errorf("PeekErr ok=false; want true on error")
  108. }
  109. if got, want := fmt.Sprint(err), "boom"; got != want {
  110. t.Errorf("PeekErr error=%v; want %v", got, want)
  111. }
  112. if p != 123 {
  113. t.Fatalf("PeekErr got %v; want 123", p)
  114. }
  115. }
  116. func TestSyncValueConcurrent(t *testing.T) {
  117. var (
  118. lt SyncValue[int]
  119. wg sync.WaitGroup
  120. start = make(chan struct{})
  121. routines = 10000
  122. )
  123. wg.Add(routines)
  124. for range routines {
  125. go func() {
  126. defer wg.Done()
  127. // Every goroutine waits for the go signal, so that more of them
  128. // have a chance to race on the initial Get than with sequential
  129. // goroutine starts.
  130. <-start
  131. got := lt.Get(fortyTwo)
  132. if got != 42 {
  133. t.Errorf("got %v; want 42", got)
  134. }
  135. }()
  136. }
  137. close(start)
  138. wg.Wait()
  139. }
  140. func TestSyncValueSetForTest(t *testing.T) {
  141. testErr := errors.New("boom")
  142. tests := []struct {
  143. name string
  144. initValue opt.Value[int]
  145. initErr opt.Value[error]
  146. setForTestValue int
  147. setForTestErr error
  148. getValue int
  149. getErr opt.Value[error]
  150. wantValue int
  151. wantErr error
  152. routines int
  153. }{
  154. {
  155. name: "GetOk",
  156. setForTestValue: 42,
  157. getValue: 8,
  158. wantValue: 42,
  159. },
  160. {
  161. name: "GetOk/WithInit",
  162. initValue: opt.ValueOf(4),
  163. setForTestValue: 42,
  164. getValue: 8,
  165. wantValue: 42,
  166. },
  167. {
  168. name: "GetOk/WithInitErr",
  169. initValue: opt.ValueOf(4),
  170. initErr: opt.ValueOf(errors.New("blast")),
  171. setForTestValue: 42,
  172. getValue: 8,
  173. wantValue: 42,
  174. },
  175. {
  176. name: "GetErr",
  177. setForTestValue: 42,
  178. setForTestErr: testErr,
  179. getValue: 8,
  180. getErr: opt.ValueOf(errors.New("ka-boom")),
  181. wantValue: 42,
  182. wantErr: testErr,
  183. },
  184. {
  185. name: "GetErr/NilError",
  186. setForTestValue: 42,
  187. setForTestErr: nil,
  188. getValue: 8,
  189. getErr: opt.ValueOf(errors.New("ka-boom")),
  190. wantValue: 42,
  191. wantErr: nil,
  192. },
  193. {
  194. name: "GetErr/WithInitErr",
  195. initValue: opt.ValueOf(4),
  196. initErr: opt.ValueOf(errors.New("blast")),
  197. setForTestValue: 42,
  198. setForTestErr: testErr,
  199. getValue: 8,
  200. getErr: opt.ValueOf(errors.New("ka-boom")),
  201. wantValue: 42,
  202. wantErr: testErr,
  203. },
  204. {
  205. name: "Concurrent/GetOk",
  206. setForTestValue: 42,
  207. getValue: 8,
  208. wantValue: 42,
  209. routines: 10000,
  210. },
  211. {
  212. name: "Concurrent/GetOk/WithInitErr",
  213. initValue: opt.ValueOf(4),
  214. initErr: opt.ValueOf(errors.New("blast")),
  215. setForTestValue: 42,
  216. getValue: 8,
  217. wantValue: 42,
  218. routines: 10000,
  219. },
  220. {
  221. name: "Concurrent/GetErr",
  222. setForTestValue: 42,
  223. setForTestErr: testErr,
  224. getValue: 8,
  225. getErr: opt.ValueOf(errors.New("ka-boom")),
  226. wantValue: 42,
  227. wantErr: testErr,
  228. routines: 10000,
  229. },
  230. {
  231. name: "Concurrent/GetErr/WithInitErr",
  232. initValue: opt.ValueOf(4),
  233. initErr: opt.ValueOf(errors.New("blast")),
  234. setForTestValue: 42,
  235. setForTestErr: testErr,
  236. getValue: 8,
  237. getErr: opt.ValueOf(errors.New("ka-boom")),
  238. wantValue: 42,
  239. wantErr: testErr,
  240. routines: 10000,
  241. },
  242. }
  243. for _, tt := range tests {
  244. t.Run(tt.name, func(t *testing.T) {
  245. var v SyncValue[int]
  246. // Initialize the sync value with the specified value and/or error,
  247. // if required by the test.
  248. if initValue, ok := tt.initValue.GetOk(); ok {
  249. var wantInitErr, gotInitErr error
  250. var wantInitValue, gotInitValue int
  251. wantInitValue = initValue
  252. if initErr, ok := tt.initErr.GetOk(); ok {
  253. wantInitErr = initErr
  254. gotInitValue, gotInitErr = v.GetErr(func() (int, error) { return initValue, initErr })
  255. } else {
  256. gotInitValue = v.Get(func() int { return initValue })
  257. }
  258. if gotInitErr != wantInitErr {
  259. t.Fatalf("InitErr: got %v; want %v", gotInitErr, wantInitErr)
  260. }
  261. if gotInitValue != wantInitValue {
  262. t.Fatalf("InitValue: got %v; want %v", gotInitValue, wantInitValue)
  263. }
  264. // Verify that SetForTest reverted the error and the value during the test cleanup.
  265. t.Cleanup(func() {
  266. wantCleanupValue, wantCleanupErr := wantInitValue, wantInitErr
  267. gotCleanupValue, gotCleanupErr, ok := v.PeekErr()
  268. if !ok {
  269. t.Fatal("SyncValue is not set after cleanup")
  270. }
  271. if gotCleanupErr != wantCleanupErr {
  272. t.Fatalf("CleanupErr: got %v; want %v", gotCleanupErr, wantCleanupErr)
  273. }
  274. if gotCleanupValue != wantCleanupValue {
  275. t.Fatalf("CleanupValue: got %v; want %v", gotCleanupValue, wantCleanupValue)
  276. }
  277. })
  278. } else {
  279. // Verify that if v wasn't set prior to SetForTest, it's
  280. // reverted to a valid unset state during the test cleanup.
  281. t.Cleanup(func() {
  282. if _, _, ok := v.PeekErr(); ok {
  283. t.Fatal("SyncValue is set after cleanup")
  284. }
  285. wantCleanupValue, wantCleanupErr := 42, errors.New("ka-boom")
  286. gotCleanupValue, gotCleanupErr := v.GetErr(func() (int, error) { return wantCleanupValue, wantCleanupErr })
  287. if gotCleanupErr != wantCleanupErr {
  288. t.Fatalf("CleanupErr: got %v; want %v", gotCleanupErr, wantCleanupErr)
  289. }
  290. if gotCleanupValue != wantCleanupValue {
  291. t.Fatalf("CleanupValue: got %v; want %v", gotCleanupValue, wantCleanupValue)
  292. }
  293. })
  294. }
  295. // Set the test value and/or error.
  296. v.SetForTest(t, tt.setForTestValue, tt.setForTestErr)
  297. // Verify that the value and/or error have been set.
  298. // This will run on either the current goroutine
  299. // or concurrently depending on the tt.routines value.
  300. checkSyncValue := func() {
  301. var gotValue int
  302. var gotErr error
  303. if getErr, ok := tt.getErr.GetOk(); ok {
  304. gotValue, gotErr = v.GetErr(func() (int, error) { return tt.getValue, getErr })
  305. } else {
  306. gotValue = v.Get(func() int { return tt.getValue })
  307. }
  308. if gotErr != tt.wantErr {
  309. t.Errorf("Err: got %v; want %v", gotErr, tt.wantErr)
  310. }
  311. if gotValue != tt.wantValue {
  312. t.Errorf("Value: got %v; want %v", gotValue, tt.wantValue)
  313. }
  314. }
  315. switch tt.routines {
  316. case 0:
  317. checkSyncValue()
  318. default:
  319. var wg sync.WaitGroup
  320. wg.Add(tt.routines)
  321. start := make(chan struct{})
  322. for range tt.routines {
  323. go func() {
  324. defer wg.Done()
  325. // Every goroutine waits for the go signal, so that more of them
  326. // have a chance to race on the initial Get than with sequential
  327. // goroutine starts.
  328. <-start
  329. checkSyncValue()
  330. }()
  331. }
  332. close(start)
  333. wg.Wait()
  334. }
  335. })
  336. }
  337. }