浏览代码

lib: Faster termination on exit (ref #6319) (#6329)

Simon Frei 5 年之前
父节点
当前提交
c3637f2191

+ 2 - 1
cmd/stfinddevice/main.go

@@ -7,6 +7,7 @@
 package main
 
 import (
+	"context"
 	"crypto/tls"
 	"errors"
 	"flag"
@@ -95,7 +96,7 @@ func checkServer(deviceID protocol.DeviceID, server string) checkResult {
 	})
 
 	go func() {
-		addresses, err := disco.Lookup(deviceID)
+		addresses, err := disco.Lookup(context.Background(), deviceID)
 		res <- checkResult{addresses: addresses, error: err}
 	}()
 

+ 2 - 1
lib/api/mocked_discovery_test.go

@@ -7,6 +7,7 @@
 package api
 
 import (
+	"context"
 	"time"
 
 	"github.com/syncthing/syncthing/lib/discover"
@@ -26,7 +27,7 @@ func (m *mockedCachingMux) Stop() {
 
 // from events.Finder
 
-func (m *mockedCachingMux) Lookup(deviceID protocol.DeviceID) (direct []string, err error) {
+func (m *mockedCachingMux) Lookup(ctx context.Context, deviceID protocol.DeviceID) (direct []string, err error) {
 	return nil, nil
 }
 

+ 7 - 1
lib/connections/service.go

@@ -360,6 +360,12 @@ func (s *service) connect(ctx context.Context) {
 		var seen []string
 
 		for _, deviceCfg := range cfg.Devices {
+			select {
+			case <-ctx.Done():
+				return
+			default:
+			}
+
 			deviceID := deviceCfg.DeviceID
 			if deviceID == s.myID {
 				continue
@@ -380,7 +386,7 @@ func (s *service) connect(ctx context.Context) {
 			for _, addr := range deviceCfg.Addresses {
 				if addr == "dynamic" {
 					if s.discoverer != nil {
-						if t, err := s.discoverer.Lookup(deviceID); err == nil {
+						if t, err := s.discoverer.Lookup(ctx, deviceID); err == nil {
 							addrs = append(addrs, t...)
 						}
 					}

+ 3 - 2
lib/discover/cache.go

@@ -7,6 +7,7 @@
 package discover
 
 import (
+	"context"
 	"sort"
 	stdsync "sync"
 	"time"
@@ -73,7 +74,7 @@ func (m *cachingMux) Add(finder Finder, cacheTime, negCacheTime time.Duration) {
 
 // Lookup attempts to resolve the device ID using any of the added Finders,
 // while obeying the cache settings.
-func (m *cachingMux) Lookup(deviceID protocol.DeviceID) (addresses []string, err error) {
+func (m *cachingMux) Lookup(ctx context.Context, deviceID protocol.DeviceID) (addresses []string, err error) {
 	m.mut.RLock()
 	for i, finder := range m.finders {
 		if cacheEntry, ok := m.caches[i].Get(deviceID); ok {
@@ -99,7 +100,7 @@ func (m *cachingMux) Lookup(deviceID protocol.DeviceID) (addresses []string, err
 		}
 
 		// Perform the actual lookup and cache the result.
-		if addrs, err := finder.Lookup(deviceID); err == nil {
+		if addrs, err := finder.Lookup(ctx, deviceID); err == nil {
 			l.Debugln("lookup for", deviceID, "at", finder)
 			l.Debugln("  addresses:", addrs)
 			addresses = append(addresses, addrs...)

+ 8 - 5
lib/discover/cache_test.go

@@ -7,6 +7,7 @@
 package discover
 
 import (
+	"context"
 	"reflect"
 	"testing"
 	"time"
@@ -39,7 +40,9 @@ func TestCacheUnique(t *testing.T) {
 	f1 := &fakeDiscovery{addresses0}
 	c.Add(f1, time.Minute, 0)
 
-	addr, err := c.Lookup(protocol.LocalDeviceID)
+	ctx := context.Background()
+
+	addr, err := c.Lookup(ctx, protocol.LocalDeviceID)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -53,7 +56,7 @@ func TestCacheUnique(t *testing.T) {
 	f2 := &fakeDiscovery{addresses1}
 	c.Add(f2, time.Minute, 0)
 
-	addr, err = c.Lookup(protocol.LocalDeviceID)
+	addr, err = c.Lookup(ctx, protocol.LocalDeviceID)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -66,7 +69,7 @@ type fakeDiscovery struct {
 	addresses []string
 }
 
-func (f *fakeDiscovery) Lookup(deviceID protocol.DeviceID) (addresses []string, err error) {
+func (f *fakeDiscovery) Lookup(_ context.Context, deviceID protocol.DeviceID) (addresses []string, err error) {
 	return f.addresses, nil
 }
 
@@ -96,7 +99,7 @@ func TestCacheSlowLookup(t *testing.T) {
 	// Start a lookup, which will take at least a second
 
 	t0 := time.Now()
-	go c.Lookup(protocol.LocalDeviceID)
+	go c.Lookup(context.Background(), protocol.LocalDeviceID)
 	<-started // The slow lookup method has been called so we're inside the lock
 
 	// It should be possible to get ChildErrors while it's running
@@ -116,7 +119,7 @@ type slowDiscovery struct {
 	started chan struct{}
 }
 
-func (f *slowDiscovery) Lookup(deviceID protocol.DeviceID) (addresses []string, err error) {
+func (f *slowDiscovery) Lookup(_ context.Context, deviceID protocol.DeviceID) (addresses []string, err error) {
 	close(f.started)
 	time.Sleep(f.delay)
 	return nil, nil

+ 2 - 1
lib/discover/discover.go

@@ -7,6 +7,7 @@
 package discover
 
 import (
+	"context"
 	"time"
 
 	"github.com/syncthing/syncthing/lib/protocol"
@@ -15,7 +16,7 @@ import (
 
 // A Finder provides lookup services of some kind.
 type Finder interface {
-	Lookup(deviceID protocol.DeviceID) (address []string, err error)
+	Lookup(ctx context.Context, deviceID protocol.DeviceID) (address []string, err error)
 	Error() error
 	String() string
 	Cache() map[protocol.DeviceID]CacheEntry

+ 44 - 15
lib/discover/global.go

@@ -41,8 +41,8 @@ type globalClient struct {
 }
 
 type httpClient interface {
-	Get(url string) (*http.Response, error)
-	Post(url, ctype string, data io.Reader) (*http.Response, error)
+	Get(ctx context.Context, url string) (*http.Response, error)
+	Post(ctx context.Context, url, ctype string, data io.Reader) (*http.Response, error)
 }
 
 const (
@@ -89,7 +89,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
 	// The http.Client used for announcements. It needs to have our
 	// certificate to prove our identity, and may or may not verify the server
 	// certificate depending on the insecure setting.
-	var announceClient httpClient = &http.Client{
+	var announceClient httpClient = &contextClient{&http.Client{
 		Timeout: requestTimeout,
 		Transport: &http.Transport{
 			DialContext: dialer.DialContext,
@@ -99,14 +99,14 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
 				Certificates:       []tls.Certificate{cert},
 			},
 		},
-	}
+	}}
 	if opts.id != "" {
 		announceClient = newIDCheckingHTTPClient(announceClient, devID)
 	}
 
 	// The http.Client used for queries. We don't need to present our
 	// certificate here, so lets not include it. May be insecure if requested.
-	var queryClient httpClient = &http.Client{
+	var queryClient httpClient = &contextClient{&http.Client{
 		Timeout: requestTimeout,
 		Transport: &http.Transport{
 			DialContext: dialer.DialContext,
@@ -115,7 +115,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
 				InsecureSkipVerify: opts.insecure,
 			},
 		},
-	}
+	}}
 	if opts.id != "" {
 		queryClient = newIDCheckingHTTPClient(queryClient, devID)
 	}
@@ -139,7 +139,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
 }
 
 // Lookup returns the list of addresses where the given device is available
-func (c *globalClient) Lookup(device protocol.DeviceID) (addresses []string, err error) {
+func (c *globalClient) Lookup(ctx context.Context, device protocol.DeviceID) (addresses []string, err error) {
 	if c.noLookup {
 		return nil, lookupError{
 			error:    errors.New("lookups not supported"),
@@ -156,7 +156,7 @@ func (c *globalClient) Lookup(device protocol.DeviceID) (addresses []string, err
 	q.Set("device", device.String())
 	qURL.RawQuery = q.Encode()
 
-	resp, err := c.queryClient.Get(qURL.String())
+	resp, err := c.queryClient.Get(ctx, qURL.String())
 	if err != nil {
 		l.Debugln("globalClient.Lookup", qURL, err)
 		return nil, err
@@ -211,7 +211,7 @@ func (c *globalClient) serve(ctx context.Context) {
 			timer.Reset(2 * time.Second)
 
 		case <-timer.C:
-			c.sendAnnouncement(timer)
+			c.sendAnnouncement(ctx, timer)
 
 		case <-ctx.Done():
 			return
@@ -219,7 +219,7 @@ func (c *globalClient) serve(ctx context.Context) {
 	}
 }
 
-func (c *globalClient) sendAnnouncement(timer *time.Timer) {
+func (c *globalClient) sendAnnouncement(ctx context.Context, timer *time.Timer) {
 	var ann announcement
 	if c.addrList != nil {
 		ann.Addresses = c.addrList.ExternalAddresses()
@@ -239,7 +239,7 @@ func (c *globalClient) sendAnnouncement(timer *time.Timer) {
 
 	l.Debugf("Announcement: %s", postData)
 
-	resp, err := c.announceClient.Post(c.server, "application/json", bytes.NewReader(postData))
+	resp, err := c.announceClient.Post(ctx, c.server, "application/json", bytes.NewReader(postData))
 	if err != nil {
 		l.Debugln("announce POST:", err)
 		c.setError(err)
@@ -362,8 +362,8 @@ func (c *idCheckingHTTPClient) check(resp *http.Response) error {
 	return nil
 }
 
-func (c *idCheckingHTTPClient) Get(url string) (*http.Response, error) {
-	resp, err := c.httpClient.Get(url)
+func (c *idCheckingHTTPClient) Get(ctx context.Context, url string) (*http.Response, error) {
+	resp, err := c.httpClient.Get(ctx, url)
 	if err != nil {
 		return nil, err
 	}
@@ -374,8 +374,8 @@ func (c *idCheckingHTTPClient) Get(url string) (*http.Response, error) {
 	return resp, nil
 }
 
-func (c *idCheckingHTTPClient) Post(url, ctype string, data io.Reader) (*http.Response, error) {
-	resp, err := c.httpClient.Post(url, ctype, data)
+func (c *idCheckingHTTPClient) Post(ctx context.Context, url, ctype string, data io.Reader) (*http.Response, error) {
+	resp, err := c.httpClient.Post(ctx, url, ctype, data)
 	if err != nil {
 		return nil, err
 	}
@@ -403,3 +403,32 @@ func (e *errorHolder) Error() error {
 	e.mut.Unlock()
 	return err
 }
+
+type contextClient struct {
+	*http.Client
+}
+
+func (c *contextClient) Get(ctx context.Context, url string) (*http.Response, error) {
+	// For <go1.13 compatibility. Use the following commented line once that
+	// isn't required anymore.
+	// req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+	req, err := http.NewRequest("GET", url, nil)
+	if err != nil {
+		return nil, err
+	}
+	req.Cancel = ctx.Done()
+	return c.Client.Do(req)
+}
+
+func (c *contextClient) Post(ctx context.Context, url, ctype string, data io.Reader) (*http.Response, error) {
+	// For <go1.13 compatibility. Use the following commented line once that
+	// isn't required anymore.
+	// req, err := http.NewRequestWithContext(ctx, "POST", url, data)
+	req, err := http.NewRequest("POST", url, data)
+	if err != nil {
+		return nil, err
+	}
+	req.Cancel = ctx.Done()
+	req.Header.Set("Content-Type", ctype)
+	return c.Client.Do(req)
+}

+ 2 - 1
lib/discover/global_test.go

@@ -7,6 +7,7 @@
 package discover
 
 import (
+	"context"
 	"crypto/tls"
 	"io/ioutil"
 	"net"
@@ -225,7 +226,7 @@ func testLookup(url string) ([]string, error) {
 	go disco.Serve()
 	defer disco.Stop()
 
-	return disco.Lookup(protocol.LocalDeviceID)
+	return disco.Lookup(context.Background(), protocol.LocalDeviceID)
 }
 
 type fakeDiscoveryServer struct {

+ 1 - 1
lib/discover/local.go

@@ -91,7 +91,7 @@ func NewLocal(id protocol.DeviceID, addr string, addrList AddressLister, evLogge
 }
 
 // Lookup returns a list of addresses the device is available at.
-func (c *localClient) Lookup(device protocol.DeviceID) (addresses []string, err error) {
+func (c *localClient) Lookup(_ context.Context, device protocol.DeviceID) (addresses []string, err error) {
 	if cache, ok := c.Get(device); ok {
 		if time.Since(cache.when) < CacheLifeTime {
 			addresses = cache.Addresses