singleflight.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. // Copyright 2013 The Go Authors. All rights reserved.
  4. // Use of this source code is governed by a BSD-style
  5. // license that can be found in the LICENSE file.
  6. // Package singleflight provides a duplicate function call suppression
  7. // mechanism.
  8. //
  9. // This is a Tailscale fork of Go's singleflight package which has had several
  10. // homes in the past:
  11. //
  12. // - https://github.com/golang/go/commit/61d3b2db6292581fc07a3767ec23ec94ad6100d1
  13. // - https://github.com/golang/groupcache/tree/master/singleflight
  14. // - https://pkg.go.dev/golang.org/x/sync/singleflight
  15. //
  16. // This fork adds generics.
  17. package singleflight // import "tailscale.com/util/singleflight"
  18. import (
  19. "bytes"
  20. "context"
  21. "errors"
  22. "fmt"
  23. "runtime"
  24. "runtime/debug"
  25. "sync"
  26. "sync/atomic"
  27. )
  28. // errGoexit indicates the runtime.Goexit was called in
  29. // the user given function.
  30. var errGoexit = errors.New("runtime.Goexit was called")
  31. // A panicError is an arbitrary value recovered from a panic
  32. // with the stack trace during the execution of given function.
  33. type panicError struct {
  34. value interface{}
  35. stack []byte
  36. }
  37. // Error implements error interface.
  38. func (p *panicError) Error() string {
  39. return fmt.Sprintf("%v\n\n%s", p.value, p.stack)
  40. }
  41. func newPanicError(v interface{}) error {
  42. stack := debug.Stack()
  43. // The first line of the stack trace is of the form "goroutine N [status]:"
  44. // but by the time the panic reaches Do the goroutine may no longer exist
  45. // and its status will have changed. Trim out the misleading line.
  46. if line := bytes.IndexByte(stack[:], '\n'); line >= 0 {
  47. stack = stack[line+1:]
  48. }
  49. return &panicError{value: v, stack: stack}
  50. }
  51. // call is an in-flight or completed singleflight.Do call
  52. type call[V any] struct {
  53. wg sync.WaitGroup
  54. // These fields are written once before the WaitGroup is done
  55. // and are only read after the WaitGroup is done.
  56. val V
  57. err error
  58. // These fields are read and written with the singleflight
  59. // mutex held before the WaitGroup is done, and are read but
  60. // not written after the WaitGroup is done.
  61. dups int
  62. chans []chan<- Result[V]
  63. // These fields are only written when the call is being created, and
  64. // only in the DoChanContext method.
  65. cancel context.CancelFunc
  66. ctxWaiters atomic.Int64
  67. }
  68. // Group represents a class of work and forms a namespace in
  69. // which units of work can be executed with duplicate suppression.
  70. type Group[K comparable, V any] struct {
  71. mu sync.Mutex // protects m
  72. m map[K]*call[V] // lazily initialized
  73. }
  74. // Result holds the results of Do, so they can be passed
  75. // on a channel.
  76. type Result[V any] struct {
  77. Val V
  78. Err error
  79. Shared bool
  80. }
  81. // Do executes and returns the results of the given function, making
  82. // sure that only one execution is in-flight for a given key at a
  83. // time. If a duplicate comes in, the duplicate caller waits for the
  84. // original to complete and receives the same results.
  85. // The return value shared indicates whether v was given to multiple callers.
  86. func (g *Group[K, V]) Do(key K, fn func() (V, error)) (v V, err error, shared bool) {
  87. g.mu.Lock()
  88. if g.m == nil {
  89. g.m = make(map[K]*call[V])
  90. }
  91. if c, ok := g.m[key]; ok {
  92. c.dups++
  93. g.mu.Unlock()
  94. c.wg.Wait()
  95. if e, ok := c.err.(*panicError); ok {
  96. panic(e)
  97. } else if c.err == errGoexit {
  98. runtime.Goexit()
  99. }
  100. return c.val, c.err, true
  101. }
  102. c := new(call[V])
  103. c.wg.Add(1)
  104. g.m[key] = c
  105. g.mu.Unlock()
  106. g.doCall(c, key, fn)
  107. return c.val, c.err, c.dups > 0
  108. }
  109. // DoChan is like Do but returns a channel that will receive the
  110. // results when they are ready.
  111. //
  112. // The returned channel will not be closed.
  113. func (g *Group[K, V]) DoChan(key K, fn func() (V, error)) <-chan Result[V] {
  114. ch := make(chan Result[V], 1)
  115. g.mu.Lock()
  116. if g.m == nil {
  117. g.m = make(map[K]*call[V])
  118. }
  119. if c, ok := g.m[key]; ok {
  120. c.dups++
  121. c.chans = append(c.chans, ch)
  122. g.mu.Unlock()
  123. return ch
  124. }
  125. c := &call[V]{chans: []chan<- Result[V]{ch}}
  126. c.wg.Add(1)
  127. g.m[key] = c
  128. g.mu.Unlock()
  129. go g.doCall(c, key, fn)
  130. return ch
  131. }
  132. // DoChanContext is like [Group.DoChan], but supports context cancelation. The
  133. // context passed to the fn function is a context that is canceled only when
  134. // there are no callers waiting on a result (i.e. all callers have canceled
  135. // their contexts).
  136. //
  137. // The context that is passed to the fn function is not derived from any of the
  138. // input contexts, so context values will not be propagated. If context values
  139. // are needed, they must be propagated explicitly.
  140. //
  141. // The returned channel will not be closed. The Result.Err field is set to the
  142. // context error if the context is canceled.
  143. func (g *Group[K, V]) DoChanContext(ctx context.Context, key K, fn func(context.Context) (V, error)) <-chan Result[V] {
  144. ch := make(chan Result[V], 1)
  145. g.mu.Lock()
  146. if g.m == nil {
  147. g.m = make(map[K]*call[V])
  148. }
  149. c, ok := g.m[key]
  150. if ok {
  151. // Call already in progress; add to the waiters list and then
  152. // release the mutex.
  153. c.dups++
  154. c.ctxWaiters.Add(1)
  155. c.chans = append(c.chans, ch)
  156. g.mu.Unlock()
  157. } else {
  158. // The call hasn't been started yet; we need to start it.
  159. //
  160. // Create a context that is not canceled when the parent context is,
  161. // but otherwise propagates all values.
  162. callCtx, callCancel := context.WithCancel(context.Background())
  163. c = &call[V]{
  164. chans: []chan<- Result[V]{ch},
  165. cancel: callCancel,
  166. }
  167. c.wg.Add(1)
  168. c.ctxWaiters.Add(1) // one caller waiting
  169. g.m[key] = c
  170. g.mu.Unlock()
  171. // Wrap our function to provide the context.
  172. go g.doCall(c, key, func() (V, error) {
  173. return fn(callCtx)
  174. })
  175. }
  176. // Instead of returning the channel directly, we need to track
  177. // when the call finishes so we can handle context cancelation.
  178. // Do so by creating an final channel that gets the
  179. // result and hooking that up to the wait function.
  180. final := make(chan Result[V], 1)
  181. go g.waitCtx(ctx, c, ch, final)
  182. return final
  183. }
  184. // waitCtx will wait on the provided call to finish, or the context to be done.
  185. // If the context is done, and this is the last waiter, then the context
  186. // provided to the underlying function will be canceled.
  187. func (g *Group[K, V]) waitCtx(ctx context.Context, c *call[V], result <-chan Result[V], output chan<- Result[V]) {
  188. var res Result[V]
  189. select {
  190. case <-ctx.Done():
  191. case res = <-result:
  192. }
  193. // Decrement the caller count, and if we're the last one, cancel the
  194. // context we created. Do this in all cases, error and otherwise, so we
  195. // don't leak goroutines.
  196. //
  197. // Also wait on the call to finish, so we know that the call has
  198. // finished executing after the last caller has returned.
  199. if c.ctxWaiters.Add(-1) == 0 {
  200. c.cancel()
  201. c.wg.Wait()
  202. }
  203. // Ensure that context cancelation takes precedence over a value being
  204. // available by checking ctx.Err() before sending the result to the
  205. // caller. The select above will nondeterministically pick a case if a
  206. // result is available and the ctx.Done channel is closed, so we check
  207. // again here.
  208. if err := ctx.Err(); err != nil {
  209. res = Result[V]{Err: err}
  210. }
  211. output <- res
  212. }
  213. // doCall handles the single call for a key.
  214. func (g *Group[K, V]) doCall(c *call[V], key K, fn func() (V, error)) {
  215. normalReturn := false
  216. recovered := false
  217. // use double-defer to distinguish panic from runtime.Goexit,
  218. // more details see https://golang.org/cl/134395
  219. defer func() {
  220. // the given function invoked runtime.Goexit
  221. if !normalReturn && !recovered {
  222. c.err = errGoexit
  223. }
  224. g.mu.Lock()
  225. defer g.mu.Unlock()
  226. c.wg.Done()
  227. if g.m[key] == c {
  228. delete(g.m, key)
  229. }
  230. if e, ok := c.err.(*panicError); ok {
  231. // In order to prevent the waiting channels from being blocked forever,
  232. // needs to ensure that this panic cannot be recovered.
  233. if len(c.chans) > 0 {
  234. go panic(e)
  235. select {} // Keep this goroutine around so that it will appear in the crash dump.
  236. } else {
  237. panic(e)
  238. }
  239. } else if c.err == errGoexit {
  240. // Already in the process of goexit, no need to call again
  241. } else {
  242. // Normal return
  243. for _, ch := range c.chans {
  244. ch <- Result[V]{c.val, c.err, c.dups > 0}
  245. }
  246. }
  247. }()
  248. func() {
  249. defer func() {
  250. if !normalReturn {
  251. // Ideally, we would wait to take a stack trace until we've determined
  252. // whether this is a panic or a runtime.Goexit.
  253. //
  254. // Unfortunately, the only way we can distinguish the two is to see
  255. // whether the recover stopped the goroutine from terminating, and by
  256. // the time we know that, the part of the stack trace relevant to the
  257. // panic has been discarded.
  258. if r := recover(); r != nil {
  259. c.err = newPanicError(r)
  260. }
  261. }
  262. }()
  263. c.val, c.err = fn()
  264. normalReturn = true
  265. }()
  266. if !normalReturn {
  267. recovered = true
  268. }
  269. }
  270. // Forget tells the singleflight to forget about a key. Future calls
  271. // to Do for this key will call the function rather than waiting for
  272. // an earlier call to complete.
  273. func (g *Group[K, V]) Forget(key K) {
  274. g.mu.Lock()
  275. delete(g.m, key)
  276. g.mu.Unlock()
  277. }