瀏覽代碼

Merge pull request #3291 from aanand/parallel-execute-refactor

Parallel execution refactor/fixes
Daniel Nephin 9 年之前
父節點
當前提交
e5443717fb
共有 3 個文件被更改,包括 169 次插入37 次删除
  1. 76 37
      compose/parallel.py
  2. 3 0
      compose/service.py
  3. 90 0
      tests/unit/parallel_test.py

+ 76 - 37
compose/parallel.py

@@ -1,6 +1,7 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 from __future__ import unicode_literals
 from __future__ import unicode_literals
 
 
+import logging
 import operator
 import operator
 import sys
 import sys
 from threading import Thread
 from threading import Thread
@@ -14,6 +15,9 @@ from compose.cli.signals import ShutdownException
 from compose.utils import get_output_stream
 from compose.utils import get_output_stream
 
 
 
 
+log = logging.getLogger(__name__)
+
+
 def parallel_execute(objects, func, get_name, msg, get_deps=None):
 def parallel_execute(objects, func, get_name, msg, get_deps=None):
     """Runs func on objects in parallel while ensuring that func is
     """Runs func on objects in parallel while ensuring that func is
     ran on object only after it is ran on all its dependencies.
     ran on object only after it is ran on all its dependencies.
@@ -28,32 +32,24 @@ def parallel_execute(objects, func, get_name, msg, get_deps=None):
     for obj in objects:
     for obj in objects:
         writer.initialize(get_name(obj))
         writer.initialize(get_name(obj))
 
 
-    q = setup_queue(objects, func, get_deps, get_name)
+    events = parallel_execute_stream(objects, func, get_deps)
 
 
-    done = 0
     errors = {}
     errors = {}
     results = []
     results = []
     error_to_reraise = None
     error_to_reraise = None
 
 
-    while done < len(objects):
-        try:
-            obj, result, exception = q.get(timeout=1)
-        except Empty:
-            continue
-        # See https://github.com/docker/compose/issues/189
-        except thread.error:
-            raise ShutdownException()
-
+    for obj, result, exception in events:
         if exception is None:
         if exception is None:
             writer.write(get_name(obj), 'done')
             writer.write(get_name(obj), 'done')
             results.append(result)
             results.append(result)
         elif isinstance(exception, APIError):
         elif isinstance(exception, APIError):
             errors[get_name(obj)] = exception.explanation
             errors[get_name(obj)] = exception.explanation
             writer.write(get_name(obj), 'error')
             writer.write(get_name(obj), 'error')
+        elif isinstance(exception, UpstreamError):
+            writer.write(get_name(obj), 'error')
         else:
         else:
             errors[get_name(obj)] = exception
             errors[get_name(obj)] = exception
             error_to_reraise = exception
             error_to_reraise = exception
-        done += 1
 
 
     for obj_name, error in errors.items():
     for obj_name, error in errors.items():
         stream.write("\nERROR: for {}  {}\n".format(obj_name, error))
         stream.write("\nERROR: for {}  {}\n".format(obj_name, error))
@@ -68,40 +64,83 @@ def _no_deps(x):
     return []
     return []
 
 
 
 
-def setup_queue(objects, func, get_deps, get_name):
+class State(object):
+    def __init__(self, objects):
+        self.objects = objects
+
+        self.started = set()   # objects being processed
+        self.finished = set()  # objects which have been processed
+        self.failed = set()    # objects which either failed or whose dependencies failed
+
+    def is_done(self):
+        return len(self.finished) + len(self.failed) >= len(self.objects)
+
+    def pending(self):
+        return set(self.objects) - self.started - self.finished - self.failed
+
+
+def parallel_execute_stream(objects, func, get_deps):
     if get_deps is None:
     if get_deps is None:
         get_deps = _no_deps
         get_deps = _no_deps
 
 
     results = Queue()
     results = Queue()
-    started = set()   # objects being processed
-    finished = set()  # objects which have been processed
+    state = State(objects)
+
+    while not state.is_done():
+        for event in feed_queue(objects, func, get_deps, results, state):
+            yield event
 
 
-    def do_op(obj):
         try:
         try:
