singleflight_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. // Copyright 2013 The Go Authors. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. package singleflight
  7. import (
  8. "bytes"
  9. "context"
  10. "errors"
  11. "fmt"
  12. "os"
  13. "os/exec"
  14. "runtime"
  15. "runtime/debug"
  16. "strings"
  17. "sync"
  18. "sync/atomic"
  19. "testing"
  20. "time"
  21. )
  22. func TestDo(t *testing.T) {
  23. var g Group[string, any]
  24. v, err, _ := g.Do("key", func() (interface{}, error) {
  25. return "bar", nil
  26. })
  27. if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
  28. t.Errorf("Do = %v; want %v", got, want)
  29. }
  30. if err != nil {
  31. t.Errorf("Do error = %v", err)
  32. }
  33. }
  34. func TestDoErr(t *testing.T) {
  35. var g Group[string, any]
  36. someErr := errors.New("Some error")
  37. v, err, _ := g.Do("key", func() (interface{}, error) {
  38. return nil, someErr
  39. })
  40. if err != someErr {
  41. t.Errorf("Do error = %v; want someErr %v", err, someErr)
  42. }
  43. if v != nil {
  44. t.Errorf("unexpected non-nil value %#v", v)
  45. }
  46. }
  47. func TestDoDupSuppress(t *testing.T) {
  48. var g Group[string, any]
  49. var wg1, wg2 sync.WaitGroup
  50. c := make(chan string, 1)
  51. var calls int32
  52. fn := func() (interface{}, error) {
  53. if atomic.AddInt32(&calls, 1) == 1 {
  54. // First invocation.
  55. wg1.Done()
  56. }
  57. v := <-c
  58. c <- v // pump; make available for any future calls
  59. time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
  60. return v, nil
  61. }
  62. const n = 10
  63. wg1.Add(1)
  64. for range n {
  65. wg1.Add(1)
  66. wg2.Add(1)
  67. go func() {
  68. defer wg2.Done()
  69. wg1.Done()
  70. v, err, _ := g.Do("key", fn)
  71. if err != nil {
  72. t.Errorf("Do error: %v", err)
  73. return
  74. }
  75. if s, _ := v.(string); s != "bar" {
  76. t.Errorf("Do = %T %v; want %q", v, v, "bar")
  77. }
  78. }()
  79. }
  80. wg1.Wait()
  81. // At least one goroutine is in fn now and all of them have at
  82. // least reached the line before the Do.
  83. c <- "bar"
  84. wg2.Wait()
  85. if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
  86. t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
  87. }
  88. }
  89. // Test that singleflight behaves correctly after Forget called.
  90. // See https://github.com/golang/go/issues/31420
  91. func TestForget(t *testing.T) {
  92. var g Group[string, any]
  93. var (
  94. firstStarted = make(chan struct{})
  95. unblockFirst = make(chan struct{})
  96. firstFinished = make(chan struct{})
  97. )
  98. go func() {
  99. g.Do("key", func() (i interface{}, e error) {
  100. close(firstStarted)
  101. <-unblockFirst
  102. close(firstFinished)
  103. return
  104. })
  105. }()
  106. <-firstStarted
  107. g.Forget("key")
  108. unblockSecond := make(chan struct{})
  109. secondResult := g.DoChan("key", func() (i interface{}, e error) {
  110. <-unblockSecond
  111. return 2, nil
  112. })
  113. close(unblockFirst)
  114. <-firstFinished
  115. thirdResult := g.DoChan("key", func() (i interface{}, e error) {
  116. return 3, nil
  117. })
  118. close(unblockSecond)
  119. <-secondResult
  120. r := <-thirdResult
  121. if r.Val != 2 {
  122. t.Errorf("We should receive result produced by second call, expected: 2, got %d", r.Val)
  123. }
  124. }
  125. func TestDoChan(t *testing.T) {
  126. var g Group[string, any]
  127. ch := g.DoChan("key", func() (interface{}, error) {
  128. return "bar", nil
  129. })
  130. res := <-ch
  131. v := res.Val
  132. err := res.Err
  133. if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
  134. t.Errorf("Do = %v; want %v", got, want)
  135. }
  136. if err != nil {
  137. t.Errorf("Do error = %v", err)
  138. }
  139. }
  140. // Test singleflight behaves correctly after Do panic.
  141. // See https://github.com/golang/go/issues/41133
  142. func TestPanicDo(t *testing.T) {
  143. var g Group[string, any]
  144. fn := func() (interface{}, error) {
  145. panic("invalid memory address or nil pointer dereference")
  146. }
  147. const n = 5
  148. waited := int32(n)
  149. panicCount := int32(0)
  150. done := make(chan struct{})
  151. for range n {
  152. go func() {
  153. defer func() {
  154. if err := recover(); err != nil {
  155. t.Logf("Got panic: %v\n%s", err, debug.Stack())
  156. atomic.AddInt32(&panicCount, 1)
  157. }
  158. if atomic.AddInt32(&waited, -1) == 0 {
  159. close(done)
  160. }
  161. }()
  162. g.Do("key", fn)
  163. }()
  164. }
  165. select {
  166. case <-done:
  167. if panicCount != n {
  168. t.Errorf("Expect %d panic, but got %d", n, panicCount)
  169. }
  170. case <-time.After(time.Second):
  171. t.Fatalf("Do hangs")
  172. }
  173. }
  174. func TestGoexitDo(t *testing.T) {
  175. var g Group[string, any]
  176. fn := func() (interface{}, error) {
  177. runtime.Goexit()
  178. return nil, nil
  179. }
  180. const n = 5
  181. waited := int32(n)
  182. done := make(chan struct{})
  183. for range n {
  184. go func() {
  185. var err error
  186. defer func() {
  187. if err != nil {
  188. t.Errorf("Error should be nil, but got: %v", err)
  189. }
  190. if atomic.AddInt32(&waited, -1) == 0 {
  191. close(done)
  192. }
  193. }()
  194. _, err, _ = g.Do("key", fn)
  195. }()
  196. }
  197. select {
  198. case <-done:
  199. case <-time.After(time.Second):
  200. t.Fatalf("Do hangs")
  201. }
  202. }
  203. func TestPanicDoChan(t *testing.T) {
  204. if runtime.GOOS == "js" {
  205. t.Skipf("js does not support exec")
  206. }
  207. if os.Getenv("TEST_PANIC_DOCHAN") != "" {
  208. defer func() {
  209. recover()
  210. }()
  211. g := new(Group[string, any])
  212. ch := g.DoChan("", func() (interface{}, error) {
  213. panic("Panicking in DoChan")
  214. })
  215. <-ch
  216. t.Fatalf("DoChan unexpectedly returned")
  217. }
  218. t.Parallel()
  219. cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v")
  220. cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1")
  221. out := new(bytes.Buffer)
  222. cmd.Stdout = out
  223. cmd.Stderr = out
  224. if err := cmd.Start(); err != nil {
  225. t.Fatal(err)
  226. }
  227. err := cmd.Wait()
  228. t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out)
  229. if err == nil {
  230. t.Errorf("Test subprocess passed; want a crash due to panic in DoChan")
  231. }
  232. if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) {
  233. t.Errorf("Test subprocess failed with an unexpected failure mode.")
  234. }
  235. if !bytes.Contains(out.Bytes(), []byte("Panicking in DoChan")) {
  236. t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in DoChan")
  237. }
  238. }
  239. func TestPanicDoSharedByDoChan(t *testing.T) {
  240. if runtime.GOOS == "js" {
  241. t.Skipf("js does not support exec")
  242. }
  243. if os.Getenv("TEST_PANIC_DOCHAN") != "" {
  244. blocked := make(chan struct{})
  245. unblock := make(chan struct{})
  246. g := new(Group[string, any])
  247. go func() {
  248. defer func() {
  249. recover()
  250. }()
  251. g.Do("", func() (interface{}, error) {
  252. close(blocked)
  253. <-unblock
  254. panic("Panicking in Do")
  255. })
  256. }()
  257. <-blocked
  258. ch := g.DoChan("", func() (interface{}, error) {
  259. panic("DoChan unexpectedly executed callback")
  260. })
  261. close(unblock)
  262. <-ch
  263. t.Fatalf("DoChan unexpectedly returned")
  264. }
  265. t.Parallel()
  266. cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v")
  267. cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1")
  268. out := new(bytes.Buffer)
  269. cmd.Stdout = out
  270. cmd.Stderr = out
  271. if err := cmd.Start(); err != nil {
  272. t.Fatal(err)
  273. }
  274. err := cmd.Wait()
  275. t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out)
  276. if err == nil {
  277. t.Errorf("Test subprocess passed; want a crash due to panic in Do shared by DoChan")
  278. }
  279. if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) {
  280. t.Errorf("Test subprocess failed with an unexpected failure mode.")
  281. }
  282. if !bytes.Contains(out.Bytes(), []byte("Panicking in Do")) {
  283. t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do")
  284. }
  285. }
  286. func TestDoChanContext(t *testing.T) {
  287. t.Run("Basic", func(t *testing.T) {
  288. ctx, cancel := context.WithCancel(context.Background())
  289. defer cancel()
  290. var g Group[string, int]
  291. ch := g.DoChanContext(ctx, "key", func(_ context.Context) (int, error) {
  292. return 1, nil
  293. })
  294. ret := <-ch
  295. assertOKResult(t, ret, 1)
  296. })
  297. t.Run("DoesNotPropagateValues", func(t *testing.T) {
  298. ctx, cancel := context.WithCancel(context.Background())
  299. defer cancel()
  300. key := new(int)
  301. const value = "hello world"
  302. ctx = context.WithValue(ctx, key, value)
  303. var g Group[string, int]
  304. ch := g.DoChanContext(ctx, "foobar", func(ctx context.Context) (int, error) {
  305. if _, ok := ctx.Value(key).(string); ok {
  306. t.Error("expected no value, but was present in context")
  307. }
  308. return 1, nil
  309. })
  310. ret := <-ch
  311. assertOKResult(t, ret, 1)
  312. })
  313. t.Run("NoCancelWhenWaiters", func(t *testing.T) {
  314. testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second)
  315. defer testCancel()
  316. trigger := make(chan struct{})
  317. ctx1, cancel1 := context.WithCancel(context.Background())
  318. defer cancel1()
  319. ctx2, cancel2 := context.WithCancel(context.Background())
  320. defer cancel2()
  321. fn := func(ctx context.Context) (int, error) {
  322. select {
  323. case <-ctx.Done():
  324. return 0, ctx.Err()
  325. case <-trigger:
  326. return 1234, nil
  327. }
  328. }
  329. // Create two waiters, then cancel the first before we trigger
  330. // the function to return a value. This shouldn't result in a
  331. // context canceled error.
  332. var g Group[string, int]
  333. ch1 := g.DoChanContext(ctx1, "key", fn)
  334. ch2 := g.DoChanContext(ctx2, "key", fn)
  335. cancel1()
  336. // The first channel, now that it's canceled, should return a
  337. // context canceled error.
  338. select {
  339. case res := <-ch1:
  340. if !errors.Is(res.Err, context.Canceled) {
  341. t.Errorf("unexpected error; got %v, want context.Canceled", res.Err)
  342. }
  343. case <-testCtx.Done():
  344. t.Fatal("test timed out")
  345. }
  346. // Actually return
  347. close(trigger)
  348. res := <-ch2
  349. assertOKResult(t, res, 1234)
  350. })
  351. t.Run("AllCancel", func(t *testing.T) {
  352. for _, n := range []int{1, 2, 10, 20} {
  353. t.Run(fmt.Sprintf("NumWaiters=%d", n), func(t *testing.T) {
  354. testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second)
  355. defer testCancel()
  356. trigger := make(chan struct{})
  357. defer close(trigger)
  358. fn := func(ctx context.Context) (int, error) {
  359. select {
  360. case <-ctx.Done():
  361. return 0, ctx.Err()
  362. case <-trigger:
  363. t.Error("unexpected trigger; want all callers to cancel")
  364. return 0, errors.New("unexpected trigger")
  365. }
  366. }
  367. // Launch N goroutines that all wait on the same key.
  368. var (
  369. g Group[string, int]
  370. chs []<-chan Result[int]
  371. cancels []context.CancelFunc
  372. )
  373. for i := range n {
  374. ctx, cancel := context.WithCancel(context.Background())
  375. defer cancel()
  376. cancels = append(cancels, cancel)
  377. ch := g.DoChanContext(ctx, "key", fn)
  378. chs = append(chs, ch)
  379. // Every third goroutine should cancel
  380. // immediately, which better tests the
  381. // cancel logic.
  382. if i%3 == 0 {
  383. cancel()
  384. }
  385. }
  386. // Now that everything is waiting, cancel all the contexts.
  387. for _, cancel := range cancels {
  388. cancel()
  389. }
  390. // Wait for a result from each channel. They
  391. // should all return an error showing a context
  392. // cancel.
  393. for _, ch := range chs {
  394. select {
  395. case res := <-ch:
  396. if !errors.Is(res.Err, context.Canceled) {
  397. t.Errorf("unexpected error; got %v, want context.Canceled", res.Err)
  398. }
  399. case <-testCtx.Done():
  400. t.Fatal("test timed out")
  401. }
  402. }
  403. })
  404. }
  405. })
  406. }
  407. func assertOKResult[V comparable](t testing.TB, res Result[V], want V) {
  408. if res.Err != nil {
  409. t.Fatalf("unexpected error: %v", res.Err)
  410. }
  411. if res.Val != want {
  412. t.Fatalf("unexpected value; got %v, want %v", res.Val, want)
  413. }
  414. }