소스 검색

lib/config, lib/connections: Refactor handling of ignored devices (fixes #3470)

GitHub-Pull-Request: https://github.com/syncthing/syncthing/pull/3471
Jakob Borg 9 년 전
부모
커밋
f368d2278f
7개의 변경된 파일147개의 추가작업 그리고 63개의 파일을 삭제
  1. 10 0
      lib/config/config.go
  2. 46 0
      lib/config/config_test.go
  3. 10 0
      lib/config/testdata/ignoreddevices.xml
  4. 12 0
      lib/config/wrapper.go
  5. 47 48
      lib/connections/service.go
  6. 1 1
      lib/connections/structs.go
  7. 21 14
      lib/model/model.go

+ 10 - 0
lib/config/config.go

@@ -273,6 +273,16 @@ func (cfg *Configuration) prepare(myID protocol.DeviceID) error {
 		cfg.GUI.APIKey = rand.String(32)
 	}
 
+	// The list of ignored devices should not contain any devices that have
+	// been manually added to the config.
+	newIgnoredDevices := []protocol.DeviceID{}
+	for _, dev := range cfg.IgnoredDevices {
+		if !existingDevices[dev] {
+			newIgnoredDevices = append(newIgnoredDevices, dev)
+		}
+	}
+	cfg.IgnoredDevices = newIgnoredDevices
+
 	return nil
 }
 

+ 46 - 0
lib/config/config_test.go

@@ -697,3 +697,49 @@ func TestV14ListenAddressesMigration(t *testing.T) {
 		}
 	}
 }
+
+func TestIgnoredDevices(t *testing.T) {
+	// Verify that ignored devices that are also present in the
+	// configuration are not in fact ignored.
+
+	wrapper, err := Load("testdata/ignoreddevices.xml", device1)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if wrapper.IgnoredDevice(device1) {
+		t.Errorf("Device %v should not be ignored", device1)
+	}
+	if !wrapper.IgnoredDevice(device3) {
+		t.Errorf("Device %v should be ignored", device3)
+	}
+}
+
+func TestGetDevice(t *testing.T) {
+	// Verify that the Device() call does the right thing
+
+	wrapper, err := Load("testdata/ignoreddevices.xml", device1)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// device1 is mentioned in the config
+
+	device, ok := wrapper.Device(device1)
+	if !ok {
+		t.Error(device1, "should exist")
+	}
+	if device.DeviceID != device1 {
+		t.Error("Should have returned", device1, "not", device.DeviceID)
+	}
+
+	// device3 is not
+
+	device, ok = wrapper.Device(device3)
+	if ok {
+		t.Error(device3, "should not exist")
+	}
+	if device.DeviceID == device3 {
+		t.Error("Should not returned ID", device3)
+	}
+}

+ 10 - 0
lib/config/testdata/ignoreddevices.xml

@@ -0,0 +1,10 @@
+<configuration version="15">
+    <device id="AIR6LPZ-7K4PTTV-UXQSMUU-CPQ5YWH-OEDFIIQ-JUG777G-2YQXXR5-YD6AWQR">
+        <address>dynamic</address>
+    </device>
+    <device id="GYRZZQB-IRNPV4Z-T7TC52W-EQYJ3TT-FDQW6MW-DFLMU42-SSSU6EM-FBK2VAY">
+        <address>dynamic</address>
+    </device>
+    <ignoredDevice>AIR6LPZ-7K4PTTV-UXQSMUU-CPQ5YWH-OEDFIIQ-JUG777G-2YQXXR5-YD6AWQR</ignoredDevice>
+    <ignoredDevice>LGFPDIT-7SKNNJL-VJZA4FC-7QNCRKA-CE753K7-2BW5QDK-2FOZ7FR-FEP57QJ</ignoredDevice>
+</configuration>

+ 12 - 0
lib/config/wrapper.go

@@ -284,6 +284,18 @@ func (w *Wrapper) IgnoredDevice(id protocol.DeviceID) bool {
 	return false
 }
 
