Browse Source

Add order to down command

Signed-off-by: Ulysses Souza <[email protected]>
Ulysses Souza 5 years ago
parent
commit
7e4cfc0e3b
6 changed files with 246 additions and 56 deletions
  1. 1 1
      cli/cmd/compose/up.go
  2. 97 28
      local/compose.go
  3. 107 8
      local/dependencies.go
  4. 32 19
      local/dependencies_test.go
  5. 4 0
      local/labels.go
  6. 5 0
      progress/event.go

+ 1 - 1
cli/cmd/compose/up.go

@@ -64,7 +64,7 @@ func runUp(ctx context.Context, opts composeOptions, services []string) error {
 			return "", err
 		}
 		if opts.DomainName != "" {
-			//arbitrarily set the domain name on the first service ; ACI backend will expose the entire project
+			// arbitrarily set the domain name on the first service ; ACI backend will expose the entire project
 			project.Services[0].DomainName = opts.DomainName
 		}
 

+ 97 - 28
local/compose.go

@@ -29,6 +29,7 @@ import (
 	"strconv"
 	"strings"
 
+	"github.com/compose-spec/compose-go/cli"
 	"github.com/compose-spec/compose-go/types"
 	"github.com/docker/buildx/build"
 	"github.com/docker/cli/cli/config"
@@ -203,7 +204,12 @@ func (s *composeService) Up(ctx context.Context, project *types.Project, detach
 		}
 	}
 
-	err = inDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error {
+	err = s.ensureImagesExists(ctx, project)
+	if err != nil {
+		return err
+	}
+
+	err = InDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error {
 		return s.ensureService(c, project, service)
 	})
 	return err
