Bläddra i källkod

all: Use context in lib/dialer (#6177)

* all: Use context in lib/dialer

* a bit slimmer

* https://github.com/syncthing/syncthing/pull/5753

* bot

* missed adding debug.go

* errors.Cause

* simultaneous dialing

* anti-leak
Simon Frei 5 år sedan
förälder
incheckning
1bae4b7f50

+ 2 - 1
cmd/strelaypoolsrv/main.go

@@ -7,6 +7,7 @@ package main
 import (
 	"bytes"
 	"compress/gzip"
+	"context"
 	"crypto/tls"
 	"encoding/json"
 	"flag"
@@ -480,7 +481,7 @@ func handleRelayTest(request request) {
 	if debug {
 		log.Println("Request for", request.relay)
 	}
-	if !client.TestRelay(request.relay.uri, []tls.Certificate{testCert}, time.Second, 2*time.Second, 3) {
+	if !client.TestRelay(context.TODO(), request.relay.uri, []tls.Certificate{testCert}, time.Second, 2*time.Second, 3) {
 		if debug {
 			log.Println("Test for relay", request.relay, "failed")
 		}

+ 7 - 4
cmd/strelaysrv/testutil/main.go

@@ -4,6 +4,7 @@ package main
 
 import (
 	"bufio"
+	"context"
 	"crypto/tls"
 	"flag"
 	"log"
@@ -19,6 +20,8 @@ import (
 )
 
 func main() {
+	ctx := context.Background()
+
 	log.SetOutput(os.Stdout)
 	log.SetFlags(log.LstdFlags | log.Lshortfile)
 
@@ -76,7 +79,7 @@ func main() {
 		}()
 
 		for {
-			conn, err := client.JoinSession(<-recv)
+			conn, err := client.JoinSession(ctx, <-recv)
 			if err != nil {
 				log.Fatalln("Failed to join", err)
 			}
@@ -90,13 +93,13 @@ func main() {
 			log.Fatal(err)
 		}
 
-		invite, err := client.GetInvitationFromRelay(uri, id, []tls.Certificate{cert}, 10*time.Second)
+		invite, err := client.GetInvitationFromRelay(ctx, uri, id, []tls.Certificate{cert}, 10*time.Second)
 		if err != nil {
 			log.Fatal(err)
 		}
 
 		log.Println("Received invitation", invite)
-		conn, err := client.JoinSession(invite)
+		conn, err := client.JoinSession(ctx, invite)
 		if err != nil {
 			log.Fatalln("Failed to join", err)
 		}
@@ -104,7 +107,7 @@ func main() {
 		connectToStdio(stdin, conn)
 		log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr())
 	} else if test {
-		if client.TestRelay(uri, []tls.Certificate{cert}, time.Second, 2*time.Second, 4) {
+		if client.TestRelay(ctx, uri, []tls.Certificate{cert}, time.Second, 2*time.Second, 4) {
 			log.Println("OK")
 		} else {
 			log.Println("FAIL")

+ 1 - 1
cmd/syncthing/main.go

@@ -512,7 +512,7 @@ func upgradeViaRest() error {
 	r.Header.Set("X-API-Key", cfg.GUI().APIKey)
 
 	tr := &http.Transport{
-		Dial:            dialer.Dial,
+		DialContext:     dialer.DialContext,
 		Proxy:           http.ProxyFromEnvironment,
 		TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
 	}

+ 2 - 2
lib/connections/quic_dial.go

@@ -42,7 +42,7 @@ type quicDialer struct {
 	commonDialer
 }
 
-func (d *quicDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error) {
+func (d *quicDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL) (internalConn, error) {
 	uri = fixupPort(uri, config.DefaultQUICPort)
 
 	addr, err := net.ResolveUDPAddr("udp", uri.Host)
@@ -66,7 +66,7 @@ func (d *quicDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, erro
 		}
 	}
 
-	ctx, cancel := context.WithTimeout(context.Background(), quicOperationTimeout)
+	ctx, cancel := context.WithTimeout(ctx, quicOperationTimeout)
 	defer cancel()
 
 	session, err := quic.DialContext(ctx, conn, addr, uri.Host, d.tlsCfg, quicConfig)

+ 4 - 3
lib/connections/relay_dial.go

@@ -7,6 +7,7 @@
 package connections
 
 import (
+	"context"
 	"crypto/tls"
 	"net/url"
 	"time"
@@ -27,13 +28,13 @@ type relayDialer struct {
 	commonDialer
 }
 
-func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) {
-	inv, err := client.GetInvitationFromRelay(uri, id, d.tlsCfg.Certificates, 10*time.Second)
+func (d *relayDialer) Dial(ctx context.Context, id protocol.DeviceID, uri *url.URL) (internalConn, error) {
+	inv, err := client.GetInvitationFromRelay(ctx, uri, id, d.tlsCfg.Certificates, 10*time.Second)
 	if err != nil {
 		return internalConn{}, err
 	}
 
-	conn, err := client.JoinSession(inv)
+	conn, err := client.JoinSession(ctx, inv)
 	if err != nil {
 		return internalConn{}, err
 	}

+ 6 - 2
lib/connections/relay_listen.go

@@ -13,6 +13,8 @@ import (
 	"sync"
 	"time"
 
+	"github.com/pkg/errors"
+
 	"github.com/syncthing/syncthing/lib/config"
 	"github.com/syncthing/syncthing/lib/dialer"
 	"github.com/syncthing/syncthing/lib/nat"
@@ -70,9 +72,11 @@ func (t *relayListener) serve(ctx context.Context) error {
 				return err
 			}
 
-			conn, err := client.JoinSession(inv)
+			conn, err := client.JoinSession(ctx, inv)
 			if err != nil {
-				l.Infoln("Listen (BEP/relay): joining session:", err)
+				if errors.Cause(err) != context.Canceled {
+					l.Infoln("Listen (BEP/relay): joining session:", err)
+				}
 				continue
 			}
 

+ 8 - 4
lib/connections/service.go

@@ -9,7 +9,6 @@ package connections
 import (
 	"context"
 	"crypto/tls"
-	"errors"
 	"fmt"
 	"net"
 	"net/url"
@@ -31,6 +30,7 @@ import (
 	_ "github.com/syncthing/syncthing/lib/pmp"
 	_ "github.com/syncthing/syncthing/lib/upnp"
 
+	"github.com/pkg/errors"
 	"github.com/thejerf/suture"
 	"golang.org/x/time/rate"
 )
@@ -463,7 +463,7 @@ func (s *service) connect(ctx context.Context) {
 				})
 			}
 
-			conn, ok := s.dialParallel(deviceCfg.DeviceID, dialTargets)
+			conn, ok := s.dialParallel(ctx, deviceCfg.DeviceID, dialTargets)
 			if ok {
 				s.conns <- conn
 			}
@@ -701,6 +701,10 @@ func (s *service) ConnectionStatus() map[string]ConnectionStatusEntry {
 }
 
 func (s *service) setConnectionStatus(address string, err error) {
+	if errors.Cause(err) != context.Canceled {
+		return
+	}
+
 	status := ConnectionStatusEntry{When: time.Now().UTC().Truncate(time.Second)}
 	if err != nil {
 		errStr := err.Error()
@@ -828,7 +832,7 @@ func IsAllowedNetwork(host string, allowed []string) bool {
 	return false
 }
 
-func (s *service) dialParallel(deviceID protocol.DeviceID, dialTargets []dialTarget) (internalConn, bool) {
+func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID, dialTargets []dialTarget) (internalConn, bool) {
 	// Group targets into buckets by priority
 	dialTargetBuckets := make(map[int][]dialTarget, len(dialTargets))
 	for _, tgt := range dialTargets {
@@ -851,7 +855,7 @@ func (s *service) dialParallel(deviceID protocol.DeviceID, dialTargets []dialTar
 		for _, tgt := range tgts {
 			wg.Add(1)
 			go func(tgt dialTarget) {
-				conn, err := tgt.Dial()
+				conn, err := tgt.Dial(ctx)
 				if err == nil {
 					// Closes the connection on error
 					err = s.validateIdentity(conn, deviceID)

+ 4 - 3
lib/connections/structs.go

@@ -7,6 +7,7 @@
 package connections
 
 import (
+	"context"
 	"crypto/tls"
 	"fmt"
 	"io"
@@ -164,7 +165,7 @@ func (d *commonDialer) RedialFrequency() time.Duration {
 }
 
 type genericDialer interface {
-	Dial(protocol.DeviceID, *url.URL) (internalConn, error)
+	Dial(context.Context, protocol.DeviceID, *url.URL) (internalConn, error)
 	RedialFrequency() time.Duration
 }
 
@@ -223,7 +224,7 @@ type dialTarget struct {
 	deviceID protocol.DeviceID
 }
 
-func (t dialTarget) Dial() (internalConn, error) {
+func (t dialTarget) Dial(ctx context.Context) (internalConn, error) {
 	l.Debugln("dialing", t.deviceID, t.uri, "prio", t.priority)
-	return t.dialer.Dial(t.deviceID, t.uri)
+	return t.dialer.Dial(ctx, t.deviceID, t.uri)
 }

+ 5 - 2
lib/connections/tcp_dial.go

@@ -7,6 +7,7 @@
 package connections
 
 import (
+	"context"
 	"crypto/tls"
 	"net/url"
 	"time"
@@ -29,10 +30,12 @@ type tcpDialer struct {
 	commonDialer
 }
 
-func (d *tcpDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error) {
+func (d *tcpDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL) (internalConn, error) {
 	uri = fixupPort(uri, config.DefaultTCPPort)
 
-	conn, err := dialer.DialTimeout(uri.Scheme, uri.Host, 10*time.Second)
+	timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
+	defer cancel()
+	conn, err := dialer.DialContext(timeoutCtx, uri.Scheme, uri.Host)
 	if err != nil {
 		return internalConn{}, err
 	}

+ 23 - 0
lib/dialer/debug.go

@@ -0,0 +1,23 @@
+// Copyright (C) 2019 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at https://mozilla.org/MPL/2.0/.
+
+package dialer
+
+import (
+	"os"
+	"strings"
+
+	"github.com/syncthing/syncthing/lib/logger"
+)
+
+var (
+	l = logger.DefaultLogger.NewFacility("dialer", "Dialing connections")
+	// To run before init() of other files that log on init.
+	_ = func() error {
+		l.SetDebug("dialer", strings.Contains(os.Getenv("STTRACE"), "dialer") || os.Getenv("STTRACE") == "all")
+		return nil
+	}()
+)

+ 3 - 102
lib/dialer/internal.go

@@ -7,34 +7,24 @@
 package dialer
 
 import (
-	"net"
 	"net/http"
 	"net/url"
 	"os"
 	"time"
 
 	"golang.org/x/net/proxy"
-
-	"github.com/syncthing/syncthing/lib/logger"
 )
 
 var (
-	l           = logger.DefaultLogger.NewFacility("dialer", "Dialing connections")
-	proxyDialer proxy.Dialer
-	usingProxy  bool
-	noFallback  = os.Getenv("ALL_PROXY_NO_FALLBACK") != ""
+	noFallback = os.Getenv("ALL_PROXY_NO_FALLBACK") != ""
 )
 
-type dialFunc func(network, addr string) (net.Conn, error)
-
 func init() {
 	proxy.RegisterDialerType("socks", socksDialerFunction)
-	proxyDialer = getDialer(proxy.Direct)
-	usingProxy = proxyDialer != proxy.Direct
 
-	if usingProxy {
+	if proxyDialer := proxy.FromEnvironment(); proxyDialer != proxy.Direct {
 		http.DefaultTransport = &http.Transport{
-			Dial:                Dial,
+			DialContext:         DialContext,
 			Proxy:               http.ProxyFromEnvironment,
 			TLSHandshakeTimeout: 10 * time.Second,
 		}
@@ -55,31 +45,6 @@ func init() {
 	}
 }
 
-func dialWithFallback(proxyDialFunc dialFunc, fallbackDialFunc dialFunc, network, addr string) (net.Conn, error) {
-	conn, err := proxyDialFunc(network, addr)
-	if err == nil {
-		l.Debugf("Dialing %s address %s via proxy - success, %s -> %s", network, addr, conn.LocalAddr(), conn.RemoteAddr())
-		SetTCPOptions(conn)
-		return dialerConn{
-			conn, newDialerAddr(network, addr),
-		}, nil
-	}
-	l.Debugf("Dialing %s address %s via proxy - error %s", network, addr, err)
-
-	if noFallback {
-		return conn, err
-	}
-
-	conn, err = fallbackDialFunc(network, addr)
-	if err == nil {
-		l.Debugf("Dialing %s address %s via fallback - success, %s -> %s", network, addr, conn.LocalAddr(), conn.RemoteAddr())
-		SetTCPOptions(conn)
-	} else {
-		l.Debugf("Dialing %s address %s via fallback - error %s", network, addr, err)
-	}
-	return conn, err
-}
-
 // This is a rip off of proxy.FromURL for "socks" URL scheme
 func socksDialerFunction(u *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
 	var auth *proxy.Auth
@@ -93,67 +58,3 @@ func socksDialerFunction(u *url.URL, forward proxy.Dialer) (proxy.Dialer, error)
 
 	return proxy.SOCKS5("tcp", u.Host, auth, forward)
 }
-
-// This is a rip off of proxy.FromEnvironment with a custom forward dialer
-func getDialer(forward proxy.Dialer) proxy.Dialer {
-	allProxy := os.Getenv("all_proxy")
-	if len(allProxy) == 0 {
-		return forward
-	}
-
-	proxyURL, err := url.Parse(allProxy)
-	if err != nil {
-		return forward
-	}
-	prxy, err := proxy.FromURL(proxyURL, forward)
-	if err != nil {
-		return forward
-	}
-
-	noProxy := os.Getenv("no_proxy")
-	if len(noProxy) == 0 {
-		return prxy
-	}
-
-	perHost := proxy.NewPerHost(prxy, forward)
-	perHost.AddFromString(noProxy)
-	return perHost
-}
-
-type timeoutDirectDialer struct {
-	timeout time.Duration
-}
-
-func (d *timeoutDirectDialer) Dial(network, addr string) (net.Conn, error) {
-	return net.DialTimeout(network, addr, d.timeout)
-}
-
-type dialerConn struct {
-	net.Conn
-	addr net.Addr
-}
-
-func (c dialerConn) RemoteAddr() net.Addr {
-	return c.addr
-}
-
-func newDialerAddr(network, addr string) net.Addr {
-	netaddr, err := net.ResolveIPAddr(network, addr)
-	if err == nil {
-		return netaddr
-	}
-	return fallbackAddr{network, addr}
-}
-
-type fallbackAddr struct {
-	network string
-	addr    string
-}
-
-func (a fallbackAddr) Network() string {
-	return a.network
-}
-
-func (a fallbackAddr) String() string {
-	return a.addr
-}

+ 51 - 43
lib/dialer/public.go

@@ -7,49 +7,18 @@
 package dialer
 
 import (
+	"context"
+	"errors"
 	"fmt"
 	"net"
 	"time"
 
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv6"
+	"golang.org/x/net/proxy"
 )
 
-// Dial tries dialing via proxy if a proxy is configured, and falls back to
-// a direct connection if no proxy is defined, or connecting via proxy fails.
-func Dial(network, addr string) (net.Conn, error) {
-	if usingProxy {
-		return dialWithFallback(proxyDialer.Dial, net.Dial, network, addr)
-	}
-	return net.Dial(network, addr)
-}
-
-// DialTimeout tries dialing via proxy with a timeout if a proxy is configured,
-// and falls back to a direct connection if no proxy is defined, or connecting
-// via proxy fails. The timeout can potentially be applied twice, once trying
-// to connect via the proxy connection, and second time trying to connect
-// directly.
-func DialTimeout(network, addr string, timeout time.Duration) (net.Conn, error) {
-	if usingProxy {
-		// Because the proxy package is poorly structured, we have to
-		// construct a struct that matches proxy.Dialer but has a timeout
-		// and reconstrcut the proxy dialer using that, in order to be able to
-		// set a timeout.
-		dd := &timeoutDirectDialer{
-			timeout: timeout,
-		}
-		// Check if the dialer we are getting is not timeoutDirectDialer we just
-		// created. It could happen that usingProxy is true, but getDialer
-		// returns timeoutDirectDialer due to env vars changing.
-		if timeoutProxyDialer := getDialer(dd); timeoutProxyDialer != dd {
-			directDialFunc := func(inetwork, iaddr string) (net.Conn, error) {
-				return net.DialTimeout(inetwork, iaddr, timeout)
-			}
-			return dialWithFallback(timeoutProxyDialer.Dial, directDialFunc, network, addr)
-		}
-	}
-	return net.DialTimeout(network, addr, timeout)
-}
+var errUnexpectedInterfaceType = errors.New("unexpected interface type")
 
 // SetTCPOptions sets our default TCP options on a TCP connection, possibly
 // digging through dialerConn to extract the *net.TCPConn
@@ -70,10 +39,6 @@ func SetTCPOptions(conn net.Conn) error {
 			return err
 		}
 		return nil
-
-	case dialerConn:
-		return SetTCPOptions(conn.Conn)
-
 	default:
 		return fmt.Errorf("unknown connection type %T", conn)
 	}
@@ -89,11 +54,54 @@ func SetTrafficClass(conn net.Conn, class int) error {
 			return e1
 		}
 		return e2
-
-	case dialerConn:
-		return SetTrafficClass(conn.Conn, class)
-
 	default:
 		return fmt.Errorf("unknown connection type %T", conn)
 	}
 }
+
+func dialContextWithFallback(ctx context.Context, fallback proxy.ContextDialer, network, addr string) (net.Conn, error) {
+	dialer, ok := proxy.FromEnvironment().(proxy.ContextDialer)
+	if !ok {
+		return nil, errUnexpectedInterfaceType
+	}
+	if dialer == proxy.Direct {
+		return fallback.DialContext(ctx, network, addr)
+	}
+	if noFallback {
+		return dialer.DialContext(ctx, network, addr)
+	}
+
+	ctx, cancel := context.WithCancel(ctx)
+	defer cancel()
+	var proxyConn, fallbackConn net.Conn
+	var proxyErr, fallbackErr error
+	proxyDone := make(chan struct{})
+	fallbackDone := make(chan struct{})
+	go func() {
+		proxyConn, proxyErr = dialer.DialContext(ctx, network, addr)
+		close(proxyDone)
+	}()
+	go func() {
+		fallbackConn, fallbackErr = fallback.DialContext(ctx, network, addr)
+		close(fallbackDone)
+	}()
+	<-proxyDone
+	if proxyErr == nil {
+		go func() {
+			<-fallbackDone
+			if fallbackErr == nil {
+				fallbackConn.Close()
+			}
+		}()
+		return proxyConn, nil
+	}
+	<-fallbackDone
+	return fallbackConn, fallbackErr
+}
+
+// DialContext dials via context and/or directly, depending on how it is configured.
+// If dialing via proxy and allowing fallback, dialing for both happens simultaneously
+// and the proxy connection is returned if successful.
+func DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+	return dialContextWithFallback(ctx, proxy.Direct, network, addr)
+}

+ 4 - 4
lib/discover/global.go

@@ -92,8 +92,8 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
 	var announceClient httpClient = &http.Client{
 		Timeout: requestTimeout,
 		Transport: &http.Transport{
-			Dial:  dialer.Dial,
-			Proxy: http.ProxyFromEnvironment,
+			DialContext: dialer.DialContext,
+			Proxy:       http.ProxyFromEnvironment,
 			TLSClientConfig: &tls.Config{
 				InsecureSkipVerify: opts.insecure,
 				Certificates:       []tls.Certificate{cert},
@@ -109,8 +109,8 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
 	var queryClient httpClient = &http.Client{
 		Timeout: requestTimeout,
 		Transport: &http.Transport{
-			Dial:  dialer.Dial,
-			Proxy: http.ProxyFromEnvironment,
+			DialContext: dialer.DialContext,
+			Proxy:       http.ProxyFromEnvironment,
 			TLSClientConfig: &tls.Config{
 				InsecureSkipVerify: opts.insecure,
 			},

+ 2 - 2
lib/nat/registry.go

@@ -12,7 +12,7 @@ import (
 	"time"
 )
 
-type DiscoverFunc func(renewal, timeout time.Duration) []Device
+type DiscoverFunc func(ctx context.Context, renewal, timeout time.Duration) []Device
 
 var providers []DiscoverFunc
 
@@ -30,7 +30,7 @@ func discoverAll(ctx context.Context, renewal, timeout time.Duration) map[string
 	for _, discoverFunc := range providers {
 		go func(f DiscoverFunc) {
 			defer wg.Done()
-			for _, dev := range f(renewal, timeout) {
+			for _, dev := range f(ctx, renewal, timeout) {
 				select {
 				case c <- dev:
 				case <-ctx.Done():

+ 7 - 4
lib/osutil/ping.go

@@ -7,6 +7,7 @@
 package osutil
 
 import (
+	"context"
 	"net/url"
 	"time"
 
@@ -16,9 +17,11 @@ import (
 // TCPPing returns the duration required to establish a TCP connection
 // to the given host. ICMP packets require root privileges, hence why we use
 // tcp.
-func TCPPing(address string) (time.Duration, error) {
+func TCPPing(ctx context.Context, address string) (time.Duration, error) {
 	start := time.Now()
-	conn, err := dialer.DialTimeout("tcp", address, time.Second)
+	ctx, cancel := context.WithTimeout(ctx, time.Second)
+	defer cancel()
+	conn, err := dialer.DialContext(ctx, "tcp", address)
 	if conn != nil {
 		conn.Close()
 	}
@@ -27,11 +30,11 @@ func TCPPing(address string) (time.Duration, error) {
 
 // GetLatencyForURL parses the given URL, tries opening a TCP connection to it
 // and returns the time it took to establish the connection.
-func GetLatencyForURL(addr string) (time.Duration, error) {
+func GetLatencyForURL(ctx context.Context, addr string) (time.Duration, error) {
 	uri, err := url.Parse(addr)
 	if err != nil {
 		return 0, err
 	}
 
-	return TCPPing(uri.Host)
+	return TCPPing(ctx, uri.Host)
 }

+ 5 - 2
lib/pmp/pmp.go

@@ -7,6 +7,7 @@
 package pmp
 
 import (
+	"context"
 	"fmt"
 	"net"
 	"strings"
@@ -21,7 +22,7 @@ func init() {
 	nat.Register(Discover)
 }
 
-func Discover(renewal, timeout time.Duration) []nat.Device {
+func Discover(ctx context.Context, renewal, timeout time.Duration) []nat.Device {
 	ip, err := gateway.DiscoverGateway()
 	if err != nil {
 		l.Debugln("Failed to discover gateway", err)
@@ -44,7 +45,9 @@ func Discover(renewal, timeout time.Duration) []nat.Device {
 
 	var localIP net.IP
 	// Port comes from the natpmp package
-	conn, err := net.DialTimeout("udp", net.JoinHostPort(ip.String(), "5351"), timeout)
+	timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
+	defer cancel()
+	conn, err := (&net.Dialer{}).DialContext(timeoutCtx, "udp", net.JoinHostPort(ip.String(), "5351"))
 	if err == nil {
 		conn.Close()
 		localIPAddress, _, err := net.SplitHostPort(conn.LocalAddr().String())

+ 1 - 1
lib/rc/rc.go

@@ -166,7 +166,7 @@ func (p *Process) Get(path string) ([]byte, error) {
 	client := &http.Client{
 		Timeout: 30 * time.Second,
 		Transport: &http.Transport{
-			Dial:              dialer.Dial,
+			DialContext:       dialer.DialContext,
 			Proxy:             http.ProxyFromEnvironment,
 			DisableKeepAlives: true,
 		},

+ 1 - 1
lib/relay/client/dynamic.go

@@ -153,7 +153,7 @@ func relayAddressesOrder(ctx context.Context, input []string) []string {
 	buckets := make(map[int][]string)
 
 	for _, relay := range input {
-		latency, err := osutil.GetLatencyForURL(relay)
+		latency, err := osutil.GetLatencyForURL(ctx, relay)
 		if err != nil {
 			latency = time.Hour
 		}

+ 11 - 6
lib/relay/client/methods.go

@@ -3,6 +3,7 @@
 package client
 
 import (
+	"context"
 	"crypto/tls"
 	"fmt"
 	"net"
@@ -16,12 +17,14 @@ import (
 	"github.com/syncthing/syncthing/lib/relay/protocol"
 )
 
-func GetInvitationFromRelay(uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate, timeout time.Duration) (protocol.SessionInvitation, error) {
+func GetInvitationFromRelay(ctx context.Context, uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate, timeout time.Duration) (protocol.SessionInvitation, error) {
 	if uri.Scheme != "relay" {
 		return protocol.SessionInvitation{}, fmt.Errorf("Unsupported relay scheme: %v", uri.Scheme)
 	}
 
-	rconn, err := dialer.DialTimeout("tcp", uri.Host, timeout)
+	ctx, cancel := context.WithTimeout(ctx, timeout)
+	defer cancel()
+	rconn, err := dialer.DialContext(ctx, "tcp", uri.Host)
 	if err != nil {
 		return protocol.SessionInvitation{}, err
 	}
@@ -63,10 +66,12 @@ func GetInvitationFromRelay(uri *url.URL, id syncthingprotocol.DeviceID, certs [
 	}
 }
 
-func JoinSession(invitation protocol.SessionInvitation) (net.Conn, error) {
+func JoinSession(ctx context.Context, invitation protocol.SessionInvitation) (net.Conn, error) {
 	addr := net.JoinHostPort(net.IP(invitation.Address).String(), strconv.Itoa(int(invitation.Port)))
 
-	conn, err := dialer.Dial("tcp", addr)
+	ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
+	defer cancel()
+	conn, err := dialer.DialContext(ctx, "tcp", addr)
 	if err != nil {
 		return nil, err
 	}
@@ -99,7 +104,7 @@ func JoinSession(invitation protocol.SessionInvitation) (net.Conn, error) {
 	}
 }
 
-func TestRelay(uri *url.URL, certs []tls.Certificate, sleep, timeout time.Duration, times int) bool {
+func TestRelay(ctx context.Context, uri *url.URL, certs []tls.Certificate, sleep, timeout time.Duration, times int) bool {
 	id := syncthingprotocol.NewDeviceID(certs[0].Certificate[0])
 	invs := make(chan protocol.SessionInvitation, 1)
 	c, err := NewClient(uri, certs, invs, timeout)
@@ -114,7 +119,7 @@ func TestRelay(uri *url.URL, certs []tls.Certificate, sleep, timeout time.Durati
 	}()
 
 	for i := 0; i < times; i++ {
-		_, err := GetInvitationFromRelay(uri, id, certs, timeout)
+		_, err := GetInvitationFromRelay(ctx, uri, id, certs, timeout)
 		if err == nil {
 			return true
 		}

+ 5 - 3
lib/relay/client/static.go

@@ -47,7 +47,7 @@ func newStaticClient(uri *url.URL, certs []tls.Certificate, invitations chan pro
 }
 
 func (c *staticClient) serve(ctx context.Context) error {
-	if err := c.connect(); err != nil {
+	if err := c.connect(ctx); err != nil {
 		l.Infof("Could not connect to relay %s: %s", c.uri, err)
 		return err
 	}
@@ -146,13 +146,15 @@ func (c *staticClient) URI() *url.URL {
 	return c.uri
 }
 
-func (c *staticClient) connect() error {
+func (c *staticClient) connect(ctx context.Context) error {
 	if c.uri.Scheme != "relay" {
 		return fmt.Errorf("unsupported relay scheme: %v", c.uri.Scheme)
 	}
 
 	t0 := time.Now()
-	tcpConn, err := dialer.DialTimeout("tcp", c.uri.Host, c.connectTimeout)
+	timeoutCtx, cancel := context.WithTimeout(ctx, time.Second)
+	defer cancel()
+	tcpConn, err := dialer.DialContext(timeoutCtx, "tcp", c.uri.Host)
 	if err != nil {
 		return err
 	}

+ 2 - 2
lib/upgrade/upgrade_supported.go

@@ -66,8 +66,8 @@ const (
 var insecureHTTP = &http.Client{
 	Timeout: readTimeout,
 	Transport: &http.Transport{
-		Dial:  dialer.Dial,
-		Proxy: http.ProxyFromEnvironment,
+		DialContext: dialer.DialContext,
+		Proxy:       http.ProxyFromEnvironment,
 		TLSClientConfig: &tls.Config{
 			InsecureSkipVerify: true,
 		},

+ 16 - 10
lib/upnp/upnp.go

@@ -35,8 +35,8 @@ package upnp
 import (
 	"bufio"
 	"bytes"
+	"context"
 	"encoding/xml"
-	"errors"
 	"fmt"
 	"io/ioutil"
 	"net"
@@ -47,6 +47,8 @@ import (
 	"sync"
 	"time"
 
+	"github.com/pkg/errors"
+
 	"github.com/syncthing/syncthing/lib/dialer"
 	"github.com/syncthing/syncthing/lib/nat"
 )
@@ -83,7 +85,7 @@ func (e UnsupportedDeviceTypeError) Error() string {
 
 // Discover discovers UPnP InternetGatewayDevices.
 // The order in which the devices appear in the results list is not deterministic.
-func Discover(renewal, timeout time.Duration) []nat.Device {
+func Discover(ctx context.Context, renewal, timeout time.Duration) []nat.Device {
 	var results []nat.Device
 
 	interfaces, err := net.Interfaces()
@@ -105,7 +107,7 @@ func Discover(renewal, timeout time.Duration) []nat.Device {
 		for _, deviceType := range []string{"urn:schemas-upnp-org:device:InternetGatewayDevice:1", "urn:schemas-upnp-org:device:InternetGatewayDevice:2"} {
 			wg.Add(1)
 			go func(intf net.Interface, deviceType string) {
-				discover(&intf, deviceType, timeout, resultChan)
+				discover(ctx, &intf, deviceType, timeout, resultChan)
 				wg.Done()
 			}(intf, deviceType)
 		}
@@ -135,7 +137,7 @@ nextResult:
 
 // Search for UPnP InternetGatewayDevices for <timeout> seconds.
 // The order in which the devices appear in the result list is not deterministic
-func discover(intf *net.Interface, deviceType string, timeout time.Duration, results chan<- nat.Device) {
+func discover(ctx context.Context, intf *net.Interface, deviceType string, timeout time.Duration, results chan<- nat.Device) {
 	ssdp := &net.UDPAddr{IP: []byte{239, 255, 255, 250}, Port: 1900}
 
 	tpl := `M-SEARCH * HTTP/1.1
@@ -187,13 +189,15 @@ USER-AGENT: syncthing/1.0
 			}
 			break
 		}
-		igds, err := parseResponse(deviceType, resp[:n])
+		igds, err := parseResponse(ctx, deviceType, resp[:n])
 		if err != nil {
 			switch err.(type) {
 			case *UnsupportedDeviceTypeError:
 				l.Debugln(err.Error())
 			default:
-				l.Infoln("UPnP parse:", err)
+				if errors.Cause(err) != context.Canceled {
+					l.Infoln("UPnP parse:", err)
+				}
 			}
 			continue
 		}
@@ -205,7 +209,7 @@ USER-AGENT: syncthing/1.0
 	l.Debugln("Discovery for device type", deviceType, "on", intf.Name, "finished.")
 }
 
-func parseResponse(deviceType string, resp []byte) ([]IGDService, error) {
+func parseResponse(ctx context.Context, deviceType string, resp []byte) ([]IGDService, error) {
 	l.Debugln("Handling UPnP response:\n\n" + string(resp))
 
 	reader := bufio.NewReader(bytes.NewBuffer(resp))
@@ -257,7 +261,7 @@ func parseResponse(deviceType string, resp []byte) ([]IGDService, error) {
 	// We do this in a fairly roundabout way by connecting to the IGD and
 	// checking the address of the local end of the socket. I'm open to
 	// suggestions on a better way to do this...
-	localIPAddress, err := localIP(deviceDescriptionURL)
+	localIPAddress, err := localIP(ctx, deviceDescriptionURL)
 	if err != nil {
 		return nil, err
 	}
@@ -270,8 +274,10 @@ func parseResponse(deviceType string, resp []byte) ([]IGDService, error) {
 	return services, nil
 }
 
-func localIP(url *url.URL) (net.IP, error) {
-	conn, err := dialer.DialTimeout("tcp", url.Host, time.Second)
+func localIP(ctx context.Context, url *url.URL) (net.IP, error) {
+	timeoutCtx, cancel := context.WithTimeout(ctx, time.Second)
+	defer cancel()
+	conn, err := dialer.DialContext(timeoutCtx, "tcp", url.Host)
 	if err != nil {
 		return nil, err
 	}

+ 2 - 2
lib/ur/usage_report.go

@@ -373,8 +373,8 @@ func (s *Service) sendUsageReport() error {
 
 	client := &http.Client{
 		Transport: &http.Transport{
-			Dial:  dialer.Dial,
-			Proxy: http.ProxyFromEnvironment,
+			DialContext: dialer.DialContext,
+			Proxy:       http.ProxyFromEnvironment,
 			TLSClientConfig: &tls.Config{
 				InsecureSkipVerify: s.cfg.Options().URPostInsecurely,
 			},

+ 3 - 0
lib/util/utils.go

@@ -236,6 +236,9 @@ func (s *service) Serve() {
 
 	var err error
 	defer func() {
+		if err == context.Canceled {
+			err = nil
+		}
 		s.mut.Lock()
 		s.err = err
 		close(s.stopped)