+// Device returns the configuration for the given device and an "ok" bool.
+func (w *Wrapper) Device(id protocol.DeviceID) (DeviceConfiguration, bool) {
+	w.mut.Lock()
+	defer w.mut.Unlock()
+	for _, device := range w.cfg.Devices {
+		if device.DeviceID == id {
+			return device, true
+		}
+	}
+	return DeviceConfiguration{}, false
+}
+
 // Save writes the configuration to disk, and generates a ConfigSaved event.
 func (w *Wrapper) Save() error {
 	fd, err := osutil.CreateAtomic(w.path, 0600)

+ 47 - 48
lib/connections/service.go

@@ -183,7 +183,13 @@ next:
 		}
 		c.SetDeadline(time.Time{})
 
-		s.model.OnHello(remoteID, c.RemoteAddr(), hello)
+		// The Model will return an error for devices that we don't want to
+		// have a connection with for whatever reason, for example unknown devices.
+		if err := s.model.OnHello(remoteID, c.RemoteAddr(), hello); err != nil {
+			l.Infof("Connection from %s at %s (%s) rejected: %v", remoteID, c.RemoteAddr(), c.Type, err)
+			c.Close()
+			continue
+		}
 
 		// If we have a relay connection, and the new incoming connection is
 		// not a relay connection, we should drop that, and prefer the this one.
@@ -205,63 +211,56 @@ next:
 			l.Infof("Connected to already connected device (%s)", remoteID)
 			c.Close()
 			continue
-		} else if s.model.IsPaused(remoteID) {
-			l.Infof("Connection from paused device (%s)", remoteID)
-			c.Close()
-			continue
 		}
 
-		for deviceID, deviceCfg := range s.cfg.Devices() {
-			if deviceID == remoteID {
-				// Verify the name on the certificate. By default we set it to
-				// "syncthing" when generating, but the user may have replaced
-				// the certificate and used another name.
-				certName := deviceCfg.CertName
-				if certName == "" {
-					certName = s.tlsDefaultCommonName
-				}
-				err := remoteCert.VerifyHostname(certName)
-				if err != nil {
-					// Incorrect certificate name is something the user most
-					// likely wants to know about, since it's an advanced
-					// config. Warn instead of Info.
-					l.Warnf("Bad certificate from %s (%v): %v", remoteID, c.RemoteAddr(), err)
-					c.Close()
-					continue next
-				}
+		deviceCfg, ok := s.cfg.Device(remoteID)
+		if !ok {
+			panic("bug: unknown device should already have been rejected")
+		}
 
-				// If rate limiting is set, and based on the address we should
-				// limit the connection, then we wrap it in a limiter.
+		// Verify the name on the certificate. By default we set it to
+		// "syncthing" when generating, but the user may have replaced
+		// the certificate and used another name.
+		certName := deviceCfg.CertName
+		if certName == "" {
+			certName = s.tlsDefaultCommonName
+		}
+		if err := remoteCert.VerifyHostname(certName); err != nil {
+			// Incorrect certificate name is something the user most
+			// likely wants to know about, since it's an advanced
+			// config. Warn instead of Info.
+			l.Warnf("Bad certificate from %s (%v): %v", remoteID, c.RemoteAddr(), err)
+			c.Close()
+			continue next
+		}
 
-				limit := s.shouldLimit(c.RemoteAddr())
+		// If rate limiting is set, and based on the address we should
+		// limit the connection, then we wrap it in a limiter.
 
-				wr := io.Writer(c)
-				if limit && s.writeRateLimit != nil {
-					wr = NewWriteLimiter(c, s.writeRateLimit)
-				}
+		limit := s.shouldLimit(c.RemoteAddr())
 
-				rd := io.Reader(c)
-				if limit && s.readRateLimit != nil {
-					rd = NewReadLimiter(c, s.readRateLimit)
-				}
+		wr := io.Writer(c)
+		if limit && s.writeRateLimit != nil {
+			wr = NewWriteLimiter(c, s.writeRateLimit)
+		}
 
-				name := fmt.Sprintf("%s-%s (%s)", c.LocalAddr(), c.RemoteAddr(), c.Type)
-				protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)
-				modelConn := Connection{c, protoConn}
+		rd := io.Reader(c)
+		if limit && s.readRateLimit != nil {
+			rd = NewReadLimiter(c, s.readRateLimit)
+		}
 
-				l.Infof("Established secure connection to %s at %s", remoteID, name)
-				l.Debugf("cipher suite: %04X in lan: %t", c.ConnectionState().CipherSuite, !limit)
+		name := fmt.Sprintf("%s-%s (%s)", c.LocalAddr(), c.RemoteAddr(), c.Type)
+		protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)
+		modelConn := Connection{c, protoConn}
 
