Browse Source

Merge pull request #633 from docker/ec2

Nicolas De loof 5 năm trước cách đây
mục cha
commit
767ed0c20d
9 tập tin đã thay đổi với 454 bổ sung133 xóa
  1. 47 71
      ecs/cloudformation.go
  2. 48 10
      ecs/cloudformation_test.go
  3. 103 36
      ecs/convert.go
  4. 113 0
      ecs/ec2.go
  5. 34 10
      ecs/gpu.go
  6. 4 4
      ecs/gpu_test.go
  7. 15 1
      ecs/iam.go
  8. 32 1
      ecs/sdk.go
  9. 58 0
      ecs/tags.go

+ 47 - 71
ecs/cloudformation.go

@@ -23,8 +23,6 @@ import (
 	"regexp"
 	"strings"
 
-	"github.com/docker/compose-cli/api/compose"
-
 	ecsapi "github.com/aws/aws-sdk-go/service/ecs"
 	"github.com/aws/aws-sdk-go/service/elbv2"
 	cloudmapapi "github.com/aws/aws-sdk-go/service/servicediscovery"
@@ -36,7 +34,6 @@ import (
 	"github.com/awslabs/goformation/v4/cloudformation/logs"
 	"github.com/awslabs/goformation/v4/cloudformation/secretsmanager"
 	cloudmap "github.com/awslabs/goformation/v4/cloudformation/servicediscovery"
-	"github.com/awslabs/goformation/v4/cloudformation/tags"
 	"github.com/compose-spec/compose-go/compatibility"
 	"github.com/compose-spec/compose-go/errdefs"
 	"github.com/compose-spec/compose-go/types"
@@ -52,7 +49,7 @@ const (
 )
 
 func (b *ecsAPIService) Convert(ctx context.Context, project *types.Project) ([]byte, error) {
-	template, err := b.convert(project)
+	template, networks, err := b.convert(project)
 	if err != nil {
 		return nil, err
 	}
@@ -97,11 +94,16 @@ func (b *ecsAPIService) Convert(ctx context.Context, project *types.Project) ([]
 		}
 	}
 
+	err = b.createCapacityProvider(ctx, project, networks, template)
+	if err != nil {
+		return nil, err
+	}
+
 	return marshall(template)
 }
 
 // Convert a compose project into a CloudFormation template
-func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Template, error) { //nolint:gocyclo
+func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Template, map[string]string, error) { //nolint:gocyclo
 	var checker compatibility.Checker = &fargateCompatibilityChecker{
 		compatibility.AllowList{
 			Supported: compatibleComposeAttributes,
@@ -116,7 +118,7 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 		}
 	}
 	if !compatibility.IsCompatible(checker) {
-		return nil, fmt.Errorf("compose file is incompatible with Amazon ECS")
+		return nil, nil, fmt.Errorf("compose file is incompatible with Amazon ECS")
 	}
 
 	template := cloudformation.NewTemplate()
@@ -152,7 +154,6 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 		Description: "Name of the LoadBalancer to connect to (optional)",
 	}
 
-	// Createmount.nfs4: Connection timed out : unsuccessful EFS utils command execution; code: 32 Cluster is `ParameterClusterName` parameter is not set
 	template.Conditions["CreateCluster"] = cloudformation.Equals("", cloudformation.Ref(parameterClusterName))
 
 	cluster := createCluster(project, template)
@@ -168,19 +169,14 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 		}
 		secret, err := ioutil.ReadFile(s.File)
 		if err != nil {
-			return nil, err
+			return nil, nil, err
 		}
 
 		name := fmt.Sprintf("%sSecret", normalizeResourceName(s.Name))
 		template.Resources[name] = &secretsmanager.Secret{
 			Description:  "",
 			SecretString: string(secret),
-			Tags: []tags.Tag{
-				{
-					Key:   compose.ProjectTag,
-					Value: project.Name,
-				},
-			},
+			Tags:         projectTags(project),
 		}
 		s.Name = cloudformation.Ref(name)
 		project.Secrets[i] = s
@@ -197,7 +193,7 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 
 		definition, err := convert(project, service)
 		if err != nil {
-			return nil, err
+			return nil, nil, err
 		}
 
 		taskExecutionRole := createTaskExecutionRole(service, definition, template)
