瀏覽代碼

Merge pull request #2333 from calmh/brokenupgrade

Remove global cfg variable (fixes #2294)
Audrius Butkevicius 10 年之前
父節點
當前提交
6f6c1cd330
共有 4 個文件被更改,包括 101 次插入76 次删除
  1. 24 23
      cmd/syncthing/gui.go
  2. 53 41
      cmd/syncthing/main.go
  3. 6 3
      cmd/syncthing/summarysvc.go
  4. 18 9
      cmd/syncthing/usage_report.go

+ 24 - 23
cmd/syncthing/gui.go

@@ -55,7 +55,7 @@ var (
 
 type apiSvc struct {
 	id              protocol.DeviceID
-	cfg             config.GUIConfiguration
+	cfg             *config.Wrapper
 	assetDir        string
 	model           *model.Model
 	eventSub        *events.BufferedSubscription
@@ -67,7 +67,7 @@ type apiSvc struct {
 	systemConfigMut sync.Mutex
 }
 
-func newAPISvc(id protocol.DeviceID, cfg config.GUIConfiguration, assetDir string, m *model.Model, eventSub *events.BufferedSubscription, discoverer *discover.CachingMux, relaySvc *relay.Svc) (*apiSvc, error) {
+func newAPISvc(id protocol.DeviceID, cfg *config.Wrapper, assetDir string, m *model.Model, eventSub *events.BufferedSubscription, discoverer *discover.CachingMux, relaySvc *relay.Svc) (*apiSvc, error) {
 	svc := &apiSvc{
 		id:              id,
 		cfg:             cfg,
@@ -80,7 +80,7 @@ func newAPISvc(id protocol.DeviceID, cfg config.GUIConfiguration, assetDir strin
 	}
 
 	var err error
-	svc.listener, err = svc.getListener(cfg)
+	svc.listener, err = svc.getListener(cfg.GUI())
 	return svc, err
 }
 
@@ -195,20 +195,22 @@ func (s *apiSvc) Serve() {
 		assets:   auto.Assets(),
 	})
 
+	guiCfg := s.cfg.GUI()
+
 	// Wrap everything in CSRF protection. The /rest prefix should be
 	// protected, other requests will grant cookies.
-	handler := csrfMiddleware(s.id.String()[:5], "/rest", s.cfg.APIKey, mux)
+	handler := csrfMiddleware(s.id.String()[:5], "/rest", guiCfg.APIKey, mux)
 
 	// Add our version and ID as a header to responses
 	handler = withDetailsMiddleware(s.id, handler)
 
 	// Wrap everything in basic auth, if user/password is set.
-	if len(s.cfg.User) > 0 && len(s.cfg.Password) > 0 {
-		handler = basicAuthAndSessionMiddleware("sessionid-"+s.id.String()[:5], s.cfg, handler)
+	if len(guiCfg.User) > 0 && len(guiCfg.Password) > 0 {
+		handler = basicAuthAndSessionMiddleware("sessionid-"+s.id.String()[:5], guiCfg, handler)
 	}
 
 	// Redirect to HTTPS if we are supposed to
-	if s.cfg.UseTLS {
+	if guiCfg.UseTLS {
 		handler = redirectToHTTPSMiddleware(handler)
 	}
 
@@ -221,7 +223,7 @@ func (s *apiSvc) Serve() {
 		ReadTimeout: 10 * time.Second,
 	}
 
-	s.fss = newFolderSummarySvc(s.model)
+	s.fss = newFolderSummarySvc(s.cfg, s.model)
 	defer s.fss.Stop()
 	s.fss.ServeBackground()
 
@@ -273,7 +275,6 @@ func (s *apiSvc) CommitConfiguration(from, to config.Configuration) bool {
 		// method.
 		return false
 	}
-	s.cfg = to.GUI
 
 	close(s.stop)
 
@@ -409,12 +410,12 @@ func (s *apiSvc) getDBCompletion(w http.ResponseWriter, r *http.Request) {
 func (s *apiSvc) getDBStatus(w http.ResponseWriter, r *http.Request) {
 	qs := r.URL.Query()
 	folder := qs.Get("folder")
-	res := folderSummary(s.model, folder)
+	res := folderSummary(s.cfg, s.model, folder)
 	w.Header().Set("Content-Type", "application/json; charset=utf-8")
 	json.NewEncoder(w).Encode(res)
 }
 
-func folderSummary(m *model.Model, folder string) map[string]interface{} {
+func folderSummary(cfg *config.Wrapper, m *model.Model, folder string) map[string]interface{} {
 	var res = make(map[string]interface{})
 
 	res["invalid"] = cfg.Folders()[folder].Invalid
@@ -524,7 +525,7 @@ func (s *apiSvc) getDBFile(w http.ResponseWriter, r *http.Request) {
 
 func (s *apiSvc) getSystemConfig(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json; charset=utf-8")
-	json.NewEncoder(w).Encode(cfg.Raw())
+	json.NewEncoder(w).Encode(s.cfg.Raw())
 }
 
 func (s *apiSvc) postSystemConfig(w http.ResponseWriter, r *http.Request) {
@@ -539,7 +540,7 @@ func (s *apiSvc) postSystemConfig(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	if to.GUI.Password != cfg.GUI().Password {
+	if to.GUI.Password != s.cfg.GUI().Password {
 		if to.GUI.Password != "" {
 			hash, err := bcrypt.GenerateFromPassword([]byte(to.GUI.Password), 0)
 			if err != nil {
@@ -554,7 +555,7 @@ func (s *apiSvc) postSystemConfig(w http.ResponseWriter, r *http.Request) {
 
 	// Fixup usage reporting settings
 
-	if curAcc := cfg.Options().URAccepted; to.Options.URAccepted > curAcc {
+	if curAcc := s.cfg.Options().URAccepted; to.Options.URAccepted > curAcc {
 		// UR was enabled
 		to.Options.URAccepted = usageReportVersion
 		to.Options.URUniqueID = randomString(8)
@@ -566,9 +567,9 @@ func (s *apiSvc) postSystemConfig(w http.ResponseWriter, r *http.Request) {
 
 	// Activate and save
 
-	resp := cfg.Replace(to)
+	resp := s.cfg.Replace(to)
 	configInSync = !resp.RequiresRestart
-	cfg.Save()
+	s.cfg.Save()
 }
 
 func (s *apiSvc) getSystemConfigInsync(w http.ResponseWriter, r *http.Request) {
@@ -586,7 +587,7 @@ func (s *apiSvc) postSystemReset(w http.ResponseWriter, r *http.Request) {
 	folder := qs.Get("folder")
 
 	if len(folder) > 0 {
-		if _, ok := cfg.Folders()[folder]; !ok {
+		if _, ok := s.cfg.Folders()[folder]; !ok {
 			http.Error(w, "Invalid folder ID", 500)
 			return
 		}
@@ -594,7 +595,7 @@ func (s *apiSvc) postSystemReset(w http.ResponseWriter, r *http.Request) {
 
 	if len(folder) == 0 {
 		// Reset all folders.
-		for folder := range cfg.Folders() {
+		for folder := range s.cfg.Folders() {
 			s.model.ResetFolder(folder)
 		}
 		s.flushResponse(`{"ok": "resetting database"}`, w)
@@ -632,7 +633,7 @@ func (s *apiSvc) getSystemStatus(w http.ResponseWriter, r *http.Request) {
 	res["alloc"] = m.Alloc
 	res["sys"] = m.Sys - m.HeapReleased
 	res["tilde"] = tilde
-	if cfg.Options().LocalAnnEnabled || cfg.Options().GlobalAnnEnabled {
+	if s.cfg.Options().LocalAnnEnabled || s.cfg.Options().GlobalAnnEnabled {
 		res["discoveryEnabled"] = true
 		discoErrors := make(map[string]string)
 		discoMethods := 0
@@ -718,7 +719,7 @@ func (s *apiSvc) getSystemDiscovery(w http.ResponseWriter, r *http.Request) {
 
 func (s *apiSvc) getReport(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json; charset=utf-8")
-	json.NewEncoder(w).Encode(reportData(s.model))
+	json.NewEncoder(w).Encode(reportData(s.cfg, s.model))
 }
 
 func (s *apiSvc) getDBIgnores(w http.ResponseWriter, r *http.Request) {
@@ -787,7 +788,7 @@ func (s *apiSvc) getSystemUpgrade(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, upgrade.ErrUpgradeUnsupported.Error(), 500)
 		return
 	}
-	rel, err := upgrade.LatestRelease(cfg.Options().ReleasesURL, Version)
+	rel, err := upgrade.LatestRelease(s.cfg.Options().ReleasesURL, Version)
 	if err != nil {
 		http.Error(w, err.Error(), 500)
 		return
@@ -830,7 +831,7 @@ func (s *apiSvc) getLang(w http.ResponseWriter, r *http.Request) {
 }
 
 func (s *apiSvc) postSystemUpgrade(w http.ResponseWriter, r *http.Request) {
-	rel, err := upgrade.LatestRelease(cfg.Options().ReleasesURL, Version)
+	rel, err := upgrade.LatestRelease(s.cfg.Options().ReleasesURL, Version)
 	if err != nil {
 		l.Warnln("getting latest release:", err)
 		http.Error(w, err.Error(), 500)
@@ -928,7 +929,7 @@ func (s *apiSvc) getPeerCompletion(w http.ResponseWriter, r *http.Request) {
 	tot := map[string]float64{}
 	count := map[string]float64{}
 
-	for _, folder := range cfg.Folders() {
+	for _, folder := range s.cfg.Folders() {
 		for _, device := range folder.DeviceIDs() {
 			deviceStr := device.String()
 			if s.model.ConnectedTo(device) {

+ 53 - 41
cmd/syncthing/main.go

@@ -26,7 +26,6 @@ import (
 	"time"
 
 	"github.com/calmh/logger"
-	"github.com/juju/ratelimit"
 	"github.com/syncthing/syncthing/lib/config"
 	"github.com/syncthing/syncthing/lib/connections"
 	"github.com/syncthing/syncthing/lib/db"
@@ -108,15 +107,12 @@ func init() {
 }
 
 var (
-	cfg            *config.Wrapper
-	myID           protocol.DeviceID
-	confDir        string
-	logFlags       = log.Ltime
-	writeRateLimit *ratelimit.Bucket
-	readRateLimit  *ratelimit.Bucket
-	stop           = make(chan int)
-	cert           tls.Certificate
-	lans           []*net.IPNet
+	myID     protocol.DeviceID
+	confDir  string
+	logFlags = log.Ltime
+	stop     = make(chan int)
+	cert     tls.Certificate
+	lans     []*net.IPNet
 )
 
 const (
@@ -346,7 +342,11 @@ func main() {
 	}
 
 	if doUpgrade || doUpgradeCheck {
-		rel, err := upgrade.LatestRelease(cfg.Options().ReleasesURL, Version)
+		releasesURL := "https://api.github.com/repos/syncthing/syncthing/releases?per_page=30"
+		if cfg, _, err := loadConfig(locations[locConfigFile]); err == nil {
+			releasesURL = cfg.Options().ReleasesURL
+		}
+		rel, err := upgrade.LatestRelease(releasesURL, Version)
 		if err != nil {
 			l.Fatalln("Upgrade:", err) // exits 1
 		}
@@ -497,33 +497,21 @@ func syncthingMain() {
 
 	cfgFile := locations[locConfigFile]
 
-	var myName string
-
 	// Load the configuration file, if it exists.
 	// If it does not, create a template.
 
-	if info, err := os.Stat(cfgFile); err == nil {
-		if !info.Mode().IsRegular() {
-			l.Fatalln("Config file is not a file?")
-		}
-		cfg, err = config.Load(cfgFile, myID)
-		if err == nil {
-			myCfg := cfg.Devices()[myID]
-			if myCfg.Name == "" {
-				myName, _ = os.Hostname()
-			} else {
-				myName = myCfg.Name
-			}
+	cfg, myName, err := loadConfig(cfgFile)
+	if err != nil {
+		if os.IsNotExist(err) {
+			l.Infoln("No config file; starting with empty defaults")
+			myName, _ = os.Hostname()
+			newCfg := defaultConfig(myName)
+			cfg = config.Wrap(cfgFile, newCfg)
+			cfg.Save()
+			l.Infof("Edit %s to taste or use the GUI\n", cfgFile)
 		} else {
-			l.Fatalln("Configuration:", err)
+			l.Fatalln("Loading config:", err)
 		}
-	} else {
-		l.Infoln("No config file; starting with empty defaults")
-		myName, _ = os.Hostname()
-		newCfg := defaultConfig(myName)
-		cfg = config.Wrap(cfgFile, newCfg)
-		cfg.Save()
-		l.Infof("Edit %s to taste or use the GUI\n", cfgFile)
 	}
 
 	if cfg.Raw().OriginalVersion != config.CurrentVersion {
@@ -596,9 +584,10 @@ func syncthingMain() {
 	}
 
 	dbFile := locations[locDatabase]
-	ldb, err := leveldb.OpenFile(dbFile, dbOpts())
+	dbOpts := dbOpts(cfg)
+	ldb, err := leveldb.OpenFile(dbFile, dbOpts)
 	if leveldbIsCorrupted(err) {
-		ldb, err = leveldb.RecoverFile(dbFile, dbOpts())
+		ldb, err = leveldb.RecoverFile(dbFile, dbOpts)
 	}
 	if leveldbIsCorrupted(err) {
 		// The database is corrupted, and we've tried to recover it but it
@@ -608,7 +597,7 @@ func syncthingMain() {
 		if err := resetDB(); err != nil {
 			l.Fatalln("Remove database:", err)
 		}
-		ldb, err = leveldb.OpenFile(dbFile, dbOpts())
+		ldb, err = leveldb.OpenFile(dbFile, dbOpts)
 	}
 	if err != nil {
 		l.Fatalln("Cannot open database:", err, "- Is another copy of Syncthing already running?")
@@ -780,7 +769,7 @@ func syncthingMain() {
 	// The usageReportingManager registers itself to listen to configuration
 	// changes, and there's nothing more we need to tell it from the outside.
 	// Hence we don't keep the returned pointer.
-	newUsageReportingManager(m, cfg)
+	newUsageReportingManager(cfg, m)
 
 	if opts.RestartOnWakeup {
 		go standbyMonitor()
@@ -790,7 +779,7 @@ func syncthingMain() {
 		if noUpgrade {
 			l.Infof("No automatic upgrades; STNOUPGRADE environment variable defined.")
 		} else if IsRelease {
-			go autoUpgrade()
+			go autoUpgrade(cfg)
 		} else {
 			l.Infof("No automatic upgrades; %s is not a release version.", Version)
 		}
@@ -816,7 +805,30 @@ func syncthingMain() {
 	os.Exit(code)
 }
 
-func dbOpts() *opt.Options {
+func loadConfig(cfgFile string) (*config.Wrapper, string, error) {
+	info, err := os.Stat(cfgFile)
+	if err != nil {
+		return nil, "", err
+	}
+	if !info.Mode().IsRegular() {
+		return nil, "", errors.New("configuration is not a file")
+	}
+
+	cfg, err := config.Load(cfgFile, myID)
+	if err != nil {
+		return nil, "", err
+	}
+
+	myCfg := cfg.Devices()[myID]
+	myName := myCfg.Name
+	if myName == "" {
+		myName, _ = os.Hostname()
+	}
+
+	return cfg, myName, nil
+}
+
+func dbOpts(cfg *config.Wrapper) *opt.Options {
 	// Calculate a suitable database block cache capacity.
 
 	// Default is 8 MiB.
@@ -896,7 +908,7 @@ func setupGUI(mainSvc *suture.Supervisor, cfg *config.Wrapper, m *model.Model, a
 
 			urlShow := fmt.Sprintf("%s://%s/", proto, net.JoinHostPort(hostShow, strconv.Itoa(addr.Port)))
 			l.Infoln("Starting web GUI on", urlShow)
-			api, err := newAPISvc(myID, guiCfg, guiAssets, m, apiSub, discoverer, relaySvc)
+			api, err := newAPISvc(myID, cfg, guiAssets, m, apiSub, discoverer, relaySvc)
 			if err != nil {
 				l.Fatalln("Cannot start GUI:", err)
 			}
@@ -1066,7 +1078,7 @@ func standbyMonitor() {
 	}
 }
 
-func autoUpgrade() {
+func autoUpgrade(cfg *config.Wrapper) {
 	timer := time.NewTimer(0)
 	sub := events.Default.Subscribe(events.DeviceConnected)
 	for {

+ 6 - 3
cmd/syncthing/summarysvc.go

@@ -9,6 +9,7 @@ package main
 import (
 	"time"
 
+	"github.com/syncthing/syncthing/lib/config"
 	"github.com/syncthing/syncthing/lib/events"
 	"github.com/syncthing/syncthing/lib/model"
 	"github.com/syncthing/syncthing/lib/sync"
@@ -20,6 +21,7 @@ import (
 type folderSummarySvc struct {
 	*suture.Supervisor
 
+	cfg       *config.Wrapper
 	model     *model.Model
 	stop      chan struct{}
 	immediate chan string
@@ -33,9 +35,10 @@ type folderSummarySvc struct {
 	lastEventReqMut sync.Mutex
 }
 
-func newFolderSummarySvc(m *model.Model) *folderSummarySvc {
+func newFolderSummarySvc(cfg *config.Wrapper, m *model.Model) *folderSummarySvc {
 	svc := &folderSummarySvc{
 		Supervisor:      suture.NewSimple("folderSummarySvc"),
+		cfg:             cfg,
 		model:           m,
 		stop:            make(chan struct{}),
 		immediate:       make(chan string),
@@ -162,13 +165,13 @@ func (c *folderSummarySvc) foldersToHandle() []string {
 func (c *folderSummarySvc) sendSummary(folder string) {
 	// The folder summary contains how many bytes, files etc
 	// are in the folder and how in sync we are.
-	data := folderSummary(c.model, folder)
+	data := folderSummary(c.cfg, c.model, folder)
 	events.Default.Log(events.FolderSummary, map[string]interface{}{
 		"folder":  folder,
 		"summary": data,
 	})
 
-	for _, devCfg := range cfg.Folders()[folder].Devices {
+	for _, devCfg := range c.cfg.Folders()[folder].Devices {
 		if devCfg.DeviceID.Equals(myID) {
 			// We already know about ourselves.
 			continue

+ 18 - 9
cmd/syncthing/usage_report.go

@@ -32,12 +32,14 @@ import (
 const usageReportVersion = 2
 
 type usageReportingManager struct {
+	cfg   *config.Wrapper
 	model *model.Model
 	sup   *suture.Supervisor
 }
 
-func newUsageReportingManager(m *model.Model, cfg *config.Wrapper) *usageReportingManager {
+func newUsageReportingManager(cfg *config.Wrapper, m *model.Model) *usageReportingManager {
 	mgr := &usageReportingManager{
+		cfg:   cfg,
 		model: m,
 	}
 
@@ -58,9 +60,7 @@ func (m *usageReportingManager) VerifyConfiguration(from, to config.Configuratio
 func (m *usageReportingManager) CommitConfiguration(from, to config.Configuration) bool {
 	if to.Options.URAccepted >= usageReportVersion && m.sup == nil {
 		// Usage reporting was turned on; lets start it.
-		svc := &usageReportingService{
-			model: m.model,
-		}
+		svc := newUsageReportingService(m.cfg, m.model)
 		m.sup = suture.NewSimple("usageReporting")
 		m.sup.Add(svc)
 		m.sup.ServeBackground()
@@ -79,7 +79,7 @@ func (m *usageReportingManager) String() string {
 
 // reportData returns the data to be sent in a usage report. It's used in
 // various places, so not part of the usageReportingSvc object.
-func reportData(m *model.Model) map[string]interface{} {
+func reportData(cfg *config.Wrapper, m *model.Model) map[string]interface{} {
 	res := make(map[string]interface{})
 	res["urVersion"] = usageReportVersion
 	res["uniqueID"] = cfg.Options().URUniqueID
@@ -238,12 +238,21 @@ func stringIn(needle string, haystack []string) bool {
 }
 
 type usageReportingService struct {
+	cfg   *config.Wrapper
 	model *model.Model
 	stop  chan struct{}
 }
 
+func newUsageReportingService(cfg *config.Wrapper, model *model.Model) *usageReportingService {
+	return &usageReportingService{
+		cfg:   cfg,
+		model: model,
+		stop:  make(chan struct{}),
+	}
+}
+
 func (s *usageReportingService) sendUsageReport() error {
-	d := reportData(s.model)
+	d := reportData(s.cfg, s.model)
 	var b bytes.Buffer
 	json.NewEncoder(&b).Encode(d)
 
@@ -256,12 +265,12 @@ func (s *usageReportingService) sendUsageReport() error {
 		}
 	}
 
-	if cfg.Options().URPostInsecurely {
+	if s.cfg.Options().URPostInsecurely {
 		transp.TLSClientConfig = &tls.Config{
 			InsecureSkipVerify: true,
 		}
 	}
-	_, err := client.Post(cfg.Options().URURL, "application/json", &b)
+	_, err := client.Post(s.cfg.Options().URURL, "application/json", &b)
 	return err
 }
 
@@ -271,7 +280,7 @@ func (s *usageReportingService) Serve() {
 	l.Infoln("Starting usage reporting")
 	defer l.Infoln("Stopping usage reporting")
 
-	t := time.NewTimer(time.Duration(cfg.Options().URInitialDelayS) * time.Second) // time to initial report at start
+	t := time.NewTimer(time.Duration(s.cfg.Options().URInitialDelayS) * time.Second) // time to initial report at start
 	for {
 		select {
 		case <-s.stop: