Przeglądaj źródła

Merge pull request #5537 from docker/Rozelette-branch_limit_test

Add configurable parallel operations limit
Joffrey F 7 lat temu
rodzic
commit
b97b76d24f
4 zmienionych plików z 145 dodań i 86 usunięć
  1. 19 0
      compose/cli/command.py
  2. 1 0
      compose/const.py
  3. 16 1
      compose/parallel.py
  4. 109 85
      tests/unit/parallel_test.py

+ 19 - 0
compose/cli/command.py

@@ -10,6 +10,7 @@ import six
 from . import errors
 from . import verbose_proxy
 from .. import config
+from .. import parallel
 from ..config.environment import Environment
 from ..const import API_VERSIONS
 from ..project import Project
@@ -23,6 +24,8 @@ log = logging.getLogger(__name__)
 
 def project_from_options(project_dir, options):
     environment = Environment.from_env_file(project_dir)
+    set_parallel_limit(environment)
+
     host = options.get('--host')
     if host is not None:
         host = host.lstrip('=')
@@ -38,6 +41,22 @@ def project_from_options(project_dir, options):
     )
 
 
+def set_parallel_limit(environment):
+    parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT')
+    if parallel_limit:
+        try:
+            parallel_limit = int(parallel_limit)
+        except ValueError:
+            raise errors.UserError(
+                'COMPOSE_PARALLEL_LIMIT must be an integer (found: "{}")'.format(
+                    environment.get('COMPOSE_PARALLEL_LIMIT')
+                )
+            )
+        if parallel_limit <= 1:
+            raise errors.UserError('COMPOSE_PARALLEL_LIMIT can not be less than 2')
+        parallel.GlobalLimit.set_global_limit(parallel_limit)
+
+
 def get_config_from_options(base_dir, options):
     environment = Environment.from_env_file(base_dir)
     config_path = get_config_path_from_options(

+ 1 - 0
compose/const.py

@@ -18,6 +18,7 @@ LABEL_VERSION = 'com.docker.compose.version'
 LABEL_VOLUME = 'com.docker.compose.volume'
 LABEL_CONFIG_HASH = 'com.docker.compose.config-hash'
 NANOCPUS_SCALE = 1000000000
+PARALLEL_LIMIT = 64
 
 SECRETS_PATH = '/run/secrets'
 

+ 16 - 1
compose/parallel.py

@@ -15,6 +15,7 @@ from six.moves.queue import Queue
 from compose.cli.colors import green
 from compose.cli.colors import red
 from compose.cli.signals import ShutdownException
+from compose.const import PARALLEL_LIMIT
 from compose.errors import HealthCheckFailed
 from compose.errors import NoHealthCheckConfigured
 from compose.errors import OperationFailedError
@@ -26,6 +27,20 @@ log = logging.getLogger(__name__)
 STOP = object()
 
 
+class GlobalLimit(object):
+    """Simple class to hold a global semaphore limiter for a project. This class
+    should be treated as a singleton that is instantiated when the project is.
+    """
+
+    global_limiter = Semaphore(PARALLEL_LIMIT)
+
+    @classmethod
+    def set_global_limit(cls, value):
+        if value is None:
+            value = PARALLEL_LIMIT
+        cls.global_limiter = Semaphore(value)
+
+
 def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, parent_objects=None):
     """Runs func on objects in parallel while ensuring that func is
     ran on object only after it is ran on all its dependencies.
@@ -173,7 +188,7 @@ def producer(obj, func, results, limiter):
     The entry point for a producer thread which runs func on a single object.
     Places a tuple on the results queue once func has either returned or raised.
     """
-    with limiter:
+    with limiter, GlobalLimit.global_limiter:
         try:
             result = func(obj)
             results.put((obj, result, None))

+ 109 - 85
tests/unit/parallel_test.py

@@ -1,11 +1,13 @@
 from __future__ import absolute_import
 from __future__ import unicode_literals
 
+import unittest
 from threading import Lock
 
 import six
 from docker.errors import APIError
 
+from compose.parallel import GlobalLimit
 from compose.parallel import parallel_execute
 from compose.parallel import parallel_execute_iter
 from compose.parallel import ParallelStreamWriter
@@ -31,91 +33,113 @@ def get_deps(obj):
     return [(dep, None) for dep in deps[obj]]
 
 
