Bladeren bron

refactor task failure detection

Signed-off-by: aiordache <[email protected]>
aiordache 5 jaren geleden
bovenliggende
commit
c15d05f7f4
5 gewijzigde bestanden met toevoegingen van 229 en 135 verwijderingen
  1. 2 3
      cli/cmd/compose/list.go
  2. 83 1
      ecs/list.go
  3. 11 19
      ecs/ps.go
  4. 121 101
      ecs/sdk.go
  5. 12 11
      ecs/wait.go

+ 2 - 3
cli/cmd/compose/list.go

@@ -61,8 +61,7 @@ func runList(ctx context.Context, opts composeOptions) error {
 	view := viewFromStackList(stackList)
 	return formatter.Print(view, opts.Format, os.Stdout, func(w io.Writer) {
 		for _, stack := range view {
-			_, _ = fmt.Fprintf(w, "%s\t%s\n", stack.Name, strings.TrimSpace(
-				fmt.Sprintf("%s %s", stack.Status, stack.Reason))
+			_, _ = fmt.Fprintf(w, "%s\t%s\n", stack.Name, stack.Status)
 		}
 	}, "NAME", "STATUS")
 }
@@ -77,7 +76,7 @@ func viewFromStackList(stackList []compose.Stack) []stackView {
 	for i, s := range stackList {
 		retList[i] = stackView{
 			Name:   s.Name,
-			Status: s.Status,
+			Status: strings.TrimSpace(fmt.Sprintf("%s %s", s.Status, s.Reason)),
 		}
 	}
 	return retList

+ 83 - 1
ecs/list.go

@@ -18,11 +18,93 @@ package ecs
 
 import (
 	"context"
+	"fmt"
 
 	"github.com/docker/compose-cli/api/compose"
 )
 
 func (b *ecsAPIService) List(ctx context.Context, project string) ([]compose.Stack, error) {
-	return b.SDK.ListStacks(ctx, project)
+	stacks, err := b.SDK.ListStacks(ctx, project)
+	if err != nil {
+		return nil, err
+	}
 
+	for _, stack := range stacks {
+		if stack.Status == compose.STARTING {
+			if err := b.checkStackState(ctx, stack.Name); err != nil {
+				stack.Status = compose.FAILED
+				stack.Reason = err.Error()
+			}
+		}
+	}
+	return stacks, nil
+
+}
+
+func (b *ecsAPIService) checkStackState(ctx context.Context, name string) error {
+	resources, err := b.SDK.ListStackResources(ctx, name)
+	if err != nil {
+		return err
+	}
+	svcArns := []string{}
+	svcNames := map[string]string{}
+	var cluster string
+	for _, r := range resources {
+		if r.Type == "AWS::ECS::Cluster" {
+			cluster = r.ARN
+			continue
+		}
+		if r.Type == "AWS::ECS::Service" {
+			if r.ARN == "" {
+				continue
+			}
+			svcArns = append(svcArns, r.ARN)
+			svcNames[r.ARN] = r.LogicalID
+		}
+	}
+	if len(svcArns) == 0 {
+		return nil
+	}
+	services, err := b.SDK.GetServiceTaskDefinition(ctx, cluster, svcArns)
+	if err != nil {
+		return err
+	}
+	for service, taskDef := range services {
+		if err := b.checkServiceState(ctx, cluster, service, taskDef); err != nil {
+			return fmt.Errorf("%s %s", svcNames[service], err.Error())
+		}
+	}
+	return nil
+}
+
+func (b *ecsAPIService) checkServiceState(ctx context.Context, cluster string, service string, taskdef string) error {
+	runningTasks, err := b.SDK.GetServiceTasks(ctx, cluster, service, false)
+	if err != nil {
+		return err
+	}
+	if len(runningTasks) > 0 {
+		return nil
+	}
+	stoppedTasks, err := b.SDK.GetServiceTasks(ctx, cluster, service, true)
+	if err != nil {
+		return err
+	}
+	if len(stoppedTasks) == 0 {
+		return nil
+	}
+	// filter tasks by task definition
+	tasks := []string{}
+	for _, t := range stoppedTasks {
+		if t.TaskDefinitionArn != nil && *t.TaskDefinitionArn == taskdef {
+			tasks = append(tasks, *t.TaskArn)
+		}
+	}
+	if len(tasks) == 0 {
+		return nil
+	}
+	reason, err := b.SDK.GetTaskStoppedReason(ctx, cluster, tasks[0])
+	if err != nil {
+		return err
+	}
+	return fmt.Errorf("%s", reason)
 }

+ 11 - 19
ecs/ps.go

@@ -25,33 +25,25 @@ import (
 )
 
 func (b *ecsAPIService) Ps(ctx context.Context, project string) ([]compose.ServiceStatus, error) {
-	resources, err := b.SDK.ListStackResources(ctx, project)
+	cluster, err := b.SDK.GetStackClusterID(ctx, project)
 	if err != nil {
 		return nil, err
 	}
-
-	var (
-		cluster     = project
-		servicesARN []string
-	)
-	for _, r := range resources {
-		switch r.Type {
-		case "AWS::ECS::Service":
-			servicesARN = append(servicesARN, r.ARN)
-		case "AWS::ECS::Cluster":
-			cluster = r.ARN
-		}
+	servicesARN, err := b.SDK.ListStackServices(ctx, project)
+	if err != nil {
+		return nil, err
 	}
 
 	if len(servicesARN) == 0 {
 		return nil, nil
 	}
-	status, err := b.SDK.DescribeServices(ctx, cluster, servicesARN)
-	if err != nil {
-		return nil, err
-	}
 
-	for i, state := range status {
+	status := []compose.ServiceStatus{}
+	for _, arn := range servicesARN {
+		state, err := b.SDK.DescribeService(ctx, cluster, arn)
+		if err != nil {
+			return nil, err
+		}
 		ports := []string{}
 		for _, lb := range state.Publishers {
 			ports = append(ports, fmt.Sprintf(
@@ -62,7 +54,7 @@ func (b *ecsAPIService) Ps(ctx context.Context, project string) ([]compose.Servi
 				strings.ToLower(lb.Protocol)))
 		}
 		state.Ports = ports
-		status[i] = state
+		status = append(status, state)
 	}
 	return status, nil
 }

+ 121 - 101
ecs/sdk.go

@@ -304,93 +304,102 @@ func (s sdk) ListStacks(ctx context.Context, name string) ([]compose.Stack, erro
 	}
 	stacks := []compose.Stack{}
 	for _, stack := range cfStacks.Stacks {
-		skip := true
 		for _, t := range stack.Tags {
 			if *t.Key == compose.ProjectTag {
-				skip = false
+				status := compose.RUNNING
+				switch aws.StringValue(stack.StackStatus) {
+				case "CREATE_IN_PROGRESS":
+					status = compose.STARTING
+				case "DELETE_IN_PROGRESS":
+					status = compose.REMOVING
+				case "UPDATE_IN_PROGRESS":
+					status = compose.UPDATING
+				default:
+				}
+				stacks = append(stacks, compose.Stack{
+					ID:     aws.StringValue(stack.StackId),
+					Name:   aws.StringValue(stack.StackName),
+					Status: status,
+				})
 				break
 			}
 		}
-		if skip {
-			continue
-		}
-		status := compose.RUNNING
-		reason := ""
-		switch aws.StringValue(stack.StackStatus) {
-		case "CREATE_IN_PROGRESS":
-			status = compose.STARTING
-		case "DELETE_IN_PROGRESS":
-			status = compose.REMOVING
-		case "UPDATE_IN_PROGRESS":
-			status = compose.UPDATING
-		}
-		if status == compose.STARTING {
-			if err := s.CheckStackState(ctx, aws.StringValue(stack.StackName)); err != nil {
-				status = compose.FAILED
-				reason = err.Error()
-			}
-		}
-		stacks = append(stacks, compose.Stack{
-			ID:     aws.StringValue(stack.StackId),
-			Name:   aws.StringValue(stack.StackName),
-			Status: status,
-			Reason: reason,
-		})
-
 	}
 	return stacks, nil
 }
 
-func (s sdk) CheckStackState(ctx context.Context, name string) error {
+func (s sdk) GetStackClusterID(ctx context.Context, stack string) (string, error) {
 	resources, err := s.CF.ListStackResourcesWithContext(ctx, &cloudformation.ListStackResourcesInput{
-		StackName: aws.String(name),
+		StackName: aws.String(stack),
 	})
 	if err != nil {
-		return err
+		return "", err
 	}
-	services := []*string{}
-	serviceNames := []string{}
-	var cluster *string
 	for _, r := range resources.StackResourceSummaries {
 		if aws.StringValue(r.ResourceType) == "AWS::ECS::Cluster" {
-			cluster = r.PhysicalResourceId
-			continue
-		}
-		if aws.StringValue(r.ResourceType) == "AWS::ECS::Service" {
-			if r.PhysicalResourceId == nil {
-				continue
-			}
-			services = append(services, r.PhysicalResourceId)
-			serviceNames = append(serviceNames, *r.LogicalResourceId)
+			return aws.StringValue(r.PhysicalResourceId), nil
 		}
 	}
-	for i, service := range services {
-		err := s.CheckTaskState(ctx, aws.StringValue(cluster), aws.StringValue(service))
-		if err != nil {
-			return fmt.Errorf("%s error: %s", serviceNames[i], err.Error())
-		}
-	}
-	return nil
+	return "", nil
 }
 
-func (s sdk) CheckTaskState(ctx context.Context, cluster string, serviceName string) error {
-	tasks, err := s.ECS.ListTasksWithContext(ctx, &ecs.ListTasksInput{
-		Cluster:     aws.String(cluster),
-		ServiceName: aws.String(serviceName),
+func (s sdk) GetServiceTaskDefinition(ctx context.Context, cluster string, serviceArns []string) (map[string]string, error) {
+	defs := map[string]string{}
+	svc := []*string{}
+	for _, s := range serviceArns {
+		svc = append(svc, aws.String(s))
+	}
+	services, err := s.ECS.DescribeServicesWithContext(ctx, &ecs.DescribeServicesInput{
+		Cluster:  aws.String(cluster),
+		Services: svc,
 	})
 	if err != nil {
-		return err
+		return nil, err
 	}
-	if len(tasks.TaskArns) > 0 {
-		return nil
+	for _, s := range services.Services {
+		defs[aws.StringValue(s.ServiceArn)] = aws.StringValue(s.TaskDefinition)
 	}
-	tasks, err = s.ECS.ListTasksWithContext(ctx, &ecs.ListTasksInput{
+	return defs, nil
+}
+
+func (s sdk) ListStackServices(ctx context.Context, stack string) ([]string, error) {
+	arns := []string{}
+	var nextToken *string
+	for {
+		response, err := s.CF.ListStackResourcesWithContext(ctx, &cloudformation.ListStackResourcesInput{
+			StackName: aws.String(stack),
+			NextToken: nextToken,
+		})
+		if err != nil {
+			return nil, err
+		}
+		for _, r := range response.StackResourceSummaries {
+			if aws.StringValue(r.ResourceType) == "AWS::ECS::Service" {
+				if r.PhysicalResourceId != nil {
+					arns = append(arns, aws.StringValue(r.PhysicalResourceId))
+				}
+			}
+		}
+		nextToken = response.NextToken
+		if nextToken == nil {
+			break
+		}
+	}
+	return arns, nil
+}
+
+func (s sdk) GetServiceTasks(ctx context.Context, cluster string, service string, stopped bool) ([]*ecs.Task, error) {
+	state := "RUNNING"
+	if stopped {
+		state = "STOPPED"
+	}
+	tasks, err := s.ECS.ListTasksWithContext(ctx, &ecs.ListTasksInput{
 		Cluster:       aws.String(cluster),
-		ServiceName:   aws.String(serviceName),
-		DesiredStatus: aws.String("STOPPED"),
+		ServiceName:   aws.String(service),
+		DesiredStatus: aws.String(state),
 	})
 	if err != nil {
-		return err
+		return nil, err
 	}
 	if len(tasks.TaskArns) > 0 {
 		taskDescriptions, err := s.ECS.DescribeTasksWithContext(ctx, &ecs.DescribeTasksInput{
@@ -398,17 +407,30 @@ func (s sdk) CheckTaskState(ctx context.Context, cluster string, serviceName str
 			Tasks:   tasks.TaskArns,
 		})
 		if err != nil {
-			return err
-		}
-		if len(taskDescriptions.Tasks) > 0 {
-			recentTask := taskDescriptions.Tasks[0]
-			switch aws.StringValue(recentTask.StopCode) {
-			case "TaskFailedToStart":
-				return fmt.Errorf(aws.StringValue(recentTask.StoppedReason))
-			}
+			return nil, err
 		}
+		return taskDescriptions.Tasks, nil
 	}
-	return nil
+	return nil, nil
+}
+
+func (s sdk) GetTaskStoppedReason(ctx context.Context, cluster string, taskArn string) (string, error) {
+	taskDescriptions, err := s.ECS.DescribeTasksWithContext(ctx, &ecs.DescribeTasksInput{
+		Cluster: aws.String(cluster),
+		Tasks:   []*string{aws.String(taskArn)},
+	})
+	if err != nil {
+		return "", err
+	}
+	if len(taskDescriptions.Tasks) == 0 {
+		return "", nil
+	}
+	task := taskDescriptions.Tasks[0]
+	return fmt.Sprintf(
+		"%s: %s",
+		aws.StringValue(task.StopCode),
+		aws.StringValue(task.StoppedReason)), nil
+
 }
 
 func (s sdk) DescribeStackEvents(ctx context.Context, stackID string) ([]*cloudformation.StackEvent, error) {
@@ -423,6 +445,7 @@ func (s sdk) DescribeStackEvents(ctx context.Context, stackID string) ([]*cloudf
 		if err != nil {
 			return nil, err
 		}
+
 		events = append(events, resp.StackEvents...)
 		if resp.NextToken == nil {
 			return events, nil
@@ -609,46 +632,43 @@ func (s sdk) GetLogs(ctx context.Context, name string, consumer func(service, co
 	}
 }
 
-func (s sdk) DescribeServices(ctx context.Context, cluster string, arns []string) ([]compose.ServiceStatus, error) {
+func (s sdk) DescribeService(ctx context.Context, cluster string, arn string) (compose.ServiceStatus, error) {
 	services, err := s.ECS.DescribeServicesWithContext(ctx, &ecs.DescribeServicesInput{
 		Cluster:  aws.String(cluster),
-		Services: aws.StringSlice(arns),
+		Services: []*string{aws.String(arn)},
 		Include:  aws.StringSlice([]string{"TAGS"}),
 	})
 	if err != nil {
-		return nil, err
+		return compose.ServiceStatus{}, err
 	}
 
-	status := []compose.ServiceStatus{}
-	for _, service := range services.Services {
-		var name string
-		for _, t := range service.Tags {
-			if *t.Key == compose.ServiceTag {
-				name = aws.StringValue(t.Value)
-			}
-		}
-		if name == "" {
-			return nil, fmt.Errorf("service %s doesn't have a %s tag", *service.ServiceArn, compose.ServiceTag)
+	service := services.Services[0]
+	var name string
+	for _, t := range service.Tags {
+		if *t.Key == compose.ServiceTag {
+			name = aws.StringValue(t.Value)
 		}
-		targetGroupArns := []string{}
-		for _, lb := range service.LoadBalancers {
-			targetGroupArns = append(targetGroupArns, *lb.TargetGroupArn)
-		}
-		// getURLwithPortMapping makes 2 queries
-		// one to get the target groups and another for load balancers
-		loadBalancers, err := s.getURLWithPortMapping(ctx, targetGroupArns)
-		if err != nil {
-			return nil, err
-		}
-		status = append(status, compose.ServiceStatus{
-			ID:         aws.StringValue(service.ServiceName),
-			Name:       name,
-			Replicas:   int(aws.Int64Value(service.RunningCount)),
-			Desired:    int(aws.Int64Value(service.DesiredCount)),
-			Publishers: loadBalancers,
-		})
 	}
-	return status, nil
+	if name == "" {
+		return compose.ServiceStatus{}, fmt.Errorf("service %s doesn't have a %s tag", *service.ServiceArn, compose.ServiceTag)
+	}
+	targetGroupArns := []string{}
+	for _, lb := range service.LoadBalancers {
+		targetGroupArns = append(targetGroupArns, *lb.TargetGroupArn)
+	}
+	// getURLwithPortMapping makes 2 queries
+	// one to get the target groups and another for load balancers
+	loadBalancers, err := s.getURLWithPortMapping(ctx, targetGroupArns)
+	if err != nil {
+		return compose.ServiceStatus{}, err
+	}
+	return compose.ServiceStatus{
+		ID:         aws.StringValue(service.ServiceName),
+		Name:       name,
+		Replicas:   int(aws.Int64Value(service.RunningCount)),
+		Desired:    int(aws.Int64Value(service.DesiredCount)),
+		Publishers: loadBalancers,
+	}, nil
 }
 
 func (s sdk) getURLWithPortMapping(ctx context.Context, targetGroupArns []string) ([]compose.PortPublisher, error) {

+ 12 - 11
ecs/wait.go

@@ -52,7 +52,6 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
 
 	var completed bool
 	var stackErr error
-
 	for !completed {
 		select {
 		case <-done:
@@ -75,11 +74,12 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
 			knownEvents[*event.EventId] = struct{}{}
 
 			resource := aws.StringValue(event.LogicalResourceId)
-			reason := aws.StringValue(event.ResourceStatusReason)
+			reason := shortenMessage(
+				aws.StringValue(event.ResourceStatusReason))
 			status := aws.StringValue(event.ResourceStatus)
 			progressStatus := progress.Working
-			switch status {
 
+			switch status {
 			case "CREATE_COMPLETE":
 				if operation == stackCreate {
 					progressStatus = progress.Done
@@ -101,7 +101,6 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
 					}
 				}
 			}
-
 			w.Event(progress.Event{
 				ID:         resource,
 				Status:     progressStatus,
@@ -111,25 +110,27 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
 		if operation != stackCreate || stackErr != nil {
 			continue
 		}
-		if err := b.SDK.CheckStackState(ctx, name); err != nil {
+		if err := b.checkStackState(ctx, name); err != nil {
 			if e := b.SDK.DeleteStack(ctx, name); e != nil {
 				return e
 			}
 			stackErr = err
 			operation = stackDelete
-			reason := err.Error()
-			if len(reason) > 30 {
-				reason = reason[:30] + "..."
-			}
+			reason := shortenMessage(err.Error())
 			w.Event(progress.Event{
 				ID:         name,
 				Status:     progress.Error,
 				StatusText: reason,
 			})
-
 		}
-
 	}
 
 	return stackErr
 }
+
+func shortenMessage(message string) string {
+	if len(message) < 30 {
+		return message
+	}
+	return message[:30] + "..."
+}