Browse Source

Connections have types

Audrius Butkevicius 10 năm trước cách đây
mục cha
commit
bb876eac82

+ 47 - 27
cmd/syncthing/connections.go

@@ -12,6 +12,7 @@ import (
 	"io"
 	"net"
 	"net/url"
+	"sync"
 	"time"
 
 	"github.com/syncthing/protocol"
@@ -22,7 +23,7 @@ import (
 )
 
 type DialerFactory func(*url.URL, *tls.Config) (*tls.Conn, error)
-type ListenerFactory func(*url.URL, *tls.Config, chan<- *tls.Conn)
+type ListenerFactory func(*url.URL, *tls.Config, chan<- intermediateConnection)
 
 var (
 	dialers   = make(map[string]DialerFactory, 0)
@@ -37,17 +38,27 @@ type connectionSvc struct {
 	myID   protocol.DeviceID
 	model  *model.Model
 	tlsCfg *tls.Config
-	conns  chan *tls.Conn
+	conns  chan intermediateConnection
+
+	mut      sync.RWMutex
+	connType map[protocol.DeviceID]model.ConnectionType
+}
+
+type intermediateConnection struct {
+	conn     *tls.Conn
+	connType model.ConnectionType
 }
 
-func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, model *model.Model, tlsCfg *tls.Config) *connectionSvc {
+func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, mdl *model.Model, tlsCfg *tls.Config) *connectionSvc {
 	svc := &connectionSvc{
 		Supervisor: suture.NewSimple("connectionSvc"),
 		cfg:        cfg,
 		myID:       myID,
-		model:      model,
+		model:      mdl,
 		tlsCfg:     tlsCfg,
-		conns:      make(chan *tls.Conn),
+		conns:      make(chan intermediateConnection),
+
+		connType: make(map[protocol.DeviceID]model.ConnectionType),
 	}
 
 	// There are several moving parts here; one routine per listening address
@@ -114,15 +125,15 @@ func newConnectionSvc(cfg *config.Wrapper, myID protocol.DeviceID, model *model.
 
 func (s *connectionSvc) handle() {
 next:
-	for conn := range s.conns {
-		cs := conn.ConnectionState()
+	for c := range s.conns {
+		cs := c.conn.ConnectionState()
 
 		// We should have negotiated the next level protocol "bep/1.0" as part
 		// of the TLS handshake. Unfortunately this can't be a hard error,
 		// because there are implementations out there that don't support
 		// protocol negotiation (iOS for one...).
 		if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != bepProtocolName {
-			l.Infof("Peer %s did not negotiate bep/1.0", conn.RemoteAddr())
+			l.Infof("Peer %s did not negotiate bep/1.0", c.conn.RemoteAddr())
 		}
 
 		// We should have received exactly one certificate from the other
@@ -130,8 +141,8 @@ next:
 		// connection.
 		certs := cs.PeerCertificates
 		if cl := len(certs); cl != 1 {
-			l.Infof("Got peer certificate list of length %d != 1 from %s; protocol error", cl, conn.RemoteAddr())
-			conn.Close()
+			l.Infof("Got peer certificate list of length %d != 1 from %s; protocol error", cl, c.conn.RemoteAddr())
+			c.conn.Close()
 			continue
 		}
 		remoteCert := certs[0]
@@ -142,7 +153,7 @@ next:
 		// clients between the same NAT gateway, and global discovery.
 		if remoteID == myID {
 			l.Infof("Connected to myself (%s) - should not happen", remoteID)
-			conn.Close()
+			c.conn.Close()
 			continue
 		}
 
@@ -154,7 +165,7 @@ next:
 		// connections still established...
 		if s.model.ConnectedTo(remoteID) {
 			l.Infof("Connected to already connected device (%s)", remoteID)
-			conn.Close()
+			c.conn.Close()
 			continue
 		}
 
@@ -172,35 +183,42 @@ next:
 					// Incorrect certificate name is something the user most
 					// likely wants to know about, since it's an advanced
 					// config. Warn instead of Info.
-					l.Warnf("Bad certificate from %s (%v): %v", remoteID, conn.RemoteAddr(), err)
-					conn.Close()
+					l.Warnf("Bad certificate from %s (%v): %v", remoteID, c.conn.RemoteAddr(), err)
+					c.conn.Close()
 					continue next
 				}
 
 				// If rate limiting is set, and based on the address we should
 				// limit the connection, then we wrap it in a limiter.
 
-				limit := s.shouldLimit(conn.RemoteAddr())
+				limit := s.shouldLimit(c.conn.RemoteAddr())
 
-				wr := io.Writer(conn)
+				wr := io.Writer(c.conn)
 				if limit && writeRateLimit != nil {
-					wr = &limitedWriter{conn, writeRateLimit}
+					wr = &limitedWriter{c.conn, writeRateLimit}
 				}
 
-				rd := io.Reader(conn)
+				rd := io.Reader(c.conn)
 				if limit && readRateLimit != nil {
-					rd = &limitedReader{conn, readRateLimit}
+					rd = &limitedReader{c.conn, readRateLimit}
 				}
 
-				name := fmt.Sprintf("%s-%s", conn.LocalAddr(), conn.RemoteAddr())
+				name := fmt.Sprintf("%s-%s (%s)", c.conn.LocalAddr(), c.conn.RemoteAddr(), c.connType)
 				protoConn := protocol.NewConnection(remoteID, rd, wr, s.model, name, deviceCfg.Compression)
 
 				l.Infof("Established secure connection to %s at %s", remoteID, name)
 				if debugNet {
-					l.Debugf("cipher suite: %04X in lan: %t", conn.ConnectionState().CipherSuite, !limit)
+					l.Debugf("cipher suite: %04X in lan: %t", c.conn.ConnectionState().CipherSuite, !limit)
 				}
 
-				s.model.AddConnection(conn, protoConn)
+				s.model.AddConnection(model.Connection{
+					c.conn,
+					protoConn,
+					c.connType,
+				})
+				s.mut.Lock()
+				s.connType[remoteID] = c.connType
+				s.mut.Unlock()
 				continue next
 			}
 		}
@@ -208,14 +226,14 @@ next:
 		if !s.cfg.IgnoredDevice(remoteID) {
 			events.Default.Log(events.DeviceRejected, map[string]string{
 				"device":  remoteID.String(),
-				"address": conn.RemoteAddr().String(),
+				"address": c.conn.RemoteAddr().String(),
 			})
-			l.Infof("Connection from %s with unknown device ID %s", conn.RemoteAddr(), remoteID)
+			l.Infof("Connection from %s (%s) with unknown device ID %s", c.conn.RemoteAddr(), c.connType, remoteID)
 		} else {
-			l.Infof("Connection from %s with ignored device ID %s", conn.RemoteAddr(), remoteID)
+			l.Infof("Connection from %s (%s) with ignored device ID %s", c.conn.RemoteAddr(), c.connType, remoteID)
 		}
 
-		conn.Close()
+		c.conn.Close()
 	}
 }
 
@@ -271,7 +289,9 @@ func (s *connectionSvc) connect() {
 					continue
 				}
 
-				s.conns <- conn
+				s.conns <- intermediateConnection{
+					conn, model.ConnectionTypeBasicDial,
+				}
 				continue nextDevice
 			}
 		}

