Преглед изворни кода

Use `WithContext` SDK APIs so we can implement cancelation

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof пре 5 година
родитељ
комит
b6be4a0ac3

+ 5 - 4
ecs/cmd/main/main.go

@@ -1,6 +1,7 @@
 package main
 package main
 
 
 import (
 import (
+	"context"
 	"fmt"
 	"fmt"
 
 
 	"github.com/docker/cli/cli-plugins/manager"
 	"github.com/docker/cli/cli-plugins/manager"
@@ -97,7 +98,7 @@ func ConvertCommand(clusteropts *clusterOptions, projectOpts *compose.ProjectOpt
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
-			template, err := client.Convert(project)
+			template, err := client.Convert(context.Background(), project)
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
@@ -123,7 +124,7 @@ func UpCommand(clusteropts *clusterOptions, projectOpts *compose.ProjectOptions)
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
-			return client.ComposeUp(project)
+			return client.ComposeUp(context.Background(), project)
 		}),
 		}),
 	}
 	}
 	cmd.Flags().StringVar(&opts.loadBalancerArn, "load-balancer", "", "")
 	cmd.Flags().StringVar(&opts.loadBalancerArn, "load-balancer", "", "")
@@ -148,11 +149,11 @@ func DownCommand(clusteropts *clusterOptions, projectOpts *compose.ProjectOption
 				if err != nil {
 				if err != nil {
 					return err
 					return err
 				}
 				}
-				return client.ComposeDown(project.Name, opts.DeleteCluster)
+				return client.ComposeDown(context.Background(), project.Name, opts.DeleteCluster)
 			}
 			}
 			// project names passed as parameters
 			// project names passed as parameters
 			for _, name := range args {
 			for _, name := range args {
-				err := client.ComposeDown(name, opts.DeleteCluster)
+				err := client.ComposeDown(context.Background(), name, opts.DeleteCluster)
 				if err != nil {
 				if err != nil {
 					return err
 					return err
 				}
 				}

+ 12 - 11
ecs/pkg/amazon/cloudformation.go

@@ -1,6 +1,7 @@
 package amazon
 package amazon
 
 
 import (
 import (
+	"context"
 	"fmt"
 	"fmt"
 	"strings"
 	"strings"
 
 
@@ -15,14 +16,14 @@ import (
 	"github.com/docker/ecs-plugin/pkg/convert"
 	"github.com/docker/ecs-plugin/pkg/convert"
 )
 )
 
 
-func (c client) Convert(project *compose.Project) (*cloudformation.Template, error) {
+func (c client) Convert(ctx context.Context, project *compose.Project) (*cloudformation.Template, error) {
 	template := cloudformation.NewTemplate()
 	template := cloudformation.NewTemplate()
-	vpc, err := c.api.GetDefaultVPC()
+	vpc, err := c.api.GetDefaultVPC(ctx)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	subnets, err := c.api.GetSubNets(vpc)
+	subnets, err := c.api.GetSubNets(ctx, vpc)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -54,7 +55,7 @@ func (c client) Convert(project *compose.Project) (*cloudformation.Template, err
 			return nil, err
 			return nil, err
 		}
 		}
 
 
-		role, err := c.GetEcsTaskExecutionRole(service)
+		role, err := c.GetEcsTaskExecutionRole(ctx, service)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -87,7 +88,7 @@ const ECSTaskExecutionPolicy = "arn:aws:iam::aws:policy/service-role/AmazonECSTa
 var defaultTaskExecutionRole string
 var defaultTaskExecutionRole string
 
 
 // GetEcsTaskExecutionRole retrieve the role ARN to apply for task execution
 // GetEcsTaskExecutionRole retrieve the role ARN to apply for task execution
-func (c client) GetEcsTaskExecutionRole(spec types.ServiceConfig) (string, error) {
+func (c client) GetEcsTaskExecutionRole(ctx context.Context, spec types.ServiceConfig) (string, error) {
 	if arn, ok := spec.Extras["x-ecs-TaskExecutionRole"]; ok {
 	if arn, ok := spec.Extras["x-ecs-TaskExecutionRole"]; ok {
 		return arn.(string), nil
 		return arn.(string), nil
 	}
 	}
@@ -96,7 +97,7 @@ func (c client) GetEcsTaskExecutionRole(spec types.ServiceConfig) (string, error
 	}
 	}
 
 
 	logrus.Debug("Retrieve Task Execution Role")
 	logrus.Debug("Retrieve Task Execution Role")
-	entities, err := c.api.ListRolesForPolicy(ECSTaskExecutionPolicy)
+	entities, err := c.api.ListRolesForPolicy(ctx, ECSTaskExecutionPolicy)
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
@@ -107,7 +108,7 @@ func (c client) GetEcsTaskExecutionRole(spec types.ServiceConfig) (string, error
 		return "", fmt.Errorf("multiple Roles are attached to AmazonECSTaskExecutionRole Policy, please provide an explicit task execution role")
 		return "", fmt.Errorf("multiple Roles are attached to AmazonECSTaskExecutionRole Policy, please provide an explicit task execution role")
 	}
 	}
 
 
-	arn, err := c.api.GetRoleArn(entities[0])
+	arn, err := c.api.GetRoleArn(ctx, entities[0])
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
@@ -116,8 +117,8 @@ func (c client) GetEcsTaskExecutionRole(spec types.ServiceConfig) (string, error
 }
 }
 
 
 type convertAPI interface {
 type convertAPI interface {
-	GetDefaultVPC() (string, error)
-	GetSubNets(vpcID string) ([]string, error)
-	ListRolesForPolicy(policy string) ([]string, error)
-	GetRoleArn(name string) (string, error)
+	GetDefaultVPC(ctx context.Context) (string, error)
+	GetSubNets(ctx context.Context, vpcID string) ([]string, error)
+	ListRolesForPolicy(ctx context.Context, policy string) ([]string, error)
+	GetRoleArn(ctx context.Context, name string) (string, error)
 }
 }

+ 6 - 5
ecs/pkg/amazon/down.go

@@ -1,11 +1,12 @@
 package amazon
 package amazon
 
 
 import (
 import (
+	"context"
 	"fmt"
 	"fmt"
 )
 )
 
 
-func (c *client) ComposeDown(projectName string, deleteCluster bool) error {
-	err := c.api.DeleteStack(projectName)
+func (c *client) ComposeDown(ctx context.Context, projectName string, deleteCluster bool) error {
+	err := c.api.DeleteStack(ctx, projectName)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -16,7 +17,7 @@ func (c *client) ComposeDown(projectName string, deleteCluster bool) error {
 	}
 	}
 
 
 	fmt.Printf("Delete cluster %s", c.Cluster)
 	fmt.Printf("Delete cluster %s", c.Cluster)
-	if err = c.api.DeleteCluster(c.Cluster); err != nil {
+	if err = c.api.DeleteCluster(ctx, c.Cluster); err != nil {
 		return err
 		return err
 	}
 	}
 	fmt.Printf("... done. \n")
 	fmt.Printf("... done. \n")
@@ -24,6 +25,6 @@ func (c *client) ComposeDown(projectName string, deleteCluster bool) error {
 }
 }
 
 
 type downAPI interface {
 type downAPI interface {
-	DeleteStack(name string) error
-	DeleteCluster(name string) error
+	DeleteStack(ctx context.Context, name string) error
+	DeleteCluster(ctx context.Context, name string) error
 }
 }

+ 8 - 6
ecs/pkg/amazon/down_test.go

@@ -1,6 +1,7 @@
 package amazon
 package amazon
 
 
 import (
 import (
+	"context"
 	"testing"
 	"testing"
 
 
 	"github.com/docker/ecs-plugin/pkg/amazon/mock"
 	"github.com/docker/ecs-plugin/pkg/amazon/mock"
@@ -16,11 +17,11 @@ func TestDownDontDeleteCluster(t *testing.T) {
 		Region:  "region",
 		Region:  "region",
 		api:     m,
 		api:     m,
 	}
 	}
-
+	ctx := context.TODO()
 	recorder := m.EXPECT()
 	recorder := m.EXPECT()
-	recorder.DeleteStack("test_project").Return(nil).Times(1)
+	recorder.DeleteStack(ctx, "test_project").Return(nil).Times(1)
 
 
-	c.ComposeDown("test_project", false)
+	c.ComposeDown(ctx, "test_project", false)
 }
 }
 
 
 func TestDownDeleteCluster(t *testing.T) {
 func TestDownDeleteCluster(t *testing.T) {
@@ -33,9 +34,10 @@ func TestDownDeleteCluster(t *testing.T) {
 		api:     m,
 		api:     m,
 	}
 	}
 
 
+	ctx := context.TODO()
 	recorder := m.EXPECT()
 	recorder := m.EXPECT()
-	recorder.DeleteStack("test_project").Return(nil).Times(1)
-	recorder.DeleteCluster("test_cluster").Return(nil).Times(1)
+	recorder.DeleteStack(ctx, "test_project").Return(nil).Times(1)
+	recorder.DeleteCluster(ctx, "test_cluster").Return(nil).Times(1)
 
 
-	c.ComposeDown("test_project", true)
+	c.ComposeDown(ctx, "test_project", true)
 }
 }

+ 45 - 44
ecs/pkg/amazon/mock/api.go

@@ -5,6 +5,7 @@
 package mock
 package mock
 
 
 import (
 import (
+	context "context"
 	cloudformation "github.com/awslabs/goformation/v4/cloudformation"
 	cloudformation "github.com/awslabs/goformation/v4/cloudformation"
 	gomock "github.com/golang/mock/gomock"
 	gomock "github.com/golang/mock/gomock"
 	reflect "reflect"
 	reflect "reflect"
@@ -34,162 +35,162 @@ func (m *MockAPI) EXPECT() *MockAPIMockRecorder {
 }
 }
 
 
 // ClusterExists mocks base method
 // ClusterExists mocks base method
-func (m *MockAPI) ClusterExists(arg0 string) (bool, error) {
+func (m *MockAPI) ClusterExists(arg0 context.Context, arg1 string) (bool, error) {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "ClusterExists", arg0)
+	ret := m.ctrl.Call(m, "ClusterExists", arg0, arg1)
 	ret0, _ := ret[0].(bool)
 	ret0, _ := ret[0].(bool)
 	ret1, _ := ret[1].(error)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 	return ret0, ret1
 }
 }
 
 
 // ClusterExists indicates an expected call of ClusterExists
 // ClusterExists indicates an expected call of ClusterExists
-func (mr *MockAPIMockRecorder) ClusterExists(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) ClusterExists(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterExists", reflect.TypeOf((*MockAPI)(nil).ClusterExists), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterExists", reflect.TypeOf((*MockAPI)(nil).ClusterExists), arg0, arg1)
 }
 }
 
 
 // CreateCluster mocks base method
 // CreateCluster mocks base method
