Parcourir la source

fix deadlock collecting large logs

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof il y a 1 an
Parent
commit
07bda5960e
3 fichiers modifiés avec 100 ajouts et 60 suppressions
  1. 70 60
      pkg/compose/printer.go
  2. 6 0
      pkg/e2e/fixtures/logs-test/cat.yaml
  3. 24 0
      pkg/e2e/logs_test.go

+ 70 - 60
pkg/compose/printer.go

@@ -32,18 +32,19 @@ type logPrinter interface {
 }
 
 type printer struct {
-	sync.Mutex
 	queue    chan api.ContainerEvent
 	consumer api.LogConsumer
-	stopped  bool
+	stopCh   chan struct{} // stopCh is a signal channel for producers to stop sending events to the queue
+	stop     sync.Once
 }
 
 // newLogPrinter builds a LogPrinter passing containers logs to LogConsumer
 func newLogPrinter(consumer api.LogConsumer) logPrinter {
-	queue := make(chan api.ContainerEvent)
 	printer := printer{
 		consumer: consumer,
-		queue:    queue,
+		queue:    make(chan api.ContainerEvent),
+		stopCh:   make(chan struct{}),
+		stop:     sync.Once{},
 	}
 	return &printer
 }
@@ -54,24 +55,27 @@ func (p *printer) Cancel() {
 }
 
 func (p *printer) Stop() {
-	p.Lock()
-	defer p.Unlock()
-	if !p.stopped {
-		// only close if this is the first call to stop
-		p.stopped = true
-		close(p.queue)
-	}
+	p.stop.Do(func() {
+		close(p.stopCh)
+		for {
+			select {
+			case <-p.queue:
+				// purge the queue to free producers goroutines
+				// p.queue will be garbage collected
+			default:
+				return
+			}
+		}
+	})
 }
 
 func (p *printer) HandleEvent(event api.ContainerEvent) {
-	p.Lock()
-	defer p.Unlock()
-	if p.stopped {
-		// prevent deadlocking, if the printer is done, there's no reader for
-		// queue, so this write could block indefinitely
+	select {
+	case <-p.stopCh:
 		return
+	default:
+		p.queue <- event
 	}
-	p.queue <- event
 }
 
 //nolint:gocyclo
@@ -80,58 +84,64 @@ func (p *printer) Run(cascadeStop bool, exitCodeFrom string, stopFn func() error
 		aborting bool
 		exitCode int
 	)
+	defer p.Stop()
+
 	containers := map[string]struct{}{}
-	for event := range p.queue {
-		container, id := event.Container, event.ID
-		switch event.Type {
-		case api.UserCancel:
-			aborting = true
-		case api.ContainerEventAttach:
-			if _, ok := containers[id]; ok {
-				continue
-			}
-			containers[id] = struct{}{}
-			p.consumer.Register(container)
-		case api.ContainerEventExit, api.ContainerEventStopped, api.ContainerEventRecreated:
-			if !event.Restarting {
-				delete(containers, id)
-			}
-			if !aborting {
-				p.consumer.Status(container, fmt.Sprintf("exited with code %d", event.ExitCode))
-				if event.Type == api.ContainerEventRecreated {
-					p.consumer.Status(container, "has been recreated")
+	for {
+		select {
+		case <-p.stopCh:
+			return exitCode, nil
+		case event := <-p.queue:
+			container, id := event.Container, event.ID
+			switch event.Type {
+			case api.UserCancel:
+				aborting = true
+			case api.ContainerEventAttach:
+				if _, ok := containers[id]; ok {
+					continue
+				}
+				containers[id] = struct{}{}
+				p.consumer.Register(container)
+			case api.ContainerEventExit, api.ContainerEventStopped, api.ContainerEventRecreated:
+				if !event.Restarting {
+					delete(containers, id)
 				}
-			}
-			if cascadeStop {
 				if !aborting {
-					aborting = true
-					err := stopFn()
-					if err != nil {
-						return 0, err
+					p.consumer.Status(container, fmt.Sprintf("exited with code %d", event.ExitCode))
+					if event.Type == api.ContainerEventRecreated {
+						p.consumer.Status(container, "has been recreated")
 					}
 				}
-				if event.Type == api.ContainerEventExit {
-					if exitCodeFrom == "" {
-						exitCodeFrom = event.Service
+				if cascadeStop {
+					if !aborting {
+						aborting = true
+						err := stopFn()
+						if err != nil {
+							return 0, err
+						}
 					}
-					if exitCodeFrom == event.Service {
-						exitCode = event.ExitCode
+					if event.Type == api.ContainerEventExit {
+						if exitCodeFrom == "" {
+							exitCodeFrom = event.Service
+						}
+						if exitCodeFrom == event.Service {
+							exitCode = event.ExitCode
+						}
 					}
 				}
-			}
-			if len(containers) == 0 {
-				// Last container terminated, done
-				return exitCode, nil
-			}
-		case api.ContainerEventLog:
-			if !aborting {
-				p.consumer.Log(container, event.Line)
-			}
-		case api.ContainerEventErr:
-			if !aborting {
-				p.consumer.Err(container, event.Line)
+				if len(containers) == 0 {
+					// Last container terminated, done
+					return exitCode, nil
+				}
+			case api.ContainerEventLog:
+				if !aborting {
+					p.consumer.Log(container, event.Line)
+				}
+			case api.ContainerEventErr:
+				if !aborting {
+					p.consumer.Err(container, event.Line)
+				}
 			}
 		}
 	}
-	return exitCode, nil
 }

+ 6 - 0
pkg/e2e/fixtures/logs-test/cat.yaml

@@ -0,0 +1,6 @@
+services:
+  test:
+    image: alpine
+    command: cat /text_file.txt
+    volumes:
+      - ${FILE}:/text_file.txt

+ 24 - 0
pkg/e2e/logs_test.go

@@ -17,6 +17,10 @@
 package e2e
 
 import (
+	"fmt"
+	"io"
+	"os"
+	"path/filepath"
 	"strings"
 	"testing"
 	"time"
@@ -96,6 +100,26 @@ func TestLocalComposeLogsFollow(t *testing.T) {
 	poll.WaitOn(t, expectOutput(res, "ping-2 "), poll.WithDelay(100*time.Millisecond), poll.WithTimeout(20*time.Second))
 }
 
+func TestLocalComposeLargeLogs(t *testing.T) {
+	const projectName = "compose-e2e-large_logs"
+	file := filepath.Join(t.TempDir(), "large.txt")
+	c := NewCLI(t, WithEnv("FILE="+file))
+	t.Cleanup(func() {
+		c.RunDockerComposeCmd(t, "--project-name", projectName, "down")
+	})
+
+	f, err := os.Create(file)
+	assert.NilError(t, err)
+	for i := 0; i < 300_000; i++ {
+		_, err := io.WriteString(f, fmt.Sprintf("This is line %d in a laaaarge text file\n", i))
+		assert.NilError(t, err)
+	}
+	assert.NilError(t, f.Close())
+
+	res := c.RunDockerComposeCmd(t, "-f", "./fixtures/logs-test/cat.yaml", "--project-name", projectName, "up", "--abort-on-container-exit")
+	res.Assert(t, icmd.Expected{Out: "test-1 exited with code 0"})
+}
+
 func expectOutput(res *icmd.Result, expected string) func(t poll.LogT) poll.Result {
 	return func(t poll.LogT) poll.Result {
 		if strings.Contains(res.Stdout(), expected) {