Selaa lähdekoodia

Set TCP options on connections

Jakob Borg 11 vuotta sitten
vanhempi
sitoutus
fccdd85cc1
1 muutettua tiedostoa jossa 135 lisäystä ja 89 poistoa
  1. 135 89
      cmd/syncthing/main.go

+ 135 - 89
cmd/syncthing/main.go

@@ -611,98 +611,11 @@ func listenConnect(myID protocol.NodeID, m *model.Model, tlsCfg *tls.Config) {
 
 
 	// Listen
 	// Listen
 	for _, addr := range cfg.Options.ListenAddress {
 	for _, addr := range cfg.Options.ListenAddress {
-		addr := addr
-		go func() {
-			if debugNet {
-				l.Debugln("listening on", addr)
-			}
-			listener, err := tls.Listen("tcp", addr, tlsCfg)
-			l.FatalErr(err)
-
-			for {
-				conn, err := listener.Accept()
-				if err != nil {
-					l.Warnln(err)
-					continue
-				}
-
-				if debugNet {
-					l.Debugln("connect from", conn.RemoteAddr())
-				}
-
-				tc := conn.(*tls.Conn)
-				err = tc.Handshake()
-				if err != nil {
-					l.Warnln(err)
-					tc.Close()
-					continue
-				}
-
-				conns <- tc
-			}
-		}()
+		go listenTLS(conns, addr, tlsCfg)
 	}
 	}
 
 
 	// Connect
 	// Connect
-	go func() {
-		var delay time.Duration = 1 * time.Second
-		for {
-		nextNode:
-			for _, nodeCfg := range cfg.Nodes {
-				if nodeCfg.NodeID == myID {
-					continue
-				}
-				if m.ConnectedTo(nodeCfg.NodeID) {
-					continue
-				}
-
-				var addrs []string
-				for _, addr := range nodeCfg.Addresses {
-					if addr == "dynamic" {
-						if discoverer != nil {
-							t := discoverer.Lookup(nodeCfg.NodeID)
-							if len(t) == 0 {
-								continue
-							}
-							addrs = append(addrs, t...)
-						}
-					} else {
-						addrs = append(addrs, addr)
-					}
-				}
-
-				for _, addr := range addrs {
-					host, port, err := net.SplitHostPort(addr)
-					if err != nil && strings.HasPrefix(err.Error(), "missing port") {
-						// addr is on the form "1.2.3.4"
-						addr = net.JoinHostPort(addr, "22000")
-					} else if err == nil && port == "" {
-						// addr is on the form "1.2.3.4:"
-						addr = net.JoinHostPort(host, "22000")
-					}
-					if debugNet {
-						l.Debugln("dial", nodeCfg.NodeID, addr)
-					}
-					conn, err := tls.Dial("tcp", addr, tlsCfg)
-					if err != nil {
-						if debugNet {
-							l.Debugln(err)
-						}
-						continue
-					}
-
-					conns <- conn
-					continue nextNode
-				}
-			}
-
-			time.Sleep(delay)
-			delay *= 2
-			if maxD := time.Duration(cfg.Options.ReconnectIntervalS) * time.Second; delay > maxD {
-				delay = maxD
-			}
-		}
-	}()
+	go dialTLS(m, conns, tlsCfg)
 
 
 next:
 next:
 	for conn := range conns {
 	for conn := range conns {
@@ -753,6 +666,139 @@ next:
 	}
 	}
 }
 }
 
 
+func listenTLS(conns chan *tls.Conn, addr string, tlsCfg *tls.Config) {
+	if debugNet {
+		l.Debugln("listening on", addr)
+	}
+
+	tcaddr, err := net.ResolveTCPAddr("tcp", addr)
+	l.FatalErr(err)
+	listener, err := net.ListenTCP("tcp", tcaddr)
+	l.FatalErr(err)
+
+	for {
+		conn, err := listener.Accept()
+		if err != nil {
+			l.Warnln(err)
+			continue
+		}
+
+		if debugNet {
+			l.Debugln("connect from", conn.RemoteAddr())
+		}
+
+		tcpConn := conn.(*net.TCPConn)
+		setTCPOptions(tcpConn)
+
+		tc := tls.Server(conn, tlsCfg)
+		err = tc.Handshake()
+		if err != nil {
+			l.Warnln(err)
+			tc.Close()
+			continue
+		}
+
+		conns <- tc
+	}
+
+}
+
+func dialTLS(m *model.Model, conns chan *tls.Conn, tlsCfg *tls.Config) {
+	var delay time.Duration = 1 * time.Second
+	for {
+	nextNode:
+		for _, nodeCfg := range cfg.Nodes {
+			if nodeCfg.NodeID == myID {
+				continue
+			}
+
+			if m.ConnectedTo(nodeCfg.NodeID) {
+				continue
+			}
+
+			var addrs []string
+			for _, addr := range nodeCfg.Addresses {
+				if addr == "dynamic" {
+					if discoverer != nil {
+						t := discoverer.Lookup(nodeCfg.NodeID)
+						if len(t) == 0 {
+							continue
+						}
+						addrs = append(addrs, t...)
+					}
+				} else {
+					addrs = append(addrs, addr)
+				}
+			}
+
+			for _, addr := range addrs {
+				host, port, err := net.SplitHostPort(addr)
+				if err != nil && strings.HasPrefix(err.Error(), "missing port") {
+					// addr is on the form "1.2.3.4"
+					addr = net.JoinHostPort(addr, "22000")
+				} else if err == nil && port == "" {
+					// addr is on the form "1.2.3.4:"
+					addr = net.JoinHostPort(host, "22000")
+				}
+				if debugNet {
+					l.Debugln("dial", nodeCfg.NodeID, addr)
+				}
+
+				raddr, err := net.ResolveTCPAddr("tcp", addr)
+				if err != nil {
+					if debugNet {
+						l.Debugln(err)
+					}
+					continue
+				}
+
+				conn, err := net.DialTCP("tcp", nil, raddr)
+				if err != nil {
+					if debugNet {
+						l.Debugln(err)
+					}
+					continue
+				}
+
+				setTCPOptions(conn)
+
+				tc := tls.Client(conn, tlsCfg)
+				err = tc.Handshake()
+				if err != nil {
+					l.Warnln(err)
+					tc.Close()
+					continue
+				}
+
+				conns <- tc
+				continue nextNode
+			}
+		}
+
+		time.Sleep(delay)
+		delay *= 2
+		if maxD := time.Duration(cfg.Options.ReconnectIntervalS) * time.Second; delay > maxD {
+			delay = maxD
+		}
+	}
+}
+
+func setTCPOptions(conn *net.TCPConn) {
+	var err error
+	if err = conn.SetLinger(0); err != nil {
+		l.Infoln(err)
+	}
+	if err = conn.SetNoDelay(false); err != nil {
+		l.Infoln(err)
+	}
+	if err = conn.SetKeepAlivePeriod(60 * time.Second); err != nil {
+		l.Infoln(err)
+	}
+	if err = conn.SetKeepAlive(true); err != nil {
+		l.Infoln(err)
+	}
+}
+
 func discovery(extPort int) *discover.Discoverer {
 func discovery(extPort int) *discover.Discoverer {
 	disc, err := discover.NewDiscoverer(myID, cfg.Options.ListenAddress, cfg.Options.LocalAnnPort)
 	disc, err := discover.NewDiscoverer(myID, cfg.Options.ListenAddress, cfg.Options.LocalAnnPort)
 	if err != nil {
 	if err != nil {