浏览代码

lib: More contextification (#6343)

Simon Frei 5 年之前
父节点
当前提交
f0e33d052a

+ 2 - 2
lib/api/api.go

@@ -1066,7 +1066,7 @@ func (s *service) getSupportBundle(w http.ResponseWriter, r *http.Request) {
 	}
 
 	// Report Data as a JSON
-	if usageReportingData, err := json.MarshalIndent(s.urService.ReportData(), "", "  "); err != nil {
+	if usageReportingData, err := json.MarshalIndent(s.urService.ReportData(context.TODO()), "", "  "); err != nil {
 		l.Warnln("Support bundle: failed to create versionPlatform.json:", err)
 	} else {
 		files = append(files, fileEntry{name: "usage-reporting.json.txt", data: usageReportingData})
@@ -1151,7 +1151,7 @@ func (s *service) getReport(w http.ResponseWriter, r *http.Request) {
 	if val, _ := strconv.Atoi(r.URL.Query().Get("version")); val > 0 {
 		version = val
 	}
-	sendJSON(w, s.urService.ReportDataPreview(version))
+	sendJSON(w, s.urService.ReportDataPreview(context.TODO(), version))
 }
 
 func (s *service) getRandomString(w http.ResponseWriter, r *http.Request) {

+ 28 - 1
lib/model/bytesemaphore.go

@@ -7,6 +7,7 @@
 package model
 
 import (
+	"context"
 	"sync"
 )
 
@@ -29,19 +30,45 @@ func newByteSemaphore(max int) *byteSemaphore {
 	return &s
 }
 
+func (s *byteSemaphore) takeWithContext(ctx context.Context, bytes int) error {
+	done := make(chan struct{})
+	var err error
+	go func() {
+		err = s.takeInner(ctx, bytes)
+		close(done)
+	}()
+	select {
+	case <-done:
+	case <-ctx.Done():
+		s.cond.Broadcast()
+		<-done
+	}
+	return err
+}
+
 func (s *byteSemaphore) take(bytes int) {
+	_ = s.takeInner(context.Background(), bytes)
+}
+
+func (s *byteSemaphore) takeInner(ctx context.Context, bytes int) error {
 	s.mut.Lock()
+	defer s.mut.Unlock()
 	if bytes > s.max {
 		bytes = s.max
 	}
 	for bytes > s.available {
 		s.cond.Wait()
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		default:
+		}
 		if bytes > s.max {
 			bytes = s.max
 		}
 	}
 	s.available -= bytes
-	s.mut.Unlock()
+	return nil
 }
 
 func (s *byteSemaphore) give(bytes int) {

+ 6 - 2
lib/model/folder.go

@@ -301,7 +301,9 @@ func (f *folder) pull() bool {
 	f.setState(FolderSyncWaiting)
 	defer f.setState(FolderIdle)
 
-	f.ioLimiter.take(1)
+	if err := f.ioLimiter.takeWithContext(f.ctx, 1); err != nil {
+		return true
+	}
 	defer f.ioLimiter.give(1)
 
 	return f.puller.pull()
@@ -340,7 +342,9 @@ func (f *folder) scanSubdirs(subDirs []string) error {
 	f.setError(nil)
 	f.setState(FolderScanWaiting)
 
-	f.ioLimiter.take(1)
+	if err := f.ioLimiter.takeWithContext(f.ctx, 1); err != nil {
+		return err
+	}
 	defer f.ioLimiter.give(1)
 
 	for i := range subDirs {

+ 4 - 1
lib/model/folder_sendrecv.go

@@ -1392,7 +1392,10 @@ func (f *sendReceiveFolder) pullerRoutine(in <-chan pullBlockState, out chan<- *
 		state := state
 		bytes := int(state.block.Size)
 
-		requestLimiter.take(bytes)
+		if err := requestLimiter.takeWithContext(f.ctx, bytes); err != nil {
+			break
+		}
+
 		wg.Add(1)
 
 		go func() {

+ 3 - 0
lib/model/model_test.go

@@ -3211,6 +3211,9 @@ func TestParentOfUnignored(t *testing.T) {
 // restarts would leave more than one folder runner alive.
 func TestFolderRestartZombies(t *testing.T) {
 	wrapper := createTmpWrapper(defaultCfg.Copy())
+	opts := wrapper.Options()
+	opts.RawMaxFolderConcurrency = -1
+	wrapper.SetOptions(opts)
 	folderCfg, _ := wrapper.Folder("default")
 	folderCfg.FilesystemType = fs.FilesystemTypeFake
 	wrapper.SetFolder(folderCfg)

+ 3 - 2
lib/nat/interface.go

@@ -7,6 +7,7 @@
 package nat
 
 import (
+	"context"
 	"net"
 	"time"
 )
@@ -21,6 +22,6 @@ const (
 type Device interface {
 	ID() string
 	GetLocalIPAddress() net.IP
-	AddPortMapping(protocol Protocol, internalPort, externalPort int, description string, duration time.Duration) (int, error)
-	GetExternalIPAddress() (net.IP, error)
+	AddPortMapping(ctx context.Context, protocol Protocol, internalPort, externalPort int, description string, duration time.Duration) (int, error)
+	GetExternalIPAddress(ctx context.Context) (net.IP, error)
 }

+ 3 - 3
lib/nat/service.go

@@ -304,7 +304,7 @@ func (s *Service) tryNATDevice(ctx context.Context, natd Device, intPort, extPor
 	if extPort != 0 {
 		// First try renewing our existing mapping, if we have one.
 		name := fmt.Sprintf("syncthing-%d", extPort)
-		port, err = natd.AddPortMapping(TCP, intPort, extPort, name, leaseTime)
+		port, err = natd.AddPortMapping(ctx, TCP, intPort, extPort, name, leaseTime)
 		if err == nil {
 			extPort = port
 			goto findIP
@@ -322,7 +322,7 @@ func (s *Service) tryNATDevice(ctx context.Context, natd Device, intPort, extPor
 		// Then try up to ten random ports.
 		extPort = 1024 + predictableRand.Intn(65535-1024)
 		name := fmt.Sprintf("syncthing-%d", extPort)
-		port, err = natd.AddPortMapping(TCP, intPort, extPort, name, leaseTime)
+		port, err = natd.AddPortMapping(ctx, TCP, intPort, extPort, name, leaseTime)
 		if err == nil {
 			extPort = port
 			goto findIP
@@ -333,7 +333,7 @@ func (s *Service) tryNATDevice(ctx context.Context, natd Device, intPort, extPor
 	return Address{}, err
 
 findIP:
-	ip, err := natd.GetExternalIPAddress()
+	ip, err := natd.GetExternalIPAddress(ctx)
 	if err != nil {
 		l.Debugln("Error getting external ip on", natd.ID(), err)
 		ip = nil

+ 22 - 5
lib/pmp/pmp.go

@@ -15,7 +15,9 @@ import (
 
 	"github.com/jackpal/gateway"
 	"github.com/jackpal/go-nat-pmp"
+
 	"github.com/syncthing/syncthing/lib/nat"
+	"github.com/syncthing/syncthing/lib/util"
 )
 
 func init() {
@@ -23,7 +25,12 @@ func init() {
 }
 
 func Discover(ctx context.Context, renewal, timeout time.Duration) []nat.Device {
-	ip, err := gateway.DiscoverGateway()
+	var ip net.IP
+	err := util.CallWithContext(ctx, func() error {
+		var err error
+		ip, err = gateway.DiscoverGateway()
+		return err
+	})
 	if err != nil {
 		l.Debugln("Failed to discover gateway", err)
 		return nil
@@ -81,14 +88,19 @@ func (w *wrapper) GetLocalIPAddress() net.IP {
 	return w.localIP
 }
 
-func (w *wrapper) AddPortMapping(protocol nat.Protocol, internalPort, externalPort int, description string, duration time.Duration) (int, error) {
+func (w *wrapper) AddPortMapping(ctx context.Context, protocol nat.Protocol, internalPort, externalPort int, description string, duration time.Duration) (int, error) {
 	// NAT-PMP says that if duration is 0, the mapping is actually removed
 	// Swap the zero with the renewal value, which should make the lease for the
 	// exact amount of time between the calls.
 	if duration == 0 {
 		duration = w.renewal
 	}
-	result, err := w.client.AddPortMapping(strings.ToLower(string(protocol)), internalPort, externalPort, int(duration/time.Second))
+	var result *natpmp.AddPortMappingResult
+	err := util.CallWithContext(ctx, func() error {
+		var err error
+		result, err = w.client.AddPortMapping(strings.ToLower(string(protocol)), internalPort, externalPort, int(duration/time.Second))
+		return err
+	})
 	port := 0
 	if result != nil {
 		port = int(result.MappedExternalPort)
@@ -96,8 +108,13 @@ func (w *wrapper) AddPortMapping(protocol nat.Protocol, internalPort, externalPo
 	return port, err
 }
 
-func (w *wrapper) GetExternalIPAddress() (net.IP, error) {
-	result, err := w.client.GetExternalAddress()
+func (w *wrapper) GetExternalIPAddress(ctx context.Context) (net.IP, error) {
+	var result *natpmp.GetExternalAddressResult
+	err := util.CallWithContext(ctx, func() error {
+		var err error
+		result, err = w.client.GetExternalAddress()
+		return err
+	})
 	ip := net.IPv4zero
 	if result != nil {
 		ip = net.IPv4(

+ 7 - 1
lib/relay/client/dynamic.go

@@ -45,7 +45,13 @@ func (c *dynamicClient) serve(ctx context.Context) error {
 
 	l.Debugln(c, "looking up dynamic relays")
 
-	data, err := http.Get(uri.String())
+	req, err := http.NewRequest("GET", uri.String(), nil)
+	if err != nil {
+		l.Debugln(c, "failed to lookup dynamic relays", err)
+		return err
+	}
+	req.Cancel = ctx.Done()
+	data, err := http.DefaultClient.Do(req)
 	if err != nil {
 		l.Debugln(c, "failed to lookup dynamic relays", err)
 		return err

+ 6 - 1
lib/stun/stun.go

@@ -185,7 +185,12 @@ func (s *Service) runStunForServer(ctx context.Context, addr string) {
 	}
 	s.client.SetServerAddr(udpAddr.String())
 
-	natType, extAddr, err := s.client.Discover()
+	var natType stun.NATType
+	var extAddr *stun.Host
+	err = util.CallWithContext(ctx, func() error {
+		natType, extAddr, err = s.client.Discover()
+		return err
+	})
 	if err != nil || extAddr == nil {
 		l.Debugf("%s stun discovery on %s: %s", s, addr, err)
 		return

+ 2 - 1
lib/syncthing/syncthing.go

@@ -7,6 +7,7 @@
 package syncthing
 
 import (
+	"context"
 	"crypto/tls"
 	"fmt"
 	"io"
@@ -179,7 +180,7 @@ func (a *App) startup() error {
 		}()
 	}
 
-	perf := ur.CpuBench(3, 150*time.Millisecond, true)
+	perf := ur.CpuBench(context.Background(), 3, 150*time.Millisecond, true)
 	l.Infof("Hashing performance is %.02f MB/s", perf)
 
 	if err := db.UpdateSchema(a.ll); err != nil {

+ 8 - 7
lib/upnp/igd_service.go

@@ -33,6 +33,7 @@
 package upnp
 
 import (
+	"context"
 	"encoding/xml"
 	"fmt"
 	"net"
@@ -52,7 +53,7 @@ type IGDService struct {
 }
 
 // AddPortMapping adds a port mapping to the specified IGD service.
-func (s *IGDService) AddPortMapping(protocol nat.Protocol, internalPort, externalPort int, description string, duration time.Duration) (int, error) {
+func (s *IGDService) AddPortMapping(ctx context.Context, protocol nat.Protocol, internalPort, externalPort int, description string, duration time.Duration) (int, error) {
 	tpl := `<u:AddPortMapping xmlns:u="%s">
 	<NewRemoteHost></NewRemoteHost>
 	<NewExternalPort>%d</NewExternalPort>
@@ -65,7 +66,7 @@ func (s *IGDService) AddPortMapping(protocol nat.Protocol, internalPort, externa
 	</u:AddPortMapping>`
 	body := fmt.Sprintf(tpl, s.URN, externalPort, protocol, internalPort, s.LocalIP, description, duration/time.Second)
 
-	response, err := soapRequest(s.URL, s.URN, "AddPortMapping", body)
+	response, err := soapRequest(ctx, s.URL, s.URN, "AddPortMapping", body)
 	if err != nil && duration > 0 {
 		// Try to repair error code 725 - OnlyPermanentLeasesSupported
 		envelope := &soapErrorResponse{}
@@ -73,7 +74,7 @@ func (s *IGDService) AddPortMapping(protocol nat.Protocol, internalPort, externa
 			return externalPort, unmarshalErr
 		}
 		if envelope.ErrorCode == 725 {
-			return s.AddPortMapping(protocol, internalPort, externalPort, description, 0)
+			return s.AddPortMapping(ctx, protocol, internalPort, externalPort, description, 0)
 		}
 	}
 
@@ -81,7 +82,7 @@ func (s *IGDService) AddPortMapping(protocol nat.Protocol, internalPort, externa
 }
 
 // DeletePortMapping deletes a port mapping from the specified IGD service.
-func (s *IGDService) DeletePortMapping(protocol nat.Protocol, externalPort int) error {
+func (s *IGDService) DeletePortMapping(ctx context.Context, protocol nat.Protocol, externalPort int) error {
 	tpl := `<u:DeletePortMapping xmlns:u="%s">
 	<NewRemoteHost></NewRemoteHost>
 	<NewExternalPort>%d</NewExternalPort>
@@ -89,19 +90,19 @@ func (s *IGDService) DeletePortMapping(protocol nat.Protocol, externalPort int)
 	</u:DeletePortMapping>`
 	body := fmt.Sprintf(tpl, s.URN, externalPort, protocol)
 
-	_, err := soapRequest(s.URL, s.URN, "DeletePortMapping", body)
+	_, err := soapRequest(ctx, s.URL, s.URN, "DeletePortMapping", body)
 	return err
 }
 
 // GetExternalIPAddress queries the IGD service for its external IP address.
 // Returns nil if the external IP address is invalid or undefined, along with
 // any relevant errors
-func (s *IGDService) GetExternalIPAddress() (net.IP, error) {
+func (s *IGDService) GetExternalIPAddress(ctx context.Context) (net.IP, error) {
 	tpl := `<u:GetExternalIPAddress xmlns:u="%s" />`
 
 	body := fmt.Sprintf(tpl, s.URN)
 
-	response, err := soapRequest(s.URL, s.URN, "GetExternalIPAddress", body)
+	response, err := soapRequest(ctx, s.URL, s.URN, "GetExternalIPAddress", body)
 
 	if err != nil {
 		return nil, err

+ 2 - 1
lib/upnp/upnp.go

@@ -423,7 +423,7 @@ func replaceRawPath(u *url.URL, rp string) {
 	}
 }
 
-func soapRequest(url, service, function, message string) ([]byte, error) {
+func soapRequest(ctx context.Context, url, service, function, message string) ([]byte, error) {
 	tpl := `<?xml version="1.0" ?>
 	<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" s:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/">
 	<s:Body>%s</s:Body>
@@ -437,6 +437,7 @@ func soapRequest(url, service, function, message string) ([]byte, error) {
 	if err != nil {
 		return resp, err
 	}
+	req.Cancel = ctx.Done()
 	req.Close = true
 	req.Header.Set("Content-Type", `text/xml; charset="utf-8"`)
 	req.Header.Set("User-Agent", "syncthing/1.0")

+ 27 - 15
lib/ur/usage_report.go

@@ -63,18 +63,18 @@ func New(cfg config.Wrapper, m model.Model, connectionsService connections.Servi
 
 // ReportData returns the data to be sent in a usage report with the currently
 // configured usage reporting version.
-func (s *Service) ReportData() map[string]interface{} {
+func (s *Service) ReportData(ctx context.Context) map[string]interface{} {
 	urVersion := s.cfg.Options().URAccepted
-	return s.reportData(urVersion, false)
+	return s.reportData(ctx, urVersion, false)
 }
 
 // ReportDataPreview returns a preview of the data to be sent in a usage report
 // with the given version.
-func (s *Service) ReportDataPreview(urVersion int) map[string]interface{} {
-	return s.reportData(urVersion, true)
+func (s *Service) ReportDataPreview(ctx context.Context, urVersion int) map[string]interface{} {
+	return s.reportData(ctx, urVersion, true)
 }
 
-func (s *Service) reportData(urVersion int, preview bool) map[string]interface{} {
+func (s *Service) reportData(ctx context.Context, urVersion int, preview bool) map[string]interface{} {
 	opts := s.cfg.Options()
 	res := make(map[string]interface{})
 	res["urVersion"] = urVersion
@@ -112,8 +112,8 @@ func (s *Service) reportData(urVersion int, preview bool) map[string]interface{}
 	var mem runtime.MemStats
 	runtime.ReadMemStats(&mem)
 	res["memoryUsageMiB"] = (mem.Sys - mem.HeapReleased) / 1024 / 1024
-	res["sha256Perf"] = CpuBench(5, 125*time.Millisecond, false)
-	res["hashPerf"] = CpuBench(5, 125*time.Millisecond, true)
+	res["sha256Perf"] = CpuBench(ctx, 5, 125*time.Millisecond, false)
+	res["hashPerf"] = CpuBench(ctx, 5, 125*time.Millisecond, true)
 
 	bytes, err := memorySize()
 	if err == nil {
@@ -368,8 +368,8 @@ func (s *Service) UptimeS() int {
 	return int(time.Since(StartTime).Seconds())
 }
 
-func (s *Service) sendUsageReport() error {
-	d := s.ReportData()
+func (s *Service) sendUsageReport(ctx context.Context) error {
+	d := s.ReportData(ctx)
 	var b bytes.Buffer
 	if err := json.NewEncoder(&b).Encode(d); err != nil {
 		return err
@@ -384,7 +384,15 @@ func (s *Service) sendUsageReport() error {
 			},
 		},
 	}
-	_, err := client.Post(s.cfg.Options().URURL, "application/json", &b)
+	req, err := http.NewRequest("POST", s.cfg.Options().URURL, &b)
+	if err == nil {
+		req.Header.Set("Content-Type", "application/json")
+		req.Cancel = ctx.Done()
+		var resp *http.Response
+		resp, err = client.Do(req)
+		resp.Body.Close()
+	}
+
 	return err
 }
 
@@ -401,7 +409,7 @@ func (s *Service) serve(ctx context.Context) {
 			t.Reset(0)
 		case <-t.C:
 			if s.cfg.Options().URAccepted >= 2 {
-				err := s.sendUsageReport()
+				err := s.sendUsageReport(ctx)
 				if err != nil {
 					l.Infoln("Usage report:", err)
 				} else {
@@ -439,7 +447,7 @@ var (
 )
 
 // CpuBench returns CPU performance as a measure of single threaded SHA-256 MiB/s
-func CpuBench(iterations int, duration time.Duration, useWeakHash bool) float64 {
+func CpuBench(ctx context.Context, iterations int, duration time.Duration, useWeakHash bool) float64 {
 	blocksResultMut.Lock()
 	defer blocksResultMut.Unlock()
 
@@ -449,7 +457,7 @@ func CpuBench(iterations int, duration time.Duration, useWeakHash bool) float64
 
 	var perf float64
 	for i := 0; i < iterations; i++ {
-		if v := cpuBenchOnce(duration, useWeakHash, bs); v > perf {
+		if v := cpuBenchOnce(ctx, duration, useWeakHash, bs); v > perf {
 			perf = v
 		}
 	}
@@ -457,12 +465,16 @@ func CpuBench(iterations int, duration time.Duration, useWeakHash bool) float64
 	return perf
 }
 
-func cpuBenchOnce(duration time.Duration, useWeakHash bool, bs []byte) float64 {
+func cpuBenchOnce(ctx context.Context, duration time.Duration, useWeakHash bool, bs []byte) float64 {
 	t0 := time.Now()
 	b := 0
+	var err error
 	for time.Since(t0) < duration {
 		r := bytes.NewReader(bs)
-		blocksResult, _ = scanner.Blocks(context.TODO(), r, protocol.MinBlockSize, int64(len(bs)), nil, useWeakHash)
+		blocksResult, err = scanner.Blocks(ctx, r, protocol.MinBlockSize, int64(len(bs)), nil, useWeakHash)
+		if err != nil {
+			return 0 // Context done
+		}
 		b += len(bs)
 	}
 	d := time.Since(t0)

+ 15 - 0
lib/util/utils.go

@@ -274,3 +274,18 @@ func (s *service) SetError(err error) {
 func (s *service) String() string {
 	return fmt.Sprintf("Service@%p created by %v", s, s.creator)
 }
+
+func CallWithContext(ctx context.Context, fn func() error) error {
+	var err error
+	done := make(chan struct{})
+	go func() {
+		err = fn()
+		close(done)
+	}()
+	select {
+	case <-done:
+		return err
+	case <-ctx.Done():
+		return ctx.Err()
+	}
+}