1
0
Эх сурвалжийг харах

Use a lock port to ensure parent has exited (fixes #450)

Jakob Borg 11 жил өмнө
parent
commit
28220310a5
1 өөрчлөгдсөн 35 нэмэгдсэн , 18 устгасан
  1. 35 18
      cmd/syncthing/main.go

+ 35 - 18
cmd/syncthing/main.go

@@ -79,6 +79,8 @@ var (
 	rateBucket *ratelimit.Bucket
 	stop       = make(chan bool)
 	discoverer *discover.Discoverer
+	lockConn   *net.TCPListener
+	lockPort   int
 )
 
 const (
@@ -149,6 +151,12 @@ func main() {
 
 	l.SetFlags(logFlags)
 
+	var err error
+	lockPort, err = getLockPort()
+	if err != nil {
+		l.Fatalln("Opening lock port:", err)
+	}
+
 	if doUpgrade || doUpgradeCheck {
 		rel, err := upgrade.LatestRelease(strings.Contains(Version, "-beta"))
 		if err != nil {
@@ -278,6 +286,10 @@ func main() {
 		return
 	}
 
+	if len(os.Getenv("STRESTART")) > 0 {
+		waitForParentExit()
+	}
+
 	if profiler := os.Getenv("STPROFILER"); len(profiler) > 0 {
 		go func() {
 			l.Debugln("Starting profiler on", profiler)
@@ -289,10 +301,6 @@ func main() {
 		}()
 	}
 
-	if len(os.Getenv("STRESTART")) > 0 {
-		waitForParentExit()
-	}
-
 	// The TLS configuration is used for both the listening socket and outgoing
 	// connections.
 
@@ -497,16 +505,21 @@ func generateEvents() {
 
 func waitForParentExit() {
 	l.Infoln("Waiting for parent to exit...")
+	lockPortStr := os.Getenv("STRESTART")
+	lockPort, err := strconv.Atoi(lockPortStr)
+	if err != nil {
+		l.Warnln("Invalid lock port %q: %v", lockPortStr, err)
+	}
 	// Wait for the listen address to become free, indicating that the parent has exited.
 	for {
-		ln, err := net.Listen("tcp", cfg.Options.ListenAddress[0])
+		ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", lockPort))
 		if err == nil {
 			ln.Close()
 			break
 		}
 		time.Sleep(250 * time.Millisecond)
 	}
-	l.Okln("Continuing")
+	l.Infoln("Continuing")
 }
 
 func setupUPnP(r rand.Source) int {
@@ -604,16 +617,21 @@ func restart() {
 	}
 
 	env := os.Environ()
-	if len(os.Getenv("STRESTART")) == 0 {
-		env = append(env, "STRESTART=1")
+	newEnv := make([]string, 0, len(env))
+	for _, s := range env {
+		if !strings.HasPrefix(s, "STRESTART=") {
+			newEnv = append(newEnv, s)
+		}
 	}
+	newEnv = append(newEnv, fmt.Sprintf("STRESTART=%d", lockPort))
+
 	pgm, err := exec.LookPath(os.Args[0])
 	if err != nil {
 		l.Warnln("Cannot restart:", err)
 		return
 	}
 	proc, err := os.StartProcess(pgm, os.Args, &os.ProcAttr{
-		Env:   env,
+		Env:   newEnv,
 		Files: []*os.File{os.Stdin, os.Stdout, os.Stderr},
 	})
 	if err != nil {
@@ -973,18 +991,17 @@ func getFreePort(host string, ports ...int) (int, error) {
 	if err != nil {
 		return 0, err
 	}
-	addr := c.Addr().String()
+	addr := c.Addr().(*net.TCPAddr)
 	c.Close()
+	return addr.Port, nil
+}
 
-	_, portstr, err := net.SplitHostPort(addr)
-	if err != nil {
-		return 0, err
-	}
-
-	port, err := strconv.Atoi(portstr)
+func getLockPort() (int, error) {
+	var err error
+	lockConn, err = net.ListenTCP("tcp", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}})
 	if err != nil {
 		return 0, err
 	}
-
-	return port, nil
+	addr := lockConn.Addr().(*net.TCPAddr)
+	return addr.Port, nil
 }