Prechádzať zdrojové kódy

`parse` to return awsResources then convert into CF template

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 5 rokov pred
rodič
commit
d5e0ec7aa6
6 zmenil súbory, kde vykonal 116 pridanie a 102 odobranie
  1. 68 62
      ecs/awsResources.go
  2. 6 8
      ecs/backend.go
  3. 22 20
      ecs/cloudformation.go
  4. 7 9
      ecs/cloudformation_test.go
  5. 3 3
      ecs/ec2.go
  6. 10 0
      ecs/sdk.go

+ 68 - 62
ecs/awsResources.go

@@ -32,7 +32,6 @@ import (
 
 // awsResources hold the AWS component being used or created to support services definition
 type awsResources struct {
-	sdk              sdk
 	vpc              string
 	subnets          []string
 	cluster          string
@@ -58,101 +57,120 @@ func (r *awsResources) allSecurityGroups() []string {
 }
 
 // parse look into compose project for configured resource to use, and check they are valid
-func (r *awsResources) parse(ctx context.Context, project *types.Project) error {
-	return findProjectFnError(ctx, project,
-		r.parseClusterExtension,
-		r.parseVPCExtension,
-		r.parseLoadBalancerExtension,
-		r.parseSecurityGroupExtension,
-	)
+func (b *ecsAPIService) parse(ctx context.Context, project *types.Project) (awsResources, error) {
+	r := awsResources{}
+	var err error
+	r.cluster, err = b.parseClusterExtension(ctx, project)
+	if err != nil {
+		return r, err
+	}
+	r.vpc, r.subnets, err = b.parseVPCExtension(ctx, project)
+	if err != nil {
+		return r, err
+	}
+	r.loadBalancer, r.loadBalancerType, err = b.parseLoadBalancerExtension(ctx, project)
+	if err != nil {
+		return r, err
+	}
+	r.securityGroups, err = b.parseSecurityGroupExtension(ctx, project)
+	if err != nil {
+		return r, err
+	}
+	return r, nil
 }
 
-func (r *awsResources) parseClusterExtension(ctx context.Context, project *types.Project) error {
+func (b *ecsAPIService) parseClusterExtension(ctx context.Context, project *types.Project) (string, error) {
 	if x, ok := project.Extensions[extensionCluster]; ok {
 		cluster := x.(string)
-		ok, err := r.sdk.ClusterExists(ctx, cluster)
+		ok, err := b.SDK.ClusterExists(ctx, cluster)
 		if err != nil {
-			return err
+			return "", err
 		}
 		if !ok {
-			return fmt.Errorf("cluster does not exist: %s", cluster)
+			return "", fmt.Errorf("cluster does not exist: %s", cluster)
 		}
-		r.cluster = cluster
+		return cluster, nil
 	}
-	return nil
+	return "", nil
 }
 
-func (r *awsResources) parseVPCExtension(ctx context.Context, project *types.Project) error {
+func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project) (string, []string, error) {
+	var vpc string
 	if x, ok := project.Extensions[extensionVPC]; ok {
-		vpc := x.(string)
-		err := r.sdk.CheckVPC(ctx, vpc)
+		vpc = x.(string)
+		err := b.SDK.CheckVPC(ctx, vpc)
 		if err != nil {
-			return err
+			return "", nil, err
 		}
-		r.vpc = vpc
+
 	} else {
-		defaultVPC, err := r.sdk.GetDefaultVPC(ctx)
+		defaultVPC, err := b.SDK.GetDefaultVPC(ctx)
 		if err != nil {
-			return err
+			return "", nil, err
 		}
-		r.vpc = defaultVPC
+		vpc = defaultVPC
 	}
 
-	subNets, err := r.sdk.GetSubNets(ctx, r.vpc)
+	subNets, err := b.SDK.GetSubNets(ctx, vpc)
 	if err != nil {
-		return err
+		return "", nil, err
 	}
 	if len(subNets) < 2 {
-		return fmt.Errorf("VPC %s should have at least 2 associated subnets in different availability zones", r.vpc)
+		return "", nil, fmt.Errorf("VPC %s should have at least 2 associated subnets in different availability zones", vpc)
 	}
-	r.subnets = subNets
-	return nil
+	return vpc, subNets, nil
 }
 
-func (r *awsResources) parseLoadBalancerExtension(ctx context.Context, project *types.Project) error {
+func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (string, string, error) {
 	if x, ok := project.Extensions[extensionLoadBalancer]; ok {
 		loadBalancer := x.(string)
-		loadBalancerType, err := r.sdk.LoadBalancerType(ctx, loadBalancer)
+		loadBalancerType, err := b.SDK.LoadBalancerType(ctx, loadBalancer)
 		if err != nil {
-			return err
+			return "", "", err
 		}
 
 		required := getRequiredLoadBalancerType(project)
 		if loadBalancerType != required {
-			return fmt.Errorf("load balancer %s is of type %s, project require a %s", loadBalancer, loadBalancerType, required)
+			return "", "", fmt.Errorf("load balancer %s is of type %s, project require a %s", loadBalancer, loadBalancerType, required)
 		}
 
-		r.loadBalancer = loadBalancer
-		r.loadBalancerType = loadBalancerType
+		return loadBalancer, loadBalancerType, nil
 	}
-	return nil
+	return "", "", nil
 }
 
-func (r *awsResources) parseSecurityGroupExtension(ctx context.Context, project *types.Project) error {
-	if r.securityGroups == nil {
-		r.securityGroups = make(map[string]string, len(project.Networks))
-	}
+func (b *ecsAPIService) parseSecurityGroupExtension(ctx context.Context, project *types.Project) (map[string]string, error) {
+	securityGroups := make(map[string]string, len(project.Networks))
 	for name, net := range project.Networks {
+		var sg string
 		if net.External.External {
-			r.securityGroups[name] = net.Name
+			sg = net.Name
 		}
 		if x, ok := net.Extensions[extensionSecurityGroup]; ok {
 			logrus.Warn("to use an existing security-group, use `network.external` and `network.name` in your compose file")
 			logrus.Debugf("Security Group for network %q set by user to %q", net.Name, x)
-			r.securityGroups[name] = x.(string)
+			sg = x.(string)
+		}
+		exists, err := b.SDK.SecurityGroupExists(ctx, sg)
+		if err != nil {
+			return nil, err
 		}
+		if !exists {
+			return nil, fmt.Errorf("security group %s doesn't exist", sg)
+		}
+		securityGroups[name] = sg
 	}
-	return nil
+	return securityGroups, nil
 }
 
-// ensure all required resources pre-exists or are defined as cloudformation resources
-func (r *awsResources) ensure(project *types.Project, template *cloudformation.Template) {
-	r.ensureCluster(project, template)
-	r.ensureNetworks(project, template)
-	r.ensureLoadBalancer(project, template)
+// ensureResources create required resources in template if not yet defined
+func (b *ecsAPIService) ensureResources(resources *awsResources, project *types.Project, template *cloudformation.Template) {
+	b.ensureCluster(resources, project, template)
+	b.ensureNetworks(resources, project, template)
+	b.ensureLoadBalancer(resources, project, template)
 }
 
-func (r *awsResources) ensureCluster(project *types.Project, template *cloudformation.Template) {
+func (b *ecsAPIService) ensureCluster(r *awsResources, project *types.Project, template *cloudformation.Template) {
 	if r.cluster != "" {
 		return
 	}
@@ -163,7 +181,7 @@ func (r *awsResources) ensureCluster(project *types.Project, template *cloudform
 	r.cluster = cloudformation.Ref("Cluster")
 }
 
-func (r *awsResources) ensureNetworks(project *types.Project, template *cloudformation.Template) {
+func (b *ecsAPIService) ensureNetworks(r *awsResources, project *types.Project, template *cloudformation.Template) {
 	if r.securityGroups == nil {
 		r.securityGroups = make(map[string]string, len(project.Networks))
 	}
@@ -179,7 +197,7 @@ func (r *awsResources) ensureNetworks(project *types.Project, template *cloudfor
 		ingress := securityGroup + "Ingress"
 		template.Resources[ingress] = &ec2.SecurityGroupIngress{
 			Description:           fmt.Sprintf("Allow communication within network %s", name),
-			IpProtocol:            "-1", // all protocols
+			IpProtocol:            allProtocols,
 			GroupId:               cloudformation.Ref(securityGroup),
 			SourceSecurityGroupId: cloudformation.Ref(securityGroup),
 		}
@@ -188,7 +206,7 @@ func (r *awsResources) ensureNetworks(project *types.Project, template *cloudfor
 	}
 }
 
-func (r *awsResources) ensureLoadBalancer(project *types.Project, template *cloudformation.Template) {
+func (b *ecsAPIService) ensureLoadBalancer(r *awsResources, project *types.Project, template *cloudformation.Template) {
 	if r.loadBalancer != "" {
 		return
 	}
@@ -239,18 +257,6 @@ func portIsHTTP(it types.ServicePortConfig) bool {
 	return it.Target == 80 || it.Target == 443
 }
 
-type projectFn func(ctx context.Context, project *types.Project) error
-
-func findProjectFnError(ctx context.Context, project *types.Project, funcs ...projectFn) error {
-	for _, fn := range funcs {
-		err := fn(ctx, project)
-		if err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
 // predicate[types.ServiceConfig]
 type servicePredicate func(it types.ServiceConfig) bool
 

+ 6 - 8
ecs/backend.go

@@ -75,18 +75,16 @@ func getEcsAPIService(ecsCtx store.EcsContext) (*ecsAPIService, error) {
 
 	sdk := newSDK(sess)
 	return &ecsAPIService{
-		ctx:       ecsCtx,
-		Region:    ecsCtx.Region,
-		SDK:       sdk,
-		resources: awsResources{sdk: sdk},
+		ctx:    ecsCtx,
+		Region: ecsCtx.Region,
+		SDK:    sdk,
 	}, nil
 }
 
 type ecsAPIService struct {
-	ctx       store.EcsContext
-	Region    string
-	SDK       sdk
-	resources awsResources
+	ctx    store.EcsContext
+	Region string
+	SDK    sdk
 }
 
 func (a *ecsAPIService) ContainerService() containers.Service {

+ 22 - 20
ecs/cloudformation.go

@@ -43,12 +43,12 @@ func (b *ecsAPIService) Convert(ctx context.Context, project *types.Project) ([]
 		return nil, err
 	}
 
-	err = b.resources.parse(ctx, project)
+	resources, err := b.parse(ctx, project)
 	if err != nil {
 		return nil, err
 	}
 
-	template, err := b.convert(project)
+	template, err := b.convert(project, resources)
 	if err != nil {
 		return nil, err
 	}
@@ -64,7 +64,7 @@ func (b *ecsAPIService) Convert(ctx context.Context, project *types.Project) ([]
 		}
 	}
 
-	err = b.createCapacityProvider(ctx, project, template)
+	err = b.createCapacityProvider(ctx, project, template, resources)
 	if err != nil {
 		return nil, err
 	}
@@ -73,9 +73,9 @@ func (b *ecsAPIService) Convert(ctx context.Context, project *types.Project) ([]
 }
 
 // Convert a compose project into a CloudFormation template
-func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Template, error) { //nolint:gocyclo
+func (b *ecsAPIService) convert(project *types.Project, resources awsResources) (*cloudformation.Template, error) {
 	template := cloudformation.NewTemplate()
-	b.resources.ensure(project, template)
+	b.ensureResources(&resources, project, template)
 
 	for name, secret := range project.Secrets {
 		err := b.createSecret(project, name, secret, template)
@@ -87,7 +87,7 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 	b.createLogGroup(project, template)
 
 	// Private DNS namespace will allow DNS name for the services to be <service>.<project>.local
-	b.createCloudMap(project, template)
+	b.createCloudMap(project, template, resources.vpc)
 
 	for _, service := range project.Services {
 		taskExecutionRole := b.createTaskExecutionRole(project, service, template)
@@ -114,16 +114,16 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 		)
 		for _, port := range service.Ports {
 			for net := range service.Networks {
-				b.createIngress(service, net, port, template)
+				b.createIngress(service, net, port, template, resources)
 			}
 
 			protocol := strings.ToUpper(port.Protocol)
-			if b.resources.loadBalancerType == elbv2.LoadBalancerTypeEnumApplication {
+			if resources.loadBalancerType == elbv2.LoadBalancerTypeEnumApplication {
 				// we don't set Https as a certificate must be specified for HTTPS listeners
 				protocol = elbv2.ProtocolEnumHttp
 			}
-			targetGroupName := b.createTargetGroup(project, service, port, template, protocol)
-			listenerName := b.createListener(service, port, template, targetGroupName, b.resources.loadBalancer, protocol)
+			targetGroupName := b.createTargetGroup(project, service, port, template, protocol, resources.vpc)
+			listenerName := b.createListener(service, port, template, targetGroupName, resources.loadBalancer, protocol)
 			dependsOn = append(dependsOn, listenerName)
 			serviceLB = append(serviceLB, ecs.Service_LoadBalancer{
 				ContainerName:  service.Name,
@@ -157,7 +157,7 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 
 		template.Resources[serviceResourceName(service.Name)] = &ecs.Service{
 			AWSCloudFormationDependsOn: dependsOn,
-			Cluster:                    b.resources.cluster,
+			Cluster:                    resources.cluster,
 			DesiredCount:               desiredCount,
 			DeploymentController: &ecs.Service_DeploymentController{
 				Type: ecsapi.DeploymentControllerTypeEcs,
@@ -172,8 +172,8 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 			NetworkConfiguration: &ecs.Service_NetworkConfiguration{
 				AwsvpcConfiguration: &ecs.Service_AwsVpcConfiguration{
 					AssignPublicIp: assignPublicIP,
-					SecurityGroups: b.resources.serviceSecurityGroups(service),
-					Subnets:        b.resources.subnets,
+					SecurityGroups: resources.serviceSecurityGroups(service),
+					Subnets:        resources.subnets,
 				},
 			},
 			PlatformVersion:    platformVersion,
@@ -187,16 +187,18 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 	return template, nil
 }
 
-func (b *ecsAPIService) createIngress(service types.ServiceConfig, net string, port types.ServicePortConfig, template *cloudformation.Template) {
+const allProtocols = "-1"
+
+func (b *ecsAPIService) createIngress(service types.ServiceConfig, net string, port types.ServicePortConfig, template *cloudformation.Template, resources awsResources) {
 	protocol := strings.ToUpper(port.Protocol)
 	if protocol == "" {
-		protocol = "-1"
+		protocol = allProtocols
 	}
 	ingress := fmt.Sprintf("%s%dIngress", normalizeResourceName(net), port.Target)
 	template.Resources[ingress] = &ec2.SecurityGroupIngress{
 		CidrIp:      "0.0.0.0/0",
 		Description: fmt.Sprintf("%s:%d/%s on %s nextwork", service.Name, port.Target, port.Protocol, net),
-		GroupId:     b.resources.securityGroups[net],
+		GroupId:     resources.securityGroups[net],
 		FromPort:    int(port.Target),
 		IpProtocol:  protocol,
 		ToPort:      int(port.Target),
@@ -306,7 +308,7 @@ func (b *ecsAPIService) createListener(service types.ServiceConfig, port types.S
 	return listenerName
 }
 
-func (b *ecsAPIService) createTargetGroup(project *types.Project, service types.ServiceConfig, port types.ServicePortConfig, template *cloudformation.Template, protocol string) string {
+func (b *ecsAPIService) createTargetGroup(project *types.Project, service types.ServiceConfig, port types.ServicePortConfig, template *cloudformation.Template, protocol string, vpc string) string {
 	targetGroupName := fmt.Sprintf(
 		"%s%s%dTargetGroup",
 		normalizeResourceName(service.Name),
@@ -319,7 +321,7 @@ func (b *ecsAPIService) createTargetGroup(project *types.Project, service types.
 		Protocol:           protocol,
 		Tags:               projectTags(project),
 		TargetType:         elbv2.TargetTypeEnumIp,
-		VpcId:              b.resources.vpc,
+		VpcId:              vpc,
 	}
 	return targetGroupName
 }
@@ -390,11 +392,11 @@ func (b *ecsAPIService) createTaskRole(service types.ServiceConfig, template *cl
 	return taskRole
 }
 
-func (b *ecsAPIService) createCloudMap(project *types.Project, template *cloudformation.Template) {
+func (b *ecsAPIService) createCloudMap(project *types.Project, template *cloudformation.Template, vpc string) {
 	template.Resources["CloudMap"] = &cloudmap.PrivateDnsNamespace{
 		Description: fmt.Sprintf("Service Map for Docker Compose project %s", project.Name),
 		Name:        fmt.Sprintf("%s.local", project.Name),
-		Vpc:         b.resources.vpc,
+		Vpc:         vpc,
 	}
 }
 

+ 7 - 9
ecs/cloudformation_test.go

@@ -321,7 +321,7 @@ services:
           memory: 2043248M
 `)
 	backend := &ecsAPIService{}
-	_, err := backend.convert(model)
+	_, err := backend.convert(model, awsResources{})
 	assert.ErrorContains(t, err, "the resources requested are not supported by ECS/Fargate")
 }
 
@@ -404,13 +404,11 @@ services:
 }
 
 func convertResultAsString(t *testing.T, project *types.Project) string {
-	backend := &ecsAPIService{
-		resources: awsResources{
-			vpc:     "vpcID",
-			subnets: []string{"subnet1", "subnet2"},
-		},
-	}
-	template, err := backend.convert(project)
+	backend := &ecsAPIService{}
+	template, err := backend.convert(project, awsResources{
+		vpc:     "vpcID",
+		subnets: []string{"subnet1", "subnet2"},
+	})
 	assert.NilError(t, err)
 	resultAsJSON, err := marshall(template)
 	assert.NilError(t, err)
@@ -430,7 +428,7 @@ func load(t *testing.T, paths ...string) *types.Project {
 func convertYaml(t *testing.T, yaml string) *cloudformation.Template {
 	project := loadConfig(t, yaml)
 	backend := &ecsAPIService{}
-	template, err := backend.convert(project)
+	template, err := backend.convert(project, awsResources{})
 	assert.NilError(t, err)
 	return template
 }

+ 3 - 3
ecs/ec2.go

@@ -28,7 +28,7 @@ import (
 	"github.com/compose-spec/compose-go/types"
 )
 
-func (b *ecsAPIService) createCapacityProvider(ctx context.Context, project *types.Project, template *cloudformation.Template) error {
+func (b *ecsAPIService) createCapacityProvider(ctx context.Context, project *types.Project, template *cloudformation.Template, resources awsResources) error {
 	var ec2 bool
 	for _, s := range project.Services {
 		if requireEC2(s) {
@@ -65,7 +65,7 @@ func (b *ecsAPIService) createCapacityProvider(ctx context.Context, project *typ
 		LaunchConfigurationName: cloudformation.Ref("LaunchConfiguration"),
 		MaxSize:                 "10", //TODO
 		MinSize:                 "1",
-		VPCZoneIdentifier:       b.resources.subnets,
+		VPCZoneIdentifier:       resources.subnets,
 	}
 
 	userData := base64.StdEncoding.EncodeToString([]byte(
@@ -74,7 +74,7 @@ func (b *ecsAPIService) createCapacityProvider(ctx context.Context, project *typ
 	template.Resources["LaunchConfiguration"] = &autoscaling.LaunchConfiguration{
 		ImageId:            ami,
 		InstanceType:       machineType,
-		SecurityGroups:     b.resources.allSecurityGroups(),
+		SecurityGroups:     resources.allSecurityGroups(),
 		IamInstanceProfile: cloudformation.Ref("EC2InstanceProfile"),
 		UserData:           userData,
 	}

+ 10 - 0
ecs/sdk.go

@@ -704,3 +704,13 @@ func (s sdk) GetParameter(ctx context.Context, name string) (string, error) {
 
 	return ami.ImageID, nil
 }
+
+func (s sdk) SecurityGroupExists(ctx context.Context, sg string) (bool, error) {
+	desc, err := s.EC2.DescribeSecurityGroupsWithContext(ctx, &ec2.DescribeSecurityGroupsInput{
+		GroupIds: aws.StringSlice([]string{sg}),
+	})
+	if err != nil {
+		return false, err
+	}
+	return len(desc.SecurityGroups) > 0, nil
+}