Browse Source

lib/connections: Add connection benchmarks, allow binding to port zero addresses (#7648)

* Add connbench

* Refactor port fixup

* More cleanup

* touch for build

Co-authored-by: Jakob Borg <[email protected]>
Audrius Butkevicius 4 years ago
parent
commit
eb178caf3a

+ 207 - 0
lib/connections/connections_test.go

@@ -8,12 +8,26 @@ package connections
 
 import (
 	"context"
+	"crypto/tls"
 	"errors"
+	"fmt"
+	"io/ioutil"
+	"math/rand"
+	"net"
 	"net/url"
+	"os"
+	"strings"
 	"testing"
+	"time"
+
+	"github.com/thejerf/suture/v4"
 
 	"github.com/syncthing/syncthing/lib/config"
+	"github.com/syncthing/syncthing/lib/events"
+	"github.com/syncthing/syncthing/lib/nat"
 	"github.com/syncthing/syncthing/lib/protocol"
+	"github.com/syncthing/syncthing/lib/sync"
+	"github.com/syncthing/syncthing/lib/tlsutil"
 )
 
 func TestFixupPort(t *testing.T) {
@@ -216,3 +230,196 @@ func TestConnectionStatus(t *testing.T) {
 
 	check(nil, nil)
 }
+
+func BenchmarkConnections(pb *testing.B) {
+	addrs := []string{
+		"tcp://127.0.0.1:0",
+		"quic://127.0.0.1:0",
+		"relay://127.0.0.1:22067",
+	}
+	sizes := []int{
+		1 << 10,
+		1 << 15,
+		1 << 20,
+		1 << 22,
+	}
+	haveRelay := false
+	// Check if we have a relay running locally
+	conn, err := net.DialTimeout("tcp", "127.0.0.1:22067", 100*time.Millisecond)
+	if err == nil {
+		haveRelay = true
+		_ = conn.Close()
+	}
+	for _, addr := range addrs {
+		for _, sz := range sizes {
+			for _, direction := range []string{"cs", "sc"} {
+				proto := strings.SplitN(addr, ":", 2)[0]
+				pb.Run(fmt.Sprintf("%s_%d_%s", proto, sz, direction), func(b *testing.B) {
+					if proto == "relay" && !haveRelay {
+						b.Skip("could not connect to relay")
+					}
+					withConnectionPair(b, addr, func(client, server internalConn) {
+						if direction == "sc" {
+							server, client = client, server
+						}
+						data := make([]byte, sz)
+						if _, err := rand.Read(data); err != nil {
+							b.Fatal(err)
+						}
+
+						total := 0
+						wg := sync.NewWaitGroup()
+						b.ResetTimer()
+						for i := 0; i < b.N; i++ {
+							wg.Add(2)
+							go func() {
+								if err := sendMsg(client, data); err != nil {
+									b.Fatal(err)
+								}
+								wg.Done()
+							}()
+							go func() {
+								if err := recvMsg(server, data); err != nil {
+									b.Fatal(err)
+								}
+								total += sz
+								wg.Done()
+							}()
+							wg.Wait()
+						}
+						b.ReportAllocs()
+						b.SetBytes(int64(total / b.N))
+					})
+				})
+			}
+
+		}
+	}
+}
+
+func sendMsg(c internalConn, buf []byte) error {
+	n, err := c.Write(buf)
+	if n != len(buf) || err != nil {
+		return err
+	}
+	return nil
+}
+
+func recvMsg(c internalConn, buf []byte) error {
+	for read := 0; read != len(buf); {
+		n, err := c.Read(buf)
+		read += n
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func withConnectionPair(b *testing.B, connUri string, h func(client, server internalConn)) {
+	// Root of the service tree.
+	supervisor := suture.New("main", suture.Spec{
+		PassThroughPanics: true,
+	})
+
+	cert := mustGetCert(b)
+	deviceId := protocol.NewDeviceID(cert.Certificate[0])
+	tlsCfg := tlsutil.SecureDefaultTLS13()
+	tlsCfg.Certificates = []tls.Certificate{cert}
+	tlsCfg.NextProtos = []string{"bench"}
+	tlsCfg.ClientAuth = tls.RequestClientCert
+	tlsCfg.SessionTicketsDisabled = true
+	tlsCfg.InsecureSkipVerify = true
+
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	supervisor.ServeBackground(ctx)
+
+	cfg := config.Configuration{
+		Options: config.OptionsConfiguration{
+			RelaysEnabled: true,
+		},
+	}
+	wcfg := config.Wrap("", cfg, deviceId, events.NoopLogger)
+	uri, err := url.Parse(connUri)
+	if err != nil {
+		b.Fatal(err)
+	}
+	lf, err := getListenerFactory(cfg, uri)
+	if err != nil {
+		b.Fatal(err)
+	}
+	natSvc := nat.NewService(deviceId, wcfg)
+	conns := make(chan internalConn, 1)
+	listenSvc := lf.New(uri, wcfg, tlsCfg, conns, natSvc)
+	supervisor.Add(listenSvc)
+
+	var addr *url.URL
+	for {
+		addrs := listenSvc.LANAddresses()
+		if len(addrs) > 0 {
+			if !strings.HasSuffix(addrs[0].Host, ":0") {
+				addr = addrs[0]
+				break
+			}
+		}
+		time.Sleep(time.Millisecond)
+	}
+
+	df, err := getDialerFactory(cfg, addr)
+	if err != nil {
+		b.Fatal(err)
+	}
+	dialer := df.New(cfg.Options, tlsCfg)
+
+	// Relays might take some time to register the device, so dial multiple times
+	clientConn, err := dialer.Dial(ctx, deviceId, addr)
+	if err != nil {
+		for i := 0; i < 10 && err != nil; i++ {
+			clientConn, err = dialer.Dial(ctx, deviceId, addr)
+			time.Sleep(100 * time.Millisecond)
+		}
+		if err != nil {
+			b.Fatal(err)
+		}
+	}
+
+	data := []byte("hello")
+
+	// Quic does not start a stream until some data is sent through, so send something for the AcceptStream
+	// to fire on the other side.
+	if err := sendMsg(clientConn, data); err != nil {
+		b.Fatal(err)
+	}
+
+	serverConn := <-conns
+
+	if err := recvMsg(serverConn, data); err != nil {
+		b.Fatal(err)
+	}
+
+	h(clientConn, serverConn)
+
+	_ = clientConn.Close()
+	_ = serverConn.Close()
+}
+
+func mustGetCert(b *testing.B) tls.Certificate {
+	f1, err := ioutil.TempFile("", "")
+	if err != nil {
+		b.Fatal(err)
+	}
+	f1.Close()
+	f2, err := ioutil.TempFile("", "")
+	if err != nil {
+		b.Fatal(err)
+	}
+	f2.Close()
+	cert, err := tlsutil.NewCertificate(f1.Name(), f2.Name(), "bench", 10)
+	if err != nil {
+		b.Fatal(err)
+	}
+	_ = os.Remove(f1.Name())
+	_ = os.Remove(f2.Name())
+	return cert
+}

+ 23 - 10
lib/connections/quic_listen.go

@@ -48,6 +48,7 @@ type quicListener struct {
 	factory listenerFactory
 
 	address *url.URL
+	laddr   net.Addr
 	mut     sync.Mutex
 }
 
@@ -87,10 +88,8 @@ func (t *quicListener) serve(ctx context.Context) error {
 		l.Infoln("Listen (BEP/quic):", err)
 		return err
 	}
-	defer func() { _ = packetConn.Close() }()
 
 	svc, conn := stun.New(t.cfg, t, packetConn)
-	defer func() { _ = conn.Close() }()
 	wrapped := &stunConnQUICWrapper{
 		PacketConn: conn,
 		underlying: packetConn.(*net.UDPConn),
@@ -99,7 +98,6 @@ func (t *quicListener) serve(ctx context.Context) error {
 	go svc.Serve(ctx)
 
 	registry.Register(t.uri.Scheme, wrapped)
-	defer registry.Unregister(t.uri.Scheme, wrapped)
 
 	listener, err := quic.Listen(wrapped, t.tlsCfg, quicConfig)
 	if err != nil {
@@ -107,11 +105,23 @@ func (t *quicListener) serve(ctx context.Context) error {
 		return err
 	}
 	t.notifyAddressesChanged(t)
-	defer listener.Close()
-	defer t.clearAddresses(t)
 
 	l.Infof("QUIC listener (%v) starting", packetConn.LocalAddr())
-	defer l.Infof("QUIC listener (%v) shutting down", packetConn.LocalAddr())
+	t.mut.Lock()
+	t.laddr = packetConn.LocalAddr()
+	t.mut.Unlock()
+
+	defer func() {
+		l.Infof("QUIC listener (%v) shutting down", packetConn.LocalAddr())
+		t.mut.Lock()
+		t.laddr = nil
+		t.mut.Unlock()
+		registry.Unregister(t.uri.Scheme, wrapped)
+		t.clearAddresses(t)
+		_ = listener.Close()
+		_ = conn.Close()
+		_ = packetConn.Close()
+	}()
 
 	acceptFailures := 0
 	const maxAcceptFailures = 10
@@ -164,8 +174,8 @@ func (t *quicListener) URI() *url.URL {
 }
 
 func (t *quicListener) WANAddresses() []*url.URL {
-	uris := []*url.URL{t.uri}
 	t.mut.Lock()
+	uris := []*url.URL{maybeReplacePort(t.uri, t.laddr)}
 	if t.address != nil {
 		uris = append(uris, t.address)
 	}
@@ -174,9 +184,12 @@ func (t *quicListener) WANAddresses() []*url.URL {
 }
 
 func (t *quicListener) LANAddresses() []*url.URL {
-	addrs := []*url.URL{t.uri}
-	network := strings.ReplaceAll(t.uri.Scheme, "quic", "udp")
-	addrs = append(addrs, getURLsForAllAdaptersIfUnspecified(network, t.uri)...)
+	t.mut.Lock()
+	uri := maybeReplacePort(t.uri, t.laddr)
+	t.mut.Unlock()
+	addrs := []*url.URL{uri}
+	network := strings.ReplaceAll(uri.Scheme, "quic", "udp")
+	addrs = append(addrs, getURLsForAllAdaptersIfUnspecified(network, uri)...)
 	return addrs
 }
 

+ 26 - 10
lib/connections/tcp_listen.go

@@ -40,6 +40,7 @@ type tcpListener struct {
 
 	natService *nat.Service
 	mapping    *nat.Mapping
+	laddr      net.Addr
 
 	mut sync.RWMutex
 }
@@ -60,26 +61,36 @@ func (t *tcpListener) serve(ctx context.Context) error {
 		l.Infoln("Listen (BEP/tcp):", err)
 		return err
 	}
+
+	// We might bind to :0, so use the port we've been given.
+	tcaddr = listener.Addr().(*net.TCPAddr)
+
 	t.notifyAddressesChanged(t)
 	registry.Register(t.uri.Scheme, tcaddr)
 
-	defer listener.Close()
-	defer t.clearAddresses(t)
-	defer registry.Unregister(t.uri.Scheme, tcaddr)
-
-	l.Infof("TCP listener (%v) starting", listener.Addr())
-	defer l.Infof("TCP listener (%v) shutting down", listener.Addr())
+	l.Infof("TCP listener (%v) starting", tcaddr)
 
 	mapping := t.natService.NewMapping(nat.TCP, tcaddr.IP, tcaddr.Port)
 	mapping.OnChanged(func(_ *nat.Mapping, _, _ []nat.Address) {
 		t.notifyAddressesChanged(t)
 	})
-	defer t.natService.RemoveMapping(mapping)
 
 	t.mut.Lock()
 	t.mapping = mapping
+	t.laddr = tcaddr
 	t.mut.Unlock()
 
+	defer func() {
+		l.Infof("TCP listener (%v) shutting down", tcaddr)
+		t.natService.RemoveMapping(mapping)
+		t.mut.Lock()
+		t.laddr = nil
+		t.mut.Unlock()
+		registry.Unregister(t.uri.Scheme, tcaddr)
+		t.clearAddresses(t)
+		_ = listener.Close()
+	}()
+
 	acceptFailures := 0
 	const maxAcceptFailures = 10
 
@@ -146,8 +157,10 @@ func (t *tcpListener) URI() *url.URL {
 }
 
 func (t *tcpListener) WANAddresses() []*url.URL {
-	uris := []*url.URL{t.uri}
 	t.mut.RLock()
+	uris := []*url.URL{
+		maybeReplacePort(t.uri, t.laddr),
+	}
 	if t.mapping != nil {
 		addrs := t.mapping.ExternalAddresses()
 		for _, addr := range addrs {
@@ -179,8 +192,11 @@ func (t *tcpListener) WANAddresses() []*url.URL {
 }
 
 func (t *tcpListener) LANAddresses() []*url.URL {
-	addrs := []*url.URL{t.uri}
-	addrs = append(addrs, getURLsForAllAdaptersIfUnspecified(t.uri.Scheme, t.uri)...)
+	t.mut.RLock()
+	uri := maybeReplacePort(t.uri, t.laddr)
+	t.mut.RUnlock()
+	addrs := []*url.URL{uri}
+	addrs = append(addrs, getURLsForAllAdaptersIfUnspecified(uri.Scheme, uri)...)
 	return addrs
 }
 

+ 27 - 0
lib/connections/util.go

@@ -117,3 +117,30 @@ func isV4Local(ip net.IP) bool {
 	}
 	return false
 }
+
+func maybeReplacePort(uri *url.URL, laddr net.Addr) *url.URL {
+	if laddr == nil {
+		return uri
+	}
+
+	host, portStr, err := net.SplitHostPort(uri.Host)
+	if err != nil {
+		return uri
+	}
+	port, err := strconv.Atoi(portStr)
+	if err != nil {
+		return uri
+	}
+	if port != 0 {
+		return uri
+	}
+
+	_, lportStr, err := net.SplitHostPort(laddr.String())
+	if err != nil {
+		return uri
+	}
+
+	uriCopy := *uri
+	uriCopy.Host = net.JoinHostPort(host, lportStr)
+	return &uriCopy
+}