浏览代码

More unig tests with aws-sdk behind an interface + mocks

Fix use of existing SecurityGroup

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 5 年之前
父节点
当前提交
432ce79e2c
共有 20 个文件被更改,包括 910 次插入213 次删除
  1. 1 1
      ecs/autoscaling_test.go
  2. 54 0
      ecs/aws.go
  3. 14 9
      ecs/awsResources.go
  4. 617 0
      ecs/aws_mock.go
  5. 2 2
      ecs/backend.go
  6. 113 105
      ecs/cloudformation.go
  7. 69 61
      ecs/cloudformation_test.go
  8. 1 0
      ecs/compatibility.go
  9. 5 5
      ecs/down.go
  10. 1 1
      ecs/ec2.go
  11. 6 6
      ecs/list.go
  12. 1 1
      ecs/logs.go
  13. 3 3
      ecs/ps.go
  14. 5 3
      ecs/sdk.go
  15. 4 4
      ecs/secrets.go
  16. 3 3
      ecs/testdata/simple/simple-cloudformation-conversion.golden
  17. 5 5
      ecs/up.go
  18. 4 4
      ecs/wait.go
  19. 1 0
      go.mod
  20. 1 0
      go.sum

+ 1 - 1
ecs/autoscaling_test.go

@@ -30,7 +30,7 @@ services:
     image: hello_world
     image: hello_world
     deploy:
     deploy:
       x-aws-autoscaling: 75
       x-aws-autoscaling: 75
-`)
+`, useDefaultVPC)
 	target := template.Resources["FooScalableTarget"].(*autoscaling.ScalableTarget)
 	target := template.Resources["FooScalableTarget"].(*autoscaling.ScalableTarget)
 	assert.Check(t, target != nil)
 	assert.Check(t, target != nil)
 	policy := template.Resources["FooScalingPolicy"].(*autoscaling.ScalingPolicy)
 	policy := template.Resources["FooScalingPolicy"].(*autoscaling.ScalingPolicy)

+ 54 - 0
ecs/aws.go

@@ -16,7 +16,61 @@
 
 
 package ecs
 package ecs
 
 
+import (
+	"context"
+
+	"github.com/aws/aws-sdk-go/service/cloudformation"
+	"github.com/aws/aws-sdk-go/service/ecs"
+	"github.com/docker/compose-cli/api/compose"
+	"github.com/docker/compose-cli/api/secrets"
+)
+
 const (
 const (
 	awsTypeCapacityProvider = "AWS::ECS::CapacityProvider"
 	awsTypeCapacityProvider = "AWS::ECS::CapacityProvider"
 	awsTypeAutoscalingGroup = "AWS::AutoScaling::AutoScalingGroup"
 	awsTypeAutoscalingGroup = "AWS::AutoScaling::AutoScalingGroup"
 )
 )
+
+//go:generate mockgen -destination=./aws_mock.go -self_package "github.com/docker/compose-cli/ecs" -package=ecs . API
+
+// API hides aws-go-sdk into a simpler, focussed API subset
+type API interface {
+	CheckRequirements(ctx context.Context, region string) error
+	ClusterExists(ctx context.Context, name string) (bool, error)
+	CreateCluster(ctx context.Context, name string) (string, error)
+	CheckVPC(ctx context.Context, vpcID string) error
+	GetDefaultVPC(ctx context.Context) (string, error)
+	GetSubNets(ctx context.Context, vpcID string) ([]string, error)
+	GetRoleArn(ctx context.Context, name string) (string, error)
+	StackExists(ctx context.Context, name string) (bool, error)
+	CreateStack(ctx context.Context, name string, template []byte) error
+	CreateChangeSet(ctx context.Context, name string, template []byte) (string, error)
+	UpdateStack(ctx context.Context, changeset string) error
+	WaitStackComplete(ctx context.Context, name string, operation int) error
+	GetStackID(ctx context.Context, name string) (string, error)
+	ListStacks(ctx context.Context, name string) ([]compose.Stack, error)
+	GetStackClusterID(ctx context.Context, stack string) (string, error)
+	GetServiceTaskDefinition(ctx context.Context, cluster string, serviceArns []string) (map[string]string, error)
+	ListStackServices(ctx context.Context, stack string) ([]string, error)
+	GetServiceTasks(ctx context.Context, cluster string, service string, stopped bool) ([]*ecs.Task, error)
+	GetTaskStoppedReason(ctx context.Context, cluster string, taskArn string) (string, error)
+	DescribeStackEvents(ctx context.Context, stackID string) ([]*cloudformation.StackEvent, error)
+	ListStackParameters(ctx context.Context, name string) (map[string]string, error)
+	ListStackResources(ctx context.Context, name string) (stackResources, error)
+	DeleteStack(ctx context.Context, name string) error
+	CreateSecret(ctx context.Context, secret secrets.Secret) (string, error)
+	InspectSecret(ctx context.Context, id string) (secrets.Secret, error)
+	ListSecrets(ctx context.Context) ([]secrets.Secret, error)
+	DeleteSecret(ctx context.Context, id string, recover bool) error
+	GetLogs(ctx context.Context, name string, consumer func(service, container, message string)) error
+	DescribeService(ctx context.Context, cluster string, arn string) (compose.ServiceStatus, error)
+	getURLWithPortMapping(ctx context.Context, targetGroupArns []string) ([]compose.PortPublisher, error)
+	ListTasks(ctx context.Context, cluster string, family string) ([]string, error)
+	GetPublicIPs(ctx context.Context, interfaces ...string) (map[string]string, error)
+	LoadBalancerType(ctx context.Context, arn string) (string, error)
+	GetLoadBalancerURL(ctx context.Context, arn string) (string, error)
+	WithVolumeSecurityGroups(ctx context.Context, id string, fn func(securityGroups []string) error) error
+	GetParameter(ctx context.Context, name string) (string, error)
+	SecurityGroupExists(ctx context.Context, sg string) (bool, error)
+	DeleteCapacityProvider(ctx context.Context, arn string) error
+	DeleteAutoscalingGroup(ctx context.Context, arn string) error
+}

+ 14 - 9
ecs/awsResources.go

