瀏覽代碼

lib/connections: Fix and optimize registry (#7996)

Registry.Get used a full sort to get the minimum of a list, and the sort
was broken because util.AddressUnspecifiedLess assumed it could find out
whether an address is IPv4 or IPv6 from its Network method. However,
net.(TCP|UDP)Addr.Network always returns "tcp"/"udp".
greatroar 4 年之前
父節點
當前提交
7c292cc812

+ 2 - 1
lib/connections/quic_dial.go

@@ -58,7 +58,8 @@ func (d *quicDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL
 	// Given we always pass the connection to quic, it assumes it's a remote connection it never closes it,
 	// So our wrapper around it needs to close it, but it only needs to close it if it's not the listening connection.
 	var createdConn net.PacketConn
-	if listenConn := registry.Get(uri.Scheme, packetConnLess); listenConn != nil {
+	listenConn := registry.Get(uri.Scheme, packetConnUnspecified)
+	if listenConn != nil {
 		conn = listenConn.(net.PacketConn)
 	} else {
 		if packetConn, err := net.ListenPacket("udp", ":0"); err != nil {

+ 6 - 4
lib/connections/quic_misc.go

@@ -15,7 +15,6 @@ import (
 	"net/url"
 
 	"github.com/lucas-clemente/quic-go"
-	"github.com/syncthing/syncthing/lib/util"
 )
 
 var (
@@ -63,7 +62,10 @@ func (q *quicTlsConn) ConnectionState() tls.ConnectionState {
 	return q.Session.ConnectionState().TLS.ConnectionState
 }
 
-// Sort available packet connections by ip address, preferring unspecified local address.
-func packetConnLess(i interface{}, j interface{}) bool {
-	return util.AddressUnspecifiedLess(i.(net.PacketConn).LocalAddr(), j.(net.PacketConn).LocalAddr())
+func packetConnUnspecified(conn interface{}) bool {
+	// Since QUIC connections are wrapped, we can't do a simple typecheck
+	// on *net.UDPAddr here.
+	addr := conn.(net.PacketConn).LocalAddr()
+	host, _, err := net.SplitHostPort(addr.String())
+	return err == nil && net.ParseIP(host).IsUnspecified()
 }

+ 27 - 17
lib/connections/registry/registry.go

@@ -10,7 +10,6 @@
 package registry
 
 import (
-	"sort"
 	"strings"
 
 	"github.com/syncthing/syncthing/lib/sync"
@@ -46,7 +45,7 @@ func (r *Registry) Unregister(scheme string, item interface{}) {
 	candidates := r.available[scheme]
 	for i, existingItem := range candidates {
 		if existingItem == item {
-			copy(candidates[i:], candidates[i+1:])
+			candidates[i] = candidates[len(candidates)-1]
 			candidates[len(candidates)-1] = nil
 			r.available[scheme] = candidates[:len(candidates)-1]
 			break
@@ -54,26 +53,37 @@ func (r *Registry) Unregister(scheme string, item interface{}) {
 	}
 }
 
-func (r *Registry) Get(scheme string, less func(i, j interface{}) bool) interface{} {
+// Get returns an item for a schema compatible with the given scheme.
+// If any item satisfies preferred, that has precedence over other items.
+func (r *Registry) Get(scheme string, preferred func(interface{}) bool) interface{} {
 	r.mut.Lock()
 	defer r.mut.Unlock()
 
-	candidates := make([]interface{}, 0)
+	var (
+		best       interface{}
+		bestPref   bool
+		bestScheme string
+	)
 	for availableScheme, items := range r.available {
 		// quic:// should be considered ok for both quic4:// and quic6://
-		if strings.HasPrefix(scheme, availableScheme) {
-			candidates = append(candidates, items...)
+		if !strings.HasPrefix(scheme, availableScheme) {
+			continue
+		}
+		for _, item := range items {
+			better := best == nil
+			pref := preferred(item)
+			if !better {
+				// In case of a tie, prefer "quic" to "quic[46]" etc.
+				better = pref &&
+					(!bestPref || len(availableScheme) < len(bestScheme))
+			}
+			if !better {
+				continue
+			}
+			best, bestPref, bestScheme = item, pref, availableScheme
 		}
 	}
-
-	if len(candidates) == 0 {
-		return nil
-	}
-
-	sort.Slice(candidates, func(i, j int) bool {
-		return less(candidates[i], candidates[j])
-	})
-	return candidates[0]
+	return best
 }
 
 func Register(scheme string, item interface{}) {
@@ -84,6 +94,6 @@ func Unregister(scheme string, item interface{}) {
 	Default.Unregister(scheme, item)
 }
 
-func Get(scheme string, less func(i, j interface{}) bool) interface{} {
-	return Default.Get(scheme, less)
+func Get(scheme string, preferred func(interface{}) bool) interface{} {
+	return Default.Get(scheme, preferred)
 }

+ 38 - 13
lib/connections/registry/registry_test.go

@@ -7,13 +7,18 @@
 package registry
 
 import (
+	"net"
 	"testing"
 )
 
 func TestRegistry(t *testing.T) {
 	r := New()
 
-	if res := r.Get("int", intLess); res != nil {
+	want := func(i int) func(interface{}) bool {
+		return func(x interface{}) bool { return x.(int) == i }
+	}
+
+	if res := r.Get("int", want(1)); res != nil {
 		t.Error("unexpected")
 	}
 
@@ -24,30 +29,28 @@ func TestRegistry(t *testing.T) {
 	r.Register("int6", 6)
 	r.Register("int6", 66)
 
-	if res := r.Get("int", intLess).(int); res != 1 {
+	if res := r.Get("int", want(1)).(int); res != 1 {
 		t.Error("unexpected", res)
 	}
 
 	// int is prefix of int4, so returns 1
-	if res := r.Get("int4", intLess).(int); res != 1 {
+	if res := r.Get("int4", want(1)).(int); res != 1 {
 		t.Error("unexpected", res)
 	}
 
 	r.Unregister("int", 1)
 
-	// Check that falls through to 11
-	if res := r.Get("int", intLess).(int); res != 11 {
+	if res := r.Get("int", want(1)).(int); res == 1 {
 		t.Error("unexpected", res)
 	}
 
-	// 6 is smaller than 11 available in int.
-	if res := r.Get("int6", intLess).(int); res != 6 {
+	if res := r.Get("int6", want(6)).(int); res != 6 {
 		t.Error("unexpected", res)
 	}
 
 	// Unregister 11, int should be impossible to find
 	r.Unregister("int", 11)
-	if res := r.Get("int", intLess); res != nil {
+	if res := r.Get("int", want(11)); res != nil {
 		t.Error("unexpected")
 	}
 
@@ -59,13 +62,35 @@ func TestRegistry(t *testing.T) {
 	r.Register("int", 1)
 	r.Unregister("int", 1)
 
-	if res := r.Get("int4", intLess).(int); res != 1 {
+	if res := r.Get("int4", want(1)).(int); res != 1 {
 		t.Error("unexpected", res)
 	}
 }
 
-func intLess(i, j interface{}) bool {
-	iInt := i.(int)
-	jInt := j.(int)
-	return iInt < jInt
+func TestShortSchemeFirst(t *testing.T) {
+	r := New()
+	r.Register("foo", 0)
+	r.Register("foobar", 1)
+
+	// If we don't care about the value, we should get the one with "foo".
+	res := r.Get("foo", func(interface{}) bool { return false })
+	if res != 0 {
+		t.Error("unexpected", res)
+	}
+}
+
+func BenchmarkGet(b *testing.B) {
+	r := New()
+	for _, addr := range []string{"192.168.1.1", "172.1.1.1", "10.1.1.1"} {
+		r.Register("tcp", &net.TCPAddr{IP: net.ParseIP(addr)})
+	}
+
+	b.ReportAllocs()
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		r.Get("tcp", func(x interface{}) bool {
+			return x.(*net.TCPAddr).IP.IsUnspecified()
+		})
+	}
 }

+ 0 - 6
lib/dialer/internal.go

@@ -13,7 +13,6 @@ import (
 	"os"
 	"time"
 
-	"github.com/syncthing/syncthing/lib/util"
 	"golang.org/x/net/proxy"
 )
 
@@ -61,11 +60,6 @@ func socksDialerFunction(u *url.URL, forward proxy.Dialer) (proxy.Dialer, error)
 	return proxy.SOCKS5("tcp", u.Host, auth, forward)
 }
 
-// Sort available addresses, preferring unspecified address.
-func tcpAddrLess(i interface{}, j interface{}) bool {
-	return util.AddressUnspecifiedLess(i.(*net.TCPAddr), j.(*net.TCPAddr))
-}
-
 // dialerConn is needed because proxy dialed connections have RemoteAddr() pointing at the proxy,
 // which then screws up various things such as IsLAN checks, and "let's populate the relay invitation address from
 // existing connection" shenanigans.

+ 3 - 1
lib/dialer/public.go

@@ -110,7 +110,9 @@ func DialContextReusePort(ctx context.Context, network, addr string) (net.Conn,
 		return DialContext(ctx, network, addr)
 	}
 
-	localAddrInterface := registry.Get(network, tcpAddrLess)
+	localAddrInterface := registry.Get(network, func(addr interface{}) bool {
+		return addr.(*net.TCPAddr).IP.IsUnspecified()
+	})
 	if localAddrInterface == nil {
 		// Nothing listening, nothing to reuse.
 		return DialContext(ctx, network, addr)

+ 0 - 20
lib/util/utils.go

@@ -9,7 +9,6 @@ package util
 import (
 	"context"
 	"fmt"
-	"net"
 	"net/url"
 	"reflect"
 	"strconv"
@@ -231,25 +230,6 @@ func Address(network, host string) string {
 	return u.String()
 }
 
-// AddressUnspecifiedLess is a comparator function preferring least specific network address (most widely listening,
-// namely preferring 0.0.0.0 over some IP), if both IPs are equal, it prefers the less restrictive network (prefers tcp
-// over tcp4)
-func AddressUnspecifiedLess(a, b net.Addr) bool {
-	aIsUnspecified := false
-	bIsUnspecified := false
-	if host, _, err := net.SplitHostPort(a.String()); err == nil {
-		aIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
-	}
-	if host, _, err := net.SplitHostPort(b.String()); err == nil {
-		bIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
-	}
-
-	if aIsUnspecified == bIsUnspecified {
-		return len(a.Network()) < len(b.Network())
-	}
-	return aIsUnspecified
-}
-
 func CallWithContext(ctx context.Context, fn func() error) error {
 	var err error
 	done := make(chan struct{})

+ 0 - 44
lib/util/utils_test.go

@@ -225,50 +225,6 @@ func TestCopyMatching(t *testing.T) {
 	}
 }
 
-type mockedAddr struct {
-	network string
-	addr    string
-}
-
-func (a mockedAddr) Network() string {
-	return a.network
-}
-
-func (a mockedAddr) String() string {
-	return a.addr
-}
-
-func TestInspecifiedAddressLess(t *testing.T) {
-	cases := []struct {
-		netA  string
-		addrA string
-		netB  string
-		addrB string
-	}{
-		// B is assumed the winner.
-		{"tcp", "127.0.0.1:1234", "tcp", ":1235"},
-		{"tcp", "127.0.0.1:1234", "tcp", "0.0.0.0:1235"},
-		{"tcp4", "0.0.0.0:1234", "tcp", "0.0.0.0:1235"}, // tcp4 on the first one
-	}
-
-	for i, testCase := range cases {
-		addrs := []mockedAddr{
-			{testCase.netA, testCase.addrA},
-			{testCase.netB, testCase.addrB},
-		}
-
-		if AddressUnspecifiedLess(addrs[0], addrs[1]) {
-			t.Error(i, "unexpected")
-		}
-		if !AddressUnspecifiedLess(addrs[1], addrs[0]) {
-			t.Error(i, "unexpected")
-		}
-		if AddressUnspecifiedLess(addrs[0], addrs[0]) || AddressUnspecifiedLess(addrs[1], addrs[1]) {
-			t.Error(i, "unexpected")
-		}
-	}
-}
-
 func TestFillNil(t *testing.T) {
 	type A struct {
 		Slice []int