-func (m *MockAPI) CreateCluster(arg0 string) (string, error) {
+func (m *MockAPI) CreateCluster(arg0 context.Context, arg1 string) (string, error) {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "CreateCluster", arg0)
+	ret := m.ctrl.Call(m, "CreateCluster", arg0, arg1)
 	ret0, _ := ret[0].(string)
 	ret0, _ := ret[0].(string)
 	ret1, _ := ret[1].(error)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 	return ret0, ret1
 }
 }
 
 
 // CreateCluster indicates an expected call of CreateCluster
 // CreateCluster indicates an expected call of CreateCluster
-func (mr *MockAPIMockRecorder) CreateCluster(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) CreateCluster(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCluster", reflect.TypeOf((*MockAPI)(nil).CreateCluster), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCluster", reflect.TypeOf((*MockAPI)(nil).CreateCluster), arg0, arg1)
 }
 }
 
 
 // CreateStack mocks base method
 // CreateStack mocks base method
-func (m *MockAPI) CreateStack(arg0 string, arg1 *cloudformation.Template) error {
+func (m *MockAPI) CreateStack(arg0 context.Context, arg1 string, arg2 *cloudformation.Template) error {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "CreateStack", arg0, arg1)
+	ret := m.ctrl.Call(m, "CreateStack", arg0, arg1, arg2)
 	ret0, _ := ret[0].(error)
 	ret0, _ := ret[0].(error)
 	return ret0
 	return ret0
 }
 }
 
 
 // CreateStack indicates an expected call of CreateStack
 // CreateStack indicates an expected call of CreateStack
