Browse Source

Get more from DescribeTask

Signed-off-by: Nicolas De Loof <[email protected]>
Nicolas De Loof 5 years ago
parent
commit
be1c65d441

+ 2 - 2
ecs/pkg/amazon/api.go

@@ -1,11 +1,11 @@
 package amazon
 
-//go:generate mockgen -destination=./mock/api.go -package=mock . API
+//go:generate mockgen -destination=./api_mock.go -self_package "github.com/docker/ecs-plugin/pkg/amazon" -package=amazon . API
 
 type API interface {
 	downAPI
 	upAPI
 	logsAPI
 	secretsAPI
-	psAPI
+	listAPI
 }

+ 37 - 37
ecs/pkg/amazon/api_mock.go

@@ -137,6 +137,26 @@ func (mr *MockAPIMockRecorder) DescribeStackEvents(arg0, arg1 interface{}) *gomo
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeStackEvents", reflect.TypeOf((*MockAPI)(nil).DescribeStackEvents), arg0, arg1)
 }
 
+// DescribeTasks mocks base method
+func (m *MockAPI) DescribeTasks(arg0 context.Context, arg1 string, arg2 ...string) ([]TaskStatus, error) {
+	m.ctrl.T.Helper()
+	varargs := []interface{}{arg0, arg1}
+	for _, a := range arg2 {
+		varargs = append(varargs, a)
+	}
+	ret := m.ctrl.Call(m, "DescribeTasks", varargs...)
+	ret0, _ := ret[0].([]TaskStatus)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// DescribeTasks indicates an expected call of DescribeTasks
+func (mr *MockAPIMockRecorder) DescribeTasks(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	varargs := append([]interface{}{arg0, arg1}, arg2...)
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeTasks", reflect.TypeOf((*MockAPI)(nil).DescribeTasks), varargs...)
+}
+
 // GetDefaultVPC mocks base method
 func (m *MockAPI) GetDefaultVPC(arg0 context.Context) (string, error) {
 	m.ctrl.T.Helper()
@@ -166,35 +186,15 @@ func (mr *MockAPIMockRecorder) GetLogs(arg0, arg1, arg2 interface{}) *gomock.Cal
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogs", reflect.TypeOf((*MockAPI)(nil).GetLogs), arg0, arg1, arg2)
 }
 
-// GetNetworkInterfaces mocks base method
-func (m *MockAPI) GetNetworkInterfaces(arg0 context.Context, arg1 string, arg2 ...string) ([]string, error) {
-	m.ctrl.T.Helper()
-	varargs := []interface{}{arg0, arg1}
-	for _, a := range arg2 {
-		varargs = append(varargs, a)
-	}
-	ret := m.ctrl.Call(m, "GetNetworkInterfaces", varargs...)
-	ret0, _ := ret[0].([]string)
-	ret1, _ := ret[1].(error)
-	return ret0, ret1
-}
-
-// GetNetworkInterfaces indicates an expected call of GetNetworkInterfaces
-func (mr *MockAPIMockRecorder) GetNetworkInterfaces(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
-	mr.mock.ctrl.T.Helper()
-	varargs := append([]interface{}{arg0, arg1}, arg2...)
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkInterfaces", reflect.TypeOf((*MockAPI)(nil).GetNetworkInterfaces), varargs...)
-}
-
 // GetPublicIPs mocks base method
-func (m *MockAPI) GetPublicIPs(arg0 context.Context, arg1 ...string) ([]string, error) {
+func (m *MockAPI) GetPublicIPs(arg0 context.Context, arg1 ...string) (map[string]string, error) {
 	m.ctrl.T.Helper()
 	varargs := []interface{}{arg0}
 	for _, a := range arg1 {
 		varargs = append(varargs, a)
 	}
 	ret := m.ctrl.Call(m, "GetPublicIPs", varargs...)
-	ret0, _ := ret[0].([]string)
+	ret0, _ := ret[0].(map[string]string)
 	ret1, _ := ret[1].(error)
 	return ret0, ret1
 }
@@ -236,21 +236,6 @@ func (mr *MockAPIMockRecorder) GetSubNets(arg0, arg1 interface{}) *gomock.Call {
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubNets", reflect.TypeOf((*MockAPI)(nil).GetSubNets), arg0, arg1)
 }
 
-// GetTasks mocks base method
-func (m *MockAPI) GetTasks(arg0 context.Context, arg1, arg2 string) ([]string, error) {
-	m.ctrl.T.Helper()
-	ret := m.ctrl.Call(m, "GetTasks", arg0, arg1, arg2)
-	ret0, _ := ret[0].([]string)
-	ret1, _ := ret[1].(error)
-	return ret0, ret1
-}
-
-// GetTasks indicates an expected call of GetTasks
-func (mr *MockAPIMockRecorder) GetTasks(arg0, arg1, arg2 interface{}) *gomock.Call {
-	mr.mock.ctrl.T.Helper()
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTasks", reflect.TypeOf((*MockAPI)(nil).GetTasks), arg0, arg1, arg2)
-}
-
 // InspectSecret mocks base method
 func (m *MockAPI) InspectSecret(arg0 context.Context, arg1 string) (docker.Secret, error) {
 	m.ctrl.T.Helper()
@@ -281,6 +266,21 @@ func (mr *MockAPIMockRecorder) ListSecrets(arg0 interface{}) *gomock.Call {
 	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSecrets", reflect.TypeOf((*MockAPI)(nil).ListSecrets), arg0)
 }
 
+// ListTasks mocks base method
+func (m *MockAPI) ListTasks(arg0 context.Context, arg1, arg2 string) ([]string, error) {
+	m.ctrl.T.Helper()
+	ret := m.ctrl.Call(m, "ListTasks", arg0, arg1, arg2)
+	ret0, _ := ret[0].([]string)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// ListTasks indicates an expected call of ListTasks
+func (mr *MockAPIMockRecorder) ListTasks(arg0, arg1, arg2 interface{}) *gomock.Call {
+	mr.mock.ctrl.T.Helper()
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTasks", reflect.TypeOf((*MockAPI)(nil).ListTasks), arg0, arg1, arg2)
+}
+
 // StackExists mocks base method
 func (m *MockAPI) StackExists(arg0 context.Context, arg1 string) (bool, error) {
 	m.ctrl.T.Helper()

+ 1 - 1
ecs/pkg/amazon/convert.go

@@ -77,7 +77,7 @@ func Convert(project *compose.Project, service types.ServiceConfig) (*ecs.TaskDe
 			},
 		},
 		Cpu:                     cpu,
-		Family:                  fmt.Sprintf("%s-%s", project.Name, service.Name),
+		Family:                  project.Name,
 		IpcMode:                 service.Ipc,
 		Memory:                  mem,
 		NetworkMode:             ecsapi.NetworkModeAwsvpc, // FIXME could be set by service.NetworkMode, Fargate only supports network mode ‘awsvpc’.

+ 36 - 24
ecs/pkg/amazon/list.go

@@ -17,39 +17,51 @@ func (c *client) ComposePs(ctx context.Context, project *compose.Project) error
 	}
 	w := tabwriter.NewWriter(os.Stdout, 20, 2, 3, ' ', 0)
 	fmt.Fprintf(w, "Name\tState\tPorts\n")
-	for _, s := range project.Services {
-		tasks, err := c.api.GetTasks(ctx, cluster, s.Name)
-		if err != nil {
-			return err
-		}
-		if len(tasks) == 0 {
-			continue
-		}
-		// TODO get more data from DescribeTask, including tasks status
-		networkInterfaces, err := c.api.GetNetworkInterfaces(ctx, cluster, tasks...)
-		if err != nil {
-			return err
-		}
-		if len(networkInterfaces) == 0 {
-			fmt.Fprintf(w, "%s\t%s\t\n", s.Name, "Provisioning")
-			continue
+	arns, err := c.api.ListTasks(ctx, cluster, project.Name)
+	if err != nil {
+		return err
+	}
+
+	tasks, err := c.api.DescribeTasks(ctx, cluster, arns...)
+	if err != nil {
+		return err
+	}
+
+	networkInterfaces := []string{}
+	for _, t := range tasks {
+		if t.NetworkInterface != "" {
+			networkInterfaces = append(networkInterfaces, t.NetworkInterface)
 		}
-		publicIps, err := c.api.GetPublicIPs(ctx, networkInterfaces...)
+	}
+	publicIps, err := c.api.GetPublicIPs(ctx, networkInterfaces...)
+	if err != nil {
+		return err
+	}
+
+	for _, t := range tasks {
+		ports := []string{}
+		s, err := project.GetService(t.Service)
 		if err != nil {
 			return err
 		}
-		ports := []string{}
 		for _, p := range s.Ports {
-			ports = append(ports, fmt.Sprintf("%s:%d->%d/%s", strings.Join(publicIps, ","), p.Published, p.Target, p.Protocol))
+			ports = append(ports, fmt.Sprintf("%s:%d->%d/%s", publicIps[t.NetworkInterface], p.Published, p.Target, p.Protocol))
 		}
-		fmt.Fprintf(w, "%s\t%s\t%s\n", s.Name, "Up", strings.Join(ports, ", "))
+		fmt.Fprintf(w, "%s\t%s\t%s\n", s.Name, t.State, strings.Join(ports, ", "))
 	}
 	w.Flush()
 	return nil
 }
 
-type psAPI interface {
-	GetTasks(ctx context.Context, cluster string, name string) ([]string, error)
-	GetNetworkInterfaces(ctx context.Context, cluster string, arns ...string) ([]string, error)
-	GetPublicIPs(ctx context.Context, interfaces ...string) ([]string, error)
+type TaskStatus struct {
+	State            string
+	Service          string
+	NetworkInterface string
+	PublicIP         string
+}
+
+type listAPI interface {
+	ListTasks(ctx context.Context, cluster string, name string) ([]string, error)
+	DescribeTasks(ctx context.Context, cluster string, arns ...string) ([]TaskStatus, error)
+	GetPublicIPs(ctx context.Context, interfaces ...string) (map[string]string, error)
 }

+ 16 - 10
ecs/pkg/amazon/sdk.go

@@ -341,10 +341,10 @@ func (s sdk) GetLogs(ctx context.Context, name string, consumer LogConsumer) err
 	}
 }
 
-func (s sdk) GetTasks(ctx context.Context, cluster string, name string) ([]string, error) {
+func (s sdk) ListTasks(ctx context.Context, cluster string, name string) ([]string, error) {
 	tasks, err := s.ECS.ListTasksWithContext(ctx, &ecs.ListTasksInput{
-		Cluster:     aws.String(cluster),
-		ServiceName: aws.String(name),
+		Cluster: aws.String(cluster),
+		Family:  aws.String(name),
 	})
 	if err != nil {
 		return nil, err
@@ -356,7 +356,7 @@ func (s sdk) GetTasks(ctx context.Context, cluster string, name string) ([]strin
 	return arns, nil
 }
 
-func (s sdk) GetNetworkInterfaces(ctx context.Context, cluster string, arns ...string) ([]string, error) {
+func (s sdk) DescribeTasks(ctx context.Context, cluster string, arns ...string) ([]TaskStatus, error) {
 	tasks, err := s.ECS.DescribeTasksWithContext(ctx, &ecs.DescribeTasksInput{
 		Cluster: aws.String(cluster),
 		Tasks:   aws.StringSlice(arns),
@@ -364,32 +364,38 @@ func (s sdk) GetNetworkInterfaces(ctx context.Context, cluster string, arns ...s
 	if err != nil {
 		return nil, err
 	}
-	interfaces := []string{}
+	result := []TaskStatus{}
 	for _, task := range tasks.Tasks {
+		var networkInterface string
 		for _, attachement := range task.Attachments {
 			if *attachement.Type == "ElasticNetworkInterface" {
 				for _, pair := range attachement.Details {
 					if *pair.Name == "networkInterfaceId" {
-						interfaces = append(interfaces, *pair.Value)
+						networkInterface = *pair.Value
 					}
 				}
 			}
 		}
+		result = append(result, TaskStatus{
+			State:            *task.LastStatus,
+			Service:          strings.Replace(*task.Group, "service:", "", 1),
+			NetworkInterface: networkInterface,
+		})
 	}
-	return interfaces, nil
+	return result, nil
 }
 
-func (s sdk) GetPublicIPs(ctx context.Context, interfaces ...string) ([]string, error) {
+func (s sdk) GetPublicIPs(ctx context.Context, interfaces ...string) (map[string]string, error) {
 	desc, err := s.EC2.DescribeNetworkInterfaces(&ec2.DescribeNetworkInterfacesInput{
 		NetworkInterfaceIds: aws.StringSlice(interfaces),
 	})
 	if err != nil {
 		return nil, err
 	}
-	publicIPs := []string{}
+	publicIPs := map[string]string{}
 	for _, interf := range desc.NetworkInterfaces {
 		if interf.Association != nil {
-			publicIPs = append(publicIPs, *interf.Association.PublicIp)
+			publicIPs[*interf.NetworkInterfaceId] = *interf.Association.PublicIp
 		}
 	}
 	return publicIPs, nil

+ 1 - 1
ecs/pkg/amazon/testdata/simple/simple-cloudformation-conversion.golden

@@ -181,7 +181,7 @@
         "ExecutionRoleArn": {
           "Ref": "SimpleTaskExecutionRole"
         },
-        "Family": "TestSimpleConvert-simple",
+        "Family": "TestSimpleConvert",
         "Memory": "512",
         "NetworkMode": "awsvpc",
         "RequiresCompatibilities": [

+ 1 - 1
ecs/pkg/amazon/testdata/simple/simple-cloudformation-with-overrides-conversion.golden

@@ -181,7 +181,7 @@
         "ExecutionRoleArn": {
           "Ref": "SimpleTaskExecutionRole"
         },
-        "Family": "TestSimpleWithOverrides-simple",
+        "Family": "TestSimpleWithOverrides",
         "Memory": "512",
         "NetworkMode": "awsvpc",
         "RequiresCompatibilities": [

+ 3 - 3
ecs/tests/e2e_deploy_services_test.go

@@ -48,10 +48,10 @@ func composeUpSimpleService(t *testing.T, cmd icmd.Cmd, awsContext docker.AwsCon
 	})
 	assert.NilError(t, err)
 	sdk := amazon.NewAPI(session)
-	arns, err := sdk.GetTasks(bgContext, t.Name(), "simple")
+	arns, err := sdk.ListTasks(bgContext, t.Name(), t.Name())
 	assert.NilError(t, err)
-	networkInterfaces, err := sdk.GetNetworkInterfaces(bgContext, t.Name(), arns...)
-	publicIps, err := sdk.GetPublicIPs(context.Background(), networkInterfaces...)
+	tasks, err := sdk.DescribeTasks(bgContext, t.Name(), arns...)
+	publicIps, err := sdk.GetPublicIPs(context.Background(), tasks[0].NetworkInterface)
 	assert.NilError(t, err)
 	for _, ip := range publicIps {
 		icmd.RunCommand("curl", "-I", "http://"+ip).Assert(t, icmd.Success)