Просмотр исходного кода

cmd/syncthing, lib/syncthing: Create library utils (ref #4085) (#5871)

Simon Frei 6 лет назад
Родитель
Сommit
46e72d76b5
3 измененных файлов с 134 добавлено и 107 удалено
  1. 6 93
      cmd/syncthing/main.go
  2. 2 14
      lib/syncthing/syncthing.go
  3. 126 0
      lib/syncthing/utils.go

+ 6 - 93
cmd/syncthing/main.go

@@ -437,7 +437,7 @@ func generate(generateDir string) error {
 		l.Warnln("Config exists; will not overwrite.")
 		return nil
 	}
-	cfg, err := defaultConfig(cfgFile, myID)
+	cfg, err := syncthing.DefaultConfig(cfgFile, myID, noDefaultFolder)
 	if err != nil {
 		return err
 	}
@@ -549,25 +549,16 @@ func upgradeViaRest() error {
 
 func syncthingMain(runtimeOptions RuntimeOptions) {
 	// Ensure that we have a certificate and key.
-	cert, err := tls.LoadX509KeyPair(
+	cert, err := syncthing.LoadOrGenerateCertificate(
 		locations.Get(locations.CertFile),
 		locations.Get(locations.KeyFile),
 	)
 	if err != nil {
-		l.Infof("Generating ECDSA key and certificate for %s...", tlsDefaultCommonName)
-		cert, err = tlsutil.NewCertificate(
-			locations.Get(locations.CertFile),
-			locations.Get(locations.KeyFile),
-			tlsDefaultCommonName,
-		)
-		if err != nil {
-			l.Warnln("Failed to generate certificate:", err)
-			os.Exit(1)
-		}
+		l.Warnln("Failed to load/generate certificate:", err)
+		os.Exit(1)
 	}
-	myID := protocol.NewDeviceID(cert.Certificate[0])
 
-	cfg, err := loadConfigAtStartup(runtimeOptions.allowNewerConfig, myID)
+	cfg, err := syncthing.LoadConfigAtStartup(locations.Get(locations.ConfigFile), cert, runtimeOptions.allowNewerConfig, noDefaultFolder)
 	if err != nil {
 		l.Warnln("Failed to initialize config:", err)
 		os.Exit(exitError)
@@ -690,74 +681,12 @@ func loadOrDefaultConfig(myID protocol.DeviceID) (config.Wrapper, error) {
 	cfg, err := config.Load(cfgFile, myID)
 
 	if err != nil {
-		cfg, err = defaultConfig(cfgFile, myID)
+		cfg, err = syncthing.DefaultConfig(cfgFile, myID, noDefaultFolder)
 	}
 
 	return cfg, err
 }
 
-func loadConfigAtStartup(allowNewerConfig bool, myID protocol.DeviceID) (config.Wrapper, error) {
-	cfgFile := locations.Get(locations.ConfigFile)
-	cfg, err := config.Load(cfgFile, myID)
-	if os.IsNotExist(err) {
-		cfg, err = defaultConfig(cfgFile, myID)
-		if err != nil {
-			return nil, errors.Wrap(err, "failed to generate default config")
-		}
-		err = cfg.Save()
-		if err != nil {
-			return nil, errors.Wrap(err, "failed to save default config")
-		}
-		l.Infof("Default config saved. Edit %s to taste (with Syncthing stopped) or use the GUI", cfg.ConfigPath())
-	} else if err == io.EOF {
-		return nil, errors.New("Failed to load config: unexpected end of file. Truncated or empty configuration?")
-	} else if err != nil {
-		return nil, errors.Wrap(err, "failed to load config")
-	}
-
-	if cfg.RawCopy().OriginalVersion != config.CurrentVersion {
-		if cfg.RawCopy().OriginalVersion == config.CurrentVersion+1101 {
-			l.Infof("Now, THAT's what we call a config from the future! Don't worry. As long as you hit that wire with the connecting hook at precisely eighty-eight miles per hour the instant the lightning strikes the tower... everything will be fine.")
-		}
-		if cfg.RawCopy().OriginalVersion > config.CurrentVersion && !allowNewerConfig {
-			return nil, fmt.Errorf("Config file version (%d) is newer than supported version (%d). If this is expected, use -allow-newer-config to override.", cfg.RawCopy().OriginalVersion, config.CurrentVersion)
-		}
-		err = archiveAndSaveConfig(cfg)
-		if err != nil {
-			return nil, errors.Wrap(err, "config archive")
-		}
-	}
-
-	return cfg, nil
-}
-
-func archiveAndSaveConfig(cfg config.Wrapper) error {
-	// Copy the existing config to an archive copy
-	archivePath := cfg.ConfigPath() + fmt.Sprintf(".v%d", cfg.RawCopy().OriginalVersion)
-	l.Infoln("Archiving a copy of old config file format at:", archivePath)
-	if err := copyFile(cfg.ConfigPath(), archivePath); err != nil {
-		return err
-	}
-
-	// Do a regular atomic config sve
-	return cfg.Save()
-}
-
-func copyFile(src, dst string) error {
-	bs, err := ioutil.ReadFile(src)
-	if err != nil {
-		return err
-	}
-
-	if err := ioutil.WriteFile(dst, bs, 0600); err != nil {
-		// Attempt to clean up
-		os.Remove(dst)
-		return err
-	}
-
-	return nil
-}
-
 func auditWriter(auditFile string) io.Writer {
 	var fd io.Writer
 	var err error
@@ -790,22 +719,6 @@ func auditWriter(auditFile string) io.Writer {
 	return fd
 }
 
-func defaultConfig(cfgFile string, myID protocol.DeviceID) (config.Wrapper, error) {
-	newCfg, err := config.NewWithFreePorts(myID)
-	if err != nil {
-		return nil, err
-	}
-
-	if noDefaultFolder {
-		l.Infoln("We will skip creation of a default folder on first start since the proper envvar is set")
-		return config.Wrap(cfgFile, newCfg), nil
-	}
-
-	newCfg.Folders = append(newCfg.Folders, config.NewFolderConfiguration(myID, "default", "Default Folder", fs.FilesystemTypeBasic, locations.Get(locations.DefFolder)))
-	l.Infoln("Default folder created and/or linked to new config")
-	return config.Wrap(cfgFile, newCfg), nil
-}
-
 func resetDB() error {
 	return os.RemoveAll(locations.Get(locations.Database))
 }

+ 2 - 14
lib/syncthing/syncthing.go

@@ -16,6 +16,8 @@ import (
 	"sync"
 	"time"
 
+	"github.com/thejerf/suture"
+
 	"github.com/syncthing/syncthing/lib/api"
 	"github.com/syncthing/syncthing/lib/build"
 	"github.com/syncthing/syncthing/lib/config"
@@ -32,8 +34,6 @@ import (
 	"github.com/syncthing/syncthing/lib/sha256"
 	"github.com/syncthing/syncthing/lib/tlsutil"
 	"github.com/syncthing/syncthing/lib/ur"
-
-	"github.com/thejerf/suture"
 )
 
 const (
@@ -473,15 +473,3 @@ func (e *controller) Shutdown() {
 func (e *controller) ExitUpgrading() {
 	e.Stop(ExitUpgrade)
 }
-
-func LoadCertificate(certFile, keyFile string) (tls.Certificate, error) {
-	return tls.LoadX509KeyPair(certFile, keyFile)
-}
-
-func LoadConfig(path string, cert tls.Certificate) (config.Wrapper, error) {
-	return config.Load(path, protocol.NewDeviceID(cert.Certificate[0]))
-}
-
-func OpenGoleveldb(path string) (*db.Lowlevel, error) {
-	return db.Open(path)
-}

+ 126 - 0
lib/syncthing/utils.go

@@ -0,0 +1,126 @@
+// Copyright (C) 2014 The Syncthing Authors.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at https://mozilla.org/MPL/2.0/.
+
+package syncthing
+
+import (
+	"crypto/tls"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"os"
+
+	"github.com/pkg/errors"
+
+	"github.com/syncthing/syncthing/lib/config"
+	"github.com/syncthing/syncthing/lib/db"
+	"github.com/syncthing/syncthing/lib/fs"
+	"github.com/syncthing/syncthing/lib/locations"
+	"github.com/syncthing/syncthing/lib/protocol"
+	"github.com/syncthing/syncthing/lib/tlsutil"
+)
+
+func LoadOrGenerateCertificate(certFile, keyFile string) (tls.Certificate, error) {
+	cert, err := tls.LoadX509KeyPair(
+		locations.Get(locations.CertFile),
+		locations.Get(locations.KeyFile),
+	)
+	if err != nil {
+		l.Infof("Generating ECDSA key and certificate for %s...", tlsDefaultCommonName)
+		return tlsutil.NewCertificate(
+			locations.Get(locations.CertFile),
+			locations.Get(locations.KeyFile),
+			tlsDefaultCommonName,
+		)
+	}
+	return cert, nil
+}
+
+func DefaultConfig(path string, myID protocol.DeviceID, noDefaultFolder bool) (config.Wrapper, error) {
+	newCfg, err := config.NewWithFreePorts(myID)
+	if err != nil {
+		return nil, err
+	}
+
+	if noDefaultFolder {
+		l.Infoln("We will skip creation of a default folder on first start")
+		return config.Wrap(path, newCfg), nil
+	}
+
+	newCfg.Folders = append(newCfg.Folders, config.NewFolderConfiguration(myID, "default", "Default Folder", fs.FilesystemTypeBasic, locations.Get(locations.DefFolder)))
+	l.Infoln("Default folder created and/or linked to new config")
+	return config.Wrap(path, newCfg), nil
+}
+
+// LoadConfigAtStartup loads an existing config. If it doesn't yet exist, it
+// creates a default one, without the default folder if noDefaultFolder is ture.
+// Otherwise it checks the version, and archives and upgrades the config if
+// necessary or returns an error, if the version isn't compatible.
+func LoadConfigAtStartup(path string, cert tls.Certificate, allowNewerConfig, noDefaultFolder bool) (config.Wrapper, error) {
+	myID := protocol.NewDeviceID(cert.Certificate[0])
+	cfg, err := config.Load(path, myID)
+	if fs.IsNotExist(err) {
+		cfg, err = DefaultConfig(path, myID, noDefaultFolder)
+		if err != nil {
+			return nil, errors.Wrap(err, "failed to generate default config")
+		}
+		err = cfg.Save()
+		if err != nil {
+			return nil, errors.Wrap(err, "failed to save default config")
+		}
+		l.Infof("Default config saved. Edit %s to taste (with Syncthing stopped) or use the GUI", cfg.ConfigPath())
+	} else if err == io.EOF {
+		return nil, errors.New("failed to load config: unexpected end of file. Truncated or empty configuration?")
+	} else if err != nil {
+		return nil, errors.Wrap(err, "failed to load config")
+	}
+
+	if cfg.RawCopy().OriginalVersion != config.CurrentVersion {
+		if cfg.RawCopy().OriginalVersion == config.CurrentVersion+1101 {
+			l.Infof("Now, THAT's what we call a config from the future! Don't worry. As long as you hit that wire with the connecting hook at precisely eighty-eight miles per hour the instant the lightning strikes the tower... everything will be fine.")
+		}
+		if cfg.RawCopy().OriginalVersion > config.CurrentVersion && !allowNewerConfig {
+			return nil, fmt.Errorf("config file version (%d) is newer than supported version (%d). If this is expected, use -allow-newer-config to override.", cfg.RawCopy().OriginalVersion, config.CurrentVersion)
+		}
+		err = archiveAndSaveConfig(cfg)
+		if err != nil {
+			return nil, errors.Wrap(err, "config archive")
+		}
+	}
+
+	return cfg, nil
+}
+
+func archiveAndSaveConfig(cfg config.Wrapper) error {
+	// Copy the existing config to an archive copy
+	archivePath := cfg.ConfigPath() + fmt.Sprintf(".v%d", cfg.RawCopy().OriginalVersion)
+	l.Infoln("Archiving a copy of old config file format at:", archivePath)
+	if err := copyFile(cfg.ConfigPath(), archivePath); err != nil {
+		return err
+	}
+
+	// Do a regular atomic config sve
+	return cfg.Save()
+}
+
+func copyFile(src, dst string) error {
+	bs, err := ioutil.ReadFile(src)
+	if err != nil {
+		return err
+	}
+
+	if err := ioutil.WriteFile(dst, bs, 0600); err != nil {
+		// Attempt to clean up
+		os.Remove(dst)
+		return err
+	}
+
+	return nil
+}
+
+func OpenGoleveldb(path string) (*db.Lowlevel, error) {
+	return db.Open(path)
+}