Browse Source

Use LoadBalancer's VPC and subnet when x-aws-loadbalancer is set

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 4 years ago
parent
commit
075f54713e
4 changed files with 58 additions and 33 deletions
  1. 2 2
      ecs/aws.go
  2. 34 16
      ecs/awsResources.go
  3. 7 6
      ecs/aws_mock.go
  4. 15 9
      ecs/sdk.go

+ 2 - 2
ecs/aws.go

@@ -40,7 +40,7 @@ type API interface {
 	CheckVPC(ctx context.Context, vpcID string) error
 	CheckVPC(ctx context.Context, vpcID string) error
 	GetDefaultVPC(ctx context.Context) (string, error)
 	GetDefaultVPC(ctx context.Context) (string, error)
 	GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
 	GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
-	IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error)
+	IsPublicSubnet(ctx context.Context, subNetID string) (bool, error)
 	GetRoleArn(ctx context.Context, name string) (string, error)
 	GetRoleArn(ctx context.Context, name string) (string, error)
 	StackExists(ctx context.Context, name string) (bool, error)
 	StackExists(ctx context.Context, name string) (bool, error)
 	CreateStack(ctx context.Context, name string, region string, template []byte) error
 	CreateStack(ctx context.Context, name string, region string, template []byte) error
@@ -68,7 +68,7 @@ type API interface {
 	getURLWithPortMapping(ctx context.Context, targetGroupArns []string) ([]compose.PortPublisher, error)
 	getURLWithPortMapping(ctx context.Context, targetGroupArns []string) ([]compose.PortPublisher, error)
 	ListTasks(ctx context.Context, cluster string, family string) ([]string, error)
 	ListTasks(ctx context.Context, cluster string, family string) ([]string, error)
 	GetPublicIPs(ctx context.Context, interfaces ...string) (map[string]string, error)
 	GetPublicIPs(ctx context.Context, interfaces ...string) (map[string]string, error)
-	ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, error)
+	ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, string, []awsResource, error)
 	GetLoadBalancerURL(ctx context.Context, arn string) (string, error)
 	GetLoadBalancerURL(ctx context.Context, arn string) (string, error)
 	GetParameter(ctx context.Context, name string) (string, error)
 	GetParameter(ctx context.Context, name string) (string, error)
 	SecurityGroupExists(ctx context.Context, sg string) (bool, error)
 	SecurityGroupExists(ctx context.Context, sg string) (bool, error)

+ 34 - 16
ecs/awsResources.go

@@ -129,11 +129,11 @@ func (b *ecsAPIService) parse(ctx context.Context, project *types.Project, templ
 	if err != nil {
 	if err != nil {
 		return r, err
 		return r, err
 	}
 	}
-	r.vpc, r.subnets, err = b.parseVPCExtension(ctx, project)
+	err = b.parseLoadBalancerExtension(ctx, project, &r)
 	if err != nil {
 	if err != nil {
 		return r, err
 		return r, err
 	}
 	}
-	r.loadBalancer, r.loadBalancerType, err = b.parseLoadBalancerExtension(ctx, project)
+	err = b.parseVPCExtension(ctx, project, &r)
 	if err != nil {
 	if err != nil {
 		return r, err
 		return r, err
 	}
 	}
@@ -165,7 +165,7 @@ func (b *ecsAPIService) parseClusterExtension(ctx context.Context, project *type
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project) (string, []awsResource, error) {
+func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project, r *awsResources) error {
 	var vpc string
 	var vpc string
 	if x, ok := project.Extensions[extensionVPC]; ok {
 	if x, ok := project.Extensions[extensionVPC]; ok {
 		vpc = x.(string)
 		vpc = x.(string)
@@ -177,29 +177,40 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
 			vpc = id[i+1:]
 			vpc = id[i+1:]
 		}
 		}
 
 
+		if r.vpc != "" {
+			if r.vpc != vpc {
+				return fmt.Errorf("load balancer set by %s is attached to VPC %s", extensionLoadBalancer, r.vpc)
+			}
+			return nil
+		}
+
 		err = b.aws.CheckVPC(ctx, vpc)
 		err = b.aws.CheckVPC(ctx, vpc)
 		if err != nil {
 		if err != nil {
-			return "", nil, err
+			return err
 		}
 		}
 
 
 	} else {
 	} else {
+		if r.vpc != "" {
+			return nil
+		}
+
 		defaultVPC, err := b.aws.GetDefaultVPC(ctx)
 		defaultVPC, err := b.aws.GetDefaultVPC(ctx)
 		if err != nil {
 		if err != nil {
-			return "", nil, err
+			return err
 		}
 		}
 		vpc = defaultVPC
 		vpc = defaultVPC
 	}
 	}
 
 
 	subNets, err := b.aws.GetSubNets(ctx, vpc)
 	subNets, err := b.aws.GetSubNets(ctx, vpc)
 	if err != nil {
 	if err != nil {
-		return "", nil, err
+		return err
 	}
 	}
 
 
 	var publicSubNets []awsResource
 	var publicSubNets []awsResource
 	for _, subNet := range subNets {
 	for _, subNet := range subNets {
-		isPublic, err := b.aws.IsPublicSubnet(ctx, vpc, subNet.ID())
+		isPublic, err := b.aws.IsPublicSubnet(ctx, subNet.ID())
 		if err != nil {
 		if err != nil {
-			return "", nil, err
+			return err
 		}
 		}
 		if isPublic {
 		if isPublic {
 			publicSubNets = append(publicSubNets, subNet)
 			publicSubNets = append(publicSubNets, subNet)
@@ -207,27 +218,34 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
 	}
 	}
 
 
 	if len(publicSubNets) < 2 {
 	if len(publicSubNets) < 2 {
-		return "", nil, fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
+		return fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
 	}
 	}
-	return vpc, publicSubNets, nil
+
+	r.vpc = vpc
+	r.subnets = subNets
+	return nil
 }
 }
 
 
