Browse Source

lib/syncthing: Prevent hangup on error during startup (fixes #6043) (#6047)

Simon Frei 6 years ago
parent
commit
b8907b49f9
3 changed files with 90 additions and 29 deletions
  1. 3 1
      cmd/syncthing/main.go
  2. 19 27
      lib/syncthing/syncthing.go
  3. 68 1
      lib/syncthing/syncthing_test.go

+ 3 - 1
cmd/syncthing/main.go

@@ -647,7 +647,9 @@ func syncthingMain(runtimeOptions RuntimeOptions) {
 		}
 	}
 
-	app.Start()
+	if err := app.Start(); err != nil {
+		os.Exit(int(syncthing.ExitError))
+	}
 
 	cleanConfigDirectory()
 

+ 19 - 27
lib/syncthing/syncthing.go

@@ -73,14 +73,13 @@ type App struct {
 	opts        Options
 	exitStatus  ExitStatus
 	err         error
-	startOnce   sync.Once
 	stopOnce    sync.Once
 	stop        chan struct{}
 	stopped     chan struct{}
 }
 
 func New(cfg config.Wrapper, ll *db.Lowlevel, evLogger events.Logger, cert tls.Certificate, opts Options) *App {
-	return &App{
+	a := &App{
 		cfg:      cfg,
 		ll:       ll,
 		evLogger: evLogger,
@@ -89,25 +88,21 @@ func New(cfg config.Wrapper, ll *db.Lowlevel, evLogger events.Logger, cert tls.C
 		stop:     make(chan struct{}),
 		stopped:  make(chan struct{}),
 	}
-}
-
-// Run does the same as start, but then does not return until the app stops. It
-// is equivalent to calling Start and then Wait.
-func (a *App) Run() ExitStatus {
-	a.Start()
-	return a.Wait()
+	close(a.stopped) // Hasn't been started, so shouldn't block on Wait.
+	return a
 }
 
 // Start executes the app and returns once all the startup operations are done,
 // e.g. the API is ready for use.
-func (a *App) Start() {
-	a.startOnce.Do(func() {
-		if err := a.startup(); err != nil {
-			a.stopWithErr(ExitError, err)
-			return
-		}
-		go a.run()
-	})
+// Must be called once only.
+func (a *App) Start() error {
+	if err := a.startup(); err != nil {
+		a.stopWithErr(ExitError, err)
+		return err
+	}
+	a.stopped = make(chan struct{})
+	go a.run()
+	return nil
 }
 
 func (a *App) startup() error {
@@ -378,7 +373,8 @@ func (a *App) run() {
 	close(a.stopped)
 }
 
-// Wait blocks until the app stops running.
+// Wait blocks until the app stops running. Also returns if the app hasn't been
+// started yet.
 func (a *App) Wait() ExitStatus {
 	<-a.stopped
 	return a.exitStatus
@@ -388,11 +384,11 @@ func (a *App) Wait() ExitStatus {
 // for the app to stop before returning.
 func (a *App) Error() error {
 	select {
-	case <-a.stopped:
-		return nil
+	case <-a.stop:
+		return a.err
 	default:
 	}
-	return a.err
+	return nil
 }
 
 // Stop stops the app and sets its exit status to given reason, unless the app
@@ -403,12 +399,8 @@ func (a *App) Stop(stopReason ExitStatus) ExitStatus {
 
 func (a *App) stopWithErr(stopReason ExitStatus, err error) ExitStatus {
 	a.stopOnce.Do(func() {
-		// ExitSuccess is the default value for a.exitStatus. If another status
-		// was already set, ignore the stop reason given as argument to Stop.
-		if a.exitStatus == ExitSuccess {
-			a.exitStatus = stopReason
-			a.err = err
-		}
+		a.exitStatus = stopReason
+		a.err = err
 		close(a.stop)
 	})
 	return a.exitStatus

+ 68 - 1
lib/syncthing/syncthing_test.go

@@ -7,20 +7,36 @@
 package syncthing
 
 import (
+	"io/ioutil"
+	"os"
+	"path/filepath"
 	"testing"
+	"time"
 
 	"github.com/syncthing/syncthing/lib/config"
 	"github.com/syncthing/syncthing/lib/events"
 	"github.com/syncthing/syncthing/lib/protocol"
+	"github.com/syncthing/syncthing/lib/tlsutil"
 )
 
+func tempCfgFilename(t *testing.T) string {
+	t.Helper()
+	f, err := ioutil.TempFile("", "syncthing-testConfig-")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer f.Close()
+	return f.Name()
+}
+
 func TestShortIDCheck(t *testing.T) {
-	cfg := config.Wrap("/tmp/test", config.Configuration{
+	cfg := config.Wrap(tempCfgFilename(t), config.Configuration{
 		Devices: []config.DeviceConfiguration{
 			{DeviceID: protocol.DeviceID{8, 16, 24, 32, 40, 48, 56, 0, 0}},
 			{DeviceID: protocol.DeviceID{8, 16, 24, 32, 40, 48, 56, 1, 1}}, // first 56 bits same, differ in the first 64 bits
 		},
 	}, events.NoopLogger)
+	defer os.Remove(cfg.ConfigPath())
 
 	if err := checkShortIDs(cfg); err != nil {
 		t.Error("Unexpected error:", err)
@@ -37,3 +53,54 @@ func TestShortIDCheck(t *testing.T) {
 		t.Error("Should have gotten an error")
 	}
 }
+
+func TestStartupFail(t *testing.T) {
+	tmpDir, err := ioutil.TempDir("", "syncthing-TestStartupFail-")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.RemoveAll(tmpDir)
+
+	cert, err := tlsutil.NewCertificate(filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key"), "syncthing")
+	if err != nil {
+		t.Fatal(err)
+	}
+	id := protocol.NewDeviceID(cert.Certificate[0])
+	conflID := protocol.DeviceID{}
+	copy(conflID[:8], id[:8])
+
+	cfg := config.Wrap(tempCfgFilename(t), config.Configuration{
+		Devices: []config.DeviceConfiguration{
+			{DeviceID: id},
+			{DeviceID: conflID},
+		},
+	}, events.NoopLogger)
+	defer os.Remove(cfg.ConfigPath())
+
+	app := New(cfg, nil, events.NoopLogger, cert, Options{})
+	startErr := app.Start()
+	if startErr == nil {
+		t.Fatal("Expected an error from Start, got nil")
+	}
+
+	done := make(chan struct{})
+	var waitE ExitStatus
+	go func() {
+		waitE = app.Wait()
+		close(done)
+	}()
+
+	select {
+	case <-time.After(time.Second):
+		t.Fatal("Wait did not return within 1s")
+	case <-done:
+	}
+
+	if waitE != ExitError {
+		t.Errorf("Got exit status %v, expected %v", waitE, ExitError)
+	}
+
+	if err = app.Error(); err != startErr {
+		t.Errorf(`Got different errors "%v" from Start and "%v" from Error`, startErr, err)
+	}
+}