@@ -82,7 +82,7 @@ func (b *ecsAPIService) parse(ctx context.Context, project *types.Project) (awsR
 func (b *ecsAPIService) parseClusterExtension(ctx context.Context, project *types.Project) (string, error) {
 func (b *ecsAPIService) parseClusterExtension(ctx context.Context, project *types.Project) (string, error) {
 	if x, ok := project.Extensions[extensionCluster]; ok {
 	if x, ok := project.Extensions[extensionCluster]; ok {
 		cluster := x.(string)
 		cluster := x.(string)
-		ok, err := b.SDK.ClusterExists(ctx, cluster)
+		ok, err := b.aws.ClusterExists(ctx, cluster)
 		if err != nil {
 		if err != nil {
 			return "", err
 			return "", err
 		}
 		}
@@ -98,20 +98,20 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
 	var vpc string
 	var vpc string
 	if x, ok := project.Extensions[extensionVPC]; ok {
 	if x, ok := project.Extensions[extensionVPC]; ok {
 		vpc = x.(string)
 		vpc = x.(string)
-		err := b.SDK.CheckVPC(ctx, vpc)
+		err := b.aws.CheckVPC(ctx, vpc)
 		if err != nil {
 		if err != nil {
 			return "", nil, err
 			return "", nil, err
 		}
 		}
 
 
 	} else {
 	} else {
-		defaultVPC, err := b.SDK.GetDefaultVPC(ctx)
+		defaultVPC, err := b.aws.GetDefaultVPC(ctx)
 		if err != nil {
 		if err != nil {
 			return "", nil, err
 			return "", nil, err
 		}
 		}
 		vpc = defaultVPC
 		vpc = defaultVPC
 	}
 	}
 
 
-	subNets, err := b.SDK.GetSubNets(ctx, vpc)
+	subNets, err := b.aws.GetSubNets(ctx, vpc)
 	if err != nil {
 	if err != nil {
 		return "", nil, err
 		return "", nil, err
 	}
 	}
@@ -124,7 +124,7 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
 func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (string, string, error) {
 func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (string, string, error) {
 	if x, ok := project.Extensions[extensionLoadBalancer]; ok {
 	if x, ok := project.Extensions[extensionLoadBalancer]; ok {
 		loadBalancer := x.(string)
 		loadBalancer := x.(string)
-		loadBalancerType, err := b.SDK.LoadBalancerType(ctx, loadBalancer)
+		loadBalancerType, err := b.aws.LoadBalancerType(ctx, loadBalancer)
 		if err != nil {
 		if err != nil {
 			return "", "", err
 			return "", "", err
 		}
 		}
@@ -142,16 +142,16 @@ func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project
 func (b *ecsAPIService) parseSecurityGroupExtension(ctx context.Context, project *types.Project) (map[string]string, error) {
 func (b *ecsAPIService) parseSecurityGroupExtension(ctx context.Context, project *types.Project) (map[string]string, error) {
 	securityGroups := make(map[string]string, len(project.Networks))
 	securityGroups := make(map[string]string, len(project.Networks))
 	for name, net := range project.Networks {
 	for name, net := range project.Networks {
-		var sg string
-		if net.External.External {
-			sg = net.Name
+		if !net.External.External {
+			continue
 		}
 		}
+		sg := net.Name
 		if x, ok := net.Extensions[extensionSecurityGroup]; ok {
 		if x, ok := net.Extensions[extensionSecurityGroup]; ok {
 			logrus.Warn("to use an existing security-group, use `network.external` and `network.name` in your compose file")
 			logrus.Warn("to use an existing security-group, use `network.external` and `network.name` in your compose file")
 			logrus.Debugf("Security Group for network %q set by user to %q", net.Name, x)
 			logrus.Debugf("Security Group for network %q set by user to %q", net.Name, x)
 			sg = x.(string)
 			sg = x.(string)
 		}
 		}
-		exists, err := b.SDK.SecurityGroupExists(ctx, sg)
+		exists, err := b.aws.SecurityGroupExists(ctx, sg)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -186,6 +186,11 @@ func (b *ecsAPIService) ensureNetworks(r *awsResources, project *types.Project,
 		r.securityGroups = make(map[string]string, len(project.Networks))
 		r.securityGroups = make(map[string]string, len(project.Networks))
 	}
 	}
 	for name, net := range project.Networks {
 	for name, net := range project.Networks {
+		if net.External.External {
+			r.securityGroups[name] = net.Name
+			continue
+		}
+
 		securityGroup := networkResourceName(name)
 		securityGroup := networkResourceName(name)
 		template.Resources[securityGroup] = &ec2.SecurityGroup{
 		template.Resources[securityGroup] = &ec2.SecurityGroup{
 			GroupDescription: fmt.Sprintf("%s Security Group for %s network", project.Name, name),
 			GroupDescription: fmt.Sprintf("%s Security Group for %s network", project.Name, name),

+ 617 - 0
ecs/aws_mock.go

@@ -0,0 +1,617 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/docker/compose-cli/ecs (interfaces: API)
+
+// Package ecs is a generated GoMock package.
+package ecs
+
+import (
+	context "context"
+	cloudformation "github.com/aws/aws-sdk-go/service/cloudformation"
+	ecs "github.com/aws/aws-sdk-go/service/ecs"
+	compose "github.com/docker/compose-cli/api/compose"
+	secrets "github.com/docker/compose-cli/api/secrets"
+	gomock "github.com/golang/mock/gomock"
+	reflect "reflect"
+)
+
+// MockAPI is a mock of API interface
+type MockAPI struct {
+	ctrl     *gomock.Controller
+	recorder *MockAPIMockRecorder
+}
+
+// MockAPIMockRecorder is the mock recorder for MockAPI
+type MockAPIMockRecorder struct {
+	mock *MockAPI
+}
+
+// NewMockAPI creates a new mock instance
+func NewMockAPI(ctrl *gomock.Controller) *MockAPI {
+	mock := &MockAPI{ctrl: ctrl}
+	mock.recorder = &MockAPIMockRecorder{mock}
+	return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MockAPI) EXPECT() *MockAPIMockRecorder {
+	return m.recorder
+}
+
+// CheckRequirements mocks base method
+func (m *MockAPI) CheckRequirements(arg0 context.Context, arg1 string) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "CheckRequirements", arg0, arg1)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// CheckRequirements indicates an expected call of CheckRequirements
+func (mr *MockAPIMockRecorder) CheckRequirements(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckRequirements", reflect.TypeOf((*MockAPI)(nil).CheckRequirements), arg0, arg1)
+}
+
+// CheckVPC mocks base method
+func (m *MockAPI) CheckVPC(arg0 context.Context, arg1 string) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "CheckVPC", arg0, arg1)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// CheckVPC indicates an expected call of CheckVPC
+func (mr *MockAPIMockRecorder) CheckVPC(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckVPC", reflect.TypeOf((*MockAPI)(nil).CheckVPC), arg0, arg1)
+}
+
+// ClusterExists mocks base method
+func (m *MockAPI) ClusterExists(arg0 context.Context, arg1 string) (bool, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "ClusterExists", arg0, arg1)
+	ret0, _ := ret[0].(bool)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// ClusterExists indicates an expected call of ClusterExists
+func (mr *MockAPIMockRecorder) ClusterExists(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterExists", reflect.TypeOf((*MockAPI)(nil).ClusterExists), arg0, arg1)
+}
+
+// CreateChangeSet mocks base method
+func (m *MockAPI) CreateChangeSet(arg0 context.Context, arg1 string, arg2 []byte) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "CreateChangeSet", arg0, arg1, arg2)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// CreateChangeSet indicates an expected call of CreateChangeSet
+func (mr *MockAPIMockRecorder) CreateChangeSet(arg0, arg1, arg2 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateChangeSet", reflect.TypeOf((*MockAPI)(nil).CreateChangeSet), arg0, arg1, arg2)
+}
+
+// CreateCluster mocks base method
+func (m *MockAPI) CreateCluster(arg0 context.Context, arg1 string) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "CreateCluster", arg0, arg1)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// CreateCluster indicates an expected call of CreateCluster
+func (mr *MockAPIMockRecorder) CreateCluster(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCluster", reflect.TypeOf((*MockAPI)(nil).CreateCluster), arg0, arg1)
+}
+
+// CreateSecret mocks base method
+func (m *MockAPI) CreateSecret(arg0 context.Context, arg1 secrets.Secret) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "CreateSecret", arg0, arg1)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// CreateSecret indicates an expected call of CreateSecret
+func (mr *MockAPIMockRecorder) CreateSecret(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSecret", reflect.TypeOf((*MockAPI)(nil).CreateSecret), arg0, arg1)
+}
+
+// CreateStack mocks base method
+func (m *MockAPI) CreateStack(arg0 context.Context, arg1 string, arg2 []byte) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "CreateStack", arg0, arg1, arg2)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// CreateStack indicates an expected call of CreateStack
+func (mr *MockAPIMockRecorder) CreateStack(arg0, arg1, arg2 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateStack", reflect.TypeOf((*MockAPI)(nil).CreateStack), arg0, arg1, arg2)
+}
+
+// DeleteAutoscalingGroup mocks base method
+func (m *MockAPI) DeleteAutoscalingGroup(arg0 context.Context, arg1 string) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "DeleteAutoscalingGroup", arg0, arg1)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// DeleteAutoscalingGroup indicates an expected call of DeleteAutoscalingGroup
+func (mr *MockAPIMockRecorder) DeleteAutoscalingGroup(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAutoscalingGroup", reflect.TypeOf((*MockAPI)(nil).DeleteAutoscalingGroup), arg0, arg1)
+}
+
+// DeleteCapacityProvider mocks base method
+func (m *MockAPI) DeleteCapacityProvider(arg0 context.Context, arg1 string) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "DeleteCapacityProvider", arg0, arg1)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// DeleteCapacityProvider indicates an expected call of DeleteCapacityProvider
+func (mr *MockAPIMockRecorder) DeleteCapacityProvider(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCapacityProvider", reflect.TypeOf((*MockAPI)(nil).DeleteCapacityProvider), arg0, arg1)
+}
+
+// DeleteSecret mocks base method
+func (m *MockAPI) DeleteSecret(arg0 context.Context, arg1 string, arg2 bool) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "DeleteSecret", arg0, arg1, arg2)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// DeleteSecret indicates an expected call of DeleteSecret
+func (mr *MockAPIMockRecorder) DeleteSecret(arg0, arg1, arg2 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSecret", reflect.TypeOf((*MockAPI)(nil).DeleteSecret), arg0, arg1, arg2)
+}
+
+// DeleteStack mocks base method
+func (m *MockAPI) DeleteStack(arg0 context.Context, arg1 string) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "DeleteStack", arg0, arg1)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// DeleteStack indicates an expected call of DeleteStack
+func (mr *MockAPIMockRecorder) DeleteStack(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStack", reflect.TypeOf((*MockAPI)(nil).DeleteStack), arg0, arg1)
+}
+
+// DescribeService mocks base method
+func (m *MockAPI) DescribeService(arg0 context.Context, arg1, arg2 string) (compose.ServiceStatus, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "DescribeService", arg0, arg1, arg2)
+	ret0, _ := ret[0].(compose.ServiceStatus)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// DescribeService indicates an expected call of DescribeService
+func (mr *MockAPIMockRecorder) DescribeService(arg0, arg1, arg2 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeService", reflect.TypeOf((*MockAPI)(nil).DescribeService), arg0, arg1, arg2)
+}
+
+// DescribeStackEvents mocks base method
+func (m *MockAPI) DescribeStackEvents(arg0 context.Context, arg1 string) ([]*cloudformation.StackEvent, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "DescribeStackEvents", arg0, arg1)
+	ret0, _ := ret[0].([]*cloudformation.StackEvent)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// DescribeStackEvents indicates an expected call of DescribeStackEvents
+func (mr *MockAPIMockRecorder) DescribeStackEvents(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeStackEvents", reflect.TypeOf((*MockAPI)(nil).DescribeStackEvents), arg0, arg1)
+}
+
+// GetDefaultVPC mocks base method
+func (m *MockAPI) GetDefaultVPC(arg0 context.Context) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetDefaultVPC", arg0)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetDefaultVPC indicates an expected call of GetDefaultVPC
+func (mr *MockAPIMockRecorder) GetDefaultVPC(arg0 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultVPC", reflect.TypeOf((*MockAPI)(nil).GetDefaultVPC), arg0)
+}
+
+// GetLoadBalancerURL mocks base method
+func (m *MockAPI) GetLoadBalancerURL(arg0 context.Context, arg1 string) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetLoadBalancerURL", arg0, arg1)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetLoadBalancerURL indicates an expected call of GetLoadBalancerURL
+func (mr *MockAPIMockRecorder) GetLoadBalancerURL(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLoadBalancerURL", reflect.TypeOf((*MockAPI)(nil).GetLoadBalancerURL), arg0, arg1)
+}
+
+// GetLogs mocks base method
+func (m *MockAPI) GetLogs(arg0 context.Context, arg1 string, arg2 func(string, string, string)) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetLogs", arg0, arg1, arg2)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// GetLogs indicates an expected call of GetLogs
+func (mr *MockAPIMockRecorder) GetLogs(arg0, arg1, arg2 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogs", reflect.TypeOf((*MockAPI)(nil).GetLogs), arg0, arg1, arg2)
+}
+
+// GetParameter mocks base method
+func (m *MockAPI) GetParameter(arg0 context.Context, arg1 string) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetParameter", arg0, arg1)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetParameter indicates an expected call of GetParameter
+func (mr *MockAPIMockRecorder) GetParameter(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetParameter", reflect.TypeOf((*MockAPI)(nil).GetParameter), arg0, arg1)
+}
+
+// GetPublicIPs mocks base method
+func (m *MockAPI) GetPublicIPs(arg0 context.Context, arg1 ...string) (map[string]string, error) {
+	m.ctrl.T.Helper()
+	varargs := []interface{}{arg0}
+	for _, a := range arg1 {
+		varargs = append(varargs, a)
+	}
+	ret := m.ctrl.Call(m, "GetPublicIPs", varargs...)
+	ret0, _ := ret[0].(map[string]string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetPublicIPs indicates an expected call of GetPublicIPs
+func (mr *MockAPIMockRecorder) GetPublicIPs(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	varargs := append([]interface{}{arg0}, arg1...)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicIPs", reflect.TypeOf((*MockAPI)(nil).GetPublicIPs), varargs...)
+}
+
+// GetRoleArn mocks base method
+func (m *MockAPI) GetRoleArn(arg0 context.Context, arg1 string) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetRoleArn", arg0, arg1)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetRoleArn indicates an expected call of GetRoleArn
+func (mr *MockAPIMockRecorder) GetRoleArn(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoleArn", reflect.TypeOf((*MockAPI)(nil).GetRoleArn), arg0, arg1)
+}
+
+// GetServiceTaskDefinition mocks base method
+func (m *MockAPI) GetServiceTaskDefinition(arg0 context.Context, arg1 string, arg2 []string) (map[string]string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetServiceTaskDefinition", arg0, arg1, arg2)
+	ret0, _ := ret[0].(map[string]string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetServiceTaskDefinition indicates an expected call of GetServiceTaskDefinition
+func (mr *MockAPIMockRecorder) GetServiceTaskDefinition(arg0, arg1, arg2 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTaskDefinition", reflect.TypeOf((*MockAPI)(nil).GetServiceTaskDefinition), arg0, arg1, arg2)
+}
+
+// GetServiceTasks mocks base method
+func (m *MockAPI) GetServiceTasks(arg0 context.Context, arg1, arg2 string, arg3 bool) ([]*ecs.Task, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetServiceTasks", arg0, arg1, arg2, arg3)
+	ret0, _ := ret[0].([]*ecs.Task)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetServiceTasks indicates an expected call of GetServiceTasks
+func (mr *MockAPIMockRecorder) GetServiceTasks(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTasks", reflect.TypeOf((*MockAPI)(nil).GetServiceTasks), arg0, arg1, arg2, arg3)
+}
+
+// GetStackClusterID mocks base method
+func (m *MockAPI) GetStackClusterID(arg0 context.Context, arg1 string) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetStackClusterID", arg0, arg1)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetStackClusterID indicates an expected call of GetStackClusterID
+func (mr *MockAPIMockRecorder) GetStackClusterID(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStackClusterID", reflect.TypeOf((*MockAPI)(nil).GetStackClusterID), arg0, arg1)
+}
+
+// GetStackID mocks base method
+func (m *MockAPI) GetStackID(arg0 context.Context, arg1 string) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetStackID", arg0, arg1)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetStackID indicates an expected call of GetStackID
+func (mr *MockAPIMockRecorder) GetStackID(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStackID", reflect.TypeOf((*MockAPI)(nil).GetStackID), arg0, arg1)
+}
+
+// GetSubNets mocks base method
+func (m *MockAPI) GetSubNets(arg0 context.Context, arg1 string) ([]string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetSubNets", arg0, arg1)
+	ret0, _ := ret[0].([]string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetSubNets indicates an expected call of GetSubNets
+func (mr *MockAPIMockRecorder) GetSubNets(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubNets", reflect.TypeOf((*MockAPI)(nil).GetSubNets), arg0, arg1)
+}
+
+// GetTaskStoppedReason mocks base method
+func (m *MockAPI) GetTaskStoppedReason(arg0 context.Context, arg1, arg2 string) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "GetTaskStoppedReason", arg0, arg1, arg2)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// GetTaskStoppedReason indicates an expected call of GetTaskStoppedReason
+func (mr *MockAPIMockRecorder) GetTaskStoppedReason(arg0, arg1, arg2 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskStoppedReason", reflect.TypeOf((*MockAPI)(nil).GetTaskStoppedReason), arg0, arg1, arg2)
+}
+
+// InspectSecret mocks base method
+func (m *MockAPI) InspectSecret(arg0 context.Context, arg1 string) (secrets.Secret, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "InspectSecret", arg0, arg1)
+	ret0, _ := ret[0].(secrets.Secret)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// InspectSecret indicates an expected call of InspectSecret
+func (mr *MockAPIMockRecorder) InspectSecret(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InspectSecret", reflect.TypeOf((*MockAPI)(nil).InspectSecret), arg0, arg1)
+}
+
+// ListSecrets mocks base method
+func (m *MockAPI) ListSecrets(arg0 context.Context) ([]secrets.Secret, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "ListSecrets", arg0)
+	ret0, _ := ret[0].([]secrets.Secret)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// ListSecrets indicates an expected call of ListSecrets
+func (mr *MockAPIMockRecorder) ListSecrets(arg0 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSecrets", reflect.TypeOf((*MockAPI)(nil).ListSecrets), arg0)
+}
+
+// ListStackParameters mocks base method
+func (m *MockAPI) ListStackParameters(arg0 context.Context, arg1 string) (map[string]string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "ListStackParameters", arg0, arg1)
+	ret0, _ := ret[0].(map[string]string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// ListStackParameters indicates an expected call of ListStackParameters
+func (mr *MockAPIMockRecorder) ListStackParameters(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListStackParameters", reflect.TypeOf((*MockAPI)(nil).ListStackParameters), arg0, arg1)
+}
+
+// ListStackResources mocks base method
+func (m *MockAPI) ListStackResources(arg0 context.Context, arg1 string) (stackResources, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "ListStackResources", arg0, arg1)
+	ret0, _ := ret[0].(stackResources)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// ListStackResources indicates an expected call of ListStackResources
+func (mr *MockAPIMockRecorder) ListStackResources(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListStackResources", reflect.TypeOf((*MockAPI)(nil).ListStackResources), arg0, arg1)
+}
+
+// ListStackServices mocks base method
+func (m *MockAPI) ListStackServices(arg0 context.Context, arg1 string) ([]string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "ListStackServices", arg0, arg1)
+	ret0, _ := ret[0].([]string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// ListStackServices indicates an expected call of ListStackServices
+func (mr *MockAPIMockRecorder) ListStackServices(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListStackServices", reflect.TypeOf((*MockAPI)(nil).ListStackServices), arg0, arg1)
+}
+
+// ListStacks mocks base method
+func (m *MockAPI) ListStacks(arg0 context.Context, arg1 string) ([]compose.Stack, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "ListStacks", arg0, arg1)
+	ret0, _ := ret[0].([]compose.Stack)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// ListStacks indicates an expected call of ListStacks
+func (mr *MockAPIMockRecorder) ListStacks(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListStacks", reflect.TypeOf((*MockAPI)(nil).ListStacks), arg0, arg1)
+}
+
+// ListTasks mocks base method
+func (m *MockAPI) ListTasks(arg0 context.Context, arg1, arg2 string) ([]string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "ListTasks", arg0, arg1, arg2)
+	ret0, _ := ret[0].([]string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// ListTasks indicates an expected call of ListTasks
+func (mr *MockAPIMockRecorder) ListTasks(arg0, arg1, arg2 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTasks", reflect.TypeOf((*MockAPI)(nil).ListTasks), arg0, arg1, arg2)
+}
+
+// LoadBalancerType mocks base method
+func (m *MockAPI) LoadBalancerType(arg0 context.Context, arg1 string) (string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "LoadBalancerType", arg0, arg1)
+	ret0, _ := ret[0].(string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// LoadBalancerType indicates an expected call of LoadBalancerType
+func (mr *MockAPIMockRecorder) LoadBalancerType(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadBalancerType", reflect.TypeOf((*MockAPI)(nil).LoadBalancerType), arg0, arg1)
+}
+
+// SecurityGroupExists mocks base method
+func (m *MockAPI) SecurityGroupExists(arg0 context.Context, arg1 string) (bool, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "SecurityGroupExists", arg0, arg1)
+	ret0, _ := ret[0].(bool)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// SecurityGroupExists indicates an expected call of SecurityGroupExists
+func (mr *MockAPIMockRecorder) SecurityGroupExists(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SecurityGroupExists", reflect.TypeOf((*MockAPI)(nil).SecurityGroupExists), arg0, arg1)
+}
+
+// StackExists mocks base method
+func (m *MockAPI) StackExists(arg0 context.Context, arg1 string) (bool, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "StackExists", arg0, arg1)
+	ret0, _ := ret[0].(bool)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// StackExists indicates an expected call of StackExists
+func (mr *MockAPIMockRecorder) StackExists(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StackExists", reflect.TypeOf((*MockAPI)(nil).StackExists), arg0, arg1)
+}
+
+// UpdateStack mocks base method
+func (m *MockAPI) UpdateStack(arg0 context.Context, arg1 string) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "UpdateStack", arg0, arg1)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// UpdateStack indicates an expected call of UpdateStack
+func (mr *MockAPIMockRecorder) UpdateStack(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStack", reflect.TypeOf((*MockAPI)(nil).UpdateStack), arg0, arg1)
+}
+
+// WaitStackComplete mocks base method
+func (m *MockAPI) WaitStackComplete(arg0 context.Context, arg1 string, arg2 int) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "WaitStackComplete", arg0, arg1, arg2)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// WaitStackComplete indicates an expected call of WaitStackComplete
+func (mr *MockAPIMockRecorder) WaitStackComplete(arg0, arg1, arg2 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitStackComplete", reflect.TypeOf((*MockAPI)(nil).WaitStackComplete), arg0, arg1, arg2)
+}
+
+// WithVolumeSecurityGroups mocks base method
+func (m *MockAPI) WithVolumeSecurityGroups(arg0 context.Context, arg1 string, arg2 func([]string) error) error {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "WithVolumeSecurityGroups", arg0, arg1, arg2)
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// WithVolumeSecurityGroups indicates an expected call of WithVolumeSecurityGroups
+func (mr *MockAPIMockRecorder) WithVolumeSecurityGroups(arg0, arg1, arg2 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithVolumeSecurityGroups", reflect.TypeOf((*MockAPI)(nil).WithVolumeSecurityGroups), arg0, arg1, arg2)
+}
+
+// getURLWithPortMapping mocks base method
+func (m *MockAPI) getURLWithPortMapping(arg0 context.Context, arg1 []string) ([]compose.PortPublisher, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "getURLWithPortMapping", arg0, arg1)
+	ret0, _ := ret[0].([]compose.PortPublisher)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// getURLWithPortMapping indicates an expected call of getURLWithPortMapping
+func (mr *MockAPIMockRecorder) getURLWithPortMapping(arg0, arg1 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getURLWithPortMapping", reflect.TypeOf((*MockAPI)(nil).getURLWithPortMapping), arg0, arg1)
+}

+ 2 - 2
ecs/backend.go

@@ -77,14 +77,14 @@ func getEcsAPIService(ecsCtx store.EcsContext) (*ecsAPIService, error) {
 	return &ecsAPIService{
 	return &ecsAPIService{
 		ctx:    ecsCtx,
 		ctx:    ecsCtx,
 		Region: ecsCtx.Region,
 		Region: ecsCtx.Region,
-		SDK:    sdk,
+		aws:    sdk,
 	}, nil
 	}, nil
 }
 }
 
 
 type ecsAPIService struct {
 type ecsAPIService struct {
 	ctx    store.EcsContext
 	ctx    store.EcsContext
 	Region string
 	Region string
-	SDK    sdk
+	aws    API
 }
 }
 
 
 func (a *ecsAPIService) ContainerService() containers.Service {
 func (a *ecsAPIService) ContainerService() containers.Service {

+ 113 - 105
ecs/cloudformation.go

@@ -38,25 +38,53 @@ import (
 )
 )
 
 
 func (b *ecsAPIService) Convert(ctx context.Context, project *types.Project) ([]byte, error) {
 func (b *ecsAPIService) Convert(ctx context.Context, project *types.Project) ([]byte, error) {
-	err := b.checkCompatibility(project)
+	template, err := b.convert(ctx, project)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	resources, err := b.parse(ctx, project)
+	return marshall(template)
+}
+
+func (b *ecsAPIService) convert(ctx context.Context, project *types.Project) (*cloudformation.Template, error) {
+	err := b.checkCompatibility(project)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	template, err := b.convert(project, resources)
+	resources, err := b.parse(ctx, project)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
+	template := cloudformation.NewTemplate()
+	b.ensureResources(&resources, project, template)
+
+	for name, secret := range project.Secrets {
+		err := b.createSecret(project, name, secret, template)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	b.createLogGroup(project, template)
+
+	// Private DNS namespace will allow DNS name for the services to be <service>.<project>.local
+	b.createCloudMap(project, template, resources.vpc)
+
+	for _, service := range project.Services {
+		err := b.createService(project, service, template, resources)
+		if err != nil {
+			return nil, err
+		}
+
+		b.createAutoscalingPolicy(project, resources, template, service)
+	}
+
 	// Create a NFS inbound rule on each mount target for volumes
 	// Create a NFS inbound rule on each mount target for volumes
 	// as "source security group" use an arbitrary network attached to service(s) who mounts target volume
 	// as "source security group" use an arbitrary network attached to service(s) who mounts target volume
 	for n, vol := range project.Volumes {
 	for n, vol := range project.Volumes {
-		err := b.SDK.WithVolumeSecurityGroups(ctx, vol.Name, func(securityGroups []string) error {
+		err := b.aws.WithVolumeSecurityGroups(ctx, vol.Name, func(securityGroups []string) error {
 			return b.createNFSmountIngress(securityGroups, project, n, template)
 			return b.createNFSmountIngress(securityGroups, project, n, template)
 		})
 		})
 		if err != nil {
 		if err != nil {
@@ -69,124 +97,104 @@ func (b *ecsAPIService) Convert(ctx context.Context, project *types.Project) ([]
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	return marshall(template)
+	return template, nil
 }
 }
 
 
-// Convert a compose project into a CloudFormation template
-func (b *ecsAPIService) convert(project *types.Project, resources awsResources) (*cloudformation.Template, error) {
-	template := cloudformation.NewTemplate()
-	b.ensureResources(&resources, project, template)
+func (b *ecsAPIService) createService(project *types.Project, service types.ServiceConfig, template *cloudformation.Template, resources awsResources) error {
+	taskExecutionRole := b.createTaskExecutionRole(project, service, template)
+	taskRole := b.createTaskRole(project, service, template)
 
 
-	for name, secret := range project.Secrets {
-		err := b.createSecret(project, name, secret, template)
-		if err != nil {
-			return nil, err
-		}
+	definition, err := b.createTaskDefinition(project, service)
+	if err != nil {
+		return err
+	}
+	definition.ExecutionRoleArn = cloudformation.Ref(taskExecutionRole)
+	if taskRole != "" {
+		definition.TaskRoleArn = cloudformation.Ref(taskRole)
 	}
 	}
 
 
-	b.createLogGroup(project, template)
-
-	// Private DNS namespace will allow DNS name for the services to be <service>.<project>.local
-	b.createCloudMap(project, template, resources.vpc)
+	taskDefinition := fmt.Sprintf("%sTaskDefinition", normalizeResourceName(service.Name))
+	template.Resources[taskDefinition] = definition
 
 
-	for _, service := range project.Services {
-		taskExecutionRole := b.createTaskExecutionRole(project, service, template)
-		taskRole := b.createTaskRole(project, service, template)
+	var healthCheck *cloudmap.Service_HealthCheckConfig
+	serviceRegistry := b.createServiceRegistry(service, template, healthCheck)
 
 
-		definition, err := b.createTaskDefinition(project, service)
-		if err != nil {
-			return nil, err
-		}
-		definition.ExecutionRoleArn = cloudformation.Ref(taskExecutionRole)
-		if taskRole != "" {
-			definition.TaskRoleArn = cloudformation.Ref(taskRole)
+	var (
+		dependsOn []string
+		serviceLB []ecs.Service_LoadBalancer
+	)
+	for _, port := range service.Ports {
+		for net := range service.Networks {
+			b.createIngress(service, net, port, template, resources)
 		}
 		}
 
 
-		taskDefinition := fmt.Sprintf("%sTaskDefinition", normalizeResourceName(service.Name))
-		template.Resources[taskDefinition] = definition
-
-		var healthCheck *cloudmap.Service_HealthCheckConfig
-		serviceRegistry := b.createServiceRegistry(service, template, healthCheck)
-
-		var (
-			dependsOn []string
-			serviceLB []ecs.Service_LoadBalancer
-		)
-		for _, port := range service.Ports {
-			for net := range service.Networks {
-				b.createIngress(service, net, port, template, resources)
-			}
-
-			protocol := strings.ToUpper(port.Protocol)
-			if resources.loadBalancerType == elbv2.LoadBalancerTypeEnumApplication {
-				// we don't set Https as a certificate must be specified for HTTPS listeners
-				protocol = elbv2.ProtocolEnumHttp
-			}
-			targetGroupName := b.createTargetGroup(project, service, port, template, protocol, resources.vpc)
-			listenerName := b.createListener(service, port, template, targetGroupName, resources.loadBalancer, protocol)
-			dependsOn = append(dependsOn, listenerName)
-			serviceLB = append(serviceLB, ecs.Service_LoadBalancer{
-				ContainerName:  service.Name,
-				ContainerPort:  int(port.Target),
-				TargetGroupArn: cloudformation.Ref(targetGroupName),
-			})
+		protocol := strings.ToUpper(port.Protocol)
+		if resources.loadBalancerType == elbv2.LoadBalancerTypeEnumApplication {
+			// we don't set Https as a certificate must be specified for HTTPS listeners
+			protocol = elbv2.ProtocolEnumHttp
 		}
 		}
+		targetGroupName := b.createTargetGroup(project, service, port, template, protocol, resources.vpc)
+		listenerName := b.createListener(service, port, template, targetGroupName, resources.loadBalancer, protocol)
+		dependsOn = append(dependsOn, listenerName)
+		serviceLB = append(serviceLB, ecs.Service_LoadBalancer{
+			ContainerName:  service.Name,
+			ContainerPort:  int(port.Target),
+			TargetGroupArn: cloudformation.Ref(targetGroupName),
+		})
+	}
 
 
-		desiredCount := 1
-		if service.Deploy != nil && service.Deploy.Replicas != nil {
-			desiredCount = int(*service.Deploy.Replicas)
-		}
+	desiredCount := 1
+	if service.Deploy != nil && service.Deploy.Replicas != nil {
+		desiredCount = int(*service.Deploy.Replicas)
+	}
 
 
-		for dependency := range service.DependsOn {
-			dependsOn = append(dependsOn, serviceResourceName(dependency))
-		}
+	for dependency := range service.DependsOn {
+		dependsOn = append(dependsOn, serviceResourceName(dependency))
+	}
 
 
-		minPercent, maxPercent, err := computeRollingUpdateLimits(service)
-		if err != nil {
-			return nil, err
-		}
+	minPercent, maxPercent, err := computeRollingUpdateLimits(service)
+	if err != nil {
+		return err
+	}
 
 
-		assignPublicIP := ecsapi.AssignPublicIpEnabled
-		launchType := ecsapi.LaunchTypeFargate
-		platformVersion := "1.4.0" // LATEST which is set to 1.3.0 (?) which doesn’t allow efs volumes.
-		if requireEC2(service) {
-			assignPublicIP = ecsapi.AssignPublicIpDisabled
-			launchType = ecsapi.LaunchTypeEc2
-			platformVersion = "" // The platform version must be null when specifying an EC2 launch type
-		}
+	assignPublicIP := ecsapi.AssignPublicIpEnabled
+	launchType := ecsapi.LaunchTypeFargate
+	platformVersion := "1.4.0" // LATEST which is set to 1.3.0 (?) which doesn’t allow efs volumes.
+	if requireEC2(service) {
+		assignPublicIP = ecsapi.AssignPublicIpDisabled
+		launchType = ecsapi.LaunchTypeEc2
+		platformVersion = "" // The platform version must be null when specifying an EC2 launch type
+	}
 
 
-		template.Resources[serviceResourceName(service.Name)] = &ecs.Service{
-			AWSCloudFormationDependsOn: dependsOn,
-			Cluster:                    resources.cluster,
-			DesiredCount:               desiredCount,
-			DeploymentController: &ecs.Service_DeploymentController{
-				Type: ecsapi.DeploymentControllerTypeEcs,
-			},
-			DeploymentConfiguration: &ecs.Service_DeploymentConfiguration{
-				MaximumPercent:        maxPercent,
-				MinimumHealthyPercent: minPercent,
-			},
-			LaunchType: launchType,
-			// TODO we miss support for https://github.com/aws/containers-roadmap/issues/631 to select a capacity provider
-			LoadBalancers: serviceLB,
-			NetworkConfiguration: &ecs.Service_NetworkConfiguration{
-				AwsvpcConfiguration: &ecs.Service_AwsVpcConfiguration{
-					AssignPublicIp: assignPublicIP,
-					SecurityGroups: resources.serviceSecurityGroups(service),
-					Subnets:        resources.subnets,
-				},
+	template.Resources[serviceResourceName(service.Name)] = &ecs.Service{
+		AWSCloudFormationDependsOn: dependsOn,
+		Cluster:                    resources.cluster,
+		DesiredCount:               desiredCount,
+		DeploymentController: &ecs.Service_DeploymentController{
+			Type: ecsapi.DeploymentControllerTypeEcs,
+		},
+		DeploymentConfiguration: &ecs.Service_DeploymentConfiguration{
+			MaximumPercent:        maxPercent,
+			MinimumHealthyPercent: minPercent,
+		},
+		LaunchType: launchType,
+		// TODO we miss support for https://github.com/aws/containers-roadmap/issues/631 to select a capacity provider
+		LoadBalancers: serviceLB,
+		NetworkConfiguration: &ecs.Service_NetworkConfiguration{
+			AwsvpcConfiguration: &ecs.Service_AwsVpcConfiguration{
+				AssignPublicIp: assignPublicIP,
+				SecurityGroups: resources.serviceSecurityGroups(service),
+				Subnets:        resources.subnets,
 			},
 			},
-			PlatformVersion:    platformVersion,
-			PropagateTags:      ecsapi.PropagateTagsService,
-			SchedulingStrategy: ecsapi.SchedulingStrategyReplica,
-			ServiceRegistries:  []ecs.Service_ServiceRegistry{serviceRegistry},
-			Tags:               serviceTags(project, service),
-			TaskDefinition:     cloudformation.Ref(normalizeResourceName(taskDefinition)),
-		}
-
-		b.createAutoscalingPolicy(project, resources, template, service)
+		},
+		PlatformVersion:    platformVersion,
+		PropagateTags:      ecsapi.PropagateTagsService,
+		SchedulingStrategy: ecsapi.SchedulingStrategyReplica,
+		ServiceRegistries:  []ecs.Service_ServiceRegistry{serviceRegistry},
+		Tags:               serviceTags(project, service),
+		TaskDefinition:     cloudformation.Ref(normalizeResourceName(taskDefinition)),
 	}
 	}
-	return template, nil
+	return nil
 }
 }
 
 
 const allProtocols = "-1"
 const allProtocols = "-1"

+ 69 - 61
ecs/cloudformation_test.go

@@ -17,10 +17,14 @@
 package ecs
 package ecs
 
 
 import (
 import (
+	"context"
 	"fmt"
 	"fmt"
+	"io/ioutil"
 	"reflect"
 	"reflect"
 	"testing"
 	"testing"
 
 
+	"github.com/golang/mock/gomock"
+
 	"github.com/docker/compose-cli/api/compose"
 	"github.com/docker/compose-cli/api/compose"
 
 
 	"github.com/aws/aws-sdk-go/service/elbv2"
 	"github.com/aws/aws-sdk-go/service/elbv2"
@@ -30,7 +34,6 @@ import (
 	"github.com/awslabs/goformation/v4/cloudformation/elasticloadbalancingv2"
 	"github.com/awslabs/goformation/v4/cloudformation/elasticloadbalancingv2"
 	"github.com/awslabs/goformation/v4/cloudformation/iam"
 	"github.com/awslabs/goformation/v4/cloudformation/iam"
 	"github.com/awslabs/goformation/v4/cloudformation/logs"
 	"github.com/awslabs/goformation/v4/cloudformation/logs"
-	"github.com/compose-spec/compose-go/cli"
 	"github.com/compose-spec/compose-go/loader"
 	"github.com/compose-spec/compose-go/loader"
 	"github.com/compose-spec/compose-go/types"
 	"github.com/compose-spec/compose-go/types"
 	"gotest.tools/v3/assert"
 	"gotest.tools/v3/assert"
@@ -38,8 +41,12 @@ import (
 )
 )
 
 
 func TestSimpleConvert(t *testing.T) {
 func TestSimpleConvert(t *testing.T) {
-	project := load(t, "testdata/input/simple-single-service.yaml")
-	result := convertResultAsString(t, project)
+	bytes, err := ioutil.ReadFile("testdata/input/simple-single-service.yaml")
+	assert.NilError(t, err)
+	template := convertYaml(t, string(bytes), useDefaultVPC)
+	resultAsJSON, err := marshall(template)
+	assert.NilError(t, err)
+	result := fmt.Sprintf("%s\n", string(resultAsJSON))
 	expected := "simple/simple-cloudformation-conversion.golden"
 	expected := "simple/simple-cloudformation-conversion.golden"
 	golden.Assert(t, result, expected)
 	golden.Assert(t, result, expected)
 }
 }
@@ -54,7 +61,7 @@ services:
         awslogs-datetime-pattern: "FOO"
         awslogs-datetime-pattern: "FOO"
 
 
 x-aws-logs_retention: 10
 x-aws-logs_retention: 10
-`)
+`, useDefaultVPC)
 	def := template.Resources["FooTaskDefinition"].(*ecs.TaskDefinition)
 	def := template.Resources["FooTaskDefinition"].(*ecs.TaskDefinition)
 	logging := getMainContainer(def, t).LogConfiguration
 	logging := getMainContainer(def, t).LogConfiguration
 	if logging != nil {
 	if logging != nil {
@@ -74,7 +81,7 @@ services:
     image: hello_world
     image: hello_world
     env_file:
     env_file:
       - testdata/input/envfile
       - testdata/input/envfile
-`)
+`, useDefaultVPC)
 	def := template.Resources["FooTaskDefinition"].(*ecs.TaskDefinition)
 	def := template.Resources["FooTaskDefinition"].(*ecs.TaskDefinition)
 	env := getMainContainer(def, t).Environment
 	env := getMainContainer(def, t).Environment
 	var found bool
 	var found bool
@@ -96,7 +103,7 @@ services:
       - testdata/input/envfile
       - testdata/input/envfile
     environment:
     environment:
       - "FOO=ZOT"
       - "FOO=ZOT"
-`)
+`, useDefaultVPC)
 	def := template.Resources["FooTaskDefinition"].(*ecs.TaskDefinition)
 	def := template.Resources["FooTaskDefinition"].(*ecs.TaskDefinition)
 	env := getMainContainer(def, t).Environment
 	env := getMainContainer(def, t).Environment
 	var found bool
 	var found bool
@@ -118,7 +125,7 @@ services:
       replicas: 4 
       replicas: 4 
       update_config:
       update_config:
         parallelism: 2
         parallelism: 2
-`)
+`, useDefaultVPC)
 	service := template.Resources["FooService"].(*ecs.Service)
 	service := template.Resources["FooService"].(*ecs.Service)
 	assert.Check(t, service.DeploymentConfiguration.MaximumPercent == 150)
 	assert.Check(t, service.DeploymentConfiguration.MaximumPercent == 150)
 	assert.Check(t, service.DeploymentConfiguration.MinimumHealthyPercent == 50)
 	assert.Check(t, service.DeploymentConfiguration.MinimumHealthyPercent == 50)
@@ -133,7 +140,7 @@ services:
       update_config:
       update_config:
         x-aws-min_percent: 25
         x-aws-min_percent: 25
         x-aws-max_percent: 125
         x-aws-max_percent: 125
-`)
+`, useDefaultVPC)
 	service := template.Resources["FooService"].(*ecs.Service)
 	service := template.Resources["FooService"].(*ecs.Service)
 	assert.Check(t, service.DeploymentConfiguration.MaximumPercent == 125)
 	assert.Check(t, service.DeploymentConfiguration.MaximumPercent == 125)
 	assert.Check(t, service.DeploymentConfiguration.MinimumHealthyPercent == 25)
 	assert.Check(t, service.DeploymentConfiguration.MinimumHealthyPercent == 25)
@@ -145,7 +152,7 @@ services:
   foo:
   foo:
     image: hello_world
     image: hello_world
     x-aws-pull_credentials: "secret"
     x-aws-pull_credentials: "secret"
-`)
+`, useDefaultVPC)
 	x := template.Resources["FooTaskExecutionRole"]
 	x := template.Resources["FooTaskExecutionRole"]
 	assert.Check(t, x != nil)
 	assert.Check(t, x != nil)
 	role := *(x.(*iam.Role))
 	role := *(x.(*iam.Role))
@@ -173,7 +180,7 @@ networks:
     name: public
     name: public
   back-tier:
   back-tier:
     internal: true
     internal: true
-`)
+`, useDefaultVPC)
 	assert.Check(t, template.Resources["FronttierNetwork"] != nil)
 	assert.Check(t, template.Resources["FronttierNetwork"] != nil)
 	assert.Check(t, template.Resources["BacktierNetwork"] != nil)
 	assert.Check(t, template.Resources["BacktierNetwork"] != nil)
 	assert.Check(t, template.Resources["BacktierNetworkIngress"] != nil)
 	assert.Check(t, template.Resources["BacktierNetworkIngress"] != nil)
@@ -201,7 +208,7 @@ func TestLoadBalancerTypeApplication(t *testing.T) {
 `,
 `,
 	}
 	}
 	for _, y := range cases {
 	for _, y := range cases {
-		template := convertYaml(t, y)
+		template := convertYaml(t, y, useDefaultVPC)
 		lb := template.Resources["LoadBalancer"]
 		lb := template.Resources["LoadBalancer"]
 		assert.Check(t, lb != nil)
 		assert.Check(t, lb != nil)
 		loadBalancer := *lb.(*elasticloadbalancingv2.LoadBalancer)
 		loadBalancer := *lb.(*elasticloadbalancingv2.LoadBalancer)
@@ -218,7 +225,7 @@ services:
     image: nginx
     image: nginx
   foo:
   foo:
     image: bar
     image: bar
-`)
+`, useDefaultVPC)
 	for _, r := range template.Resources {
 	for _, r := range template.Resources {
 		assert.Check(t, r.AWSCloudFormationType() != "AWS::ElasticLoadBalancingV2::TargetGroup")
 		assert.Check(t, r.AWSCloudFormationType() != "AWS::ElasticLoadBalancingV2::TargetGroup")
 		assert.Check(t, r.AWSCloudFormationType() != "AWS::ElasticLoadBalancingV2::Listener")
 		assert.Check(t, r.AWSCloudFormationType() != "AWS::ElasticLoadBalancingV2::Listener")
@@ -233,7 +240,7 @@ services:
     image: nginx
     image: nginx
     deploy:
     deploy:
       replicas: 10
       replicas: 10
-`)
+`, useDefaultVPC)
 	s := template.Resources["TestService"]
 	s := template.Resources["TestService"]
 	assert.Check(t, s != nil)
 	assert.Check(t, s != nil)
 	service := *s.(*ecs.Service)
 	service := *s.(*ecs.Service)
@@ -245,7 +252,7 @@ func TestTaskSizeConvert(t *testing.T) {
 services:
 services:
   test:
   test:
     image: nginx
     image: nginx
-`)
+`, useDefaultVPC)
 	def := template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	def := template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	assert.Equal(t, def.Cpu, "256")
 	assert.Equal(t, def.Cpu, "256")
 	assert.Equal(t, def.Memory, "512")
 	assert.Equal(t, def.Memory, "512")
@@ -259,7 +266,7 @@ services:
         limits:
         limits:
           cpus: '0.5'
           cpus: '0.5'
           memory: 2048M
           memory: 2048M
-`)
+`, useDefaultVPC)
 	def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	assert.Equal(t, def.Cpu, "512")
 	assert.Equal(t, def.Cpu, "512")
 	assert.Equal(t, def.Memory, "2048")
 	assert.Equal(t, def.Memory, "2048")
@@ -273,7 +280,7 @@ services:
         limits:
         limits:
           cpus: '4'
           cpus: '4'
           memory: 8192M
           memory: 8192M
-`)
+`, useDefaultVPC)
 	def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	assert.Equal(t, def.Cpu, "4096")
 	assert.Equal(t, def.Cpu, "4096")
 	assert.Equal(t, def.Memory, "8192")
 	assert.Equal(t, def.Memory, "8192")
@@ -292,7 +299,7 @@ services:
             - discrete_resource_spec:
             - discrete_resource_spec:
                 kind: gpus
                 kind: gpus
                 value: 2
                 value: 2
-`)
+`, useDefaultVPC, useGPU)
 	def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	assert.Equal(t, def.Cpu, "4000")
 	assert.Equal(t, def.Cpu, "4000")
 	assert.Equal(t, def.Memory, "792")
 	assert.Equal(t, def.Memory, "792")
@@ -308,26 +315,11 @@ services:
             - discrete_resource_spec:
             - discrete_resource_spec:
                 kind: gpus
                 kind: gpus
                 value: 2
                 value: 2
-`)
+`, useDefaultVPC, useGPU)
 	def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	assert.Equal(t, def.Cpu, "")
 	assert.Equal(t, def.Cpu, "")
 	assert.Equal(t, def.Memory, "")
 	assert.Equal(t, def.Memory, "")
 }
 }
-func TestTaskSizeConvertFailure(t *testing.T) {
-	model := loadConfig(t, `
-services:
-  test:
-    image: nginx
-    deploy:
-      resources:
-        limits:
-          cpus: '0.5'
-          memory: 2043248M
-`)
-	backend := &ecsAPIService{}
-	_, err := backend.convert(model, awsResources{})
-	assert.ErrorContains(t, err, "the resources requested are not supported by ECS/Fargate")
-}
 
 
 func TestLoadBalancerTypeNetwork(t *testing.T) {
 func TestLoadBalancerTypeNetwork(t *testing.T) {
 	template := convertYaml(t, `
 	template := convertYaml(t, `
@@ -337,13 +329,32 @@ services:
     ports:
     ports:
       - 80:80
       - 80:80
       - 88:88
       - 88:88
-`)
+`, useDefaultVPC)
 	lb := template.Resources["LoadBalancer"]
 	lb := template.Resources["LoadBalancer"]
 	assert.Check(t, lb != nil)
 	assert.Check(t, lb != nil)
 	loadBalancer := *lb.(*elasticloadbalancingv2.LoadBalancer)
 	loadBalancer := *lb.(*elasticloadbalancingv2.LoadBalancer)
 	assert.Check(t, loadBalancer.Type == elbv2.LoadBalancerTypeEnumNetwork)
 	assert.Check(t, loadBalancer.Type == elbv2.LoadBalancerTypeEnumNetwork)
 }
 }
 
 
+func TestUseCustomNetwork(t *testing.T) {
+	template := convertYaml(t, `
+services:
+  test:
+    image: nginx
+networks:
+  default:
+    external: true
+    name: sg-123abc
+`, useDefaultVPC, func(m *MockAPIMockRecorder) {
+		m.SecurityGroupExists(gomock.Any(), "sg-123abc").Return(true, nil)
+	})
+	assert.Check(t, template.Resources["DefaultNetwork"] == nil)
+	assert.Check(t, template.Resources["DefaultNetworkIngress"] == nil)
+	s := template.Resources["TestService"].(*ecs.Service)
+	assert.Check(t, s != nil)
+	assert.Check(t, s.NetworkConfiguration.AwsvpcConfiguration.SecurityGroups[0] == "sg-123abc") //nolint:staticcheck
+}
+
 func TestServiceMapping(t *testing.T) {
 func TestServiceMapping(t *testing.T) {
 	template := convertYaml(t, `
 	template := convertYaml(t, `
 services:
 services:
@@ -360,7 +371,7 @@ services:
     init: true
     init: true
     user: "user"
     user: "user"
     working_dir: "working_dir"
     working_dir: "working_dir"
-`)
+`, useDefaultVPC)
 	def := template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	def := template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	container := getMainContainer(def, t)
 	container := getMainContainer(def, t)
 	assert.Equal(t, container.Image, "image")
 	assert.Equal(t, container.Image, "image")
@@ -391,7 +402,7 @@ services:
     ports:
     ports:
       - 80:80
       - 80:80
       - 88:88
       - 88:88
-`)
+`, useDefaultVPC)
 	for _, r := range template.Resources {
 	for _, r := range template.Resources {
 		tags := reflect.Indirect(reflect.ValueOf(r)).FieldByName("Tags")
 		tags := reflect.Indirect(reflect.ValueOf(r)).FieldByName("Tags")
 		if !tags.IsValid() {
 		if !tags.IsValid() {
@@ -401,38 +412,26 @@ services:
 			k := tags.Index(i).FieldByName("Key").String()
 			k := tags.Index(i).FieldByName("Key").String()
 			v := tags.Index(i).FieldByName("Value").String()
 			v := tags.Index(i).FieldByName("Value").String()
 			if k == compose.ProjectTag {
 			if k == compose.ProjectTag {
-				assert.Equal(t, v, "Test")
+				assert.Equal(t, v, t.Name())
 			}
 			}
 		}
 		}
 	}
 	}
 }
 }
 
 
-func convertResultAsString(t *testing.T, project *types.Project) string {
-	backend := &ecsAPIService{}
-	template, err := backend.convert(project, awsResources{
-		vpc:     "vpcID",
-		subnets: []string{"subnet1", "subnet2"},
-	})
-	assert.NilError(t, err)
-	resultAsJSON, err := marshall(template)
-	assert.NilError(t, err)
-	return fmt.Sprintf("%s\n", string(resultAsJSON))
-}
+func convertYaml(t *testing.T, yaml string, fn ...func(m *MockAPIMockRecorder)) *cloudformation.Template {
+	project := loadConfig(t, yaml)
+	ctrl := gomock.NewController(t)
+	defer ctrl.Finish()
 
 
-func load(t *testing.T, paths ...string) *types.Project {
-	options := cli.ProjectOptions{
-		Name:        t.Name(),
-		ConfigPaths: paths,
+	m := NewMockAPI(ctrl)
+	for _, f := range fn {
+		f(m.EXPECT())
 	}
 	}
-	project, err := cli.ProjectFromOptions(&options)
-	assert.NilError(t, err)
-	return project
-}
 
 
-func convertYaml(t *testing.T, yaml string) *cloudformation.Template {
-	project := loadConfig(t, yaml)
-	backend := &ecsAPIService{}
-	template, err := backend.convert(project, awsResources{})
+	backend := &ecsAPIService{
+		aws: m,
+	}
+	template, err := backend.convert(context.TODO(), project)
 	assert.NilError(t, err)
 	assert.NilError(t, err)
 	return template
 	return template
 }
 }
@@ -445,7 +444,7 @@ func loadConfig(t *testing.T, yaml string) *types.Project {
 			{Config: dict},
 			{Config: dict},
 		},
 		},
 	}, func(options *loader.Options) {
 	}, func(options *loader.Options) {
-		options.Name = "Test"
+		options.Name = t.Name()
 	})
 	})
 	assert.NilError(t, err)
 	assert.NilError(t, err)
 	return model
 	return model
@@ -460,3 +459,12 @@ func getMainContainer(def *ecs.TaskDefinition, t *testing.T) ecs.TaskDefinition_
 	t.Fail()
 	t.Fail()
 	return def.ContainerDefinitions[0]
 	return def.ContainerDefinitions[0]
 }
 }
+
+func useDefaultVPC(m *MockAPIMockRecorder) {
+	m.GetDefaultVPC(gomock.Any()).Return("vpc-123", nil)
+	m.GetSubNets(gomock.Any(), "vpc-123").Return([]string{"subnet1", "subnet2"}, nil)
+}
+
+func useGPU(m *MockAPIMockRecorder) {
+	m.GetParameter(gomock.Any(), gomock.Any()).Return("", nil)
+}

+ 1 - 0
ecs/compatibility.go

@@ -97,6 +97,7 @@ var compatibleComposeAttributes = []string{
 	"secrets.file",
 	"secrets.file",
 	"volumes",
 	"volumes",
 	"volumes.external",
 	"volumes.external",
+	"networks.external",
 }
 }
 
 
 func (c *fargateCompatibilityChecker) CheckImage(service *types.ServiceConfig) {
 func (c *fargateCompatibilityChecker) CheckImage(service *types.ServiceConfig) {

+ 5 - 5
ecs/down.go

@@ -23,17 +23,17 @@ import (
 )
 )
 
 
 func (b *ecsAPIService) Down(ctx context.Context, project string) error {
 func (b *ecsAPIService) Down(ctx context.Context, project string) error {
-	resources, err := b.SDK.ListStackResources(ctx, project)
+	resources, err := b.aws.ListStackResources(ctx, project)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	err = resources.apply(awsTypeCapacityProvider, delete(ctx, b.SDK.DeleteCapacityProvider))
+	err = resources.apply(awsTypeCapacityProvider, delete(ctx, b.aws.DeleteCapacityProvider))
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	err = resources.apply(awsTypeAutoscalingGroup, delete(ctx, b.SDK.DeleteAutoscalingGroup))
+	err = resources.apply(awsTypeAutoscalingGroup, delete(ctx, b.aws.DeleteAutoscalingGroup))
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -43,7 +43,7 @@ func (b *ecsAPIService) Down(ctx context.Context, project string) error {
 		return err
 		return err
 	}
 	}
 
 
-	err = b.SDK.DeleteStack(ctx, project)
+	err = b.aws.DeleteStack(ctx, project)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -51,7 +51,7 @@ func (b *ecsAPIService) Down(ctx context.Context, project string) error {
 }
 }
 
 
 func (b *ecsAPIService) previousStackEvents(ctx context.Context, project string) ([]string, error) {
 func (b *ecsAPIService) previousStackEvents(ctx context.Context, project string) ([]string, error) {
-	events, err := b.SDK.DescribeStackEvents(ctx, project)
+	events, err := b.aws.DescribeStackEvents(ctx, project)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 1 - 1
ecs/ec2.go

@@ -41,7 +41,7 @@ func (b *ecsAPIService) createCapacityProvider(ctx context.Context, project *typ
 		return nil
 		return nil
 	}
 	}
 
 
-	ami, err := b.SDK.GetParameter(ctx, "/aws/service/ecs/optimized-ami/amazon-linux-2/gpu/recommended")
+	ami, err := b.aws.GetParameter(ctx, "/aws/service/ecs/optimized-ami/amazon-linux-2/gpu/recommended")
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 6 - 6
ecs/list.go

@@ -24,7 +24,7 @@ import (
 )
 )
 
 
 func (b *ecsAPIService) List(ctx context.Context, project string) ([]compose.Stack, error) {
 func (b *ecsAPIService) List(ctx context.Context, project string) ([]compose.Stack, error) {
-	stacks, err := b.SDK.ListStacks(ctx, project)
+	stacks, err := b.aws.ListStacks(ctx, project)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -42,7 +42,7 @@ func (b *ecsAPIService) List(ctx context.Context, project string) ([]compose.Sta
 }
 }
 
 
 func (b *ecsAPIService) checkStackState(ctx context.Context, name string) error {
 func (b *ecsAPIService) checkStackState(ctx context.Context, name string) error {
-	resources, err := b.SDK.ListStackResources(ctx, name)
+	resources, err := b.aws.ListStackResources(ctx, name)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -65,7 +65,7 @@ func (b *ecsAPIService) checkStackState(ctx context.Context, name string) error
 	if len(svcArns) == 0 {
 	if len(svcArns) == 0 {
 		return nil
 		return nil
 	}
 	}
-	services, err := b.SDK.GetServiceTaskDefinition(ctx, cluster, svcArns)
+	services, err := b.aws.GetServiceTaskDefinition(ctx, cluster, svcArns)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -78,14 +78,14 @@ func (b *ecsAPIService) checkStackState(ctx context.Context, name string) error
 }
 }
 
 
 func (b *ecsAPIService) checkServiceState(ctx context.Context, cluster string, service string, taskdef string) error {
 func (b *ecsAPIService) checkServiceState(ctx context.Context, cluster string, service string, taskdef string) error {
-	runningTasks, err := b.SDK.GetServiceTasks(ctx, cluster, service, false)
+	runningTasks, err := b.aws.GetServiceTasks(ctx, cluster, service, false)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 	if len(runningTasks) > 0 {
 	if len(runningTasks) > 0 {
 		return nil
 		return nil
 	}
 	}
-	stoppedTasks, err := b.SDK.GetServiceTasks(ctx, cluster, service, true)
+	stoppedTasks, err := b.aws.GetServiceTasks(ctx, cluster, service, true)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -102,7 +102,7 @@ func (b *ecsAPIService) checkServiceState(ctx context.Context, cluster string, s
 	if len(tasks) == 0 {
 	if len(tasks) == 0 {
 		return nil
 		return nil
 	}
 	}
-	reason, err := b.SDK.GetTaskStoppedReason(ctx, cluster, tasks[0])
+	reason, err := b.aws.GetTaskStoppedReason(ctx, cluster, tasks[0])
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 1 - 1
ecs/logs.go

@@ -31,7 +31,7 @@ func (b *ecsAPIService) Logs(ctx context.Context, project string, w io.Writer) e
 		width:  0,
 		width:  0,
 		writer: w,
 		writer: w,
 	}
 	}
-	err := b.SDK.GetLogs(ctx, project, consumer.Log)
+	err := b.aws.GetLogs(ctx, project, consumer.Log)
 	return err
 	return err
 }
 }
 
 

+ 3 - 3
ecs/ps.go

@@ -25,11 +25,11 @@ import (
 )
 )
 
 
 func (b *ecsAPIService) Ps(ctx context.Context, project string) ([]compose.ServiceStatus, error) {
 func (b *ecsAPIService) Ps(ctx context.Context, project string) ([]compose.ServiceStatus, error) {
-	cluster, err := b.SDK.GetStackClusterID(ctx, project)
+	cluster, err := b.aws.GetStackClusterID(ctx, project)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	servicesARN, err := b.SDK.ListStackServices(ctx, project)
+	servicesARN, err := b.aws.ListStackServices(ctx, project)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -40,7 +40,7 @@ func (b *ecsAPIService) Ps(ctx context.Context, project string) ([]compose.Servi
 
 
 	status := []compose.ServiceStatus{}
 	status := []compose.ServiceStatus{}
 	for _, arn := range servicesARN {
 	for _, arn := range servicesARN {
-		state, err := b.SDK.DescribeService(ctx, cluster, arn)
+		state, err := b.aws.DescribeService(ctx, cluster, arn)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}

+ 5 - 3
ecs/sdk.go

@@ -23,9 +23,6 @@ import (
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
-	"github.com/aws/aws-sdk-go/service/ssm"
-	"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
-
 	"github.com/docker/compose-cli/api/compose"
 	"github.com/docker/compose-cli/api/compose"
 	"github.com/docker/compose-cli/api/secrets"
 	"github.com/docker/compose-cli/api/secrets"
 
 
@@ -50,6 +47,8 @@ import (
 	"github.com/aws/aws-sdk-go/service/iam/iamiface"
 	"github.com/aws/aws-sdk-go/service/iam/iamiface"
 	"github.com/aws/aws-sdk-go/service/secretsmanager"
 	"github.com/aws/aws-sdk-go/service/secretsmanager"
 	"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
 	"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
+	"github.com/aws/aws-sdk-go/service/ssm"
+	"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
 	"github.com/hashicorp/go-multierror"
 	"github.com/hashicorp/go-multierror"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 )
 )
@@ -67,6 +66,9 @@ type sdk struct {
 	AG  autoscalingiface.AutoScalingAPI
 	AG  autoscalingiface.AutoScalingAPI
 }
 }
 
 
+// sdk implement API
+var _ API = sdk{}
+
 func newSDK(sess *session.Session) sdk {
 func newSDK(sess *session.Session) sdk {
 	sess.Handlers.Build.PushBack(func(r *request.Request) {
 	sess.Handlers.Build.PushBack(func(r *request.Request) {
 		request.AddToUserAgent(r, "Docker CLI")
 		request.AddToUserAgent(r, "Docker CLI")

+ 4 - 4
ecs/secrets.go

@@ -23,17 +23,17 @@ import (
 )
 )
 
 
 func (b *ecsAPIService) CreateSecret(ctx context.Context, secret secrets.Secret) (string, error) {
 func (b *ecsAPIService) CreateSecret(ctx context.Context, secret secrets.Secret) (string, error) {
-	return b.SDK.CreateSecret(ctx, secret)
+	return b.aws.CreateSecret(ctx, secret)
 }
 }
 
 
 func (b *ecsAPIService) InspectSecret(ctx context.Context, id string) (secrets.Secret, error) {
 func (b *ecsAPIService) InspectSecret(ctx context.Context, id string) (secrets.Secret, error) {
-	return b.SDK.InspectSecret(ctx, id)
+	return b.aws.InspectSecret(ctx, id)
 }
 }
 
 
 func (b *ecsAPIService) ListSecrets(ctx context.Context) ([]secrets.Secret, error) {
 func (b *ecsAPIService) ListSecrets(ctx context.Context) ([]secrets.Secret, error) {
-	return b.SDK.ListSecrets(ctx)
+	return b.aws.ListSecrets(ctx)
 }
 }
 
 
 func (b *ecsAPIService) DeleteSecret(ctx context.Context, id string, recover bool) error {
 func (b *ecsAPIService) DeleteSecret(ctx context.Context, id string, recover bool) error {
-	return b.SDK.DeleteSecret(ctx, id, recover)
+	return b.aws.DeleteSecret(ctx, id, recover)
 }
 }

+ 3 - 3
ecs/testdata/simple/simple-cloudformation-conversion.golden

@@ -5,7 +5,7 @@
       "Properties": {
       "Properties": {
         "Description": "Service Map for Docker Compose project TestSimpleConvert",
         "Description": "Service Map for Docker Compose project TestSimpleConvert",
         "Name": "TestSimpleConvert.local",
         "Name": "TestSimpleConvert.local",
-        "Vpc": "vpcID"
+        "Vpc": "vpc-123"
       },
       },
       "Type": "AWS::ServiceDiscovery::PrivateDnsNamespace"
       "Type": "AWS::ServiceDiscovery::PrivateDnsNamespace"
     },
     },
@@ -47,7 +47,7 @@
             "Value": "default"
             "Value": "default"
           }
           }
         ],
         ],
-        "VpcId": "vpcID"
+        "VpcId": "vpc-123"
       },
       },
       "Type": "AWS::EC2::SecurityGroup"
       "Type": "AWS::EC2::SecurityGroup"
     },
     },
@@ -218,7 +218,7 @@
           }
           }
         ],
         ],
         "TargetType": "ip",
         "TargetType": "ip",
