Przeglądaj źródła

use a sync.Map to avoid concurrency issue accessing log presenters

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 4 lat temu
rodzic
commit
cc47e5385b
1 zmienionych plików z 22 dodań i 15 usunięć
  1. 22 15
      cmd/formatter/logs.go

+ 22 - 15
cmd/formatter/logs.go

@@ -22,6 +22,7 @@ import (
 	"io"
 	"strconv"
 	"strings"
+	"sync"
 
 	"github.com/docker/compose/v2/pkg/api"
 )
@@ -30,7 +31,7 @@ import (
 func NewLogConsumer(ctx context.Context, w io.Writer, color bool, prefix bool) api.LogConsumer {
 	return &logConsumer{
 		ctx:        ctx,
-		presenters: map[string]*presenter{},
+		presenters: sync.Map{},
 		width:      0,
 		writer:     w,
 		color:      color,
@@ -51,53 +52,59 @@ func (l *logConsumer) register(name string) *presenter {
 		colors: cf,
 		name:   name,
 	}
-	l.presenters[name] = p
+	l.presenters.Store(name, p)
 	if l.prefix {
 		l.computeWidth()
-		for _, p := range l.presenters {
+		l.presenters.Range(func(key, value interface{}) bool {
+			p := value.(*presenter)
 			p.setPrefix(l.width)
-		}
+			return true
+		})
 	}
 	return p
 }
 
+func (l *logConsumer) getPresenter(container string) *presenter {
+	p, ok := l.presenters.Load(container)
+	if !ok { // should have been registered, but ¯\_(ツ)_/¯
+		return l.register(container)
+	}
+	return p.(*presenter)
+}
+
 // Log formats a log message as received from name/container
 func (l *logConsumer) Log(container, service, message string) {
 	if l.ctx.Err() != nil {
 		return
 	}
-	p, ok := l.presenters[container]
-	if !ok { // should have been registered, but ¯\_(ツ)_/¯
-		p = l.register(container)
-	}
+	p := l.getPresenter(container)
 	for _, line := range strings.Split(message, "\n") {
 		fmt.Fprintf(l.writer, "%s %s\n", p.prefix, line) // nolint:errcheck
 	}
 }
 
 func (l *logConsumer) Status(container, msg string) {
-	p, ok := l.presenters[container]
-	if !ok {
-		p = l.register(container)
-	}
+	p := l.getPresenter(container)
 	s := p.colors(fmt.Sprintf("%s %s\n", container, msg))
 	l.writer.Write([]byte(s)) // nolint:errcheck
 }
 
 func (l *logConsumer) computeWidth() {
 	width := 0
-	for _, p := range l.presenters {
+	l.presenters.Range(func(key, value interface{}) bool {
+		p := value.(*presenter)
 		if len(p.name) > width {
 			width = len(p.name)
 		}
-	}
+		return true
+	})
 	l.width = width + 1
 }
 
 // LogConsumer consume logs from services and format them
 type logConsumer struct {
 	ctx        context.Context
-	presenters map[string]*presenter
+	presenters sync.Map // map[string]*presenter
 	width      int
 	writer     io.Writer
 	color      bool