Sfoglia il codice sorgente

Don't leak sendIndexes on disconnect (fixes #2589)

Adds a Closed() method on protocol.Connection and clears up
wireformatConnection a little too.
Jakob Borg 10 anni fa
parent
commit
acdddc0b79

+ 9 - 1
lib/model/model.go

@@ -1073,10 +1073,18 @@ func sendIndexes(conn protocol.Connection, folder string, fs *db.FileSet, ignore
 
 	minLocalVer, err := sendIndexTo(true, 0, conn, folder, fs, ignores)
 
-	sub := events.Default.Subscribe(events.LocalIndexUpdated)
+	// Subscribe to LocalIndexUpdated (we have new information to send) and
+	// DeviceDisconnected (it might be us who disconnected, so we should
+	// exit).
+	sub := events.Default.Subscribe(events.LocalIndexUpdated | events.DeviceDisconnected)
 	defer events.Default.Unsubscribe(sub)
 
 	for err == nil {
+		if conn.Closed() {
+			// Our work is done.
+			return
+		}
+
 		// While we have sent a localVersion at least equal to the one
 		// currently in the database, wait for the local index to update. The
 		// local index may update for other folders than the one we are

+ 4 - 0
lib/model/model_test.go

@@ -243,6 +243,10 @@ func (FakeConnection) Ping() bool {
 	return true
 }
 
+func (FakeConnection) Closed() bool {
+	return false
+}
+
 func (FakeConnection) Statistics() protocol.Statistics {
 	return protocol.Statistics{}
 }

+ 10 - 0
lib/protocol/protocol.go

@@ -113,6 +113,7 @@ type Connection interface {
 	Request(folder string, name string, offset int64, size int, hash []byte, flags uint32, options []Option) ([]byte, error)
 	ClusterConfig(config ClusterConfigMessage)
 	Statistics() Statistics
+	Closed() bool
 }
 
 type rawConnection struct {
@@ -287,6 +288,15 @@ func (c *rawConnection) ClusterConfig(config ClusterConfigMessage) {
 	c.send(-1, messageTypeClusterConfig, config, nil)
 }
 
+func (c *rawConnection) Closed() bool {
+	select {
+	case <-c.closed:
+		return true
+	default:
+		return false
+	}
+}
+
 func (c *rawConnection) ping() bool {
 	var id int
 	select {

+ 5 - 5
lib/protocol/protocol_test.go

@@ -76,9 +76,9 @@ func TestPing(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, newTestModel(), "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
+	c0 := NewConnection(c0ID, ar, bw, newTestModel(), "name", CompressAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c0.Start()
-	c1 := NewConnection(c1ID, br, aw, newTestModel(), "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
+	c1 := NewConnection(c1ID, br, aw, newTestModel(), "name", CompressAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c1.Start()
 	c0.ClusterConfig(ClusterConfigMessage{})
 	c1.ClusterConfig(ClusterConfigMessage{})
@@ -98,7 +98,7 @@ func TestVersionErr(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
+	c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c0.Start()
 	c1 := NewConnection(c1ID, br, aw, m1, "name", CompressAlways)
 	c1.Start()
@@ -125,7 +125,7 @@ func TestTypeErr(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
+	c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c0.Start()
 	c1 := NewConnection(c1ID, br, aw, m1, "name", CompressAlways)
 	c1.Start()
@@ -152,7 +152,7 @@ func TestClose(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).next.(*rawConnection)
+	c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressAlways).(wireFormatConnection).Connection.(*rawConnection)
 	c0.Start()
 	c1 := NewConnection(c1ID, br, aw, m1, "name", CompressAlways)
 	c1.Start()

+ 4 - 24
lib/protocol/wireformat.go

@@ -9,19 +9,7 @@ import (
 )
 
 type wireFormatConnection struct {
-	next Connection
-}
-
-func (c wireFormatConnection) Start() {
-	c.next.Start()
-}
-
-func (c wireFormatConnection) ID() DeviceID {
-	return c.next.ID()
-}
-
-func (c wireFormatConnection) Name() string {
-	return c.next.Name()
+	Connection
 }
 
 func (c wireFormatConnection) Index(folder string, fs []FileInfo, flags uint32, options []Option) error {
@@ -32,7 +20,7 @@ func (c wireFormatConnection) Index(folder string, fs []FileInfo, flags uint32,
 		myFs[i].Name = norm.NFC.String(filepath.ToSlash(myFs[i].Name))
 	}
 
-	return c.next.Index(folder, myFs, flags, options)
+	return c.Connection.Index(folder, myFs, flags, options)
 }
 
 func (c wireFormatConnection) IndexUpdate(folder string, fs []FileInfo, flags uint32, options []Option) error {
@@ -43,18 +31,10 @@ func (c wireFormatConnection) IndexUpdate(folder string, fs []FileInfo, flags ui
 		myFs[i].Name = norm.NFC.String(filepath.ToSlash(myFs[i].Name))
 	}
 
-	return c.next.IndexUpdate(folder, myFs, flags, options)
+	return c.Connection.IndexUpdate(folder, myFs, flags, options)
 }
 
 func (c wireFormatConnection) Request(folder, name string, offset int64, size int, hash []byte, flags uint32, options []Option) ([]byte, error) {
 	name = norm.NFC.String(filepath.ToSlash(name))
-	return c.next.Request(folder, name, offset, size, hash, flags, options)
-}
-
-func (c wireFormatConnection) ClusterConfig(config ClusterConfigMessage) {
-	c.next.ClusterConfig(config)
-}
-
-func (c wireFormatConnection) Statistics() Statistics {
-	return c.next.Statistics()
+	return c.Connection.Request(folder, name, offset, size, hash, flags, options)
 }