Browse Source

Pass region to create s3 bucket into

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 5 years ago
parent
commit
10a384d35b
4 changed files with 56 additions and 51 deletions
  1. 2 2
      ecs/aws.go
  2. 8 8
      ecs/aws_mock.go
  3. 44 39
      ecs/sdk.go
  4. 2 2
      ecs/up.go

+ 2 - 2
ecs/aws.go

@@ -42,8 +42,8 @@ type API interface {
 	GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
 	GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
 	GetRoleArn(ctx context.Context, name string) (string, error)
 	GetRoleArn(ctx context.Context, name string) (string, error)
 	StackExists(ctx context.Context, name string) (bool, error)
 	StackExists(ctx context.Context, name string) (bool, error)
-	CreateStack(ctx context.Context, name string, template []byte) error
-	CreateChangeSet(ctx context.Context, name string, template []byte) (string, error)
+	CreateStack(ctx context.Context, name string, region string, template []byte) error
+	CreateChangeSet(ctx context.Context, name string, region string, template []byte) (string, error)
 	UpdateStack(ctx context.Context, changeset string) error
 	UpdateStack(ctx context.Context, changeset string) error
 	WaitStackComplete(ctx context.Context, name string, operation int) error
 	WaitStackComplete(ctx context.Context, name string, operation int) error
 	GetStackID(ctx context.Context, name string) (string, error)
 	GetStackID(ctx context.Context, name string) (string, error)

+ 8 - 8
ecs/aws_mock.go

@@ -66,18 +66,18 @@ func (mr *MockAPIMockRecorder) CheckVPC(arg0, arg1 interface{}) *gomock.Call {
 }
 }
 
 
 // CreateChangeSet mocks base method
 // CreateChangeSet mocks base method
-func (m *MockAPI) CreateChangeSet(arg0 context.Context, arg1 string, arg2 []byte) (string, error) {
+func (m *MockAPI) CreateChangeSet(arg0 context.Context, arg1, arg2 string, arg3 []byte) (string, error) {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "CreateChangeSet", arg0, arg1, arg2)
+	ret := m.ctrl.Call(m, "CreateChangeSet", arg0, arg1, arg2, arg3)
 	ret0, _ := ret[0].(string)
 	ret0, _ := ret[0].(string)
 	ret1, _ := ret[1].(error)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 	return ret0, ret1
 }
 }
 
 
 // CreateChangeSet indicates an expected call of CreateChangeSet
 // CreateChangeSet indicates an expected call of CreateChangeSet
-func (mr *MockAPIMockRecorder) CreateChangeSet(arg0, arg1, arg2 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) CreateChangeSet(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateChangeSet", reflect.TypeOf((*MockAPI)(nil).CreateChangeSet), arg0, arg1, arg2)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateChangeSet", reflect.TypeOf((*MockAPI)(nil).CreateChangeSet), arg0, arg1, arg2, arg3)
 }
 }
 
 
 // CreateCluster mocks base method
 // CreateCluster mocks base method
@@ -126,17 +126,17 @@ func (mr *MockAPIMockRecorder) CreateSecret(arg0, arg1 interface{}) *gomock.Call
 }
 }
 
 
 // CreateStack mocks base method
 // CreateStack mocks base method
-func (m *MockAPI) CreateStack(arg0 context.Context, arg1 string, arg2 []byte) error {
+func (m *MockAPI) CreateStack(arg0 context.Context, arg1, arg2 string, arg3 []byte) error {
 	m.ctrl.T.Helper()
 	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "CreateStack", arg0, arg1, arg2)
+	ret := m.ctrl.Call(m, "CreateStack", arg0, arg1, arg2, arg3)
 	ret0, _ := ret[0].(error)
 	ret0, _ := ret[0].(error)
 	return ret0
 	return ret0
 }
 }
 
 
 // CreateStack indicates an expected call of CreateStack
 // CreateStack indicates an expected call of CreateStack
-func (mr *MockAPIMockRecorder) CreateStack(arg0, arg1, arg2 interface{}) *gomock.Call {
+func (mr *MockAPIMockRecorder) CreateStack(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
 	mr.mock.ctrl.T.Helper()
 	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateStack", reflect.TypeOf((*MockAPI)(nil).CreateStack), arg0, arg1, arg2)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateStack", reflect.TypeOf((*MockAPI)(nil).CreateStack), arg0, arg1, arg2, arg3)
 }
 }
 
 
 // DeleteAutoscalingGroup mocks base method
 // DeleteAutoscalingGroup mocks base method

