memoize.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package dist
  4. import (
  5. "sync"
  6. "tailscale.com/util/deephash"
  7. )
  8. // MemoizedFn is a function that memoize.Do can call.
  9. type MemoizedFn[T any] func() (T, error)
  10. // Memoize runs MemoizedFns and remembers their results.
  11. type Memoize[O any] struct {
  12. mu sync.Mutex
  13. cond *sync.Cond
  14. outs map[deephash.Sum]O
  15. errs map[deephash.Sum]error
  16. inflight map[deephash.Sum]bool
  17. }
  18. // Do runs fn and returns its result.
  19. // fn is only run once per unique key. Subsequent Do calls with the same key
  20. // return the memoized result of the first call, even if fn is a different
  21. // function.
  22. func (m *Memoize[O]) Do(key any, fn MemoizedFn[O]) (ret O, err error) {
  23. m.mu.Lock()
  24. defer m.mu.Unlock()
  25. if m.cond == nil {
  26. m.cond = sync.NewCond(&m.mu)
  27. m.outs = map[deephash.Sum]O{}
  28. m.errs = map[deephash.Sum]error{}
  29. m.inflight = map[deephash.Sum]bool{}
  30. }
  31. k := deephash.Hash(&key)
  32. for m.inflight[k] {
  33. m.cond.Wait()
  34. }
  35. if err := m.errs[k]; err != nil {
  36. var ret O
  37. return ret, err
  38. }
  39. if ret, ok := m.outs[k]; ok {
  40. return ret, nil
  41. }
  42. m.inflight[k] = true
  43. m.mu.Unlock()
  44. defer func() {
  45. m.mu.Lock()
  46. delete(m.inflight, k)
  47. if err != nil {
  48. m.errs[k] = err
  49. } else {
  50. m.outs[k] = ret
  51. }
  52. m.cond.Broadcast()
  53. }()
  54. ret, err = fn()
  55. if err != nil {
  56. var ret O
  57. return ret, err
  58. }
  59. return ret, nil
  60. }
  61. // once is like memoize, but for functions that don't return non-error values.
  62. type once struct {
  63. m Memoize[any]
  64. }
  65. // Do runs fn.
  66. // fn is only run once per unique key. Subsequent Do calls with the same key
  67. // return the memoized result of the first call, even if fn is a different
  68. // function.
  69. func (o *once) Do(key any, fn func() error) error {
  70. _, err := o.m.Do(key, func() (any, error) {
  71. return nil, fn()
  72. })
  73. return err
  74. }