浏览代码

Merge pull request #5537 from docker/Rozelette-branch_limit_test

Add configurable parallel operations limit
Joffrey F 7 年之前
父节点
当前提交
b97b76d24f
共有 4 个文件被更改,包括 145 次插入86 次删除
  1. 19 0
      compose/cli/command.py
  2. 1 0
      compose/const.py
  3. 16 1
      compose/parallel.py
  4. 109 85
      tests/unit/parallel_test.py

+ 19 - 0
compose/cli/command.py

@@ -10,6 +10,7 @@ import six
 from . import errors
 from . import errors
 from . import verbose_proxy
 from . import verbose_proxy
 from .. import config
 from .. import config
+from .. import parallel
 from ..config.environment import Environment
 from ..config.environment import Environment
 from ..const import API_VERSIONS
 from ..const import API_VERSIONS
 from ..project import Project
 from ..project import Project
@@ -23,6 +24,8 @@ log = logging.getLogger(__name__)
 
 
 def project_from_options(project_dir, options):
 def project_from_options(project_dir, options):
     environment = Environment.from_env_file(project_dir)
     environment = Environment.from_env_file(project_dir)
+    set_parallel_limit(environment)
+
     host = options.get('--host')
     host = options.get('--host')
     if host is not None:
     if host is not None:
         host = host.lstrip('=')
         host = host.lstrip('=')
@@ -38,6 +41,22 @@ def project_from_options(project_dir, options):
     )
     )
 
 
 
 
+def set_parallel_limit(environment):
+    parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT')
+    if parallel_limit:
+        try:
+            parallel_limit = int(parallel_limit)
+        except ValueError:
+            raise errors.UserError(
+                'COMPOSE_PARALLEL_LIMIT must be an integer (found: "{}")'.format(
+                    environment.get('COMPOSE_PARALLEL_LIMIT')
+                )
+            )
+        if parallel_limit <= 1:
+            raise errors.UserError('COMPOSE_PARALLEL_LIMIT can not be less than 2')
+        parallel.GlobalLimit.set_global_limit(parallel_limit)
+
+
 def get_config_from_options(base_dir, options):
 def get_config_from_options(base_dir, options):
     environment = Environment.from_env_file(base_dir)
     environment = Environment.from_env_file(base_dir)
     config_path = get_config_path_from_options(
     config_path = get_config_path_from_options(

+ 1 - 0
compose/const.py

@@ -18,6 +18,7 @@ LABEL_VERSION = 'com.docker.compose.version'
 LABEL_VOLUME = 'com.docker.compose.volume'
 LABEL_VOLUME = 'com.docker.compose.volume'
 LABEL_CONFIG_HASH = 'com.docker.compose.config-hash'
 LABEL_CONFIG_HASH = 'com.docker.compose.config-hash'
 NANOCPUS_SCALE = 1000000000
 NANOCPUS_SCALE = 1000000000
+PARALLEL_LIMIT = 64
 
 
 SECRETS_PATH = '/run/secrets'
 SECRETS_PATH = '/run/secrets'
 
 

+ 16 - 1
compose/parallel.py

@@ -15,6 +15,7 @@ from six.moves.queue import Queue
 from compose.cli.colors import green
 from compose.cli.colors import green
 from compose.cli.colors import red
 from compose.cli.colors import red
 from compose.cli.signals import ShutdownException
 from compose.cli.signals import ShutdownException
+from compose.const import PARALLEL_LIMIT
 from compose.errors import HealthCheckFailed
 from compose.errors import HealthCheckFailed
 from compose.errors import NoHealthCheckConfigured
 from compose.errors import NoHealthCheckConfigured
 from compose.errors import OperationFailedError
 from compose.errors import OperationFailedError
@@ -26,6 +27,20 @@ log = logging.getLogger(__name__)
 STOP = object()
 STOP = object()
 
 
 
 
+class GlobalLimit(object):
+    """Simple class to hold a global semaphore limiter for a project. This class
+    should be treated as a singleton that is instantiated when the project is.
+    """
+
+    global_limiter = Semaphore(PARALLEL_LIMIT)
+
+    @classmethod
+    def set_global_limit(cls, value):
+        if value is None:
+            value = PARALLEL_LIMIT
+        cls.global_limiter = Semaphore(value)
+
+
 def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, parent_objects=None):
 def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, parent_objects=None):
     """Runs func on objects in parallel while ensuring that func is
     """Runs func on objects in parallel while ensuring that func is
     ran on object only after it is ran on all its dependencies.
     ran on object only after it is ran on all its dependencies.
@@ -173,7 +188,7 @@ def producer(obj, func, results, limiter):
     The entry point for a producer thread which runs func on a single object.
     The entry point for a producer thread which runs func on a single object.
     Places a tuple on the results queue once func has either returned or raised.
     Places a tuple on the results queue once func has either returned or raised.
     """
     """
-    with limiter:
+    with limiter, GlobalLimit.global_limiter:
         try:
         try:
             result = func(obj)
             result = func(obj)
             results.put((obj, result, None))
             results.put((obj, result, None))

+ 109 - 85
tests/unit/parallel_test.py

@@ -1,11 +1,13 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 from __future__ import unicode_literals
 from __future__ import unicode_literals
 
 
+import unittest
 from threading import Lock
 from threading import Lock
 
 
 import six
 import six
 from docker.errors import APIError
 from docker.errors import APIError
 
 