-        "VpcId": "vpcID"
+        "VpcId": "vpc-123"
       },
       },
       "Type": "AWS::ElasticLoadBalancingV2::TargetGroup"
       "Type": "AWS::ElasticLoadBalancingV2::TargetGroup"
     },
     },

+ 5 - 5
ecs/up.go

@@ -27,7 +27,7 @@ import (
 )
 )
 
 
 func (b *ecsAPIService) Up(ctx context.Context, project *types.Project, detach bool) error {
 func (b *ecsAPIService) Up(ctx context.Context, project *types.Project, detach bool) error {
-	err := b.SDK.CheckRequirements(ctx, b.Region)
+	err := b.aws.CheckRequirements(ctx, b.Region)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -37,23 +37,23 @@ func (b *ecsAPIService) Up(ctx context.Context, project *types.Project, detach b
 		return err
 		return err
 	}
 	}
 
 
-	update, err := b.SDK.StackExists(ctx, project.Name)
+	update, err := b.aws.StackExists(ctx, project.Name)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 	operation := stackCreate
 	operation := stackCreate
 	if update {
 	if update {
 		operation = stackUpdate
 		operation = stackUpdate
-		changeset, err := b.SDK.CreateChangeSet(ctx, project.Name, template)
+		changeset, err := b.aws.CreateChangeSet(ctx, project.Name, template)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		err = b.SDK.UpdateStack(ctx, changeset)
+		err = b.aws.UpdateStack(ctx, changeset)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
 	} else {
 	} else {
-		err = b.SDK.CreateStack(ctx, project.Name, template)
+		err = b.aws.CreateStack(ctx, project.Name, template)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}

+ 4 - 4
ecs/wait.go

@@ -37,7 +37,7 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
 	// progress writer
 	// progress writer
 	w := progress.ContextWriter(ctx)
 	w := progress.ContextWriter(ctx)
 	// Get the unique Stack ID so we can collect events without getting some from previous deployments with same name
 	// Get the unique Stack ID so we can collect events without getting some from previous deployments with same name
-	stackID, err := b.SDK.GetStackID(ctx, name)
+	stackID, err := b.aws.GetStackID(ctx, name)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -45,7 +45,7 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
 	ticker := time.NewTicker(1 * time.Second)
 	ticker := time.NewTicker(1 * time.Second)
 	done := make(chan bool)
 	done := make(chan bool)
 	go func() {
 	go func() {
-		b.SDK.WaitStackComplete(ctx, stackID, operation) //nolint:errcheck
+		b.aws.WaitStackComplete(ctx, stackID, operation) //nolint:errcheck
 		ticker.Stop()
 		ticker.Stop()
 		done <- true
 		done <- true
 	}()
 	}()
@@ -58,7 +58,7 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
 			completed = true
 			completed = true
 		case <-ticker.C:
 		case <-ticker.C:
 		}
 		}