-func (mr *MockAPIMockRecorder) CreateStack(arg0, arg1 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) CreateStack(arg0, arg1, arg2 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateStack", reflect.TypeOf((*MockAPI)(nil).CreateStack), arg0, arg1)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateStack", reflect.TypeOf((*MockAPI)(nil).CreateStack), arg0, arg1, arg2)
 }
 }
 
 
 // DeleteCluster mocks base method
 // DeleteCluster mocks base method
-func (m *MockAPI) DeleteCluster(arg0 string) error {
+func (m *MockAPI) DeleteCluster(arg0 context.Context, arg1 string) error {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "DeleteCluster", arg0)
+	ret := m.ctrl.Call(m, "DeleteCluster", arg0, arg1)
 	ret0, _ := ret[0].(error)
 	ret0, _ := ret[0].(error)
 	return ret0
 	return ret0
 }
 }
 
 
 // DeleteCluster indicates an expected call of DeleteCluster
 // DeleteCluster indicates an expected call of DeleteCluster
-func (mr *MockAPIMockRecorder) DeleteCluster(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) DeleteCluster(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCluster", reflect.TypeOf((*MockAPI)(nil).DeleteCluster), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCluster", reflect.TypeOf((*MockAPI)(nil).DeleteCluster), arg0, arg1)
 }
 }
 
 
 // DeleteStack mocks base method
 // DeleteStack mocks base method