-            result = func(obj)
-            results.put((obj, result, None))
-        except Exception as e:
-            results.put((obj, None, e))
-
-        finished.add(obj)
-        feed()
-
-    def ready(obj):
-        # Is object ready for performing operation
-        return obj not in started and all(
-            dep not in objects or dep in finished
-            for dep in get_deps(obj)
-        )
-
-    def feed():
-        for obj in filter(ready, objects):
-            started.add(obj)
-            t = Thread(target=do_op, args=(obj,))
+            event = results.get(timeout=0.1)
+        except Empty:
+            continue
+        # See https://github.com/docker/compose/issues/189
+        except thread.error:
+            raise ShutdownException()
+
+        obj, _, exception = event
+        if exception is None:
+            log.debug('Finished processing: {}'.format(obj))
+            state.finished.add(obj)
+        else:
+            log.debug('Failed: {}'.format(obj))
+            state.failed.add(obj)
+
+        yield event
+
+
+def queue_producer(obj, func, results):
+    try:
+        result = func(obj)
+        results.put((obj, result, None))
+    except Exception as e:
+        results.put((obj, None, e))
+
+
+def feed_queue(objects, func, get_deps, results, state):
+    pending = state.pending()
+    log.debug('Pending: {}'.format(pending))
+
+    for obj in pending:
+        deps = get_deps(obj)
+
+        if any(dep in state.failed for dep in deps):
+            log.debug('{} has upstream errors - not processing'.format(obj))
+            yield (obj, None, UpstreamError())
+            state.failed.add(obj)
+        elif all(
+            dep not in objects or dep in state.finished
+            for dep in deps
+        ):
+            log.debug('Starting producer thread for {}'.format(obj))
+            t = Thread(target=queue_producer, args=(obj, func, results))
             t.daemon = True
             t.daemon = True
             t.start()
             t.start()
+            state.started.add(obj)
 
 
-    feed()
-    return results
+
+class UpstreamError(Exception):
+    pass
 
 
 
 
 class ParallelStreamWriter(object):
 class ParallelStreamWriter(object):

+ 3 - 0
compose/service.py

@@ -135,6 +135,9 @@ class Service(object):
         self.networks = networks or {}
         self.networks = networks or {}
         self.options = options
         self.options = options
 
 
+    def __repr__(self):
+        return '<Service: {}>'.format(self.name)
+
     def containers(self, stopped=False, one_off=False, filters={}):
     def containers(self, stopped=False, one_off=False, filters={}):
         filters.update({'label': self.labels(one_off=one_off)})
         filters.update({'label': self.labels(one_off=one_off)})
 
 

+ 90 - 0
tests/unit/parallel_test.py

@@ -0,0 +1,90 @@
+from __future__ import absolute_import
+from __future__ import unicode_literals
+
+import six
+from docker.errors import APIError
+
+from compose.parallel import parallel_execute
+from compose.parallel import parallel_execute_stream
+from compose.parallel import UpstreamError
+
+
+web = 'web'
+db = 'db'
+data_volume = 'data_volume'
+cache = 'cache'
+
+objects = [web, db, data_volume, cache]
+
+deps = {
+    web: [db, cache],
+    db: [data_volume],
+    data_volume: [],
+    cache: [],
+}
+
+
+def get_deps(obj):
+    return deps[obj]
+
+
+def test_parallel_execute():
+    results = parallel_execute(
+        objects=[1, 2, 3, 4, 5],
+        func=lambda x: x * 2,
+        get_name=six.text_type,
+        msg="Doubling",
+    )
+
+    assert sorted(results) == [2, 4, 6, 8, 10]
+
+
+def test_parallel_execute_with_deps():
+    log = []
+
+    def process(x):
+        log.append(x)
+
+    parallel_execute(
+        objects=objects,
+        func=process,
+        get_name=lambda obj: obj,
+        msg="Processing",
+        get_deps=get_deps,
+    )
+
+    assert sorted(log) == sorted(objects)
+
+    assert log.index(data_volume) < log.index(db)
+    assert log.index(db) < log.index(web)
+    assert log.index(cache) < log.index(web)
+
+
+def test_parallel_execute_with_upstream_errors():
+    log = []
+
+    def process(x):
+        if x is data_volume:
+            raise APIError(None, None, "Something went wrong")
+        log.append(x)
+
+    parallel_execute(
+        objects=objects,
+        func=process,
+        get_name=lambda obj: obj,
+        msg="Processing",
+        get_deps=get_deps,
+    )
+
+    assert log == [cache]
+
+    events = [
+        (obj, result, type(exception))
+        for obj, result, exception
+        in parallel_execute_stream(objects, process, get_deps)
+    ]
+
+    assert (cache, None, type(None)) in events
+    assert (data_volume, None, APIError) in events
+    assert (db, None, UpstreamError) in events
+    assert (web, None, UpstreamError) in events