Browse Source

up: fix various race/deadlock conditions on exit (#10934)

If running `up` in foreground mode (i.e. not `-d`),
when exiting via `Ctrl-C`, Compose stops all the
services it launched directly as part of that `up`
command.

In one of the E2E tests (`TestUpDependenciesNotStopped`),
this was occasionally flaking because the stop
behavior was racy: the return might not block on
the stop operation because it gets added to the
error group in a goroutine. As a result, it was
possible for no services to get terminated on exit.

There were a few other related pieces here that
I uncovered and tried to fix while stressing this.
For example, the printer could cause a deadlock if
an event was sent to it after it stopped.

Also, an error group wasn't really appropriate here;
each goroutine is a different operation for printing,
signal-handling, etc. If one part fails, we don't
actually want printing to stop, for example. This has
been switched to a `multierror.Group`, which has the
same API but coalesces errors instead of canceling a
context the moment the first one fails and returning
that single error.

Signed-off-by: Milas Bowman <[email protected]>
Milas Bowman 2 years ago
parent
commit
407a0d5b53
4 changed files with 105 additions and 82 deletions
  1. 56 54
      pkg/compose/printer.go
  2. 41 18
      pkg/compose/up.go
  3. 3 2
      pkg/e2e/assert.go
  4. 5 8
      pkg/e2e/up_test.go

+ 56 - 54
pkg/compose/printer.go

@@ -18,6 +18,7 @@ package compose
 
 import (
 	"fmt"
+	"sync/atomic"
 
 	"github.com/docker/compose/v2/pkg/api"
 )
@@ -33,32 +34,37 @@ type logPrinter interface {
 type printer struct {
 	queue    chan api.ContainerEvent
 	consumer api.LogConsumer
-	stopCh   chan struct{}
+	stopped  atomic.Bool
 }
 
 // newLogPrinter builds a LogPrinter passing containers logs to LogConsumer
 func newLogPrinter(consumer api.LogConsumer) logPrinter {
 	queue := make(chan api.ContainerEvent)
-	stopCh := make(chan struct{}, 1) // printer MAY stop on his own, so Stop MUST not be blocking
 	printer := printer{
 		consumer: consumer,
 		queue:    queue,
-		stopCh:   stopCh,
 	}
 	return &printer
 }
 
 func (p *printer) Cancel() {
-	p.queue <- api.ContainerEvent{
-		Type: api.UserCancel,
-	}
+	// note: HandleEvent is used to ensure this doesn't deadlock
+	p.HandleEvent(api.ContainerEvent{Type: api.UserCancel})
 }
 
 func (p *printer) Stop() {
-	p.stopCh <- struct{}{}
+	if p.stopped.CompareAndSwap(false, true) {
+		// only close if this is the first call to stop
+		close(p.queue)
+	}
 }
 
 func (p *printer) HandleEvent(event api.ContainerEvent) {
+	// prevent deadlocking, if the printer is done, there's no reader for
+	// queue, so this write could block indefinitely
+	if p.stopped.Load() {
+		return
+	}
 	p.queue <- event
 }
 
@@ -69,61 +75,57 @@ func (p *printer) Run(cascadeStop bool, exitCodeFrom string, stopFn func() error
 		exitCode int
 	)
 	containers := map[string]struct{}{}
-	for {
-		select {
-		case <-p.stopCh:
-			return exitCode, nil
-		case event := <-p.queue:
-			container, id := event.Container, event.ID
-			switch event.Type {
-			case api.UserCancel:
-				aborting = true
-			case api.ContainerEventAttach:
-				if _, ok := containers[id]; ok {
-					continue
-				}
-				containers[id] = struct{}{}
-				p.consumer.Register(container)
-			case api.ContainerEventExit, api.ContainerEventStopped, api.ContainerEventRecreated:
-				if !event.Restarting {
-					delete(containers, id)
+	for event := range p.queue {
+		container, id := event.Container, event.ID
+		switch event.Type {
+		case api.UserCancel:
+			aborting = true
+		case api.ContainerEventAttach:
+			if _, ok := containers[id]; ok {
+				continue
+			}
+			containers[id] = struct{}{}
+			p.consumer.Register(container)
+		case api.ContainerEventExit, api.ContainerEventStopped, api.ContainerEventRecreated:
+			if !event.Restarting {
+				delete(containers, id)
+			}
+			if !aborting {
+				p.consumer.Status(container, fmt.Sprintf("exited with code %d", event.ExitCode))
+				if event.Type == api.ContainerEventRecreated {
+					p.consumer.Status(container, "has been recreated")
 				}
+			}
+			if cascadeStop {
 				if !aborting {
-					p.consumer.Status(container, fmt.Sprintf("exited with code %d", event.ExitCode))
-					if event.Type == api.ContainerEventRecreated {
-						p.consumer.Status(container, "has been recreated")
+					aborting = true
+					err := stopFn()
+					if err != nil {
+						return 0, err
 					}
 				}
-				if cascadeStop {
-					if !aborting {
-						aborting = true
-						err := stopFn()
-						if err != nil {
-							return 0, err
-						}
+				if event.Type == api.ContainerEventExit {
+					if exitCodeFrom == "" {
+						exitCodeFrom = event.Service
 					}
-					if event.Type == api.ContainerEventExit {
-						if exitCodeFrom == "" {
-							exitCodeFrom = event.Service
-						}
-						if exitCodeFrom == event.Service {
-							exitCode = event.ExitCode
-						}
+					if exitCodeFrom == event.Service {
+						exitCode = event.ExitCode
 					}
 				}
-				if len(containers) == 0 {
-					// Last container terminated, done
-					return exitCode, nil
-				}
-			case api.ContainerEventLog:
-				if !aborting {
-					p.consumer.Log(container, event.Line)
-				}
-			case api.ContainerEventErr:
-				if !aborting {
-					p.consumer.Err(container, event.Line)
-				}
+			}
+			if len(containers) == 0 {
+				// Last container terminated, done
+				return exitCode, nil
+			}
+		case api.ContainerEventLog:
+			if !aborting {
+				p.consumer.Log(container, event.Line)
+			}
+		case api.ContainerEventErr:
+			if !aborting {
+				p.consumer.Err(container, event.Line)
 			}
 		}
 	}
+	return exitCode, nil
 }

+ 41 - 18
pkg/compose/up.go

@@ -21,15 +21,15 @@ import (
 	"fmt"
 	"os"
 	"os/signal"
+	"sync"
 	"syscall"
 
-	"github.com/docker/compose/v2/internal/tracing"
-
 	"github.com/compose-spec/compose-go/types"
 	"github.com/docker/cli/cli"
+	"github.com/docker/compose/v2/internal/tracing"
 	"github.com/docker/compose/v2/pkg/api"
 	"github.com/docker/compose/v2/pkg/progress"
-	"golang.org/x/sync/errgroup"
+	"github.com/hashicorp/go-multierror"
 )
 
 func (s *composeService) Up(ctx context.Context, project *types.Project, options api.UpOptions) error {
@@ -55,39 +55,60 @@ func (s *composeService) Up(ctx context.Context, project *types.Project, options
 		return err
 	}
 
-	printer := newLogPrinter(options.Start.Attach)
-
-	signalChan := make(chan os.Signal, 1)
+	// if we get a second signal during shutdown, we kill the services
+	// immediately, so the channel needs to have sufficient capacity or
+	// we might miss a signal while setting up the second channel read
+	// (this is also why signal.Notify is used vs signal.NotifyContext)
+	signalChan := make(chan os.Signal, 2)
 	signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)
+	signalCancel := sync.OnceFunc(func() {
+		signal.Stop(signalChan)
+		close(signalChan)
+	})
+	defer signalCancel()
 
+	printer := newLogPrinter(options.Start.Attach)
 	stopFunc := func() error {
 		fmt.Fprintln(s.stdinfo(), "Aborting on container exit...")
 		ctx := context.Background()
 		return progress.Run(ctx, func(ctx context.Context) error {
+			// race two goroutines - one that blocks until another signal is received
+			// and then does a Kill() and one that immediately starts a friendly Stop()
+			errCh := make(chan error, 1)
 			go func() {
-				<-signalChan
-				s.Kill(ctx, project.Name, api.KillOptions{ //nolint:errcheck
+				if _, ok := <-signalChan; !ok {
+					// channel closed, so the outer function is done, which
+					// means the other goroutine (calling Stop()) finished
+					return
+				}
+				errCh <- s.Kill(ctx, project.Name, api.KillOptions{
 					Services: options.Create.Services,
 					Project:  project,
 				})
 			}()
 
-			return s.Stop(ctx, project.Name, api.StopOptions{
-				Services: options.Create.Services,
-				Project:  project,
-			})
+			go func() {
+				errCh <- s.Stop(ctx, project.Name, api.StopOptions{
+					Services: options.Create.Services,
+					Project:  project,
+				})
+			}()
+			return <-errCh
 		}, s.stdinfo())
 	}
 
 	var isTerminated bool
-	eg, ctx := errgroup.WithContext(ctx)
-	go func() {
-		<-signalChan
+	var eg multierror.Group
+	eg.Go(func() error {
+		if _, ok := <-signalChan; !ok {
+			// function finished without receiving a signal
+			return nil
+		}
 		isTerminated = true
 		printer.Cancel()
 		fmt.Fprintln(s.stdinfo(), "Gracefully stopping... (press Ctrl+C again to force)")
-		eg.Go(stopFunc)
-	}()
+		return stopFunc()
+	})
 
 	var exitCode int
 	eg.Go(func() error {
@@ -101,8 +122,10 @@ func (s *composeService) Up(ctx context.Context, project *types.Project, options
 		return err
 	}
 
+	// signal for the goroutines to stop & wait for them to finish any remaining work
+	signalCancel()
 	printer.Stop()
-	err = eg.Wait()
+	err = eg.Wait().ErrorOrNil()
 	if exitCode != 0 {
 		errMsg := ""
 		if err != nil {

+ 3 - 2
pkg/e2e/assert.go

@@ -28,10 +28,11 @@ import (
 // (running or exited).
 func RequireServiceState(t testing.TB, cli *CLI, service string, state string) {
 	t.Helper()
-	psRes := cli.RunDockerComposeCmd(t, "ps", "--format=json", service)
+	psRes := cli.RunDockerComposeCmd(t, "ps", "--all", "--format=json", service)
 	var svc map[string]interface{}
 	require.NoError(t, json.Unmarshal([]byte(psRes.Stdout()), &svc),
-		"Invalid `compose ps` JSON output")
+		"Invalid `compose ps` JSON: command output: %s",
+		psRes.Combined())
 
 	require.Equal(t, service, svc["Service"],
 		"Found ps output for unexpected service")

+ 5 - 8
pkg/e2e/up_test.go

@@ -21,7 +21,6 @@ package e2e
 
 import (
 	"context"
-	"os"
 	"os/exec"
 	"strings"
 	"syscall"
@@ -45,9 +44,6 @@ func TestUpServiceUnhealthy(t *testing.T) {
 }
 
 func TestUpDependenciesNotStopped(t *testing.T) {
-	if _, ok := os.LookupEnv("CI"); ok {
-		t.Skip("Skipping test on CI... flaky")
-	}
 	c := NewParallelCLI(t, WithEnv(
 		"COMPOSE_PROJECT_NAME=up-deps-stop",
 	))
@@ -76,8 +72,8 @@ func TestUpDependenciesNotStopped(t *testing.T) {
 		"app",
 	)
 
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
+	ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+	t.Cleanup(cancel)
 
 	cmd, err := StartWithNewGroupID(ctx, testCmd, upOut, nil)
 	assert.NilError(t, err, "Failed to run compose up")
@@ -91,12 +87,13 @@ func TestUpDependenciesNotStopped(t *testing.T) {
 	require.NoError(t, syscall.Kill(-cmd.Process.Pid, syscall.SIGINT),
 		"Failed to send SIGINT to compose up process")
 
-	time.AfterFunc(5*time.Second, cancel)
-
 	t.Log("Waiting for `compose up` to exit")
 	err = cmd.Wait()
 	if err != nil {
 		exitErr := err.(*exec.ExitError)
+		if exitErr.ExitCode() == -1 {
+			t.Fatalf("`compose up` was killed: %v", err)
+		}
 		require.EqualValues(t, exitErr.ExitCode(), 130)
 	}