Procházet zdrojové kódy

deps: fix race condition during graph traversal (#9878)

Keep track of visited nodes to prevent visiting a service multiple
times. This is possible when a service depends on multiple others,
as an attempt could be made to visit it from multiple parents.

Signed-off-by: Milas Bowman <[email protected]>
Milas Bowman před 3 roky
rodič
revize
616777eb4a
2 změnil soubory, kde provedl 93 přidání a 20 odebrání
  1. 47 20
      pkg/compose/dependencies.go
  2. 46 0
      pkg/compose/dependencies_test.go

+ 47 - 20
pkg/compose/dependencies.go

@@ -37,38 +37,49 @@ const (
 	ServiceStarted
 )
 
-type graphTraversalConfig struct {
+type graphTraversal struct {
+	mu   sync.Mutex
+	seen map[string]struct{}
+
 	extremityNodesFn            func(*Graph) []*Vertex                        // leaves or roots
 	adjacentNodesFn             func(*Vertex) []*Vertex                       // getParents or getChildren
 	filterAdjacentByStatusFn    func(*Graph, string, ServiceStatus) []*Vertex // filterChildren or filterParents
 	targetServiceStatus         ServiceStatus
 	adjacentServiceStatusToSkip ServiceStatus
+
+	visitorFn func(context.Context, string) error
 }
 
-var (
-	upDirectionTraversalConfig = graphTraversalConfig{
+func upDirectionTraversal(visitorFn func(context.Context, string) error) *graphTraversal {
+	return &graphTraversal{
 		extremityNodesFn:            leaves,
 		adjacentNodesFn:             getParents,
 		filterAdjacentByStatusFn:    filterChildren,
 		adjacentServiceStatusToSkip: ServiceStopped,
 		targetServiceStatus:         ServiceStarted,
+		visitorFn:                   visitorFn,
 	}
-	downDirectionTraversalConfig = graphTraversalConfig{
+}
+
+func downDirectionTraversal(visitorFn func(context.Context, string) error) *graphTraversal {
+	return &graphTraversal{
 		extremityNodesFn:            roots,
 		adjacentNodesFn:             getChildren,
 		filterAdjacentByStatusFn:    filterParents,
 		adjacentServiceStatusToSkip: ServiceStarted,
 		targetServiceStatus:         ServiceStopped,
+		visitorFn:                   visitorFn,
 	}
-)
+}
 
 // InDependencyOrder applies the function to the services of the project taking in account the dependency order
-func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversalConfig)) error {
+func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversal)) error {
 	graph, err := NewGraph(project.Services, ServiceStopped)
 	if err != nil {
 		return err
 	}
-	return visit(ctx, graph, upDirectionTraversalConfig, fn)
+	t := upDirectionTraversal(fn)
+	return t.visit(ctx, graph)
 }
 
 // InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies
@@ -77,43 +88,59 @@ func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn fu
 	if err != nil {
 		return err
 	}
-	return visit(ctx, graph, downDirectionTraversalConfig, fn)
+	t := downDirectionTraversal(fn)
+	return t.visit(ctx, graph)
 }
 
-func visit(ctx context.Context, g *Graph, traversalConfig graphTraversalConfig, fn func(context.Context, string) error) error {
-	nodes := traversalConfig.extremityNodesFn(g)
+func (t *graphTraversal) visit(ctx context.Context, g *Graph) error {
+	nodes := t.extremityNodesFn(g)
 
-	eg, _ := errgroup.WithContext(ctx)
-	eg.Go(func() error {
-		return run(ctx, g, eg, nodes, traversalConfig, fn)
-	})
+	eg, ctx := errgroup.WithContext(ctx)
+	t.run(ctx, g, eg, nodes)
 
 	return eg.Wait()
 }
 
 // Note: this could be `graph.walk` or whatever
-func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, traversalConfig graphTraversalConfig, fn func(context.Context, string) error) error {
+func (t *graphTraversal) run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex) {
 	for _, node := range nodes {
 		// Don't start this service yet if all of its children have
 		// not been started yet.
-		if len(traversalConfig.filterAdjacentByStatusFn(graph, node.Key, traversalConfig.adjacentServiceStatusToSkip)) != 0 {
+		if len(t.filterAdjacentByStatusFn(graph, node.Key, t.adjacentServiceStatusToSkip)) != 0 {
 			continue
 		}
 
 		node := node
+		if !t.consume(node.Key) {
+			// another worker already visited this node
+			continue
+		}
+
 		eg.Go(func() error {
-			err := fn(ctx, node.Service)
+			err := t.visitorFn(ctx, node.Service)
 			if err != nil {
 				return err
 			}
 
-			graph.UpdateStatus(node.Key, traversalConfig.targetServiceStatus)
+			graph.UpdateStatus(node.Key, t.targetServiceStatus)
 
-			return run(ctx, graph, eg, traversalConfig.adjacentNodesFn(node), traversalConfig, fn)
+			t.run(ctx, graph, eg, t.adjacentNodesFn(node))
+			return nil
 		})
 	}
+}
 
-	return nil
+func (t *graphTraversal) consume(nodeKey string) bool {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	if t.seen == nil {
+		t.seen = make(map[string]struct{})
+	}
+	if _, ok := t.seen[nodeKey]; ok {
+		return false
+	}
+	t.seen[nodeKey] = struct{}{}
+	return true
 }
 
 // Graph represents project as service dependencies

+ 46 - 0
pkg/compose/dependencies_test.go

@@ -22,6 +22,7 @@ import (
 	"testing"
 
 	"github.com/compose-spec/compose-go/types"
+	testify "github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 	"gotest.tools/assert"
 )
@@ -46,6 +47,51 @@ var project = types.Project{
 	},
 }
 
+func TestTraversalWithMultipleParents(t *testing.T) {
+	dependent := types.ServiceConfig{
+		Name:      "dependent",
+		DependsOn: make(types.DependsOnConfig),
+	}
+
+	project := types.Project{
+		Services: []types.ServiceConfig{dependent},
+	}
+
+	for i := 1; i <= 100; i++ {
+		name := fmt.Sprintf("svc_%d", i)
+		dependent.DependsOn[name] = types.ServiceDependency{}
+
+		svc := types.ServiceConfig{Name: name}
+		project.Services = append(project.Services, svc)
+	}
+
+	ctx, cancel := context.WithCancel(context.Background())
+	t.Cleanup(cancel)
+
+	svc := make(chan string, 10)
+	seen := make(map[string]int)
+	done := make(chan struct{})
+	go func() {
+		for service := range svc {
+			seen[service]++
+		}
+		done <- struct{}{}
+	}()
+
+	err := InDependencyOrder(ctx, &project, func(ctx context.Context, service string) error {
+		svc <- service
+		return nil
+	})
+	require.NoError(t, err, "Error during iteration")
+	close(svc)
+	<-done
+
+	testify.Len(t, seen, 101)
+	for svc, count := range seen {
+		assert.Equal(t, 1, count, "Service: %s", svc)
+	}
+}
+
 func TestInDependencyUpCommandOrder(t *testing.T) {
 	ctx, cancel := context.WithCancel(context.Background())
 	t.Cleanup(cancel)