瀏覽代碼

lib/connections: Try TCP punchthrough (fixes #4259) (#5753)

Audrius Butkevicius 5 年之前
父節點
當前提交
f619a7f4cc

+ 40 - 26
cmd/stdiscosrv/apisrv.go

@@ -132,11 +132,19 @@ func (s *apiSrv) handler(w http.ResponseWriter, req *http.Request) {
 		log.Println(reqID, req.Method, req.URL)
 	}
 
-	var remoteIP net.IP
+	remoteAddr := &net.TCPAddr{
+		IP:   nil,
+		Port: -1,
+	}
+
 	if s.useHTTP {
-		remoteIP = net.ParseIP(req.Header.Get("X-Forwarded-For"))
+		remoteAddr.IP = net.ParseIP(req.Header.Get("X-Forwarded-For"))
+		if parsedPort, err := strconv.ParseInt(req.Header.Get("X-Forwarded-Port"), 10, 0); err == nil {
+			remoteAddr.Port = int(parsedPort)
+		}
 	} else {
-		addr, err := net.ResolveTCPAddr("tcp", req.RemoteAddr)
+		var err error
+		remoteAddr, err = net.ResolveTCPAddr("tcp", req.RemoteAddr)
 		if err != nil {
 			log.Println("remoteAddr:", err)
 			lw.Header().Set("Retry-After", errorRetryAfterString())
@@ -144,14 +152,13 @@ func (s *apiSrv) handler(w http.ResponseWriter, req *http.Request) {
 			apiRequestsTotal.WithLabelValues("no_remote_addr").Inc()
 			return
 		}
-		remoteIP = addr.IP
 	}
 
 	switch req.Method {
 	case "GET":
 		s.handleGET(ctx, lw, req)
 	case "POST":
-		s.handlePOST(ctx, remoteIP, lw, req)
+		s.handlePOST(ctx, remoteAddr, lw, req)
 	default:
 		http.Error(lw, "Method Not Allowed", http.StatusMethodNotAllowed)
 	}
@@ -217,7 +224,7 @@ func (s *apiSrv) handleGET(ctx context.Context, w http.ResponseWriter, req *http
 	w.Write(bs)
 }
 
-func (s *apiSrv) handlePOST(ctx context.Context, remoteIP net.IP, w http.ResponseWriter, req *http.Request) {
+func (s *apiSrv) handlePOST(ctx context.Context, remoteAddr *net.TCPAddr, w http.ResponseWriter, req *http.Request) {
 	reqID := ctx.Value(idKey).(requestID)
 
 	rawCert := certificateBytes(req)
@@ -244,7 +251,7 @@ func (s *apiSrv) handlePOST(ctx context.Context, remoteIP net.IP, w http.Respons
 
 	deviceID := protocol.NewDeviceID(rawCert)
 
-	addresses := fixupAddresses(remoteIP, ann.Addresses)
+	addresses := fixupAddresses(remoteAddr, ann.Addresses)
 	if len(addresses) == 0 {
 		announceRequestsTotal.WithLabelValues("bad_request").Inc()
 		w.Header().Set("Retry-After", errorRetryAfterString())
@@ -252,7 +259,7 @@ func (s *apiSrv) handlePOST(ctx context.Context, remoteIP net.IP, w http.Respons
 		return
 	}
 
-	if err := s.handleAnnounce(remoteIP, deviceID, addresses); err != nil {
+	if err := s.handleAnnounce(deviceID, addresses); err != nil {
 		announceRequestsTotal.WithLabelValues("internal_error").Inc()
 		w.Header().Set("Retry-After", errorRetryAfterString())
 		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@@ -269,7 +276,7 @@ func (s *apiSrv) Stop() {
 	s.listener.Close()
 }
 
-func (s *apiSrv) handleAnnounce(remote net.IP, deviceID protocol.DeviceID, addresses []string) error {
+func (s *apiSrv) handleAnnounce(deviceID protocol.DeviceID, addresses []string) error {
 	key := deviceID.String()
 	now := time.Now()
 	expire := now.Add(addressExpiryTime).UnixNano()
@@ -364,7 +371,7 @@ func certificateBytes(req *http.Request) []byte {
 
 // fixupAddresses checks the list of addresses, removing invalid ones and
 // replacing unspecified IPs with the given remote IP.
-func fixupAddresses(remote net.IP, addresses []string) []string {
+func fixupAddresses(remote *net.TCPAddr, addresses []string) []string {
 	fixed := make([]string, 0, len(addresses))
 	for _, annAddr := range addresses {
 		uri, err := url.Parse(annAddr)
@@ -384,27 +391,34 @@ func fixupAddresses(remote net.IP, addresses []string) []string {
 			continue
 		}
 
-		if host == "" || ip.IsUnspecified() {
-			// Replace the unspecified IP with the request source.
+		if remote != nil {
+			if host == "" || ip.IsUnspecified() {
+				// Replace the unspecified IP with the request source.
 
-			// ... unless the request source is the loopback address or
-			// multicast/unspecified (can't happen, really).
-			if remote.IsLoopback() || remote.IsMulticast() || remote.IsUnspecified() {
-				continue
-			}
+				// ... unless the request source is the loopback address or
+				// multicast/unspecified (can't happen, really).
+				if remote.IP.IsLoopback() || remote.IP.IsMulticast() || remote.IP.IsUnspecified() {
+					continue
+				}
 
-			// Do not use IPv6 remote address if requested scheme is ...4
-			// (i.e., tcp4, etc.)
-			if strings.HasSuffix(uri.Scheme, "4") && remote.To4() == nil {
-				continue
-			}
+				// Do not use IPv6 remote address if requested scheme is ...4
+				// (i.e., tcp4, etc.)
+				if strings.HasSuffix(uri.Scheme, "4") && remote.IP.To4() == nil {
+					continue
+				}
 
-			// Do not use IPv4 remote address if requested scheme is ...6
-			if strings.HasSuffix(uri.Scheme, "6") && remote.To4() != nil {
-				continue
+				// Do not use IPv4 remote address if requested scheme is ...6
+				if strings.HasSuffix(uri.Scheme, "6") && remote.IP.To4() != nil {
+					continue
+				}
+
+				host = remote.IP.String()
 			}
 
-			host = remote.String()
+			// If zero port was specified, use remote port.
+			if port == "0" && remote.Port > 0 {
+				port = fmt.Sprintf("%d", remote.Port)
+			}
 		}
 
 		uri.Host = net.JoinHostPort(host, port)

+ 32 - 9
cmd/stdiscosrv/apisrv_test.go

@@ -14,7 +14,7 @@ import (
 
 func TestFixupAddresses(t *testing.T) {
 	cases := []struct {
-		remote net.IP
+		remote *net.TCPAddr
 		in     []string
 		out    []string
 	}{
@@ -22,37 +22,53 @@ func TestFixupAddresses(t *testing.T) {
 			in:  []string{"tcp://1.2.3.4:22000"},
 			out: []string{"tcp://1.2.3.4:22000"},
 		}, { // unspecified replaced by remote
-			remote: net.ParseIP("1.2.3.4"),
+			remote: addr("1.2.3.4", 22000),
 			in:     []string{"tcp://:22000", "tcp://192.0.2.42:22000"},
 			out:    []string{"tcp://1.2.3.4:22000", "tcp://192.0.2.42:22000"},
 		}, { // unspecified not used as replacement
-			remote: net.ParseIP("0.0.0.0"),
+			remote: addr("0.0.0.0", 22000),
 			in:     []string{"tcp://:22000", "tcp://192.0.2.42:22000"},
 			out:    []string{"tcp://192.0.2.42:22000"},
 		}, { // unspecified not used as replacement
-			remote: net.ParseIP("::"),
+			remote: addr("::", 22000),
 			in:     []string{"tcp://:22000", "tcp://192.0.2.42:22000"},
 			out:    []string{"tcp://192.0.2.42:22000"},
 		}, { // localhost not used as replacement
-			remote: net.ParseIP("127.0.0.1"),
+			remote: addr("127.0.0.1", 22000),
 			in:     []string{"tcp://:22000", "tcp://192.0.2.42:22000"},
 			out:    []string{"tcp://192.0.2.42:22000"},
 		}, { // localhost not used as replacement
-			remote: net.ParseIP("::1"),
+			remote: addr("::1", 22000),
 			in:     []string{"tcp://:22000", "tcp://192.0.2.42:22000"},
 			out:    []string{"tcp://192.0.2.42:22000"},
 		}, { // multicast not used as replacement
-			remote: net.ParseIP("224.0.0.1"),
+			remote: addr("224.0.0.1", 22000),
 			in:     []string{"tcp://:22000", "tcp://192.0.2.42:22000"},
 			out:    []string{"tcp://192.0.2.42:22000"},
 		}, { // multicast not used as replacement
-			remote: net.ParseIP("ff80::42"),
+			remote: addr("ff80::42", 22000),
 			in:     []string{"tcp://:22000", "tcp://192.0.2.42:22000"},
 			out:    []string{"tcp://192.0.2.42:22000"},
 		}, { // explicitly announced weirdness is also filtered
-			remote: net.ParseIP("192.0.2.42"),
+			remote: addr("192.0.2.42", 22000),
 			in:     []string{"tcp://:22000", "tcp://127.1.2.3:22000", "tcp://[::1]:22000", "tcp://[ff80::42]:22000"},
 			out:    []string{"tcp://192.0.2.42:22000"},
+		}, { // port remapping
+			remote: addr("123.123.123.123", 9000),
+			in:     []string{"tcp://0.0.0.0:0"},
+			out:    []string{"tcp://123.123.123.123:9000"},
+		}, { // unspecified port remapping
+			remote: addr("123.123.123.123", 9000),
+			in:     []string{"tcp://:0"},
+			out:    []string{"tcp://123.123.123.123:9000"},
+		}, { // empty remapping
+			remote: addr("123.123.123.123", 9000),
+			in:     []string{"tcp://"},
+			out:    []string{},
+		}, { // port only remapping
+			remote: addr("123.123.123.123", 9000),
+			in:     []string{"tcp://44.44.44.44:0"},
+			out:    []string{"tcp://44.44.44.44:9000"},
 		},
 	}
 
@@ -63,3 +79,10 @@ func TestFixupAddresses(t *testing.T) {
 		}
 	}
 }
+
+func addr(host string, port int) *net.TCPAddr {
+	return &net.TCPAddr{
+		IP:   net.ParseIP(host),
+		Port: port,
+	}
+}

+ 2 - 17
lib/connections/quic_misc.go

@@ -13,6 +13,7 @@ import (
 	"net"
 
 	"github.com/lucas-clemente/quic-go"
+	"github.com/syncthing/syncthing/lib/util"
 )
 
 var (
@@ -65,21 +66,5 @@ func (q *quicTlsConn) ConnectionState() tls.ConnectionState {
 
 // Sort available packet connections by ip address, preferring unspecified local address.
 func packetConnLess(i interface{}, j interface{}) bool {
-	iIsUnspecified := false
-	jIsUnspecified := false
-	iLocalAddr := i.(net.PacketConn).LocalAddr()
-	jLocalAddr := j.(net.PacketConn).LocalAddr()
-
-	if host, _, err := net.SplitHostPort(iLocalAddr.String()); err == nil {
-		iIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
-	}
-	if host, _, err := net.SplitHostPort(jLocalAddr.String()); err == nil {
-		jIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
-	}
-
-	if jIsUnspecified == iIsUnspecified {
-		return len(iLocalAddr.Network()) < len(jLocalAddr.Network())
-	}
-
-	return iIsUnspecified
+	return util.AddressUnspecifiedLess(i.(net.PacketConn).LocalAddr(), j.(net.PacketConn).LocalAddr())
 }

+ 0 - 92
lib/connections/quic_misc_test.go

@@ -1,92 +0,0 @@
-// Copyright (C) 2019 The Syncthing Authors.
-//
-// This Source Code Form is subject to the terms of the Mozilla Public
-// License, v. 2.0. If a copy of the MPL was not distributed with this file,
-// You can obtain one at http://mozilla.org/MPL/2.0/.
-
-// +build go1.12
-
-package connections
-
-import (
-	"net"
-	"testing"
-	"time"
-)
-
-type mockPacketConn struct {
-	addr mockedAddr
-}
-
-func (mockPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
-	panic("implement me")
-}
-
-func (mockPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
-	panic("implement me")
-}
-
-func (mockPacketConn) Close() error {
-	panic("implement me")
-}
-
-func (c *mockPacketConn) LocalAddr() net.Addr {
-	return c.addr
-}
-
-func (mockPacketConn) SetDeadline(t time.Time) error {
-	panic("implement me")
-}
-
-func (mockPacketConn) SetReadDeadline(t time.Time) error {
-	panic("implement me")
-}
-
-func (mockPacketConn) SetWriteDeadline(t time.Time) error {
-	panic("implement me")
-}
-
-type mockedAddr struct {
-	network string
-	addr    string
-}
-
-func (a mockedAddr) Network() string {
-	return a.network
-}
-
-func (a mockedAddr) String() string {
-	return a.addr
-}
-
-func TestPacketConnLess(t *testing.T) {
-	cases := []struct {
-		netA  string
-		addrA string
-		netB  string
-		addrB string
-	}{
-		// B is assumed the winner.
-		{"tcp", "127.0.0.1:1234", "tcp", ":1235"},
-		{"tcp", "127.0.0.1:1234", "tcp", "0.0.0.0:1235"},
-		{"tcp4", "0.0.0.0:1234", "tcp", "0.0.0.0:1235"}, // tcp4 on the first one
-	}
-
-	for i, testCase := range cases {
-
-		conns := []*mockPacketConn{
-			{mockedAddr{testCase.netA, testCase.addrA}},
-			{mockedAddr{testCase.netB, testCase.addrB}},
-		}
-
-		if packetConnLess(conns[0], conns[1]) {
-			t.Error(i, "unexpected")
-		}
-		if !packetConnLess(conns[1], conns[0]) {
-			t.Error(i, "unexpected")
-		}
-		if packetConnLess(conns[0], conns[0]) || packetConnLess(conns[1], conns[1]) {
-			t.Error(i, "unexpected")
-		}
-	}
-}

+ 1 - 1
lib/connections/tcp_dial.go

@@ -35,7 +35,7 @@ func (d *tcpDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL)
 
 	timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
 	defer cancel()
-	conn, err := dialer.DialContext(timeoutCtx, uri.Scheme, uri.Host)
+	conn, err := dialer.DialContextReusePort(timeoutCtx, uri.Scheme, uri.Host)
 	if err != nil {
 		return internalConn{}, err
 	}

+ 22 - 3
lib/connections/tcp_listen.go

@@ -15,6 +15,7 @@ import (
 	"time"
 
 	"github.com/syncthing/syncthing/lib/config"
+	"github.com/syncthing/syncthing/lib/connections/registry"
 	"github.com/syncthing/syncthing/lib/dialer"
 	"github.com/syncthing/syncthing/lib/nat"
 	"github.com/syncthing/syncthing/lib/util"
@@ -50,14 +51,21 @@ func (t *tcpListener) serve(ctx context.Context) error {
 		return err
 	}
 
-	listener, err := net.ListenTCP(t.uri.Scheme, tcaddr)
+	lc := net.ListenConfig{
+		Control: dialer.ReusePortControl,
+	}
+
+	listener, err := lc.Listen(context.TODO(), t.uri.Scheme, tcaddr.String())
 	if err != nil {
 		l.Infoln("Listen (BEP/tcp):", err)
 		return err
 	}
 	t.notifyAddressesChanged(t)
+	registry.Register(t.uri.Scheme, tcaddr)
+
 	defer listener.Close()
 	defer t.clearAddresses(t)
+	defer registry.Unregister(t.uri.Scheme, tcaddr)
 
 	l.Infof("TCP listener (%v) starting", listener.Addr())
 	defer l.Infof("TCP listener (%v) shutting down", listener.Addr())
@@ -75,9 +83,12 @@ func (t *tcpListener) serve(ctx context.Context) error {
 	acceptFailures := 0
 	const maxAcceptFailures = 10
 
+	// :(, but what can you do.
+	tcpListener := listener.(*net.TCPListener)
+
 	for {
-		listener.SetDeadline(time.Now().Add(time.Second))
-		conn, err := listener.Accept()
+		_ = tcpListener.SetDeadline(time.Now().Add(time.Second))
+		conn, err := tcpListener.Accept()
 		select {
 		case <-ctx.Done():
 			if err == nil {
@@ -156,6 +167,14 @@ func (t *tcpListener) WANAddresses() []*url.URL {
 		}
 	}
 	t.mut.RUnlock()
+
+	// If we support ReusePort, add an unspecified zero port address, which will be resolved by the discovery server
+	// in hopes that TCP punch through works.
+	if dialer.SupportsReusePort {
+		uri := *t.uri
+		uri.Host = "0.0.0.0:0"
+		uris = append([]*url.URL{&uri}, uris...)
+	}
 	return uris
 }
 

+ 51 - 0
lib/dialer/control_unix.go

@@ -0,0 +1,51 @@
+// Copyright (C) 2019 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at https://mozilla.org/MPL/2.0/.
+
+// +build !solaris,!windows
+
+package dialer
+
+import (
+	"syscall"
+
+	"golang.org/x/sys/unix"
+)
+
+var SupportsReusePort = false
+
+func init() {
+	fd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP)
+	if err != nil {
+		l.Debugln("Failed to create a socket", err)
+		return
+	}
+	defer func() { _ = unix.Close(fd) }()
+
+	err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
+	switch {
+	case err == unix.ENOPROTOOPT || err == unix.EINVAL:
+		l.Debugln("SO_REUSEPORT not supported")
+	case err != nil:
+		l.Debugln("Unknown error when determining SO_REUSEPORT support", err)
+	default:
+		l.Debugln("SO_REUSEPORT supported")
+		SupportsReusePort = true
+	}
+}
+
+func ReusePortControl(_, _ string, c syscall.RawConn) error {
+	if !SupportsReusePort {
+		return nil
+	}
+	var opErr error
+	err := c.Control(func(fd uintptr) {
+		opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
+	})
+	if err != nil {
+		return err
+	}
+	return opErr
+}

+ 17 - 0
lib/dialer/control_unsupported.go

@@ -0,0 +1,17 @@
+// Copyright (C) 2019 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at https://mozilla.org/MPL/2.0/.
+
+// +build solaris
+
+package dialer
+
+import "syscall"
+
+var SupportsReusePort = false
+
+func ReusePortControl(_, _ string, _ syscall.RawConn) error {
+	return nil
+}

+ 31 - 0
lib/dialer/control_windows.go

@@ -0,0 +1,31 @@
+// Copyright (C) 2019 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at https://mozilla.org/MPL/2.0/.
+
+// +build windows
+
+package dialer
+
+import (
+	"syscall"
+)
+
+var SupportsReusePort = true
+
+func ReusePortControl(_, _ string, c syscall.RawConn) error {
+	var opErr error
+	err := c.Control(func(fd uintptr) {
+		// On Windows, SO_REUSEADDR is equivalent to SO_REUSEPORT on Linux.
+		opErr = syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+	})
+	if err != nil {
+		l.Debugln("ReusePortControl", err)
+		return err
+	}
+	if opErr != nil {
+		l.Debugln("ReusePortControl", opErr)
+	}
+	return opErr
+}

+ 6 - 0
lib/dialer/internal.go

@@ -13,6 +13,7 @@ import (
 	"os"
 	"time"
 
+	"github.com/syncthing/syncthing/lib/util"
 	"golang.org/x/net/proxy"
 )
 
@@ -60,6 +61,11 @@ func socksDialerFunction(u *url.URL, forward proxy.Dialer) (proxy.Dialer, error)
 	return proxy.SOCKS5("tcp", u.Host, auth, forward)
 }
 
+// Sort available addresses, preferring unspecified address.
+func tcpAddrLess(i interface{}, j interface{}) bool {
+	return util.AddressUnspecifiedLess(i.(*net.TCPAddr), j.(*net.TCPAddr))
+}
+
 // dialerConn is needed because proxy dialed connections have RemoteAddr() pointing at the proxy,
 // which then screws up various things such as IsLAN checks, and "let's populate the relay invitation address from
 // existing connection" shenanigans.

+ 20 - 0
lib/dialer/public.go

@@ -13,6 +13,7 @@ import (
 	"net"
 	"time"
 
+	"github.com/syncthing/syncthing/lib/connections/registry"
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv6"
 	"golang.org/x/net/proxy"
@@ -121,3 +122,22 @@ func dialContextWithFallback(ctx context.Context, fallback proxy.ContextDialer,
 func DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
 	return dialContextWithFallback(ctx, proxy.Direct, network, addr)
 }
+
+// DialContextReusePort tries dialing via proxy if a proxy is configured, and falls back to
+// a direct connection reusing the port from the connections registry, if no proxy is defined, or connecting via proxy
+// fails. If the context has a timeout, the timeout might be applied twice.
+func DialContextReusePort(ctx context.Context, network, addr string) (net.Conn, error) {
+	dialer := &net.Dialer{
+		Control: ReusePortControl,
+	}
+	localAddrInterface := registry.Get(network, tcpAddrLess)
+	if localAddrInterface != nil {
+		if addr, ok := localAddrInterface.(*net.TCPAddr); !ok {
+			return nil, errUnexpectedInterfaceType
+		} else {
+			dialer.LocalAddr = addr
+		}
+	}
+
+	return dialContextWithFallback(ctx, dialer, network, addr)
+}

+ 1 - 1
lib/discover/global.go

@@ -93,7 +93,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
 	var announceClient httpClient = &contextClient{&http.Client{
 		Timeout: requestTimeout,
 		Transport: &http.Transport{
-			DialContext: dialer.DialContext,
+			DialContext: dialer.DialContextReusePort,
 			Proxy:       http.ProxyFromEnvironment,
 			TLSClientConfig: &tls.Config{
 				InsecureSkipVerify: opts.insecure,

+ 20 - 0
lib/util/utils.go

@@ -9,6 +9,7 @@ package util
 import (
 	"context"
 	"fmt"
+	"net"
 	"net/url"
 	"reflect"
 	"strconv"
@@ -176,6 +177,25 @@ func Address(network, host string) string {
 	return u.String()
 }
 
+// AddressUnspecifiedLess is a comparator function preferring least specific network address (most widely listening,
+// namely preferring 0.0.0.0 over some IP), if both IPs are equal, it prefers the less restrictive network (prefers tcp
+// over tcp4)
+func AddressUnspecifiedLess(a, b net.Addr) bool {
+	aIsUnspecified := false
+	bIsUnspecified := false
+	if host, _, err := net.SplitHostPort(a.String()); err == nil {
+		aIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
+	}
+	if host, _, err := net.SplitHostPort(b.String()); err == nil {
+		bIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
+	}
+
+	if aIsUnspecified == bIsUnspecified {
+		return len(a.Network()) < len(b.Network())
+	}
+	return aIsUnspecified
+}
+
 // AsService wraps the given function to implement suture.Service by calling
 // that function on serve and closing the passed channel when Stop is called.
 func AsService(fn func(ctx context.Context), creator string) suture.Service {

+ 44 - 0
lib/util/utils_test.go

@@ -227,6 +227,50 @@ func TestCopyMatching(t *testing.T) {
 	}
 }
 
+type mockedAddr struct {
+	network string
+	addr    string
+}
+
+func (a mockedAddr) Network() string {
+	return a.network
+}
+
+func (a mockedAddr) String() string {
+	return a.addr
+}
+
+func TestInspecifiedAddressLess(t *testing.T) {
+	cases := []struct {
+		netA  string
+		addrA string
+		netB  string
+		addrB string
+	}{
+		// B is assumed the winner.
+		{"tcp", "127.0.0.1:1234", "tcp", ":1235"},
+		{"tcp", "127.0.0.1:1234", "tcp", "0.0.0.0:1235"},
+		{"tcp4", "0.0.0.0:1234", "tcp", "0.0.0.0:1235"}, // tcp4 on the first one
+	}
+
+	for i, testCase := range cases {
+		addrs := []mockedAddr{
+			{testCase.netA, testCase.addrA},
+			{testCase.netB, testCase.addrB},
+		}
+
+		if AddressUnspecifiedLess(addrs[0], addrs[1]) {
+			t.Error(i, "unexpected")
+		}
+		if !AddressUnspecifiedLess(addrs[1], addrs[0]) {
+			t.Error(i, "unexpected")
+		}
+		if AddressUnspecifiedLess(addrs[0], addrs[0]) || AddressUnspecifiedLess(addrs[1], addrs[1]) {
+			t.Error(i, "unexpected")
+		}
+	}
+}
+
 func TestUtilStopTwicePanic(t *testing.T) {
 	name := "foo"
 	s := AsService(func(ctx context.Context) {