@@ -220,7 +226,27 @@ func getContainerName(c moby.Container) string {
 }
 
 func (s *composeService) Down(ctx context.Context, projectName string) error {
-	list, err := s.apiClient.ContainerList(ctx, moby.ContainerListOptions{
+	eg, _ := errgroup.WithContext(ctx)
+	w := progress.ContextWriter(ctx)
+
+	project, err := s.projectFromContainerLabels(ctx, projectName)
+	if err != nil || project == nil {
+		return err
+	}
+
+	err = InReverseDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error {
+		filter := filters.NewArgs(projectFilter(project.Name), serviceFilter(service.Name))
+		return s.removeContainers(ctx, w, eg, filter)
+	})
+
+	if err != nil {
+		return err
+	}
+	err = eg.Wait()
+	if err != nil {
+		return err
+	}
+	networks, err := s.apiClient.NetworkList(ctx, moby.NetworkListOptions{
 		Filters: filters.NewArgs(
 			projectFilter(projectName),
 		),
@@ -228,48 +254,91 @@ func (s *composeService) Down(ctx context.Context, projectName string) error {
 	if err != nil {
 		return err
 	}
+	for _, n := range networks {
+		networkID := n.ID
+		networkName := n.Name
+		eg.Go(func() error {
+			return s.ensureNetworkDown(ctx, networkID, networkName)
+		})
+	}
 
-	eg, _ := errgroup.WithContext(ctx)
-	w := progress.ContextWriter(ctx)
-	for _, c := range list {
-		container := c
+	return eg.Wait()
+}
+
+func (s *composeService) removeContainers(ctx context.Context, w progress.Writer, eg *errgroup.Group, filter filters.Args) error {
+	cnts, err := s.apiClient.ContainerList(ctx, moby.ContainerListOptions{
+		Filters: filter,
+	})
+	if err != nil {
+		return err
+	}
+	for _, c := range cnts {
 		eg.Go(func() error {
-			w.Event(progress.NewEvent(getContainerName(container), progress.Working, "Stopping"))
-			err := s.apiClient.ContainerStop(ctx, container.ID, nil)
+			cName := getContainerName(c)
+			w.Event(progress.StoppingEvent(cName))
+			err := s.apiClient.ContainerStop(ctx, c.ID, nil)
 			if err != nil {
-				w.Event(progress.ErrorMessageEvent(getContainerName(container), "Error while Stopping"))
+				w.Event(progress.ErrorMessageEvent(cName, "Error while Stopping"))
 				return err
 			}
-			w.Event(progress.RemovingEvent(getContainerName(container)))
-			err = s.apiClient.ContainerRemove(ctx, container.ID, moby.ContainerRemoveOptions{})
+			w.Event(progress.RemovingEvent(cName))
+			err = s.apiClient.ContainerRemove(ctx, c.ID, moby.ContainerRemoveOptions{})
 			if err != nil {
-				w.Event(progress.ErrorMessageEvent(getContainerName(container), "Error while Removing"))
+				w.Event(progress.ErrorMessageEvent(cName, "Error while Removing"))
 				return err
 			}
-			w.Event(progress.RemovedEvent(getContainerName(container)))
+			w.Event(progress.RemovedEvent(cName))
 			return nil
 		})
 	}
-	err = eg.Wait()
-	if err != nil {
-		return err
-	}
-	networks, err := s.apiClient.NetworkList(ctx, moby.NetworkListOptions{
+	return nil
+}
+
+func (s *composeService) projectFromContainerLabels(ctx context.Context, projectName string) (*types.Project, error) {
+	cnts, err := s.apiClient.ContainerList(ctx, moby.ContainerListOptions{
 		Filters: filters.NewArgs(
 			projectFilter(projectName),
 		),
 	})
 	if err != nil {
-		return err
+		return nil, err
 	}
-	for _, network := range networks {
-		networkID := network.ID
-		networkName := network.Name
-		eg.Go(func() error {
-			return s.ensureNetworkDown(ctx, networkID, networkName)
-		})
+	if len(cnts) == 0 {
+		return nil, nil
 	}
-	return eg.Wait()
+	options, err := loadProjectOptionsFromLabels(cnts[0])
+	if err != nil {
+		return nil, err
+	}
+	if options.ConfigPaths[0] == "-" {
+		fakeProject := &types.Project{
+			Name: projectName,
+		}
+		for _, c := range cnts {
+			fakeProject.Services = append(fakeProject.Services, types.ServiceConfig{
+				Name: c.Labels[serviceLabel],
+			})
+		}
+		return fakeProject, nil
+	}
+	project, err := cli.ProjectFromOptions(options)
+	if err != nil {
+		return nil, err
+	}
+
+	return project, nil
+}
+
+func loadProjectOptionsFromLabels(c moby.Container) (*cli.ProjectOptions, error) {
+	var configFiles []string
+	relativePathConfigFiles := strings.Split(c.Labels[configFilesLabel], ",")
+	for _, c := range relativePathConfigFiles {
+		configFiles = append(configFiles, filepath.Base(c))
+	}
+	return cli.NewProjectOptions(configFiles,
+		cli.WithOsEnv,
+		cli.WithWorkingDirectory(c.Labels[workingDirLabel]),
+		cli.WithName(c.Labels[projectLabel]))
 }
 
 func (s *composeService) Logs(ctx context.Context, projectName string, w io.Writer) error {
@@ -443,7 +512,7 @@ func getContainerCreateOptions(p *types.Project, s types.ServiceConfig, number i
 	if err != nil {
 		return nil, nil, nil, err
 	}
-	//TODO: change oneoffLabel value for containers started with `docker compose run`
+	// TODO: change oneoffLabel value for containers started with `docker compose run`
 	labels := map[string]string{
 		projectLabel:         p.Name,
 		serviceLabel:         s.Name,
@@ -653,7 +722,7 @@ func getNetworkMode(p *types.Project, service types.ServiceConfig) container.Net
 		return container.NetworkMode("none")
 	}
 
-	/// FIXME incomplete implementation
+	// FIXME incomplete implementation
 	if strings.HasPrefix(mode, "service:") {
 		panic("Not yet implemented")
 	}

+ 107 - 8
local/dependencies.go

@@ -37,29 +37,64 @@ const (
 	ServiceStarted
 )
 
-func inDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, types.ServiceConfig) error) error {
+type graphTraversalConfig struct {
+	extremityNodesFn            func(*Graph) []*Vertex                        // leaves or roots
+	adjacentNodesFn             func(*Vertex) []*Vertex                       // getParents or getChildren
+	filterAdjacentByStatusFn    func(*Graph, string, ServiceStatus) []*Vertex // filterChildren or filterParents
+	targetServiceStatus         ServiceStatus
+	adjacentServiceStatusToSkip ServiceStatus
+}
+
+var (
+	upDirectionTraversalConfig = graphTraversalConfig{
+		extremityNodesFn:            leaves,
+		adjacentNodesFn:             getParents,
+		filterAdjacentByStatusFn:    filterChildren,
+		adjacentServiceStatusToSkip: ServiceStopped,
+		targetServiceStatus:         ServiceStarted,
+	}
+	downDirectionTraversalConfig = graphTraversalConfig{
+		extremityNodesFn:            roots,
+		adjacentNodesFn:             getChildren,
+		filterAdjacentByStatusFn:    filterParents,
+		adjacentServiceStatusToSkip: ServiceStarted,
+		targetServiceStatus:         ServiceStopped,
+	}
+)
+
+// InDependencyOrder applies the function to the services of the project taking in account the dependency order
+func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, types.ServiceConfig) error) error {
+	return visit(ctx, project, upDirectionTraversalConfig, fn)
+}
+
+// InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies
+func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, types.ServiceConfig) error) error {
+	return visit(ctx, project, downDirectionTraversalConfig, fn)
+}
+
+func visit(ctx context.Context, project *types.Project, traversalConfig graphTraversalConfig, fn func(context.Context, types.ServiceConfig) error) error {
 	g := NewGraph(project.Services)
 	if b, err := g.HasCycles(); b {
 		return err
 	}
 
-	leaves := g.Leaves()
+	nodes := traversalConfig.extremityNodesFn(g)
 
 	eg, _ := errgroup.WithContext(ctx)
 	eg.Go(func() error {
-		return run(ctx, g, eg, leaves, fn)
+		return run(ctx, g, eg, nodes, traversalConfig, fn)
 	})
 
 	return eg.Wait()
 }
 
 // Note: this could be `graph.walk` or whatever
-func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, fn func(context.Context, types.ServiceConfig) error) error {
+func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, traversalConfig graphTraversalConfig, fn func(context.Context, types.ServiceConfig) error) error {
 	for _, node := range nodes {
 		n := node
 		// Don't start this service yet if all of its children have
 		// not been started yet.
-		if len(graph.FilterChildren(n.Service.Name, ServiceStopped)) != 0 {
+		if len(traversalConfig.filterAdjacentByStatusFn(graph, n.Service.Name, traversalConfig.adjacentServiceStatusToSkip)) != 0 {
 			continue
 		}
 
@@ -69,9 +104,9 @@ func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex,
 				return err
 			}
 
-			graph.UpdateStatus(n.Service.Name, ServiceStarted)
+			graph.UpdateStatus(n.Service.Name, traversalConfig.targetServiceStatus)
 
-			return run(ctx, graph, eg, n.GetParents(), fn)
+			return run(ctx, graph, eg, traversalConfig.adjacentNodesFn(n), traversalConfig, fn)
 		})
 	}
 
@@ -93,7 +128,11 @@ type Vertex struct {
 	Parents  map[string]*Vertex
 }
 
-// GetParents returns a slice with the parent vertexes of the current Vertex
+func getParents(v *Vertex) []*Vertex {
+	return v.GetParents()
+}
+
+// GetParents returns a slice with the parent vertexes of the a Vertex
 func (v *Vertex) GetParents() []*Vertex {
 	var res []*Vertex
 	for _, p := range v.Parents {
@@ -102,6 +141,19 @@ func (v *Vertex) GetParents() []*Vertex {
 	return res
 }
 
+func getChildren(v *Vertex) []*Vertex {
+	return v.GetChildren()
+}
+
+// GetChildren returns a slice with the child vertexes of the a Vertex
+func (v *Vertex) GetChildren() []*Vertex {
+	var res []*Vertex
+	for _, p := range v.Children {
+		res = append(res, p)
+	}
+	return res
+}
+
 // NewGraph returns the dependency graph of the services
 func NewGraph(services types.Services) *Graph {
 	graph := &Graph{
@@ -168,6 +220,10 @@ func (g *Graph) AddEdge(source string, destination string) error {
 	return nil
 }
 
+func leaves(g *Graph) []*Vertex {
+	return g.Leaves()
+}
+
 // Leaves returns the slice of leaves of the graph
 func (g *Graph) Leaves() []*Vertex {
 	g.lock.Lock()
@@ -183,6 +239,24 @@ func (g *Graph) Leaves() []*Vertex {
 	return res
 }
 
+func roots(g *Graph) []*Vertex {
+	return g.Roots()
+}
+
+// Roots returns the slice of "Roots" of the graph
+func (g *Graph) Roots() []*Vertex {
+	g.lock.Lock()
+	defer g.lock.Unlock()
+
+	var res []*Vertex
+	for _, v := range g.Vertices {
+		if len(v.Parents) == 0 {
+			res = append(res, v)
+		}
+	}
+	return res
+}
+
 // UpdateStatus updates the status of a certain vertex
 func (g *Graph) UpdateStatus(key string, status ServiceStatus) {
 	g.lock.Lock()
@@ -190,6 +264,10 @@ func (g *Graph) UpdateStatus(key string, status ServiceStatus) {
 	g.Vertices[key].Status = status
 }
 
+func filterChildren(g *Graph, k string, s ServiceStatus) []*Vertex {
+	return g.FilterChildren(k, s)
+}
+
 // FilterChildren returns children of a certain vertex that are in a certain status
 func (g *Graph) FilterChildren(key string, status ServiceStatus) []*Vertex {
 	g.lock.Lock()
@@ -207,6 +285,27 @@ func (g *Graph) FilterChildren(key string, status ServiceStatus) []*Vertex {
 	return res
 }
 
+func filterParents(g *Graph, k string, s ServiceStatus) []*Vertex {
+	return g.FilterParents(k, s)
+}
+
+// FilterParents returns the parents of a certain vertex that are in a certain status
+func (g *Graph) FilterParents(key string, status ServiceStatus) []*Vertex {
+	g.lock.Lock()
+	defer g.lock.Unlock()
+
+	var res []*Vertex
+	vertex := g.Vertices[key]
+
+	for _, parent := range vertex.Parents {
+		if parent.Status == status {
+			res = append(res, parent)
+		}
+	}
+
+	return res
+}
+
 // HasCycles detects cycles in the graph
 func (g *Graph) HasCycles() (bool, error) {
 	discovered := []string{}

+ 32 - 19
local/dependencies_test.go

@@ -27,29 +27,30 @@ import (
 	"github.com/compose-spec/compose-go/types"
 )
 
-func TestInDependencyOrder(t *testing.T) {
-	order := make(chan string)
-	project := types.Project{
-		Services: []types.ServiceConfig{
-			{
-				Name: "test1",
-				DependsOn: map[string]types.ServiceDependency{
-					"test2": {},
-				},
-			},
-			{
-				Name: "test2",
-				DependsOn: map[string]types.ServiceDependency{
-					"test3": {},
-				},
+var project = types.Project{
+	Services: []types.ServiceConfig{
+		{
+			Name: "test1",
+			DependsOn: map[string]types.ServiceDependency{
+				"test2": {},
 			},
-			{
-				Name: "test3",
+		},
+		{
+			Name: "test2",
+			DependsOn: map[string]types.ServiceDependency{
+				"test3": {},
 			},
 		},
-	}
+		{
+			Name: "test3",
+		},
+	},
+}
+
+func TestInDependencyUpCommandOrder(t *testing.T) {
+	order := make(chan string)
 	//nolint:errcheck, unparam
-	go inDependencyOrder(context.TODO(), &project, func(ctx context.Context, config types.ServiceConfig) error {
+	go InDependencyOrder(context.TODO(), &project, func(ctx context.Context, config types.ServiceConfig) error {
 		order <- config.Name
 		return nil
 	})
@@ -57,3 +58,15 @@ func TestInDependencyOrder(t *testing.T) {
 	assert.Equal(t, <-order, "test2")
 	assert.Equal(t, <-order, "test1")
 }
+
+func TestInDependencyReverseDownCommandOrder(t *testing.T) {
+	order := make(chan string)
+	//nolint:errcheck, unparam
+	go InReverseDependencyOrder(context.TODO(), &project, func(ctx context.Context, config types.ServiceConfig) error {
+		order <- config.Name
+		return nil
+	})
+	assert.Equal(t, <-order, "test1")
+	assert.Equal(t, <-order, "test2")
+	assert.Equal(t, <-order, "test3")
+}

+ 4 - 0
local/labels.go

@@ -44,6 +44,10 @@ func projectFilter(projectName string) filters.KeyValuePair {
 	return filters.Arg("label", fmt.Sprintf("%s=%s", projectLabel, projectName))
 }
 
+func serviceFilter(serviceName string) filters.KeyValuePair {
+	return filters.Arg("label", fmt.Sprintf("%s=%s", serviceLabel, serviceName))
+}
+
 func hasProjectLabelFilter() filters.KeyValuePair {
 	return filters.Arg("label", projectLabel)
 }

+ 5 - 0
progress/event.go

@@ -62,6 +62,11 @@ func CreatedEvent(ID string) Event {
 	return NewEvent(ID, Done, "Created")
 }
 
+// StoppingEvent stops a new Removing in progress Event
+func StoppingEvent(ID string) Event {
+	return NewEvent(ID, Working, "Stopping")
+}
+
 // RemovingEvent creates a new Removing in progress Event
 func RemovingEvent(ID string) Event {
 	return NewEvent(ID, Working, "Removing")