浏览代码

Merge pull request #1841 from ndeloof/convergence

Nicolas De loof 4 年之前
父节点
当前提交
d20c3b0e22

+ 2 - 2
local/e2e/compose/ipc_test.go

@@ -59,7 +59,7 @@ func TestIPC(t *testing.T) {
 	t.Run("down", func(t *testing.T) {
 		_ = c.RunDockerCmd("compose", "--project-name", projectName, "down")
 	})
-	t.Run("stop ipc mode container", func(t *testing.T) {
-		_ = c.RunDockerCmd("stop", "ipc_mode_container")
+	t.Run("remove ipc mode container", func(t *testing.T) {
+		_ = c.RunDockerCmd("rm", "-f", "ipc_mode_container")
 	})
 }

+ 1 - 1
local/e2e/compose/networks_test.go

@@ -86,7 +86,7 @@ func TestNetworkAliassesAndLinks(t *testing.T) {
 	})
 
 	t.Run("curl links", func(t *testing.T) {
-		res := c.RunDockerCmd("compose", "-f", "./fixtures/network-alias/compose.yaml", "--project-name", projectName, "exec", "-T", "container1", "curl", "container")
+		res := c.RunDockerCmd("compose", "-f", "./fixtures/network-alias/compose.yaml", "--project-name", projectName, "exec", "-T", "container1", "curl", "http://container/")
 		assert.Assert(t, strings.Contains(res.Stdout(), "Welcome to nginx!"), res.Stdout())
 	})
 

+ 183 - 150
pkg/compose/convergence.go

@@ -21,6 +21,7 @@ import (
 	"fmt"
 	"strconv"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/compose-spec/compose-go/types"
@@ -46,76 +47,147 @@ const (
 		"Remove the custom name to scale the service.\n"
 )
 
-func (s *composeService) ensureScale(ctx context.Context, project *types.Project, service types.ServiceConfig, timeout *time.Duration) (*errgroup.Group, []moby.Container, error) {
-	cState, err := GetContextContainerState(ctx)
-	if err != nil {
-		return nil, nil, err
+// convergence manages service's container lifecycle.
+// Based on initially observed state, it reconciles the existing container with desired state, which might include
+// re-creating container, adding or removing replicas, or starting stopped containers.
+// Cross services dependencies are managed by creating services in expected order and updating `service:xx` reference
+// when a service has converged, so dependent ones can be managed with resolved containers references.
+type convergence struct {
+	service       *composeService
+	observedState map[string]Containers
+}
+
+func newConvergence(services []string, state Containers, s *composeService) *convergence {
+	observedState := map[string]Containers{}
+	for _, s := range services {
+		observedState[s] = Containers{}
 	}
-	observedState := cState.GetContainers()
-	actual := observedState.filter(isService(service.Name)).filter(isNotOneOff)
-	scale, err := getScale(service)
-	if err != nil {
-		return nil, nil, err
+	for _, c := range state.filter(isNotOneOff) {
+		service := c.Labels[api.ServiceLabel]
+		observedState[service] = append(observedState[service], c)
 	}
-	eg, _ := errgroup.WithContext(ctx)
-	if len(actual) < scale {
-		next, err := nextContainerNumber(actual)
+	return &convergence{
+		service:       s,
+		observedState: observedState,
+	}
+}
+
+func (c *convergence) apply(ctx context.Context, project *types.Project, options api.CreateOptions) error {
+	return InDependencyOrder(ctx, project, func(ctx context.Context, name string) error {
+		service, err := project.GetService(name)
 		if err != nil {
-			return nil, actual, err
+			return err
 		}
-		missing := scale - len(actual)
-		for i := 0; i < missing; i++ {
-			number := next + i
-			name := getContainerName(project.Name, service, number)
-			eg.Go(func() error {
-				return s.createContainer(ctx, project, service, name, number, false, true)
-			})
+
+		strategy := options.RecreateDependencies
+		if utils.StringContains(options.Services, name) {
+			strategy = options.Recreate
+		}
+		err = c.ensureService(ctx, project, service, strategy, options.Inherit, options.Timeout)
+		if err != nil {
+			return err
 		}
-	}
 
-	if len(actual) > scale {
-		for i := scale; i < len(actual); i++ {
-			container := actual[i]
-			eg.Go(func() error {
-				err := s.apiClient.ContainerStop(ctx, container.ID, timeout)
-				if err != nil {
-					return err
+		c.updateProject(project, name)
+		return nil
+	})
+}
+
+var mu sync.Mutex
+
+// updateProject updates project after service converged, so dependent services relying on `service:xx` can refer to actual containers.
+func (c *convergence) updateProject(project *types.Project, service string) {
+	containers := c.observedState[service]
+	container := containers[0]
+
+	// operation is protected by a Mutex so that we can safely update project.Services while running concurrent convergence on services
+	mu.Lock()
+	defer mu.Unlock()
+
+	for i, s := range project.Services {
+		if d := getDependentServiceFromMode(s.NetworkMode); d == service {
+			s.NetworkMode = types.NetworkModeContainerPrefix + container.ID
+		}
+		if d := getDependentServiceFromMode(s.Ipc); d == service {
+			s.Ipc = types.NetworkModeContainerPrefix + container.ID
+		}
+		if d := getDependentServiceFromMode(s.Pid); d == service {
+			s.Pid = types.NetworkModeContainerPrefix + container.ID
+		}
+		var links []string
+		for _, serviceLink := range s.Links {
+			parts := strings.Split(serviceLink, ":")
+			serviceName := serviceLink
+			serviceAlias := ""
+			if len(parts) == 2 {
+				serviceName = parts[0]
+				serviceAlias = parts[1]
+			}
+			if serviceName != service {
+				links = append(links, serviceLink)
+				continue
+			}
+			for _, container := range containers {
+				name := getCanonicalContainerName(container)
+				if serviceAlias != "" {
+					links = append(links,
+						fmt.Sprintf("%s:%s", name, serviceAlias))
 				}
-				return s.apiClient.ContainerRemove(ctx, container.ID, moby.ContainerRemoveOptions{})
-			})
+				links = append(links,
+					fmt.Sprintf("%s:%s", name, name),
+					fmt.Sprintf("%s:%s", name, getContainerNameWithoutProject(container)))
+			}
+			s.Links = links
 		}
-		actual = actual[:scale]
+		project.Services[i] = s
 	}
-	return eg, actual, nil
 }
 
-func (s *composeService) ensureService(ctx context.Context, project *types.Project, service types.ServiceConfig, recreate string, inherit bool, timeout *time.Duration) error {
-	eg, actual, err := s.ensureScale(ctx, project, service, timeout)
+func (c *convergence) ensureService(ctx context.Context, project *types.Project, service types.ServiceConfig, recreate string, inherit bool, timeout *time.Duration) error {
+	expected, err := getScale(service)
 	if err != nil {
 		return err
 	}
+	containers := c.observedState[service.Name]
+	actual := len(containers)
+	updated := make(Containers, expected)
 
-	if recreate == api.RecreateNever {
-		return nil
-	}
+	eg, _ := errgroup.WithContext(ctx)
 
-	expected, err := ServiceHash(service)
-	if err != nil {
-		return err
-	}
+	for i, container := range containers {
+		if i > expected {
+			// Scale Down
+			eg.Go(func() error {
+				err := c.service.apiClient.ContainerStop(ctx, container.ID, timeout)
+				if err != nil {
+					return err
+				}
+				return c.service.apiClient.ContainerRemove(ctx, container.ID, moby.ContainerRemoveOptions{})
+			})
+			continue
+		}
 
-	for _, container := range actual {
-		container := container
+		if recreate == api.RecreateNever {
+			continue
+		}
+		// Re-create diverged containers
+		configHash, err := ServiceHash(service)
+		if err != nil {
+			return err
+		}
 		name := getContainerProgressName(container)
-
-		diverged := container.Labels[api.ConfigHashLabel] != expected
+		diverged := container.Labels[api.ConfigHashLabel] != configHash
 		if diverged || recreate == api.RecreateForce || service.Extensions[extLifecycle] == forceRecreate {
+			i := i
 			eg.Go(func() error {
-				return s.recreateContainer(ctx, project, service, container, inherit, timeout)
+				recreated, err := c.service.recreateContainer(ctx, project, service, container, inherit, timeout)
+				updated[i] = recreated
+				return err
 			})
 			continue
 		}
 
+		// Enforce non-diverged containers are running
 		w := progress.ContextWriter(ctx)
 		switch container.State {
 		case ContainerRunning:
@@ -126,11 +198,31 @@ func (s *composeService) ensureService(ctx context.Context, project *types.Proje
 			w.Event(progress.CreatedEvent(name))
 		default:
 			eg.Go(func() error {
-				return s.startContainer(ctx, container)
+				return c.service.startContainer(ctx, container)
 			})
 		}
+		updated[i] = container
 	}
-	return eg.Wait()
+
+	next, err := nextContainerNumber(containers)
+	if err != nil {
+		return err
+	}
+	for i := 0; i < expected-actual; i++ {
+		// Scale UP
+		number := next + i
+		name := getContainerName(project.Name, service, number)
+		eg.Go(func() error {
+			container, err := c.service.createContainer(ctx, project, service, name, number, false, true)
+			updated[actual+i-1] = container
+			return err
+		})
+		continue
+	}
+
+	err = eg.Wait()
+	c.observedState[service.Name] = updated
+	return err
 }
 
 func getContainerName(projectName string, service types.ServiceConfig, number int) string {
@@ -220,51 +312,54 @@ func getScale(config types.ServiceConfig) (int, error) {
 	return scale, err
 }
 
-func (s *composeService) createContainer(ctx context.Context, project *types.Project, service types.ServiceConfig, name string, number int, autoRemove bool, useNetworkAliases bool) error {
+func (s *composeService) createContainer(ctx context.Context, project *types.Project, service types.ServiceConfig,
+	name string, number int, autoRemove bool, useNetworkAliases bool) (container moby.Container, err error) {
 	w := progress.ContextWriter(ctx)
 	eventName := "Container " + name
 	w.Event(progress.CreatingEvent(eventName))
-	err := s.createMobyContainer(ctx, project, service, name, number, nil, autoRemove, useNetworkAliases)
+	container, err = s.createMobyContainer(ctx, project, service, name, number, nil, autoRemove, useNetworkAliases)
 	if err != nil {
-		return err
+		return
 	}
 	w.Event(progress.CreatedEvent(eventName))
-	return nil
+	return
 }
 
-func (s *composeService) recreateContainer(ctx context.Context, project *types.Project, service types.ServiceConfig, container moby.Container, inherit bool, timeout *time.Duration) error {
+func (s *composeService) recreateContainer(ctx context.Context, project *types.Project, service types.ServiceConfig,
+	replaced moby.Container, inherit bool, timeout *time.Duration) (moby.Container, error) {
+	var created moby.Container
 	w := progress.ContextWriter(ctx)
-	w.Event(progress.NewEvent(getContainerProgressName(container), progress.Working, "Recreate"))
-	err := s.apiClient.ContainerStop(ctx, container.ID, timeout)
+	w.Event(progress.NewEvent(getContainerProgressName(replaced), progress.Working, "Recreate"))
+	err := s.apiClient.ContainerStop(ctx, replaced.ID, timeout)
 	if err != nil {
-		return err
+		return created, err
 	}
-	name := getCanonicalContainerName(container)
-	tmpName := fmt.Sprintf("%s_%s", container.ID[:12], name)
-	err = s.apiClient.ContainerRename(ctx, container.ID, tmpName)
+	name := getCanonicalContainerName(replaced)
+	tmpName := fmt.Sprintf("%s_%s", replaced.ID[:12], name)
+	err = s.apiClient.ContainerRename(ctx, replaced.ID, tmpName)
 	if err != nil {
-		return err
+		return created, err
 	}
-	number, err := strconv.Atoi(container.Labels[api.ContainerNumberLabel])
+	number, err := strconv.Atoi(replaced.Labels[api.ContainerNumberLabel])
 	if err != nil {
-		return err
+		return created, err
 	}
 
 	var inherited *moby.Container
 	if inherit {
-		inherited = &container
+		inherited = &replaced
 	}
-	err = s.createMobyContainer(ctx, project, service, name, number, inherited, false, true)
+	created, err = s.createMobyContainer(ctx, project, service, name, number, inherited, false, true)
 	if err != nil {
-		return err
+		return created, err
 	}
-	err = s.apiClient.ContainerRemove(ctx, container.ID, moby.ContainerRemoveOptions{})
+	err = s.apiClient.ContainerRemove(ctx, replaced.ID, moby.ContainerRemoveOptions{})
 	if err != nil {
-		return err
+		return created, err
 	}
-	w.Event(progress.NewEvent(getContainerProgressName(container), progress.Done, "Recreated"))
+	w.Event(progress.NewEvent(getContainerProgressName(replaced), progress.Done, "Recreated"))
 	setDependentLifecycle(project, service.Name, forceRecreate)
-	return nil
+	return created, err
 }
 
 // setDependentLifecycle define the Lifecycle strategy for all services to depend on specified service
@@ -291,35 +386,31 @@ func (s *composeService) startContainer(ctx context.Context, container moby.Cont
 	return nil
 }
 
-func (s *composeService) createMobyContainer(ctx context.Context, project *types.Project, service types.ServiceConfig, name string, number int,
-	inherit *moby.Container,
-	autoRemove bool,
-	useNetworkAliases bool) error {
-	cState, err := GetContextContainerState(ctx)
-	if err != nil {
-		return err
-	}
+func (s *composeService) createMobyContainer(ctx context.Context, project *types.Project, service types.ServiceConfig,
+	name string, number int, inherit *moby.Container, autoRemove bool, useNetworkAliases bool) (moby.Container, error) {
+	var created moby.Container
 	containerConfig, hostConfig, networkingConfig, err := s.getCreateOptions(ctx, project, service, number, inherit, autoRemove)
 	if err != nil {
-		return err
+		return created, err
 	}
 	var plat *specs.Platform
 	if service.Platform != "" {
-		p, err := platforms.Parse(service.Platform)
+		var p specs.Platform
+		p, err = platforms.Parse(service.Platform)
 		if err != nil {
-			return err
+			return created, err
 		}
 		plat = &p
 	}
-	created, err := s.apiClient.ContainerCreate(ctx, containerConfig, hostConfig, networkingConfig, plat, name)
+	response, err := s.apiClient.ContainerCreate(ctx, containerConfig, hostConfig, networkingConfig, plat, name)
 	if err != nil {
-		return err
+		return created, err
 	}
-	inspectedContainer, err := s.apiClient.ContainerInspect(ctx, created.ID)
+	inspectedContainer, err := s.apiClient.ContainerInspect(ctx, response.ID)
 	if err != nil {
-		return err
+		return created, err
 	}
-	createdContainer := moby.Container{
+	created = moby.Container{
 		ID:     inspectedContainer.ID,
 		Labels: inspectedContainer.Config.Labels,
 		Names:  []string{inspectedContainer.Name},
@@ -327,11 +418,7 @@ func (s *composeService) createMobyContainer(ctx context.Context, project *types
 			Networks: inspectedContainer.NetworkSettings.Networks,
 		},
 	}
-	cState.Add(createdContainer)
-	links, err := s.getLinks(ctx, service)
-	if err != nil {
-		return err
-	}
+	links := append(service.Links, service.ExternalLinks...)
 	for _, netName := range service.NetworksByPriority() {
 		netwrk := project.Networks[netName]
 		cfg := service.Networks[netName]
@@ -342,21 +429,21 @@ func (s *composeService) createMobyContainer(ctx context.Context, project *types
 				aliases = append(aliases, cfg.Aliases...)
 			}
 		}
-		if val, ok := createdContainer.NetworkSettings.Networks[netwrk.Name]; ok {
-			if shortIDAliasExists(createdContainer.ID, val.Aliases...) {
+		if val, ok := created.NetworkSettings.Networks[netwrk.Name]; ok {
+			if shortIDAliasExists(created.ID, val.Aliases...) {
 				continue
 			}
-			err := s.apiClient.NetworkDisconnect(ctx, netwrk.Name, createdContainer.ID, false)
+			err = s.apiClient.NetworkDisconnect(ctx, netwrk.Name, created.ID, false)
 			if err != nil {
-				return err
+				return created, err
 			}
 		}
 		err = s.connectContainerToNetwork(ctx, created.ID, netwrk.Name, cfg, links, aliases...)
 		if err != nil {
-			return err
+			return created, err
 		}
 	}
-	return nil
+	return created, err
 }
 
 func shortIDAliasExists(containerID string, aliases ...string) bool {
@@ -395,37 +482,6 @@ func (s *composeService) connectContainerToNetwork(ctx context.Context, id strin
 	return nil
 }
 
-func (s *composeService) getLinks(ctx context.Context, service types.ServiceConfig) ([]string, error) {
-	cState, err := GetContextContainerState(ctx)
-	if err != nil {
-		return nil, err
-	}
-	links := []string{}
-	for _, serviceLink := range service.Links {
-		s := strings.Split(serviceLink, ":")
-		serviceName := serviceLink
-		serviceAlias := ""
-		if len(s) == 2 {
-			serviceName = s[0]
-			serviceAlias = s[1]
-		}
-		containers := cState.GetContainers()
-		depServiceContainers := containers.filter(isService(serviceName))
-		for _, container := range depServiceContainers {
-			name := getCanonicalContainerName(container)
-			if serviceAlias != "" {
-				links = append(links,
-					fmt.Sprintf("%s:%s", name, serviceAlias))
-			}
-			links = append(links,
-				fmt.Sprintf("%s:%s", name, name),
-				fmt.Sprintf("%s:%s", name, getContainerNameWithoutProject(container)))
-		}
-	}
-	links = append(links, service.ExternalLinks...)
-	return links, nil
-}
-
 func (s *composeService) isServiceHealthy(ctx context.Context, project *types.Project, service string) (bool, error) {
 	containers, err := s.getContainers(ctx, project.Name, oneOffExclude, false, service)
 	if err != nil {
@@ -503,26 +559,3 @@ func (s *composeService) startService(ctx context.Context, project *types.Projec
 	}
 	return eg.Wait()
 }
-
-func (s *composeService) restartService(ctx context.Context, serviceName string, timeout *time.Duration) error {
-	containerState, err := GetContextContainerState(ctx)
-	if err != nil {
-		return err
-	}
-	containers := containerState.GetContainers().filter(isService(serviceName))
-	w := progress.ContextWriter(ctx)
-	eg, ctx := errgroup.WithContext(ctx)
-	for _, c := range containers {
-		container := c
-		eg.Go(func() error {
-			eventName := getContainerProgressName(container)
-			w.Event(progress.RestartingEvent(eventName))
-			err := s.apiClient.ContainerRestart(ctx, container.ID, timeout)
-			if err == nil {
-				w.Event(progress.StartedEvent(eventName))
-			}
-			return err
-		})
-	}
-	return eg.Wait()
-}

+ 5 - 39
pkg/compose/create.go

@@ -59,8 +59,6 @@ func (s *composeService) create(ctx context.Context, project *types.Project, opt
 	if err != nil {
 		return err
 	}
-	containerState := NewContainersState(observedState)
-	ctx = context.WithValue(ctx, ContainersKey{}, containerState)
 
 	err = s.ensureImagesExists(ctx, project, observedState, options.QuietPull)
 	if err != nil {
@@ -105,12 +103,7 @@ func (s *composeService) create(ctx context.Context, project *types.Project, opt
 
 	prepareServicesDependsOn(project)
 
-	return InDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error {
-		if utils.StringContains(options.Services, service.Name) {
-			return s.ensureService(c, project, service, options.Recreate, options.Inherit, options.Timeout)
-		}
-		return s.ensureService(c, project, service, options.RecreateDependencies, options.Inherit, options.Timeout)
-	})
+	return newConvergence(options.Services, observedState, s).apply(ctx, project, options)
 }
 
 func prepareVolumes(p *types.Project) error {
@@ -275,12 +268,8 @@ func (s *composeService) getCreateOptions(ctx context.Context, p *types.Project,
 
 	resources := getDeployResources(service)
 
-	networkMode, err := getMode(ctx, service.Name, service.NetworkMode)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-	if networkMode == "" {
-		networkMode = getDefaultNetworkMode(p, service)
+	if service.NetworkMode == "" {
+		service.NetworkMode = getDefaultNetworkMode(p, service)
 	}
 
 	var networkConfig *network.NetworkingConfig
@@ -314,11 +303,6 @@ func (s *composeService) getCreateOptions(ctx context.Context, p *types.Project,
 		break //nolint:staticcheck
 	}
 
-	ipcmode, err := getMode(ctx, service.Name, service.Ipc)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-
 	tmpfs := map[string]string{}
 	for _, t := range service.Tmpfs {
 		if arr := strings.SplitN(t, ":", 2); len(arr) > 1 {
@@ -342,9 +326,9 @@ func (s *composeService) getCreateOptions(ctx context.Context, p *types.Project,
 		Mounts:         mounts,
 		CapAdd:         strslice.StrSlice(service.CapAdd),
 		CapDrop:        strslice.StrSlice(service.CapDrop),
-		NetworkMode:    container.NetworkMode(networkMode),
+		NetworkMode:    container.NetworkMode(service.NetworkMode),
 		Init:           service.Init,
-		IpcMode:        container.IpcMode(ipcmode),
+		IpcMode:        container.IpcMode(service.Ipc),
 		ReadonlyRootfs: service.ReadOnly,
 		RestartPolicy:  getRestartPolicy(service),
 		ShmSize:        int64(service.ShmSize),
@@ -913,24 +897,6 @@ func getAliases(s types.ServiceConfig, c *types.ServiceNetworkConfig) []string {
 	return aliases
 }
 
-func getMode(ctx context.Context, serviceName string, mode string) (string, error) {
-	cState, err := GetContextContainerState(ctx)
-	if err != nil {
-		return "", nil
-	}
-	observedState := cState.GetContainers()
-	depService := getDependentServiceFromMode(mode)
-	if depService != "" {
-		depServiceContainers := observedState.filter(isService(depService))
-		if len(depServiceContainers) > 0 {
-			return types.NetworkModeContainerPrefix + depServiceContainers[0].ID, nil
-		}
-		return "", fmt.Errorf(`no containers started for %q in service %q -> %v`,
-			mode, serviceName, observedState)
-	}
-	return mode, nil
-}
-
 func getNetworksForService(s types.ServiceConfig) map[string]*types.ServiceNetworkConfig {
 	if len(s.Networks) > 0 {
 		return s.Networks

+ 10 - 10
pkg/compose/dependencies.go

@@ -63,16 +63,16 @@ var (
 )
 
 // InDependencyOrder applies the function to the services of the project taking in account the dependency order
-func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, types.ServiceConfig) error) error {
+func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error) error {
 	return visit(ctx, project, upDirectionTraversalConfig, fn, ServiceStopped)
 }
 
 // InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies
-func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, types.ServiceConfig) error) error {
+func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error) error {
 	return visit(ctx, project, downDirectionTraversalConfig, fn, ServiceStarted)
 }
 
-func visit(ctx context.Context, project *types.Project, traversalConfig graphTraversalConfig, fn func(context.Context, types.ServiceConfig) error, initialStatus ServiceStatus) error {
+func visit(ctx context.Context, project *types.Project, traversalConfig graphTraversalConfig, fn func(context.Context, string) error, initialStatus ServiceStatus) error {
 	g := NewGraph(project.Services, initialStatus)
 	if b, err := g.HasCycles(); b {
 		return err
@@ -89,12 +89,12 @@ func visit(ctx context.Context, project *types.Project, traversalConfig graphTra
 }
 
 // Note: this could be `graph.walk` or whatever
-func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, traversalConfig graphTraversalConfig, fn func(context.Context, types.ServiceConfig) error) error {
+func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, traversalConfig graphTraversalConfig, fn func(context.Context, string) error) error {
 	for _, node := range nodes {
 		n := node
 		// Don't start this service yet if all of its children have
 		// not been started yet.
-		if len(traversalConfig.filterAdjacentByStatusFn(graph, n.Service.Name, traversalConfig.adjacentServiceStatusToSkip)) != 0 {
+		if len(traversalConfig.filterAdjacentByStatusFn(graph, n.Service, traversalConfig.adjacentServiceStatusToSkip)) != 0 {
 			continue
 		}
 
@@ -104,7 +104,7 @@ func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex,
 				return err
 			}
 
-			graph.UpdateStatus(n.Service.Name, traversalConfig.targetServiceStatus)
+			graph.UpdateStatus(n.Service, traversalConfig.targetServiceStatus)
 
 			return run(ctx, graph, eg, traversalConfig.adjacentNodesFn(n), traversalConfig, fn)
 		})
@@ -122,7 +122,7 @@ type Graph struct {
 // Vertex represents a service in the dependencies structure
 type Vertex struct {
 	Key      string
-	Service  types.ServiceConfig
+	Service  string
 	Status   ServiceStatus
 	Children map[string]*Vertex
 	Parents  map[string]*Vertex
@@ -162,7 +162,7 @@ func NewGraph(services types.Services, initialStatus ServiceStatus) *Graph {
 	}
 
 	for _, s := range services {
-		graph.AddVertex(s.Name, s, initialStatus)
+		graph.AddVertex(s.Name, s.Name, initialStatus)
 	}
 
 	for _, s := range services {
@@ -175,7 +175,7 @@ func NewGraph(services types.Services, initialStatus ServiceStatus) *Graph {
 }
 
 // NewVertex is the constructor function for the Vertex
-func NewVertex(key string, service types.ServiceConfig, initialStatus ServiceStatus) *Vertex {
+func NewVertex(key string, service string, initialStatus ServiceStatus) *Vertex {
 	return &Vertex{
 		Key:      key,
 		Service:  service,
@@ -186,7 +186,7 @@ func NewVertex(key string, service types.ServiceConfig, initialStatus ServiceSta
 }
 
 // AddVertex adds a vertex to the Graph
-func (g *Graph) AddVertex(key string, service types.ServiceConfig, initialStatus ServiceStatus) {
+func (g *Graph) AddVertex(key string, service string, initialStatus ServiceStatus) {
 	g.lock.Lock()
 	defer g.lock.Unlock()
 

+ 4 - 4
pkg/compose/dependencies_test.go

@@ -47,8 +47,8 @@ var project = types.Project{
 func TestInDependencyUpCommandOrder(t *testing.T) {
 	order := make(chan string)
 	//nolint:errcheck, unparam
-	go InDependencyOrder(context.TODO(), &project, func(ctx context.Context, config types.ServiceConfig) error {
-		order <- config.Name
+	go InDependencyOrder(context.TODO(), &project, func(ctx context.Context, config string) error {
+		order <- config
 		return nil
 	})
 	assert.Equal(t, <-order, "test3")
@@ -59,8 +59,8 @@ func TestInDependencyUpCommandOrder(t *testing.T) {
 func TestInDependencyReverseDownCommandOrder(t *testing.T) {
 	order := make(chan string)
 	//nolint:errcheck, unparam
-	go InReverseDependencyOrder(context.TODO(), &project, func(ctx context.Context, config types.ServiceConfig) error {
-		order <- config.Name
+	go InReverseDependencyOrder(context.TODO(), &project, func(ctx context.Context, config string) error {
+		order <- config
 		return nil
 	})
 	assert.Equal(t, <-order, "test1")

+ 2 - 8
pkg/compose/down.go

@@ -50,7 +50,6 @@ func (s *composeService) down(ctx context.Context, projectName string, options a
 	if err != nil {
 		return err
 	}
-	ctx = context.WithValue(ctx, ContainersKey{}, NewContainersState(containers))
 
 	if options.Project == nil {
 		project, err := s.projectFromContainerLabels(containers, projectName)
@@ -64,8 +63,8 @@ func (s *composeService) down(ctx context.Context, projectName string, options a
 		resourceToRemove = true
 	}
 
-	err = InReverseDependencyOrder(ctx, options.Project, func(c context.Context, service types.ServiceConfig) error {
-		serviceContainers := containers.filter(isService(service.Name))
+	err = InReverseDependencyOrder(ctx, options.Project, func(c context.Context, service string) error {
+		serviceContainers := containers.filter(isService(service))
 		err := s.removeContainers(ctx, w, serviceContainers, options.Timeout, options.Volumes)
 		return err
 	})
@@ -236,11 +235,6 @@ func (s *composeService) removeContainers(ctx context.Context, w progress.Writer
 				w.Event(progress.ErrorMessageEvent(eventName, "Error while Removing"))
 				return err
 			}
-			contextContainerState, err := GetContextContainerState(ctx)
-			if err != nil {
-				return err
-			}
-			contextContainerState.Remove(toDelete.ID)
 			w.Event(progress.RemovedEvent(eventName))
 			return nil
 		})

+ 0 - 4
pkg/compose/filters.go

@@ -31,10 +31,6 @@ func serviceFilter(serviceName string) filters.KeyValuePair {
 	return filters.Arg("label", fmt.Sprintf("%s=%s", api.ServiceLabel, serviceName))
 }
 
-func slugFilter(slug string) filters.KeyValuePair {
-	return filters.Arg("label", fmt.Sprintf("%s=%s", api.SlugLabel, slug))
-}
-
 func oneOffFilter(b bool) filters.KeyValuePair {
 	v := "False"
 	if b {

+ 20 - 5
pkg/compose/restart.go

@@ -21,6 +21,7 @@ import (
 
 	"github.com/compose-spec/compose-go/types"
 	"github.com/docker/compose-cli/pkg/api"
+	"golang.org/x/sync/errgroup"
 
 	"github.com/docker/compose-cli/pkg/progress"
 	"github.com/docker/compose-cli/pkg/utils"
@@ -33,7 +34,7 @@ func (s *composeService) Restart(ctx context.Context, project *types.Project, op
 }
 
 func (s *composeService) restart(ctx context.Context, project *types.Project, options api.RestartOptions) error {
-	ctx, err := s.getUpdatedContainersStateContext(ctx, project.Name)
+	observedState, err := s.getContainers(ctx, project.Name, oneOffInclude, true)
 	if err != nil {
 		return err
 	}
@@ -42,11 +43,25 @@ func (s *composeService) restart(ctx context.Context, project *types.Project, op
 		options.Services = project.ServiceNames()
 	}
 
-	err = InDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error {
-		if utils.StringContains(options.Services, service.Name) {
-			return s.restartService(ctx, service.Name, options.Timeout)
+	w := progress.ContextWriter(ctx)
+	err = InDependencyOrder(ctx, project, func(c context.Context, service string) error {
+		if !utils.StringContains(options.Services, service) {
+			return nil
 		}
-		return nil
+		eg, ctx := errgroup.WithContext(ctx)
+		for _, c := range observedState.filter(isService(service)) {
+			container := c
+			eg.Go(func() error {
+				eventName := getContainerProgressName(container)
+				w.Event(progress.RestartingEvent(eventName))
+				err := s.apiClient.ContainerRestart(ctx, container.ID, options.Timeout)
+				if err == nil {
+					w.Event(progress.StartedEvent(eventName))
+				}
+				return err
+			})
+		}
+		return eg.Wait()
 	})
 	if err != nil {
 		return err

+ 5 - 15
pkg/compose/run.go

@@ -25,7 +25,6 @@ import (
 	"github.com/compose-spec/compose-go/types"
 	moby "github.com/docker/docker/api/types"
 	"github.com/docker/docker/api/types/container"
-	"github.com/docker/docker/api/types/filters"
 	"github.com/docker/docker/pkg/stringid"
 )
 
@@ -34,8 +33,6 @@ func (s *composeService) RunOneOffContainer(ctx context.Context, project *types.
 	if err != nil {
 		return 0, err
 	}
-	containerState := NewContainersState(observedState)
-	ctx = context.WithValue(ctx, ContainersKey{}, containerState)
 
 	service, err := project.GetService(opts.Service)
 	if err != nil {
@@ -63,10 +60,11 @@ func (s *composeService) RunOneOffContainer(ctx context.Context, project *types.
 	if err := s.waitDependencies(ctx, project, service); err != nil {
 		return 0, err
 	}
-	if err := s.createContainer(ctx, project, service, service.ContainerName, 1, opts.AutoRemove, opts.UseNetworkAliases); err != nil {
+	created, err := s.createContainer(ctx, project, service, service.ContainerName, 1, opts.AutoRemove, opts.UseNetworkAliases)
+	if err != nil {
 		return 0, err
 	}
-	containerID := service.ContainerName
+	containerID := created.ID
 
 	if opts.Detach {
 		err := s.apiClient.ContainerStart(ctx, containerID, moby.ContainerStartOptions{})
@@ -77,21 +75,13 @@ func (s *composeService) RunOneOffContainer(ctx context.Context, project *types.
 		return 0, nil
 	}
 
-	containers, err := s.apiClient.ContainerList(ctx, moby.ContainerListOptions{
-		Filters: filters.NewArgs(slugFilter(slug)),
-		All:     true,
-	})
-	if err != nil {
-		return 0, err
-	}
-	oneoffContainer := containers[0]
-	restore, err := s.attachContainerStreams(ctx, oneoffContainer.ID, service.Tty, opts.Reader, opts.Writer)
+	restore, err := s.attachContainerStreams(ctx, containerID, service.Tty, opts.Reader, opts.Writer)
 	if err != nil {
 		return 0, err
 	}
 	defer restore()
 
-	statusC, errC := s.apiClient.ContainerWait(context.Background(), oneoffContainer.ID, container.WaitConditionNextExit)
+	statusC, errC := s.apiClient.ContainerWait(context.Background(), containerID, container.WaitConditionNextExit)
 
 	err = s.apiClient.ContainerStart(ctx, containerID, moby.ContainerStartOptions{})
 	if err != nil {

+ 5 - 1
pkg/compose/start.go

@@ -53,7 +53,11 @@ func (s *composeService) start(ctx context.Context, project *types.Project, opti
 		})
 	}
 
-	err := InDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error {
+	err := InDependencyOrder(ctx, project, func(c context.Context, name string) error {
+		service, err := project.GetService(name)
+		if err != nil {
+			return err
+		}
 		return s.startService(ctx, project, service)
 	})
 	if err != nil {

+ 0 - 111
pkg/compose/status.go

@@ -1,111 +0,0 @@
-/*
-   Copyright 2020 Docker Compose CLI authors
-
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
-
-       http://www.apache.org/licenses/LICENSE-2.0
-
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License.
-*/
-
-package compose
-
-import (
-	"context"
-
-	"github.com/docker/docker/api/types"
-	"github.com/pkg/errors"
-)
-
-// ContainersKey is the context key to access context value os a ContainersStatus
-type ContainersKey struct{}
-
-// ContainersState state management interface
-type ContainersState interface {
-	Get(string) *types.Container
-	GetContainers() Containers
-	Add(c types.Container)
-	AddAll(cs Containers)
-	Remove(string) types.Container
-}
-
-// NewContainersState creates a new container state manager
-func NewContainersState(cs Containers) ContainersState {
-	s := containersState{
-		observedContainers: &cs,
-	}
-	return &s
-}
-
-// ContainersStatus works as a collection container for the observed containers
-type containersState struct {
-	observedContainers *Containers
-}
-
-func (s *containersState) AddAll(cs Containers) {
-	for _, c := range cs {
-		lValue := append(*s.observedContainers, c)
-		s.observedContainers = &lValue
-	}
-}
-
-func (s *containersState) Add(c types.Container) {
-	if s.Get(c.ID) == nil {
-		lValue := append(*s.observedContainers, c)
-		s.observedContainers = &lValue
-	}
-}
-
-func (s *containersState) Remove(id string) types.Container {
-	var c types.Container
-	var newObserved Containers
-	for _, o := range *s.observedContainers {
-		if o.ID == id {
-			c = o
-			continue
-		}
-		newObserved = append(newObserved, o)
-	}
-	s.observedContainers = &newObserved
-	return c
-}
-
-func (s *containersState) Get(id string) *types.Container {
-	for _, o := range *s.observedContainers {
-		if id == o.ID {
-			return &o
-		}
-	}
-	return nil
-}
-
-func (s *containersState) GetContainers() Containers {
-	if s.observedContainers != nil && *s.observedContainers != nil {
-		return *s.observedContainers
-	}
-	return make(Containers, 0)
-}
-
-// GetContextContainerState gets the container state manager
-func GetContextContainerState(ctx context.Context) (ContainersState, error) {
-	cState, ok := ctx.Value(ContainersKey{}).(*containersState)
-	if !ok {
-		return nil, errors.New("containers' containersState not available in context")
-	}
-	return cState, nil
-}
-
-func (s composeService) getUpdatedContainersStateContext(ctx context.Context, projectName string) (context.Context, error) {
-	observedState, err := s.getContainers(ctx, projectName, oneOffInclude, true)
-	if err != nil {
-		return nil, err
-	}
-	containerState := NewContainersState(observedState)
-	return context.WithValue(ctx, ContainersKey{}, containerState), nil
-}

+ 2 - 2
pkg/compose/stop.go

@@ -44,7 +44,7 @@ func (s *composeService) stop(ctx context.Context, project *types.Project, optio
 		return err
 	}
 
-	return InReverseDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error {
-		return s.stopContainers(ctx, w, containers.filter(isService(service.Name)), options.Timeout)
+	return InReverseDependencyOrder(ctx, project, func(c context.Context, service string) error {
+		return s.stopContainers(ctx, w, containers.filter(isService(service)), options.Timeout)
 	})
 }