Forráskód Böngészése

lib: Factor out getting IP address from net.Addr (#8538)

... and add fast paths for common cases.
greatroar 3 éve
szülő
commit
8065cf7e97

+ 3 - 4
lib/connections/quic_misc.go

@@ -16,6 +16,7 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/lucas-clemente/quic-go"
 	"github.com/lucas-clemente/quic-go"
+	"github.com/syncthing/syncthing/lib/osutil"
 )
 )
 
 
 var (
 var (
@@ -65,9 +66,7 @@ func (q *quicTlsConn) ConnectionState() tls.ConnectionState {
 }
 }
 
 
 func packetConnUnspecified(conn interface{}) bool {
 func packetConnUnspecified(conn interface{}) bool {
-	// Since QUIC connections are wrapped, we can't do a simple typecheck
-	// on *net.UDPAddr here.
 	addr := conn.(net.PacketConn).LocalAddr()
 	addr := conn.(net.PacketConn).LocalAddr()
-	host, _, err := net.SplitHostPort(addr.String())
-	return err == nil && net.ParseIP(host).IsUnspecified()
+	ip, err := osutil.IPFromAddr(addr)
+	return err == nil && ip.IsUnspecified()
 }
 }

+ 2 - 5
lib/connections/structs.go

@@ -18,6 +18,7 @@ import (
 	"github.com/syncthing/syncthing/lib/config"
 	"github.com/syncthing/syncthing/lib/config"
 	"github.com/syncthing/syncthing/lib/connections/registry"
 	"github.com/syncthing/syncthing/lib/connections/registry"
 	"github.com/syncthing/syncthing/lib/nat"
 	"github.com/syncthing/syncthing/lib/nat"
+	"github.com/syncthing/syncthing/lib/osutil"
 	"github.com/syncthing/syncthing/lib/protocol"
 	"github.com/syncthing/syncthing/lib/protocol"
 	"github.com/syncthing/syncthing/lib/stats"
 	"github.com/syncthing/syncthing/lib/stats"
 
 
@@ -117,14 +118,10 @@ func (c internalConn) Crypto() string {
 
 
 func (c internalConn) Transport() string {
 func (c internalConn) Transport() string {
 	transport := c.connType.Transport()
 	transport := c.connType.Transport()
-	host, _, err := net.SplitHostPort(c.LocalAddr().String())
+	ip, err := osutil.IPFromAddr(c.LocalAddr())
 	if err != nil {
 	if err != nil {
 		return transport
 		return transport
 	}
 	}
-	ip := net.ParseIP(host)
-	if ip == nil {
-		return transport
-	}
 	if ip.To4() != nil {
 	if ip.To4() != nil {
 		return transport + "4"
 		return transport + "4"
 	}
 	}

+ 12 - 0
lib/osutil/lan.go → lib/osutil/net.go

@@ -37,3 +37,15 @@ func GetLans() ([]*net.IPNet, error) {
 	}
 	}
 	return nets, nil
 	return nets, nil
 }
 }
+
+func IPFromAddr(addr net.Addr) (net.IP, error) {
+	switch a := addr.(type) {
+	case *net.TCPAddr:
+		return a.IP, nil
+	case *net.UDPAddr:
+		return a.IP, nil
+	default:
+		host, _, err := net.SplitHostPort(addr.String())
+		return net.ParseIP(host), err
+	}
+}

+ 3 - 4
lib/pmp/pmp.go

@@ -18,6 +18,7 @@ import (
 	natpmp "github.com/jackpal/go-nat-pmp"
 	natpmp "github.com/jackpal/go-nat-pmp"
 
 
 	"github.com/syncthing/syncthing/lib/nat"
 	"github.com/syncthing/syncthing/lib/nat"
+	"github.com/syncthing/syncthing/lib/osutil"
 	"github.com/syncthing/syncthing/lib/util"
 	"github.com/syncthing/syncthing/lib/util"
 )
 )
 
 
@@ -66,10 +67,8 @@ func Discover(ctx context.Context, renewal, timeout time.Duration) []nat.Device
 	conn, err := (&net.Dialer{}).DialContext(timeoutCtx, "udp", net.JoinHostPort(ip.String(), "5351"))
 	conn, err := (&net.Dialer{}).DialContext(timeoutCtx, "udp", net.JoinHostPort(ip.String(), "5351"))
 	if err == nil {
 	if err == nil {
 		conn.Close()
 		conn.Close()
-		localIPAddress, _, err := net.SplitHostPort(conn.LocalAddr().String())
-		if err == nil {
-			localIP = net.ParseIP(localIPAddress)
-		} else {
+		localIP, err = osutil.IPFromAddr(conn.LocalAddr())
+		if localIP == nil {
 			l.Debugln("Failed to lookup local IP", err)
 			l.Debugln("Failed to lookup local IP", err)
 		}
 		}
 	}
 	}

+ 2 - 9
lib/relay/client/methods.go

@@ -12,6 +12,7 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/syncthing/syncthing/lib/dialer"
 	"github.com/syncthing/syncthing/lib/dialer"
+	"github.com/syncthing/syncthing/lib/osutil"
 	syncthingprotocol "github.com/syncthing/syncthing/lib/protocol"
 	syncthingprotocol "github.com/syncthing/syncthing/lib/protocol"
 	"github.com/syncthing/syncthing/lib/relay/protocol"
 	"github.com/syncthing/syncthing/lib/relay/protocol"
 )
 )