@@ -226,10 +222,8 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 			for _, port := range service.Ports {
 				protocol := strings.ToUpper(port.Protocol)
 				if getLoadBalancerType(project) == elbv2.LoadBalancerTypeEnumApplication {
-					protocol = elbv2.ProtocolEnumHttps
-					if port.Published == 80 {
-						protocol = elbv2.ProtocolEnumHttp
-					}
+					// we don't set Https as a certificate must be specified for HTTPS listeners
+					protocol = elbv2.ProtocolEnumHttp
 				}
 				if loadBalancerARN != "" {
 					targetGroupName := createTargetGroup(project, service, port, template, protocol)
@@ -255,7 +249,16 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 
 		minPercent, maxPercent, err := computeRollingUpdateLimits(service)
 		if err != nil {
-			return nil, err
+			return nil, nil, err
+		}
+
+		assignPublicIP := ecsapi.AssignPublicIpEnabled
+		launchType := ecsapi.LaunchTypeFargate
+		platformVersion := "1.4.0" // LATEST which is set to 1.3.0 (?) which doesn’t allow efs volumes.
+		if requireEC2(service) {
+			assignPublicIP = ecsapi.AssignPublicIpDisabled
+			launchType = ecsapi.LaunchTypeEc2
+			platformVersion = "" // The platform version must be null when specifying an EC2 launch type
 		}
 
 		template.Resources[serviceResourceName(service.Name)] = &ecs.Service{
@@ -269,11 +272,12 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 				MaximumPercent:        maxPercent,
 				MinimumHealthyPercent: minPercent,
 			},
-			LaunchType:    ecsapi.LaunchTypeFargate,
+			LaunchType: launchType,
+			// TODO we miss support for https://github.com/aws/containers-roadmap/issues/631 to select a capacity provider
 			LoadBalancers: serviceLB,
 			NetworkConfiguration: &ecs.Service_NetworkConfiguration{
 				AwsvpcConfiguration: &ecs.Service_AwsVpcConfiguration{
-					AssignPublicIp: ecsapi.AssignPublicIpEnabled,
+					AssignPublicIp: assignPublicIP,
 					SecurityGroups: serviceSecurityGroups,
 					Subnets: []string{
 						cloudformation.Ref(parameterSubnet1Id),
@@ -281,24 +285,15 @@ func (b *ecsAPIService) convert(project *types.Project) (*cloudformation.Templat
 					},
 				},
 			},
-			PlatformVersion:    "1.4.0", // LATEST which is set to 1.3.0 (?) which doesn’t allow efs volumes.
+			PlatformVersion:    platformVersion,
 			PropagateTags:      ecsapi.PropagateTagsService,
 			SchedulingStrategy: ecsapi.SchedulingStrategyReplica,
 			ServiceRegistries:  []ecs.Service_ServiceRegistry{serviceRegistry},
-			Tags: []tags.Tag{
-				{
-					Key:   compose.ProjectTag,
-					Value: project.Name,
-				},
-				{
-					Key:   compose.ServiceTag,
-					Value: service.Name,
-				},
-			},
-			TaskDefinition: cloudformation.Ref(normalizeResourceName(taskDefinition)),
+			Tags:               serviceTags(project, service),
+			TaskDefinition:     cloudformation.Ref(normalizeResourceName(taskDefinition)),
 		}
 	}
-	return template, nil
+	return template, networks, nil
 }
 
 func createLogGroup(project *types.Project, template *cloudformation.Template) {
@@ -413,12 +408,7 @@ func createLoadBalancer(project *types.Project, template *cloudformation.Templat
 			cloudformation.Ref(parameterSubnet1Id),
 			cloudformation.Ref(parameterSubnet2Id),
 		},
-		Tags: []tags.Tag{
-			{
-				Key:   compose.ProjectTag,
-				Value: project.Name,
-			},
-		},
+		Tags:                       projectTags(project),
 		Type:                       loadBalancerType,
 		AWSCloudFormationCondition: "CreateLoadBalancer",
 	}
@@ -462,16 +452,12 @@ func createTargetGroup(project *types.Project, service types.ServiceConfig, port
 		port.Published,
 	)
 	template.Resources[targetGroupName] = &elasticloadbalancingv2.TargetGroup{
-		Port:     int(port.Target),
-		Protocol: protocol,
-		Tags: []tags.Tag{
-			{
-				Key:   compose.ProjectTag,
-				Value: project.Name,
-			},
-		},
-		VpcId:      cloudformation.Ref(parameterVPCId),
-		TargetType: elbv2.TargetTypeEnumIp,
+		HealthCheckEnabled: false,
+		Port:               int(port.Target),
+		Protocol:           protocol,
+		Tags:               projectTags(project),
+		TargetType:         elbv2.TargetTypeEnumIp,
+		VpcId:              cloudformation.Ref(parameterVPCId),
 	}
 	return targetGroupName
 }
