Sfoglia il codice sorgente

Abort operations if their dependencies fail

Signed-off-by: Aanand Prasad <[email protected]>
Aanand Prasad 9 anni fa
parent
commit
141b96bb31
2 ha cambiato i file con 124 aggiunte e 35 eliminazioni
  1. 51 35
      compose/parallel.py
  2. 73 0
      tests/unit/parallel_test.py

+ 51 - 35
compose/parallel.py

@@ -32,7 +32,7 @@ def parallel_execute(objects, func, get_name, msg, get_deps=None):
     for obj in objects:
         writer.initialize(get_name(obj))
 
-    q = setup_queue(objects, func, get_deps, get_name)
+    q = setup_queue(objects, func, get_deps)
 
     done = 0
     errors = {}
@@ -54,6 +54,8 @@ def parallel_execute(objects, func, get_name, msg, get_deps=None):
         elif isinstance(exception, APIError):
             errors[get_name(obj)] = exception.explanation
             writer.write(get_name(obj), 'error')
+        elif isinstance(exception, UpstreamError):
+            writer.write(get_name(obj), 'error')
         else:
             errors[get_name(obj)] = exception
             error_to_reraise = exception
@@ -72,58 +74,72 @@ def _no_deps(x):
     return []
 
 
-def setup_queue(objects, func, get_deps, get_name):
+def setup_queue(objects, func, get_deps):
     if get_deps is None:
         get_deps = _no_deps
 
     results = Queue()
     output = Queue()
 
-    def consumer():
-        started = set()   # objects being processed
-        finished = set()  # objects which have been processed
-
-        def ready(obj):
-            """
-            Returns true if obj is ready to be processed:
-              - all dependencies have been processed
-              - obj is not already being processed
-            """
-            return obj not in started and all(
-                dep not in objects or dep in finished
-                for dep in get_deps(obj)
-            )
+    t = Thread(target=queue_consumer, args=(objects, func, get_deps, results, output))
+    t.daemon = True
+    t.start()
+
+    return output
+
+
+def queue_producer(obj, func, results):
+    try:
+        result = func(obj)
+        results.put((obj, result, None))
+    except Exception as e:
+        results.put((obj, None, e))
+
+
+def queue_consumer(objects, func, get_deps, results, output):
+    started = set()   # objects being processed
+    finished = set()  # objects which have been processed
+    failed = set()    # objects which either failed or whose dependencies failed
+
+    while len(finished) + len(failed) < len(objects):
+        pending = set(objects) - started - finished - failed
+        log.debug('Pending: {}'.format(pending))
 
-        while len(finished) < len(objects):
-            for obj in filter(ready, objects):
+        for obj in pending:
+            deps = get_deps(obj)
+
+            if any(dep in failed for dep in deps):
+                log.debug('{} has upstream errors - not processing'.format(obj))
+                output.put((obj, None, UpstreamError()))
+                failed.add(obj)
+            elif all(
+                dep not in objects or dep in finished
+                for dep in deps
+            ):
                 log.debug('Starting producer thread for {}'.format(obj))
-                t = Thread(target=producer, args=(obj,))
+                t = Thread(target=queue_producer, args=(obj, func, results))
                 t.daemon = True
                 t.start()
                 started.add(obj)
 
-            try:
-                event = results.get(timeout=1)
-            except Empty:
-                continue
+        try:
+            event = results.get(timeout=1)
+        except Empty:
+            continue
 
-            obj = event[0]
+        obj, _, exception = event
+        if exception is None:
             log.debug('Finished processing: {}'.format(obj))
             finished.add(obj)
-            output.put(event)
+        else:
+            log.debug('Failed: {}'.format(obj))
+            failed.add(obj)
 
-    def producer(obj):
-        try:
-            result = func(obj)
-            results.put((obj, result, None))
-        except Exception as e:
-            results.put((obj, None, e))
+        output.put(event)
 
-    t = Thread(target=consumer)
-    t.daemon = True
-    t.start()
 
-    return output
+class UpstreamError(Exception):
+    pass
 
 
 class ParallelStreamWriter(object):

+ 73 - 0
tests/unit/parallel_test.py

@@ -0,0 +1,73 @@
+from __future__ import absolute_import
+from __future__ import unicode_literals
+
+import six
+from docker.errors import APIError
+
+from compose.parallel import parallel_execute
+
+
+web = 'web'
+db = 'db'
+data_volume = 'data_volume'
+cache = 'cache'
+
+objects = [web, db, data_volume, cache]
+
+deps = {
+    web: [db, cache],
+    db: [data_volume],
+    data_volume: [],
+    cache: [],
+}
+
+
+def test_parallel_execute():
+    results = parallel_execute(
+        objects=[1, 2, 3, 4, 5],
+        func=lambda x: x * 2,
+        get_name=six.text_type,
+        msg="Doubling",
+    )
+
+    assert sorted(results) == [2, 4, 6, 8, 10]
+
+
+def test_parallel_execute_with_deps():
+    log = []
+
+    def process(x):
+        log.append(x)
+
+    parallel_execute(
+        objects=objects,
+        func=process,
+        get_name=lambda obj: obj,
+        msg="Processing",
+        get_deps=lambda obj: deps[obj],
+    )
+
+    assert sorted(log) == sorted(objects)
+
+    assert log.index(data_volume) < log.index(db)
+    assert log.index(db) < log.index(web)
+    assert log.index(cache) < log.index(web)
+
+
+def test_parallel_execute_with_upstream_errors():
+    log = []
+
+    def process(x):
+        if x is data_volume:
+            raise APIError(None, None, "Something went wrong")
+        log.append(x)
+
+    parallel_execute(
+        objects=objects,
+        func=process,
+        get_name=lambda obj: obj,
+        msg="Processing",
+        get_deps=lambda obj: deps[obj],
+    )
+
+    assert log == [cache]