Browse Source

Merge pull request #706 from aiordache/ecs_task_failure_detection

Nicolas De loof 5 years ago
parent
commit
60bd0e5303
6 changed files with 258 additions and 54 deletions
  1. 1 0
      api/compose/api.go
  2. 2 1
      cli/cmd/compose/list.go
  3. 83 1
      ecs/list.go
  4. 11 19
      ecs/ps.go
  5. 135 31
      ecs/sdk.go
  6. 26 2
      ecs/wait.go

+ 1 - 0
api/compose/api.go

@@ -77,4 +77,5 @@ type Stack struct {
 	ID     string
 	Name   string
 	Status string
+	Reason string
 }

+ 2 - 1
cli/cmd/compose/list.go

@@ -21,6 +21,7 @@ import (
 	"fmt"
 	"io"
 	"os"
+	"strings"
 
 	"github.com/spf13/cobra"
 	"github.com/spf13/pflag"
@@ -75,7 +76,7 @@ func viewFromStackList(stackList []compose.Stack) []stackView {
 	for i, s := range stackList {
 		retList[i] = stackView{
 			Name:   s.Name,
-			Status: s.Status,
+			Status: strings.TrimSpace(fmt.Sprintf("%s %s", s.Status, s.Reason)),
 		}
 	}
 	return retList

+ 83 - 1
ecs/list.go

@@ -18,11 +18,93 @@ package ecs
 
 import (
 	"context"
+	"fmt"
 
 	"github.com/docker/compose-cli/api/compose"
 )
 
 func (b *ecsAPIService) List(ctx context.Context, project string) ([]compose.Stack, error) {
-	return b.SDK.ListStacks(ctx, project)
+	stacks, err := b.SDK.ListStacks(ctx, project)
+	if err != nil {
+		return nil, err
+	}
 
+	for _, stack := range stacks {
+		if stack.Status == compose.STARTING {
+			if err := b.checkStackState(ctx, stack.Name); err != nil {
+				stack.Status = compose.FAILED
+				stack.Reason = err.Error()
+			}
+		}
+	}
+	return stacks, nil
+
+}
+
+func (b *ecsAPIService) checkStackState(ctx context.Context, name string) error {
+	resources, err := b.SDK.ListStackResources(ctx, name)
+	if err != nil {
+		return err
+	}
+	svcArns := []string{}
+	svcNames := map[string]string{}
+	var cluster string
+	for _, r := range resources {
+		if r.Type == "AWS::ECS::Cluster" {
+			cluster = r.ARN
+			continue
+		}
+		if r.Type == "AWS::ECS::Service" {
+			if r.ARN == "" {
+				continue
+			}
+			svcArns = append(svcArns, r.ARN)
+			svcNames[r.ARN] = r.LogicalID
+		}
+	}
+	if len(svcArns) == 0 {
+		return nil
+	}
+	services, err := b.SDK.GetServiceTaskDefinition(ctx, cluster, svcArns)
+	if err != nil {
+		return err
+	}
+	for service, taskDef := range services {
+		if err := b.checkServiceState(ctx, cluster, service, taskDef); err != nil {
+			return fmt.Errorf("%s %s", svcNames[service], err.Error())
+		}
+	}
+	return nil
+}
+
+func (b *ecsAPIService) checkServiceState(ctx context.Context, cluster string, service string, taskdef string) error {
+	runningTasks, err := b.SDK.GetServiceTasks(ctx, cluster, service, false)
+	if err != nil {
+		return err
+	}
+	if len(runningTasks) > 0 {
+		return nil
+	}
+	stoppedTasks, err := b.SDK.GetServiceTasks(ctx, cluster, service, true)
+	if err != nil {
+		return err
+	}
+	if len(stoppedTasks) == 0 {
+		return nil
+	}
+	// filter tasks by task definition
+	tasks := []string{}
+	for _, t := range stoppedTasks {
+		if t.TaskDefinitionArn != nil && *t.TaskDefinitionArn == taskdef {
+			tasks = append(tasks, *t.TaskArn)
+		}
+	}
+	if len(tasks) == 0 {
+		return nil
+	}
+	reason, err := b.SDK.GetTaskStoppedReason(ctx, cluster, tasks[0])
+	if err != nil {
+		return err
+	}
+	return fmt.Errorf("%s", reason)
 }

+ 11 - 19
ecs/ps.go

@@ -25,33 +25,25 @@ import (
 )
 
 func (b *ecsAPIService) Ps(ctx context.Context, project string) ([]compose.ServiceStatus, error) {
-	resources, err := b.SDK.ListStackResources(ctx, project)
+	cluster, err := b.SDK.GetStackClusterID(ctx, project)
 	if err != nil {
 		return nil, err
 	}
-
-	var (
-		cluster     = project
-		servicesARN []string
-	)
-	for _, r := range resources {
-		switch r.Type {
-		case "AWS::ECS::Service":
-			servicesARN = append(servicesARN, r.ARN)
-		case "AWS::ECS::Cluster":
-			cluster = r.ARN
-		}
+	servicesARN, err := b.SDK.ListStackServices(ctx, project)
+	if err != nil {
+		return nil, err
 	}
 
 	if len(servicesARN) == 0 {
 		return nil, nil
 	}
-	status, err := b.SDK.DescribeServices(ctx, cluster, servicesARN)
-	if err != nil {
-		return nil, err
-	}
 
-	for i, state := range status {
+	status := []compose.ServiceStatus{}
+	for _, arn := range servicesARN {
+		state, err := b.SDK.DescribeService(ctx, cluster, arn)
+		if err != nil {
+			return nil, err
+		}
 		ports := []string{}
 		for _, lb := range state.Publishers {
 			ports = append(ports, fmt.Sprintf(
@@ -62,7 +54,7 @@ func (b *ecsAPIService) Ps(ctx context.Context, project string) ([]compose.Servi
 				strings.ToLower(lb.Protocol)))
 		}
 		state.Ports = ports
-		status[i] = state
+		status = append(status, state)
 	}
 	return status, nil
 }

+ 135 - 31
ecs/sdk.go

@@ -314,6 +314,7 @@ func (s sdk) ListStacks(ctx context.Context, name string) ([]compose.Stack, erro
 					status = compose.REMOVING
 				case "UPDATE_IN_PROGRESS":
 					status = compose.UPDATING
+				default:
 				}
 				stacks = append(stacks, compose.Stack{
 					ID:     aws.StringValue(stack.StackId),
@@ -327,6 +328,111 @@ func (s sdk) ListStacks(ctx context.Context, name string) ([]compose.Stack, erro
 	return stacks, nil
 }
 
+func (s sdk) GetStackClusterID(ctx context.Context, stack string) (string, error) {
+	resources, err := s.CF.ListStackResourcesWithContext(ctx, &cloudformation.ListStackResourcesInput{
+		StackName: aws.String(stack),
+	})
+	if err != nil {
+		return "", err
+	}
+	for _, r := range resources.StackResourceSummaries {
+		if aws.StringValue(r.ResourceType) == "AWS::ECS::Cluster" {
+			return aws.StringValue(r.PhysicalResourceId), nil
+		}
+	}
+	return "", nil
+}
+
+func (s sdk) GetServiceTaskDefinition(ctx context.Context, cluster string, serviceArns []string) (map[string]string, error) {
+	defs := map[string]string{}
+	svc := []*string{}
+	for _, s := range serviceArns {
+		svc = append(svc, aws.String(s))
+	}
+	services, err := s.ECS.DescribeServicesWithContext(ctx, &ecs.DescribeServicesInput{
+		Cluster:  aws.String(cluster),
+		Services: svc,
+	})
+	if err != nil {
+		return nil, err
+	}
+	for _, s := range services.Services {
+		defs[aws.StringValue(s.ServiceArn)] = aws.StringValue(s.TaskDefinition)
+	}
+	return defs, nil
+}
+
+func (s sdk) ListStackServices(ctx context.Context, stack string) ([]string, error) {
+	arns := []string{}
+	var nextToken *string
+	for {
+		response, err := s.CF.ListStackResourcesWithContext(ctx, &cloudformation.ListStackResourcesInput{
+			StackName: aws.String(stack),
+			NextToken: nextToken,
+		})
+		if err != nil {
+			return nil, err
+		}
+		for _, r := range response.StackResourceSummaries {
+			if aws.StringValue(r.ResourceType) == "AWS::ECS::Service" {
+				if r.PhysicalResourceId != nil {
+					arns = append(arns, aws.StringValue(r.PhysicalResourceId))
+				}
+			}
+		}
+		nextToken = response.NextToken
+		if nextToken == nil {
+			break
+		}
+	}
+	return arns, nil
+}
+
+func (s sdk) GetServiceTasks(ctx context.Context, cluster string, service string, stopped bool) ([]*ecs.Task, error) {
+	state := "RUNNING"
+	if stopped {
+		state = "STOPPED"
+	}
+	tasks, err := s.ECS.ListTasksWithContext(ctx, &ecs.ListTasksInput{
+		Cluster:       aws.String(cluster),
+		ServiceName:   aws.String(service),
+		DesiredStatus: aws.String(state),
+	})
+	if err != nil {
+		return nil, err
+	}
+	if len(tasks.TaskArns) > 0 {
+		taskDescriptions, err := s.ECS.DescribeTasksWithContext(ctx, &ecs.DescribeTasksInput{
+			Cluster: aws.String(cluster),
+			Tasks:   tasks.TaskArns,
+		})
+		if err != nil {
+			return nil, err
+		}
+		return taskDescriptions.Tasks, nil
+	}
+	return nil, nil
+}
+
+func (s sdk) GetTaskStoppedReason(ctx context.Context, cluster string, taskArn string) (string, error) {
+	taskDescriptions, err := s.ECS.DescribeTasksWithContext(ctx, &ecs.DescribeTasksInput{
+		Cluster: aws.String(cluster),
+		Tasks:   []*string{aws.String(taskArn)},
+	})
+	if err != nil {
+		return "", err
+	}
+	if len(taskDescriptions.Tasks) == 0 {
+		return "", nil
+	}
+	task := taskDescriptions.Tasks[0]
+	return fmt.Sprintf(
+		"%s: %s",
+		aws.StringValue(task.StopCode),
+		aws.StringValue(task.StoppedReason)), nil
+
+}
+
 func (s sdk) DescribeStackEvents(ctx context.Context, stackID string) ([]*cloudformation.StackEvent, error) {
 	// Fixme implement Paginator on Events and return as a chan(events)
 	events := []*cloudformation.StackEvent{}
@@ -339,6 +445,7 @@ func (s sdk) DescribeStackEvents(ctx context.Context, stackID string) ([]*cloudf
 		if err != nil {
 			return nil, err
 		}
+
 		events = append(events, resp.StackEvents...)
 		if resp.NextToken == nil {
 			return events, nil
@@ -525,46 +632,43 @@ func (s sdk) GetLogs(ctx context.Context, name string, consumer func(service, co
 	}
 }
 
-func (s sdk) DescribeServices(ctx context.Context, cluster string, arns []string) ([]compose.ServiceStatus, error) {
+func (s sdk) DescribeService(ctx context.Context, cluster string, arn string) (compose.ServiceStatus, error) {
 	services, err := s.ECS.DescribeServicesWithContext(ctx, &ecs.DescribeServicesInput{
 		Cluster:  aws.String(cluster),
-		Services: aws.StringSlice(arns),
+		Services: []*string{aws.String(arn)},
 		Include:  aws.StringSlice([]string{"TAGS"}),
 	})
 	if err != nil {
-		return nil, err
+		return compose.ServiceStatus{}, err
 	}
 
-	status := []compose.ServiceStatus{}
-	for _, service := range services.Services {
-		var name string
-		for _, t := range service.Tags {
-			if *t.Key == compose.ServiceTag {
-				name = aws.StringValue(t.Value)
-			}
-		}
-		if name == "" {
-			return nil, fmt.Errorf("service %s doesn't have a %s tag", *service.ServiceArn, compose.ServiceTag)
+	service := services.Services[0]
+	var name string
+	for _, t := range service.Tags {
+		if *t.Key == compose.ServiceTag {
+			name = aws.StringValue(t.Value)
 		}
-		targetGroupArns := []string{}
-		for _, lb := range service.LoadBalancers {
-			targetGroupArns = append(targetGroupArns, *lb.TargetGroupArn)
-		}
-		// getURLwithPortMapping makes 2 queries
-		// one to get the target groups and another for load balancers
-		loadBalancers, err := s.getURLWithPortMapping(ctx, targetGroupArns)
-		if err != nil {
-			return nil, err
-		}
-		status = append(status, compose.ServiceStatus{
-			ID:         aws.StringValue(service.ServiceName),
-			Name:       name,
-			Replicas:   int(aws.Int64Value(service.RunningCount)),
-			Desired:    int(aws.Int64Value(service.DesiredCount)),
-			Publishers: loadBalancers,
-		})
 	}
-	return status, nil
+	if name == "" {
+		return compose.ServiceStatus{}, fmt.Errorf("service %s doesn't have a %s tag", *service.ServiceArn, compose.ServiceTag)
+	}
+	targetGroupArns := []string{}
+	for _, lb := range service.LoadBalancers {
+		targetGroupArns = append(targetGroupArns, *lb.TargetGroupArn)
+	}
+	// getURLwithPortMapping makes 2 queries
+	// one to get the target groups and another for load balancers
+	loadBalancers, err := s.getURLWithPortMapping(ctx, targetGroupArns)
+	if err != nil {
+		return compose.ServiceStatus{}, err
+	}
+	return compose.ServiceStatus{
+		ID:         aws.StringValue(service.ServiceName),
+		Name:       name,
+		Replicas:   int(aws.Int64Value(service.RunningCount)),
+		Desired:    int(aws.Int64Value(service.DesiredCount)),
+		Publishers: loadBalancers,
+	}, nil
 }
 
 func (s sdk) getURLWithPortMapping(ctx context.Context, targetGroupArns []string) ([]compose.PortPublisher, error) {

+ 26 - 2
ecs/wait.go

@@ -74,7 +74,8 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
 			knownEvents[*event.EventId] = struct{}{}
 
 			resource := aws.StringValue(event.LogicalResourceId)
-			reason := aws.StringValue(event.ResourceStatusReason)
+			reason := shortenMessage(
+				aws.StringValue(event.ResourceStatusReason))
 			status := aws.StringValue(event.ResourceStatus)
 			progressStatus := progress.Working
 
@@ -103,10 +104,33 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
 			w.Event(progress.Event{
 				ID:         resource,
 				Status:     progressStatus,
-				StatusText: status,
+				StatusText: reason,
+			})
+		}
+		if operation != stackCreate || stackErr != nil {
+			continue
+		}
+		if err := b.checkStackState(ctx, name); err != nil {
+			if e := b.SDK.DeleteStack(ctx, name); e != nil {
+				return e
+			}
+			stackErr = err
+			operation = stackDelete
+			reason := shortenMessage(err.Error())
+			w.Event(progress.Event{
+				ID:         name,
+				Status:     progress.Error,
+				StatusText: reason,
 			})
 		}
 	}
 
 	return stackErr
 }
+
+func shortenMessage(message string) string {
+	if len(message) < 30 {
+		return message
+	}
+	return message[:30] + "..."
+}