Browse Source

Merge pull request #6100 from docker/5960-parallel-pull-progress

Add progress messages to parallel pull
Joffrey F 7 years ago
parent
commit
db391c03ad

+ 7 - 0
compose/parallel.py

@@ -313,6 +313,13 @@ class ParallelStreamWriter(object):
             self._write_ansi(msg, obj_index, color_func(status))
 
 
+def get_stream_writer():
+    instance = ParallelStreamWriter.instance
+    if instance is None:
+        raise RuntimeError('ParallelStreamWriter has not yet been instantiated')
+    return instance
+
+
 def parallel_operation(containers, operation, options, message):
     parallel_execute(
         containers,

+ 1 - 4
compose/progress_stream.py

@@ -19,12 +19,11 @@ def write_to_stream(s, stream):
 def stream_output(output, stream):
     is_terminal = hasattr(stream, 'isatty') and stream.isatty()
     stream = utils.get_output_stream(stream)
-    all_events = []
     lines = {}
     diff = 0
 
     for event in utils.json_stream(output):
-        all_events.append(event)
+        yield event
         is_progress_event = 'progress' in event or 'progressDetail' in event
 
         if not is_progress_event:
@@ -57,8 +56,6 @@ def stream_output(output, stream):
 
         stream.flush()
 
-    return all_events
-
 
 def print_output_event(event, stream, is_terminal):
     if 'errorDetail' in event:

+ 23 - 2
compose/project.py

@@ -571,16 +571,37 @@ class Project(object):
     def pull(self, service_names=None, ignore_pull_failures=False, parallel_pull=False, silent=False,
              include_deps=False):
         services = self.get_services(service_names, include_deps)
+        msg = not silent and 'Pulling' or None
 
         if parallel_pull:
             def pull_service(service):
-                service.pull(ignore_pull_failures, True)
+                strm = service.pull(ignore_pull_failures, True, stream=True)
+                writer = parallel.get_stream_writer()
+
+                def trunc(s):
+                    if len(s) > 35:
+                        return s[:33] + '...'
+                    return s
+
+                for event in strm:
+                    if 'status' not in event:
+                        continue
+                    status = event['status'].lower()
+                    if 'progressDetail' in event:
+                        detail = event['progressDetail']
+                        if 'current' in detail and 'total' in detail:
+                            percentage = float(detail['current']) / float(detail['total'])
+                            status = '{} ({:.1%})'.format(status, percentage)
+
+                    writer.write(
+                        msg, service.name, trunc(status), lambda s: s
+                    )
 
             _, errors = parallel.parallel_execute(
                 services,
                 pull_service,
                 operator.attrgetter('name'),
-                not silent and 'Pulling' or None,
+                msg,
                 limit=5,
             )
             if len(errors):

+ 23 - 16
compose/service.py

@@ -1074,7 +1074,7 @@ class Service(object):
         )
 
         try:
-            all_events = stream_output(build_output, sys.stdout)
+            all_events = list(stream_output(build_output, sys.stdout))
         except StreamOutputError as e:
             raise BuildError(self, six.text_type(e))
 
@@ -1168,7 +1168,23 @@ class Service(object):
 
         return any(has_host_port(binding) for binding in self.options.get('ports', []))
 
-    def pull(self, ignore_pull_failures=False, silent=False):
+    def _do_pull(self, repo, pull_kwargs, silent, ignore_pull_failures):
+        try:
+            output = self.client.pull(repo, **pull_kwargs)
+            if silent:
+                with open(os.devnull, 'w') as devnull:
+                    for event in stream_output(output, devnull):
+                        yield event
+            else:
+                for event in stream_output(output, sys.stdout):
+                    yield event
+        except (StreamOutputError, NotFound) as e:
+            if not ignore_pull_failures:
+                raise
+            else:
+                log.error(six.text_type(e))
+
+    def pull(self, ignore_pull_failures=False, silent=False, stream=False):
         if 'image' not in self.options:
             return
 