+ 6 - 2
cmd/syncthing/connections_tcp.go

@@ -11,6 +11,8 @@ import (
 	"net"
 	"net/url"
 	"strings"
+
+	"github.com/syncthing/syncthing/lib/model"
 )
 
 func init() {
@@ -56,7 +58,7 @@ func tcpDialer(uri *url.URL, tlsCfg *tls.Config) (*tls.Conn, error) {
 	return tc, nil
 }
 
-func tcpListener(uri *url.URL, tlsCfg *tls.Config, conns chan<- *tls.Conn) {
+func tcpListener(uri *url.URL, tlsCfg *tls.Config, conns chan<- intermediateConnection) {
 	tcaddr, err := net.ResolveTCPAddr("tcp", uri.Host)
 	if err != nil {
 		l.Fatalln("listen (BEP/tcp):", err)
@@ -90,6 +92,8 @@ func tcpListener(uri *url.URL, tlsCfg *tls.Config, conns chan<- *tls.Conn) {
 			continue
 		}
 
-		conns <- tc
+		conns <- intermediateConnection{
+			tc, model.ConnectionTypeBasicAccept,
+		}
 	}
 }

+ 36 - 0
lib/model/connection.go

@@ -0,0 +1,36 @@
+// Copyright (C) 2015 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package model
+
+import (
+	"net"
+
+	"github.com/syncthing/protocol"
+)
+
+type Connection struct {
+	net.Conn
+	protocol.Connection
+	Type ConnectionType
+}
+
+const (
+	ConnectionTypeBasicAccept ConnectionType = iota
+	ConnectionTypeBasicDial
+)
+
+type ConnectionType int
+
+func (t ConnectionType) String() string {
+	switch t {
+	case ConnectionTypeBasicAccept:
+		return "basic-accept"
+	case ConnectionTypeBasicDial:
+		return "basic-dial"
+	}
+	return "unknown"
+}

+ 26 - 30
lib/model/model.go

@@ -87,10 +87,9 @@ type Model struct {
 	folderStatRefs map[string]*stats.FolderStatisticsReference            // folder -> statsRef
 	fmut           sync.RWMutex                                           // protects the above
 
-	protoConn map[protocol.DeviceID]protocol.Connection
-	rawConn   map[protocol.DeviceID]io.Closer
+	conn      map[protocol.DeviceID]Connection
 	deviceVer map[protocol.DeviceID]string
-	pmut      sync.RWMutex // protects protoConn and rawConn
+	pmut      sync.RWMutex // protects conn and deviceVer
 
 	reqValidationCache map[string]time.Time // folder / file name => time when confirmed to exist
 	rvmut              sync.RWMutex         // protects reqValidationCache
@@ -130,8 +129,7 @@ func NewModel(cfg *config.Wrapper, id protocol.DeviceID, deviceName, clientName,
 		folderIgnores:      make(map[string]*ignore.Matcher),
 		folderRunners:      make(map[string]service),
 		folderStatRefs:     make(map[string]*stats.FolderStatisticsReference),
-		protoConn:          make(map[protocol.DeviceID]protocol.Connection),
-		rawConn:            make(map[protocol.DeviceID]io.Closer),
+		conn:               make(map[protocol.DeviceID]Connection),
 		deviceVer:          make(map[protocol.DeviceID]string),
 		reqValidationCache: make(map[string]time.Time),
 
@@ -243,14 +241,14 @@ func (m *Model) ConnectionStats() map[string]interface{} {
 	m.fmut.RLock()
 
 	var res = make(map[string]interface{})
-	conns := make(map[string]ConnectionInfo, len(m.protoConn))
-	for device, conn := range m.protoConn {
+	conns := make(map[string]ConnectionInfo, len(m.conn))
+	for device, conn := range m.conn {
 		ci := ConnectionInfo{
 			Statistics:    conn.Statistics(),
 			ClientVersion: m.deviceVer[device],
 		}
-		if nc, ok := m.rawConn[device].(remoteAddrer); ok {
-			ci.Address = nc.RemoteAddr().String()
+		if addr := m.conn[device].RemoteAddr(); addr != nil {
+			ci.Address = addr.String()
 		}
 
 		conns[device.String()] = ci
@@ -586,8 +584,11 @@ func (m *Model) ClusterConfig(deviceID protocol.DeviceID, cm protocol.ClusterCon
 		"clientVersion": cm.ClientVersion,
 	}
 
-	if conn, ok := m.rawConn[deviceID].(*tls.Conn); ok {
-		event["addr"] = conn.RemoteAddr().String()
+	if conn, ok := m.conn[deviceID]; ok {
+		addr := conn.RemoteAddr()
+		if addr != nil {
+			event["addr"] = addr.String()
+		}
 	}
 
 	m.pmut.Unlock()
@@ -693,12 +694,11 @@ func (m *Model) Close(device protocol.DeviceID, err error) {
 	}
 	m.fmut.RUnlock()
 
-	conn, ok := m.rawConn[device]
+	conn, ok := m.conn[device]
 	if ok {
 		closeRawConn(conn)
 	}
-	delete(m.protoConn, device)
-	delete(m.rawConn, device)
+	delete(m.conn, device)
 	delete(m.deviceVer, device)
 	m.pmut.Unlock()
 }
@@ -860,7 +860,7 @@ func (cf cFiler) CurrentFile(file string) (protocol.FileInfo, bool) {
 // ConnectedTo returns true if we are connected to the named device.
 func (m *Model) ConnectedTo(deviceID protocol.DeviceID) bool {
 	m.pmut.RLock()
-	_, ok := m.protoConn[deviceID]
+	_, ok := m.conn[deviceID]
 	m.pmut.RUnlock()
 	if ok {
 		m.deviceWasSeen(deviceID)
@@ -927,28 +927,24 @@ func (m *Model) SetIgnores(folder string, content []string) error {
 // AddConnection adds a new peer connection to the model. An initial index will
 // be sent to the connected peer, thereafter index updates whenever the local
 // folder changes.
-func (m *Model) AddConnection(rawConn io.Closer, protoConn protocol.Connection) {
-	deviceID := protoConn.ID()
+func (m *Model) AddConnection(conn Connection) {
+	deviceID := conn.ID()
 
 	m.pmut.Lock()
-	if _, ok := m.protoConn[deviceID]; ok {
-		panic("add existing device")
-	}
-	m.protoConn[deviceID] = protoConn
-	if _, ok := m.rawConn[deviceID]; ok {
+	if _, ok := m.conn[deviceID]; ok {
 		panic("add existing device")
 	}
-	m.rawConn[deviceID] = rawConn
+	m.conn[deviceID] = conn
 
-	protoConn.Start()
+	conn.Start()
 
 	cm := m.clusterConfig(deviceID)
-	protoConn.ClusterConfig(cm)
+	conn.ClusterConfig(cm)
 
 	m.fmut.RLock()
 	for _, folder := range m.deviceFolders[deviceID] {
 		fs := m.folderFiles[folder]
-		go sendIndexes(protoConn, folder, fs, m.folderIgnores[folder])
+		go sendIndexes(conn, folder, fs, m.folderIgnores[folder])
 	}
 	m.fmut.RUnlock()
 	m.pmut.Unlock()
@@ -1114,7 +1110,7 @@ func (m *Model) updateLocals(folder string, fs []protocol.FileInfo) {
 
 func (m *Model) requestGlobal(deviceID protocol.DeviceID, folder, name string, offset int64, size int, hash []byte, flags uint32, options []protocol.Option) ([]byte, error) {
 	m.pmut.RLock()
-	nc, ok := m.protoConn[deviceID]
+	nc, ok := m.conn[deviceID]
 	m.pmut.RUnlock()
 
 	if !ok {
@@ -1640,7 +1636,7 @@ func (m *Model) Availability(folder, file string) []protocol.DeviceID {
 
 	availableDevices := []protocol.DeviceID{}
 	for _, device := range fs.Availability(file) {
-		_, ok := m.protoConn[device]
+		_, ok := m.conn[device]
 		if ok {
 			availableDevices = append(availableDevices, device)
 		}
@@ -1764,7 +1760,7 @@ func (m *Model) CommitConfiguration(from, to config.Configuration) bool {
 			// folder.
 			m.pmut.Lock()
 			for _, dev := range cfg.DeviceIDs() {
-				if conn, ok := m.rawConn[dev]; ok {
+				if conn, ok := m.conn[dev]; ok {
 					closeRawConn(conn)
 				}
 			}
@@ -1812,7 +1808,7 @@ func (m *Model) CommitConfiguration(from, to config.Configuration) bool {
 				// disconnect it so that we start sharing the folder with it.
 				// We close the underlying connection and let the normal error
 				// handling kick in to clean up and reconnect.
-				if conn, ok := m.rawConn[dev]; ok {
+				if conn, ok := m.conn[dev]; ok {
 					closeRawConn(conn)
 				}
 

+ 18 - 1
lib/model/model_test.go

@@ -12,6 +12,7 @@ import (
 	"fmt"
 	"io/ioutil"
 	"math/rand"
+	"net"
 	"os"
 	"path/filepath"
 	"strconv"
@@ -281,7 +282,11 @@ func BenchmarkRequest(b *testing.B) {
 		id:          device1,
 		requestData: []byte("some data to return"),
 	}
-	m.AddConnection(fc, fc)
+	m.AddConnection(Connection{
+		&net.TCPConn{},
+		fc,
+		ConnectionTypeBasicAccept,
+	})
 	m.Index(device1, "default", files, 0, nil)
 
 	b.ResetTimer()
@@ -314,6 +319,18 @@ func TestDeviceRename(t *testing.T) {
 
 	db, _ := leveldb.Open(storage.NewMemStorage(), nil)
 	m := NewModel(cfg, protocol.LocalDeviceID, "device", "syncthing", "dev", db)
+
+	fc := FakeConnection{
+		id:          device1,
+		requestData: []byte("some data to return"),
+	}
+
+	m.AddConnection(Connection{
+		&net.TCPConn{},
+		fc,
+		ConnectionTypeBasicAccept,
+	})
+
 	m.ServeBackground()
 	if cfg.Devices()[device1].Name != "" {
 		t.Errorf("Device already has a name")