+from compose.parallel import GlobalLimit
 from compose.parallel import parallel_execute
 from compose.parallel import parallel_execute
 from compose.parallel import parallel_execute_iter
 from compose.parallel import parallel_execute_iter
 from compose.parallel import ParallelStreamWriter
 from compose.parallel import ParallelStreamWriter
@@ -31,91 +33,113 @@ def get_deps(obj):
     return [(dep, None) for dep in deps[obj]]
     return [(dep, None) for dep in deps[obj]]
 
 
 
 
-def test_parallel_execute():
-    results, errors = parallel_execute(
-        objects=[1, 2, 3, 4, 5],
-        func=lambda x: x * 2,
-        get_name=six.text_type,
-        msg="Doubling",
-    )
-
-    assert sorted(results) == [2, 4, 6, 8, 10]
-    assert errors == {}
-
-
-def test_parallel_execute_with_limit():
-    limit = 1
-    tasks = 20
-    lock = Lock()
-
-    def f(obj):
-        locked = lock.acquire(False)
-        # we should always get the lock because we're the only thread running
-        assert locked
-        lock.release()
-        return None
-
-    results, errors = parallel_execute(
-        objects=list(range(tasks)),
-        func=f,
-        get_name=six.text_type,
-        msg="Testing",
-        limit=limit,
-    )
-
-    assert results == tasks * [None]
-    assert errors == {}
-
-
-def test_parallel_execute_with_deps():
-    log = []
-
-    def process(x):
-        log.append(x)
-
-    parallel_execute(
-        objects=objects,
-        func=process,
-        get_name=lambda obj: obj,
-        msg="Processing",
-        get_deps=get_deps,
-    )
-
-    assert sorted(log) == sorted(objects)
-
-    assert log.index(data_volume) < log.index(db)
-    assert log.index(db) < log.index(web)
-    assert log.index(cache) < log.index(web)
-
-
-def test_parallel_execute_with_upstream_errors():
-    log = []
-
-    def process(x):
-        if x is data_volume:
-            raise APIError(None, None, "Something went wrong")
-        log.append(x)
-
-    parallel_execute(
-        objects=objects,
-        func=process,
-        get_name=lambda obj: obj,
-        msg="Processing",
-        get_deps=get_deps,
-    )
-
-    assert log == [cache]
-
-    events = [
-        (obj, result, type(exception))
-        for obj, result, exception
-        in parallel_execute_iter(objects, process, get_deps, None)
-    ]
-
-    assert (cache, None, type(None)) in events
-    assert (data_volume, None, APIError) in events
-    assert (db, None, UpstreamError) in events
-    assert (web, None, UpstreamError) in events
+class ParallelTest(unittest.TestCase):
+
+    def test_parallel_execute(self):
+        results, errors = parallel_execute(
+            objects=[1, 2, 3, 4, 5],
+            func=lambda x: x * 2,
+            get_name=six.text_type,
+            msg="Doubling",
+        )
+
+        assert sorted(results) == [2, 4, 6, 8, 10]
+        assert errors == {}
+
+    def test_parallel_execute_with_limit(self):
+        limit = 1
+        tasks = 20
+        lock = Lock()
+
+        def f(obj):
+            locked = lock.acquire(False)
+            # we should always get the lock because we're the only thread running
+            assert locked
+            lock.release()
+            return None
+
+        results, errors = parallel_execute(
+            objects=list(range(tasks)),
+            func=f,
+            get_name=six.text_type,
+            msg="Testing",
+            limit=limit,
+        )
+
+        assert results == tasks * [None]
+        assert errors == {}
+
+    def test_parallel_execute_with_global_limit(self):
+        GlobalLimit.set_global_limit(1)
+        self.addCleanup(GlobalLimit.set_global_limit, None)
+        tasks = 20
+        lock = Lock()
+
+        def f(obj):
+            locked = lock.acquire(False)
+            # we should always get the lock because we're the only thread running
+            assert locked
+            lock.release()
+            return None
+
+        results, errors = parallel_execute(
+            objects=list(range(tasks)),
+            func=f,
+            get_name=six.text_type,
+            msg="Testing",
+        )
+
+        assert results == tasks * [None]
+        assert errors == {}
+
+    def test_parallel_execute_with_deps(self):
+        log = []
+
+        def process(x):
+            log.append(x)
+
+        parallel_execute(
+            objects=objects,
+            func=process,
+            get_name=lambda obj: obj,
+            msg="Processing",
+            get_deps=get_deps,
+        )
+
+        assert sorted(log) == sorted(objects)
+
+        assert log.index(data_volume) < log.index(db)
+        assert log.index(db) < log.index(web)
+        assert log.index(cache) < log.index(web)
+
+    def test_parallel_execute_with_upstream_errors(self):
+        log = []
+
+        def process(x):
+            if x is data_volume:
+                raise APIError(None, None, "Something went wrong")
+            log.append(x)
+
+        parallel_execute(
+            objects=objects,
+            func=process,
+            get_name=lambda obj: obj,
+            msg="Processing",
+            get_deps=get_deps,
+        )
+
+        assert log == [cache]
+
+        events = [
+            (obj, result, type(exception))
+            for obj, result, exception
+            in parallel_execute_iter(objects, process, get_deps, None)
+        ]
+
+        assert (cache, None, type(None)) in events
+        assert (data_volume, None, APIError) in events
+        assert (db, None, UpstreamError) in events
+        assert (web, None, UpstreamError) in events
 
 
 
 
 def test_parallel_execute_alignment(capsys):
 def test_parallel_execute_alignment(capsys):