@@ -1185,20 +1201,11 @@ class Service(object):
             raise OperationFailedError(
                 'Impossible to perform platform-targeted pulls for API version < 1.35'
             )
-        try:
-            output = self.client.pull(repo, **kwargs)
-            if silent:
-                with open(os.devnull, 'w') as devnull:
-                    return progress_stream.get_digest_from_pull(
-                        stream_output(output, devnull))
-            else:
-                return progress_stream.get_digest_from_pull(
-                    stream_output(output, sys.stdout))
-        except (StreamOutputError, NotFound) as e:
-            if not ignore_pull_failures:
-                raise
-            else:
-                log.error(six.text_type(e))
+
+        event_stream = self._do_pull(repo, kwargs, silent, ignore_pull_failures)
+        if stream:
+            return event_stream
+        return progress_stream.get_digest_from_pull(event_stream)
 
     def push(self, ignore_push_failures=False):
         if 'image' not in self.options or 'build' not in self.options:

+ 3 - 1
tests/integration/testcases.py

@@ -139,7 +139,9 @@ class DockerClientTestCase(unittest.TestCase):
     def check_build(self, *args, **kwargs):
         kwargs.setdefault('rm', True)
         build_output = self.client.build(*args, **kwargs)
-        stream_output(build_output, open('/dev/null', 'w'))
+        with open(os.devnull, 'w') as devnull:
+            for event in stream_output(build_output, devnull):
+                pass
 
     def require_api_version(self, minimum):
         api_version = self.client.version()['ApiVersion']

+ 6 - 6
tests/unit/progress_stream_test.py

@@ -21,7 +21,7 @@ class ProgressStreamTestCase(unittest.TestCase):
             b'31019763, "start": 1413653874, "total": 62763875}, '
             b'"progress": "..."}',
         ]
-        events = progress_stream.stream_output(output, StringIO())
+        events = list(progress_stream.stream_output(output, StringIO()))
         assert len(events) == 1
 
     def test_stream_output_div_zero(self):
@@ -30,7 +30,7 @@ class ProgressStreamTestCase(unittest.TestCase):
             b'0, "start": 1413653874, "total": 0}, '
             b'"progress": "..."}',
         ]
-        events = progress_stream.stream_output(output, StringIO())
+        events = list(progress_stream.stream_output(output, StringIO()))
         assert len(events) == 1
 
     def test_stream_output_null_total(self):
@@ -39,7 +39,7 @@ class ProgressStreamTestCase(unittest.TestCase):
             b'0, "start": 1413653874, "total": null}, '
             b'"progress": "..."}',
         ]
-        events = progress_stream.stream_output(output, StringIO())
+        events = list(progress_stream.stream_output(output, StringIO()))
         assert len(events) == 1
 
     def test_stream_output_progress_event_tty(self):
@@ -52,7 +52,7 @@ class ProgressStreamTestCase(unittest.TestCase):
                 return True
 
         output = TTYStringIO()
-        events = progress_stream.stream_output(events, output)
+        events = list(progress_stream.stream_output(events, output))
         assert len(output.getvalue()) > 0
 
     def test_stream_output_progress_event_no_tty(self):
@@ -61,7 +61,7 @@ class ProgressStreamTestCase(unittest.TestCase):
         ]
         output = StringIO()
 
-        events = progress_stream.stream_output(events, output)
+        events = list(progress_stream.stream_output(events, output))
         assert len(output.getvalue()) == 0
 
     def test_stream_output_no_progress_event_no_tty(self):
@@ -70,7 +70,7 @@ class ProgressStreamTestCase(unittest.TestCase):
         ]
         output = StringIO()
 
-        events = progress_stream.stream_output(events, output)
+        events = list(progress_stream.stream_output(events, output))
         assert len(output.getvalue()) > 0
 
     def test_mismatched_encoding_stream_write(self):