ソースを参照

Allow parallel limit to be set in env file.

Signed-off-by: Ashlie Martinez <[email protected]>
Ashlie Martinez 7 年 前
コミット
acf76c15a2
4 ファイル変更24 行追加21 行削除
  1. 6 1
      compose/cli/command.py
  2. 10 10
      compose/parallel.py
  3. 6 3
      compose/project.py
  4. 2 7
      tests/unit/parallel_test.py

+ 6 - 1
compose/cli/command.py

@@ -99,8 +99,13 @@ def get_project(project_dir, config_path=None, project_name=None, verbose=False,
         host=host, environment=environment
     )
 
+    global_parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT')
+    if global_parallel_limit:
+        global_parallel_limit = int(global_parallel_limit)
+
     with errors.handle_connection_errors(client):
-        return Project.from_config(project_name, config_data, client)
+        return Project.from_config(project_name, config_data, client,
+                                   global_parallel_limit=global_parallel_limit)
 
 
 def get_project_name(working_dir, project_name=None, environment=None):

+ 10 - 10
compose/parallel.py

@@ -15,7 +15,6 @@ from six.moves.queue import Queue
 from compose.cli.colors import green
 from compose.cli.colors import red
 from compose.cli.signals import ShutdownException
-from compose.config.environment import Environment
 from compose.const import PARALLEL_LIMIT
 from compose.errors import HealthCheckFailed
 from compose.errors import NoHealthCheckConfigured
@@ -28,16 +27,17 @@ log = logging.getLogger(__name__)
 STOP = object()
 
 
-def get_configured_limit():
-    limit = Environment.from_command_line({'COMPOSE_PARALLEL_LIMIT': None})['COMPOSE_PARALLEL_LIMIT']
-    if limit:
-        limit = int(limit)
-    else:
-        limit = PARALLEL_LIMIT
-    return limit
+class GlobalLimit(object):
+    """Simple class to hold a global semaphore limiter for a project. This class
+    should be treated as a singleton that is instantiated when the project is.
+    """
 
+    global_limiter = Semaphore(PARALLEL_LIMIT)
 
-global_limiter = Semaphore(get_configured_limit())
+    @classmethod
+    def set_global_limit(cls, value=None):
+        if value is not None:
+            cls.global_limiter = Semaphore(value)
 
 
 def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, parent_objects=None):
@@ -187,7 +187,7 @@ def producer(obj, func, results, limiter):
     The entry point for a producer thread which runs func on a single object.
     Places a tuple on the results queue once func has either returned or raised.
     """
-    with limiter, global_limiter:
+    with limiter, GlobalLimit.global_limiter:
         try:
             result = func(obj)
             results.put((obj, result, None))

+ 6 - 3
compose/project.py

@@ -61,13 +61,15 @@ class Project(object):
     """
     A collection of services.
     """
-    def __init__(self, name, services, client, networks=None, volumes=None, config_version=None):
+    def __init__(self, name, services, client, networks=None, volumes=None, config_version=None,
+                 parallel_limit=None):
         self.name = name
         self.services = services
         self.client = client
         self.volumes = volumes or ProjectVolumes({})
         self.networks = networks or ProjectNetworks({}, False)
         self.config_version = config_version
+        parallel.GlobalLimit.set_global_limit(value=parallel_limit)
 
     def labels(self, one_off=OneOffFilter.exclude):
         labels = ['{0}={1}'.format(LABEL_PROJECT, self.name)]
@@ -76,7 +78,7 @@ class Project(object):
         return labels
 
     @classmethod
-    def from_config(cls, name, config_data, client):
+    def from_config(cls, name, config_data, client, global_parallel_limit=None):
         """
         Construct a Project from a config.Config object.
         """
@@ -87,7 +89,8 @@ class Project(object):
             networks,
             use_networking)
         volumes = ProjectVolumes.from_config(name, config_data, client)
-        project = cls(name, [], client, project_networks, volumes, config_data.version)
+        project = cls(name, [], client, project_networks, volumes, config_data.version,
+                      parallel_limit=global_parallel_limit)
 
         for service_dict in config_data.services:
             service_dict = dict(service_dict)

+ 2 - 7
tests/unit/parallel_test.py

@@ -1,14 +1,12 @@
 from __future__ import absolute_import
 from __future__ import unicode_literals
 
-import os
 from threading import Lock
 
 import six
 from docker.errors import APIError
 
-from .. import mock
-from compose.parallel import get_configured_limit
+from compose.parallel import GlobalLimit
 from compose.parallel import parallel_execute
 from compose.parallel import parallel_execute_iter
 from compose.parallel import ParallelStreamWriter
@@ -70,14 +68,11 @@ def test_parallel_execute_with_limit():
     assert errors == {}
 
 
[email protected](os.environ)
 def test_parallel_execute_with_global_limit():
-    os.environ['COMPOSE_PARALLEL_LIMIT'] = '1'
+    GlobalLimit.set_global_limit(1)
     tasks = 20
     lock = Lock()
 
-    assert get_configured_limit() == 1
-
     def f(obj):
         locked = lock.acquire(False)
         # we should always get the lock because we're the only thread running