Selaa lähdekoodia

Use LoadBalancer's VPC and subnet when x-aws-loadbalancer is set

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 4 vuotta sitten
vanhempi
sitoutus
075f54713e
4 muutettua tiedostoa jossa 58 lisäystä ja 33 poistoa
  1. 2 2
      ecs/aws.go
  2. 34 16
      ecs/awsResources.go
  3. 7 6
      ecs/aws_mock.go
  4. 15 9
      ecs/sdk.go

+ 2 - 2
ecs/aws.go

@@ -40,7 +40,7 @@ type API interface {
 	CheckVPC(ctx context.Context, vpcID string) error
 	GetDefaultVPC(ctx context.Context) (string, error)
 	GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
-	IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error)
+	IsPublicSubnet(ctx context.Context, subNetID string) (bool, error)
 	GetRoleArn(ctx context.Context, name string) (string, error)
 	StackExists(ctx context.Context, name string) (bool, error)
 	CreateStack(ctx context.Context, name string, region string, template []byte) error
@@ -68,7 +68,7 @@ type API interface {
 	getURLWithPortMapping(ctx context.Context, targetGroupArns []string) ([]compose.PortPublisher, error)
 	ListTasks(ctx context.Context, cluster string, family string) ([]string, error)
 	GetPublicIPs(ctx context.Context, interfaces ...string) (map[string]string, error)
-	ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, error)
+	ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, string, []awsResource, error)
 	GetLoadBalancerURL(ctx context.Context, arn string) (string, error)
 	GetParameter(ctx context.Context, name string) (string, error)
 	SecurityGroupExists(ctx context.Context, sg string) (bool, error)

+ 34 - 16
ecs/awsResources.go

@@ -129,11 +129,11 @@ func (b *ecsAPIService) parse(ctx context.Context, project *types.Project, templ
 	if err != nil {
 		return r, err
 	}
-	r.vpc, r.subnets, err = b.parseVPCExtension(ctx, project)
+	err = b.parseLoadBalancerExtension(ctx, project, &r)
 	if err != nil {
 		return r, err
 	}
-	r.loadBalancer, r.loadBalancerType, err = b.parseLoadBalancerExtension(ctx, project)
+	err = b.parseVPCExtension(ctx, project, &r)
 	if err != nil {
 		return r, err
 	}
@@ -165,7 +165,7 @@ func (b *ecsAPIService) parseClusterExtension(ctx context.Context, project *type
 	return nil, nil
 }
 
-func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project) (string, []awsResource, error) {
+func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project, r *awsResources) error {
 	var vpc string
 	if x, ok := project.Extensions[extensionVPC]; ok {
 		vpc = x.(string)
@@ -177,29 +177,40 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
 			vpc = id[i+1:]
 		}
 
+		if r.vpc != "" {
+			if r.vpc != vpc {
+				return fmt.Errorf("load balancer set by %s is attached to VPC %s", extensionLoadBalancer, r.vpc)
+			}
+			return nil
+		}
+
 		err = b.aws.CheckVPC(ctx, vpc)
 		if err != nil {
-			return "", nil, err
+			return err
 		}
 
 	} else {
+		if r.vpc != "" {
+			return nil
+		}
+
 		defaultVPC, err := b.aws.GetDefaultVPC(ctx)
 		if err != nil {
-			return "", nil, err
+			return err
 		}
 		vpc = defaultVPC
 	}
 
 	subNets, err := b.aws.GetSubNets(ctx, vpc)
 	if err != nil {
-		return "", nil, err
+		return err
 	}
 
 	var publicSubNets []awsResource
 	for _, subNet := range subNets {
-		isPublic, err := b.aws.IsPublicSubnet(ctx, vpc, subNet.ID())
+		isPublic, err := b.aws.IsPublicSubnet(ctx, subNet.ID())
 		if err != nil {
-			return "", nil, err
+			return err
 		}
 		if isPublic {
 			publicSubNets = append(publicSubNets, subNet)
@@ -207,27 +218,34 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
 	}
 
 	if len(publicSubNets) < 2 {
-		return "", nil, fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
+		return fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
 	}
-	return vpc, publicSubNets, nil
+
+	r.vpc = vpc
+	r.subnets = subNets
+	return nil
 }
 
