Browse Source

cmd/containerboot: fix unclean shutdown (#10035)

* cmd/containerboot: shut down cleanly on SIGTERM

Make sure that tailscaled watcher returns when
SIGTERM is received and also that it shuts down
before tailscaled exits.

Updates tailscale/tailscale#10090

Signed-off-by: Irbe Krumina <[email protected]>
Irbe Krumina 2 years ago
parent
commit
664ebb14d9
1 changed files with 125 additions and 76 deletions
  1. 125 76
      cmd/containerboot/main.go

+ 125 - 76
cmd/containerboot/main.go

@@ -69,6 +69,7 @@ import (
 	"reflect"
 	"strconv"
 	"strings"
+	"sync"
 	"sync/atomic"
 	"syscall"
 	"time"
@@ -181,10 +182,16 @@ func main() {
 		}
 	}
 
-	client, daemonPid, err := startTailscaled(bootCtx, cfg)
+	client, daemonProcess, err := startTailscaled(bootCtx, cfg)
 	if err != nil {
 		log.Fatalf("failed to bring up tailscale: %v", err)
 	}
+	killTailscaled := func() {
+		if err := daemonProcess.Signal(unix.SIGTERM); err != nil {
+			log.Fatalf("error shutting tailscaled down: %v", err)
+		}
+	}
+	defer killTailscaled()
 
 	w, err := client.WatchIPNBus(bootCtx, ipn.NotifyInitialNetMap|ipn.NotifyInitialPrefs|ipn.NotifyInitialState)
 	if err != nil {
@@ -252,7 +259,7 @@ authLoop:
 
 	w.Close()
 
-	ctx, cancel := context.WithCancel(context.Background()) // no deadline now that we're in steady state
+	ctx, cancel := contextWithExitSignalWatch()
 	defer cancel()
 
 	if cfg.AuthOnce {
@@ -306,84 +313,111 @@ authLoop:
 			log.Fatalf("error creating new netfilter runner: %v", err)
 		}
 	}
+	notifyChan := make(chan ipn.Notify)
+	errChan := make(chan error)
+	go func() {
+		for {
+			n, err := w.Next()
+			if err != nil {
+				errChan <- err
+				break
+			} else {
+				notifyChan <- n
+			}
+		}
+	}()
+	var wg sync.WaitGroup
+runLoop:
 	for {
-		n, err := w.Next()
-		if err != nil {
+		select {
+		case <-ctx.Done():
+			// Although killTailscaled() is deferred earlier, if we
+			// have started the reaper defined below, we need to
+			// kill tailscaled and let reaper clean up child
+			// processes.
+			killTailscaled()
+			break runLoop
+		case err := <-errChan:
 			log.Fatalf("failed to read from tailscaled: %v", err)
-		}
-
-		if n.State != nil && *n.State != ipn.Running {
-			// Something's gone wrong and we've left the authenticated state.
-			// Our container image never recovered gracefully from this, and the
-			// control flow required to make it work now is hard. So, just crash
-			// the container and rely on the container runtime to restart us,
-			// whereupon we'll go through initial auth again.
-			log.Fatalf("tailscaled left running state (now in state %q), exiting", *n.State)
-		}
-		if n.NetMap != nil {
-			addrs := n.NetMap.SelfNode.Addresses().AsSlice()
-			newCurrentIPs := deephash.Hash(&addrs)
-			ipsHaveChanged := newCurrentIPs != currentIPs
-			if cfg.ProxyTo != "" && len(addrs) > 0 && ipsHaveChanged {
-				log.Printf("Installing proxy rules")
-				if err := installIngressForwardingRule(ctx, cfg.ProxyTo, addrs, nfr); err != nil {
-					log.Fatalf("installing ingress proxy rules: %v", err)
-				}
+		case n := <-notifyChan:
+			if n.State != nil && *n.State != ipn.Running {
+				// Something's gone wrong and we've left the authenticated state.
+				// Our container image never recovered gracefully from this, and the
+				// control flow required to make it work now is hard. So, just crash
+				// the container and rely on the container runtime to restart us,
+				// whereupon we'll go through initial auth again.
+				log.Fatalf("tailscaled left running state (now in state %q), exiting", *n.State)
 			}
-			if cfg.ServeConfigPath != "" && len(n.NetMap.DNS.CertDomains) > 0 {
-				cd := n.NetMap.DNS.CertDomains[0]
-				prev := certDomain.Swap(ptr.To(cd))
-				if prev == nil || *prev != cd {
-					select {
-					case certDomainChanged <- true:
-					default:
+			if n.NetMap != nil {
+				addrs := n.NetMap.SelfNode.Addresses().AsSlice()
+				newCurrentIPs := deephash.Hash(&addrs)
+				ipsHaveChanged := newCurrentIPs != currentIPs
+				if cfg.ProxyTo != "" && len(addrs) > 0 && ipsHaveChanged {
+					log.Printf("Installing proxy rules")
+					if err := installIngressForwardingRule(ctx, cfg.ProxyTo, addrs, nfr); err != nil {
+						log.Fatalf("installing ingress proxy rules: %v", err)
 					}
 				}
-			}
-			if cfg.TailnetTargetIP != "" && ipsHaveChanged && len(addrs) > 0 {
-				if err := installEgressForwardingRule(ctx, cfg.TailnetTargetIP, addrs, nfr); err != nil {
-					log.Fatalf("installing egress proxy rules: %v", err)
+				if cfg.ServeConfigPath != "" && len(n.NetMap.DNS.CertDomains) > 0 {
+					cd := n.NetMap.DNS.CertDomains[0]
+					prev := certDomain.Swap(ptr.To(cd))
+					if prev == nil || *prev != cd {
+						select {
+						case certDomainChanged <- true:
+						default:
+						}
+					}
 				}
-			}
-			currentIPs = newCurrentIPs
+				if cfg.TailnetTargetIP != "" && ipsHaveChanged && len(addrs) > 0 {
+					if err := installEgressForwardingRule(ctx, cfg.TailnetTargetIP, addrs, nfr); err != nil {
+						log.Fatalf("installing egress proxy rules: %v", err)
+					}
+				}
+				currentIPs = newCurrentIPs
 
-			deviceInfo := []any{n.NetMap.SelfNode.StableID(), n.NetMap.SelfNode.Name()}
-			if cfg.InKubernetes && cfg.KubernetesCanPatch && cfg.KubeSecret != "" && deephash.Update(&currentDeviceInfo, &deviceInfo) {
-				if err := storeDeviceInfo(ctx, cfg.KubeSecret, n.NetMap.SelfNode.StableID(), n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil {
-					log.Fatalf("storing device ID in kube secret: %v", err)
+				deviceInfo := []any{n.NetMap.SelfNode.StableID(), n.NetMap.SelfNode.Name()}
+				if cfg.InKubernetes && cfg.KubernetesCanPatch && cfg.KubeSecret != "" && deephash.Update(&currentDeviceInfo, &deviceInfo) {
+					if err := storeDeviceInfo(ctx, cfg.KubeSecret, n.NetMap.SelfNode.StableID(), n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil {
+						log.Fatalf("storing device ID in kube secret: %v", err)
+					}
 				}
 			}
-		}
-		if !startupTasksDone {
-			if (!wantProxy || currentIPs != deephash.Sum{}) && (!wantDeviceInfo || currentDeviceInfo != deephash.Sum{}) {
-				// This log message is used in tests to detect when all
-				// post-auth configuration is done.
-				log.Println("Startup complete, waiting for shutdown signal")
-				startupTasksDone = true
-
-				// Reap all processes, since we are PID1 and need to collect zombies. We can
-				// only start doing this once we've stopped shelling out to things
-				// `tailscale up`, otherwise this goroutine can reap the CLI subprocesses
-				// and wedge bringup.
-				go func() {
-					for {
-						var status unix.WaitStatus
-						pid, err := unix.Wait4(-1, &status, 0, nil)
-						if errors.Is(err, unix.EINTR) {
-							continue
-						}
-						if err != nil {
-							log.Fatalf("Waiting for exited processes: %v", err)
-						}
-						if pid == daemonPid {
-							log.Printf("Tailscaled exited")
-							os.Exit(0)
+			if !startupTasksDone {
+				if (!wantProxy || currentIPs != deephash.Sum{}) && (!wantDeviceInfo || currentDeviceInfo != deephash.Sum{}) {
+					// This log message is used in tests to detect when all
+					// post-auth configuration is done.
+					log.Println("Startup complete, waiting for shutdown signal")
+					startupTasksDone = true
+
+					// 		// Reap all processes, since we are PID1 and need to collect zombies. We can
+					// 		// only start doing this once we've stopped shelling out to things
+					// 		// `tailscale up`, otherwise this goroutine can reap the CLI subprocesses
+					// 		// and wedge bringup.
+					reaper := func() {
+						defer wg.Done()
+						for {
+							var status unix.WaitStatus
+							pid, err := unix.Wait4(-1, &status, 0, nil)
+							if errors.Is(err, unix.EINTR) {
+								continue
+							}
+							if err != nil {
+								log.Fatalf("Waiting for exited processes: %v", err)
+							}
+							if pid == daemonProcess.Pid {
+								log.Printf("Tailscaled exited")
+								os.Exit(0)
+							}
 						}
+
 					}
-				}()
+					wg.Add(1)
+					go reaper()
+				}
 			}
 		}
 	}
+	wg.Wait()
 }
 
 // watchServeConfigChanges watches path for changes, and when it sees one, reads
@@ -460,10 +494,8 @@ func readServeConfig(path, certDomain string) (*ipn.ServeConfig, error) {
 	return &sc, nil
 }
 
-func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient, int, error) {
+func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient, *os.Process, error) {
 	args := tailscaledArgs(cfg)
-	sigCh := make(chan os.Signal, 1)
-	signal.Notify(sigCh, unix.SIGTERM, unix.SIGINT)
 	// tailscaled runs without context, since it needs to persist
 	// beyond the startup timeout in ctx.
 	cmd := exec.Command("tailscaled", args...)
@@ -474,13 +506,8 @@ func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient
 	}
 	log.Printf("Starting tailscaled")
 	if err := cmd.Start(); err != nil {
-		return nil, 0, fmt.Errorf("starting tailscaled failed: %v", err)
+		return nil, nil, fmt.Errorf("starting tailscaled failed: %v", err)
 	}
-	go func() {
-		<-sigCh
-		log.Printf("Received SIGTERM from container runtime, shutting down tailscaled")
-		cmd.Process.Signal(unix.SIGTERM)
-	}()
 
 	// Wait for the socket file to appear, otherwise API ops will racily fail.
 	log.Printf("Waiting for tailscaled socket")
@@ -503,7 +530,7 @@ func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient
 		UseSocketOnly: true,
 	}
 
-	return tsClient, cmd.Process.Pid, nil
+	return tsClient, cmd.Process, nil
 }
 
 // tailscaledArgs uses cfg to construct the argv for tailscaled.
@@ -801,3 +828,25 @@ func defaultBool(name string, defVal bool) bool {
 	}
 	return ret
 }
+
+// contextWithExitSignalWatch watches for SIGTERM/SIGINT signals. It returns a
+// context that gets cancelled when a signal is received and a cancel function
+// that can be called to free the resources when the watch should be stopped.
+func contextWithExitSignalWatch() (context.Context, func()) {
+	closeChan := make(chan string)
+	ctx, cancel := context.WithCancel(context.Background())
+	signalChan := make(chan os.Signal, 1)
+	signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)
+	go func() {
+		select {
+		case <-signalChan:
+			cancel()
+		case <-closeChan:
+			return
+		}
+	}()
+	f := func() {
+		closeChan <- "goodbye"
+	}
+	return ctx, f
+}