Browse Source

Clean up limit setting code and add reasonable input guards

Signed-off-by: Joffrey F <[email protected]>
Joffrey F 7 năm trước cách đây
mục cha
commit
8d3c7d4bce
4 tập tin đã thay đổi với 135 bổ sung123 xóa
  1. 20 6
      compose/cli/command.py
  2. 4 3
      compose/parallel.py
  3. 3 6
      compose/project.py
  4. 108 108
      tests/unit/parallel_test.py

+ 20 - 6
compose/cli/command.py

@@ -10,6 +10,7 @@ import six
 from . import errors
 from . import verbose_proxy
 from .. import config
+from .. import parallel
 from ..config.environment import Environment
 from ..const import API_VERSIONS
 from ..project import Project
@@ -23,6 +24,8 @@ log = logging.getLogger(__name__)
 
 def project_from_options(project_dir, options):
     environment = Environment.from_env_file(project_dir)
+    set_parallel_limit(environment)
+
     host = options.get('--host')
     if host is not None:
         host = host.lstrip('=')
@@ -38,6 +41,22 @@ def project_from_options(project_dir, options):
     )
 
 
+def set_parallel_limit(environment):
+    parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT')
+    if parallel_limit:
+        try:
+            parallel_limit = int(parallel_limit)
+        except ValueError:
+            raise errors.UserError(
+                'COMPOSE_PARALLEL_LIMIT must be an integer (found: "{}")'.format(
+                    environment.get('COMPOSE_PARALLEL_LIMIT')
+                )
+            )
+        if parallel_limit <= 1:
+            raise errors.UserError('COMPOSE_PARALLEL_LIMIT can not be less than 2')
+        parallel.GlobalLimit.set_global_limit(parallel_limit)
+
+
 def get_config_from_options(base_dir, options):
     environment = Environment.from_env_file(base_dir)
     config_path = get_config_path_from_options(
@@ -99,13 +118,8 @@ def get_project(project_dir, config_path=None, project_name=None, verbose=False,
         host=host, environment=environment
     )
 
-    global_parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT')
-    if global_parallel_limit:
-        global_parallel_limit = int(global_parallel_limit)
-
     with errors.handle_connection_errors(client):
-        return Project.from_config(project_name, config_data, client,
-                                   global_parallel_limit=global_parallel_limit)
+        return Project.from_config(project_name, config_data, client)
 
 
 def get_project_name(working_dir, project_name=None, environment=None):

+ 4 - 3
compose/parallel.py

@@ -35,9 +35,10 @@ class GlobalLimit(object):
     global_limiter = Semaphore(PARALLEL_LIMIT)
 
     @classmethod
-    def set_global_limit(cls, value=None):
-        if value is not None:
-            cls.global_limiter = Semaphore(value)
+    def set_global_limit(cls, value):
+        if value is None:
+            value = PARALLEL_LIMIT
+        cls.global_limiter = Semaphore(value)
 
 
 def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, parent_objects=None):

+ 3 - 6
compose/project.py

@@ -61,15 +61,13 @@ class Project(object):
     """
     A collection of services.
     """
-    def __init__(self, name, services, client, networks=None, volumes=None, config_version=None,
-                 parallel_limit=None):
+    def __init__(self, name, services, client, networks=None, volumes=None, config_version=None):
         self.name = name
         self.services = services
         self.client = client
         self.volumes = volumes or ProjectVolumes({})
         self.networks = networks or ProjectNetworks({}, False)
         self.config_version = config_version
-        parallel.GlobalLimit.set_global_limit(value=parallel_limit)
 
     def labels(self, one_off=OneOffFilter.exclude):
         labels = ['{0}={1}'.format(LABEL_PROJECT, self.name)]
@@ -78,7 +76,7 @@ class Project(object):
         return labels
 
     @classmethod
-    def from_config(cls, name, config_data, client, global_parallel_limit=None):
+    def from_config(cls, name, config_data, client):
         """
         Construct a Project from a config.Config object.
         """
@@ -89,8 +87,7 @@ class Project(object):
             networks,
             use_networking)
         volumes = ProjectVolumes.from_config(name, config_data, client)
-        project = cls(name, [], client, project_networks, volumes, config_data.version,
-                      parallel_limit=global_parallel_limit)
+        project = cls(name, [], client, project_networks, volumes, config_data.version)
 
         for service_dict in config_data.services:
             service_dict = dict(service_dict)

+ 108 - 108
tests/unit/parallel_test.py

@@ -1,6 +1,7 @@
 from __future__ import absolute_import
 from __future__ import unicode_literals
 
+import unittest
 from threading import Lock
 
 import six
@@ -32,114 +33,113 @@ def get_deps(obj):
     return [(dep, None) for dep in deps[obj]]
 
 
