Selaa lähdekoodia

Use `WithContext` SDK APIs so we can implement cancelation

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 5 vuotta sitten
vanhempi
sitoutus
b6be4a0ac3

+ 5 - 4
ecs/cmd/main/main.go

@@ -1,6 +1,7 @@
 package main
 
 import (
+	"context"
 	"fmt"
 
 	"github.com/docker/cli/cli-plugins/manager"
@@ -97,7 +98,7 @@ func ConvertCommand(clusteropts *clusterOptions, projectOpts *compose.ProjectOpt
 			if err != nil {
 				return err
 			}
-			template, err := client.Convert(project)
+			template, err := client.Convert(context.Background(), project)
 			if err != nil {
 				return err
 			}
@@ -123,7 +124,7 @@ func UpCommand(clusteropts *clusterOptions, projectOpts *compose.ProjectOptions)
 			if err != nil {
 				return err
 			}
-			return client.ComposeUp(project)
+			return client.ComposeUp(context.Background(), project)
 		}),
 	}
 	cmd.Flags().StringVar(&opts.loadBalancerArn, "load-balancer", "", "")
@@ -148,11 +149,11 @@ func DownCommand(clusteropts *clusterOptions, projectOpts *compose.ProjectOption
 				if err != nil {
 					return err
 				}
-				return client.ComposeDown(project.Name, opts.DeleteCluster)
+				return client.ComposeDown(context.Background(), project.Name, opts.DeleteCluster)
 			}
 			// project names passed as parameters
 			for _, name := range args {
-				err := client.ComposeDown(name, opts.DeleteCluster)
+				err := client.ComposeDown(context.Background(), name, opts.DeleteCluster)
 				if err != nil {
 					return err
 				}

+ 12 - 11
ecs/pkg/amazon/cloudformation.go

@@ -1,6 +1,7 @@
 package amazon
 
 import (
+	"context"
 	"fmt"
 	"strings"
 
@@ -15,14 +16,14 @@ import (
 	"github.com/docker/ecs-plugin/pkg/convert"
 )
 
-func (c client) Convert(project *compose.Project) (*cloudformation.Template, error) {
+func (c client) Convert(ctx context.Context, project *compose.Project) (*cloudformation.Template, error) {
 	template := cloudformation.NewTemplate()
-	vpc, err := c.api.GetDefaultVPC()
+	vpc, err := c.api.GetDefaultVPC(ctx)
 	if err != nil {
 		return nil, err
 	}
 
-	subnets, err := c.api.GetSubNets(vpc)
+	subnets, err := c.api.GetSubNets(ctx, vpc)
 	if err != nil {
 		return nil, err
 	}
@@ -54,7 +55,7 @@ func (c client) Convert(project *compose.Project) (*cloudformation.Template, err
 			return nil, err
 		}
 
-		role, err := c.GetEcsTaskExecutionRole(service)
+		role, err := c.GetEcsTaskExecutionRole(ctx, service)
 		if err != nil {
 			return nil, err
 		}
@@ -87,7 +88,7 @@ const ECSTaskExecutionPolicy = "arn:aws:iam::aws:policy/service-role/AmazonECSTa
 var defaultTaskExecutionRole string
 
 // GetEcsTaskExecutionRole retrieve the role ARN to apply for task execution
-func (c client) GetEcsTaskExecutionRole(spec types.ServiceConfig) (string, error) {
+func (c client) GetEcsTaskExecutionRole(ctx context.Context, spec types.ServiceConfig) (string, error) {
 	if arn, ok := spec.Extras["x-ecs-TaskExecutionRole"]; ok {
 		return arn.(string), nil
 	}
@@ -96,7 +97,7 @@ func (c client) GetEcsTaskExecutionRole(spec types.ServiceConfig) (string, error
 	}
 
 	logrus.Debug("Retrieve Task Execution Role")
-	entities, err := c.api.ListRolesForPolicy(ECSTaskExecutionPolicy)
+	entities, err := c.api.ListRolesForPolicy(ctx, ECSTaskExecutionPolicy)
 	if err != nil {
 		return "", err
 	}
@@ -107,7 +108,7 @@ func (c client) GetEcsTaskExecutionRole(spec types.ServiceConfig) (string, error
 		return "", fmt.Errorf("multiple Roles are attached to AmazonECSTaskExecutionRole Policy, please provide an explicit task execution role")
 	}
 
-	arn, err := c.api.GetRoleArn(entities[0])
+	arn, err := c.api.GetRoleArn(ctx, entities[0])
 	if err != nil {
 		return "", err
 	}
@@ -116,8 +117,8 @@ func (c client) GetEcsTaskExecutionRole(spec types.ServiceConfig) (string, error
 }
 
 type convertAPI interface {
-	GetDefaultVPC() (string, error)
-	GetSubNets(vpcID string) ([]string, error)
-	ListRolesForPolicy(policy string) ([]string, error)
-	GetRoleArn(name string) (string, error)
+	GetDefaultVPC(ctx context.Context) (string, error)
+	GetSubNets(ctx context.Context, vpcID string) ([]string, error)
+	ListRolesForPolicy(ctx context.Context, policy string) ([]string, error)
+	GetRoleArn(ctx context.Context, name string) (string, error)
 }

+ 6 - 5
ecs/pkg/amazon/down.go

@@ -1,11 +1,12 @@
 package amazon
 
 import (
+	"context"
 	"fmt"
 )
 
-func (c *client) ComposeDown(projectName string, deleteCluster bool) error {
-	err := c.api.DeleteStack(projectName)
+func (c *client) ComposeDown(ctx context.Context, projectName string, deleteCluster bool) error {
+	err := c.api.DeleteStack(ctx, projectName)
 	if err != nil {
 		return err
 	}
@@ -16,7 +17,7 @@ func (c *client) ComposeDown(projectName string, deleteCluster bool) error {
 	}
 
 	fmt.Printf("Delete cluster %s", c.Cluster)
-	if err = c.api.DeleteCluster(c.Cluster); err != nil {
+	if err = c.api.DeleteCluster(ctx, c.Cluster); err != nil {
 		return err
 	}
 	fmt.Printf("... done. \n")
@@ -24,6 +25,6 @@ func (c *client) ComposeDown(projectName string, deleteCluster bool) error {
 }
 
 type downAPI interface {
-	DeleteStack(name string) error
-	DeleteCluster(name string) error
+	DeleteStack(ctx context.Context, name string) error
+	DeleteCluster(ctx context.Context, name string) error
 }

+ 8 - 6
ecs/pkg/amazon/down_test.go

@@ -1,6 +1,7 @@
 package amazon
 
 import (
+	"context"
 	"testing"
 
 	"github.com/docker/ecs-plugin/pkg/amazon/mock"
@@ -16,11 +17,11 @@ func TestDownDontDeleteCluster(t *testing.T) {
 		Region:  "region",
 		api:     m,
 	}
-
+	ctx := context.TODO()
 	recorder := m.EXPECT()
-	recorder.DeleteStack("test_project").Return(nil).Times(1)
+	recorder.DeleteStack(ctx, "test_project").Return(nil).Times(1)
 
-	c.ComposeDown("test_project", false)
+	c.ComposeDown(ctx, "test_project", false)
 }
 
 func TestDownDeleteCluster(t *testing.T) {
@@ -33,9 +34,10 @@ func TestDownDeleteCluster(t *testing.T) {
 		api:     m,
 	}
 
+	ctx := context.TODO()
 	recorder := m.EXPECT()
-	recorder.DeleteStack("test_project").Return(nil).Times(1)
-	recorder.DeleteCluster("test_cluster").Return(nil).Times(1)
+	recorder.DeleteStack(ctx, "test_project").Return(nil).Times(1)
+	recorder.DeleteCluster(ctx, "test_cluster").Return(nil).Times(1)
 
-	c.ComposeDown("test_project", true)
+	c.ComposeDown(ctx, "test_project", true)
 }

+ 45 - 44
ecs/pkg/amazon/mock/api.go

@@ -5,6 +5,7 @@
 package mock
 
 import (
+	context "context"
 	cloudformation "github.com/awslabs/goformation/v4/cloudformation"
 	gomock "github.com/golang/mock/gomock"
 	reflect "reflect"
@@ -34,162 +35,162 @@ func (m *MockAPI) EXPECT() *MockAPIMockRecorder {
 }
 
 // ClusterExists mocks base method
-func (m *MockAPI) ClusterExists(arg0 string) (bool, error) {
+func (m *MockAPI) ClusterExists(arg0 context.Context, arg1 string) (bool, error) {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "ClusterExists", arg0)
+	ret := m.ctrl.Call(m, "ClusterExists", arg0, arg1)
 	ret0, _ := ret[0].(bool)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 }
 
 // ClusterExists indicates an expected call of ClusterExists
-func (mr *MockAPIMockRecorder) ClusterExists(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) ClusterExists(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterExists", reflect.TypeOf((*MockAPI)(nil).ClusterExists), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterExists", reflect.TypeOf((*MockAPI)(nil).ClusterExists), arg0, arg1)
 }
 
 // CreateCluster mocks base method
-func (m *MockAPI) CreateCluster(arg0 string) (string, error) {
+func (m *MockAPI) CreateCluster(arg0 context.Context, arg1 string) (string, error) {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "CreateCluster", arg0)
+	ret := m.ctrl.Call(m, "CreateCluster", arg0, arg1)
 	ret0, _ := ret[0].(string)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 }
 
 // CreateCluster indicates an expected call of CreateCluster
-func (mr *MockAPIMockRecorder) CreateCluster(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) CreateCluster(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCluster", reflect.TypeOf((*MockAPI)(nil).CreateCluster), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCluster", reflect.TypeOf((*MockAPI)(nil).CreateCluster), arg0, arg1)
 }
 
 // CreateStack mocks base method
-func (m *MockAPI) CreateStack(arg0 string, arg1 *cloudformation.Template) error {
+func (m *MockAPI) CreateStack(arg0 context.Context, arg1 string, arg2 *cloudformation.Template) error {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "CreateStack", arg0, arg1)
+	ret := m.ctrl.Call(m, "CreateStack", arg0, arg1, arg2)
 	ret0, _ := ret[0].(error)
 	return ret0
 }
 
 // CreateStack indicates an expected call of CreateStack
-func (mr *MockAPIMockRecorder) CreateStack(arg0, arg1 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) CreateStack(arg0, arg1, arg2 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateStack", reflect.TypeOf((*MockAPI)(nil).CreateStack), arg0, arg1)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateStack", reflect.TypeOf((*MockAPI)(nil).CreateStack), arg0, arg1, arg2)
 }
 
 // DeleteCluster mocks base method
-func (m *MockAPI) DeleteCluster(arg0 string) error {
+func (m *MockAPI) DeleteCluster(arg0 context.Context, arg1 string) error {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "DeleteCluster", arg0)
+	ret := m.ctrl.Call(m, "DeleteCluster", arg0, arg1)
 	ret0, _ := ret[0].(error)
 	return ret0
 }
 
 // DeleteCluster indicates an expected call of DeleteCluster
-func (mr *MockAPIMockRecorder) DeleteCluster(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) DeleteCluster(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCluster", reflect.TypeOf((*MockAPI)(nil).DeleteCluster), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCluster", reflect.TypeOf((*MockAPI)(nil).DeleteCluster), arg0, arg1)
 }
 
 // DeleteStack mocks base method
-func (m *MockAPI) DeleteStack(arg0 string) error {
+func (m *MockAPI) DeleteStack(arg0 context.Context, arg1 string) error {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "DeleteStack", arg0)
+	ret := m.ctrl.Call(m, "DeleteStack", arg0, arg1)
 	ret0, _ := ret[0].(error)
 	return ret0
 }
 
 // DeleteStack indicates an expected call of DeleteStack
-func (mr *MockAPIMockRecorder) DeleteStack(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) DeleteStack(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStack", reflect.TypeOf((*MockAPI)(nil).DeleteStack), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStack", reflect.TypeOf((*MockAPI)(nil).DeleteStack), arg0, arg1)
 }
 
 // DescribeStackEvents mocks base method
-func (m *MockAPI) DescribeStackEvents(arg0 string) error {
+func (m *MockAPI) DescribeStackEvents(arg0 context.Context, arg1 string) error {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "DescribeStackEvents", arg0)
+	ret := m.ctrl.Call(m, "DescribeStackEvents", arg0, arg1)
 	ret0, _ := ret[0].(error)
 	return ret0
 }
 
 // DescribeStackEvents indicates an expected call of DescribeStackEvents
-func (mr *MockAPIMockRecorder) DescribeStackEvents(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) DescribeStackEvents(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeStackEvents", reflect.TypeOf((*MockAPI)(nil).DescribeStackEvents), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeStackEvents", reflect.TypeOf((*MockAPI)(nil).DescribeStackEvents), arg0, arg1)
 }
 
 // GetDefaultVPC mocks base method
-func (m *MockAPI) GetDefaultVPC() (string, error) {
+func (m *MockAPI) GetDefaultVPC(arg0 context.Context) (string, error) {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "GetDefaultVPC")
+	ret := m.ctrl.Call(m, "GetDefaultVPC", arg0)
 	ret0, _ := ret[0].(string)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 }
 
 // GetDefaultVPC indicates an expected call of GetDefaultVPC
-func (mr *MockAPIMockRecorder) GetDefaultVPC() *gomock.Call {
+func (mr *MockAPIMockRecorder) GetDefaultVPC(arg0 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultVPC", reflect.TypeOf((*MockAPI)(nil).GetDefaultVPC))
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultVPC", reflect.TypeOf((*MockAPI)(nil).GetDefaultVPC), arg0)
 }
 
 // GetRoleArn mocks base method
-func (m *MockAPI) GetRoleArn(arg0 string) (string, error) {
+func (m *MockAPI) GetRoleArn(arg0 context.Context, arg1 string) (string, error) {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "GetRoleArn", arg0)
+	ret := m.ctrl.Call(m, "GetRoleArn", arg0, arg1)
 	ret0, _ := ret[0].(string)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 }
 
 // GetRoleArn indicates an expected call of GetRoleArn
-func (mr *MockAPIMockRecorder) GetRoleArn(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) GetRoleArn(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoleArn", reflect.TypeOf((*MockAPI)(nil).GetRoleArn), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoleArn", reflect.TypeOf((*MockAPI)(nil).GetRoleArn), arg0, arg1)
 }
 
 // GetSubNets mocks base method
-func (m *MockAPI) GetSubNets(arg0 string) ([]string, error) {
+func (m *MockAPI) GetSubNets(arg0 context.Context, arg1 string) ([]string, error) {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "GetSubNets", arg0)
+	ret := m.ctrl.Call(m, "GetSubNets", arg0, arg1)
 	ret0, _ := ret[0].([]string)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 }
 
 // GetSubNets indicates an expected call of GetSubNets
-func (mr *MockAPIMockRecorder) GetSubNets(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) GetSubNets(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubNets", reflect.TypeOf((*MockAPI)(nil).GetSubNets), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubNets", reflect.TypeOf((*MockAPI)(nil).GetSubNets), arg0, arg1)
 }
 
 // ListRolesForPolicy mocks base method
-func (m *MockAPI) ListRolesForPolicy(arg0 string) ([]string, error) {
+func (m *MockAPI) ListRolesForPolicy(arg0 context.Context, arg1 string) ([]string, error) {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "ListRolesForPolicy", arg0)
+	ret := m.ctrl.Call(m, "ListRolesForPolicy", arg0, arg1)
 	ret0, _ := ret[0].([]string)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 }
 
 // ListRolesForPolicy indicates an expected call of ListRolesForPolicy
-func (mr *MockAPIMockRecorder) ListRolesForPolicy(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) ListRolesForPolicy(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRolesForPolicy", reflect.TypeOf((*MockAPI)(nil).ListRolesForPolicy), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRolesForPolicy", reflect.TypeOf((*MockAPI)(nil).ListRolesForPolicy), arg0, arg1)
 }
 
 // StackExists mocks base method
-func (m *MockAPI) StackExists(arg0 string) (bool, error) {
+func (m *MockAPI) StackExists(arg0 context.Context, arg1 string) (bool, error) {
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "StackExists", arg0)
+	ret := m.ctrl.Call(m, "StackExists", arg0, arg1)
 	ret0, _ := ret[0].(bool)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 }
 
 // StackExists indicates an expected call of StackExists
-func (mr *MockAPIMockRecorder) StackExists(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) StackExists(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StackExists", reflect.TypeOf((*MockAPI)(nil).StackExists), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StackExists", reflect.TypeOf((*MockAPI)(nil).StackExists), arg0, arg1)
 }

+ 23 - 22
ecs/pkg/amazon/sdk.go

@@ -1,6 +1,7 @@
 package amazon
 
 import (
+	"context"
 	"fmt"
 
 	"github.com/aws/aws-sdk-go/aws"
@@ -42,9 +43,9 @@ func NewAPI(sess *session.Session) API {
 	}
 }
 
-func (s sdk) ClusterExists(name string) (bool, error) {
+func (s sdk) ClusterExists(ctx context.Context, name string) (bool, error) {
 	logrus.Debug("Check if cluster was already created: ", name)
-	clusters, err := s.ECS.DescribeClusters(&ecs.DescribeClustersInput{
+	clusters, err := s.ECS.DescribeClustersWithContext(aws.Context(ctx), &ecs.DescribeClustersInput{
 		Clusters: []*string{aws.String(name)},
 	})
 	if err != nil {
@@ -53,18 +54,18 @@ func (s sdk) ClusterExists(name string) (bool, error) {
 	return len(clusters.Clusters) > 0, nil
 }
 
-func (s sdk) CreateCluster(name string) (string, error) {
+func (s sdk) CreateCluster(ctx context.Context, name string) (string, error) {
 	logrus.Debug("Create cluster ", name)
-	response, err := s.ECS.CreateCluster(&ecs.CreateClusterInput{ClusterName: aws.String(name)})
+	response, err := s.ECS.CreateClusterWithContext(aws.Context(ctx), &ecs.CreateClusterInput{ClusterName: aws.String(name)})
 	if err != nil {
 		return "", err
 	}
 	return *response.Cluster.Status, nil
 }
 
-func (s sdk) DeleteCluster(name string) error {
+func (s sdk) DeleteCluster(ctx context.Context, name string) error {
 	logrus.Debug("Delete cluster ", name)
-	response, err := s.ECS.DeleteCluster(&ecs.DeleteClusterInput{Cluster: aws.String(name)})
+	response, err := s.ECS.DeleteClusterWithContext(aws.Context(ctx), &ecs.DeleteClusterInput{Cluster: aws.String(name)})
 	if err != nil {
 		return err
 	}
@@ -74,9 +75,9 @@ func (s sdk) DeleteCluster(name string) error {
 	return fmt.Errorf("Failed to delete cluster, status: %s" + *response.Cluster.Status)
 }
 
-func (s sdk) GetDefaultVPC() (string, error) {
+func (s sdk) GetDefaultVPC(ctx context.Context) (string, error) {
 	logrus.Debug("Retrieve default VPC")
-	vpcs, err := s.EC2.DescribeVpcs(&ec2.DescribeVpcsInput{
+	vpcs, err := s.EC2.DescribeVpcsWithContext(aws.Context(ctx), &ec2.DescribeVpcsInput{
 		Filters: []*ec2.Filter{
 			{
 				Name:   aws.String("isDefault"),
@@ -93,9 +94,9 @@ func (s sdk) GetDefaultVPC() (string, error) {
 	return *vpcs.Vpcs[0].VpcId, nil
 }
 
-func (s sdk) GetSubNets(vpcID string) ([]string, error) {
+func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]string, error) {
 	logrus.Debug("Retrieve SubNets")
-	subnets, err := s.EC2.DescribeSubnets(&ec2.DescribeSubnetsInput{
+	subnets, err := s.EC2.DescribeSubnetsWithContext(aws.Context(ctx), &ec2.DescribeSubnetsInput{
 		DryRun: nil,
 		Filters: []*ec2.Filter{
 			{
@@ -119,8 +120,8 @@ func (s sdk) GetSubNets(vpcID string) ([]string, error) {
 	return ids, nil
 }
 
-func (s sdk) ListRolesForPolicy(policy string) ([]string, error) {
-	entities, err := s.IAM.ListEntitiesForPolicy(&iam.ListEntitiesForPolicyInput{
+func (s sdk) ListRolesForPolicy(ctx context.Context, policy string) ([]string, error) {
+	entities, err := s.IAM.ListEntitiesForPolicyWithContext(aws.Context(ctx), &iam.ListEntitiesForPolicyInput{
 		EntityFilter: aws.String("Role"),
 		PolicyArn:    aws.String(policy),
 	})
@@ -134,8 +135,8 @@ func (s sdk) ListRolesForPolicy(policy string) ([]string, error) {
 	return roles, nil
 }
 
-func (s sdk) GetRoleArn(name string) (string, error) {
-	role, err := s.IAM.GetRole(&iam.GetRoleInput{
+func (s sdk) GetRoleArn(ctx context.Context, name string) (string, error) {
+	role, err := s.IAM.GetRoleWithContext(aws.Context(ctx), &iam.GetRoleInput{
 		RoleName: aws.String(name),
 	})
 	if err != nil {
@@ -144,8 +145,8 @@ func (s sdk) GetRoleArn(name string) (string, error) {
 	return *role.Role.Arn, nil
 }
 
-func (s sdk) StackExists(name string) (bool, error) {
-	stacks, err := s.CF.DescribeStacks(&cloudformation.DescribeStacksInput{
+func (s sdk) StackExists(ctx context.Context, name string) (bool, error) {
+	stacks, err := s.CF.DescribeStacksWithContext(aws.Context(ctx), &cloudformation.DescribeStacksInput{
 		StackName: aws.String(name),
 	})
 	if err != nil {
@@ -155,14 +156,14 @@ func (s sdk) StackExists(name string) (bool, error) {
 	return len(stacks.Stacks) > 0, nil
 }
 
-func (s sdk) CreateStack(name string, template *cf.Template) error {
+func (s sdk) CreateStack(ctx context.Context, name string, template *cf.Template) error {
 	logrus.Debug("Create CloudFormation stack")
 	json, err := template.JSON()
 	if err != nil {
 		return err
 	}
 
-	_, err = s.CF.CreateStack(&cloudformation.CreateStackInput{
+	_, err = s.CF.CreateStackWithContext(aws.Context(ctx), &cloudformation.CreateStackInput{
 		OnFailure:        aws.String("DELETE"),
 		StackName:        aws.String(name),
 		TemplateBody:     aws.String(string(json)),
@@ -171,17 +172,17 @@ func (s sdk) CreateStack(name string, template *cf.Template) error {
 	return err
 }
 
-func (s sdk) DescribeStackEvents(name string) error {
+func (s sdk) DescribeStackEvents(ctx context.Context, name string) error {
 	// Fixme implement Paginator on Events and return as a chan(events)
-	_, err := s.CF.DescribeStackEvents(&cloudformation.DescribeStackEventsInput{
+	_, err := s.CF.DescribeStackEventsWithContext(aws.Context(ctx), &cloudformation.DescribeStackEventsInput{
 		StackName: aws.String(name),
 	})
 	return err
 }
 
-func (s sdk) DeleteStack(name string) error {
+func (s sdk) DeleteStack(ctx context.Context, name string) error {
 	logrus.Debug("Delete CloudFormation stack")
-	_, err := s.CF.DeleteStack(&cloudformation.DeleteStackInput{
+	_, err := s.CF.DeleteStackWithContext(aws.Context(ctx), &cloudformation.DeleteStackInput{
 		StackName: aws.String(name),
 	})
 	return err

+ 13 - 12
ecs/pkg/amazon/up.go

@@ -1,21 +1,22 @@
 package amazon
 
 import (
+	"context"
 	"fmt"
 
 	"github.com/awslabs/goformation/v4/cloudformation"
 	"github.com/docker/ecs-plugin/pkg/compose"
 )
 
-func (c *client) ComposeUp(project *compose.Project) error {
-	ok, err := c.api.ClusterExists(c.Cluster)
+func (c *client) ComposeUp(ctx context.Context, project *compose.Project) error {
+	ok, err := c.api.ClusterExists(ctx, c.Cluster)
 	if err != nil {
 		return err
 	}
 	if !ok {
-		c.api.CreateCluster(c.Cluster)
+		c.api.CreateCluster(ctx, c.Cluster)
 	}
-	update, err := c.api.StackExists(project.Name)
+	update, err := c.api.StackExists(ctx, project.Name)
 	if err != nil {
 		return err
 	}
@@ -23,17 +24,17 @@ func (c *client) ComposeUp(project *compose.Project) error {
 		return fmt.Errorf("we do not (yet) support updating an existing CloudFormation stack")
 	}
 
-	template, err := c.Convert(project)
+	template, err := c.Convert(ctx, project)
 	if err != nil {
 		return err
 	}
 
-	err = c.api.CreateStack(project.Name, template)
+	err = c.api.CreateStack(ctx, project.Name, template)
 	if err != nil {
 		return err
 	}
 
-	err = c.api.DescribeStackEvents(project.Name)
+	err = c.api.DescribeStackEvents(ctx, project.Name)
 	if err != nil {
 		return err
 	}
@@ -43,9 +44,9 @@ func (c *client) ComposeUp(project *compose.Project) error {
 }
 
 type upAPI interface {
-	ClusterExists(name string) (bool, error)
-	CreateCluster(name string) (string, error)
-	StackExists(name string) (bool, error)
-	CreateStack(name string, template *cloudformation.Template) error
-	DescribeStackEvents(stack string) error
+	ClusterExists(ctx context.Context, name string) (bool, error)
+	CreateCluster(ctx context.Context, name string) (string, error)
+	StackExists(ctx context.Context, name string) (bool, error)
+	CreateStack(ctx context.Context, name string, template *cloudformation.Template) error
+	DescribeStackEvents(ctx context.Context, stack string) error
 }

+ 8 - 4
ecs/pkg/compose/api.go

@@ -1,9 +1,13 @@
 package compose
 
-import "github.com/awslabs/goformation/v4/cloudformation"
+import (
+	"context"
+
+	"github.com/awslabs/goformation/v4/cloudformation"
+)
 
 type API interface {
-	Convert(project *Project) (*cloudformation.Template, error)
-	ComposeUp(project *Project) error
-	ComposeDown(projectName string, deleteCluster bool) error
+	Convert(ctx context.Context, project *Project) (*cloudformation.Template, error)
+	ComposeUp(ctx context.Context, project *Project) error
+	ComposeDown(ctx context.Context, projectName string, deleteCluster bool) error
 }