-func (m *MockAPI) DeleteStack(arg0 string) error {
+func (m *MockAPI) DeleteStack(arg0 context.Context, arg1 string) error {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "DeleteStack", arg0)
+	ret := m.ctrl.Call(m, "DeleteStack", arg0, arg1)
 	ret0, _ := ret[0].(error)
 	ret0, _ := ret[0].(error)
 	return ret0
 	return ret0
 }
 }
 
 
 // DeleteStack indicates an expected call of DeleteStack
 // DeleteStack indicates an expected call of DeleteStack
-func (mr *MockAPIMockRecorder) DeleteStack(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) DeleteStack(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStack", reflect.TypeOf((*MockAPI)(nil).DeleteStack), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStack", reflect.TypeOf((*MockAPI)(nil).DeleteStack), arg0, arg1)
 }
 }
 
 
 // DescribeStackEvents mocks base method
 // DescribeStackEvents mocks base method
-func (m *MockAPI) DescribeStackEvents(arg0 string) error {
+func (m *MockAPI) DescribeStackEvents(arg0 context.Context, arg1 string) error {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "DescribeStackEvents", arg0)
+	ret := m.ctrl.Call(m, "DescribeStackEvents", arg0, arg1)
 	ret0, _ := ret[0].(error)
 	ret0, _ := ret[0].(error)
 	return ret0
 	return ret0
 }
 }
 
 
 // DescribeStackEvents indicates an expected call of DescribeStackEvents
 // DescribeStackEvents indicates an expected call of DescribeStackEvents
-func (mr *MockAPIMockRecorder) DescribeStackEvents(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) DescribeStackEvents(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeStackEvents", reflect.TypeOf((*MockAPI)(nil).DescribeStackEvents), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeStackEvents", reflect.TypeOf((*MockAPI)(nil).DescribeStackEvents), arg0, arg1)
 }
 }
 
 
 // GetDefaultVPC mocks base method
 // GetDefaultVPC mocks base method
-func (m *MockAPI) GetDefaultVPC() (string, error) {
+func (m *MockAPI) GetDefaultVPC(arg0 context.Context) (string, error) {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "GetDefaultVPC")
+	ret := m.ctrl.Call(m, "GetDefaultVPC", arg0)
 	ret0, _ := ret[0].(string)
 	ret0, _ := ret[0].(string)
 	ret1, _ := ret[1].(error)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 	return ret0, ret1
 }
 }
 
 
 // GetDefaultVPC indicates an expected call of GetDefaultVPC
 // GetDefaultVPC indicates an expected call of GetDefaultVPC
