Browse Source

Only consider public subnets

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 4 years ago
parent
commit
b9a3025865
5 changed files with 66 additions and 4 deletions
  1. 1 0
      ecs/aws.go
  2. 15 3
      ecs/awsResources.go
  3. 17 1
      ecs/aws_mock.go
  4. 2 0
      ecs/cloudformation_test.go
  5. 31 0
      ecs/sdk.go

+ 1 - 0
ecs/aws.go

@@ -40,6 +40,7 @@ type API interface {
 	CheckVPC(ctx context.Context, vpcID string) error
 	GetDefaultVPC(ctx context.Context) (string, error)
 	GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
+	IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error)
 	GetRoleArn(ctx context.Context, name string) (string, error)
 	StackExists(ctx context.Context, name string) (bool, error)
 	CreateStack(ctx context.Context, name string, region string, template []byte) error

+ 15 - 3
ecs/awsResources.go

@@ -185,10 +185,22 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
 	if err != nil {
 		return "", nil, err
 	}
-	if len(subNets) < 2 {
-		return "", nil, fmt.Errorf("VPC %s should have at least 2 associated subnets in different availability zones", vpc)
+
+	var publicSubNets []awsResource
+	for _, subNet := range subNets {
+		isPublic, err := b.aws.IsPublicSubnet(ctx, vpc, subNet.ID())
+		if err != nil {
+			return "", nil, err
+		}
+		if isPublic {
+			publicSubNets = append(publicSubNets, subNet)
+		}
+	}
+
+	if len(publicSubNets) < 2 {
+		return "", nil, fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
 	}
-	return vpc, subNets, nil
+	return vpc, publicSubNets, nil
 }
 
 func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) {

+ 17 - 1
ecs/aws_mock.go

@@ -6,12 +6,13 @@ package ecs
 
 import (
 	context "context"
+	reflect "reflect"
+
 	cloudformation "github.com/aws/aws-sdk-go/service/cloudformation"
 	ecs "github.com/aws/aws-sdk-go/service/ecs"
 	compose "github.com/docker/compose-cli/api/compose"
 	secrets "github.com/docker/compose-cli/api/secrets"
 	gomock "github.com/golang/mock/gomock"
-	reflect "reflect"
 )
 
 // MockAPI is a mock of API interface
@@ -453,6 +454,21 @@ func (mr *MockAPIMockRecorder) InspectSecret(arg0, arg1 interface{}) *gomock.Cal
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InspectSecret", reflect.TypeOf((*MockAPI)(nil).InspectSecret), arg0, arg1)
 }
 
+// IsPublicSubnet mocks base method
+func (m *MockAPI) IsPublicSubnet(ctx context.Context, arg0 string, arg1 string) (bool, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "IsPublicSubnet", arg0, arg1)
+	ret0, _ := ret[0].(bool)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// IsPublicSubnet indicates an expected call of IsPublicSubnet
+func (mr *MockAPIMockRecorder) IsPublicSubnet(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPublicSubnet", reflect.TypeOf((*MockAPI)(nil).IsPublicSubnet), arg0, arg1)
+}
+
 // ListFileSystems mocks base method
 func (m *MockAPI) ListFileSystems(arg0 context.Context, arg1 map[string]string) ([]awsResource, error) {
 	m.ctrl.T.Helper()

+ 2 - 0
ecs/cloudformation_test.go

@@ -591,6 +591,8 @@ func useDefaultVPC(m *MockAPIMockRecorder) {
 		existingAWSResource{id: "subnet1"},
 		existingAWSResource{id: "subnet2"},
 	}, nil)
+	m.IsPublicSubnet(gomock.Any(), "subnet1").Return(true, nil)
+	m.IsPublicSubnet(gomock.Any(), "subnet2").Return(true, nil)
 }
 
 func useGPU(m *MockAPIMockRecorder) {

+ 31 - 0
ecs/sdk.go

@@ -211,6 +211,37 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error
 	return ids, nil
 }
 
+func (s sdk) IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) {
+	tables, err := s.EC2.DescribeRouteTablesWithContext(ctx, &ec2.DescribeRouteTablesInput{
+		Filters: []*ec2.Filter{
+			{
+				Name:   aws.String("association.subnet-id"),
+				Values: []*string{aws.String(subNetID)},
+			},
+		},
+	})
+	if err != nil {
+		return false, err
+	}
+	if len(tables.RouteTables) == 0 {
+		// If a subnet is not explicitly associated with any route table, it is implicitly associated with the main route table.
+		// https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-route-tables.html
+		return true, nil
+	}
+	for _, routeTable := range tables.RouteTables {
+		for _, route := range routeTable.Routes {
+			if aws.StringValue(route.State) != "active" {
+				continue
+			}
+			if strings.HasPrefix(aws.StringValue(route.GatewayId), "igw-") {
+				// Connected to an internet Gateway
+				return true, nil
+			}
+		}
+	}
+	return false, nil
+}
+
 func (s sdk) GetRoleArn(ctx context.Context, name string) (string, error) {
 	role, err := s.IAM.GetRoleWithContext(ctx, &iam.GetRoleInput{
 		RoleName: aws.String(name),