فهرست منبع

Decouple connections service from model

The connections service no longer depends directly on the
syncthing model object, but on an interface instead. This
makes it drastically easier to write clients that handle
the model differently, but still want to benefit from
existing and future connections changes in the core.

This was motivated by burkemw3's interest in creating a
FUSE client that can present a view of the global model,
but not have all of the file data locally.

The actual decoupling was done by adding a connections.Model
interface. This interface is effectively an extension of the
protocol.Model interface that also handles connections
alongside the modified service.
Matt Burke 10 سال پیش
والد
کامیت
2234c45c19

+ 2 - 8
cmd/syncthing/main.go

@@ -28,6 +28,7 @@ import (
 	"github.com/calmh/logger"
 	"github.com/calmh/logger"
 	"github.com/juju/ratelimit"
 	"github.com/juju/ratelimit"
 	"github.com/syncthing/syncthing/lib/config"
 	"github.com/syncthing/syncthing/lib/config"
+	"github.com/syncthing/syncthing/lib/connections"
 	"github.com/syncthing/syncthing/lib/db"
 	"github.com/syncthing/syncthing/lib/db"
 	"github.com/syncthing/syncthing/lib/discover"
 	"github.com/syncthing/syncthing/lib/discover"
 	"github.com/syncthing/syncthing/lib/events"
 	"github.com/syncthing/syncthing/lib/events"
@@ -577,13 +578,6 @@ func syncthingMain() {
 		symlinks.Supported = false
 		symlinks.Supported = false
 	}
 	}
 
 