-func (mr *MockAPIMockRecorder) GetDefaultVPC() *gomock.Call {
+func (mr *MockAPIMockRecorder) GetDefaultVPC(arg0 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultVPC", reflect.TypeOf((*MockAPI)(nil).GetDefaultVPC))
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultVPC", reflect.TypeOf((*MockAPI)(nil).GetDefaultVPC), arg0)
 }
 }
 
 
 // GetRoleArn mocks base method
 // GetRoleArn mocks base method
-func (m *MockAPI) GetRoleArn(arg0 string) (string, error) {
+func (m *MockAPI) GetRoleArn(arg0 context.Context, arg1 string) (string, error) {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "GetRoleArn", arg0)
+	ret := m.ctrl.Call(m, "GetRoleArn", arg0, arg1)
 	ret0, _ := ret[0].(string)
 	ret0, _ := ret[0].(string)
 	ret1, _ := ret[1].(error)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 	return ret0, ret1
 }
 }
 
 
 // GetRoleArn indicates an expected call of GetRoleArn
 // GetRoleArn indicates an expected call of GetRoleArn
-func (mr *MockAPIMockRecorder) GetRoleArn(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) GetRoleArn(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoleArn", reflect.TypeOf((*MockAPI)(nil).GetRoleArn), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoleArn", reflect.TypeOf((*MockAPI)(nil).GetRoleArn), arg0, arg1)
 }
 }
 
 
 // GetSubNets mocks base method
 // GetSubNets mocks base method
-func (m *MockAPI) GetSubNets(arg0 string) ([]string, error) {
+func (m *MockAPI) GetSubNets(arg0 context.Context, arg1 string) ([]string, error) {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "GetSubNets", arg0)
+	ret := m.ctrl.Call(m, "GetSubNets", arg0, arg1)
 	ret0, _ := ret[0].([]string)
 	ret0, _ := ret[0].([]string)
 	ret1, _ := ret[1].(error)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 	return ret0, ret1
 }
 }
 
 
 // GetSubNets indicates an expected call of GetSubNets
 // GetSubNets indicates an expected call of GetSubNets
-func (mr *MockAPIMockRecorder) GetSubNets(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) GetSubNets(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubNets", reflect.TypeOf((*MockAPI)(nil).GetSubNets), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubNets", reflect.TypeOf((*MockAPI)(nil).GetSubNets), arg0, arg1)
 }
 }
 
 
 // ListRolesForPolicy mocks base method
 // ListRolesForPolicy mocks base method
-func (m *MockAPI) ListRolesForPolicy(arg0 string) ([]string, error) {
+func (m *MockAPI) ListRolesForPolicy(arg0 context.Context, arg1 string) ([]string, error) {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "ListRolesForPolicy", arg0)
+	ret := m.ctrl.Call(m, "ListRolesForPolicy", arg0, arg1)
 	ret0, _ := ret[0].([]string)
 	ret0, _ := ret[0].([]string)
 	ret1, _ := ret[1].(error)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 	return ret0, ret1
 }
 }
 
 
 // ListRolesForPolicy indicates an expected call of ListRolesForPolicy
 // ListRolesForPolicy indicates an expected call of ListRolesForPolicy
-func (mr *MockAPIMockRecorder) ListRolesForPolicy(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) ListRolesForPolicy(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRolesForPolicy", reflect.TypeOf((*MockAPI)(nil).ListRolesForPolicy), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRolesForPolicy", reflect.TypeOf((*MockAPI)(nil).ListRolesForPolicy), arg0, arg1)
 }
 }
 
 
 // StackExists mocks base method
 // StackExists mocks base method
-func (m *MockAPI) StackExists(arg0 string) (bool, error) {
+func (m *MockAPI) StackExists(arg0 context.Context, arg1 string) (bool, error) {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "StackExists", arg0)
+	ret := m.ctrl.Call(m, "StackExists", arg0, arg1)
 	ret0, _ := ret[0].(bool)
 	ret0, _ := ret[0].(bool)
 	ret1, _ := ret[1].(error)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 	return ret0, ret1
 }
 }
 
 
 // StackExists indicates an expected call of StackExists
 // StackExists indicates an expected call of StackExists
