Explorar el Código

introduce WithRootNodesAndDown to walk the graph from specified nodes and down

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof hace 2 años
padre
commit
ca19b7fcc9
Se han modificado 4 ficheros con 174 adiciones y 11 borrados
  1. 49 4
      pkg/compose/dependencies.go
  2. 88 0
      pkg/compose/dependencies_test.go
  3. 25 5
      pkg/compose/down.go
  4. 12 2
      pkg/utils/set.go

+ 49 - 4
pkg/compose/dependencies.go

@@ -38,8 +38,9 @@ const (
 )
 
 type graphTraversal struct {
-	mu   sync.Mutex
-	seen map[string]struct{}
+	mu      sync.Mutex
+	seen    map[string]struct{}
+	ignored map[string]struct{}
 
 	extremityNodesFn            func(*Graph) []*Vertex                        // leaves or roots
 	adjacentNodesFn             func(*Vertex) []*Vertex                       // getParents or getChildren
@@ -87,15 +88,46 @@ func InDependencyOrder(ctx context.Context, project *types.Project, fn func(cont
 }
 
 // InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies
-func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error) error {
+func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversal)) error {
 	graph, err := NewGraph(project.Services, ServiceStarted)
 	if err != nil {
 		return err
 	}
 	t := downDirectionTraversal(fn)
+	for _, option := range options {
+		option(t)
+	}
 	return t.visit(ctx, graph)
 }
 