-		events, err := b.SDK.DescribeStackEvents(ctx, stackID)
+		events, err := b.aws.DescribeStackEvents(ctx, stackID)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -111,7 +111,7 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
 			continue
 			continue
 		}
 		}
 		if err := b.checkStackState(ctx, name); err != nil {
 		if err := b.checkStackState(ctx, name); err != nil {
-			if e := b.SDK.DeleteStack(ctx, name); e != nil {
+			if e := b.aws.DeleteStack(ctx, name); e != nil {
 				return e
 				return e
 			}
 			}
 			stackErr = err
 			stackErr = err

+ 1 - 0
go.mod

@@ -36,6 +36,7 @@ require (
 	github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect
 	github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect
 	github.com/gobwas/pool v0.2.0 // indirect
 	github.com/gobwas/pool v0.2.0 // indirect
 	github.com/gobwas/ws v1.0.4
 	github.com/gobwas/ws v1.0.4
+	github.com/golang/mock v1.4.4
 	github.com/golang/protobuf v1.4.2
 	github.com/golang/protobuf v1.4.2
 	github.com/google/go-cmp v0.5.2
 	github.com/google/go-cmp v0.5.2
 	github.com/google/uuid v1.1.2
 	github.com/google/uuid v1.1.2

+ 1 - 0
go.sum

@@ -220,6 +220,7 @@ github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFU
 github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
 github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
 github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
 github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
 github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
 github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
+github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc=
 github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
 github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
 github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
 github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
 github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
 github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=