+ 44 - 39
ecs/sdk.go

@@ -21,11 +21,6 @@ import (
 	"context"
 	"context"
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
-	"github.com/aws/aws-sdk-go/aws/awserr"
-	"github.com/aws/aws-sdk-go/service/s3"
-	"github.com/aws/aws-sdk-go/service/s3/s3iface"
-	"github.com/aws/aws-sdk-go/service/s3/s3manager"
-	"github.com/hashicorp/go-uuid"
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
@@ -36,6 +31,7 @@ import (
 
 
 	"github.com/aws/aws-sdk-go/aws"
 	"github.com/aws/aws-sdk-go/aws"
 	"github.com/aws/aws-sdk-go/aws/arn"
 	"github.com/aws/aws-sdk-go/aws/arn"
+	"github.com/aws/aws-sdk-go/aws/awserr"
 	"github.com/aws/aws-sdk-go/aws/request"
 	"github.com/aws/aws-sdk-go/aws/request"
 	"github.com/aws/aws-sdk-go/aws/session"
 	"github.com/aws/aws-sdk-go/aws/session"
 	"github.com/aws/aws-sdk-go/service/autoscaling"
 	"github.com/aws/aws-sdk-go/service/autoscaling"
@@ -54,27 +50,31 @@ import (
 	"github.com/aws/aws-sdk-go/service/elbv2/elbv2iface"
 	"github.com/aws/aws-sdk-go/service/elbv2/elbv2iface"
 	"github.com/aws/aws-sdk-go/service/iam"
 	"github.com/aws/aws-sdk-go/service/iam"
 	"github.com/aws/aws-sdk-go/service/iam/iamiface"
 	"github.com/aws/aws-sdk-go/service/iam/iamiface"
+	"github.com/aws/aws-sdk-go/service/s3"
+	"github.com/aws/aws-sdk-go/service/s3/s3iface"
+	"github.com/aws/aws-sdk-go/service/s3/s3manager"
 	"github.com/aws/aws-sdk-go/service/secretsmanager"
 	"github.com/aws/aws-sdk-go/service/secretsmanager"
 	"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
 	"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
 	"github.com/aws/aws-sdk-go/service/ssm"
 	"github.com/aws/aws-sdk-go/service/ssm"
 	"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
 	"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
 	"github.com/hashicorp/go-multierror"
 	"github.com/hashicorp/go-multierror"
+	"github.com/hashicorp/go-uuid"
 	"github.com/pkg/errors"
 	"github.com/pkg/errors"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 )
 )
 
 
 type sdk struct {
 type sdk struct {
-	ECS ecsiface.ECSAPI
-	EC2 ec2iface.EC2API
-	EFS efsiface.EFSAPI
-	ELB elbv2iface.ELBV2API
-	CW  cloudwatchlogsiface.CloudWatchLogsAPI
-	IAM iamiface.IAMAPI
-	CF  cloudformationiface.CloudFormationAPI
-	SM  secretsmanageriface.SecretsManagerAPI
-	SSM ssmiface.SSMAPI
-	AG  autoscalingiface.AutoScalingAPI
-	S3 s3iface.S3API
+	ECS      ecsiface.ECSAPI
+	EC2      ec2iface.EC2API
+	EFS      efsiface.EFSAPI
+	ELB      elbv2iface.ELBV2API
+	CW       cloudwatchlogsiface.CloudWatchLogsAPI
+	IAM      iamiface.IAMAPI
+	CF       cloudformationiface.CloudFormationAPI
+	SM       secretsmanageriface.SecretsManagerAPI
+	SSM      ssmiface.SSMAPI
+	AG       autoscalingiface.AutoScalingAPI
+	S3       s3iface.S3API
 	uploader *s3manager.Uploader
 	uploader *s3manager.Uploader
 }
 }
 
 
@@ -86,17 +86,17 @@ func newSDK(sess *session.Session) sdk {
 		request.AddToUserAgent(r, internal.ECSUserAgentName+"/"+internal.Version)
 		request.AddToUserAgent(r, internal.ECSUserAgentName+"/"+internal.Version)
 	})
 	})
 	return sdk{
 	return sdk{
-		ECS: ecs.New(sess),
-		EC2: ec2.New(sess),
-		EFS: efs.New(sess),
-		ELB: elbv2.New(sess),
-		CW:  cloudwatchlogs.New(sess),
-		IAM: iam.New(sess),
-		CF:  cloudformation.New(sess),
-		SM:  secretsmanager.New(sess),
-		SSM: ssm.New(sess),
-		AG:  autoscaling.New(sess),
-		S3:  s3.New(sess),
+		ECS:      ecs.New(sess),
+		EC2:      ec2.New(sess),
+		EFS:      efs.New(sess),
+		ELB:      elbv2.New(sess),
+		CW:       cloudwatchlogs.New(sess),
+		IAM:      iam.New(sess),
+		CF:       cloudformation.New(sess),
+		SM:       secretsmanager.New(sess),
+		SSM:      ssm.New(sess),
+		AG:       autoscaling.New(sess),
+		S3:       s3.New(sess),
 		uploader: s3manager.NewUploader(sess),
 		uploader: s3manager.NewUploader(sess),
 	}
 	}
 }
 }
