浏览代码

fix SIGTERM support to stop/kill stack

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 2 年之前
父节点
当前提交
dd0803dba1
共有 2 个文件被更改,包括 53 次插入49 次删除
  1. 2 0
      cmd/compose/compose.go
  2. 51 49
      pkg/compose/up.go

+ 2 - 0
cmd/compose/compose.go

@@ -83,6 +83,8 @@ func AdaptCmd(fn CobraCommand) func(cmd *cobra.Command, args []string) error {
 			go func() {
 				<-s
 				cancel()
+				signal.Stop(s)
+				close(s)
 			}()
 		}
 		err := fn(ctx, cmd, args)

+ 51 - 49
pkg/compose/up.go

@@ -21,7 +21,6 @@ import (
 	"fmt"
 	"os"
 	"os/signal"
-	"sync"
 	"syscall"
 
 	"github.com/compose-spec/compose-go/types"
@@ -55,76 +54,79 @@ func (s *composeService) Up(ctx context.Context, project *types.Project, options
 		return err
 	}
 
+	var eg multierror.Group
+
 	// if we get a second signal during shutdown, we kill the services
 	// immediately, so the channel needs to have sufficient capacity or
 	// we might miss a signal while setting up the second channel read
 	// (this is also why signal.Notify is used vs signal.NotifyContext)
 	signalChan := make(chan os.Signal, 2)
 	signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)
-	signalCancel := sync.OnceFunc(func() {
-		signal.Stop(signalChan)
-		close(signalChan)
-	})
-	defer signalCancel()
-
-	printer := newLogPrinter(options.Start.Attach)
-	stopFunc := func() error {
-		fmt.Fprintln(s.stdinfo(), "Aborting on container exit...")
-		ctx := context.Background()
-		return progress.Run(ctx, func(ctx context.Context) error {
-			// race two goroutines - one that blocks until another signal is received
-			// and then does a Kill() and one that immediately starts a friendly Stop()
-			errCh := make(chan error, 1)
-			go func() {
-				if _, ok := <-signalChan; !ok {
-					// channel closed, so the outer function is done, which
-					// means the other goroutine (calling Stop()) finished
-					return
-				}
-				errCh <- s.Kill(ctx, project.Name, api.KillOptions{
-					Services: options.Create.Services,
-					Project:  project,
-				})
-			}()
-
-			go func() {
-				errCh <- s.Stop(ctx, project.Name, api.StopOptions{
-					Services: options.Create.Services,
-					Project:  project,
-				})
-			}()
-			return <-errCh
-		}, s.stdinfo())
-	}
-
+	defer close(signalChan)
 	var isTerminated bool
-	var eg multierror.Group
+
+	doneCh := make(chan bool)
 	eg.Go(func() error {
-		if _, ok := <-signalChan; !ok {
-			// function finished without receiving a signal
-			return nil
+		first := true
+		for {
+			select {
+			case <-doneCh:
+				return nil
+			case <-signalChan:
+				if first {
+					fmt.Fprintln(s.stdinfo(), "Gracefully stopping... (press Ctrl+C again to force)")
+					eg.Go(func() error {
+						err := s.Stop(context.Background(), project.Name, api.StopOptions{
+							Services: options.Create.Services,
+							Project:  project,
+						})
+						isTerminated = true
+						close(doneCh)
+						return err
+					})
+					first = false
+				} else {
+					eg.Go(func() error {
+						return s.Kill(context.Background(), project.Name, api.KillOptions{
+							Services: options.Create.Services,
+							Project:  project,
+						})
+					})
+					return nil
+				}
+			}
 		}
-		isTerminated = true
-		printer.Cancel()
-		fmt.Fprintln(s.stdinfo(), "Gracefully stopping... (press Ctrl+C again to force)")
-		return stopFunc()
 	})
 
+	printer := newLogPrinter(options.Start.Attach)
+
 	var exitCode int
 	eg.Go(func() error {
-		code, err := printer.Run(options.Start.CascadeStop, options.Start.ExitCodeFrom, stopFunc)
+		code, err := printer.Run(options.Start.CascadeStop, options.Start.ExitCodeFrom, func() error {
+			fmt.Fprintln(s.stdinfo(), "Aborting on container exit...")
+			return progress.Run(ctx, func(ctx context.Context) error {
+				return s.Stop(ctx, project.Name, api.StopOptions{
+					Services: options.Create.Services,
+					Project:  project,
+				})
+			}, s.stdinfo())
+		})
 		exitCode = code
 		return err
 	})
 
-	err = s.start(ctx, project.Name, options.Start, printer.HandleEvent)
+	// We don't use parent (cancelable) context as we manage sigterm to stop the stack
+	err = s.start(context.Background(), project.Name, options.Start, printer.HandleEvent)
 	if err != nil && !isTerminated { // Ignore error if the process is terminated
 		return err
 	}
 
-	// signal for the goroutines to stop & wait for them to finish any remaining work
-	signalCancel()
 	printer.Stop()
+
+	if !isTerminated {
+		// signal for the signal-handler goroutines to stop
+		close(doneCh)
+	}
 	err = eg.Wait().ErrorOrNil()
 	if exitCode != 0 {
 		errMsg := ""