Bläddra i källkod

Merge pull request #2189 from burkemw3/lib-ify-connections

Decouple connections service from model
Jakob Borg 10 år sedan
förälder
incheckning
b158072a15

+ 2 - 8
cmd/syncthing/main.go

@@ -28,6 +28,7 @@ import (
 	"github.com/calmh/logger"
 	"github.com/juju/ratelimit"
 	"github.com/syncthing/syncthing/lib/config"
+	"github.com/syncthing/syncthing/lib/connections"
 	"github.com/syncthing/syncthing/lib/db"
 	"github.com/syncthing/syncthing/lib/discover"
 	"github.com/syncthing/syncthing/lib/events"
@@ -577,13 +578,6 @@ func syncthingMain() {
 		symlinks.Supported = false
 	}
 
-	if opts.MaxSendKbps > 0 {
-		writeRateLimit = ratelimit.NewBucketWithRate(float64(1000*opts.MaxSendKbps), int64(5*1000*opts.MaxSendKbps))
-	}
-	if opts.MaxRecvKbps > 0 {
-		readRateLimit = ratelimit.NewBucketWithRate(float64(1000*opts.MaxRecvKbps), int64(5*1000*opts.MaxRecvKbps))
-	}
-
 	if (opts.MaxRecvKbps > 0 || opts.MaxSendKbps > 0) && !opts.LimitBandwidthInLan {
 		lans, _ = osutil.GetLans()
 		networks := make([]string, 0, len(lans))
@@ -750,7 +744,7 @@ func syncthingMain() {
 
 	// Start connection management
 
-	connectionSvc := newConnectionSvc(cfg, myID, m, tlsCfg, cachedDiscovery, relaySvc)
+	connectionSvc := connections.NewConnectionSvc(cfg, myID, m, tlsCfg, cachedDiscovery, relaySvc, bepProtocolName, tlsDefaultCommonName, lans)
 	mainSvc.Add(connectionSvc)
 
 	if cpuProfile {

+ 78 - 37
cmd/syncthing/connections.go → lib/connections/connections.go

@@ -4,7 +4,7 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // You can obtain one at http://mozilla.org/MPL/2.0/.
 
-package main
+package connections
 
 import (
 	"crypto/tls"
@@ -15,6 +15,7 @@ import (
 	"sync"
 	"time"
 
+	"github.com/juju/ratelimit"
 	"github.com/syncthing/syncthing/lib/config"
 	"github.com/syncthing/syncthing/lib/discover"
 	"github.com/syncthing/syncthing/lib/events"
@@ -35,17 +36,39 @@ var (
 	listeners = make(map[string]ListenerFactory, 0)
 )
 
+type Model interface {
+	AddConnection(conn model.Connection)
+	ConnectedTo(remoteID protocol.DeviceID) bool
+	IsPaused(remoteID protocol.DeviceID) bool
+
+	// An index was received from the peer device
+	Index(deviceID protocol.DeviceID, folder string, files []protocol.FileInfo, flags uint32, options []protocol.Option)
+	// An index update was received from the peer device
+	IndexUpdate(deviceID protocol.DeviceID, folder string, files []protocol.FileInfo, flags uint32, options []protocol.Option)
+	// A request was made by the peer device
+	Request(deviceID protocol.DeviceID, folder string, name string, offset int64, hash []byte, flags uint32, options []protocol.Option, buf []byte) error
+	// A cluster configuration message was received
+	ClusterConfig(deviceID protocol.DeviceID, config protocol.ClusterConfigMessage)
+	// The peer device closed the connection
+	Close(deviceID protocol.DeviceID, err error)
+}
+
 // The connection service listens on TLS and dials configured unconnected
 // devices. Successful connections are handed to the model.
 type connectionSvc struct {
 	*suture.Supervisor
-	cfg        *config.Wrapper
-	myID       protocol.DeviceID
-	model      *model.Model
-	tlsCfg     *tls.Config
-	discoverer discover.Finder
-	conns      chan model.IntermediateConnection
-	relaySvc   *relay.Svc
+	cfg                  *config.Wrapper
+	myID                 protocol.DeviceID
+	model                Model
+	tlsCfg               *tls.Config
+	discoverer           discover.Finder
+	conns                chan model.IntermediateConnection
+	relaySvc             *relay.Svc
+	bepProtocolName      string
+	tlsDefaultCommonName string
+	lans                 []*net.IPNet
+	writeRateLimit       *ratelimit.Bucket
+	readRateLimit        *ratelimit.Bucket
 
 	lastRelayCheck map[protocol.DeviceID]time.Time
 
@@ -54,16 +77,20 @@ type connectionSvc struct {
 	relaysEnabled bool
 }
 
-func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl *model.Model, tlsCfg *tls.Config, discoverer discover.Finder, relaySvc *relay.Svc) *connectionSvc {
+func NewConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder, relaySvc *relay.Svc,
+	bepProtocolName string, tlsDefaultCommonName string, lans []*net.IPNet) suture.Service {
 	svc := &connectionSvc{
-		Supervisor: suture.NewSimple("connectionSvc"),
-		cfg:        cfg,
-		myID:       myID,
-		model:      mdl,
-		tlsCfg:     tlsCfg,
-		discoverer: discoverer,
-		relaySvc:   relaySvc,
-		conns:      make(chan model.IntermediateConnection),
+		Supervisor:           suture.NewSimple("connectionSvc"),
+		cfg:                  cfg,
+		myID:                 myID,
+		model:                mdl,
+		tlsCfg:               tlsCfg,
+		discoverer:           discoverer,
+		relaySvc:             relaySvc,
+		conns:                make(chan model.IntermediateConnection),
+		bepProtocolName:      bepProtocolName,
+		tlsDefaultCommonName: tlsDefaultCommonName,
+		lans:                 lans,
 
 		connType:       make(map[protocol.DeviceID]model.ConnectionType),
 		relaysEnabled:  cfg.Options().RelaysEnabled,
@@ -71,6 +98,13 @@ func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl *model.Mo
 	}
 	cfg.Subscribe(svc)
 
+	if svc.cfg.Options().MaxSendKbps > 0 {
+		svc.writeRateLimit = ratelimit.NewBucketWithRate(float64(1000*svc.cfg.Options().MaxSendKbps), int64(5*1000*svc.cfg.Options().MaxSendKbps))
+	}
+	if svc.cfg.Options().MaxRecvKbps > 0 {
+		svc.readRateLimit = ratelimit.NewBucketWithRate(float64(1000*svc.cfg.Options().MaxRecvKbps), int64(5*1000*svc.cfg.Options().MaxRecvKbps))
+	}
+
 	// There are several moving parts here; one routine per listening address
 	// to handle incoming connections, one routine to periodically attempt
 	// outgoing connections, one routine to the the common handling
@@ -97,7 +131,7 @@ func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl *model.Mo
 			continue
 		}
 
-		if debugNet {
+		if debug {
 			l.Debugln("listening on", uri.String())
 		}
 
@@ -123,7 +157,7 @@ next:
 		// of the TLS handshake. Unfortunately this can't be a hard error,
 		// because there are implementations out there that don't support
 		// protocol negotiation (iOS for one...).
-		if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != bepProtocolName {
+		if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != s.bepProtocolName {
 			l.Infof("Peer %s did not negotiate bep/1.0", c.Conn.RemoteAddr())
 		}
 
@@ -142,7 +176,7 @@ next:
 		// The device ID should not be that of ourselves. It can happen
 		// though, especially in the presence of NAT hairpinning, multiple
 		// clients between the same NAT gateway, and global discovery.
-		if remoteID == myID {
+		if remoteID == s.myID {
 			l.Infof("Connected to myself (%s) - should not happen", remoteID)
 			c.Conn.Close()
 			continue
@@ -154,7 +188,7 @@ next:
 		ct, ok := s.connType[remoteID]
 		s.mut.RUnlock()
 		if ok && !ct.IsDirect() && c.Type.IsDirect() {
-			if debugNet {
+			if debug {
 				l.Debugln("Switching connections", remoteID)
 			}
 			s.model.Close(remoteID, fmt.Errorf("switching connections"))
@@ -181,7 +215,7 @@ next:
 				// the certificate and used another name.
 				certName := deviceCfg.CertName
 				if certName == "" {
-					certName = tlsDefaultCommonName
+					certName = s.tlsDefaultCommonName
 				}
 				err := remoteCert.VerifyHostname(certName)
 				if err != nil {
@@ -199,20 +233,20 @@ next:
 				limit := s.shouldLimit(c.Conn.RemoteAddr())
 
 				wr := io.Writer(c.Conn)
-				if limit && writeRateLimit != nil {
-					wr = &limitedWriter{c.Conn, writeRateLimit}
+				if limit && s.writeRateLimit != nil {
+					wr = NewWriteLimiter(c.Conn, s.writeRateLimit)
 				}
 
 				rd := io.Reader(c.Conn)
-				if limit && readRateLimit != nil {
-					rd = &limitedReader{c.Conn, readRateLimit}
+				if limit && s.readRateLimit != nil {
+					rd = NewReadLimiter(c.Conn, s.readRateLimit)
 				}
 
 				name := fmt.Sprintf("%s-%s (%s)", c.Conn.LocalAddr(), c.Conn.RemoteAddr(), c.Type)
 				protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)
 
 				l.Infof("Established secure connection to %s at %s", remoteID, name)
-				if debugNet {
+				if debug {
 					l.Debugf("cipher suite: %04X in lan: %t", c.Conn.ConnectionState().CipherSuite, !limit)
 				}
 
@@ -245,7 +279,7 @@ func (s *connectionSvc) connect() {
 	for {
 	nextDevice:
 		for deviceID, deviceCfg := range s.cfg.Devices() {
-			if deviceID == myID {
+			if deviceID == s.myID {
 				continue
 			}
 
@@ -291,12 +325,12 @@ func (s *connectionSvc) connect() {
 					continue
 				}
 
-				if debugNet {
+				if debug {
 					l.Debugln("dial", deviceCfg.DeviceID, uri.String())
 				}
 				conn, err := dialer(uri, s.tlsCfg)
 				if err != nil {
-					if debugNet {
+					if debug {
 						l.Debugln("dial failed", deviceCfg.DeviceID, uri.String(), err)
 					}
 					continue
@@ -323,11 +357,11 @@ func (s *connectionSvc) connect() {
 
 			reconIntv := time.Duration(s.cfg.Options().RelayReconnectIntervalM) * time.Minute
 			if last, ok := s.lastRelayCheck[deviceID]; ok && time.Since(last) < reconIntv {
-				if debugNet {
+				if debug {
 					l.Debugln("Skipping connecting via relay to", deviceID, "last checked at", last)
 				}
 				continue nextDevice
-			} else if debugNet {
+			} else if debug {
 				l.Debugln("Trying relay connections to", deviceID, relays)
 			}
 
@@ -342,21 +376,21 @@ func (s *connectionSvc) connect() {
 
 				inv, err := client.GetInvitationFromRelay(uri, deviceID, s.tlsCfg.Certificates)
 				if err != nil {
-					if debugNet {
+					if debug {
 						l.Debugf("Failed to get invitation for %s from %s: %v", deviceID, uri, err)
 					}
 					continue
-				} else if debugNet {
+				} else if debug {
 					l.Debugln("Succesfully retrieved relay invitation", inv, "from", uri)
 				}
 
 				conn, err := client.JoinSession(inv)
 				if err != nil {
-					if debugNet {
+					if debug {
 						l.Debugf("Failed to join relay session %s: %v", inv, err)
 					}
 					continue
-				} else if debugNet {
+				} else if debug {
 					l.Debugln("Sucessfully joined relay session", inv)
 				}
 
@@ -412,7 +446,7 @@ func (s *connectionSvc) shouldLimit(addr net.Addr) bool {
 	if !ok {
 		return true
 	}
-	for _, lan := range lans {
+	for _, lan := range s.lans {
 		if lan.Contains(tcpaddr.IP) {
 			return false
 		}
@@ -444,3 +478,10 @@ func (s *connectionSvc) CommitConfiguration(from, to config.Configuration) bool
 
 	return true
 }
+
+// serviceFunc wraps a function to create a suture.Service without stop
+// functionality.
+type serviceFunc func()
+
+func (f serviceFunc) Serve() { f() }
+func (f serviceFunc) Stop()  {}

+ 4 - 4
cmd/syncthing/connections_tcp.go → lib/connections/connections_tcp.go

@@ -4,7 +4,7 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // You can obtain one at http://mozilla.org/MPL/2.0/.
 
-package main
+package connections
 
 import (
 	"crypto/tls"
@@ -33,7 +33,7 @@ func tcpDialer(uri *url.URL, tlsCfg *tls.Config) (*tls.Conn, error) {
 
 	raddr, err := net.ResolveTCPAddr("tcp", uri.Host)
 	if err != nil {
-		if debugNet {
+		if debug {
 			l.Debugln(err)
 		}
 		return nil, err
@@ -41,7 +41,7 @@ func tcpDialer(uri *url.URL, tlsCfg *tls.Config) (*tls.Conn, error) {
 
 	conn, err := net.DialTCP("tcp", nil, raddr)
 	if err != nil {
-		if debugNet {
+		if debug {
 			l.Debugln(err)
 		}
 		return nil, err
@@ -81,7 +81,7 @@ func tcpListener(uri *url.URL, tlsCfg *tls.Config, conns chan<- model.Intermedia
 			continue
 		}
 
-		if debugNet {
+		if debug {
 			l.Debugln("connect from", conn.RemoteAddr())
 		}
 

+ 19 - 0
lib/connections/debug.go

@@ -0,0 +1,19 @@
+// Copyright (C) 2014 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package connections
+
+import (
+	"os"
+	"strings"
+
+	"github.com/calmh/logger"
+)
+
+var (
+	debug = strings.Contains(os.Getenv("STTRACE"), "connections") || os.Getenv("STTRACE") == "all"
+	l     = logger.DefaultLogger
+)

+ 12 - 5
cmd/syncthing/limitedreader.go → lib/connections/limitedreader.go

@@ -4,7 +4,7 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // You can obtain one at http://mozilla.org/MPL/2.0/.
 
-package main
+package connections
 
 import (
 	"io"
@@ -12,13 +12,20 @@ import (
 	"github.com/juju/ratelimit"
 )
 
-type limitedReader struct {
-	r      io.Reader
+type LimitedReader struct {
+	reader io.Reader
 	bucket *ratelimit.Bucket
 }
 
-func (r *limitedReader) Read(buf []byte) (int, error) {
-	n, err := r.r.Read(buf)
+func NewReadLimiter(r io.Reader, b *ratelimit.Bucket) *LimitedReader {
+	return &LimitedReader{
+		reader: r,
+		bucket: b,
+	}
+}
+
+func (r *LimitedReader) Read(buf []byte) (int, error) {
+	n, err := r.reader.Read(buf)
 	if r.bucket != nil {
 		r.bucket.Wait(int64(n))
 	}

+ 12 - 5
cmd/syncthing/limitedwriter.go → lib/connections/limitedwriter.go

@@ -4,7 +4,7 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 // You can obtain one at http://mozilla.org/MPL/2.0/.
 
-package main
+package connections
 
 import (
 	"io"
@@ -12,14 +12,21 @@ import (
 	"github.com/juju/ratelimit"
 )
 
-type limitedWriter struct {
-	w      io.Writer
+type LimitedWriter struct {
+	writer io.Writer
 	bucket *ratelimit.Bucket
 }
 
-func (w *limitedWriter) Write(buf []byte) (int, error) {
+func NewWriteLimiter(w io.Writer, b *ratelimit.Bucket) *LimitedWriter {
+	return &LimitedWriter{
+		writer: w,
+		bucket: b,
+	}
+}
+
+func (w *LimitedWriter) Write(buf []byte) (int, error) {
 	if w.bucket != nil {
 		w.bucket.Wait(int64(len(buf)))
 	}
-	return w.w.Write(buf)
+	return w.writer.Write(buf)
 }