Browse Source

Query stack events by stack ID (not name)

This prevent a race condition on `down` as stack is deleted and we still
ask for stack events as we didn't recieved the DELETE_COMPLETE one

Use WaitUntilStack* to detect stack operation completion

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 5 years ago
parent
commit
a8e963a304

+ 1 - 2
ecs/pkg/amazon/down.go

@@ -10,9 +10,8 @@ func (c *client) ComposeDown(ctx context.Context, projectName string, deleteClus
 	if err != nil {
 		return err
 	}
-	fmt.Printf("Delete stack ")
 
-	err = c.WaitStackCompletion(ctx, projectName)
+	err = c.WaitStackCompletion(ctx, projectName, StackDelete)
 	if err != nil {
 		return err
 	}

+ 16 - 1
ecs/pkg/amazon/mock/api.go

@@ -152,6 +152,21 @@ func (mr *MockAPIMockRecorder) GetRoleArn(arg0, arg1 interface{}) *gomock.Call {
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoleArn", reflect.TypeOf((*MockAPI)(nil).GetRoleArn), arg0, arg1)
 }
 
+// GetStackID mocks base method
+func (m *MockAPI) GetStackID(arg0 context.Context, arg1 string) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetStackID", arg0, arg1)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetStackID indicates an expected call of GetStackID
+func (mr *MockAPIMockRecorder) GetStackID(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStackID", reflect.TypeOf((*MockAPI)(nil).GetStackID), arg0, arg1)
+}
+
 // GetSubNets mocks base method
 func (m *MockAPI) GetSubNets(arg0 context.Context, arg1 string) ([]string, error) {
 	m.ctrl.T.Helper()
@@ -213,7 +228,7 @@ func (mr *MockAPIMockRecorder) VpcExists(arg0, arg1 interface{}) *gomock.Call {
 }
 
 // WaitStackComplete mocks base method
-func (m *MockAPI) WaitStackComplete(arg0 context.Context, arg1 string, arg2 func() error) error {
+func (m *MockAPI) WaitStackComplete(arg0 context.Context, arg1 string, arg2 int) error {
 	m.ctrl.T.Helper()
 	ret := m.ctrl.Call(m, "WaitStackComplete", arg0, arg1, arg2)
 	ret0, _ := ret[0].(error)

+ 34 - 35
ecs/pkg/amazon/sdk.go

@@ -3,8 +3,6 @@ package amazon
 import (
 	"context"
 	"fmt"
-	"strings"
-	"time"
 
 	"github.com/aws/aws-sdk-go/aws"
 	"github.com/aws/aws-sdk-go/aws/session"
@@ -53,7 +51,7 @@ func NewAPI(sess *session.Session) API {
 
 func (s sdk) ClusterExists(ctx context.Context, name string) (bool, error) {
 	logrus.Debug("Check if cluster was already created: ", name)
-	clusters, err := s.ECS.DescribeClustersWithContext(aws.Context(ctx), &ecs.DescribeClustersInput{
+	clusters, err := s.ECS.DescribeClustersWithContext(ctx, &ecs.DescribeClustersInput{
 		Clusters: []*string{aws.String(name)},
 	})
 	if err != nil {
@@ -64,7 +62,7 @@ func (s sdk) ClusterExists(ctx context.Context, name string) (bool, error) {
 
 func (s sdk) CreateCluster(ctx context.Context, name string) (string, error) {
 	logrus.Debug("Create cluster ", name)
-	response, err := s.ECS.CreateClusterWithContext(aws.Context(ctx), &ecs.CreateClusterInput{ClusterName: aws.String(name)})
+	response, err := s.ECS.CreateClusterWithContext(ctx, &ecs.CreateClusterInput{ClusterName: aws.String(name)})
 	if err != nil {
 		return "", err
 	}
@@ -73,7 +71,7 @@ func (s sdk) CreateCluster(ctx context.Context, name string) (string, error) {
 
 func (s sdk) DeleteCluster(ctx context.Context, name string) error {
 	logrus.Debug("Delete cluster ", name)
-	response, err := s.ECS.DeleteClusterWithContext(aws.Context(ctx), &ecs.DeleteClusterInput{Cluster: aws.String(name)})
+	response, err := s.ECS.DeleteClusterWithContext(ctx, &ecs.DeleteClusterInput{Cluster: aws.String(name)})
 	if err != nil {
 		return err
 	}
@@ -85,13 +83,13 @@ func (s sdk) DeleteCluster(ctx context.Context, name string) error {
 
 func (s sdk) VpcExists(ctx context.Context, vpcID string) (bool, error) {
 	logrus.Debug("Check if VPC exists: ", vpcID)
-	_, err := s.EC2.DescribeVpcsWithContext(aws.Context(ctx), &ec2.DescribeVpcsInput{VpcIds: []*string{&vpcID}})
+	_, err := s.EC2.DescribeVpcsWithContext(ctx, &ec2.DescribeVpcsInput{VpcIds: []*string{&vpcID}})
 	return err == nil, err
 }
 
 func (s sdk) GetDefaultVPC(ctx context.Context) (string, error) {
 	logrus.Debug("Retrieve default VPC")
-	vpcs, err := s.EC2.DescribeVpcsWithContext(aws.Context(ctx), &ec2.DescribeVpcsInput{
+	vpcs, err := s.EC2.DescribeVpcsWithContext(ctx, &ec2.DescribeVpcsInput{
 		Filters: []*ec2.Filter{
 			{
 				Name:   aws.String("isDefault"),
@@ -110,7 +108,7 @@ func (s sdk) GetDefaultVPC(ctx context.Context) (string, error) {
 
 func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]string, error) {
 	logrus.Debug("Retrieve SubNets")
-	subnets, err := s.EC2.DescribeSubnetsWithContext(aws.Context(ctx), &ec2.DescribeSubnetsInput{
+	subnets, err := s.EC2.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{
 		DryRun: nil,
 		Filters: []*ec2.Filter{
 			{
@@ -135,7 +133,7 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]string, error) {
 }
 
 func (s sdk) ListRolesForPolicy(ctx context.Context, policy string) ([]string, error) {
-	entities, err := s.IAM.ListEntitiesForPolicyWithContext(aws.Context(ctx), &iam.ListEntitiesForPolicyInput{
+	entities, err := s.IAM.ListEntitiesForPolicyWithContext(ctx, &iam.ListEntitiesForPolicyInput{
 		EntityFilter: aws.String("Role"),
 		PolicyArn:    aws.String(policy),
 	})
@@ -150,7 +148,7 @@ func (s sdk) ListRolesForPolicy(ctx context.Context, policy string) ([]string, e
 }
 
 func (s sdk) GetRoleArn(ctx context.Context, name string) (string, error) {
-	role, err := s.IAM.GetRoleWithContext(aws.Context(ctx), &iam.GetRoleInput{
+	role, err := s.IAM.GetRoleWithContext(ctx, &iam.GetRoleInput{
 		RoleName: aws.String(name),
 	})
 	if err != nil {
@@ -160,7 +158,7 @@ func (s sdk) GetRoleArn(ctx context.Context, name string) (string, error) {
 }
 
 func (s sdk) StackExists(ctx context.Context, name string) (bool, error) {
-	stacks, err := s.CF.DescribeStacksWithContext(aws.Context(ctx), &cloudformation.DescribeStacksInput{
+	stacks, err := s.CF.DescribeStacksWithContext(ctx, &cloudformation.DescribeStacksInput{
 		StackName: aws.String(name),
 	})
 	if err != nil {
@@ -177,7 +175,7 @@ func (s sdk) CreateStack(ctx context.Context, name string, template *cf.Template
 		return err
 	}
 
-	_, err = s.CF.CreateStackWithContext(aws.Context(ctx), &cloudformation.CreateStackInput{
+	_, err = s.CF.CreateStackWithContext(ctx, &cloudformation.CreateStackInput{
 		OnFailure:        aws.String("DELETE"),
 		StackName:        aws.String(name),
 		TemplateBody:     aws.String(string(json)),
@@ -185,36 +183,37 @@ func (s sdk) CreateStack(ctx context.Context, name string, template *cf.Template
 	})
 	return err
 }
-func (s sdk) WaitStackComplete(ctx context.Context, name string, fn func() error) error {
-	for i := 0; i < 120; i++ {
-		stacks, err := s.CF.DescribeStacks(&cloudformation.DescribeStacksInput{
-			StackName: aws.String(name),
-		})
-		if err != nil {
-			return err
-		}
-
-		err = fn()
-		if err != nil {
-			return err
-		}
+func (s sdk) WaitStackComplete(ctx context.Context, name string, operation int) error {
+	input := &cloudformation.DescribeStacksInput{
+		StackName: aws.String(name),
+	}
+	switch operation {
+	case StackCreate:
+		return s.CF.WaitUntilStackCreateCompleteWithContext(ctx, input)
+	case StackDelete:
+		return s.CF.WaitUntilStackDeleteCompleteWithContext(ctx, input)
+	default:
+		return fmt.Errorf("internal error: unexpected stack operation %d", operation)
+	}
+}
 
-		status := *stacks.Stacks[0].StackStatus
-		if strings.HasSuffix(status, "_COMPLETE") || strings.HasSuffix(status, "_FAILED") {
-			return nil
-		}
-		time.Sleep(1 * time.Second)
+func (s sdk) GetStackID(ctx context.Context, name string) (string, error) {
+	stacks, err := s.CF.DescribeStacksWithContext(ctx, &cloudformation.DescribeStacksInput{
+		StackName: aws.String(name),
+	})
+	if err != nil {
+		return "", err
 	}
-	return fmt.Errorf("120s timeout waiting for CloudFormation stack %s to complete", name)
+	return *stacks.Stacks[0].StackId, nil
 }
 
-func (s sdk) DescribeStackEvents(ctx context.Context, name string) ([]*cloudformation.StackEvent, error) {
+func (s sdk) DescribeStackEvents(ctx context.Context, stackID string) ([]*cloudformation.StackEvent, error) {
 	// Fixme implement Paginator on Events and return as a chan(events)
 	events := []*cloudformation.StackEvent{}
 	var nextToken *string
 	for {
-		resp, err := s.CF.DescribeStackEventsWithContext(aws.Context(ctx), &cloudformation.DescribeStackEventsInput{
-			StackName: aws.String(name),
+		resp, err := s.CF.DescribeStackEventsWithContext(ctx, &cloudformation.DescribeStackEventsInput{
+			StackName: aws.String(stackID),
 			NextToken: nextToken,
 		})
 		if err != nil {
@@ -230,7 +229,7 @@ func (s sdk) DescribeStackEvents(ctx context.Context, name string) ([]*cloudform
 
 func (s sdk) DeleteStack(ctx context.Context, name string) error {
 	logrus.Debug("Delete CloudFormation stack")
-	_, err := s.CF.DeleteStackWithContext(aws.Context(ctx), &cloudformation.DeleteStackInput{
+	_, err := s.CF.DeleteStackWithContext(ctx, &cloudformation.DeleteStackInput{
 		StackName: aws.String(name),
 	})
 	return err

+ 1 - 1
ecs/pkg/amazon/up.go

@@ -34,7 +34,7 @@ func (c *client) ComposeUp(ctx context.Context, project *compose.Project) error
 		return err
 	}
 
-	return c.WaitStackCompletion(ctx, project.Name)
+	return c.WaitStackCompletion(ctx, project.Name, StackCreate)
 }
 
 type upAPI interface {

+ 40 - 13
ecs/pkg/amazon/wait.go

@@ -4,17 +4,42 @@ import (
 	"context"
 	"fmt"
 	"sort"
+	"time"
 
 	"github.com/aws/aws-sdk-go/aws"
 	"github.com/aws/aws-sdk-go/service/cloudformation"
 	"github.com/docker/ecs-plugin/pkg/console"
 )
 
-func (c *client) WaitStackCompletion(ctx context.Context, name string) error {
+func (c *client) WaitStackCompletion(ctx context.Context, name string, operation int) error {
 	w := console.NewProgressWriter()
-	known := map[string]struct{}{}
-	err := c.api.WaitStackComplete(ctx, name, func() error {
-		events, err := c.api.DescribeStackEvents(ctx, name)
+	knownEvents := map[string]struct{}{}
+
+	// Get the unique Stack ID so we can collect events without getting some from previous deployments with same name
+	stackID, err := c.api.GetStackID(ctx, name)
+	if err != nil {
+		return err
+	}
+
+	ticker := time.NewTicker(1 * time.Second)
+	done := make(chan error)
+
+	go func() {
+		err := c.api.WaitStackComplete(ctx, name, operation)
+		ticker.Stop()
+		done <- err
+	}()
+
+	var completed bool
+	var waitErr error
+	for !completed {
+		select {
+		case err := <-done:
+			completed = true
+			waitErr = err
+		case <-ticker.C:
+		}
+		events, err := c.api.DescribeStackEvents(ctx, stackID)
 		if err != nil {
 			return err
 		}
@@ -24,23 +49,25 @@ func (c *client) WaitStackCompletion(ctx context.Context, name string) error {
 		})
 
 		for _, event := range events {
-			if _, ok := known[*event.EventId]; ok {
+			if _, ok := knownEvents[*event.EventId]; ok {
 				continue
 			}
-			known[*event.EventId] = struct{}{}
+			knownEvents[*event.EventId] = struct{}{}
 
 			resource := fmt.Sprintf("%s %q", aws.StringValue(event.ResourceType), aws.StringValue(event.LogicalResourceId))
 			w.ResourceEvent(resource, aws.StringValue(event.ResourceStatus), aws.StringValue(event.ResourceStatusReason))
 		}
-		return nil
-	})
-	if err != nil {
-		return err
 	}
-	return nil
+	return waitErr
 }
 
 type waitAPI interface {
-	WaitStackComplete(ctx context.Context, name string, fn func() error) error
-	DescribeStackEvents(ctx context.Context, stack string) ([]*cloudformation.StackEvent, error)
+	GetStackID(ctx context.Context, name string) (string, error)
+	WaitStackComplete(ctx context.Context, name string, operation int) error
+	DescribeStackEvents(ctx context.Context, stackID string) ([]*cloudformation.StackEvent, error)
 }
+
+const (
+	StackCreate = iota
+	StackDelete
+)

+ 0 - 6
ecs/pkg/console/progress.go

@@ -100,16 +100,10 @@ func (c ansiConsole) Printf(format string, a ...interface{}) {
 }
 
 func (c ansiConsole) MoveUp(i int) {
-	if i == 0 {
-		return
-	}
 	fmt.Fprintf(c.out, "\033[%dA", i) // nolint:errcheck
 }
 
 func (c ansiConsole) MoveDown(i int) {
-	if i == 0 {
-		return
-	}
 	fmt.Fprintf(c.out, "\033[%dB", i) // nolint:errcheck
 }
 

+ 0 - 1
ecs/pkg/convert/convert.go

@@ -71,7 +71,6 @@ func Convert(project *compose.Project, service types.ServiceConfig) (*ecs.TaskDe
 		Tags:                    nil,
 		Volumes:                 []ecs.TaskDefinition_Volume{},
 	}, nil
-
 }
 
 func toCPU(service types.ServiceConfig) string {