-				s.model.AddConnection(modelConn, hello)
-				s.curConMut.Lock()
-				s.currentConnection[remoteID] = modelConn
-				s.curConMut.Unlock()
-				continue next
-			}
-		}
+		l.Infof("Established secure connection to %s at %s", remoteID, name)
+		l.Debugf("cipher suite: %04X in lan: %t", c.ConnectionState().CipherSuite, !limit)
 
-		l.Infof("Connection from %s (%s) with ignored device ID %s", c.RemoteAddr(), c.Type, remoteID)
-		c.Close()
+		s.model.AddConnection(modelConn, hello)
+		s.curConMut.Lock()
+		s.currentConnection[remoteID] = modelConn
+		s.curConMut.Unlock()
+		continue next
 	}
 }
 

+ 1 - 1
lib/connections/structs.go

@@ -69,7 +69,7 @@ type Model interface {
 	AddConnection(conn Connection, hello protocol.HelloResult)
 	ConnectedTo(remoteID protocol.DeviceID) bool
 	IsPaused(remoteID protocol.DeviceID) bool
-	OnHello(protocol.DeviceID, net.Addr, protocol.HelloResult)
+	OnHello(protocol.DeviceID, net.Addr, protocol.HelloResult) error
 	GetHello(protocol.DeviceID) protocol.HelloIntf
 }
 

+ 21 - 14
lib/model/model.go

@@ -116,6 +116,9 @@ var (
 	errFolderNoSpace       = errors.New("folder has insufficient free space")
 	errUnsupportedSymlink  = errors.New("symlink not supported")
 	errInvalidFilename     = errors.New("filename is invalid")
+	errDeviceUnknown       = errors.New("unknown device")
+	errDevicePaused        = errors.New("device is paused")
+	errDeviceIgnored       = errors.New("device is ignored")
 )
 
 // NewModel creates and starts a new model. The model starts in read-only mode,
@@ -1065,23 +1068,27 @@ func (m *Model) SetIgnores(folder string, content []string) error {
 // OnHello is called when an device connects to us.
 // This allows us to extract some information from the Hello message
 // and add it to a list of known devices ahead of any checks.
-func (m *Model) OnHello(remoteID protocol.DeviceID, addr net.Addr, hello protocol.HelloResult) {
-	for deviceID := range m.cfg.Devices() {
-		if deviceID == remoteID {
-			// Existing device, we will get the hello message in AddConnection
-			// hence do not persist any state here, as the connection might
-			// get killed before AddConnection
-			return
-		}
+func (m *Model) OnHello(remoteID protocol.DeviceID, addr net.Addr, hello protocol.HelloResult) error {
+	if m.IsPaused(remoteID) {
+		return errDevicePaused
 	}
 
-	if !m.cfg.IgnoredDevice(remoteID) {
-		events.Default.Log(events.DeviceRejected, map[string]string{
-			"name":    hello.DeviceName,
-			"device":  remoteID.String(),
-			"address": addr.String(),
-		})
+	if m.cfg.IgnoredDevice(remoteID) {
+		return errDeviceIgnored
 	}
+
+	if _, ok := m.cfg.Device(remoteID); ok {
+		// The device exists
+		return nil
+	}
+
+	events.Default.Log(events.DeviceRejected, map[string]string{
+		"name":    hello.DeviceName,
+		"device":  remoteID.String(),
+		"address": addr.String(),
+	})
+
+	return errDeviceUnknown
 }
 
 // GetHello is called when we are about to connect to some remote device.