-	if opts.MaxSendKbps > 0 {
-		writeRateLimit = ratelimit.NewBucketWithRate(float64(1000*opts.MaxSendKbps), int64(5*1000*opts.MaxSendKbps))
-	}
-	if opts.MaxRecvKbps > 0 {
-		readRateLimit = ratelimit.NewBucketWithRate(float64(1000*opts.MaxRecvKbps), int64(5*1000*opts.MaxRecvKbps))
-	}
-
 	if (opts.MaxRecvKbps > 0 || opts.MaxSendKbps > 0) && !opts.LimitBandwidthInLan {
 	if (opts.MaxRecvKbps > 0 || opts.MaxSendKbps > 0) && !opts.LimitBandwidthInLan {
 		lans, _ = osutil.GetLans()
 		lans, _ = osutil.GetLans()
 		networks := make([]string, 0, len(lans))
 		networks := make([]string, 0, len(lans))
@@ -750,7 +744,7 @@ func syncthingMain() {
 
 
 	// Start connection management
 	// Start connection management
 
 
-	connectionSvc := newConnectionSvc(cfg, myID, m, tlsCfg, cachedDiscovery, relaySvc)
+	connectionSvc := connections.NewConnectionSvc(cfg, myID, m, tlsCfg, cachedDiscovery, relaySvc, bepProtocolName, tlsDefaultCommonName, lans)
 	mainSvc.Add(connectionSvc)
 	mainSvc.Add(connectionSvc)
 
 
 	if cpuProfile {
 	if cpuProfile {

+ 78 - 37
cmd/syncthing/connections.go → lib/connections/connections.go

@@ -4,7 +4,7 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // You can obtain one at http://mozilla.org/MPL/2.0/.
 // You can obtain one at http://mozilla.org/MPL/2.0/.
 
 
-package main
+package connections
 
 
 import (
 import (
 	"crypto/tls"
 	"crypto/tls"
@@ -15,6 +15,7 @@ import (
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
+	"github.com/juju/ratelimit"
 	"github.com/syncthing/syncthing/lib/config"
 	"github.com/syncthing/syncthing/lib/config"
 	"github.com/syncthing/syncthing/lib/discover"
 	"github.com/syncthing/syncthing/lib/discover"
 	"github.com/syncthing/syncthing/lib/events"
 	"github.com/syncthing/syncthing/lib/events"
@@ -35,17 +36,39 @@ var (
 	listeners = make(map[string]ListenerFactory, 0)
 	listeners = make(map[string]ListenerFactory, 0)
 )
 )
 
 
+type Model interface {
+	AddConnection(conn model.Connection)
+	ConnectedTo(remoteID protocol.DeviceID) bool
+	IsPaused(remoteID protocol.DeviceID) bool
+
+	// An index was received from the peer device
+	Index(deviceID protocol.DeviceID, folder string, files []protocol.FileInfo, flags uint32, options []protocol.Option)
+	// An index update was received from the peer device
+	IndexUpdate(deviceID protocol.DeviceID, folder string, files []protocol.FileInfo, flags uint32, options []protocol.Option)
+	// A request was made by the peer device
+	Request(deviceID protocol.DeviceID, folder string, name string, offset int64, hash []byte, flags uint32, options []protocol.Option, buf []byte) error
+	// A cluster configuration message was received
+	ClusterConfig(deviceID protocol.DeviceID, config protocol.ClusterConfigMessage)
+	// The peer device closed the connection
+	Close(deviceID protocol.DeviceID, err error)
+}
+
 // The connection service listens on TLS and dials configured unconnected
 // The connection service listens on TLS and dials configured unconnected
 // devices. Successful connections are handed to the model.
 // devices. Successful connections are handed to the model.
 type connectionSvc struct {
 type connectionSvc struct {
 	*suture.Supervisor
 	*suture.Supervisor
-	cfg        *config.Wrapper
-	myID       protocol.DeviceID
-	model      *model.Model
-	tlsCfg     *tls.Config
-	discoverer discover.Finder
-	conns      chan model.IntermediateConnection
-	relaySvc   *relay.Svc
+	cfg                  *config.Wrapper
+	myID                 protocol.DeviceID
+	model                Model
+	tlsCfg               *tls.Config
+	discoverer           discover.Finder
+	conns                chan model.IntermediateConnection
+	relaySvc             *relay.Svc
+	bepProtocolName      string
+	tlsDefaultCommonName string
+	lans                 []*net.IPNet
+	writeRateLimit       *ratelimit.Bucket
+	readRateLimit        *ratelimit.Bucket
 
 
 	lastRelayCheck map[protocol.DeviceID]time.Time
 	lastRelayCheck map[protocol.DeviceID]time.Time
 
 
@@ -54,16 +77,20 @@ type connectionSvc struct {
 	relaysEnabled bool
 	relaysEnabled bool
 }
 }
 
 
-func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl *model.Model, tlsCfg *tls.Config, discoverer discover.Finder, relaySvc *relay.Svc) *connectionSvc {
+func NewConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder, relaySvc *relay.Svc,
+	bepProtocolName string, tlsDefaultCommonName string, lans []*net.IPNet) suture.Service {
 	svc := &connectionSvc{
 	svc := &connectionSvc{
-		Supervisor: suture.NewSimple("connectionSvc"),
-		cfg:        cfg,
-		myID:       myID,
-		model:      mdl,
-		tlsCfg:     tlsCfg,
-		discoverer: discoverer,
-		relaySvc:   relaySvc,
-		conns:      make(chan model.IntermediateConnection),
+		Supervisor:           suture.NewSimple("connectionSvc"),
+		cfg:                  cfg,
+		myID:                 myID,
+		model:                mdl,
+		tlsCfg:               tlsCfg,
+		discoverer:           discoverer,
+		relaySvc:             relaySvc,
+		conns:                make(chan model.IntermediateConnection),
+		bepProtocolName:      bepProtocolName,
+		tlsDefaultCommonName: tlsDefaultCommonName,
+		lans:                 lans,
 
 
 		connType:       make(map[protocol.DeviceID]model.ConnectionType),
 		connType:       make(map[protocol.DeviceID]model.ConnectionType),
 		relaysEnabled:  cfg.Options().RelaysEnabled,
 		relaysEnabled:  cfg.Options().RelaysEnabled,
@@ -71,6 +98,13 @@ func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl *model.Mo
 	}
 	}
 	cfg.Subscribe(svc)
 	cfg.Subscribe(svc)
 
 
+	if svc.cfg.Options().MaxSendKbps > 0 {
+		svc.writeRateLimit = ratelimit.NewBucketWithRate(float64(1000*svc.cfg.Options().MaxSendKbps), int64(5*1000*svc.cfg.Options().MaxSendKbps))
+	}
+	if svc.cfg.Options().MaxRecvKbps > 0 {
+		svc.readRateLimit = ratelimit.NewBucketWithRate(float64(1000*svc.cfg.Options().MaxRecvKbps), int64(5*1000*svc.cfg.Options().MaxRecvKbps))
+	}
+
 	// There are several moving parts here; one routine per listening address
 	// There are several moving parts here; one routine per listening address
 	// to handle incoming connections, one routine to periodically attempt
 	// to handle incoming connections, one routine to periodically attempt
 	// outgoing connections, one routine to the the common handling
 	// outgoing connections, one routine to the the common handling
@@ -97,7 +131,7 @@ func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl *model.Mo
 			continue
 			continue
 		}
 		}
 
 
-		if debugNet {
+		if debug {
 			l.Debugln("listening on", uri.String())
 			l.Debugln("listening on", uri.String())
 		}
 		}
 
 
@@ -123,7 +157,7 @@ next:
 		// of the TLS handshake. Unfortunately this can't be a hard error,
 		// of the TLS handshake. Unfortunately this can't be a hard error,
 		// because there are implementations out there that don't support
 		// because there are implementations out there that don't support
 		// protocol negotiation (iOS for one...).
 		// protocol negotiation (iOS for one...).
-		if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != bepProtocolName {
+		if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != s.bepProtocolName {
 			l.Infof("Peer %s did not negotiate bep/1.0", c.Conn.RemoteAddr())
 			l.Infof("Peer %s did not negotiate bep/1.0", c.Conn.RemoteAddr())
 		}
 		}
 
 
@@ -142,7 +176,7 @@ next:
 		// The device ID should not be that of ourselves. It can happen
 		// The device ID should not be that of ourselves. It can happen
 		// though, especially in the presence of NAT hairpinning, multiple
 		// though, especially in the presence of NAT hairpinning, multiple
 		// clients between the same NAT gateway, and global discovery.
 		// clients between the same NAT gateway, and global discovery.
-		if remoteID == myID {
+		if remoteID == s.myID {
 			l.Infof("Connected to myself (%s) - should not happen", remoteID)
 			l.Infof("Connected to myself (%s) - should not happen", remoteID)
 			c.Conn.Close()
 			c.Conn.Close()
 			continue
 			continue
@@ -154,7 +188,7 @@ next:
 		ct, ok := s.connType[remoteID]
 		ct, ok := s.connType[remoteID]
 		s.mut.RUnlock()
 		s.mut.RUnlock()
 		if ok && !ct.IsDirect() && c.Type.IsDirect() {
 		if ok && !ct.IsDirect() && c.Type.IsDirect() {
-			if debugNet {
+			if debug {
 				l.Debugln("Switching connections", remoteID)
 				l.Debugln("Switching connections", remoteID)
 			}
 			}
 			s.model.Close(remoteID, fmt.Errorf("switching connections"))
 			s.model.Close(remoteID, fmt.Errorf("switching connections"))
@@ -181,7 +215,7 @@ next:
 				// the certificate and used another name.
 				// the certificate and used another name.
 				certName := deviceCfg.CertName
 				certName := deviceCfg.CertName
 				if certName == "" {
 				if certName == "" {
-					certName = tlsDefaultCommonName
+					certName = s.tlsDefaultCommonName
 				}
 				}
 				err := remoteCert.VerifyHostname(certName)
 				err := remoteCert.VerifyHostname(certName)
 				if err != nil {
 				if err != nil {
@@ -199,20 +233,20 @@ next:
 				limit := s.shouldLimit(c.Conn.RemoteAddr())
 				limit := s.shouldLimit(c.Conn.RemoteAddr())
 
 
 				wr := io.Writer(c.Conn)
 				wr := io.Writer(c.Conn)
-				if limit && writeRateLimit != nil {
-					wr = &limitedWriter{c.Conn, writeRateLimit}
+				if limit && s.writeRateLimit != nil {
+					wr = NewWriteLimiter(c.Conn, s.writeRateLimit)
 				}
 				}
 
 
 				rd := io.Reader(c.Conn)
 				rd := io.Reader(c.Conn)
-				if limit && readRateLimit != nil {
-					rd = &limitedReader{c.Conn, readRateLimit}
+				if limit && s.readRateLimit != nil {
+					rd = NewReadLimiter(c.Conn, s.readRateLimit)
 				}
 				}
 
 
 				name := fmt.Sprintf("%s-%s (%s)", c.Conn.LocalAddr(), c.Conn.RemoteAddr(), c.Type)
 				name := fmt.Sprintf("%s-%s (%s)", c.Conn.LocalAddr(), c.Conn.RemoteAddr(), c.Type)
 				protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)
 				protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)
 
 
 				l.Infof("Established secure connection to %s at %s", remoteID, name)
 				l.Infof("Established secure connection to %s at %s", remoteID, name)
-				if debugNet {
+				if debug {
 					l.Debugf("cipher suite: %04X in lan: %t", c.Conn.ConnectionState().CipherSuite, !limit)
 					l.Debugf("cipher suite: %04X in lan: %t", c.Conn.ConnectionState().CipherSuite, !limit)
 				}
 				}
 
 
@@ -245,7 +279,7 @@ func (s *connectionSvc) connect() {
 	for {
 	for {
 	nextDevice:
 	nextDevice:
 		for deviceID, deviceCfg := range s.cfg.Devices() {
 		for deviceID, deviceCfg := range s.cfg.Devices() {
-			if deviceID == myID {
+			if deviceID == s.myID {
 				continue
 				continue
 			}
 			}
 
 
@@ -291,12 +325,12 @@ func (s *connectionSvc) connect() {
 					continue
 					continue
 				}
 				}
 
 
-				if debugNet {
+				if debug {
 					l.Debugln("dial", deviceCfg.DeviceID, uri.String())
 					l.Debugln("dial", deviceCfg.DeviceID, uri.String())
 				}
 				}
 				conn, err := dialer(uri, s.tlsCfg)
 				conn, err := dialer(uri, s.tlsCfg)
 				if err != nil {
 				if err != nil {
-					if debugNet {
+					if debug {
 						l.Debugln("dial failed", deviceCfg.DeviceID, uri.String(), err)
 						l.Debugln("dial failed", deviceCfg.DeviceID, uri.String(), err)
 					}
 					}
 					continue
 					continue
@@ -323,11 +357,11 @@ func (s *connectionSvc) connect() {
 
 
 			reconIntv := time.Duration(s.cfg.Options().RelayReconnectIntervalM) * time.Minute
 			reconIntv := time.Duration(s.cfg.Options().RelayReconnectIntervalM) * time.Minute
 			if last, ok := s.lastRelayCheck[deviceID]; ok && time.Since(last) < reconIntv {
 			if last, ok := s.lastRelayCheck[deviceID]; ok && time.Since(last) < reconIntv {
-				if debugNet {
+				if debug {
 					l.Debugln("Skipping connecting via relay to", deviceID, "last checked at", last)
 					l.Debugln("Skipping connecting via relay to", deviceID, "last checked at", last)
 				}
 				}
 				continue nextDevice
 				continue nextDevice
-			} else if debugNet {
+			} else if debug {
 				l.Debugln("Trying relay connections to", deviceID, relays)
 				l.Debugln("Trying relay connections to", deviceID, relays)
 			}
 			}
 
 
@@ -342,21 +376,21 @@ func (s *connectionSvc) connect() {
 
 
 				inv, err := client.GetInvitationFromRelay(uri, deviceID, s.tlsCfg.Certificates)
 				inv, err := client.GetInvitationFromRelay(uri, deviceID, s.tlsCfg.Certificates)
 				if err != nil {
 				if err != nil {
-					if debugNet {
+					if debug {
 						l.Debugf("Failed to get invitation for %s from %s: %v", deviceID, uri, err)
 						l.Debugf("Failed to get invitation for %s from %s: %v", deviceID, uri, err)
 					}
 					}
 					continue
 					continue
-				} else if debugNet {
+				} else if debug {
 					l.Debugln("Succesfully retrieved relay invitation", inv, "from", uri)
 					l.Debugln("Succesfully retrieved relay invitation", inv, "from", uri)
 				}
 				}
 
 
 				conn, err := client.JoinSession(inv)
 				conn, err := client.JoinSession(inv)
 				if err != nil {
 				if err != nil {
-					if debugNet {
+					if debug {
 						l.Debugf("Failed to join relay session %s: %v", inv, err)
 						l.Debugf("Failed to join relay session %s: %v", inv, err)
 					}
 					}
 					continue
 					continue
-				} else if debugNet {
+				} else if debug {
 					l.Debugln("Sucessfully joined relay session", inv)
 					l.Debugln("Sucessfully joined relay session", inv)
 				}
 				}
 
 
@@ -412,7 +446,7 @@ func (s *connectionSvc) shouldLimit(addr net.Addr) bool {
 	if !ok {
 	if !ok {
 		return true
 		return true
 	}
 	}
-	for _, lan := range lans {
+	for _, lan := range s.lans {
 		if lan.Contains(tcpaddr.IP) {
 		if lan.Contains(tcpaddr.IP) {
 			return false
 			return false
 		}
 		}
@@ -444,3 +478,10 @@ func (s *connectionSvc) CommitConfiguration(from, to config.Configuration) bool
 
 
 	return true
 	return true
 }
 }
+
+// serviceFunc wraps a function to create a suture.Service without stop
+// functionality.
+type serviceFunc func()
+
+func (f serviceFunc) Serve() { f() }
+func (f serviceFunc) Stop()  {}

+ 4 - 4
cmd/syncthing/connections_tcp.go → lib/connections/connections_tcp.go

@@ -4,7 +4,7 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // You can obtain one at http://mozilla.org/MPL/2.0/.
 // You can obtain one at http://mozilla.org/MPL/2.0/.
 
 
-package main
+package connections
 
 
 import (
 import (
 	"crypto/tls"
 	"crypto/tls"
@@ -33,7 +33,7 @@ func tcpDialer(uri *url.URL, tlsCfg *tls.Config) (*tls.Conn, error) {
 
 
 	raddr, err := net.ResolveTCPAddr("tcp", uri.Host)
 	raddr, err := net.ResolveTCPAddr("tcp", uri.Host)
 	if err != nil {
 	if err != nil {
-		if debugNet {
+		if debug {
 			l.Debugln(err)
 			l.Debugln(err)
 		}
 		}
 		return nil, err
 		return nil, err
@@ -41,7 +41,7 @@ func tcpDialer(uri *url.URL, tlsCfg *tls.Config) (*tls.Conn, error) {
 
 
 	conn, err := net.DialTCP("tcp", nil, raddr)
 	conn, err := net.DialTCP("tcp", nil, raddr)
 	if err != nil {
 	if err != nil {
-		if debugNet {
+		if debug {
 			l.Debugln(err)
 			l.Debugln(err)
 		}
 		}
 		return nil, err
 		return nil, err
@@ -81,7 +81,7 @@ func tcpListener(uri *url.URL, tlsCfg *tls.Config, conns chan<- model.Intermedia
 			continue
 			continue
 		}
 		}
 
 
-		if debugNet {
+		if debug {
 			l.Debugln("connect from", conn.RemoteAddr())
 			l.Debugln("connect from", conn.RemoteAddr())
 		}
 		}
 
 

+ 19 - 0
lib/connections/debug.go

@@ -0,0 +1,19 @@
+// Copyright (C) 2014 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package connections
+
+import (
+	"os"
+	"strings"
+
+	"github.com/calmh/logger"
+)
+
+var (
+	debug = strings.Contains(os.Getenv("STTRACE"), "connections") || os.Getenv("STTRACE") == "all"
+	l     = logger.DefaultLogger
+)

+ 12 - 5
cmd/syncthing/limitedreader.go → lib/connections/limitedreader.go

@@ -4,7 +4,7 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // You can obtain one at http://mozilla.org/MPL/2.0/.
 // You can obtain one at http://mozilla.org/MPL/2.0/.
 
 
-package main
+package connections
 
 
 import (
 import (
 	"io"
 	"io"
@@ -12,13 +12,20 @@ import (
 	"github.com/juju/ratelimit"
 	"github.com/juju/ratelimit"
 )
 )
 
 
-type limitedReader struct {
-	r      io.Reader
+type LimitedReader struct {
+	reader io.Reader
 	bucket *ratelimit.Bucket
 	bucket *ratelimit.Bucket
 }
 }
 
 
-func (r *limitedReader) Read(buf []byte) (int, error) {
-	n, err := r.r.Read(buf)
+func NewReadLimiter(r io.Reader, b *ratelimit.Bucket) *LimitedReader {
+	return &LimitedReader{
+		reader: r,
+		bucket: b,
+	}
+}
+
+func (r *LimitedReader) Read(buf []byte) (int, error) {
+	n, err := r.reader.Read(buf)
 	if r.bucket != nil {
 	if r.bucket != nil {
 		r.bucket.Wait(int64(n))
 		r.bucket.Wait(int64(n))
 	}
 	}

+ 12 - 5
cmd/syncthing/limitedwriter.go → lib/connections/limitedwriter.go

@@ -4,7 +4,7 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // You can obtain one at http://mozilla.org/MPL/2.0/.
 // You can obtain one at http://mozilla.org/MPL/2.0/.
 
 
-package main
+package connections
 
 
 import (
 import (
 	"io"
 	"io"
@@ -12,14 +12,21 @@ import (
 	"github.com/juju/ratelimit"
 	"github.com/juju/ratelimit"
 )
 )
 
 
-type limitedWriter struct {
-	w      io.Writer
+type LimitedWriter struct {
+	writer io.Writer
 	bucket *ratelimit.Bucket
 	bucket *ratelimit.Bucket
 }
 }
 
 
-func (w *limitedWriter) Write(buf []byte) (int, error) {
+func NewWriteLimiter(w io.Writer, b *ratelimit.Bucket) *LimitedWriter {
+	return &LimitedWriter{
+		writer: w,
+		bucket: b,
+	}
+}
+
+func (w *LimitedWriter) Write(buf []byte) (int, error) {
 	if w.bucket != nil {
 	if w.bucket != nil {
 		w.bucket.Wait(int64(len(buf)))
 		w.bucket.Wait(int64(len(buf)))
 	}
 	}
-	return w.w.Write(buf)
+	return w.writer.Write(buf)
 }
 }