소스 검색

lib/model, lib/protocol: Handle request concurrency in model (#5216)

Simon Frei 7 년 전
부모
커밋
4f27bdfc27

+ 1 - 0
lib/config/deviceconfiguration.go

@@ -28,6 +28,7 @@ type DeviceConfiguration struct {
 	MaxRecvKbps              int                  `xml:"maxRecvKbps" json:"maxRecvKbps"`
 	IgnoredFolders           []ObservedFolder     `xml:"ignoredFolder" json:"ignoredFolders"`
 	PendingFolders           []ObservedFolder     `xml:"pendingFolder" json:"pendingFolders"`
+	MaxRequestKiB            int                  `xml:"maxRequestKiB" json:"maxRequestKiB"`
 }
 
 func NewDeviceConfiguration(id protocol.DeviceID, name string) DeviceConfiguration {

+ 50 - 0
lib/model/bytesemaphore.go

@@ -0,0 +1,50 @@
+// Copyright (C) 2018 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at https://mozilla.org/MPL/2.0/.
+
+package model
+
+import "sync"
+
+type byteSemaphore struct {
+	max       int
+	available int
+	mut       sync.Mutex
+	cond      *sync.Cond
+}
+
+func newByteSemaphore(max int) *byteSemaphore {
+	s := byteSemaphore{
+		max:       max,
+		available: max,
+	}
+	s.cond = sync.NewCond(&s.mut)
+	return &s
+}
+
+func (s *byteSemaphore) take(bytes int) {
+	if bytes > s.max {
+		bytes = s.max
+	}
+	s.mut.Lock()
+	for bytes > s.available {
+		s.cond.Wait()
+	}
+	s.available -= bytes
+	s.mut.Unlock()
+}
+
+func (s *byteSemaphore) give(bytes int) {
+	if bytes > s.max {
+		bytes = s.max
+	}
+	s.mut.Lock()
+	if s.available+bytes > s.max {
+		panic("bug: can never give more than max")
+	}
+	s.available += bytes
+	s.cond.Broadcast()
+	s.mut.Unlock()
+}

+ 5 - 45
lib/model/folder_sendrecv.go

@@ -15,7 +15,6 @@ import (
 	"runtime"
 	"sort"
 	"strings"
-	stdsync "sync"
 	"time"
 
 	"github.com/syncthing/syncthing/lib/config"
@@ -1147,7 +1146,10 @@ func (f *sendReceiveFolder) shortcutFile(file, curFile protocol.FileInfo, dbUpda
 // copierRoutine reads copierStates until the in channel closes and performs
 // the relevant copies when possible, or passes it to the puller routine.
 func (f *sendReceiveFolder) copierRoutine(in <-chan copyBlocksState, pullChan chan<- pullBlockState, out chan<- *sharedPullerState) {
-	buf := make([]byte, protocol.MinBlockSize)
+	buf := protocol.BufferPool.Get(protocol.MinBlockSize)
+	defer func() {
+		protocol.BufferPool.Put(buf)
+	}()
 
 	for state := range in {
 		dstFd, err := state.tempFile()
@@ -1223,11 +1225,7 @@ func (f *sendReceiveFolder) copierRoutine(in <-chan copyBlocksState, pullChan ch
 				continue
 			}
 
-			if s := int(block.Size); s > cap(buf) {
-				buf = make([]byte, s)
-			} else {
-				buf = buf[:s]
-			}
+			buf = protocol.BufferPool.Upgrade(buf, int(block.Size))
 
 			found, err := weakHashFinder.Iterate(block.WeakHash, buf, func(offset int64) bool {
 				if verifyBuffer(buf, block) != nil {
@@ -1935,41 +1933,3 @@ func componentCount(name string) int {
 	}
 	return count
 }
-
-type byteSemaphore struct {
-	max       int
-	available int
-	mut       stdsync.Mutex
-	cond      *stdsync.Cond
-}
-
-func newByteSemaphore(max int) *byteSemaphore {
-	s := byteSemaphore{
-		max:       max,
-		available: max,
-	}
-	s.cond = stdsync.NewCond(&s.mut)
-	return &s
-}
-
-func (s *byteSemaphore) take(bytes int) {
-	if bytes > s.max {
-		panic("bug: more than max bytes will never be available")
-	}
-	s.mut.Lock()
-	for bytes > s.available {
-		s.cond.Wait()
-	}
-	s.available -= bytes
-	s.mut.Unlock()
-}
-
-func (s *byteSemaphore) give(bytes int) {
-	s.mut.Lock()
-	if s.available+bytes > s.max {
-		panic("bug: can never give more than max")
-	}
-	s.available += bytes
-	s.cond.Broadcast()
-	s.mut.Unlock()
-}

+ 111 - 37
lib/model/model.go

@@ -105,6 +105,7 @@ type Model struct {
 
 	pmut                sync.RWMutex // protects the below
 	conn                map[protocol.DeviceID]connections.Connection
+	connRequestLimiters map[protocol.DeviceID]*byteSemaphore
 	closed              map[protocol.DeviceID]chan struct{}
 	helloMessages       map[protocol.DeviceID]protocol.HelloResult
 	deviceDownloads     map[protocol.DeviceID]*deviceDownloadState
@@ -158,6 +159,7 @@ func NewModel(cfg *config.Wrapper, id protocol.DeviceID, clientName, clientVersi
 		folderRunnerTokens:  make(map[string][]suture.ServiceToken),
 		folderStatRefs:      make(map[string]*stats.FolderStatisticsReference),
 		conn:                make(map[protocol.DeviceID]connections.Connection),
+		connRequestLimiters: make(map[protocol.DeviceID]*byteSemaphore),
 		closed:              make(map[protocol.DeviceID]chan struct{}),
 		helloMessages:       make(map[protocol.DeviceID]protocol.HelloResult),
 		deviceDownloads:     make(map[protocol.DeviceID]*deviceDownloadState),
@@ -1281,6 +1283,7 @@ func (m *Model) Closed(conn protocol.Connection, err error) {
 		m.progressEmitter.temporaryIndexUnsubscribe(conn)
 	}
 	delete(m.conn, device)
+	delete(m.connRequestLimiters, device)
 	delete(m.helloMessages, device)
 	delete(m.deviceDownloads, device)
 	delete(m.remotePausedFolders, device)
@@ -1314,19 +1317,40 @@ func (m *Model) closeLocked(device protocol.DeviceID) {
 	closeRawConn(conn)
 }
 
-// Request returns the specified data segment by reading it from local disk.
-// Implements the protocol.Model interface.
-func (m *Model) Request(deviceID protocol.DeviceID, folder, name string, offset int64, hash []byte, weakHash uint32, fromTemporary bool, buf []byte) error {
-	if offset < 0 {
-		return protocol.ErrInvalid
+// Implements protocol.RequestResponse
+type requestResponse struct {
+	data   []byte
+	closed chan struct{}
+	once   stdsync.Once
+}
+
+func newRequestResponse(size int) *requestResponse {
+	return &requestResponse{
+		data:   protocol.BufferPool.Get(size),
+		closed: make(chan struct{}),
 	}
+}
 
-	if cfg, ok := m.cfg.Folder(folder); !ok || !cfg.SharedWith(deviceID) {
-		l.Warnf("Request from %s for file %s in unshared folder %q", deviceID, name, folder)
-		return protocol.ErrNoSuchFile
-	} else if cfg.Paused {
-		l.Debugf("Request from %s for file %s in paused folder %q", deviceID, name, folder)
-		return protocol.ErrInvalid
+func (r *requestResponse) Data() []byte {
+	return r.data
+}
+
+func (r *requestResponse) Close() {
+	r.once.Do(func() {
+		protocol.BufferPool.Put(r.data)
+		close(r.closed)
+	})
+}
+
+func (r *requestResponse) Wait() {
+	<-r.closed
+}
+
+// Request returns the specified data segment by reading it from local disk.
+// Implements the protocol.Model interface.
+func (m *Model) Request(deviceID protocol.DeviceID, folder, name string, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (out protocol.RequestResponse, err error) {
+	if size < 0 || offset < 0 {
+		return nil, protocol.ErrInvalid
 	}
 
 	m.fmut.RLock()
@@ -1337,35 +1361,69 @@ func (m *Model) Request(deviceID protocol.DeviceID, folder, name string, offset
 		// The folder might be already unpaused in the config, but not yet
 		// in the model.
 		l.Debugf("Request from %s for file %s in unstarted folder %q", deviceID, name, folder)
-		return protocol.ErrInvalid
+		return nil, protocol.ErrInvalid
+	}
+
+	if !folderCfg.SharedWith(deviceID) {
+		l.Warnf("Request from %s for file %s in unshared folder %q", deviceID, name, folder)
+		return nil, protocol.ErrNoSuchFile
+	}
+	if folderCfg.Paused {
+		l.Debugf("Request from %s for file %s in paused folder %q", deviceID, name, folder)
+		return nil, protocol.ErrInvalid
 	}
 
 	// Make sure the path is valid and in canonical form
-	var err error
 	if name, err = fs.Canonicalize(name); err != nil {
 		l.Debugf("Request from %s in folder %q for invalid filename %s", deviceID, folder, name)
-		return protocol.ErrInvalid
+		return nil, protocol.ErrInvalid
 	}
 
 	if deviceID != protocol.LocalDeviceID {
-		l.Debugf("%v REQ(in): %s: %q / %q o=%d s=%d t=%v", m, deviceID, folder, name, offset, len(buf), fromTemporary)
+		l.Debugf("%v REQ(in): %s: %q / %q o=%d s=%d t=%v", m, deviceID, folder, name, offset, size, fromTemporary)
 	}
 
-	folderFs := folderCfg.Filesystem()
-
 	if fs.IsInternal(name) {
-		l.Debugf("%v REQ(in) for internal file: %s: %q / %q o=%d s=%d", m, deviceID, folder, name, offset, len(buf))
-		return protocol.ErrNoSuchFile
+		l.Debugf("%v REQ(in) for internal file: %s: %q / %q o=%d s=%d", m, deviceID, folder, name, offset, size)
+		return nil, protocol.ErrNoSuchFile
 	}
 
 	if folderIgnores.Match(name).IsIgnored() {
-		l.Debugf("%v REQ(in) for ignored file: %s: %q / %q o=%d s=%d", m, deviceID, folder, name, offset, len(buf))
-		return protocol.ErrNoSuchFile
+		l.Debugf("%v REQ(in) for ignored file: %s: %q / %q o=%d s=%d", m, deviceID, folder, name, offset, size)
+		return nil, protocol.ErrNoSuchFile
 	}
 
+	folderFs := folderCfg.Filesystem()
+
 	if err := osutil.TraversesSymlink(folderFs, filepath.Dir(name)); err != nil {
-		l.Debugf("%v REQ(in) traversal check: %s - %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, len(buf))
-		return protocol.ErrNoSuchFile
+		l.Debugf("%v REQ(in) traversal check: %s - %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, size)
+		return nil, protocol.ErrNoSuchFile
+	}
+
+	// Restrict parallel requests by connection/device
+
+	m.pmut.RLock()
+	limiter := m.connRequestLimiters[deviceID]
+	m.pmut.RUnlock()
+
+	if limiter != nil {
+		limiter.take(int(size))
+	}
+
+	// The requestResponse releases the bytes to the limiter when its Close method is called.
+	res := newRequestResponse(int(size))
+	defer func() {
+		// Close it ourselves if it isn't returned due to an error
+		if err != nil {
+			res.Close()
+		}
+	}()
+
+	if limiter != nil {
+		go func() {
+			res.Wait()
+			limiter.give(int(size))
+		}()
 	}
 
 	// Only check temp files if the flag is set, and if we are set to advertise
@@ -1376,11 +1434,12 @@ func (m *Model) Request(deviceID protocol.DeviceID, folder, name string, offset
 		if info, err := folderFs.Lstat(tempFn); err != nil || !info.IsRegular() {
 			// Reject reads for anything that doesn't exist or is something
 			// other than a regular file.
-			return protocol.ErrNoSuchFile
+			l.Debugf("%v REQ(in) failed stating temp file (%v): %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, size)
+			return nil, protocol.ErrNoSuchFile
 		}
-		err := readOffsetIntoBuf(folderFs, tempFn, offset, buf)
-		if err == nil && scanner.Validate(buf, hash, weakHash) {
-			return nil
+		err := readOffsetIntoBuf(folderFs, tempFn, offset, res.data)
+		if err == nil && scanner.Validate(res.data, hash, weakHash) {
+			return res, nil
 		}
 		// Fall through to reading from a non-temp file, just incase the temp
 		// file has finished downloading.
@@ -1389,21 +1448,25 @@ func (m *Model) Request(deviceID protocol.DeviceID, folder, name string, offset
 	if info, err := folderFs.Lstat(name); err != nil || !info.IsRegular() {
 		// Reject reads for anything that doesn't exist or is something
 		// other than a regular file.
-		return protocol.ErrNoSuchFile
+		l.Debugf("%v REQ(in) failed stating file (%v): %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, size)
+		return nil, protocol.ErrNoSuchFile
 	}
 
-	if err = readOffsetIntoBuf(folderFs, name, offset, buf); fs.IsNotExist(err) {
-		return protocol.ErrNoSuchFile
+	if err := readOffsetIntoBuf(folderFs, name, offset, res.data); fs.IsNotExist(err) {
+		l.Debugf("%v REQ(in) file doesn't exist: %s: %q / %q o=%d s=%d", m, deviceID, folder, name, offset, size)
+		return nil, protocol.ErrNoSuchFile
 	} else if err != nil {
-		return protocol.ErrGeneric
+		l.Debugf("%v REQ(in) failed reading file (%v): %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, size)
+		return nil, protocol.ErrGeneric
 	}
 
-	if !scanner.Validate(buf, hash, weakHash) {
-		m.recheckFile(deviceID, folderFs, folder, name, int(offset)/len(buf), hash)
-		return protocol.ErrNoSuchFile
+	if !scanner.Validate(res.data, hash, weakHash) {
+		m.recheckFile(deviceID, folderFs, folder, name, int(offset)/int(size), hash)
+		l.Debugf("%v REQ(in) failed validating data (%v): %s: %q / %q o=%d s=%d", m, err, deviceID, folder, name, offset, size)
+		return nil, protocol.ErrNoSuchFile
 	}
 
-	return nil
+	return res, nil
 }
 
 func (m *Model) recheckFile(deviceID protocol.DeviceID, folderFs fs.Filesystem, folder, name string, blockIndex int, hash []byte) {
@@ -1598,6 +1661,11 @@ func (m *Model) GetHello(id protocol.DeviceID) protocol.HelloIntf {
 // folder changes.
 func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloResult) {
 	deviceID := conn.ID()
+	device, ok := m.cfg.Device(deviceID)
+	if !ok {
+		l.Infoln("Trying to add connection to unknown device")
+		return
+	}
 
 	m.pmut.Lock()
 	if oldConn, ok := m.conn[deviceID]; ok {
@@ -1617,6 +1685,13 @@ func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloR
 	m.conn[deviceID] = conn
 	m.closed[deviceID] = make(chan struct{})
 	m.deviceDownloads[deviceID] = newDeviceDownloadState()
+	// 0: default, <0: no limiting
+	switch {
+	case device.MaxRequestKiB > 0:
+		m.connRequestLimiters[deviceID] = newByteSemaphore(1024 * device.MaxRequestKiB)
+	case device.MaxRequestKiB == 0:
+		m.connRequestLimiters[deviceID] = newByteSemaphore(1024 * defaultPullerPendingKiB)
+	}
 
 	m.helloMessages[deviceID] = hello
 
@@ -1644,8 +1719,7 @@ func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloR
 	cm := m.generateClusterConfig(deviceID)
 	conn.ClusterConfig(cm)
 
-	device, ok := m.cfg.Devices()[deviceID]
-	if ok && (device.Name == "" || m.cfg.Options().OverwriteRemoteDevNames) && hello.DeviceName != "" {
+	if (device.Name == "" || m.cfg.Options().OverwriteRemoteDevNames) && hello.DeviceName != "" {
 		device.Name = hello.DeviceName
 		m.cfg.SetDevice(device)
 		m.cfg.Save()

+ 51 - 11
lib/model/model_test.go

@@ -183,45 +183,42 @@ func TestRequest(t *testing.T) {
 	defer m.Stop()
 	m.ScanFolder("default")
 
-	bs := make([]byte, protocol.MinBlockSize)
-
 	// Existing, shared file
-	bs = bs[:6]
-	err := m.Request(device1, "default", "foo", 0, nil, 0, false, bs)
+	res, err := m.Request(device1, "default", "foo", 6, 0, nil, 0, false)
 	if err != nil {
 		t.Error(err)
 	}
+	bs := res.Data()
 	if !bytes.Equal(bs, []byte("foobar")) {
 		t.Errorf("Incorrect data from request: %q", string(bs))
 	}
 
 	// Existing, nonshared file
-	err = m.Request(device2, "default", "foo", 0, nil, 0, false, bs)
+	_, err = m.Request(device2, "default", "foo", 6, 0, nil, 0, false)
 	if err == nil {
 		t.Error("Unexpected nil error on insecure file read")
 	}
 
 	// Nonexistent file
-	err = m.Request(device1, "default", "nonexistent", 0, nil, 0, false, bs)
+	_, err = m.Request(device1, "default", "nonexistent", 6, 0, nil, 0, false)
 	if err == nil {
 		t.Error("Unexpected nil error on insecure file read")
 	}
 
 	// Shared folder, but disallowed file name
-	err = m.Request(device1, "default", "../walk.go", 0, nil, 0, false, bs)
+	_, err = m.Request(device1, "default", "../walk.go", 6, 0, nil, 0, false)
 	if err == nil {
 		t.Error("Unexpected nil error on insecure file read")
 	}
 
 	// Negative offset
-	err = m.Request(device1, "default", "foo", -4, nil, 0, false, bs[:0])
+	_, err = m.Request(device1, "default", "foo", -4, 0, nil, 0, false)
 	if err == nil {
 		t.Error("Unexpected nil error on insecure file read")
 	}
 
 	// Larger block than available
-	bs = bs[:42]
-	err = m.Request(device1, "default", "foo", 0, nil, 0, false, bs)
+	_, err = m.Request(device1, "default", "foo", 42, 0, nil, 0, false)
 	if err == nil {
 		t.Error("Unexpected nil error on insecure file read")
 	}
@@ -536,7 +533,7 @@ func BenchmarkRequestInSingleFile(b *testing.B) {
 	b.ResetTimer()
 
 	for i := 0; i < b.N; i++ {
-		if err := m.Request(device1, "default", "request/for/a/file/in/a/couple/of/dirs/128k", 0, nil, 0, false, buf); err != nil {
+		if _, err := m.Request(device1, "default", "request/for/a/file/in/a/couple/of/dirs/128k", 128<<10, 0, nil, 0, false); err != nil {
 			b.Error(err)
 		}
 	}
@@ -3667,6 +3664,7 @@ func TestFolderRestartZombies(t *testing.T) {
 	// would leave more than one folder runner alive.
 
 	wrapper := createTmpWrapper(defaultCfg.Copy())
+	defer os.Remove(wrapper.ConfigPath())
 	folderCfg, _ := wrapper.Folder("default")
 	folderCfg.FilesystemType = fs.FilesystemTypeFake
 	wrapper.SetFolder(folderCfg)
@@ -3759,3 +3757,45 @@ func (c *alwaysChanged) Seen(fs fs.Filesystem, name string) bool {
 func (c *alwaysChanged) Changed() bool {
 	return true
 }
+
+func TestRequestLimit(t *testing.T) {
+	cfg := defaultCfg.Copy()
+	cfg.Devices = append(cfg.Devices, config.NewDeviceConfiguration(device2, "device2"))
+	cfg.Devices[1].MaxRequestKiB = 1
+	cfg.Folders[0].Devices = []config.FolderDeviceConfiguration{
+		{DeviceID: device1},
+		{DeviceID: device2},
+	}
+	m, _, wrapper := setupModelWithConnectionManual(cfg)
+	defer m.Stop()
+	defer os.Remove(wrapper.ConfigPath())
+
+	file := "tmpfile"
+	befReq := time.Now()
+	first, err := m.Request(device2, "default", file, 2000, 0, nil, 0, false)
+	if err != nil {
+		t.Fatalf("First request failed: %v", err)
+	}
+	reqDur := time.Since(befReq)
+	returned := make(chan struct{})
+	go func() {
+		second, err := m.Request(device2, "default", file, 2000, 0, nil, 0, false)
+		if err != nil {
+			t.Fatalf("Second request failed: %v", err)
+		}
+		close(returned)
+		second.Close()
+	}()
+	time.Sleep(10 * reqDur)
+	select {
+	case <-returned:
+		t.Fatalf("Second request returned before first was done")
+	default:
+	}
+	first.Close()
+	select {
+	case <-returned:
+	case <-time.After(time.Second):
+		t.Fatalf("Second request did not return after first was done")
+	}
+}

+ 6 - 7
lib/model/requests_test.go

@@ -98,9 +98,8 @@ func TestSymlinkTraversalRead(t *testing.T) {
 	<-done
 
 	// Request a file by traversing the symlink
-	buf := make([]byte, 10)
-	err := m.Request(device1, "default", "symlink/requests_test.go", 0, nil, 0, false, buf)
-	if err == nil || !bytes.Equal(buf, make([]byte, 10)) {
+	res, err := m.Request(device1, "default", "symlink/requests_test.go", 10, 0, nil, 0, false)
+	if err == nil || res != nil {
 		t.Error("Managed to traverse symlink")
 	}
 }
@@ -225,6 +224,7 @@ func TestRequestVersioningSymlinkAttack(t *testing.T) {
 	defer os.RemoveAll(tmpDir)
 
 	cfg := defaultCfgWrapper.RawCopy()
+	cfg.Devices = append(cfg.Devices, config.NewDeviceConfiguration(device2, "device2"))
 	cfg.Folders[0] = config.NewFolderConfiguration(protocol.LocalDeviceID, "default", "default", fs.FilesystemTypeBasic, tmpDir)
 	cfg.Folders[0].Devices = []config.FolderDeviceConfiguration{
 		{DeviceID: device1},
@@ -519,12 +519,11 @@ func TestRescanIfHaveInvalidContent(t *testing.T) {
 		t.Fatalf("unexpected weak hash: %d != 103547413", f.Blocks[0].WeakHash)
 	}
 
-	buf := make([]byte, len(payload))
-
-	err := m.Request(device2, "default", "foo", 0, f.Blocks[0].Hash, f.Blocks[0].WeakHash, false, buf)
+	res, err := m.Request(device2, "default", "foo", int32(len(payload)), 0, f.Blocks[0].Hash, f.Blocks[0].WeakHash, false)
 	if err != nil {
 		t.Fatal(err)
 	}
+	buf := res.Data()
 	if !bytes.Equal(buf, payload) {
 		t.Errorf("%s != %s", buf, payload)
 	}
@@ -536,7 +535,7 @@ func TestRescanIfHaveInvalidContent(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	err = m.Request(device2, "default", "foo", 0, f.Blocks[0].Hash, f.Blocks[0].WeakHash, false, buf)
+	res, err = m.Request(device2, "default", "foo", int32(len(payload)), 0, f.Blocks[0].Hash, f.Blocks[0].WeakHash, false)
 	if err == nil {
 		t.Fatalf("expected failure")
 	}

+ 3 - 2
lib/protocol/benchmark_test.go

@@ -171,12 +171,13 @@ func (m *fakeModel) Index(deviceID DeviceID, folder string, files []FileInfo) {
 func (m *fakeModel) IndexUpdate(deviceID DeviceID, folder string, files []FileInfo) {
 }
 
-func (m *fakeModel) Request(deviceID DeviceID, folder string, name string, offset int64, hash []byte, weakHAsh uint32, fromTemporary bool, buf []byte) error {
+func (m *fakeModel) Request(deviceID DeviceID, folder, name string, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) {
 	// We write the offset to the end of the buffer, so the receiver
 	// can verify that it did in fact get some data back over the
 	// connection.
+	buf := make([]byte, size)
 	binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(offset))
-	return nil
+	return &fakeRequestResponse{buf}, nil
 }
 
 func (m *fakeModel) ClusterConfig(deviceID DeviceID, config ClusterConfig) {

+ 46 - 36
lib/protocol/bufferpool.go

@@ -4,32 +4,59 @@ package protocol
 
 import "sync"
 
+// Global pool to get buffers from. Requires Blocksizes to be initialised,
+// therefore it is initialized in the same init() as BlockSizes
+var BufferPool bufferPool
+
 type bufferPool struct {
-	minSize int
-	pool    sync.Pool
+	pools []sync.Pool
 }
 
-// get returns a new buffer of the requested size
-func (p *bufferPool) get(size int) []byte {
-	intf := p.pool.Get()
-	if intf == nil {
-		// Pool is empty, must allocate.
-		return p.new(size)
-	}
+func newBufferPool() bufferPool {
+	return bufferPool{make([]sync.Pool, len(BlockSizes))}
+}
 
-	bs := *intf.(*[]byte)
-	if cap(bs) < size {
-		// Buffer was too small, leave it for someone else and allocate.
-		p.pool.Put(intf)
-		return p.new(size)
+func (p *bufferPool) Get(size int) []byte {
+	// Too big, isn't pooled
+	if size > MaxBlockSize {
+		return make([]byte, size)
+	}
+	var i int
+	for i = range BlockSizes {
+		if size <= BlockSizes[i] {
+			break
+		}
 	}
+	var bs []byte
+	// Try the fitting and all bigger pools
+	for j := i; j < len(BlockSizes); j++ {
+		if intf := p.pools[j].Get(); intf != nil {
+			bs = *intf.(*[]byte)
+			return bs[:size]
+		}
+	}
+	// All pools are empty, must allocate.
+	return make([]byte, BlockSizes[i])[:size]
+}
 
-	return bs[:size]
+// Put makes the given byte slice availabe again in the global pool
+func (p *bufferPool) Put(bs []byte) {
+	c := cap(bs)
+	// Don't buffer huge byte slices
+	if c > 2*MaxBlockSize {
+		return
+	}
+	for i := range BlockSizes {
+		if c >= BlockSizes[i] {
+			p.pools[i].Put(&bs)
+			return
+		}
+	}
 }
 
-// upgrade grows the buffer to the requested size, while attempting to reuse
+// Upgrade grows the buffer to the requested size, while attempting to reuse
 // it if possible.
-func (p *bufferPool) upgrade(bs []byte, size int) []byte {
+func (p *bufferPool) Upgrade(bs []byte, size int) []byte {
 	if cap(bs) >= size {
 		// Reslicing is enough, lets go!
 		return bs[:size]
@@ -37,23 +64,6 @@ func (p *bufferPool) upgrade(bs []byte, size int) []byte {
 
 	// It was too small. But it pack into the pool and try to get another
 	// buffer.
-	p.put(bs)
-	return p.get(size)
-}
-
-// put returns the buffer to the pool
-func (p *bufferPool) put(bs []byte) {
-	p.pool.Put(&bs)
-}
-
-// new creates a new buffer of the requested size, taking the minimum
-// allocation count into account. For internal use only.
-func (p *bufferPool) new(size int) []byte {
-	allocSize := size
-	if allocSize < p.minSize {
-		// Avoid allocating tiny buffers that we won't be able to reuse for
-		// anything useful.
-		allocSize = p.minSize
-	}
-	return make([]byte, allocSize)[:size]
+	p.Put(bs)
+	return p.Get(size)
 }

+ 17 - 4
lib/protocol/common_test.go

@@ -9,7 +9,7 @@ type TestModel struct {
 	folder        string
 	name          string
 	offset        int64
-	size          int
+	size          int32
 	hash          []byte
 	weakHash      uint32
 	fromTemporary bool
@@ -29,16 +29,17 @@ func (t *TestModel) Index(deviceID DeviceID, folder string, files []FileInfo) {
 func (t *TestModel) IndexUpdate(deviceID DeviceID, folder string, files []FileInfo) {
 }
 
-func (t *TestModel) Request(deviceID DeviceID, folder, name string, offset int64, hash []byte, weakHash uint32, fromTemporary bool, buf []byte) error {
+func (t *TestModel) Request(deviceID DeviceID, folder, name string, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) {
 	t.folder = folder
 	t.name = name
 	t.offset = offset
-	t.size = len(buf)
+	t.size = size
 	t.hash = hash
 	t.weakHash = weakHash
 	t.fromTemporary = fromTemporary
+	buf := make([]byte, len(t.data))
 	copy(buf, t.data)
-	return nil
+	return &fakeRequestResponse{buf}, nil
 }
 
 func (t *TestModel) Closed(conn Connection, err error) {
@@ -60,3 +61,15 @@ func (t *TestModel) closedError() error {
 		return nil // Timeout
 	}
 }
+
+type fakeRequestResponse struct {
+	data []byte
+}
+
+func (r *fakeRequestResponse) Data() []byte {
+	return r.data
+}
+
+func (r *fakeRequestResponse) Close() {}
+
+func (r *fakeRequestResponse) Wait() {}

+ 2 - 2
lib/protocol/nativemodel_darwin.go

@@ -26,7 +26,7 @@ func (m nativeModel) IndexUpdate(deviceID DeviceID, folder string, files []FileI
 	m.Model.IndexUpdate(deviceID, folder, files)
 }
 
-func (m nativeModel) Request(deviceID DeviceID, folder string, name string, offset int64, hash []byte, weakHash uint32, fromTemporary bool, buf []byte) error {
+func (m nativeModel) Request(deviceID DeviceID, folder, name string, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) {
 	name = norm.NFD.String(name)
-	return m.Model.Request(deviceID, folder, name, offset, hash, weakHash, fromTemporary, buf)
+	return m.Model.Request(deviceID, folder, name, size, offset, hash, weakHash, fromTemporary)
 }

+ 3 - 3
lib/protocol/nativemodel_windows.go

@@ -25,14 +25,14 @@ func (m nativeModel) IndexUpdate(deviceID DeviceID, folder string, files []FileI
 	m.Model.IndexUpdate(deviceID, folder, files)
 }
 
-func (m nativeModel) Request(deviceID DeviceID, folder string, name string, offset int64, hash []byte, weakHash uint32, fromTemporary bool, buf []byte) error {
+func (m nativeModel) Request(deviceID DeviceID, folder, name string, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) {
 	if strings.Contains(name, `\`) {
 		l.Warnf("Dropping request for %s, contains invalid path separator", name)
-		return ErrNoSuchFile
+		return nil, ErrNoSuchFile
 	}
 
 	name = filepath.FromSlash(name)
-	return m.Model.Request(deviceID, folder, name, offset, hash, weakHash, fromTemporary, buf)
+	return m.Model.Request(deviceID, folder, name, size, offset, hash, weakHash, fromTemporary)
 }
 
 func fixupFiles(files []FileInfo) []FileInfo {

+ 4 - 2
lib/protocol/nativemodel_windows_test.go

@@ -2,8 +2,10 @@
 
 package protocol
 
-import "testing"
-import "reflect"
+import (
+	"reflect"
+	"testing"
+)
 
 func TestFixupFiles(t *testing.T) {
 	files := []FileInfo{

+ 52 - 70
lib/protocol/protocol.go

@@ -48,6 +48,7 @@ func init() {
 		BlockSizes = append(BlockSizes, blockSize)
 		sha256OfEmptyBlock[blockSize] = sha256.Sum256(make([]byte, blockSize))
 	}
+	BufferPool = newBufferPool()
 }
 
 // BlockSize returns the block size to use for the given file size
@@ -125,7 +126,7 @@ type Model interface {
 	// An index update was received from the peer device
 	IndexUpdate(deviceID DeviceID, folder string, files []FileInfo)
 	// A request was made by the peer device
-	Request(deviceID DeviceID, folder string, name string, offset int64, hash []byte, weakHash uint32, fromTemporary bool, buf []byte) error
+	Request(deviceID DeviceID, folder, name string, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error)
 	// A cluster configuration message was received
 	ClusterConfig(deviceID DeviceID, config ClusterConfig)
 	// The peer device closed the connection
@@ -134,6 +135,12 @@ type Model interface {
 	DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate)
 }
 
+type RequestResponse interface {
+	Data() []byte
+	Close() // Must always be called once the byte slice is no longer in use
+	Wait()  // Blocks until Close is called
+}
+
 type Connection interface {
 	Start()
 	ID() DeviceID
@@ -166,7 +173,6 @@ type rawConnection struct {
 	outbox      chan asyncMessage
 	closed      chan struct{}
 	once        sync.Once
-	pool        bufferPool
 	compression Compression
 }
 
@@ -184,7 +190,7 @@ type message interface {
 
 type asyncMessage struct {
 	msg  message
-	done chan struct{} // done closes when we're done marshalling the message and its contents can be reused
+	done chan struct{} // done closes when we're done sending the message
 }
 
 const (
@@ -196,12 +202,6 @@ const (
 	ReceiveTimeout = 300 * time.Second
 )
 
-// A buffer pool for global use. We don't allocate smaller buffers than 64k,
-// in the hope of being able to reuse them later.
-var buffers = bufferPool{
-	minSize: 64 << 10,
-}
-
 func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection {
 	cr := &countingReader{Reader: reader}
 	cw := &countingWriter{Writer: writer}
@@ -215,7 +215,6 @@ func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiv
 		awaiting:    make(map[int32]chan asyncResult),
 		outbox:      make(chan asyncMessage),
 		closed:      make(chan struct{}),
-		pool:        bufferPool{minSize: MinBlockSize},
 		compression: compress,
 	}
 
@@ -338,6 +337,7 @@ func (c *rawConnection) readerLoop() (err error) {
 		c.close(err)
 	}()
 
+	fourByteBuf := make([]byte, 4)
 	state := stateInitial
 	for {
 		select {
@@ -346,7 +346,7 @@ func (c *rawConnection) readerLoop() (err error) {
 		default:
 		}
 
-		msg, err := c.readMessage()
+		msg, err := c.readMessage(fourByteBuf)
 		if err == errUnknownMessage {
 			// Unknown message types are skipped, for future extensibility.
 			continue
@@ -394,7 +394,6 @@ func (c *rawConnection) readerLoop() (err error) {
 			if err := checkFilename(msg.Name); err != nil {
 				return fmt.Errorf("protocol error: request: %q: %v", msg.Name, err)
 			}
-			// Requests are handled asynchronously
 			go c.handleRequest(*msg)
 
 		case *Response:
@@ -429,30 +428,29 @@ func (c *rawConnection) readerLoop() (err error) {
 	}
 }
 
-func (c *rawConnection) readMessage() (message, error) {
-	hdr, err := c.readHeader()
+func (c *rawConnection) readMessage(fourByteBuf []byte) (message, error) {
+	hdr, err := c.readHeader(fourByteBuf)
 	if err != nil {
 		return nil, err
 	}
 
-	return c.readMessageAfterHeader(hdr)
+	return c.readMessageAfterHeader(hdr, fourByteBuf)
 }
 
-func (c *rawConnection) readMessageAfterHeader(hdr Header) (message, error) {
+func (c *rawConnection) readMessageAfterHeader(hdr Header, fourByteBuf []byte) (message, error) {
 	// First comes a 4 byte message length
 
-	buf := buffers.get(4)
-	if _, err := io.ReadFull(c.cr, buf); err != nil {
+	if _, err := io.ReadFull(c.cr, fourByteBuf[:4]); err != nil {
 		return nil, fmt.Errorf("reading message length: %v", err)
 	}
-	msgLen := int32(binary.BigEndian.Uint32(buf))
+	msgLen := int32(binary.BigEndian.Uint32(fourByteBuf))
 	if msgLen < 0 {
 		return nil, fmt.Errorf("negative message length %d", msgLen)
 	}
 
 	// Then comes the message
 
-	buf = buffers.upgrade(buf, int(msgLen))
+	buf := BufferPool.Get(int(msgLen))
 	if _, err := io.ReadFull(c.cr, buf); err != nil {
 		return nil, fmt.Errorf("reading message: %v", err)
 	}
@@ -465,7 +463,7 @@ func (c *rawConnection) readMessageAfterHeader(hdr Header) (message, error) {
 
 	case MessageCompressionLZ4:
 		decomp, err := c.lz4Decompress(buf)
-		buffers.put(buf)
+		BufferPool.Put(buf)
 		if err != nil {
 			return nil, fmt.Errorf("decompressing message: %v", err)
 		}
@@ -484,26 +482,25 @@ func (c *rawConnection) readMessageAfterHeader(hdr Header) (message, error) {
 	if err := msg.Unmarshal(buf); err != nil {
 		return nil, fmt.Errorf("unmarshalling message: %v", err)
 	}
-	buffers.put(buf)
+	BufferPool.Put(buf)
 
 	return msg, nil
 }
 
-func (c *rawConnection) readHeader() (Header, error) {
+func (c *rawConnection) readHeader(fourByteBuf []byte) (Header, error) {
 	// First comes a 2 byte header length
 
-	buf := buffers.get(2)
-	if _, err := io.ReadFull(c.cr, buf); err != nil {
+	if _, err := io.ReadFull(c.cr, fourByteBuf[:2]); err != nil {
 		return Header{}, fmt.Errorf("reading length: %v", err)
 	}
-	hdrLen := int16(binary.BigEndian.Uint16(buf))
+	hdrLen := int16(binary.BigEndian.Uint16(fourByteBuf))
 	if hdrLen < 0 {
 		return Header{}, fmt.Errorf("negative header length %d", hdrLen)
 	}
 
 	// Then comes the header
 
-	buf = buffers.upgrade(buf, int(hdrLen))
+	buf := BufferPool.Get(int(hdrLen))
 	if _, err := io.ReadFull(c.cr, buf); err != nil {
 		return Header{}, fmt.Errorf("reading header: %v", err)
 	}
@@ -513,7 +510,7 @@ func (c *rawConnection) readHeader() (Header, error) {
 		return Header{}, fmt.Errorf("unmarshalling header: %v", err)
 	}
 
-	buffers.put(buf)
+	BufferPool.Put(buf)
 	return hdr, nil
 }
 
@@ -590,38 +587,22 @@ func checkFilename(name string) error {
 }
 
 func (c *rawConnection) handleRequest(req Request) {
-	size := int(req.Size)
-	usePool := size <= MaxBlockSize
-
-	var buf []byte
-	var done chan struct{}
-
-	if usePool {
-		buf = c.pool.get(size)
-		done = make(chan struct{})
-	} else {
-		buf = make([]byte, size)
-	}
-
-	err := c.receiver.Request(c.id, req.Folder, req.Name, req.Offset, req.Hash, req.WeakHash, req.FromTemporary, buf)
+	res, err := c.receiver.Request(c.id, req.Folder, req.Name, req.Size, req.Offset, req.Hash, req.WeakHash, req.FromTemporary)
 	if err != nil {
 		c.send(&Response{
 			ID:   req.ID,
-			Data: nil,
 			Code: errorToCode(err),
-		}, done)
-	} else {
-		c.send(&Response{
-			ID:   req.ID,
-			Data: buf,
-			Code: errorToCode(err),
-		}, done)
-	}
-
-	if usePool {
-		<-done
-		c.pool.put(buf)
+		}, nil)
+		return
 	}
+	done := make(chan struct{})
+	c.send(&Response{
+		ID:   req.ID,
+		Data: res.Data(),
+		Code: errorToCode(nil),
+	}, done)
+	<-done
+	res.Close()
 }
 
 func (c *rawConnection) handleResponse(resp Response) {
@@ -639,6 +620,9 @@ func (c *rawConnection) send(msg message, done chan struct{}) bool {
 	case c.outbox <- asyncMessage{msg, done}:
 		return true
 	case <-c.closed:
+		if done != nil {
+			close(done)
+		}
 		return false
 	}
 }
@@ -647,7 +631,11 @@ func (c *rawConnection) writerLoop() {
 	for {
 		select {
 		case hm := <-c.outbox:
-			if err := c.writeMessage(hm); err != nil {
+			err := c.writeMessage(hm)
+			if hm.done != nil {
+				close(hm.done)
+			}
+			if err != nil {
 				c.close(err)
 				return
 			}
@@ -667,13 +655,10 @@ func (c *rawConnection) writeMessage(hm asyncMessage) error {
 
 func (c *rawConnection) writeCompressedMessage(hm asyncMessage) error {
 	size := hm.msg.ProtoSize()
-	buf := buffers.get(size)
+	buf := BufferPool.Get(size)
 	if _, err := hm.msg.MarshalTo(buf); err != nil {
 		return fmt.Errorf("marshalling message: %v", err)
 	}
-	if hm.done != nil {
-		close(hm.done)
-	}
 
 	compressed, err := c.lz4Compress(buf)
 	if err != nil {
@@ -690,7 +675,7 @@ func (c *rawConnection) writeCompressedMessage(hm asyncMessage) error {
 	}
 
 	totSize := 2 + hdrSize + 4 + len(compressed)
-	buf = buffers.upgrade(buf, totSize)
+	buf = BufferPool.Upgrade(buf, totSize)
 
 	// Header length
 	binary.BigEndian.PutUint16(buf, uint16(hdrSize))
@@ -702,10 +687,10 @@ func (c *rawConnection) writeCompressedMessage(hm asyncMessage) error {
 	binary.BigEndian.PutUint32(buf[2+hdrSize:], uint32(len(compressed)))
 	// Message
 	copy(buf[2+hdrSize+4:], compressed)
-	buffers.put(compressed)
+	BufferPool.Put(compressed)
 
 	n, err := c.cw.Write(buf)
-	buffers.put(buf)
+	BufferPool.Put(buf)
 
 	l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message (%d uncompressed)), err=%v", n, hdrSize, len(compressed), size, err)
 	if err != nil {
@@ -726,7 +711,7 @@ func (c *rawConnection) writeUncompressedMessage(hm asyncMessage) error {
 	}
 
 	totSize := 2 + hdrSize + 4 + size
-	buf := buffers.get(totSize)
+	buf := BufferPool.Get(totSize)
 
 	// Header length
 	binary.BigEndian.PutUint16(buf, uint16(hdrSize))
@@ -740,12 +725,9 @@ func (c *rawConnection) writeUncompressedMessage(hm asyncMessage) error {
 	if _, err := hm.msg.MarshalTo(buf[2+hdrSize+4:]); err != nil {
 		return fmt.Errorf("marshalling message: %v", err)
 	}
-	if hm.done != nil {
-		close(hm.done)
-	}
 
 	n, err := c.cw.Write(buf[:totSize])
-	buffers.put(buf)
+	BufferPool.Put(buf)
 
 	l.Debugf("wrote %d bytes on the wire (2 bytes length, %d bytes header, 4 bytes message length, %d bytes message), err=%v", n, hdrSize, size, err)
 	if err != nil {
@@ -904,7 +886,7 @@ func (c *rawConnection) Statistics() Statistics {
 
 func (c *rawConnection) lz4Compress(src []byte) ([]byte, error) {
 	var err error
-	buf := buffers.get(len(src))
+	buf := BufferPool.Get(len(src))
 	buf, err = lz4.Encode(buf, src)
 	if err != nil {
 		return nil, err
@@ -918,7 +900,7 @@ func (c *rawConnection) lz4Decompress(src []byte) ([]byte, error) {
 	size := binary.BigEndian.Uint32(src)
 	binary.LittleEndian.PutUint32(src, size)
 	var err error
-	buf := buffers.get(int(size))
+	buf := BufferPool.Get(int(size))
 	buf, err = lz4.Decode(buf, src)
 	if err != nil {
 		return nil, err