@@ -197,11 +197,9 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error
 			return nil, err
 			return nil, err
 		}
 		}
 		for _, subnet := range subnets.Subnets {
 		for _, subnet := range subnets.Subnets {
-			id := aws.StringValue(subnet.SubnetId)
-			logrus.Debugf("Found SubNet %s", id)
 			ids = append(ids, existingAWSResource{
 			ids = append(ids, existingAWSResource{
 				arn: aws.StringValue(subnet.SubnetArn),
 				arn: aws.StringValue(subnet.SubnetArn),
-				id:  id,
+				id:  aws.StringValue(subnet.SubnetId),
 			})
 			})
 		}
 		}
 
 
@@ -238,10 +236,17 @@ func (s sdk) StackExists(ctx context.Context, name string) (bool, error) {
 
 
 type uploadedTemplateFunc func(ctx context.Context, name string, url string) (string, error)
 type uploadedTemplateFunc func(ctx context.Context, name string, url string) (string, error)
 
 
-func (s sdk) withTemplate(ctx context.Context, name string, template []byte, fn uploadedTemplateFunc) (string, error) {
+func (s sdk) withTemplate(ctx context.Context, name string, template []byte, region string, fn uploadedTemplateFunc) (string, error) {
 	logrus.Debug("Create s3 bucket to store cloudformation template")
 	logrus.Debug("Create s3 bucket to store cloudformation template")
+	var configuration *s3.CreateBucketConfiguration
+	if region != "us-east-1" {
+		configuration = &s3.CreateBucketConfiguration{
+			LocationConstraint: aws.String(region),
+		}
+	}
 	_, err := s.S3.CreateBucket(&s3.CreateBucketInput{
 	_, err := s.S3.CreateBucket(&s3.CreateBucketInput{
-		Bucket: aws.String("com.docker.compose"),
+		Bucket:                    aws.String("com.docker.compose." + region),
+		CreateBucketConfiguration: configuration,
 	})
 	})
 	if err != nil {
 	if err != nil {
 		ae, ok := err.(awserr.Error)
 		ae, ok := err.(awserr.Error)
@@ -261,7 +266,7 @@ func (s sdk) withTemplate(ctx context.Context, name string, template []byte, fn
 	upload, err := s.uploader.UploadWithContext(ctx, &s3manager.UploadInput{
 	upload, err := s.uploader.UploadWithContext(ctx, &s3manager.UploadInput{
 		Key:         aws.String(key),
 		Key:         aws.String(key),
 		Body:        bytes.NewReader(template),
 		Body:        bytes.NewReader(template),
-		Bucket:      aws.String("com.docker.compose"),
+		Bucket:      aws.String("com.docker.compose." + region),
 		ContentType: aws.String("application/json"),
 		ContentType: aws.String("application/json"),
 		Tagging:     aws.String(name),
 		Tagging:     aws.String(name),
 	})
 	})
@@ -270,7 +275,7 @@ func (s sdk) withTemplate(ctx context.Context, name string, template []byte, fn
 		return "", err
 		return "", err
 	}
 	}
 
 
-	defer s.S3.DeleteObjects(&s3.DeleteObjectsInput{
+	defer s.S3.DeleteObjects(&s3.DeleteObjectsInput{ //nolint: errcheck
 		Bucket: aws.String("com.docker.compose"),
 		Bucket: aws.String("com.docker.compose"),
 		Delete: &s3.Delete{
 		Delete: &s3.Delete{
 			Objects: []*s3.ObjectIdentifier{
 			Objects: []*s3.ObjectIdentifier{
@@ -285,10 +290,10 @@ func (s sdk) withTemplate(ctx context.Context, name string, template []byte, fn
 	return fn(ctx, name, upload.Location)
 	return fn(ctx, name, upload.Location)
 }
 }
 
 
-func (s sdk) CreateStack(ctx context.Context, name string, template []byte) error {
+func (s sdk) CreateStack(ctx context.Context, name string, region string, template []byte) error {
 	logrus.Debug("Create CloudFormation stack")
 	logrus.Debug("Create CloudFormation stack")
 
 
-	stackId, err := s.withTemplate(ctx, name, template, func(ctx context.Context, name string, url string) (string, error) {
+	stackID, err := s.withTemplate(ctx, name, template, region, func(ctx context.Context, name string, url string) (string, error) {
 		stack, err := s.CF.CreateStackWithContext(ctx, &cloudformation.CreateStackInput{
 		stack, err := s.CF.CreateStackWithContext(ctx, &cloudformation.CreateStackInput{
 			OnFailure:        aws.String("DELETE"),
 			OnFailure:        aws.String("DELETE"),
 			StackName:        aws.String(name),
 			StackName:        aws.String(name),
@@ -309,14 +314,14 @@ func (s sdk) CreateStack(ctx context.Context, name string, template []byte) erro
 		}
 		}
 		return aws.StringValue(stack.StackId), nil
 		return aws.StringValue(stack.StackId), nil
 	})
 	})
-	logrus.Debugf("Stack %s created", stackId)
+	logrus.Debugf("Stack %s created", stackID)
 	return err
 	return err
 }
 }
 
 
-func (s sdk) CreateChangeSet(ctx context.Context, name string, template []byte) (string, error) {
+func (s sdk) CreateChangeSet(ctx context.Context, name string, region string, template []byte) (string, error) {
 	logrus.Debug("Create CloudFormation Changeset")
 	logrus.Debug("Create CloudFormation Changeset")
 
 
-	changeset, err := s.withTemplate(ctx, name, template, func(ctx context.Context, name string, url string) (string, error) {
+	changeset, err := s.withTemplate(ctx, name, template, region, func(ctx context.Context, name string, url string) (string, error) {
 		update := fmt.Sprintf("Update%s", time.Now().Format("2006-01-02-15-04-05"))
 		update := fmt.Sprintf("Update%s", time.Now().Format("2006-01-02-15-04-05"))
 		changeset, err := s.CF.CreateChangeSetWithContext(ctx, &cloudformation.CreateChangeSetInput{
 		changeset, err := s.CF.CreateChangeSetWithContext(ctx, &cloudformation.CreateChangeSetInput{
 			ChangeSetName: aws.String(update),
 			ChangeSetName: aws.String(update),

+ 2 - 2
ecs/up.go

@@ -44,7 +44,7 @@ func (b *ecsAPIService) Up(ctx context.Context, project *types.Project, detach b
 	operation := stackCreate
 	operation := stackCreate
 	if update {
 	if update {
 		operation = stackUpdate
 		operation = stackUpdate
-		changeset, err := b.aws.CreateChangeSet(ctx, project.Name, template)
+		changeset, err := b.aws.CreateChangeSet(ctx, project.Name, b.Region, template)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -53,7 +53,7 @@ func (b *ecsAPIService) Up(ctx context.Context, project *types.Project, detach b
 			return err
 			return err
 		}
 		}
 	} else {
 	} else {
-		err = b.aws.CreateStack(ctx, project.Name, template)
+		err = b.aws.CreateStack(ctx, project.Name, b.Region, template)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}