Ver código fonte

handle API pagination

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 4 anos atrás
pai
commit
59d17a45a0
1 arquivos alterados com 161 adições e 109 exclusões
  1. 161 109
      ecs/sdk.go

+ 161 - 109
ecs/sdk.go

@@ -181,7 +181,7 @@ func (s sdk) GetDefaultVPC(ctx context.Context) (string, error) {
 
 func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error) {
 	logrus.Debug("Retrieve SubNets")
-	ids := []awsResource{}
+	var ids []awsResource
 	var token *string
 	for {
 		subnets, err := s.EC2.DescribeSubnetsWithContext(ctx, &ec2.DescribeSubnetsInput{
@@ -453,48 +453,61 @@ func (s sdk) ListStacks(ctx context.Context, name string) ([]compose.Stack, erro
 	if name != "" {
 		params.StackName = &name
 	}
-	cfStacks, err := s.CF.DescribeStacksWithContext(ctx, &params)
-	if err != nil {
-		return nil, err
-	}
-	stacks := []compose.Stack{}
-	for _, stack := range cfStacks.Stacks {
-		for _, t := range stack.Tags {
-			if *t.Key == compose.ProjectTag {
-				status := compose.RUNNING
-				switch aws.StringValue(stack.StackStatus) {
-				case "CREATE_IN_PROGRESS":
-					status = compose.STARTING
-				case "DELETE_IN_PROGRESS":
-					status = compose.REMOVING
-				case "UPDATE_IN_PROGRESS":
-					status = compose.UPDATING
-				default:
+	var token *string
+	var stacks []compose.Stack
+	for {
+		response, err := s.CF.DescribeStacksWithContext(ctx, &params)
+		if err != nil {
+			return nil, err
+		}
+		for _, stack := range response.Stacks {
+			for _, t := range stack.Tags {
+				if *t.Key == compose.ProjectTag {
+					status := compose.RUNNING
+					switch aws.StringValue(stack.StackStatus) {
+					case "CREATE_IN_PROGRESS":
+						status = compose.STARTING
+					case "DELETE_IN_PROGRESS":
+						status = compose.REMOVING
+					case "UPDATE_IN_PROGRESS":
+						status = compose.UPDATING
+					default:
+					}
+					stacks = append(stacks, compose.Stack{
+						ID:     aws.StringValue(stack.StackId),
+						Name:   aws.StringValue(stack.StackName),
+						Status: status,
+					})
+					break
 				}
-				stacks = append(stacks, compose.Stack{
-					ID:     aws.StringValue(stack.StackId),
-					Name:   aws.StringValue(stack.StackName),
-					Status: status,
-				})
-				break
 			}
 		}
+		if token == response.NextToken {
+			return stacks, nil
+		}
+		token = response.NextToken
 	}
-	return stacks, nil
 }
 
 func (s sdk) GetStackClusterID(ctx context.Context, stack string) (string, error) {
 	// Note: could use DescribeStackResource but we only can detect `does not exist` case by matching string error message
-	resources, err := s.CF.ListStackResourcesWithContext(ctx, &cloudformation.ListStackResourcesInput{
-		StackName: aws.String(stack),
-	})
-	if err != nil {
-		return "", err
-	}
-	for _, r := range resources.StackResourceSummaries {
-		if aws.StringValue(r.ResourceType) == "AWS::ECS::Cluster" {
-			return aws.StringValue(r.PhysicalResourceId), nil
+	var token *string
+	for {
+		response, err := s.CF.ListStackResourcesWithContext(ctx, &cloudformation.ListStackResourcesInput{
+			StackName: aws.String(stack),
+		})
+		if err != nil {
+			return "", err
 		}
+		for _, r := range response.StackResourceSummaries {
+			if aws.StringValue(r.ResourceType) == "AWS::ECS::Cluster" {
+				return aws.StringValue(r.PhysicalResourceId), nil
+			}
+		}
+		if token == response.NextToken {
+			break
+		}
+		token = response.NextToken
 	}
 	// stack is using user-provided cluster
 	res, err := s.CF.GetTemplateSummaryWithContext(ctx, &cloudformation.GetTemplateSummaryInput{
@@ -522,19 +535,27 @@ type templateMetadata struct {
 
 func (s sdk) GetServiceTaskDefinition(ctx context.Context, cluster string, serviceArns []string) (map[string]string, error) {
 	defs := map[string]string{}
+
 	svc := []*string{}
 	for _, s := range serviceArns {
 		svc = append(svc, aws.String(s))
 	}
-	services, err := s.ECS.DescribeServicesWithContext(ctx, &ecs.DescribeServicesInput{
-		Cluster:  aws.String(cluster),
-		Services: svc,
-	})
-	if err != nil {
-		return nil, err
-	}
-	for _, s := range services.Services {
-		defs[aws.StringValue(s.ServiceArn)] = aws.StringValue(s.TaskDefinition)
+	for i := 0; i < len(svc); i += 10 {
+		end := i + 10
+		if end > len(svc) {
+			end = len(svc)
+		}
+		chunk := svc[i:end]
+		services, err := s.ECS.DescribeServicesWithContext(ctx, &ecs.DescribeServicesInput{
+			Cluster:  aws.String(cluster),
+			Services: chunk,
+		})
+		if err != nil {
+			return nil, err
+		}
+		for _, s := range services.Services {
+			defs[aws.StringValue(s.ServiceArn)] = aws.StringValue(s.TaskDefinition)
+		}
 	}
 	return defs, nil
 }
@@ -570,25 +591,32 @@ func (s sdk) GetServiceTasks(ctx context.Context, cluster string, service string
 	if stopped {
 		state = "STOPPED"
 	}
-	tasks, err := s.ECS.ListTasksWithContext(ctx, &ecs.ListTasksInput{
-		Cluster:       aws.String(cluster),
-		ServiceName:   aws.String(service),
-		DesiredStatus: aws.String(state),
-	})
-	if err != nil {
-		return nil, err
-	}
-	if len(tasks.TaskArns) > 0 {
-		taskDescriptions, err := s.ECS.DescribeTasksWithContext(ctx, &ecs.DescribeTasksInput{
-			Cluster: aws.String(cluster),
-			Tasks:   tasks.TaskArns,
+	var token *string
+	var tasks []*ecs.Task
+	for {
+		response, err := s.ECS.ListTasksWithContext(ctx, &ecs.ListTasksInput{
+			Cluster:       aws.String(cluster),
+			ServiceName:   aws.String(service),
+			DesiredStatus: aws.String(state),
 		})
 		if err != nil {
 			return nil, err
 		}
-		return taskDescriptions.Tasks, nil
+		if len(response.TaskArns) > 0 {
+			taskDescriptions, err := s.ECS.DescribeTasksWithContext(ctx, &ecs.DescribeTasksInput{
+				Cluster: aws.String(cluster),
+				Tasks:   response.TaskArns,
+			})
+			if err != nil {
+				return nil, err
+			}
+			tasks = append(tasks, taskDescriptions.Tasks...)
+		}
+		if token == response.NextToken {
+			return tasks, nil
+		}
+		token = response.NextToken
 	}
-	return nil, nil
 }
 
 func (s sdk) GetTaskStoppedReason(ctx context.Context, cluster string, taskArn string) (string, error) {
@@ -671,24 +699,29 @@ func (resources stackResources) apply(awsType string, fn stackResourceFn) error
 }
 
 func (s sdk) ListStackResources(ctx context.Context, name string) (stackResources, error) {
-	// FIXME handle pagination
-	res, err := s.CF.ListStackResourcesWithContext(ctx, &cloudformation.ListStackResourcesInput{
-		StackName: aws.String(name),
-	})
-	if err != nil {
-		return nil, err
-	}
-
-	resources := stackResources{}
-	for _, r := range res.StackResourceSummaries {
-		resources = append(resources, stackResource{
-			LogicalID: aws.StringValue(r.LogicalResourceId),
-			Type:      aws.StringValue(r.ResourceType),
-			ARN:       aws.StringValue(r.PhysicalResourceId),
-			Status:    aws.StringValue(r.ResourceStatus),
+	var token *string
+	var resources stackResources
+	for {
+		response, err := s.CF.ListStackResourcesWithContext(ctx, &cloudformation.ListStackResourcesInput{
+			StackName: aws.String(name),
 		})
+		if err != nil {
+			return nil, err
+		}
+
+		for _, r := range response.StackResourceSummaries {
+			resources = append(resources, stackResource{
+				LogicalID: aws.StringValue(r.LogicalResourceId),
+				Type:      aws.StringValue(r.ResourceType),
+				ARN:       aws.StringValue(r.PhysicalResourceId),
+				Status:    aws.StringValue(r.ResourceStatus),
+			})
+		}
+		if token == response.NextToken {
+			return resources, nil
+		}
+		token = response.NextToken
 	}
-	return resources, nil
 }
 
 func (s sdk) DeleteStack(ctx context.Context, name string) error {
@@ -744,25 +777,32 @@ func (s sdk) InspectSecret(ctx context.Context, id string) (secrets.Secret, erro
 
 func (s sdk) ListSecrets(ctx context.Context) ([]secrets.Secret, error) {
 	logrus.Debug("List secrets ...")
-	response, err := s.SM.ListSecrets(&secretsmanager.ListSecretsInput{})
-	if err != nil {
-		return nil, err
-	}
-
 	var ls []secrets.Secret
-	for _, sec := range response.SecretList {
+	var token *string
+	for {
+		response, err := s.SM.ListSecrets(&secretsmanager.ListSecretsInput{})
+		if err != nil {
+			return nil, err
+		}
+
+		for _, sec := range response.SecretList {
 
-		tags := map[string]string{}
-		for _, tag := range sec.Tags {
-			tags[*tag.Key] = *tag.Value
+			tags := map[string]string{}
+			for _, tag := range sec.Tags {
+				tags[*tag.Key] = *tag.Value
+			}
+			ls = append(ls, secrets.Secret{
+				ID:     *sec.ARN,
+				Name:   *sec.Name,
+				Labels: tags,
+			})
 		}
-		ls = append(ls, secrets.Secret{
-			ID:     *sec.ARN,
-			Name:   *sec.Name,
-			Labels: tags,
-		})
+
+		if token == response.NextToken {
+			return ls, nil
+		}
+		token = response.NextToken
 	}
-	return ls, nil
 }
 
 func (s sdk) DeleteSecret(ctx context.Context, id string, recover bool) error {
@@ -967,34 +1007,46 @@ func (s sdk) getURLWithPortMapping(ctx context.Context, targetGroupArns []string
 }
 
 func (s sdk) ListTasks(ctx context.Context, cluster string, family string) ([]string, error) {
-	tasks, err := s.ECS.ListTasksWithContext(ctx, &ecs.ListTasksInput{
-		Cluster: aws.String(cluster),
-		Family:  aws.String(family),
-	})
-	if err != nil {
-		return nil, err
-	}
-	arns := []string{}
-	for _, arn := range tasks.TaskArns {
-		arns = append(arns, *arn)
+	var token *string
+	var arns []string
+	for {
+		response, err := s.ECS.ListTasksWithContext(ctx, &ecs.ListTasksInput{
+			Cluster: aws.String(cluster),
+			Family:  aws.String(family),
+		})
+		if err != nil {
+			return nil, err
+		}
+		for _, arn := range response.TaskArns {
+			arns = append(arns, *arn)
+		}
+		if token == response.NextToken {
+			return arns, nil
+		}
+		token = response.NextToken
 	}
-	return arns, nil
 }
 
 func (s sdk) GetPublicIPs(ctx context.Context, interfaces ...string) (map[string]string, error) {
-	desc, err := s.EC2.DescribeNetworkInterfaces(&ec2.DescribeNetworkInterfacesInput{
-		NetworkInterfaceIds: aws.StringSlice(interfaces),
-	})
-	if err != nil {
-		return nil, err
-	}
+	var token *string
 	publicIPs := map[string]string{}
-	for _, interf := range desc.NetworkInterfaces {
-		if interf.Association != nil {
-			publicIPs[aws.StringValue(interf.NetworkInterfaceId)] = aws.StringValue(interf.Association.PublicIp)
+	for {
+		response, err := s.EC2.DescribeNetworkInterfaces(&ec2.DescribeNetworkInterfacesInput{
+			NetworkInterfaceIds: aws.StringSlice(interfaces),
+		})
+		if err != nil {
+			return nil, err
+		}
+		for _, interf := range response.NetworkInterfaces {
+			if interf.Association != nil {
+				publicIPs[aws.StringValue(interf.NetworkInterfaceId)] = aws.StringValue(interf.Association.PublicIp)
+			}
+		}
+		if token == response.NextToken {
+			return publicIPs, nil
 		}
+		token = response.NextToken
 	}
-	return publicIPs, nil
 }
 
 func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrarn string) (awsResource, string, error) {