浏览代码

lib/connections: Handle wrapped connection in SetTCPOptions (fixes #3223)

GitHub-Pull-Request: https://github.com/syncthing/syncthing/pull/3225
Jakob Borg 9 年之前
父节点
当前提交
ac40b27c79

+ 1 - 2
lib/connections/relay_dial.go

@@ -8,7 +8,6 @@ package connections
 
 import (
 	"crypto/tls"
-	"net"
 	"net/url"
 	"time"
 
@@ -40,7 +39,7 @@ func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (IntermediateConn
 		return IntermediateConnection{}, err
 	}
 
-	err = dialer.SetTCPOptions(conn.(*net.TCPConn))
+	err = dialer.SetTCPOptions(conn)
 	if err != nil {
 		conn.Close()
 		return IntermediateConnection{}, err

+ 1 - 2
lib/connections/relay_listen.go

@@ -8,7 +8,6 @@ package connections
 
 import (
 	"crypto/tls"
-	"net"
 	"net/url"
 	"sync"
 	"time"
@@ -74,7 +73,7 @@ func (t *relayListener) Serve() {
 				continue
 			}
 
-			err = dialer.SetTCPOptions(conn.(*net.TCPConn))
+			err = dialer.SetTCPOptions(conn)
 			if err != nil {
 				l.Infoln(err)
 			}

+ 1 - 1
lib/connections/tcp_listen.go

@@ -102,7 +102,7 @@ func (t *tcpListener) Serve() {
 
 		l.Debugln("connect from", conn.RemoteAddr())
 
-		err = dialer.SetTCPOptions(conn.(*net.TCPConn))
+		err = dialer.SetTCPOptions(conn)
 		if err != nil {
 			l.Infoln(err)
 		}

+ 2 - 6
lib/dialer/internal.go

@@ -57,9 +57,7 @@ func dialWithFallback(proxyDialFunc dialFunc, fallbackDialFunc dialFunc, network
 	conn, err := proxyDialFunc(network, addr)
 	if err == nil {
 		l.Debugf("Dialing %s address %s via proxy - success, %s -> %s", network, addr, conn.LocalAddr(), conn.RemoteAddr())
-		if tcpconn, ok := conn.(*net.TCPConn); ok {
-			SetTCPOptions(tcpconn)
-		}
+		SetTCPOptions(conn)
 		return dialerConn{
 			conn, newDialerAddr(network, addr),
 		}, nil
@@ -73,9 +71,7 @@ func dialWithFallback(proxyDialFunc dialFunc, fallbackDialFunc dialFunc, network
 	conn, err = fallbackDialFunc(network, addr)
 	if err == nil {
 		l.Debugf("Dialing %s address %s via fallback - success, %s -> %s", network, addr, conn.LocalAddr(), conn.RemoteAddr())
-		if tcpconn, ok := conn.(*net.TCPConn); ok {
-			SetTCPOptions(tcpconn)
-		}
+		SetTCPOptions(conn)
 	} else {
 		l.Debugf("Dialing %s address %s via fallback - error %s", network, addr, err)
 	}

+ 26 - 15
lib/dialer/public.go

@@ -7,6 +7,7 @@
 package dialer
 
 import (
+	"fmt"
 	"net"
 	"time"
 )
@@ -47,20 +48,30 @@ func DialTimeout(network, addr string, timeout time.Duration) (net.Conn, error)
 	return net.DialTimeout(network, addr, timeout)
 }
 
-// SetTCPOptions sets syncthings default TCP options on a TCP connection
-func SetTCPOptions(conn *net.TCPConn) error {
-	var err error
-	if err = conn.SetLinger(0); err != nil {
-		return err
-	}
-	if err = conn.SetNoDelay(false); err != nil {
-		return err
-	}
-	if err = conn.SetKeepAlivePeriod(60 * time.Second); err != nil {
-		return err
-	}
-	if err = conn.SetKeepAlive(true); err != nil {
-		return err
+// SetTCPOptions sets our default TCP options on a TCP connection, possibly
+// digging through dialerConn to extract the *net.TCPConn
+func SetTCPOptions(conn net.Conn) error {
+	switch conn := conn.(type) {
+	case *net.TCPConn:
+		var err error
+		if err = conn.SetLinger(0); err != nil {
+			return err
+		}
+		if err = conn.SetNoDelay(false); err != nil {
+			return err
+		}
+		if err = conn.SetKeepAlivePeriod(60 * time.Second); err != nil {
+			return err
+		}
+		if err = conn.SetKeepAlive(true); err != nil {
+			return err
+		}
+		return nil
+
+	case dialerConn:
+		return SetTCPOptions(conn.Conn)
+
+	default:
+		return fmt.Errorf("unknown connection type %T", conn)
 	}
-	return nil
 }

+ 2 - 2
lib/protocol/benchmark_test.go

@@ -131,8 +131,8 @@ func getTCPConnectionPair() (net.Conn, net.Conn, error) {
 	}
 
 	// Set the buffer sizes etc as usual
-	dialer.SetTCPOptions(conn0.(*net.TCPConn))
-	dialer.SetTCPOptions(conn1.(*net.TCPConn))
+	dialer.SetTCPOptions(conn0)
+	dialer.SetTCPOptions(conn1)
 
 	return conn0, conn1, nil
 }

+ 2 - 1
lib/relay/client/static.go

@@ -122,7 +122,8 @@ func (c *staticClient) Serve() {
 			case protocol.SessionInvitation:
 				ip := net.IP(msg.Address)
 				if len(ip) == 0 || ip.IsUnspecified() {
-					msg.Address = c.conn.RemoteAddr().(*net.TCPAddr).IP[:]
+					ip := net.ParseIP(c.conn.RemoteAddr().String())
+					msg.Address = ip[:]
 				}
 				c.invitations <- msg