+func WithRootNodesAndDown(nodes []string) func(*graphTraversal) {
+	return func(t *graphTraversal) {
+		if len(nodes) == 0 {
+			return
+		}
+		originalFn := t.extremityNodesFn
+		t.extremityNodesFn = func(graph *Graph) []*Vertex {
+			var want []string
+			for _, node := range nodes {
+				vertex := graph.Vertices[node]
+				want = append(want, vertex.Service)
+				for _, v := range getAncestors(vertex) {
+					want = append(want, v.Service)
+				}
+			}
+
+			t.ignored = map[string]struct{}{}
+			for k := range graph.Vertices {
+				if !utils.Contains(want, k) {
+					t.ignored[k] = struct{}{}
+				}
+			}
+
+			return originalFn(graph)
+		}
+	}
+}
+
 func (t *graphTraversal) visit(ctx context.Context, g *Graph) error {
 	expect := len(g.Vertices)
 	if expect == 0 {
@@ -142,7 +174,10 @@ func (t *graphTraversal) run(ctx context.Context, graph *Graph, eg *errgroup.Gro
 		}
 
 		eg.Go(func() error {
-			err := t.visitorFn(ctx, node.Service)
+			var err error
+			if _, ignore := t.ignored[node.Service]; !ignore {
+				err = t.visitorFn(ctx, node.Service)
+			}
 			if err == nil {
 				graph.UpdateStatus(node.Key, t.targetServiceStatus)
 			}
@@ -197,6 +232,16 @@ func getChildren(v *Vertex) []*Vertex {
 	return v.GetChildren()
 }
 
+// getAncestors return all descendents for a vertex, might contain duplicates
+func getAncestors(v *Vertex) []*Vertex {
+	var descendents []*Vertex
+	for _, parent := range v.GetParents() {
+		descendents = append(descendents, parent)
+		descendents = append(descendents, getAncestors(parent)...)
+	}
+	return descendents
+}
+
 // GetChildren returns a slice with the child vertices of the a Vertex
 func (v *Vertex) GetChildren() []*Vertex {
 	var res []*Vertex

+ 88 - 0
pkg/compose/dependencies_test.go

@@ -19,9 +19,12 @@ package compose
 import (
 	"context"
 	"fmt"
+	"sort"
+	"sync"
 	"testing"
 
 	"github.com/compose-spec/compose-go/types"
+	"github.com/docker/compose/v2/pkg/utils"
 	testify "github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 	"gotest.tools/v3/assert"
@@ -297,3 +300,88 @@ func isVertexEqual(a, b Vertex) bool {
 		childrenEquality &&
 		parentEquality
 }
+
+func TestWith_RootNodesAndUp(t *testing.T) {
+	graph := &Graph{
+		lock:     sync.RWMutex{},
+		Vertices: map[string]*Vertex{},
+	}
+
+	/** graph topology:
+	           A   B
+		      / \ / \
+		     G   C   E
+		          \ /
+		           D
+		           |
+		           F
+	*/
+
+	graph.AddVertex("A", "A", 0)
+	graph.AddVertex("B", "B", 0)
+	graph.AddVertex("C", "C", 0)
+	graph.AddVertex("D", "D", 0)
+	graph.AddVertex("E", "E", 0)
+	graph.AddVertex("F", "F", 0)
+	graph.AddVertex("G", "G", 0)
+
+	_ = graph.AddEdge("C", "A")
+	_ = graph.AddEdge("C", "B")
+	_ = graph.AddEdge("E", "B")
+	_ = graph.AddEdge("D", "C")
+	_ = graph.AddEdge("D", "E")
+	_ = graph.AddEdge("F", "D")
+	_ = graph.AddEdge("G", "A")
+
+	tests := []struct {
+		name  string
+		nodes []string
+		want  []string
+	}{
+		{
+			name:  "whole graph",
+			nodes: []string{"A", "B"},
+			want:  []string{"A", "B", "C", "D", "E", "F", "G"},
+		},
+		{
+			name:  "only leaves",
+			nodes: []string{"F", "G"},
+			want:  []string{"F", "G"},
+		},
+		{
+			name:  "simple dependent",
+			nodes: []string{"D"},
+			want:  []string{"D", "F"},
+		},
+		{
+			name:  "diamond dependents",
+			nodes: []string{"B"},
+			want:  []string{"B", "C", "D", "E", "F"},
+		},
+		{
+			name:  "partial graph",
+			nodes: []string{"A"},
+			want:  []string{"A", "C", "D", "F", "G"},
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			mx := sync.Mutex{}
+			expected := utils.Set[string]{}
+			expected.AddAll("C", "G", "D", "F")
+			var visited []string
+
+			gt := downDirectionTraversal(func(ctx context.Context, s string) error {
+				mx.Lock()
+				defer mx.Unlock()
+				visited = append(visited, s)
+				return nil
+			})
+			WithRootNodesAndDown(tt.nodes)(gt)
+			err := gt.visit(context.TODO(), graph)
+			assert.NilError(t, err)
+			sort.Strings(visited)
+			assert.DeepEqual(t, tt.want, visited)
+		})
+	}
+}

+ 25 - 5
pkg/compose/down.go

@@ -44,7 +44,7 @@ func (s *composeService) Down(ctx context.Context, projectName string, options a
 	}, s.stdinfo())
 }
 
-func (s *composeService) down(ctx context.Context, projectName string, options api.DownOptions) error { //golint:nocyclo
+func (s *composeService) down(ctx context.Context, projectName string, options api.DownOptions) error { //nolint:gocyclo
 	w := progress.ContextWriter(ctx)
 	resourceToRemove := false
 
@@ -65,18 +65,21 @@ func (s *composeService) down(ctx context.Context, projectName string, options a
 		}
 	}
 
+	// Check requested services exists in model
+	options.Services, err = checkSelectedServices(options, project)
+	if err != nil {
+		return err
+	}
+
 	if len(containers) > 0 {
 		resourceToRemove = true
 	}
 
 	err = InReverseDependencyOrder(ctx, project, func(c context.Context, service string) error {
-		if len(options.Services) > 0 && !utils.StringContains(options.Services, service) {
-			return nil
-		}
 		serviceContainers := containers.filter(isService(service))
 		err := s.removeContainers(ctx, w, serviceContainers, options.Timeout, options.Volumes)
 		return err
-	})
+	}, WithRootNodesAndDown(options.Services))
 	if err != nil {
 		return err
 	}
@@ -114,6 +117,23 @@ func (s *composeService) down(ctx context.Context, projectName string, options a
 	return eg.Wait()
 }
 
+func checkSelectedServices(options api.DownOptions, project *types.Project) ([]string, error) {
+	var services []string
+	for _, service := range options.Services {
+		_, err := project.GetService(service)
+		if err != nil {
+			if options.Project != nil {
+				// ran with an explicit compose.yaml file, so we should not ignore
+				return nil, err
+			}
+			// ran without an explicit compose.yaml file, so can't distinguish typo vs container already removed
+		} else {
+			services = append(services, service)
+		}
+	}
+	return services, nil
+}
+
 func (s *composeService) ensureVolumesDown(ctx context.Context, project *types.Project, w progress.Writer) []downOp {
 	var ops []downOp
 	for _, vol := range project.Volumes {

+ 12 - 2
pkg/utils/set.go

@@ -20,8 +20,18 @@ func (s Set[T]) Add(v T) {
 	s[v] = struct{}{}
 }
 
-func (s Set[T]) Remove(v T) {
-	delete(s, v)
+func (s Set[T]) AddAll(v ...T) {
+	for _, e := range v {
+		s[e] = struct{}{}
+	}
+}
+
+func (s Set[T]) Remove(v T) bool {
+	_, ok := s[v]
+	if ok {
+		delete(s, v)
+	}
+	return ok
 }
 
 func (s Set[T]) Clear() {