-func (mr *MockAPIMockRecorder) StackExists(arg0 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) StackExists(arg0, arg1 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StackExists", reflect.TypeOf((*MockAPI)(nil).StackExists), arg0)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StackExists", reflect.TypeOf((*MockAPI)(nil).StackExists), arg0, arg1)
 }
 }

+ 23 - 22
ecs/pkg/amazon/sdk.go

@@ -1,6 +1,7 @@
 package amazon
 package amazon
 
 
 import (
 import (
+	"context"
 	"fmt"
 	"fmt"
 
 
 	"github.com/aws/aws-sdk-go/aws"
 	"github.com/aws/aws-sdk-go/aws"
@@ -42,9 +43,9 @@ func NewAPI(sess *session.Session) API {
 	}
 	}
 }
 }
 
 
-func (s sdk) ClusterExists(name string) (bool, error) {
+func (s sdk) ClusterExists(ctx context.Context, name string) (bool, error) {
 	logrus.Debug("Check if cluster was already created: ", name)
 	logrus.Debug("Check if cluster was already created: ", name)
-	clusters, err := s.ECS.DescribeClusters(&ecs.DescribeClustersInput{
+	clusters, err := s.ECS.DescribeClustersWithContext(aws.Context(ctx), &ecs.DescribeClustersInput{
 		Clusters: []*string{aws.String(name)},
 		Clusters: []*string{aws.String(name)},
 	})
 	})
 	if err != nil {
 	if err != nil {
@@ -53,18 +54,18 @@ func (s sdk) ClusterExists(name string) (bool, error) {
 	return len(clusters.Clusters) > 0, nil
 	return len(clusters.Clusters) > 0, nil
 }
 }
 
 
-func (s sdk) CreateCluster(name string) (string, error) {
+func (s sdk) CreateCluster(ctx context.Context, name string) (string, error) {
 	logrus.Debug("Create cluster ", name)
 	logrus.Debug("Create cluster ", name)
-	response, err := s.ECS.CreateCluster(&ecs.CreateClusterInput{ClusterName: aws.String(name)})
+	response, err := s.ECS.CreateClusterWithContext(aws.Context(ctx), &ecs.CreateClusterInput{ClusterName: aws.String(name)})
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
 	return *response.Cluster.Status, nil
 	return *response.Cluster.Status, nil
 }
 }
 
 
-func (s sdk) DeleteCluster(name string) error {
+func (s sdk) DeleteCluster(ctx context.Context, name string) error {
 	logrus.Debug("Delete cluster ", name)
 	logrus.Debug("Delete cluster ", name)
-	response, err := s.ECS.DeleteCluster(&ecs.DeleteClusterInput{Cluster: aws.String(name)})
+	response, err := s.ECS.DeleteClusterWithContext(aws.Context(ctx), &ecs.DeleteClusterInput{Cluster: aws.String(name)})
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -74,9 +75,9 @@ func (s sdk) DeleteCluster(name string) error {
 	return fmt.Errorf("Failed to delete cluster, status: %s" + *response.Cluster.Status)
 	return fmt.Errorf("Failed to delete cluster, status: %s" + *response.Cluster.Status)
 }
 }
 
 
-func (s sdk) GetDefaultVPC() (string, error) {
+func (s sdk) GetDefaultVPC(ctx context.Context) (string, error) {
 	logrus.Debug("Retrieve default VPC")
 	logrus.Debug("Retrieve default VPC")
-	vpcs, err := s.EC2.DescribeVpcs(&ec2.DescribeVpcsInput{
+	vpcs, err := s.EC2.DescribeVpcsWithContext(aws.Context(ctx), &ec2.DescribeVpcsInput{
 		Filters: []*ec2.Filter{
 		Filters: []*ec2.Filter{
 			{
 			{
 				Name:   aws.String("isDefault"),
 				Name:   aws.String("isDefault"),
@@ -93,9 +94,9 @@ func (s sdk) GetDefaultVPC() (string, error) {
 	return *vpcs.Vpcs[0].VpcId, nil
 	return *vpcs.Vpcs[0].VpcId, nil
 }
 }
 
 
-func (s sdk) GetSubNets(vpcID string) ([]string, error) {
+func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]string, error) {
 	logrus.Debug("Retrieve SubNets")
 	logrus.Debug("Retrieve SubNets")
-	subnets, err := s.EC2.DescribeSubnets(&ec2.DescribeSubnetsInput{
+	subnets, err := s.EC2.DescribeSubnetsWithContext(aws.Context(ctx), &ec2.DescribeSubnetsInput{
 		DryRun: nil,
 		DryRun: nil,
 		Filters: []*ec2.Filter{
 		Filters: []*ec2.Filter{
 			{
 			{
@@ -119,8 +120,8 @@ func (s sdk) GetSubNets(vpcID string) ([]string, error) {
 	return ids, nil
 	return ids, nil
 }
 }
 
 
-func (s sdk) ListRolesForPolicy(policy string) ([]string, error) {
-	entities, err := s.IAM.ListEntitiesForPolicy(&iam.ListEntitiesForPolicyInput{
+func (s sdk) ListRolesForPolicy(ctx context.Context, policy string) ([]string, error) {
+	entities, err := s.IAM.ListEntitiesForPolicyWithContext(aws.Context(ctx), &iam.ListEntitiesForPolicyInput{
 		EntityFilter: aws.String("Role"),
 		EntityFilter: aws.String("Role"),
 		PolicyArn:    aws.String(policy),
 		PolicyArn:    aws.String(policy),
 	})
 	})
@@ -134,8 +135,8 @@ func (s sdk) ListRolesForPolicy(policy string) ([]string, error) {
 	return roles, nil
 	return roles, nil
 }
 }
 
 
-func (s sdk) GetRoleArn(name string) (string, error) {
-	role, err := s.IAM.GetRole(&iam.GetRoleInput{
+func (s sdk) GetRoleArn(ctx context.Context, name string) (string, error) {
+	role, err := s.IAM.GetRoleWithContext(aws.Context(ctx), &iam.GetRoleInput{
 		RoleName: aws.String(name),
 		RoleName: aws.String(name),
 	})
 	})
 	if err != nil {
 	if err != nil {
@@ -144,8 +145,8 @@ func (s sdk) GetRoleArn(name string) (string, error) {
 	return *role.Role.Arn, nil
 	return *role.Role.Arn, nil
 }
 }
 
 
-func (s sdk) StackExists(name string) (bool, error) {
-	stacks, err := s.CF.DescribeStacks(&cloudformation.DescribeStacksInput{
+func (s sdk) StackExists(ctx context.Context, name string) (bool, error) {
+	stacks, err := s.CF.DescribeStacksWithContext(aws.Context(ctx), &cloudformation.DescribeStacksInput{
 		StackName: aws.String(name),
 		StackName: aws.String(name),
 	})
 	})
 	if err != nil {
 	if err != nil {
@@ -155,14 +156,14 @@ func (s sdk) StackExists(name string) (bool, error) {
 	return len(stacks.Stacks) > 0, nil
 	return len(stacks.Stacks) > 0, nil
 }
 }
 
 
-func (s sdk) CreateStack(name string, template *cf.Template) error {
+func (s sdk) CreateStack(ctx context.Context, name string, template *cf.Template) error {
 	logrus.Debug("Create CloudFormation stack")
 	logrus.Debug("Create CloudFormation stack")
 	json, err := template.JSON()
 	json, err := template.JSON()
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	_, err = s.CF.CreateStack(&cloudformation.CreateStackInput{
+	_, err = s.CF.CreateStackWithContext(aws.Context(ctx), &cloudformation.CreateStackInput{
 		OnFailure:        aws.String("DELETE"),
 		OnFailure:        aws.String("DELETE"),
 		StackName:        aws.String(name),
 		StackName:        aws.String(name),
 		TemplateBody:     aws.String(string(json)),
 		TemplateBody:     aws.String(string(json)),
@@ -171,17 +172,17 @@ func (s sdk) CreateStack(name string, template *cf.Template) error {
 	return err
 	return err
 }
 }
 
 
-func (s sdk) DescribeStackEvents(name string) error {
+func (s sdk) DescribeStackEvents(ctx context.Context, name string) error {
 	// Fixme implement Paginator on Events and return as a chan(events)
 	// Fixme implement Paginator on Events and return as a chan(events)
-	_, err := s.CF.DescribeStackEvents(&cloudformation.DescribeStackEventsInput{
+	_, err := s.CF.DescribeStackEventsWithContext(aws.Context(ctx), &cloudformation.DescribeStackEventsInput{
 		StackName: aws.String(name),
 		StackName: aws.String(name),
 	})
 	})
 	return err
 	return err
 }
 }
 
 
-func (s sdk) DeleteStack(name string) error {
+func (s sdk) DeleteStack(ctx context.Context, name string) error {
 	logrus.Debug("Delete CloudFormation stack")
 	logrus.Debug("Delete CloudFormation stack")
-	_, err := s.CF.DeleteStack(&cloudformation.DeleteStackInput{
+	_, err := s.CF.DeleteStackWithContext(aws.Context(ctx), &cloudformation.DeleteStackInput{
 		StackName: aws.String(name),
 		StackName: aws.String(name),
 	})
 	})
 	return err
 	return err

+ 13 - 12
ecs/pkg/amazon/up.go

@@ -1,21 +1,22 @@
 package amazon
 package amazon
 
 
 import (
 import (
+	"context"
 	"fmt"
 	"fmt"
 
 
 	"github.com/awslabs/goformation/v4/cloudformation"
 	"github.com/awslabs/goformation/v4/cloudformation"
 	"github.com/docker/ecs-plugin/pkg/compose"
 	"github.com/docker/ecs-plugin/pkg/compose"
 )
 )
 
 
-func (c *client) ComposeUp(project *compose.Project) error {
-	ok, err := c.api.ClusterExists(c.Cluster)
+func (c *client) ComposeUp(ctx context.Context, project *compose.Project) error {
+	ok, err := c.api.ClusterExists(ctx, c.Cluster)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 	if !ok {
 	if !ok {
-		c.api.CreateCluster(c.Cluster)
+		c.api.CreateCluster(ctx, c.Cluster)
 	}
 	}
-	update, err := c.api.StackExists(project.Name)
+	update, err := c.api.StackExists(ctx, project.Name)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -23,17 +24,17 @@ func (c *client) ComposeUp(project *compose.Project) error {
 		return fmt.Errorf("we do not (yet) support updating an existing CloudFormation stack")
 		return fmt.Errorf("we do not (yet) support updating an existing CloudFormation stack")
 	}
 	}
 
 
-	template, err := c.Convert(project)
+	template, err := c.Convert(ctx, project)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	err = c.api.CreateStack(project.Name, template)
+	err = c.api.CreateStack(ctx, project.Name, template)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	err = c.api.DescribeStackEvents(project.Name)
+	err = c.api.DescribeStackEvents(ctx, project.Name)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -43,9 +44,9 @@ func (c *client) ComposeUp(project *compose.Project) error {
 }
 }
 
 
 type upAPI interface {
 type upAPI interface {
-	ClusterExists(name string) (bool, error)
-	CreateCluster(name string) (string, error)
-	StackExists(name string) (bool, error)
-	CreateStack(name string, template *cloudformation.Template) error
-	DescribeStackEvents(stack string) error
+	ClusterExists(ctx context.Context, name string) (bool, error)
+	CreateCluster(ctx context.Context, name string) (string, error)
+	StackExists(ctx context.Context, name string) (bool, error)
+	CreateStack(ctx context.Context, name string, template *cloudformation.Template) error
+	DescribeStackEvents(ctx context.Context, stack string) error
 }
 }

+ 8 - 4
ecs/pkg/compose/api.go

@@ -1,9 +1,13 @@
 package compose
 package compose
 
 
-import "github.com/awslabs/goformation/v4/cloudformation"
+import (
+	"context"
+
+	"github.com/awslabs/goformation/v4/cloudformation"
+)
 
 
 type API interface {
 type API interface {
-	Convert(project *Project) (*cloudformation.Template, error)
-	ComposeUp(project *Project) error
-	ComposeDown(projectName string, deleteCluster bool) error
+	Convert(ctx context.Context, project *Project) (*cloudformation.Template, error)
+	ComposeUp(ctx context.Context, project *Project) error
+	ComposeDown(ctx context.Context, projectName string, deleteCluster bool) error
 }
 }