-def test_parallel_execute():
-    results, errors = parallel_execute(
-        objects=[1, 2, 3, 4, 5],
-        func=lambda x: x * 2,
-        get_name=six.text_type,
-        msg="Doubling",
-    )
-
-    assert sorted(results) == [2, 4, 6, 8, 10]
-    assert errors == {}
-
-
-def test_parallel_execute_with_limit():
-    limit = 1
-    tasks = 20
-    lock = Lock()
-
-    def f(obj):
-        locked = lock.acquire(False)
-        # we should always get the lock because we're the only thread running
-        assert locked
-        lock.release()
-        return None
-
-    results, errors = parallel_execute(
-        objects=list(range(tasks)),
-        func=f,
-        get_name=six.text_type,
-        msg="Testing",
-        limit=limit,
-    )
-
-    assert results == tasks * [None]
-    assert errors == {}
-
-
-def test_parallel_execute_with_deps():
-    log = []
-
-    def process(x):
-        log.append(x)
-
-    parallel_execute(
-        objects=objects,
-        func=process,
-        get_name=lambda obj: obj,
-        msg="Processing",
-        get_deps=get_deps,
-    )
-
-    assert sorted(log) == sorted(objects)
-
-    assert log.index(data_volume) < log.index(db)
-    assert log.index(db) < log.index(web)
-    assert log.index(cache) < log.index(web)
-
-
-def test_parallel_execute_with_upstream_errors():
-    log = []
-
-    def process(x):
-        if x is data_volume:
-            raise APIError(None, None, "Something went wrong")
-        log.append(x)
-
-    parallel_execute(
-        objects=objects,
-        func=process,
-        get_name=lambda obj: obj,
-        msg="Processing",
-        get_deps=get_deps,
-    )
-
-    assert log == [cache]
-
-    events = [
-        (obj, result, type(exception))
-        for obj, result, exception
-        in parallel_execute_iter(objects, process, get_deps, None)
-    ]
-
-    assert (cache, None, type(None)) in events
-    assert (data_volume, None, APIError) in events
-    assert (db, None, UpstreamError) in events
-    assert (web, None, UpstreamError) in events
+class ParallelTest(unittest.TestCase):
+
+    def test_parallel_execute(self):
+        results, errors = parallel_execute(
+            objects=[1, 2, 3, 4, 5],
+            func=lambda x: x * 2,
+            get_name=six.text_type,
+            msg="Doubling",
+        )
+
+        assert sorted(results) == [2, 4, 6, 8, 10]
+        assert errors == {}
+
+    def test_parallel_execute_with_limit(self):
+        limit = 1
+        tasks = 20
+        lock = Lock()
+
+        def f(obj):
+            locked = lock.acquire(False)
+            # we should always get the lock because we're the only thread running
+            assert locked
+            lock.release()
+            return None
+
+        results, errors = parallel_execute(
+            objects=list(range(tasks)),
+            func=f,
+            get_name=six.text_type,
+            msg="Testing",
+            limit=limit,
+        )
+
+        assert results == tasks * [None]
+        assert errors == {}
+
+    def test_parallel_execute_with_global_limit(self):
+        GlobalLimit.set_global_limit(1)
+        self.addCleanup(GlobalLimit.set_global_limit, None)
+        tasks = 20
+        lock = Lock()
+
+        def f(obj):
+            locked = lock.acquire(False)
+            # we should always get the lock because we're the only thread running
+            assert locked
+            lock.release()
+            return None
+
+        results, errors = parallel_execute(
+            objects=list(range(tasks)),
+            func=f,
+            get_name=six.text_type,
+            msg="Testing",
+        )
+
+        assert results == tasks * [None]
+        assert errors == {}
+
+    def test_parallel_execute_with_deps(self):
+        log = []
+
+        def process(x):
+            log.append(x)
+
+        parallel_execute(
+            objects=objects,
+            func=process,
+            get_name=lambda obj: obj,
+            msg="Processing",
+            get_deps=get_deps,
+        )
+
+        assert sorted(log) == sorted(objects)
+
+        assert log.index(data_volume) < log.index(db)
+        assert log.index(db) < log.index(web)
+        assert log.index(cache) < log.index(web)
+
+    def test_parallel_execute_with_upstream_errors(self):
+        log = []
+
+        def process(x):
+            if x is data_volume:
+                raise APIError(None, None, "Something went wrong")
+            log.append(x)
+
+        parallel_execute(
+            objects=objects,
+            func=process,
+            get_name=lambda obj: obj,
+            msg="Processing",
+            get_deps=get_deps,
+        )
+
+        assert log == [cache]
+
+        events = [
+            (obj, result, type(exception))
+            for obj, result, exception
+            in parallel_execute_iter(objects, process, get_deps, None)
+        ]
+
+        assert (cache, None, type(None)) in events
+        assert (data_volume, None, APIError) in events
+        assert (db, None, UpstreamError) in events
+        assert (web, None, UpstreamError) in events
 
 
 def test_parallel_execute_alignment(capsys):