Browse Source

lib: Close underlying conn in protocol (fixes #7165) (#7212)

Simon Frei 5 years ago
parent
commit
c845e245a1

+ 2 - 3
lib/api/mocked_model_test.go

@@ -11,7 +11,6 @@ import (
 	"net"
 	"time"
 
-	"github.com/syncthing/syncthing/lib/connections"
 	"github.com/syncthing/syncthing/lib/db"
 	"github.com/syncthing/syncthing/lib/model"
 	"github.com/syncthing/syncthing/lib/protocol"
@@ -114,7 +113,7 @@ func (m *mockedModel) ScanFolderSubdirs(folder string, subs []string) error {
 
 func (m *mockedModel) BringToFront(folder, file string) {}
 
-func (m *mockedModel) Connection(deviceID protocol.DeviceID) (connections.Connection, bool) {
+func (m *mockedModel) Connection(deviceID protocol.DeviceID) (protocol.Connection, bool) {
 	return nil, false
 }
 
@@ -165,7 +164,7 @@ func (m *mockedModel) DownloadProgress(deviceID protocol.DeviceID, folder string
 	return nil
 }
 
-func (m *mockedModel) AddConnection(conn connections.Connection, hello protocol.Hello) {}
+func (m *mockedModel) AddConnection(conn protocol.Connection, hello protocol.Hello) {}
 
 func (m *mockedModel) OnHello(protocol.DeviceID, net.Addr, protocol.Hello) error {
 	return nil

+ 3 - 4
lib/connections/service.go

@@ -329,15 +329,14 @@ func (s *service) handle(ctx context.Context) error {
 		var protoConn protocol.Connection
 		passwords := s.cfg.FolderPasswords(remoteID)
 		if len(passwords) > 0 {
-			protoConn = protocol.NewEncryptedConnection(passwords, remoteID, rd, wr, s.model, c.String(), deviceCfg.Compression)
+			protoConn = protocol.NewEncryptedConnection(passwords, remoteID, rd, wr, c, s.model, c, deviceCfg.Compression)
 		} else {
-			protoConn = protocol.NewConnection(remoteID, rd, wr, s.model, c.String(), deviceCfg.Compression)
+			protoConn = protocol.NewConnection(remoteID, rd, wr, c, s.model, c, deviceCfg.Compression)
 		}
-		modelConn := completeConn{c, protoConn}
 
 		l.Infof("Established secure connection to %s at %s", remoteID, c)
 
-		s.model.AddConnection(modelConn, hello)
+		s.model.AddConnection(protoConn, hello)
 		continue
 	}
 }

+ 4 - 29
lib/connections/structs.go

@@ -22,31 +22,6 @@ import (
 	"github.com/thejerf/suture/v4"
 )
 
-// Connection is what we expose to the outside. It is a protocol.Connection
-// that can be closed and has some metadata.
-type Connection interface {
-	protocol.Connection
-	Type() string
-	Transport() string
-	RemoteAddr() net.Addr
-	Priority() int
-	String() string
-	Crypto() string
-}
-
-// completeConn is the aggregation of an internalConn and the
-// protocol.Connection running on top of it. It implements the Connection
-// interface.
-type completeConn struct {
-	internalConn
-	protocol.Connection
-}
-
-func (c completeConn) Close(err error) {
-	c.Connection.Close(err)
-	c.internalConn.Close()
-}
-
 type tlsConn interface {
 	io.ReadWriteCloser
 	ConnectionState() tls.ConnectionState
@@ -107,12 +82,12 @@ func (t connType) Transport() string {
 	}
 }
 
-func (c internalConn) Close() {
+func (c internalConn) Close() error {
 	// *tls.Conn.Close() does more than it says on the tin. Specifically, it
 	// sends a TLS alert message, which might block forever if the
 	// connection is dead and we don't have a deadline set.
 	_ = c.SetWriteDeadline(time.Now().Add(250 * time.Millisecond))
-	_ = c.tlsConn.Close()
+	return c.tlsConn.Close()
 }
 
 func (c internalConn) Type() string {
@@ -203,8 +178,8 @@ type genericListener interface {
 
 type Model interface {
 	protocol.Model
-	AddConnection(conn Connection, hello protocol.Hello)
-	Connection(remoteID protocol.DeviceID) (Connection, bool)
+	AddConnection(conn protocol.Connection, hello protocol.Hello)
+	Connection(remoteID protocol.DeviceID) (protocol.Connection, bool)
 	OnHello(protocol.DeviceID, net.Addr, protocol.Hello) error
 	GetHello(protocol.DeviceID) protocol.HelloIntf
 }

+ 2 - 50
lib/model/fakeconns_test.go

@@ -9,13 +9,12 @@ package model
 import (
 	"bytes"
 	"context"
-	"net"
 	"sync"
 	"time"
 
-	"github.com/syncthing/syncthing/lib/connections"
 	"github.com/syncthing/syncthing/lib/protocol"
 	"github.com/syncthing/syncthing/lib/scanner"
+	"github.com/syncthing/syncthing/lib/testutils"
 )
 
 type downloadProgressMessage struct {
@@ -24,7 +23,7 @@ type downloadProgressMessage struct {
 }
 
 type fakeConnection struct {
-	fakeUnderlyingConn
+	testutils.FakeConnectionInfo
 	id                       protocol.DeviceID
 	downloadProgressMessages []downloadProgressMessage
 	closed                   bool
@@ -219,50 +218,3 @@ func addFakeConn(m *testModel, dev protocol.DeviceID) *fakeConnection {
 
 	return fc
 }
-
-type fakeProtoConn struct {
-	protocol.Connection
-	fakeUnderlyingConn
-}
-
-func newFakeProtoConn(protoConn protocol.Connection) connections.Connection {
-	return &fakeProtoConn{Connection: protoConn}
-}
-
-// fakeUnderlyingConn implements the methods of connections.Connection that are
-// not implemented by protocol.Connection
-type fakeUnderlyingConn struct{}
-
-func (f *fakeUnderlyingConn) RemoteAddr() net.Addr {
-	return &fakeAddr{}
-}
-
-func (f *fakeUnderlyingConn) Type() string {
-	return "fake"
-}
-
-func (f *fakeUnderlyingConn) Crypto() string {
-	return "fake"
-}
-
-func (f *fakeUnderlyingConn) Transport() string {
-	return "fake"
-}
-
-func (f *fakeUnderlyingConn) Priority() int {
-	return 9000
-}
-
-func (f *fakeUnderlyingConn) String() string {
-	return ""
-}
-
-type fakeAddr struct{}
-
-func (fakeAddr) Network() string {
-	return "network"
-}
-
-func (fakeAddr) String() string {
-	return "address"
-}

+ 5 - 5
lib/model/model.go

@@ -150,7 +150,7 @@ type model struct {
 
 	// fields protected by pmut
 	pmut                sync.RWMutex
-	conn                map[protocol.DeviceID]connections.Connection
+	conn                map[protocol.DeviceID]protocol.Connection
 	connRequestLimiters map[protocol.DeviceID]*byteSemaphore
 	closed              map[protocol.DeviceID]chan struct{}
 	helloMessages       map[protocol.DeviceID]protocol.Hello
@@ -232,7 +232,7 @@ func NewModel(cfg config.Wrapper, id protocol.DeviceID, clientName, clientVersio
 
 		// fields protected by pmut
 		pmut:                sync.NewRWMutex(),
-		conn:                make(map[protocol.DeviceID]connections.Connection),
+		conn:                make(map[protocol.DeviceID]protocol.Connection),
 		connRequestLimiters: make(map[protocol.DeviceID]*byteSemaphore),
 		closed:              make(map[protocol.DeviceID]chan struct{}),
 		helloMessages:       make(map[protocol.DeviceID]protocol.Hello),
@@ -1660,7 +1660,7 @@ func (m *model) Closed(conn protocol.Connection, err error) {
 
 	m.progressEmitter.temporaryIndexUnsubscribe(conn)
 
-	l.Infof("Connection to %s at %s closed: %v", device, conn.Name(), err)
+	l.Infof("Connection to %s at %s closed: %v", device, conn, err)
 	m.evLogger.Log(events.DeviceDisconnected, map[string]string{
 		"id":    device.String(),
 		"error": err.Error(),
@@ -1912,7 +1912,7 @@ func (m *model) CurrentGlobalFile(folder string, file string) (protocol.FileInfo
 }
 
 // Connection returns the current connection for device, and a boolean whether a connection was found.
-func (m *model) Connection(deviceID protocol.DeviceID) (connections.Connection, bool) {
+func (m *model) Connection(deviceID protocol.DeviceID) (protocol.Connection, bool) {
 	m.pmut.RLock()
 	cn, ok := m.conn[deviceID]
 	m.pmut.RUnlock()
@@ -2039,7 +2039,7 @@ func (m *model) GetHello(id protocol.DeviceID) protocol.HelloIntf {
 // AddConnection adds a new peer connection to the model. An initial index will
 // be sent to the connected peer, thereafter index updates whenever the local
 // folder changes.
-func (m *model) AddConnection(conn connections.Connection, hello protocol.Hello) {
+func (m *model) AddConnection(conn protocol.Connection, hello protocol.Hello) {
 	deviceID := conn.ID()
 	device, ok := m.cfg.Device(deviceID)
 	if !ok {

+ 1 - 1
lib/model/model_test.go

@@ -3297,7 +3297,7 @@ func TestConnCloseOnRestart(t *testing.T) {
 
 	br := &testutils.BlockingRW{}
 	nw := &testutils.NoopRW{}
-	m.AddConnection(newFakeProtoConn(protocol.NewConnection(device1, br, nw, m, "testConn", protocol.CompressionNever)), protocol.Hello{})
+	m.AddConnection(protocol.NewConnection(device1, br, nw, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"fc"}, protocol.CompressionNever), protocol.Hello{})
 	m.pmut.RLock()
 	if len(m.closed) != 1 {
 		t.Fatalf("Expected just one conn (len(m.conn) == %v)", len(m.conn))

+ 3 - 2
lib/protocol/benchmark_test.go

@@ -10,6 +10,7 @@ import (
 	"testing"
 
 	"github.com/syncthing/syncthing/lib/dialer"
+	"github.com/syncthing/syncthing/lib/testutils"
 )
 
 func BenchmarkRequestsRawTCP(b *testing.B) {
@@ -59,9 +60,9 @@ func benchmarkRequestsTLS(b *testing.B, conn0, conn1 net.Conn) {
 
 func benchmarkRequestsConnPair(b *testing.B, conn0, conn1 net.Conn) {
 	// Start up Connections on them
-	c0 := NewConnection(LocalDeviceID, conn0, conn0, new(fakeModel), "c0", CompressionMetadata)
+	c0 := NewConnection(LocalDeviceID, conn0, conn0, testutils.NoopCloser{}, new(fakeModel), &testutils.FakeConnectionInfo{"c0"}, CompressionMetadata)
 	c0.Start()
-	c1 := NewConnection(LocalDeviceID, conn1, conn1, new(fakeModel), "c1", CompressionMetadata)
+	c1 := NewConnection(LocalDeviceID, conn1, conn1, testutils.NoopCloser{}, new(fakeModel), &testutils.FakeConnectionInfo{"c1"}, CompressionMetadata)
 	c1.Start()
 
 	// Satisfy the assertions in the protocol by sending an initial cluster config

+ 1 - 4
lib/protocol/encryption.go

@@ -128,6 +128,7 @@ func (e encryptedModel) Closed(conn Connection, err error) {
 // The encryptedConnection sits between the model and the encrypted device. It
 // encrypts outgoing metadata and decrypts incoming responses.
 type encryptedConnection struct {
+	ConnectionInfo
 	conn       Connection
 	folderKeys map[string]*[keySize]byte // folder ID -> key
 }
@@ -140,10 +141,6 @@ func (e encryptedConnection) ID() DeviceID {
 	return e.conn.ID()
 }
 
-func (e encryptedConnection) Name() string {
-	return e.conn.Name()
-}
-
 func (e encryptedConnection) Index(ctx context.Context, folder string, files []FileInfo) error {
 	if folderKey, ok := e.folderKeys[folder]; ok {
 		encryptFileInfos(files, folderKey)

+ 27 - 15
lib/protocol/protocol.go

@@ -8,6 +8,7 @@ import (
 	"encoding/binary"
 	"fmt"
 	"io"
+	"net"
 	"path"
 	"strings"
 	"sync"
@@ -134,7 +135,6 @@ type Connection interface {
 	Start()
 	Close(err error)
 	ID() DeviceID
-	Name() string
 	Index(ctx context.Context, folder string, files []FileInfo) error
 	IndexUpdate(ctx context.Context, folder string, files []FileInfo) error
 	Request(ctx context.Context, folder string, name string, blockNo int, offset int64, size int, hash []byte, weakHash uint32, fromTemporary bool) ([]byte, error)
@@ -142,16 +142,28 @@ type Connection interface {
 	DownloadProgress(ctx context.Context, folder string, updates []FileDownloadProgressUpdate)
 	Statistics() Statistics
 	Closed() bool
+	ConnectionInfo
+}
+
+type ConnectionInfo interface {
+	Type() string
+	Transport() string
+	RemoteAddr() net.Addr
+	Priority() int
+	String() string
+	Crypto() string
 }
 
 type rawConnection struct {
+	ConnectionInfo
+
 	id        DeviceID
-	name      string
 	receiver  Model
 	startTime time.Time
 
-	cr *countingReader
-	cw *countingWriter
+	cr     *countingReader
+	cw     *countingWriter
+	closer io.Closer // Closing the underlying connection and thus cr and cw
 
 	awaiting    map[int]chan asyncResult
 	awaitingMut sync.Mutex
@@ -205,13 +217,13 @@ const (
 // Should not be modified in production code, just for testing.
 var CloseTimeout = 10 * time.Second
 
-func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection {
+func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) Connection {
 	receiver = nativeModel{receiver}
-	rc := newRawConnection(deviceID, reader, writer, receiver, name, compress)
+	rc := newRawConnection(deviceID, reader, writer, closer, receiver, connInfo, compress)
 	return wireFormatConnection{rc}
 }
 
-func NewEncryptedConnection(passwords map[string]string, deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) Connection {
+func NewEncryptedConnection(passwords map[string]string, deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) Connection {
 	keys := keysFromPasswords(passwords)
 
 	// Encryption / decryption is first (outermost) before conversion to
@@ -221,23 +233,24 @@ func NewEncryptedConnection(passwords map[string]string, deviceID DeviceID, read
 
 	// We do the wire format conversion first (outermost) so that the
 	// metadata is in wire format when it reaches the encryption step.
-	rc := newRawConnection(deviceID, reader, writer, em, name, compress)
-	ec := encryptedConnection{conn: rc, folderKeys: keys}
+	rc := newRawConnection(deviceID, reader, writer, closer, em, connInfo, compress)
+	ec := encryptedConnection{ConnectionInfo: rc, conn: rc, folderKeys: keys}
 	wc := wireFormatConnection{ec}
 
 	return wc
 }
 
-func newRawConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, receiver Model, name string, compress Compression) *rawConnection {
+func newRawConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) *rawConnection {
 	cr := &countingReader{Reader: reader}
 	cw := &countingWriter{Writer: writer}
 
 	return &rawConnection{
+		ConnectionInfo:        connInfo,
 		id:                    deviceID,
-		name:                  name,
 		receiver:              receiver,
 		cr:                    cr,
 		cw:                    cw,
+		closer:                closer,
 		awaiting:              make(map[int]chan asyncResult),
 		inbox:                 make(chan message),
 		outbox:                make(chan asyncMessage),
@@ -282,10 +295,6 @@ func (c *rawConnection) ID() DeviceID {
 	return c.id
 }
 
-func (c *rawConnection) Name() string {
-	return c.name
-}
-
 // Index writes the list of file information to the connected peer device
 func (c *rawConnection) Index(ctx context.Context, folder string, idx []FileInfo) error {
 	select {
@@ -931,6 +940,9 @@ func (c *rawConnection) Close(err error) {
 func (c *rawConnection) internalClose(err error) {
 	c.closeOnce.Do(func() {
 		l.Debugln("close due to", err)
+		if cerr := c.closer.Close(); cerr != nil {
+			l.Debugln(c.id, "failed to close underlying conn:", cerr)
+		}
 		close(c.closed)
 
 		c.awaitingMut.Lock()

+ 11 - 11
lib/protocol/protocol_test.go

@@ -31,10 +31,10 @@ func TestPing(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, newTestModel(), &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c0.Start()
 	defer closeAndWait(c0, ar, bw)
-	c1 := NewConnection(c1ID, br, aw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, newTestModel(), &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c1.Start()
 	defer closeAndWait(c1, ar, bw)
 	c0.ClusterConfig(ClusterConfig{})
@@ -57,10 +57,10 @@ func TestClose(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c0.Start()
 	defer closeAndWait(c0, ar, bw)
-	c1 := NewConnection(c1ID, br, aw, m1, "name", CompressionAlways)
+	c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, &testutils.FakeConnectionInfo{"name"}, CompressionAlways)
 	c1.Start()
 	defer closeAndWait(c1, ar, bw)
 	c0.ClusterConfig(ClusterConfig{})
@@ -102,7 +102,7 @@ func TestCloseOnBlockingSend(t *testing.T) {
 	m := newTestModel()
 
 	rw := testutils.NewBlockingRW()
-	c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c.Start()
 	defer closeAndWait(c, rw)
 
@@ -153,10 +153,10 @@ func TestCloseRace(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, m0, "c0", CompressionNever).(wireFormatConnection).Connection.(*rawConnection)
+	c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, &testutils.FakeConnectionInfo{"c0"}, CompressionNever).(wireFormatConnection).Connection.(*rawConnection)
 	c0.Start()
 	defer closeAndWait(c0, ar, bw)
-	c1 := NewConnection(c1ID, br, aw, m1, "c1", CompressionNever)
+	c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, &testutils.FakeConnectionInfo{"c1"}, CompressionNever)
 	c1.Start()
 	defer closeAndWait(c1, ar, bw)
 	c0.ClusterConfig(ClusterConfig{})
@@ -193,7 +193,7 @@ func TestClusterConfigFirst(t *testing.T) {
 	m := newTestModel()
 
 	rw := testutils.NewBlockingRW()
-	c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c := NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c.Start()
 	defer closeAndWait(c, rw)
 
@@ -245,7 +245,7 @@ func TestCloseTimeout(t *testing.T) {
 	m := newTestModel()
 
 	rw := testutils.NewBlockingRW()
-	c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c.Start()
 	defer closeAndWait(c, rw)
 
@@ -865,7 +865,7 @@ func TestClusterConfigAfterClose(t *testing.T) {
 	m := newTestModel()
 
 	rw := testutils.NewBlockingRW()
-	c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c.Start()
 	defer closeAndWait(c, rw)
 
@@ -889,7 +889,7 @@ func TestDispatcherToCloseDeadlock(t *testing.T) {
 	// the model callbacks (ClusterConfig).
 	m := newTestModel()
 	rw := testutils.NewBlockingRW()
-	c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c := NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, &testutils.FakeConnectionInfo{"name"}, CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
 	m.ccFn = func(devID DeviceID, cc ClusterConfig) {
 		c.Close(errManual)
 	}

+ 47 - 0
lib/testutils/testutils.go

@@ -8,6 +8,7 @@ package testutils
 
 import (
 	"errors"
+	"net"
 	"sync"
 )
 
@@ -52,3 +53,49 @@ func (rw *NoopRW) Read(p []byte) (n int, err error) {
 func (rw *NoopRW) Write(p []byte) (n int, err error) {
 	return len(p), nil
 }
+
+type NoopCloser struct{}
+
+func (NoopCloser) Close() error {
+	return nil
+}
+
+// FakeConnectionInfo implements the methods of protocol.Connection that are
+// not implemented by protocol.Connection
+type FakeConnectionInfo struct {
+	Name string
+}
+
+func (f *FakeConnectionInfo) RemoteAddr() net.Addr {
+	return &FakeAddr{}
+}
+
+func (f *FakeConnectionInfo) Type() string {
+	return "fake"
+}
+
+func (f *FakeConnectionInfo) Crypto() string {
+	return "fake"
+}
+
+func (f *FakeConnectionInfo) Transport() string {
+	return "fake"
+}
+
+func (f *FakeConnectionInfo) Priority() int {
+	return 9000
+}
+
+func (f *FakeConnectionInfo) String() string {
+	return ""
+}
+
+type FakeAddr struct{}
+
+func (FakeAddr) Network() string {
+	return "network"
+}
+
+func (FakeAddr) String() string {
+	return "address"
+}