-func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) {
+func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project, r *awsResources) error {
 	if x, ok := project.Extensions[extensionLoadBalancer]; ok {
 		nameOrArn := x.(string)
-		loadBalancer, loadBalancerType, err := b.aws.ResolveLoadBalancer(ctx, nameOrArn)
+		loadBalancer, loadBalancerType, vpc, subnets, err := b.aws.ResolveLoadBalancer(ctx, nameOrArn)
 		if err != nil {
-			return nil, "", err
+			return err
 		}
 
 		required := getRequiredLoadBalancerType(project)
 		if loadBalancerType != required {
-			return nil, "", fmt.Errorf("load balancer %q is of type %s, project require a %s", nameOrArn, loadBalancerType, required)
+			return fmt.Errorf("load balancer %q is of type %s, project require a %s", nameOrArn, loadBalancerType, required)
 		}
 
-		return loadBalancer, loadBalancerType, err
+		r.loadBalancer = loadBalancer
+		r.loadBalancerType = loadBalancerType
+		r.vpc = vpc
+		r.subnets = subnets
+		return err
 	}
-	return nil, "", nil
+	return nil
 }
 
 func (b *ecsAPIService) parseExternalNetworks(ctx context.Context, project *types.Project) (map[string]string, error) {

+ 7 - 6
ecs/aws_mock.go

@@ -6,13 +6,12 @@ package ecs
 
 import (
 	context "context"
-	reflect "reflect"
-
 	cloudformation "github.com/aws/aws-sdk-go/service/cloudformation"
 	ecs "github.com/aws/aws-sdk-go/service/ecs"
 	compose "github.com/docker/compose-cli/api/compose"
 	secrets "github.com/docker/compose-cli/api/secrets"
 	gomock "github.com/golang/mock/gomock"
+	reflect "reflect"
 )
 
 // MockAPI is a mock of API interface
@@ -455,7 +454,7 @@ func (mr *MockAPIMockRecorder) InspectSecret(arg0, arg1 interface{}) *gomock.Cal
 }
 
 // IsPublicSubnet mocks base method
-func (m *MockAPI) IsPublicSubnet(ctx context.Context, arg0 string, arg1 string) (bool, error) {
+func (m *MockAPI) IsPublicSubnet(arg0 context.Context, arg1 string) (bool, error) {
 	m.ctrl.T.Helper()
 	ret := m.ctrl.Call(m, "IsPublicSubnet", arg0, arg1)
 	ret0, _ := ret[0].(bool)
@@ -605,13 +604,15 @@ func (mr *MockAPIMockRecorder) ResolveFileSystem(arg0, arg1 interface{}) *gomock
 }
 
 // ResolveLoadBalancer mocks base method
-func (m *MockAPI) ResolveLoadBalancer(arg0 context.Context, arg1 string) (awsResource, string, error) {
+func (m *MockAPI) ResolveLoadBalancer(arg0 context.Context, arg1 string) (awsResource, string, string, []awsResource, error) {
 	m.ctrl.T.Helper()
 	ret := m.ctrl.Call(m, "ResolveLoadBalancer", arg0, arg1)
 	ret0, _ := ret[0].(awsResource)
 	ret1, _ := ret[1].(string)
-	ret2, _ := ret[2].(error)
-	return ret0, ret1, ret2
+	ret2, _ := ret[2].(string)
+	ret3, _ := ret[3].([]awsResource)
+	ret4, _ := ret[4].(error)
+	return ret0, ret1, ret2, ret3, ret4
 }
 
 // ResolveLoadBalancer indicates an expected call of ResolveLoadBalancer

+ 15 - 9
ecs/sdk.go

@@ -210,7 +210,7 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error
 	return ids, nil
 }
 
-func (s sdk) IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) {
+func (s sdk) IsPublicSubnet(ctx context.Context, subNetID string) (bool, error) {
 	tables, err := s.EC2.DescribeRouteTablesWithContext(ctx, &ec2.DescribeRouteTablesInput{
 		Filters: []*ec2.Filter{
 			{
@@ -1045,14 +1045,14 @@ func (s sdk) GetPublicIPs(ctx context.Context, interfaces ...string) (map[string
 	}
 }
 
-func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrarn string) (awsResource, string, error) {
-	logrus.Debug("Check if LoadBalancer exists: ", nameOrarn)
+func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, string, []awsResource, error) {
+	logrus.Debug("Check if LoadBalancer exists: ", nameOrArn)
 	var arns []*string
 	var names []*string
-	if arn.IsARN(nameOrarn) {
-		arns = append(arns, aws.String(nameOrarn))
+	if arn.IsARN(nameOrArn) {
+		arns = append(arns, aws.String(nameOrArn))
 	} else {
-		names = append(names, aws.String(nameOrarn))
+		names = append(names, aws.String(nameOrArn))
 	}
 
 	lbs, err := s.ELB.DescribeLoadBalancersWithContext(ctx, &elbv2.DescribeLoadBalancersInput{
@@ -1060,16 +1060,22 @@ func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrarn string) (awsReso
 		Names:            names,
 	})
 	if err != nil {
-		return nil, "", err
+		return nil, "", "", nil, err
 	}
 	if len(lbs.LoadBalancers) == 0 {
-		return nil, "", errors.Wrapf(errdefs.ErrNotFound, "load balancer %q does not exist", nameOrarn)
+		return nil, "", "", nil, errors.Wrapf(errdefs.ErrNotFound, "load balancer %q does not exist", nameOrArn)
 	}
 	it := lbs.LoadBalancers[0]
+	var subNets []awsResource
+	for _, az := range it.AvailabilityZones {
+		subNets = append(subNets, existingAWSResource{
+			id: aws.StringValue(az.SubnetId),
+		})
+	}
 	return existingAWSResource{
 		arn: aws.StringValue(it.LoadBalancerArn),
 		id:  aws.StringValue(it.LoadBalancerName),
-	}, aws.StringValue(it.Type), nil
+	}, aws.StringValue(it.Type), aws.StringValue(it.VpcId), subNets, nil
 }
 
 func (s sdk) GetLoadBalancerURL(ctx context.Context, arn string) (string, error) {