浏览代码

Only consider public subnets

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 4 年之前
父节点
当前提交
b9a3025865
共有 5 个文件被更改,包括 66 次插入4 次删除
  1. 1 0
      ecs/aws.go
  2. 15 3
      ecs/awsResources.go
  3. 17 1
      ecs/aws_mock.go
  4. 2 0
      ecs/cloudformation_test.go
  5. 31 0
      ecs/sdk.go

+ 1 - 0
ecs/aws.go

@@ -40,6 +40,7 @@ type API interface {
 	CheckVPC(ctx context.Context, vpcID string) error
 	CheckVPC(ctx context.Context, vpcID string) error
 	GetDefaultVPC(ctx context.Context) (string, error)
 	GetDefaultVPC(ctx context.Context) (string, error)
 	GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
 	GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
+	IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error)
 	GetRoleArn(ctx context.Context, name string) (string, error)
 	GetRoleArn(ctx context.Context, name string) (string, error)
 	StackExists(ctx context.Context, name string) (bool, error)
 	StackExists(ctx context.Context, name string) (bool, error)
 	CreateStack(ctx context.Context, name string, region string, template []byte) error
 	CreateStack(ctx context.Context, name string, region string, template []byte) error

+ 15 - 3
ecs/awsResources.go

@@ -185,10 +185,22 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
 	if err != nil {
 	if err != nil {
 		return "", nil, err
 		return "", nil, err
 	}
 	}
-	if len(subNets) < 2 {
-		return "", nil, fmt.Errorf("VPC %s should have at least 2 associated subnets in different availability zones", vpc)
+
+	var publicSubNets []awsResource
+	for _, subNet := range subNets {
+		isPublic, err := b.aws.IsPublicSubnet(ctx, vpc, subNet.ID())
+		if err != nil {
+			return "", nil, err
+		}
+		if isPublic {
+			publicSubNets = append(publicSubNets, subNet)
+		}
+	}
+
+	if len(publicSubNets) < 2 {
+		return "", nil, fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
 	}
 	}
-	return vpc, subNets, nil
+	return vpc, publicSubNets, nil
 }
 }
 
 
 func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) {
 func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) {

+ 17 - 1
ecs/aws_mock.go

@@ -6,12 +6,13 @@ package ecs
 
 
 import (
 import (
 	context "context"
 	context "context"
+	reflect "reflect"
+
 	cloudformation "github.com/aws/aws-sdk-go/service/cloudformation"
 	cloudformation "github.com/aws/aws-sdk-go/service/cloudformation"
 	ecs "github.com/aws/aws-sdk-go/service/ecs"
 	ecs "github.com/aws/aws-sdk-go/service/ecs"
 	compose "github.com/docker/compose-cli/api/compose"
 	compose "github.com/docker/compose-cli/api/compose"
 	secrets "github.com/docker/compose-cli/api/secrets"
 	secrets "github.com/docker/compose-cli/api/secrets"
 	gomock "github.com/golang/mock/gomock"
 	gomock "github.com/golang/mock/gomock"
-	reflect "reflect"
 )
 )
 
 
 // MockAPI is a mock of API interface
 // MockAPI is a mock of API interface
@@ -453,6 +454,21 @@ func (mr *MockAPIMockRecorder) InspectSecret(arg0, arg1 interface{}) *gomock.Cal
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InspectSecret", reflect.TypeOf((*MockAPI)(nil).InspectSecret), arg0, arg1)
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InspectSecret", reflect.TypeOf((*MockAPI)(nil).InspectSecret), arg0, arg1)
 }
 }
 
 
+// IsPublicSubnet mocks base method
+func (m *MockAPI) IsPublicSubnet(ctx context.Context, arg0 string, arg1 string) (bool, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "IsPublicSubnet", arg0, arg1)
+	ret0, _ := ret[0].(bool)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// IsPublicSubnet indicates an expected call of IsPublicSubnet
+func (mr *MockAPIMockRecorder) IsPublicSubnet(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPublicSubnet", reflect.TypeOf((*MockAPI)(nil).IsPublicSubnet), arg0, arg1)
+}
+
 // ListFileSystems mocks base method
 // ListFileSystems mocks base method
 func (m *MockAPI) ListFileSystems(arg0 context.Context, arg1 map[string]string) ([]awsResource, error) {
 func (m *MockAPI) ListFileSystems(arg0 context.Context, arg1 map[string]string) ([]awsResource, error) {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()

+ 2 - 0
ecs/cloudformation_test.go

@@ -591,6 +591,8 @@ func useDefaultVPC(m *MockAPIMockRecorder) {
 		existingAWSResource{id: "subnet1"},
 		existingAWSResource{id: "subnet1"},
 		existingAWSResource{id: "subnet2"},
 		existingAWSResource{id: "subnet2"},
 	}, nil)
 	}, nil)
+	m.IsPublicSubnet(gomock.Any(), "subnet1").Return(true, nil)
+	m.IsPublicSubnet(gomock.Any(), "subnet2").Return(true, nil)
 }
 }
 
 
 func useGPU(m *MockAPIMockRecorder) {
 func useGPU(m *MockAPIMockRecorder) {

+ 31 - 0
ecs/sdk.go

@@ -211,6 +211,37 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error
 	return ids, nil
 	return ids, nil
 }
 }
 
 
+func (s sdk) IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) {
+	tables, err := s.EC2.DescribeRouteTablesWithContext(ctx, &ec2.DescribeRouteTablesInput{
+		Filters: []*ec2.Filter{
+			{
+				Name:   aws.String("association.subnet-id"),
+				Values: []*string{aws.String(subNetID)},
+			},
+		},
+	})
+	if err != nil {
+		return false, err
+	}
+	if len(tables.RouteTables) == 0 {
+		// If a subnet is not explicitly associated with any route table, it is implicitly associated with the main route table.
+		// https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-route-tables.html
+		return true, nil
+	}
+	for _, routeTable := range tables.RouteTables {
+		for _, route := range routeTable.Routes {
+			if aws.StringValue(route.State) != "active" {
+				continue
+			}
+			if strings.HasPrefix(aws.StringValue(route.GatewayId), "igw-") {
+				// Connected to an internet Gateway
+				return true, nil
+			}
+		}
+	}
+	return false, nil
+}
+
 func (s sdk) GetRoleArn(ctx context.Context, name string) (string, error) {
 func (s sdk) GetRoleArn(ctx context.Context, name string) (string, error) {
 	role, err := s.IAM.GetRoleWithContext(ctx, &iam.GetRoleInput{
 	role, err := s.IAM.GetRoleWithContext(ctx, &iam.GetRoleInput{
 		RoleName: aws.String(name),
 		RoleName: aws.String(name),