@@ -507,7 +493,7 @@ func createTaskExecutionRole(service types.ServiceConfig, definition *ecs.TaskDe
 	taskExecutionRole := fmt.Sprintf("%sTaskExecutionRole", normalizeResourceName(service.Name))
 	policies := createPolicies(service, definition)
 	template.Resources[taskExecutionRole] = &iam.Role{
-		AssumeRolePolicyDocument: assumeRolePolicyDocument,
+		AssumeRolePolicyDocument: ecsTaskAssumeRolePolicyDocument,
 		Policies:                 policies,
 		ManagedPolicyArns: []string{
 			ecsTaskExecutionPolicy,
@@ -535,7 +521,7 @@ func createTaskRole(service types.ServiceConfig, template *cloudformation.Templa
 		return ""
 	}
 	template.Resources[taskRole] = &iam.Role{
-		AssumeRolePolicyDocument: assumeRolePolicyDocument,
+		AssumeRolePolicyDocument: ecsTaskAssumeRolePolicyDocument,
 		Policies:                 rolePolicies,
 		ManagedPolicyArns:        managedPolicies,
 	}
@@ -544,13 +530,8 @@ func createTaskRole(service types.ServiceConfig, template *cloudformation.Templa
 
 func createCluster(project *types.Project, template *cloudformation.Template) string {
 	template.Resources["Cluster"] = &ecs.Cluster{
-		ClusterName: project.Name,
-		Tags: []tags.Tag{
-			{
-				Key:   compose.ProjectTag,
-				Value: project.Name,
-			},
-		},
+		ClusterName:                project.Name,
+		Tags:                       projectTags(project),
 		AWSCloudFormationCondition: "CreateCluster",
 	}
 	cluster := cloudformation.If("CreateCluster", cloudformation.Ref("Cluster"), cloudformation.Ref(parameterClusterName))
@@ -580,11 +561,15 @@ func convertNetwork(project *types.Project, net types.NetworkConfig, vpc string,
 		for _, service := range project.Services {
 			if _, ok := service.Networks[net.Name]; ok {
 				for _, port := range service.Ports {
+					protocol := strings.ToUpper(port.Protocol)
+					if protocol == "" {
+						protocol = "-1"
+					}
 					ingresses = append(ingresses, ec2.SecurityGroup_Ingress{
 						CidrIp:      "0.0.0.0/0",
 						Description: fmt.Sprintf("%s:%d/%s", service.Name, port.Target, port.Protocol),
 						FromPort:    int(port.Target),
-						IpProtocol:  strings.ToUpper(port.Protocol),
+						IpProtocol:  protocol,
 						ToPort:      int(port.Target),
 					})
 				}
@@ -598,16 +583,7 @@ func convertNetwork(project *types.Project, net types.NetworkConfig, vpc string,
 		GroupName:            securityGroup,
 		SecurityGroupIngress: ingresses,
 		VpcId:                vpc,
-		Tags: []tags.Tag{
-			{
-				Key:   compose.ProjectTag,
-				Value: project.Name,
-			},
-			{
-				Key:   compose.NetworkTag,
-				Value: net.Name,
-			},
-		},
+		Tags:                 networkTags(project, net),
 	}
 
 	ingress := securityGroup + "Ingress"

+ 48 - 10
ecs/cloudformation_test.go

@@ -245,6 +245,15 @@ services:
 
 func TestTaskSizeConvert(t *testing.T) {
 	template := convertYaml(t, `
+services:
+  test:
+    image: nginx
+`)
+	def := template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
+	assert.Equal(t, def.Cpu, "256")
+	assert.Equal(t, def.Memory, "512")
+
+	template = convertYaml(t, `
 services:
   test:
     image: nginx
@@ -253,11 +262,8 @@ services:
         limits:
           cpus: '0.5'
           memory: 2048M
-        reservations:
-          cpus: '0.5'
-          memory: 2048M
 `)
-	def := template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
+	def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	assert.Equal(t, def.Cpu, "512")
 	assert.Equal(t, def.Memory, "2048")
 
@@ -270,13 +276,45 @@ services:
         limits:
           cpus: '4'
           memory: 8192M
-        reservations:
-          cpus: '4'
-          memory: 8192M
 `)
 	def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
 	assert.Equal(t, def.Cpu, "4096")
 	assert.Equal(t, def.Memory, "8192")
+
+	template = convertYaml(t, `
+services:
+  test:
+    image: nginx
+    deploy:
+      resources:
+        limits:
+          cpus: '4'
+          memory: 792Mb
+        reservations:
+          generic_resources: 
+            - discrete_resource_spec:
+                kind: gpus
+                value: 2
+`)
+	def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
+	assert.Equal(t, def.Cpu, "4000")
+	assert.Equal(t, def.Memory, "792")
+
+	template = convertYaml(t, `
+services:
+  test:
+    image: nginx
+    deploy:
+      resources:
+        reservations:
+          generic_resources: 
+            - discrete_resource_spec:
+                kind: gpus
+                value: 2
+`)
+	def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
+	assert.Equal(t, def.Cpu, "")
+	assert.Equal(t, def.Memory, "")
 }
 func TestTaskSizeConvertFailure(t *testing.T) {
 	model := loadConfig(t, `
@@ -290,7 +328,7 @@ services:
           memory: 2043248M
 `)
 	backend := &ecsAPIService{}
-	_, err := backend.convert(model)
+	_, _, err := backend.convert(model)
 	assert.ErrorContains(t, err, "the resources requested are not supported by ECS/Fargate")
 }
 
@@ -374,7 +412,7 @@ services:
 
 func convertResultAsString(t *testing.T, project *types.Project) string {
 	backend := &ecsAPIService{}
-	template, err := backend.convert(project)
+	template, _, err := backend.convert(project)
 	assert.NilError(t, err)
 	resultAsJSON, err := marshall(template)
 	assert.NilError(t, err)
@@ -394,7 +432,7 @@ func load(t *testing.T, paths ...string) *types.Project {
 func convertYaml(t *testing.T, yaml string) *cloudformation.Template {
 	project := loadConfig(t, yaml)
 	backend := &ecsAPIService{}
-	template, err := backend.convert(project)
+	template, _, err := backend.convert(project)
 	assert.NilError(t, err)
 	return template
 }

+ 103 - 36
ecs/convert.go

@@ -102,6 +102,11 @@ func convert(project *types.Project, service types.ServiceConfig) (*ecs.TaskDefi
 		return nil, err
 	}
 
+	var reservations *types.Resource
+	if service.Deploy != nil && service.Deploy.Resources.Reservations != nil {
+		reservations = service.Deploy.Resources.Reservations
+	}
+
 	containers := append(initContainers, ecs.TaskDefinition_ContainerDefinition{
 		Command:                service.Command,
 		DisableNetworking:      service.NetworkMode == "none",
@@ -129,7 +134,7 @@ func convert(project *types.Project, service types.ServiceConfig) (*ecs.TaskDefi
 		PseudoTerminal:         service.Tty,
 		ReadonlyRootFilesystem: service.ReadOnly,
 		RepositoryCredentials:  credential,
-		ResourceRequirements:   nil,
+		ResourceRequirements:   toTaskResourceRequirements(reservations),
 		StartTimeout:           0,
 		StopTimeout:            durationToInt(service.StopGracePeriod),
 		SystemControls:         toSystemControls(service.Sysctls),
@@ -139,21 +144,44 @@ func convert(project *types.Project, service types.ServiceConfig) (*ecs.TaskDefi
 		WorkingDirectory:       service.WorkingDir,
 	})
 
+	launchType := ecsapi.LaunchTypeFargate
+	if requireEC2(service) {
+		launchType = ecsapi.LaunchTypeEc2
+	}
+
 	return &ecs.TaskDefinition{
-		ContainerDefinitions:    containers,
-		Cpu:                     cpu,
-		Family:                  fmt.Sprintf("%s-%s", project.Name, service.Name),
-		IpcMode:                 service.Ipc,
-		Memory:                  mem,
-		NetworkMode:             ecsapi.NetworkModeAwsvpc, // FIXME could be set by service.NetworkMode, Fargate only supports network mode ‘awsvpc’.
-		PidMode:                 service.Pid,
-		PlacementConstraints:    toPlacementConstraints(service.Deploy),
-		ProxyConfiguration:      nil,
-		RequiresCompatibilities: []string{ecsapi.LaunchTypeFargate},
-		Volumes:                 volumes,
+		ContainerDefinitions: containers,
+		Cpu:                  cpu,
+		Family:               fmt.Sprintf("%s-%s", project.Name, service.Name),
+		IpcMode:              service.Ipc,
+		Memory:               mem,
+		NetworkMode:          ecsapi.NetworkModeAwsvpc, // FIXME could be set by service.NetworkMode, Fargate only supports network mode ‘awsvpc’.
+		PidMode:              service.Pid,
+		PlacementConstraints: toPlacementConstraints(service.Deploy),
+		ProxyConfiguration:   nil,
+		RequiresCompatibilities: []string{
+			launchType,
+		},
+		Volumes: volumes,
 	}, nil
 }
 
+func toTaskResourceRequirements(reservations *types.Resource) []ecs.TaskDefinition_ResourceRequirement {
+	if reservations == nil {
+		return nil
+	}
+	var requirements []ecs.TaskDefinition_ResourceRequirement
+	for _, r := range reservations.GenericResources {
+		if r.DiscreteResourceSpec.Kind == "gpus" {
+			requirements = append(requirements, ecs.TaskDefinition_ResourceRequirement{
+				Type:  ecsapi.ResourceTypeGpu,
+				Value: fmt.Sprint(r.DiscreteResourceSpec.Value),
+			})
+		}
+	}
+	return requirements
+}
+
 func createSecretsSideCar(project *types.Project, service types.ServiceConfig, logConfiguration *ecs.TaskDefinition_LogConfiguration) (
 	ecs.TaskDefinition_Volume,
 	ecs.TaskDefinition_MountPoint,
@@ -295,8 +323,24 @@ func toSystemControls(sysctls types.Mapping) []ecs.TaskDefinition_SystemControl
 const miB = 1024 * 1024
 
 func toLimits(service types.ServiceConfig) (string, string, error) {
+	mem, cpu, err := getConfiguredLimits(service)
+	if err != nil {
+		return "", "", err
+	}
+	if requireEC2(service) {
+		// just return configured limits expressed in Mb and CPU units
+		var cpuLimit, memLimit string
+		if cpu > 0 {
+			cpuLimit = fmt.Sprint(cpu)
+		}
+		if mem > 0 {
+			memLimit = fmt.Sprint(mem / miB)
+		}
+		return cpuLimit, memLimit, nil
+	}
+
 	// All possible cpu/mem values for Fargate
-	cpuToMem := map[int64][]types.UnitBytes{
+	fargateCPUToMem := map[int64][]types.UnitBytes{
 		256:  {512, 1024, 2048},
 		512:  {1024, 2048, 3072, 4096},
 		1024: {2048, 3072, 4096, 5120, 6144, 7168, 8192},
@@ -305,37 +349,22 @@ func toLimits(service types.ServiceConfig) (string, string, error) {
 	}
 	cpuLimit := "256"
 	memLimit := "512"
-
-	if service.Deploy == nil {
+	if mem == 0 && cpu == 0 {
 		return cpuLimit, memLimit, nil
 	}
 
-	limits := service.Deploy.Resources.Limits
-	if limits == nil {
-		return cpuLimit, memLimit, nil
-	}
-
-	if limits.NanoCPUs == "" {
-		return cpuLimit, memLimit, nil
-	}
-
-	v, err := opts.ParseCPUs(limits.NanoCPUs)
-	if err != nil {
-		return "", "", err
-	}
-
 	var cpus []int64
-	for k := range cpuToMem {
+	for k := range fargateCPUToMem {
 		cpus = append(cpus, k)
 	}
 	sort.Slice(cpus, func(i, j int) bool { return cpus[i] < cpus[j] })
 
-	for _, cpu := range cpus {
-		mem := cpuToMem[cpu]
-		if v <= cpu*miB {
-			for _, m := range mem {
-				if limits.MemoryBytes <= m*miB {
-					cpuLimit = strconv.FormatInt(cpu, 10)
+	for _, fargateCPU := range cpus {
+		options := fargateCPUToMem[fargateCPU]
+		if cpu <= fargateCPU {
+			for _, m := range options {
+				if mem <= m*miB {
+					cpuLimit = strconv.FormatInt(fargateCPU, 10)
 					memLimit = strconv.FormatInt(int64(m), 10)
 					return cpuLimit, memLimit, nil
 				}
@@ -345,6 +374,27 @@ func toLimits(service types.ServiceConfig) (string, string, error) {
 	return "", "", fmt.Errorf("the resources requested are not supported by ECS/Fargate")
 }
 
+func getConfiguredLimits(service types.ServiceConfig) (types.UnitBytes, int64, error) {
+	if service.Deploy == nil {
+		return 0, 0, nil
+	}
+
+	limits := service.Deploy.Resources.Limits
+	if limits == nil {
+		return 0, 0, nil
+	}
+
+	if limits.NanoCPUs == "" {
+		return limits.MemoryBytes, 0, nil
+	}
+	v, err := opts.ParseCPUs(limits.NanoCPUs)
+	if err != nil {
+		return 0, 0, err
+	}
+
+	return limits.MemoryBytes, v / 1e6, nil
+}
+
 func toContainerReservation(service types.ServiceConfig) (string, int) {
 	cpuReservation := ".0"
 	memReservation := 0
@@ -490,3 +540,20 @@ func getRepoCredentials(service types.ServiceConfig) *ecs.TaskDefinition_Reposit
 	}
 	return nil
 }
+
+func requireEC2(s types.ServiceConfig) bool {
+	return gpuRequirements(s) > 0
+}
+
+func gpuRequirements(s types.ServiceConfig) int64 {
+	if deploy := s.Deploy; deploy != nil {
+		if reservations := deploy.Resources.Reservations; reservations != nil {
+			for _, resource := range reservations.GenericResources {
+				if resource.DiscreteResourceSpec.Kind == "gpus" {
+					return resource.DiscreteResourceSpec.Value
+				}
+			}
+		}
+	}
+	return 0
+}

+ 113 - 0
ecs/ec2.go

@@ -0,0 +1,113 @@
+/*
+   Copyright 2020 Docker Compose CLI authors
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package ecs
+
+import (
+	"context"
+	"encoding/base64"
+	"fmt"
+
+	"github.com/awslabs/goformation/v4/cloudformation"
+	"github.com/awslabs/goformation/v4/cloudformation/autoscaling"
+	"github.com/awslabs/goformation/v4/cloudformation/ecs"
+	"github.com/awslabs/goformation/v4/cloudformation/iam"
+	"github.com/compose-spec/compose-go/types"
+)
+
+func (b *ecsAPIService) createCapacityProvider(ctx context.Context, project *types.Project, networks map[string]string, template *cloudformation.Template) error {
+	var ec2 bool
+	for _, s := range project.Services {
+		if requireEC2(s) {
+			ec2 = true
+			break
+		}
+	}
+
+	if !ec2 {
+		return nil
+	}
+
+	ami, err := b.SDK.GetParameter(ctx, "/aws/service/ecs/optimized-ami/amazon-linux-2/gpu/recommended")
+	if err != nil {
+		return err
+	}
+
+	machineType, err := guessMachineType(project)
+	if err != nil {
+		return err
+	}
+
+	var securityGroups []string
+	for _, r := range networks {
+		securityGroups = append(securityGroups, r)
+	}
+
+	template.Resources["CapacityProvider"] = &ecs.CapacityProvider{
+		AutoScalingGroupProvider: &ecs.CapacityProvider_AutoScalingGroupProvider{
+			AutoScalingGroupArn: cloudformation.Ref("AutoscalingGroup"),
+			ManagedScaling: &ecs.CapacityProvider_ManagedScaling{
+				TargetCapacity: 100,
+			},
+		},
+		Tags:                       projectTags(project),
+		AWSCloudFormationCondition: "CreateCluster",
+	}
+
+	template.Resources["AutoscalingGroup"] = &autoscaling.AutoScalingGroup{
+		LaunchConfigurationName: cloudformation.Ref("LaunchConfiguration"),
+		MaxSize:                 "10", //TODO
+		MinSize:                 "1",
+		VPCZoneIdentifier: []string{
+			cloudformation.Ref(parameterSubnet1Id),
+			cloudformation.Ref(parameterSubnet2Id),
+		},
+		AWSCloudFormationCondition: "CreateCluster",
+	}
+
+	userData := base64.StdEncoding.EncodeToString([]byte(
+		fmt.Sprintf("#!/bin/bash\necho ECS_CLUSTER=%s >> /etc/ecs/ecs.config", project.Name)))
+
+	template.Resources["LaunchConfiguration"] = &autoscaling.LaunchConfiguration{
+		ImageId:                    ami,
+		InstanceType:               machineType,
+		SecurityGroups:             securityGroups,
+		IamInstanceProfile:         cloudformation.Ref("EC2InstanceProfile"),
+		UserData:                   userData,
+		AWSCloudFormationCondition: "CreateCluster",
+	}
+
+	template.Resources["EC2InstanceProfile"] = &iam.InstanceProfile{
+		Roles:                      []string{cloudformation.Ref("EC2InstanceRole")},
+		AWSCloudFormationCondition: "CreateCluster",
+	}
+
+	template.Resources["EC2InstanceRole"] = &iam.Role{
+		AssumeRolePolicyDocument: ec2InstanceAssumeRolePolicyDocument,
+		ManagedPolicyArns: []string{
+			ecsEC2InstanceRole,
+		},
+		Tags:                       projectTags(project),
+		AWSCloudFormationCondition: "CreateCluster",
+	}
+
+	cluster := template.Resources["Cluster"].(*ecs.Cluster)
+	cluster.CapacityProviders = []string{
+		cloudformation.Ref("CapacityProvider"),
+	}
+
+	return nil
+}

+ 34 - 10
ecs/gpu.go

@@ -34,23 +34,47 @@ type machine struct {
 
 type family []machine
 
-var p3family = family{
+var gpufamily = family{
 	{
-		id:     "p3.2xlarge",
+		id:     "g4dn.xlarge",
+		cpus:   4,
+		memory: 16 * units.GiB,
+		gpus:   1,
+	},
+	{
+		id:     "g4dn.2xlarge",
 		cpus:   8,
+		memory: 32 * units.GiB,
+		gpus:   1,
+	},
+	{
+		id:     "g4dn.4xlarge",
+		cpus:   16,
 		memory: 64 * units.GiB,
-		gpus:   2,
+		gpus:   1,
 	},
 	{
-		id:     "p3.8xlarge",
+		id:     "g4dn.8xlarge",
 		cpus:   32,
-		memory: 244 * units.GiB,
+		memory: 128 * units.GiB,
+		gpus:   1,
+	},
+	{
+		id:     "g4dn.12xlarge",
+		cpus:   48,
+		memory: 192 * units.GiB,
 		gpus:   4,
 	},
 	{
-		id:     "p3.16xlarge",
+		id:     "g4dn.16xlarge",
 		cpus:   64,
-		memory: 488 * units.GiB,
+		memory: 256 * units.GiB,
+		gpus:   1,
+	},
+	{
+		id:     "g4dn.metal",
+		cpus:   96,
+		memory: 384 * units.GiB,
 		gpus:   8,
 	},
 }
@@ -82,9 +106,9 @@ func guessMachineType(project *types.Project) (string, error) {
 		return "", err
 	}
 
-	instanceType, err := p3family.
+	instanceType, err := gpufamily.
 		filter(func(m machine) bool {
-			return m.memory >= requirements.memory
+			return m.memory > requirements.memory // actual memory available for ECS tasks < total machine memory
 		}).
 		filter(func(m machine) bool {
 			return m.cpus >= requirements.cpus
@@ -92,7 +116,7 @@ func guessMachineType(project *types.Project) (string, error) {
 		filter(func(m machine) bool {
 			return m.gpus >= requirements.gpus
 		}).
-		firstOrError("none of the Amazon EC2 P3 instance types meet the requirements for memory:%d cpu:%f gpus:%d", requirements.memory, requirements.cpus, requirements.gpus)
+		firstOrError("none of the Amazon EC2 G4 instance types meet the requirements for memory:%d cpu:%f gpus:%d", requirements.memory, requirements.cpus, requirements.gpus)
 	if err != nil {
 		return "", err
 	}

+ 4 - 4
ecs/gpu_test.go

@@ -41,7 +41,7 @@ services:
                          kind: gpus
                          value: 1
 `,
-			want:    "p3.2xlarge",
+			want:    "g4dn.xlarge",
 			wantErr: false,
 		},
 		{
@@ -58,7 +58,7 @@ services:
                          kind: gpus
                          value: 4
 `,
-			want:    "p3.8xlarge",
+			want:    "g4dn.12xlarge",
 			wantErr: false,
 		},
 		{
@@ -76,7 +76,7 @@ services:
                          kind: gpus
                          value: 2
 `,
-			want:    "p3.16xlarge",
+			want:    "g4dn.metal",
 			wantErr: false,
 		},
 		{
@@ -95,7 +95,7 @@ services:
                          kind: gpus
                          value: 2
 `,
-			want:    "p3.8xlarge",
+			want:    "g4dn.12xlarge",
 			wantErr: false,
 		},
 	}

+ 15 - 1
ecs/iam.go

@@ -19,13 +19,14 @@ package ecs
 const (
 	ecsTaskExecutionPolicy = "arn:aws:iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy"
 	ecrReadOnlyPolicy      = "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryReadOnly"
+	ecsEC2InstanceRole     = "arn:aws:iam::aws:policy/service-role/AmazonEC2ContainerServiceforEC2Role"
 
 	actionGetSecretValue = "secretsmanager:GetSecretValue"
 	actionGetParameters  = "ssm:GetParameters"
 	actionDecrypt        = "kms:Decrypt"
 )
 
-var assumeRolePolicyDocument = PolicyDocument{
+var ecsTaskAssumeRolePolicyDocument = PolicyDocument{
 	Version: "2012-10-17", // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_version.html
 	Statement: []PolicyStatement{
 		{
@@ -38,6 +39,19 @@ var assumeRolePolicyDocument = PolicyDocument{
 	},
 }
 
+var ec2InstanceAssumeRolePolicyDocument = PolicyDocument{
+	Version: "2012-10-17", // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_version.html
+	Statement: []PolicyStatement{
+		{
+			Effect: "Allow",
+			Principal: PolicyPrincipal{
+				Service: "ec2.amazonaws.com",
+			},
+			Action: []string{"sts:AssumeRole"},
+		},
+	},
+}
+
 // PolicyDocument describes an IAM policy document
 // could alternatively depend on https://github.com/kubernetes-sigs/cluster-api-provider-aws/blob/master/cmd/clusterawsadm/api/iam/v1alpha1/types.go
 type PolicyDocument struct {

+ 32 - 1
ecs/sdk.go

@@ -18,10 +18,14 @@ package ecs
 
 import (
 	"context"
+	"encoding/json"
 	"fmt"
 	"strings"
 	"time"
 
+	"github.com/aws/aws-sdk-go/service/ssm"
+	"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
+
 	"github.com/docker/compose-cli/api/compose"
 	"github.com/docker/compose-cli/api/secrets"
 
@@ -56,6 +60,7 @@ type sdk struct {
 	IAM iamiface.IAMAPI
 	CF  cloudformationiface.CloudFormationAPI
 	SM  secretsmanageriface.SecretsManagerAPI
+	SSM ssmiface.SSMAPI
 }
 
 func newSDK(sess *session.Session) sdk {
@@ -71,6 +76,7 @@ func newSDK(sess *session.Session) sdk {
 		IAM: iam.New(sess),
 		CF:  cloudformation.New(sess),
 		SM:  secretsmanager.New(sess),
+		SSM: ssm.New(sess),
 	}
 }
 
@@ -182,7 +188,7 @@ func (s sdk) StackExists(ctx context.Context, name string) (bool, error) {
 		StackName: aws.String(name),
 	})
 	if err != nil {
-		if strings.HasPrefix(err.Error(), fmt.Sprintf("ValidationError: Stack with id %s does not exist", name)) {
+		if strings.HasPrefix(err.Error(), fmt.Sprintf("ValidationError: Stack with ID %s does not exist", name)) {
 			return false, nil
 		}
 		return false, nil
@@ -688,3 +694,28 @@ func (s sdk) WithVolumeSecurityGroups(ctx context.Context, id string, fn func(se
 	}
 	return nil
 }
+
+func (s sdk) GetParameter(ctx context.Context, name string) (string, error) {
+	parameter, err := s.SSM.GetParameterWithContext(ctx, &ssm.GetParameterInput{
+		Name: aws.String(name),
+	})
+	if err != nil {
+		return "", err
+	}
+
+	value := *parameter.Parameter.Value
+	var ami struct {
+		SchemaVersion     int    `json:"schema_version"`
+		ImageName         string `json:"image_name"`
+		ImageID           string `json:"image_id"`
+		OS                string `json:"os"`
+		ECSRuntimeVersion string `json:"ecs_runtime_verion"`
+		ECSAgentVersion   string `json:"ecs_agent_version"`
+	}
+	err = json.Unmarshal([]byte(value), &ami)
+	if err != nil {
+		return "", err
+	}
+
+	return ami.ImageID, nil
+}

+ 58 - 0
ecs/tags.go

@@ -0,0 +1,58 @@
+/*
+   Copyright 2020 Docker Compose CLI authors
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+*/
+
+package ecs
+
+import (
+	"github.com/awslabs/goformation/v4/cloudformation/tags"
+	"github.com/compose-spec/compose-go/types"
+	"github.com/docker/compose-cli/api/compose"
+)
+
+func projectTags(project *types.Project) []tags.Tag {
+	return []tags.Tag{
+		{
+			Key:   compose.ProjectTag,
+			Value: project.Name,
+		},
+	}
+}
+
+func serviceTags(project *types.Project, service types.ServiceConfig) []tags.Tag {
+	return []tags.Tag{
+		{
+			Key:   compose.ProjectTag,
+			Value: project.Name,
+		},
+		{
+			Key:   compose.ServiceTag,
+			Value: service.Name,
+		},
+	}
+}
+
+func networkTags(project *types.Project, net types.NetworkConfig) []tags.Tag {
+	return []tags.Tag{
+		{
+			Key:   compose.ProjectTag,
+			Value: project.Name,
+		},
+		{
+			Key:   compose.NetworkTag,
+			Value: net.Name,
+		},
+	}
+}