semaphore.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. // Copyright (C) 2018 The Syncthing Authors.
  2. //
  3. // This Source Code Form is subject to the terms of the Mozilla Public
  4. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  5. // You can obtain one at https://mozilla.org/MPL/2.0/.
  6. package semaphore
  7. import (
  8. "context"
  9. "sync"
  10. )
  11. type Semaphore struct {
  12. max int
  13. available int
  14. mut sync.Mutex
  15. cond *sync.Cond
  16. }
  17. func New(max int) *Semaphore {
  18. if max < 0 {
  19. max = 0
  20. }
  21. s := Semaphore{
  22. max: max,
  23. available: max,
  24. }
  25. s.cond = sync.NewCond(&s.mut)
  26. return &s
  27. }
  28. func (s *Semaphore) TakeWithContext(ctx context.Context, size int) error {
  29. done := make(chan struct{})
  30. var err error
  31. go func() {
  32. err = s.takeInner(ctx, size)
  33. close(done)
  34. }()
  35. select {
  36. case <-done:
  37. case <-ctx.Done():
  38. s.cond.Broadcast()
  39. <-done
  40. }
  41. return err
  42. }
  43. func (s *Semaphore) Take(size int) {
  44. _ = s.takeInner(context.Background(), size)
  45. }
  46. func (s *Semaphore) takeInner(ctx context.Context, size int) error {
  47. // Checking context for size <= s.available is required for testing and doesn't do any harm.
  48. select {
  49. case <-ctx.Done():
  50. return ctx.Err()
  51. default:
  52. }
  53. s.mut.Lock()
  54. defer s.mut.Unlock()
  55. if size > s.max {
  56. size = s.max
  57. }
  58. for size > s.available {
  59. s.cond.Wait()
  60. select {
  61. case <-ctx.Done():
  62. return ctx.Err()
  63. default:
  64. }
  65. if size > s.max {
  66. size = s.max
  67. }
  68. }
  69. s.available -= size
  70. return nil
  71. }
  72. func (s *Semaphore) Give(size int) {
  73. s.mut.Lock()
  74. if size > s.max {
  75. size = s.max
  76. }
  77. if s.available+size > s.max {
  78. s.available = s.max
  79. } else {
  80. s.available += size
  81. }
  82. s.cond.Broadcast()
  83. s.mut.Unlock()
  84. }
  85. func (s *Semaphore) SetCapacity(capacity int) {
  86. if capacity < 0 {
  87. capacity = 0
  88. }
  89. s.mut.Lock()
  90. diff := capacity - s.max
  91. s.max = capacity
  92. s.available += diff
  93. if s.available < 0 {
  94. s.available = 0
  95. } else if s.available > s.max {
  96. s.available = s.max
  97. }
  98. s.cond.Broadcast()
  99. s.mut.Unlock()
  100. }
  101. func (s *Semaphore) Available() int {
  102. s.mut.Lock()
  103. defer s.mut.Unlock()
  104. return s.available
  105. }
  106. // MultiSemaphore combines semaphores, making sure to always take and give in
  107. // the same order (reversed for give). A semaphore may be nil, in which case it
  108. // is skipped.
  109. type MultiSemaphore []*Semaphore
  110. func (s MultiSemaphore) TakeWithContext(ctx context.Context, size int) error {
  111. for _, limiter := range s {
  112. if limiter != nil {
  113. if err := limiter.TakeWithContext(ctx, size); err != nil {
  114. return err
  115. }
  116. }
  117. }
  118. return nil
  119. }
  120. func (s MultiSemaphore) Take(size int) {
  121. for _, limiter := range s {
  122. if limiter != nil {
  123. limiter.Take(size)
  124. }
  125. }
  126. }
  127. func (s MultiSemaphore) Give(size int) {
  128. for i := range s {
  129. limiter := s[len(s)-1-i]
  130. if limiter != nil {
  131. limiter.Give(size)
  132. }
  133. }
  134. }