-func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) {
+func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project, r *awsResources) error {
 	if x, ok := project.Extensions[extensionLoadBalancer]; ok {
 	if x, ok := project.Extensions[extensionLoadBalancer]; ok {
 		nameOrArn := x.(string)
 		nameOrArn := x.(string)
-		loadBalancer, loadBalancerType, err := b.aws.ResolveLoadBalancer(ctx, nameOrArn)
+		loadBalancer, loadBalancerType, vpc, subnets, err := b.aws.ResolveLoadBalancer(ctx, nameOrArn)
 		if err != nil {
 		if err != nil {
-			return nil, "", err
+			return err
 		}
 		}
 
 
 		required := getRequiredLoadBalancerType(project)
 		required := getRequiredLoadBalancerType(project)
 		if loadBalancerType != required {
 		if loadBalancerType != required {
-			return nil, "", fmt.Errorf("load balancer %q is of type %s, project require a %s", nameOrArn, loadBalancerType, required)
+			return fmt.Errorf("load balancer %q is of type %s, project require a %s", nameOrArn, loadBalancerType, required)
 		}
 		}
 
 
-		return loadBalancer, loadBalancerType, err
+		r.loadBalancer = loadBalancer
+		r.loadBalancerType = loadBalancerType
+		r.vpc = vpc
+		r.subnets = subnets
+		return err
 	}
 	}
-	return nil, "", nil
+	return nil
 }
 }
 
 
 func (b *ecsAPIService) parseExternalNetworks(ctx context.Context, project *types.Project) (map[string]string, error) {
 func (b *ecsAPIService) parseExternalNetworks(ctx context.Context, project *types.Project) (map[string]string, error) {

+ 7 - 6
ecs/aws_mock.go

@@ -6,13 +6,12 @@ package ecs
 
 
 import (
 import (
 	context "context"
 	context "context"
-	reflect "reflect"
-
 	cloudformation "github.com/aws/aws-sdk-go/service/cloudformation"
 	cloudformation "github.com/aws/aws-sdk-go/service/cloudformation"
 	ecs "github.com/aws/aws-sdk-go/service/ecs"
 	ecs "github.com/aws/aws-sdk-go/service/ecs"
 	compose "github.com/docker/compose-cli/api/compose"
 	compose "github.com/docker/compose-cli/api/compose"
 	secrets "github.com/docker/compose-cli/api/secrets"
 	secrets "github.com/docker/compose-cli/api/secrets"
 	gomock "github.com/golang/mock/gomock"
 	gomock "github.com/golang/mock/gomock"
+	reflect "reflect"
 )
 )
 
 
 // MockAPI is a mock of API interface
 // MockAPI is a mock of API interface
@@ -455,7 +454,7 @@ func (mr *MockAPIMockRecorder) InspectSecret(arg0, arg1 interface{}) *gomock.Cal
 }
 }
 
 
 // IsPublicSubnet mocks base method
 // IsPublicSubnet mocks base method
-func (m *MockAPI) IsPublicSubnet(ctx context.Context, arg0 string, arg1 string) (bool, error) {
+func (m *MockAPI) IsPublicSubnet(arg0 context.Context, arg1 string) (bool, error) {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
 	ret := m.ctrl.Call(m, "IsPublicSubnet", arg0, arg1)
 	ret := m.ctrl.Call(m, "IsPublicSubnet", arg0, arg1)
 	ret0, _ := ret[0].(bool)
 	ret0, _ := ret[0].(bool)
@@ -605,13 +604,15 @@ func (mr *MockAPIMockRecorder) ResolveFileSystem(arg0, arg1 interface{}) *gomock
 }
 }
 
 
 // ResolveLoadBalancer mocks base method
 // ResolveLoadBalancer mocks base method
