Browse Source

lib/connections: Dial devices in parallel (#7783)

Simon Frei 4 years ago
parent
commit
c78fa42f31

+ 42 - 9
lib/connections/service.go

@@ -66,6 +66,8 @@ const (
 	worstDialerPriority           = math.MaxInt32
 	recentlySeenCutoff            = 7 * 24 * time.Hour
 	shortLivedConnectionThreshold = 5 * time.Second
+	dialMaxParallel               = 64
+	dialMaxParallelPerDevice      = 8
 )
 
 // From go/src/crypto/tls/cipher_suites.go
@@ -490,14 +492,40 @@ func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Con
 	// Perform dials according to the queue, stopping when we've reached the
 	// allowed additional number of connections (if limited).
 	numConns := 0
-	for _, entry := range queue {
-		if conn, ok := s.dialParallel(ctx, entry.id, entry.targets); ok {
-			s.conns <- conn
-			numConns++
-			if allowAdditional > 0 && numConns >= allowAdditional {
-				break
-			}
+	var numConnsMut stdsync.Mutex
+	dialSemaphore := util.NewSemaphore(dialMaxParallel)
+	dialWG := new(stdsync.WaitGroup)
+	dialCtx, dialCancel := context.WithCancel(ctx)
+	defer func() {
+		dialWG.Wait()
+		dialCancel()
+	}()
+	for i := range queue {
+		select {
+		case <-dialCtx.Done():
+			return
+		default:
 		}
+		dialWG.Add(1)
+		go func(entry dialQueueEntry) {
+			defer dialWG.Done()
+			conn, ok := s.dialParallel(dialCtx, entry.id, entry.targets, dialSemaphore)
+			if !ok {
+				return
+			}
+			numConnsMut.Lock()
+			if allowAdditional == 0 || numConns < allowAdditional {
+				select {
+				case s.conns <- conn:
+					numConns++
+					if allowAdditional > 0 && numConns >= allowAdditional {
+						dialCancel()
+					}
+				case <-dialCtx.Done():
+				}
+			}
+			numConnsMut.Unlock()
+		}(queue[i])
 	}
 }
 
@@ -959,7 +987,7 @@ func IsAllowedNetwork(host string, allowed []string) bool {
 	return false
 }
 
-func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID, dialTargets []dialTarget) (internalConn, bool) {
+func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID, dialTargets []dialTarget, parentSema *util.Semaphore) (internalConn, bool) {
 	// Group targets into buckets by priority
 	dialTargetBuckets := make(map[int][]dialTarget, len(dialTargets))
 	for _, tgt := range dialTargets {
@@ -975,13 +1003,19 @@ func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID,
 	// Sort the priorities so that we dial lowest first (which means highest...)
 	sort.Ints(priorities)
 
+	sema := util.MultiSemaphore{util.NewSemaphore(dialMaxParallelPerDevice), parentSema}
 	for _, prio := range priorities {
 		tgts := dialTargetBuckets[prio]
 		res := make(chan internalConn, len(tgts))
 		wg := stdsync.WaitGroup{}
 		for _, tgt := range tgts {
+			sema.Take(1)
 			wg.Add(1)
 			go func(tgt dialTarget) {
+				defer func() {
+					wg.Done()
+					sema.Give(1)
+				}()
 				conn, err := tgt.Dial(ctx)
 				if err == nil {
 					// Closes the connection on error
@@ -994,7 +1028,6 @@ func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID,
 					l.Debugln("dialing", deviceID, tgt.uri, "success:", conn)
 					res <- conn
 				}
-				wg.Done()
 			}(tgt)
 		}
 

+ 0 - 109
lib/model/bytesemaphore.go

@@ -1,109 +0,0 @@
-// Copyright (C) 2018 The Syncthing Authors.
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this file,
-// You can obtain one at https://mozilla.org/MPL/2.0/.
-
-package model
-
-import (
-	"context"
-	"sync"
-)
-
-type byteSemaphore struct {
-	max       int
-	available int
-	mut       sync.Mutex
-	cond      *sync.Cond
-}
-
-func newByteSemaphore(max int) *byteSemaphore {
-	if max < 0 {
-		max = 0
-	}
-	s := byteSemaphore{
-		max:       max,
-		available: max,
-	}
-	s.cond = sync.NewCond(&s.mut)
-	return &s
-}
-
-func (s *byteSemaphore) takeWithContext(ctx context.Context, bytes int) error {
-	done := make(chan struct{})
-	var err error
-	go func() {
-		err = s.takeInner(ctx, bytes)
-		close(done)
-	}()
-	select {
-	case <-done:
-	case <-ctx.Done():
-		s.cond.Broadcast()
-		<-done
-	}
-	return err
-}
-
-func (s *byteSemaphore) take(bytes int) {
-	_ = s.takeInner(context.Background(), bytes)
-}
-
-func (s *byteSemaphore) takeInner(ctx context.Context, bytes int) error {
-	// Checking context for bytes <= s.available is required for testing and doesn't do any harm.
-	select {
-	case <-ctx.Done():
-		return ctx.Err()
-	default:
-	}
-	s.mut.Lock()
-	defer s.mut.Unlock()
-	if bytes > s.max {
-		bytes = s.max
-	}
-	for bytes > s.available {
-		s.cond.Wait()
-		select {
-		case <-ctx.Done():
-			return ctx.Err()
-		default:
-		}
-		if bytes > s.max {
-			bytes = s.max
-		}
-	}
-	s.available -= bytes
-	return nil
-}
-
-func (s *byteSemaphore) give(bytes int) {
-	s.mut.Lock()
-	if bytes > s.max {
-		bytes = s.max
-	}
-	if s.available+bytes > s.max {
-		s.available = s.max
-	} else {
-		s.available += bytes
-	}
-	s.cond.Broadcast()
-	s.mut.Unlock()
-}
-
-func (s *byteSemaphore) setCapacity(cap int) {
-	if cap < 0 {
-		cap = 0
-	}
-	s.mut.Lock()
-	diff := cap - s.max
-	s.max = cap
-	s.available += diff
-	if s.available < 0 {
-		s.available = 0
-	} else if s.available > s.max {
-		s.available = s.max
-	}
-	s.cond.Broadcast()
-	s.mut.Unlock()
-}

+ 8 - 8
lib/model/folder.go

@@ -38,7 +38,7 @@ type folder struct {
 	stateTracker
 	config.FolderConfiguration
 	*stats.FolderStatisticsReference
-	ioLimiter *byteSemaphore
+	ioLimiter *util.Semaphore
 
 	localFlags uint32
 
@@ -91,7 +91,7 @@ type puller interface {
 	pull() (bool, error) // true when successful and should not be retried
 }
 
-func newFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, evLogger events.Logger, ioLimiter *byteSemaphore, ver versioner.Versioner) folder {
+func newFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, evLogger events.Logger, ioLimiter *util.Semaphore, ver versioner.Versioner) folder {
 	f := folder{
 		stateTracker:              newStateTracker(cfg.ID, evLogger),
 		FolderConfiguration:       cfg,
@@ -375,10 +375,10 @@ func (f *folder) pull() (success bool, err error) {
 	if f.Type != config.FolderTypeSendOnly {
 		f.setState(FolderSyncWaiting)
 
-		if err := f.ioLimiter.takeWithContext(f.ctx, 1); err != nil {
+		if err := f.ioLimiter.TakeWithContext(f.ctx, 1); err != nil {
 			return true, err
 		}
-		defer f.ioLimiter.give(1)
+		defer f.ioLimiter.Give(1)
 	}
 
 	startTime := time.Now()
@@ -439,10 +439,10 @@ func (f *folder) scanSubdirs(subDirs []string) error {
 	f.setState(FolderScanWaiting)
 	defer f.setState(FolderIdle)
 
-	if err := f.ioLimiter.takeWithContext(f.ctx, 1); err != nil {
+	if err := f.ioLimiter.TakeWithContext(f.ctx, 1); err != nil {
 		return err
 	}
-	defer f.ioLimiter.give(1)
+	defer f.ioLimiter.Give(1)
 
 	for i := range subDirs {
 		sub := osutil.NativeFilename(subDirs[i])
@@ -870,10 +870,10 @@ func (f *folder) versionCleanupTimerFired() {
 	f.setState(FolderCleanWaiting)
 	defer f.setState(FolderIdle)
 
-	if err := f.ioLimiter.takeWithContext(f.ctx, 1); err != nil {
+	if err := f.ioLimiter.TakeWithContext(f.ctx, 1); err != nil {
 		return
 	}
-	defer f.ioLimiter.give(1)
+	defer f.ioLimiter.Give(1)
 
 	f.setState(FolderCleaning)
 

+ 2 - 1
lib/model/folder_recvenc.go

@@ -16,6 +16,7 @@ import (
 	"github.com/syncthing/syncthing/lib/fs"
 	"github.com/syncthing/syncthing/lib/ignore"
 	"github.com/syncthing/syncthing/lib/protocol"
+	"github.com/syncthing/syncthing/lib/util"
 	"github.com/syncthing/syncthing/lib/versioner"
 )
 
@@ -27,7 +28,7 @@ type receiveEncryptedFolder struct {
 	*sendReceiveFolder
 }
 
-func newReceiveEncryptedFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *byteSemaphore) service {
+func newReceiveEncryptedFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *util.Semaphore) service {
 	return &receiveEncryptedFolder{newSendReceiveFolder(model, fset, ignores, cfg, ver, evLogger, ioLimiter).(*sendReceiveFolder)}
 }
 

+ 2 - 1
lib/model/folder_recvonly.go

@@ -15,6 +15,7 @@ import (
 	"github.com/syncthing/syncthing/lib/events"
 	"github.com/syncthing/syncthing/lib/ignore"
 	"github.com/syncthing/syncthing/lib/protocol"
+	"github.com/syncthing/syncthing/lib/util"
 	"github.com/syncthing/syncthing/lib/versioner"
 )
 
@@ -56,7 +57,7 @@ type receiveOnlyFolder struct {
 	*sendReceiveFolder
 }
 
-func newReceiveOnlyFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *byteSemaphore) service {
+func newReceiveOnlyFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *util.Semaphore) service {
 	sr := newSendReceiveFolder(model, fset, ignores, cfg, ver, evLogger, ioLimiter).(*sendReceiveFolder)
 	sr.localFlags = protocol.FlagLocalReceiveOnly // gets propagated to the scanner, and set on locally changed files
 	return &receiveOnlyFolder{sr}

+ 2 - 1
lib/model/folder_sendonly.go

@@ -12,6 +12,7 @@ import (
 	"github.com/syncthing/syncthing/lib/events"
 	"github.com/syncthing/syncthing/lib/ignore"
 	"github.com/syncthing/syncthing/lib/protocol"
+	"github.com/syncthing/syncthing/lib/util"
 	"github.com/syncthing/syncthing/lib/versioner"
 )
 
@@ -23,7 +24,7 @@ type sendOnlyFolder struct {
 	folder
 }
 
-func newSendOnlyFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, _ versioner.Versioner, evLogger events.Logger, ioLimiter *byteSemaphore) service {
+func newSendOnlyFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, _ versioner.Versioner, evLogger events.Logger, ioLimiter *util.Semaphore) service {
 	f := &sendOnlyFolder{
 		folder: newFolder(model, fset, ignores, cfg, evLogger, ioLimiter, nil),
 	}

+ 9 - 8
lib/model/folder_sendrecv.go

@@ -28,6 +28,7 @@ import (
 	"github.com/syncthing/syncthing/lib/scanner"
 	"github.com/syncthing/syncthing/lib/sha256"
 	"github.com/syncthing/syncthing/lib/sync"
+	"github.com/syncthing/syncthing/lib/util"
 	"github.com/syncthing/syncthing/lib/versioner"
 	"github.com/syncthing/syncthing/lib/weakhash"
 )
@@ -123,17 +124,17 @@ type sendReceiveFolder struct {
 
 	queue              *jobQueue
 	blockPullReorderer blockPullReorderer
-	writeLimiter       *byteSemaphore
+	writeLimiter       *util.Semaphore
 
 	tempPullErrors map[string]string // pull errors that might be just transient
 }
 
-func newSendReceiveFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *byteSemaphore) service {
+func newSendReceiveFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *util.Semaphore) service {
 	f := &sendReceiveFolder{
 		folder:             newFolder(model, fset, ignores, cfg, evLogger, ioLimiter, ver),
 		queue:              newJobQueue(),
 		blockPullReorderer: newBlockPullReorderer(cfg.BlockPullOrder, model.id, cfg.DeviceIDs()),
-		writeLimiter:       newByteSemaphore(cfg.MaxConcurrentWrites),
+		writeLimiter:       util.NewSemaphore(cfg.MaxConcurrentWrites),
 	}
 	f.folder.puller = f
 
@@ -1435,7 +1436,7 @@ func (f *sendReceiveFolder) verifyBuffer(buf []byte, block protocol.BlockInfo) e
 }
 
 func (f *sendReceiveFolder) pullerRoutine(snap *db.Snapshot, in <-chan pullBlockState, out chan<- *sharedPullerState) {
-	requestLimiter := newByteSemaphore(f.PullerMaxPendingKiB * 1024)
+	requestLimiter := util.NewSemaphore(f.PullerMaxPendingKiB * 1024)
 	wg := sync.NewWaitGroup()
 
 	for state := range in {
@@ -1453,7 +1454,7 @@ func (f *sendReceiveFolder) pullerRoutine(snap *db.Snapshot, in <-chan pullBlock
 		state := state
 		bytes := int(state.block.Size)
 
-		if err := requestLimiter.takeWithContext(f.ctx, bytes); err != nil {
+		if err := requestLimiter.TakeWithContext(f.ctx, bytes); err != nil {
 			state.fail(err)
 			out <- state.sharedPullerState
 			continue
@@ -1463,7 +1464,7 @@ func (f *sendReceiveFolder) pullerRoutine(snap *db.Snapshot, in <-chan pullBlock
 
 		go func() {
 			defer wg.Done()
-			defer requestLimiter.give(bytes)
+			defer requestLimiter.Give(bytes)
 
 			f.pullBlock(state, snap, out)
 		}()
@@ -2085,10 +2086,10 @@ func (f *sendReceiveFolder) limitedWriteAt(fd io.WriterAt, data []byte, offset i
 }
 
 func (f *sendReceiveFolder) withLimiter(fn func() error) error {
-	if err := f.writeLimiter.takeWithContext(f.ctx, 1); err != nil {
+	if err := f.writeLimiter.TakeWithContext(f.ctx, 1); err != nil {
 		return err
 	}
-	defer f.writeLimiter.give(1)
+	defer f.writeLimiter.Give(1)
 	return fn()
 }
 

+ 16 - 23
lib/model/model.go

@@ -40,6 +40,7 @@ import (
 	"github.com/syncthing/syncthing/lib/svcutil"
 	"github.com/syncthing/syncthing/lib/sync"
 	"github.com/syncthing/syncthing/lib/ur/contract"
+	"github.com/syncthing/syncthing/lib/util"
 	"github.com/syncthing/syncthing/lib/versioner"
 )
 
@@ -132,10 +133,10 @@ type model struct {
 	shortID         protocol.ShortID
 	// globalRequestLimiter limits the amount of data in concurrent incoming
 	// requests
-	globalRequestLimiter *byteSemaphore
+	globalRequestLimiter *util.Semaphore
 	// folderIOLimiter limits the number of concurrent I/O heavy operations,
 	// such as scans and pulls.
-	folderIOLimiter *byteSemaphore
+	folderIOLimiter *util.Semaphore
 	fatalChan       chan error
 	started         chan struct{}
 
@@ -155,7 +156,7 @@ type model struct {
 	// fields protected by pmut
 	pmut                sync.RWMutex
 	conn                map[protocol.DeviceID]protocol.Connection
-	connRequestLimiters map[protocol.DeviceID]*byteSemaphore
+	connRequestLimiters map[protocol.DeviceID]*util.Semaphore
 	closed              map[protocol.DeviceID]chan struct{}
 	helloMessages       map[protocol.DeviceID]protocol.Hello
 	deviceDownloads     map[protocol.DeviceID]*deviceDownloadState
@@ -166,7 +167,7 @@ type model struct {
 	foldersRunning int32
 }
 
-type folderFactory func(*model, *db.FileSet, *ignore.Matcher, config.FolderConfiguration, versioner.Versioner, events.Logger, *byteSemaphore) service
+type folderFactory func(*model, *db.FileSet, *ignore.Matcher, config.FolderConfiguration, versioner.Versioner, events.Logger, *util.Semaphore) service
 
 var (
 	folderFactories = make(map[config.FolderType]folderFactory)
@@ -220,8 +221,8 @@ func NewModel(cfg config.Wrapper, id protocol.DeviceID, clientName, clientVersio
 		finder:               db.NewBlockFinder(ldb),
 		progressEmitter:      NewProgressEmitter(cfg, evLogger),
 		shortID:              id.Short(),
-		globalRequestLimiter: newByteSemaphore(1024 * cfg.Options().MaxConcurrentIncomingRequestKiB()),
-		folderIOLimiter:      newByteSemaphore(cfg.Options().MaxFolderConcurrency()),
+		globalRequestLimiter: util.NewSemaphore(1024 * cfg.Options().MaxConcurrentIncomingRequestKiB()),
+		folderIOLimiter:      util.NewSemaphore(cfg.Options().MaxFolderConcurrency()),
 		fatalChan:            make(chan error),
 		started:              make(chan struct{}),
 
@@ -240,7 +241,7 @@ func NewModel(cfg config.Wrapper, id protocol.DeviceID, clientName, clientVersio
 		// fields protected by pmut
 		pmut:                sync.NewRWMutex(),
 		conn:                make(map[protocol.DeviceID]protocol.Connection),
-		connRequestLimiters: make(map[protocol.DeviceID]*byteSemaphore),
+		connRequestLimiters: make(map[protocol.DeviceID]*util.Semaphore),
 		closed:              make(map[protocol.DeviceID]chan struct{}),
 		helloMessages:       make(map[protocol.DeviceID]protocol.Hello),
 		deviceDownloads:     make(map[protocol.DeviceID]*deviceDownloadState),
@@ -1906,23 +1907,15 @@ func (m *model) Request(deviceID protocol.DeviceID, folder, name string, blockNo
 // skipping nil limiters, then returns a requestResponse of the given size.
 // When the requestResponse is closed the limiters are given back the bytes,
 // in reverse order.
-func newLimitedRequestResponse(size int, limiters ...*byteSemaphore) *requestResponse {
-	for _, limiter := range limiters {
-		if limiter != nil {
-			limiter.take(size)
-		}
-	}
+func newLimitedRequestResponse(size int, limiters ...*util.Semaphore) *requestResponse {
+	multi := util.MultiSemaphore(limiters)
+	multi.Take(size)
 
 	res := newRequestResponse(size)
 
 	go func() {
 		res.Wait()
-		for i := range limiters {
-			limiter := limiters[len(limiters)-1-i]
-			if limiter != nil {
-				limiter.give(size)
-			}
-		}
+		multi.Give(size)
 	}()
 
 	return res
@@ -2230,9 +2223,9 @@ func (m *model) AddConnection(conn protocol.Connection, hello protocol.Hello) {
 	// 0: default, <0: no limiting
 	switch {
 	case device.MaxRequestKiB > 0:
-		m.connRequestLimiters[deviceID] = newByteSemaphore(1024 * device.MaxRequestKiB)
+		m.connRequestLimiters[deviceID] = util.NewSemaphore(1024 * device.MaxRequestKiB)
 	case device.MaxRequestKiB == 0:
-		m.connRequestLimiters[deviceID] = newByteSemaphore(1024 * defaultPullerPendingKiB)
+		m.connRequestLimiters[deviceID] = util.NewSemaphore(1024 * defaultPullerPendingKiB)
 	}
 
 	m.helloMessages[deviceID] = hello
@@ -2927,8 +2920,8 @@ func (m *model) CommitConfiguration(from, to config.Configuration) bool {
 	ignoredDevices := observedDeviceSet(to.IgnoredDevices)
 	m.cleanPending(toDevices, toFolders, ignoredDevices, removedFolders)
 
-	m.globalRequestLimiter.setCapacity(1024 * to.Options.MaxConcurrentIncomingRequestKiB())
-	m.folderIOLimiter.setCapacity(to.Options.MaxFolderConcurrency())
+	m.globalRequestLimiter.SetCapacity(1024 * to.Options.MaxConcurrentIncomingRequestKiB())
+	m.folderIOLimiter.SetCapacity(to.Options.MaxFolderConcurrency())
 
 	// Some options don't require restart as those components handle it fine
 	// by themselves. Compare the options structs containing only the

+ 6 - 5
lib/model/model_test.go

@@ -38,6 +38,7 @@ import (
 	protocolmocks "github.com/syncthing/syncthing/lib/protocol/mocks"
 	srand "github.com/syncthing/syncthing/lib/rand"
 	"github.com/syncthing/syncthing/lib/testutils"
+	"github.com/syncthing/syncthing/lib/util"
 	"github.com/syncthing/syncthing/lib/versioner"
 )
 
@@ -3319,14 +3320,14 @@ func TestDeviceWasSeen(t *testing.T) {
 }
 
 func TestNewLimitedRequestResponse(t *testing.T) {
-	l0 := newByteSemaphore(0)
-	l1 := newByteSemaphore(1024)
-	l2 := (*byteSemaphore)(nil)
+	l0 := util.NewSemaphore(0)
+	l1 := util.NewSemaphore(1024)
+	l2 := (*util.Semaphore)(nil)
 
 	// Should take 500 bytes from any non-unlimited non-nil limiters.
 	res := newLimitedRequestResponse(500, l0, l1, l2)
 
-	if l1.available != 1024-500 {
+	if l1.Available() != 1024-500 {
 		t.Error("should have taken bytes from limited limiter")
 	}
 
@@ -3336,7 +3337,7 @@ func TestNewLimitedRequestResponse(t *testing.T) {
 	// Try to take 1024 bytes to make sure the bytes were returned.
 	done := make(chan struct{})
 	go func() {
-		l1.take(1024)
+		l1.Take(1024)
 		close(done)
 	}()
 	select {

+ 148 - 0
lib/util/semaphore.go

@@ -0,0 +1,148 @@
+// Copyright (C) 2018 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at https://mozilla.org/MPL/2.0/.
+
+package util
+
+import (
+	"context"
+	"sync"
+)
+
+type Semaphore struct {
+	max       int
+	available int
+	mut       sync.Mutex
+	cond      *sync.Cond
+}
+
+func NewSemaphore(max int) *Semaphore {
+	if max < 0 {
+		max = 0
+	}
+	s := Semaphore{
+		max:       max,
+		available: max,
+	}
+	s.cond = sync.NewCond(&s.mut)
+	return &s
+}
+
+func (s *Semaphore) TakeWithContext(ctx context.Context, size int) error {
+	done := make(chan struct{})
+	var err error
+	go func() {
+		err = s.takeInner(ctx, size)
+		close(done)
+	}()
+	select {
+	case <-done:
+	case <-ctx.Done():
+		s.cond.Broadcast()
+		<-done
+	}
+	return err
+}
+
+func (s *Semaphore) Take(size int) {
+	_ = s.takeInner(context.Background(), size)
+}
+
+func (s *Semaphore) takeInner(ctx context.Context, size int) error {
+	// Checking context for size <= s.available is required for testing and doesn't do any harm.
+	select {
+	case <-ctx.Done():
+		return ctx.Err()
+	default:
+	}
+	s.mut.Lock()
+	defer s.mut.Unlock()
+	if size > s.max {
+		size = s.max
+	}
+	for size > s.available {
+		s.cond.Wait()
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		default:
+		}
+		if size > s.max {
+			size = s.max
+		}
+	}
+	s.available -= size
+	return nil
+}
+
+func (s *Semaphore) Give(size int) {
+	s.mut.Lock()
+	if size > s.max {
+		size = s.max
+	}
+	if s.available+size > s.max {
+		s.available = s.max
+	} else {
+		s.available += size
+	}
+	s.cond.Broadcast()
+	s.mut.Unlock()
+}
+
+func (s *Semaphore) SetCapacity(capacity int) {
+	if capacity < 0 {
+		capacity = 0
+	}
+	s.mut.Lock()
+	diff := capacity - s.max
+	s.max = capacity
+	s.available += diff
+	if s.available < 0 {
+		s.available = 0
+	} else if s.available > s.max {
+		s.available = s.max
+	}
+	s.cond.Broadcast()
+	s.mut.Unlock()
+}
+
+func (s *Semaphore) Available() int {
+	s.mut.Lock()
+	defer s.mut.Unlock()
+	return s.available
+}
+
+// MultiSemaphore combines semaphores, making sure to always take and give in
+// the same order (reversed for give). A semaphore may be nil, in which case it
+// is skipped.
+type MultiSemaphore []*Semaphore
+
+func (s MultiSemaphore) TakeWithContext(ctx context.Context, size int) error {
+	for _, limiter := range s {
+		if limiter != nil {
+			if err := limiter.TakeWithContext(ctx, size); err != nil {
+				return err
+			}
+		}
+	}
+	return nil
+}
+
+func (s MultiSemaphore) Take(size int) {
+	for _, limiter := range s {
+		if limiter != nil {
+			limiter.Take(size)
+		}
+	}
+}
+
+func (s MultiSemaphore) Give(size int) {
+	for i := range s {
+		limiter := s[len(s)-1-i]
+		if limiter != nil {
+			limiter.Give(size)
+		}
+	}
+}

+ 23 - 23
lib/model/bytesemaphore_test.go → lib/util/semaphore_test.go

@@ -4,38 +4,38 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // You can obtain one at https://mozilla.org/MPL/2.0/.
 
-package model
+package util
 
 import "testing"
 
 func TestZeroByteSempahore(t *testing.T) {
 	// A semaphore with zero capacity is just a no-op.
 
-	s := newByteSemaphore(0)
+	s := NewSemaphore(0)
 
 	// None of these should block or panic
-	s.take(123)
-	s.take(456)
-	s.give(1 << 30)
+	s.Take(123)
+	s.Take(456)
+	s.Give(1 << 30)
 }
 
 func TestByteSempahoreCapChangeUp(t *testing.T) {
 	// Waiting takes should unblock when the capacity increases
 
-	s := newByteSemaphore(100)
+	s := NewSemaphore(100)
 
-	s.take(75)
+	s.Take(75)
 	if s.available != 25 {
 		t.Error("bad state after take")
 	}
 
 	gotit := make(chan struct{})
 	go func() {
-		s.take(75)
+		s.Take(75)
 		close(gotit)
 	}()
 
-	s.setCapacity(155)
+	s.SetCapacity(155)
 	<-gotit
 	if s.available != 5 {
 		t.Error("bad state after both takes")
@@ -45,19 +45,19 @@ func TestByteSempahoreCapChangeUp(t *testing.T) {
 func TestByteSempahoreCapChangeDown1(t *testing.T) {
 	// Things should make sense when capacity is adjusted down
 
-	s := newByteSemaphore(100)
+	s := NewSemaphore(100)
 
-	s.take(75)
+	s.Take(75)
 	if s.available != 25 {
 		t.Error("bad state after take")
 	}
 
-	s.setCapacity(90)
+	s.SetCapacity(90)
 	if s.available != 15 {
 		t.Error("bad state after adjust")
 	}
 
-	s.give(75)
+	s.Give(75)
 	if s.available != 90 {
 		t.Error("bad state after give")
 	}
@@ -66,19 +66,19 @@ func TestByteSempahoreCapChangeDown1(t *testing.T) {
 func TestByteSempahoreCapChangeDown2(t *testing.T) {
 	// Things should make sense when capacity is adjusted down, different case
 
-	s := newByteSemaphore(100)
+	s := NewSemaphore(100)
 
-	s.take(75)
+	s.Take(75)
 	if s.available != 25 {
 		t.Error("bad state after take")
 	}
 
-	s.setCapacity(10)
+	s.SetCapacity(10)
 	if s.available != 0 {
 		t.Error("bad state after adjust")
 	}
 
-	s.give(75)
+	s.Give(75)
 	if s.available != 10 {
 		t.Error("bad state after give")
 	}
@@ -87,26 +87,26 @@ func TestByteSempahoreCapChangeDown2(t *testing.T) {
 func TestByteSempahoreGiveMore(t *testing.T) {
 	// We shouldn't end up with more available than we have capacity...
 
-	s := newByteSemaphore(100)
+	s := NewSemaphore(100)
 
-	s.take(150)
+	s.Take(150)
 	if s.available != 0 {
 		t.Errorf("bad state after large take")
 	}
 
-	s.give(150)
+	s.Give(150)
 	if s.available != 100 {
 		t.Errorf("bad state after large take + give")
 	}
 
-	s.take(150)
-	s.setCapacity(125)
+	s.Take(150)
+	s.SetCapacity(125)
 	// available was zero before, we're increasing capacity by 25
 	if s.available != 25 {
 		t.Errorf("bad state after setcap")
 	}
 
-	s.give(150)
+	s.Give(150)
 	if s.available != 125 {
 		t.Errorf("bad state after large take + give with adjustment")
 	}