@@ -66,7 +67,7 @@ func GetInvitationFromRelay(ctx context.Context, uri *url.URL, id syncthingproto
 		l.Debugln("Received invitation", msg, "via", conn.LocalAddr())
 		l.Debugln("Received invitation", msg, "via", conn.LocalAddr())
 		ip := net.IP(msg.Address)
 		ip := net.IP(msg.Address)
 		if len(ip) == 0 || ip.IsUnspecified() {
 		if len(ip) == 0 || ip.IsUnspecified() {
-			msg.Address = remoteIPBytes(conn)
+			msg.Address, _ = osutil.IPFromAddr(conn.RemoteAddr())
 		}
 		}
 		return msg, nil
 		return msg, nil
 	default:
 	default:
@@ -163,11 +164,3 @@ func configForCerts(certs []tls.Certificate) *tls.Config {
 		},
 		},
 	}
 	}
 }
 }
-
-func remoteIPBytes(conn net.Conn) []byte {
-	addr := conn.RemoteAddr().String()
-	if host, _, err := net.SplitHostPort(addr); err == nil {
-		addr = host
-	}
-	return net.ParseIP(addr)[:]
-}

+ 2 - 1
lib/relay/client/static.go

@@ -12,6 +12,7 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/syncthing/syncthing/lib/dialer"
 	"github.com/syncthing/syncthing/lib/dialer"
+	"github.com/syncthing/syncthing/lib/osutil"
 	syncthingprotocol "github.com/syncthing/syncthing/lib/protocol"
 	syncthingprotocol "github.com/syncthing/syncthing/lib/protocol"
 	"github.com/syncthing/syncthing/lib/relay/protocol"
 	"github.com/syncthing/syncthing/lib/relay/protocol"
 )
 )
@@ -87,7 +88,7 @@ func (c *staticClient) serve(ctx context.Context) error {
 			case protocol.SessionInvitation:
 			case protocol.SessionInvitation:
 				ip := net.IP(msg.Address)
 				ip := net.IP(msg.Address)
 				if len(ip) == 0 || ip.IsUnspecified() {
 				if len(ip) == 0 || ip.IsUnspecified() {
-					msg.Address = remoteIPBytes(c.conn)
+					msg.Address, _ = osutil.IPFromAddr(c.conn.RemoteAddr())
 				}
 				}
 				select {
 				select {
 				case c.invitations <- msg:
 				case c.invitations <- msg:

+ 2 - 6
lib/upnp/upnp.go

@@ -50,6 +50,7 @@ import (
 	"github.com/syncthing/syncthing/lib/build"
 	"github.com/syncthing/syncthing/lib/build"
 	"github.com/syncthing/syncthing/lib/dialer"
 	"github.com/syncthing/syncthing/lib/dialer"
 	"github.com/syncthing/syncthing/lib/nat"
 	"github.com/syncthing/syncthing/lib/nat"
+	"github.com/syncthing/syncthing/lib/osutil"
 )
 )
 
 
 func init() {
 func init() {
@@ -303,12 +304,7 @@ func localIP(ctx context.Context, url *url.URL) (net.IP, error) {
 	}
 	}
 	defer conn.Close()
 	defer conn.Close()
 
 
-	localIPAddress, _, err := net.SplitHostPort(conn.LocalAddr().String())
-	if err != nil {
-		return nil, err
-	}
-
-	return net.ParseIP(localIPAddress), nil
+	return osutil.IPFromAddr(conn.LocalAddr())
 }
 }
 
 
 func getChildDevices(d upnpDevice, deviceType string) []upnpDevice {
 func getChildDevices(d upnpDevice, deviceType string) []upnpDevice {