-func (m *MockAPI) ResolveLoadBalancer(arg0 context.Context, arg1 string) (awsResource, string, error) {
+func (m *MockAPI) ResolveLoadBalancer(arg0 context.Context, arg1 string) (awsResource, string, string, []awsResource, error) {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
 	ret := m.ctrl.Call(m, "ResolveLoadBalancer", arg0, arg1)
 	ret := m.ctrl.Call(m, "ResolveLoadBalancer", arg0, arg1)
 	ret0, _ := ret[0].(awsResource)
 	ret0, _ := ret[0].(awsResource)
 	ret1, _ := ret[1].(string)
 	ret1, _ := ret[1].(string)
-	ret2, _ := ret[2].(error)
-	return ret0, ret1, ret2
+	ret2, _ := ret[2].(string)
+	ret3, _ := ret[3].([]awsResource)
+	ret4, _ := ret[4].(error)
+	return ret0, ret1, ret2, ret3, ret4
 }
 }
 
 
 // ResolveLoadBalancer indicates an expected call of ResolveLoadBalancer
 // ResolveLoadBalancer indicates an expected call of ResolveLoadBalancer

+ 15 - 9
ecs/sdk.go

@@ -210,7 +210,7 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error
 	return ids, nil
 	return ids, nil
 }
 }
 
 
-func (s sdk) IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) {
+func (s sdk) IsPublicSubnet(ctx context.Context, subNetID string) (bool, error) {
 	tables, err := s.EC2.DescribeRouteTablesWithContext(ctx, &ec2.DescribeRouteTablesInput{
 	tables, err := s.EC2.DescribeRouteTablesWithContext(ctx, &ec2.DescribeRouteTablesInput{
 		Filters: []*ec2.Filter{
 		Filters: []*ec2.Filter{
 			{
 			{
@@ -1045,14 +1045,14 @@ func (s sdk) GetPublicIPs(ctx context.Context, interfaces ...string) (map[string
 	}
 	}
 }
 }
 
 
-func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrarn string) (awsResource, string, error) {
-	logrus.Debug("Check if LoadBalancer exists: ", nameOrarn)
+func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, string, []awsResource, error) {
+	logrus.Debug("Check if LoadBalancer exists: ", nameOrArn)
 	var arns []*string
 	var arns []*string
 	var names []*string
 	var names []*string
-	if arn.IsARN(nameOrarn) {
-		arns = append(arns, aws.String(nameOrarn))
+	if arn.IsARN(nameOrArn) {
+		arns = append(arns, aws.String(nameOrArn))
 	} else {
 	} else {
-		names = append(names, aws.String(nameOrarn))
+		names = append(names, aws.String(nameOrArn))
 	}
 	}
 
 
 	lbs, err := s.ELB.DescribeLoadBalancersWithContext(ctx, &elbv2.DescribeLoadBalancersInput{
 	lbs, err := s.ELB.DescribeLoadBalancersWithContext(ctx, &elbv2.DescribeLoadBalancersInput{
@@ -1060,16 +1060,22 @@ func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrarn string) (awsReso
 		Names:            names,
 		Names:            names,
 	})
 	})
 	if err != nil {
 	if err != nil {
-		return nil, "", err
+		return nil, "", "", nil, err
 	}
 	}
 	if len(lbs.LoadBalancers) == 0 {
 	if len(lbs.LoadBalancers) == 0 {
-		return nil, "", errors.Wrapf(errdefs.ErrNotFound, "load balancer %q does not exist", nameOrarn)
+		return nil, "", "", nil, errors.Wrapf(errdefs.ErrNotFound, "load balancer %q does not exist", nameOrArn)
 	}
 	}
 	it := lbs.LoadBalancers[0]
 	it := lbs.LoadBalancers[0]
+	var subNets []awsResource
+	for _, az := range it.AvailabilityZones {
+		subNets = append(subNets, existingAWSResource{
+			id: aws.StringValue(az.SubnetId),
+		})
+	}
 	return existingAWSResource{
 	return existingAWSResource{
 		arn: aws.StringValue(it.LoadBalancerArn),
 		arn: aws.StringValue(it.LoadBalancerArn),
 		id:  aws.StringValue(it.LoadBalancerName),
 		id:  aws.StringValue(it.LoadBalancerName),
-	}, aws.StringValue(it.Type), nil
+	}, aws.StringValue(it.Type), aws.StringValue(it.VpcId), subNets, nil
 }
 }
 
 
 func (s sdk) GetLoadBalancerURL(ctx context.Context, arn string) (string, error) {
 func (s sdk) GetLoadBalancerURL(ctx context.Context, arn string) (string, error) {