-def test_parallel_execute():
-    results, errors = parallel_execute(
-        objects=[1, 2, 3, 4, 5],
-        func=lambda x: x * 2,
-        get_name=six.text_type,
-        msg="Doubling",
-    )
-
-    assert sorted(results) == [2, 4, 6, 8, 10]
-    assert errors == {}
-
-
-def test_parallel_execute_with_limit():
-    limit = 1
-    tasks = 20
-    lock = Lock()
-
-    def f(obj):
-        locked = lock.acquire(False)
-        # we should always get the lock because we're the only thread running
-        assert locked
-        lock.release()
-        return None
-
-    results, errors = parallel_execute(
-        objects=list(range(tasks)),
-        func=f,
-        get_name=six.text_type,
-        msg="Testing",
-        limit=limit,
-    )
-
-    assert results == tasks * [None]
-    assert errors == {}
-
-
-def test_parallel_execute_with_global_limit():
-    GlobalLimit.set_global_limit(1)
-    tasks = 20
-    lock = Lock()
-
-    def f(obj):
-        locked = lock.acquire(False)
-        # we should always get the lock because we're the only thread running
-        assert locked
-        lock.release()
-        return None
-
-    results, errors = parallel_execute(
-        objects=list(range(tasks)),
-        func=f,
-        get_name=six.text_type,
-        msg="Testing",
-    )
-
-    assert results == tasks * [None]
-    assert errors == {}
-
-
-def test_parallel_execute_with_deps():
-    log = []
-
-    def process(x):
-        log.append(x)
-
-    parallel_execute(
-        objects=objects,
-        func=process,
-        get_name=lambda obj: obj,
-        msg="Processing",
-        get_deps=get_deps,
-    )
-
-    assert sorted(log) == sorted(objects)
-
-    assert log.index(data_volume) < log.index(db)
-    assert log.index(db) < log.index(web)
-    assert log.index(cache) < log.index(web)
-
-
-def test_parallel_execute_with_upstream_errors():
-    log = []
-
-    def process(x):
-        if x is data_volume:
-            raise APIError(None, None, "Something went wrong")
-        log.append(x)
-
-    parallel_execute(
-        objects=objects,
-        func=process,
-        get_name=lambda obj: obj,
-        msg="Processing",
-        get_deps=get_deps,
-    )
-
-    assert log == [cache]
-
-    events = [
-        (obj, result, type(exception))
-        for obj, result, exception
-        in parallel_execute_iter(objects, process, get_deps, None)
-    ]
-
-    assert (cache, None, type(None)) in events
-    assert (data_volume, None, APIError) in events
-    assert (db, None, UpstreamError) in events
-    assert (web, None, UpstreamError) in events
+class ParallelTest(unittest.TestCase):
+
+    def test_parallel_execute(self):
+        results, errors = parallel_execute(
+            objects=[1, 2, 3, 4, 5],
+            func=lambda x: x * 2,
+            get_name=six.text_type,
+            msg="Doubling",
+        )
+
+        assert sorted(results) == [2, 4, 6, 8, 10]
+        assert errors == {}
+
+    def test_parallel_execute_with_limit(self):
+        limit = 1
+        tasks = 20
+        lock = Lock()
+
+        def f(obj):
+            locked = lock.acquire(False)
+            # we should always get the lock because we're the only thread running
+            assert locked
+            lock.release()
+            return None
+
+        results, errors = parallel_execute(
+            objects=list(range(tasks)),
+            func=f,
+            get_name=six.text_type,
+            msg="Testing",
+            limit=limit,
+        )
+
+        assert results == tasks * [None]
+        assert errors == {}
+
+    def test_parallel_execute_with_global_limit(self):
+        GlobalLimit.set_global_limit(1)
+        self.addCleanup(GlobalLimit.set_global_limit, None)
+        tasks = 20
+        lock = Lock()
+
+        def f(obj):
+            locked = lock.acquire(False)
+            # we should always get the lock because we're the only thread running
+            assert locked
+            lock.release()
+            return None
+
+        results, errors = parallel_execute(
+            objects=list(range(tasks)),
+            func=f,
+            get_name=six.text_type,
+            msg="Testing",
+        )
+
+        assert results == tasks * [None]
+        assert errors == {}
+
+    def test_parallel_execute_with_deps(self):
+        log = []
+
+        def process(x):
+            log.append(x)
+
+        parallel_execute(
+            objects=objects,
+            func=process,
+            get_name=lambda obj: obj,
+            msg="Processing",
+            get_deps=get_deps,
+        )
+
+        assert sorted(log) == sorted(objects)
+
+        assert log.index(data_volume) < log.index(db)
+        assert log.index(db) < log.index(web)
+        assert log.index(cache) < log.index(web)
+
+    def test_parallel_execute_with_upstream_errors(self):
+        log = []
+
+        def process(x):
+            if x is data_volume:
+                raise APIError(None, None, "Something went wrong")
+            log.append(x)
+
+        parallel_execute(
+            objects=objects,
+            func=process,
+            get_name=lambda obj: obj,
+            msg="Processing",
+            get_deps=get_deps,
+        )
+
+        assert log == [cache]
+
+        events = [
+            (obj, result, type(exception))
+            for obj, result, exception
+            in parallel_execute_iter(objects, process, get_deps, None)
+        ]
+
+        assert (cache, None, type(None)) in events
+        assert (data_volume, None, APIError) in events
+        assert (db, None, UpstreamError) in events
+        assert (web, None, UpstreamError) in events
 
 
 def test_parallel_execute_alignment(capsys):