1
0
Эх сурвалжийг харах

Merge pull request #1699 from calmh/connsvc

Break out connection handling into a service
Audrius Butkevicius 10 жил өмнө
parent
commit
ecc8591c95

+ 89 - 29
cmd/syncthing/connections.go

@@ -15,23 +15,84 @@ import (
 	"time"
 
 	"github.com/syncthing/protocol"
+	"github.com/syncthing/syncthing/internal/config"
 	"github.com/syncthing/syncthing/internal/events"
 	"github.com/syncthing/syncthing/internal/model"
+	"github.com/thejerf/suture"
 )
 
-func listenConnect(myID protocol.DeviceID, m *model.Model, tlsCfg *tls.Config) {
-	var conns = make(chan *tls.Conn)
+// The connection service listens on TLS and dials configured unconnected
+// devices. Successfull connections are handed to the model.
+type connectionSvc struct {
+	*suture.Supervisor
+	cfg    *config.Wrapper
+	myID   protocol.DeviceID
+	model  *model.Model
+	tlsCfg *tls.Config
+	conns  chan *tls.Conn
+}
 
-	// Listen
-	for _, addr := range cfg.Options().ListenAddress {
-		go listenTLS(conns, addr, tlsCfg)
+func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, model *model.Model, tlsCfg *tls.Config) *connectionSvc {
+	svc := &connectionSvc{
+		Supervisor: suture.NewSimple("connectionSvc"),
+		cfg:        cfg,
+		myID:       myID,
+		model:      model,
+		tlsCfg:     tlsCfg,
+		conns:      make(chan *tls.Conn),
 	}
 
-	// Connect
-	go dialTLS(m, conns, tlsCfg)
+	// There are several moving parts here; one routine per listening address
+	// to handle incoming connections, one routine to periodically attempt
+	// outgoing connections, and lastly one routine to the the common handling
+	// regardless of whether the connection was incoming or outgoing. It ends
+	// up as in the diagram below. We embed a Supervisor to manage the
+	// routines (i.e. log and restart if they crash or exit, etc).
+	//
+	//                +-----------------+
+	//    Incoming    | +---------------+-+      +-----------------+
+	//   Connections  | |                 |      |                 |   Outgoing
+	// -------------->| |   svc.listen    |      |                 |  Connections
+	//                | |  (1 per listen  |      |   svc.connect   |-------------->
+	//                | |    address)     |      |                 |
+	//                +-+                 |      |                 |
+	//                  +-----------------+      +-----------------+
+	//                           v                        v
+	//                           |                        |
+	//                           |                        |
+	//                           +------------+-----------+
+	//                                        |
+	//                                        | svc.conns
+	//                                        v
+	//                               +-----------------+
+	//                               |                 |
+	//                               |                 |
+	//                               |   svc.handle    |------> model.AddConnection()
+	//                               |                 |
+	//                               |                 |
+	//                               +-----------------+
+	//
+	// TODO: Clean shutdown, and/or handling config changes on the fly. We
+	// partly do this now - new devices and addresses will be picked up, but
+	// not new listen addresses and we don't support disconnecting devices
+	// that are removed and so on...
+
+	svc.Add(serviceFunc(svc.connect))
+	for _, addr := range svc.cfg.Options().ListenAddress {
+		addr := addr
+		listener := serviceFunc(func() {
+			svc.listen(addr)
+		})
+		svc.Add(listener)
+	}
+	svc.Add(serviceFunc(svc.handle))
 
+	return svc
+}
+
+func (s *connectionSvc) handle() {
 next:
-	for conn := range conns {
+	for conn := range s.conns {
 		cs := conn.ConnectionState()
 
 		// We should have negotiated the next level protocol "bep/1.0" as part
@@ -69,13 +130,13 @@ next:
 		// this one. But in case we are two devices connecting to each other
 		// in parallell we don't want to do that or we end up with no
 		// connections still established...
-		if m.ConnectedTo(remoteID) {
+		if s.model.ConnectedTo(remoteID) {
 			l.Infof("Connected to already connected device (%s)", remoteID)
 			conn.Close()
 			continue
 		}
 
-		for deviceID, deviceCfg := range cfg.Devices() {
+		for deviceID, deviceCfg := range s.cfg.Devices() {
 			if deviceID == remoteID {
 				// Verify the name on the certificate. By default we set it to
 				// "syncthing" when generating, but the user may have replaced
@@ -97,7 +158,7 @@ next:
 				// If rate limiting is set, and based on the address we should
 				// limit the connection, then we wrap it in a limiter.
 
-				limit := shouldLimit(conn.RemoteAddr())
+				limit := s.shouldLimit(conn.RemoteAddr())
 
 				wr := io.Writer(conn)
 				if limit && writeRateLimit != nil {
@@ -110,7 +171,7 @@ next:
 				}
 
 				name := fmt.Sprintf("%s-%s", conn.LocalAddr(), conn.RemoteAddr())
-				protoConn := protocol.NewConnection(remoteID, rd, wr, m, name, deviceCfg.Compression)
+				protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)
 
 				l.Infof("Established secure connection to %s at %s", remoteID, name)
 				if debugNet {
@@ -121,12 +182,12 @@ next:
 					"addr": conn.RemoteAddr().String(),
 				})
 
-				m.AddConnection(conn, protoConn)
+				s.model.AddConnection(conn, protoConn)
 				continue next
 			}
 		}
 
-		if !cfg.IgnoredDevice(remoteID) {
+		if !s.cfg.IgnoredDevice(remoteID) {
 			events.Default.Log(events.DeviceRejected, map[string]string{
 				"device":  remoteID.String(),
 				"address": conn.RemoteAddr().String(),
@@ -140,7 +201,7 @@ next:
 	}
 }
 
-func listenTLS(conns chan *tls.Conn, addr string, tlsCfg *tls.Config) {
+func (s *connectionSvc) listen(addr string) {
 	if debugNet {
 		l.Debugln("listening on", addr)
 	}
@@ -166,9 +227,9 @@ func listenTLS(conns chan *tls.Conn, addr string, tlsCfg *tls.Config) {
 		}
 
 		tcpConn := conn.(*net.TCPConn)
-		setTCPOptions(tcpConn)
+		s.setTCPOptions(tcpConn)
 
-		tc := tls.Server(conn, tlsCfg)
+		tc := tls.Server(conn, s.tlsCfg)
 		err = tc.Handshake()
 		if err != nil {
 			l.Infoln("TLS handshake:", err)
@@ -176,21 +237,20 @@ func listenTLS(conns chan *tls.Conn, addr string, tlsCfg *tls.Config) {
 			continue
 		}
 
-		conns <- tc
+		s.conns <- tc
 	}
-
 }
 
-func dialTLS(m *model.Model, conns chan *tls.Conn, tlsCfg *tls.Config) {
+func (s *connectionSvc) connect() {
 	delay := time.Second
 	for {
 	nextDevice:
-		for deviceID, deviceCfg := range cfg.Devices() {
+		for deviceID, deviceCfg := range s.cfg.Devices() {
 			if deviceID == myID {
 				continue
 			}
 
-			if m.ConnectedTo(deviceID) {
+			if s.model.ConnectedTo(deviceID) {
 				continue
 			}
 
@@ -238,9 +298,9 @@ func dialTLS(m *model.Model, conns chan *tls.Conn, tlsCfg *tls.Config) {
 					continue
 				}
 
-				setTCPOptions(conn)
+				s.setTCPOptions(conn)
 
-				tc := tls.Client(conn, tlsCfg)
+				tc := tls.Client(conn, s.tlsCfg)
 				err = tc.Handshake()
 				if err != nil {
 					l.Infoln("TLS handshake:", err)
@@ -248,20 +308,20 @@ func dialTLS(m *model.Model, conns chan *tls.Conn, tlsCfg *tls.Config) {
 					continue
 				}
 
-				conns <- tc
+				s.conns <- tc
 				continue nextDevice
 			}
 		}
 
 		time.Sleep(delay)
 		delay *= 2
-		if maxD := time.Duration(cfg.Options().ReconnectIntervalS) * time.Second; delay > maxD {
+		if maxD := time.Duration(s.cfg.Options().ReconnectIntervalS) * time.Second; delay > maxD {
 			delay = maxD
 		}
 	}
 }
 
-func setTCPOptions(conn *net.TCPConn) {
+func (*connectionSvc) setTCPOptions(conn *net.TCPConn) {
 	var err error
 	if err = conn.SetLinger(0); err != nil {
 		l.Infoln(err)
@@ -277,8 +337,8 @@ func setTCPOptions(conn *net.TCPConn) {
 	}
 }
 
-func shouldLimit(addr net.Addr) bool {
-	if cfg.Options().LimitBandwidthInLan {
+func (s *connectionSvc) shouldLimit(addr net.Addr) bool {
+	if s.cfg.Options().LimitBandwidthInLan {
 		return true
 	}
 

+ 3 - 1
cmd/syncthing/main.go

@@ -584,7 +584,9 @@ func syncthingMain() {
 
 	// Routine to connect out to configured devices
 	discoverer = discovery(externalPort)
-	go listenConnect(myID, m, tlsCfg)
+
+	connectionSvc := newConnectionSvc(cfg, myID, m, tlsCfg)
+	mainSvc.Add(connectionSvc)
 
 	for _, folder := range cfg.Folders() {
 		// Routine to pull blocks from other devices to synchronize the local