Browse Source

Merge pull request #2436 from dnephin/reduce_cyclomatic_complexity

Reduce cyclomatic complexity
Aanand Prasad 10 years ago
parent
commit
665de9a494
7 changed files with 211 additions and 204 deletions
  1. 34 39
      compose/config/config.py
  2. 135 0
      compose/parallel.py
  3. 7 27
      compose/project.py
  4. 17 32
      compose/service.py
  5. 0 95
      compose/utils.py
  6. 17 10
      tests/integration/service_test.py
  7. 1 1
      tox.ini

+ 34 - 39
compose/config/config.py

@@ -1,5 +1,6 @@
 import codecs
 import logging
+import operator
 import os
 import sys
 from collections import namedtuple
@@ -387,54 +388,44 @@ def merge_service_dicts_from_files(base, override):
 
 
 def merge_service_dicts(base, override):
-    d = base.copy()
+    d = {}
 
-    if 'environment' in base or 'environment' in override:
-        d['environment'] = merge_environment(
-            base.get('environment'),
-            override.get('environment'),
-        )
-
-    path_mapping_keys = ['volumes', 'devices']
-
-    for key in path_mapping_keys:
-        if key in base or key in override:
-            d[key] = merge_path_mappings(
-                base.get(key),
-                override.get(key),
-            )
-
-    if 'labels' in base or 'labels' in override:
-        d['labels'] = merge_labels(
-            base.get('labels'),
-            override.get('labels'),
-        )
+    def merge_field(field, merge_func, default=None):
+        if field in base or field in override:
+            d[field] = merge_func(
+                base.get(field, default),
+                override.get(field, default))
 
-    if 'image' in override and 'build' in d:
-        del d['build']
+    merge_field('environment', merge_environment)
+    merge_field('labels', merge_labels)
+    merge_image_or_build(base, override, d)
 
-    if 'build' in override and 'image' in d:
-        del d['image']
+    for field in ['volumes', 'devices']:
+        merge_field(field, merge_path_mappings)
 
-    list_keys = ['ports', 'expose', 'external_links']
+    for field in ['ports', 'expose', 'external_links']:
+        merge_field(field, operator.add, default=[])
 
-    for key in list_keys:
-        if key in base or key in override:
-            d[key] = base.get(key, []) + override.get(key, [])
+    for field in ['dns', 'dns_search']:
+        merge_field(field, merge_list_or_string)
 
-    list_or_string_keys = ['dns', 'dns_search']
+    already_merged_keys = set(d) | {'image', 'build'}
+    for field in set(ALLOWED_KEYS) - already_merged_keys:
+        if field in base or field in override:
+            d[field] = override.get(field, base.get(field))
 
-    for key in list_or_string_keys:
-        if key in base or key in override:
-            d[key] = to_list(base.get(key)) + to_list(override.get(key))
-
-    already_merged_keys = ['environment', 'labels'] + path_mapping_keys + list_keys + list_or_string_keys
+    return d
 
-    for k in set(ALLOWED_KEYS) - set(already_merged_keys):
-        if k in override:
-            d[k] = override[k]
 
-    return d
+def merge_image_or_build(base, override, output):
+    if 'image' in override:
+        output['image'] = override['image']
+    elif 'build' in override:
+        output['build'] = override['build']
+    elif 'image' in base:
+        output['image'] = base['image']
+    elif 'build' in base:
+        output['build'] = base['build']
 
 
 def merge_environment(base, override):
@@ -602,6 +593,10 @@ def expand_path(working_dir, path):
     return os.path.abspath(os.path.join(working_dir, os.path.expanduser(path)))
 
 
+def merge_list_or_string(base, override):
+    return to_list(base) + to_list(override)
+
+
 def to_list(value):
     if value is None:
         return []

+ 135 - 0
compose/parallel.py

@@ -0,0 +1,135 @@
+from __future__ import absolute_import
+from __future__ import unicode_literals
+
+import operator
+import sys
+from threading import Thread
+
+from docker.errors import APIError
+from six.moves.queue import Empty
+from six.moves.queue import Queue
+
+from compose.utils import get_output_stream
+
+
+def perform_operation(func, arg, callback, index):
+    try:
+        callback((index, func(arg)))
+    except Exception as e:
+        callback((index, e))
+
+
+def parallel_execute(objects, func, index_func, msg):
+    """For a given list of objects, call the callable passing in the first
+    object we give it.
+    """
+    objects = list(objects)
+    stream = get_output_stream(sys.stdout)
+    writer = ParallelStreamWriter(stream, msg)
+
+    for obj in objects:
+        writer.initialize(index_func(obj))
+
+    q = Queue()
+
+    # TODO: limit the number of threads #1828
+    for obj in objects:
+        t = Thread(
+            target=perform_operation,
+            args=(func, obj, q.put, index_func(obj)))
+        t.daemon = True
+        t.start()
+
+    done = 0
+    errors = {}
+
+    while done < len(objects):
+        try:
+            msg_index, result = q.get(timeout=1)
+        except Empty:
+            continue
+
+        if isinstance(result, APIError):
+            errors[msg_index] = "error", result.explanation
+            writer.write(msg_index, 'error')
+        elif isinstance(result, Exception):
+            errors[msg_index] = "unexpected_exception", result
+        else:
+            writer.write(msg_index, 'done')
+        done += 1
+
+    if not errors:
+        return
+
+    stream.write("\n")
+    for msg_index, (result, error) in errors.items():
+        stream.write("ERROR: for {}  {} \n".format(msg_index, error))
+        if result == 'unexpected_exception':
+            raise error
+
+
+class ParallelStreamWriter(object):
+    """Write out messages for operations happening in parallel.
+
+    Each operation has it's own line, and ANSI code characters are used
+    to jump to the correct line, and write over the line.
+    """
+
+    def __init__(self, stream, msg):
+        self.stream = stream
+        self.msg = msg
+        self.lines = []
+
+    def initialize(self, obj_index):
+        self.lines.append(obj_index)
+        self.stream.write("{} {} ... \r\n".format(self.msg, obj_index))
+        self.stream.flush()
+
+    def write(self, obj_index, status):
+        position = self.lines.index(obj_index)
+        diff = len(self.lines) - position
+        # move up
+        self.stream.write("%c[%dA" % (27, diff))
+        # erase
+        self.stream.write("%c[2K\r" % 27)
+        self.stream.write("{} {} ... {}\r".format(self.msg, obj_index, status))
+        # move back down
+        self.stream.write("%c[%dB" % (27, diff))
+        self.stream.flush()
+
+
+def parallel_operation(containers, operation, options, message):
+    parallel_execute(
+        containers,
+        operator.methodcaller(operation, **options),
+        operator.attrgetter('name'),
+        message)
+
+
+def parallel_remove(containers, options):
+    stopped_containers = [c for c in containers if not c.is_running]
+    parallel_operation(stopped_containers, 'remove', options, 'Removing')
+
+
+def parallel_stop(containers, options):
+    parallel_operation(containers, 'stop', options, 'Stopping')
+
+
+def parallel_start(containers, options):
+    parallel_operation(containers, 'start', options, 'Starting')
+
+
+def parallel_pause(containers, options):
+    parallel_operation(containers, 'pause', options, 'Pausing')
+
+
+def parallel_unpause(containers, options):
+    parallel_operation(containers, 'unpause', options, 'Unpausing')
+
+
+def parallel_kill(containers, options):
+    parallel_operation(containers, 'kill', options, 'Killing')
+
+
+def parallel_restart(containers, options):
+    parallel_operation(containers, 'restart', options, 'Restarting')

+ 7 - 27
compose/project.py

@@ -7,6 +7,7 @@ from functools import reduce
 from docker.errors import APIError
 from docker.errors import NotFound
 
+from . import parallel
 from .config import ConfigurationError
 from .config import get_service_name_from_net
 from .const import DEFAULT_TIMEOUT
@@ -22,7 +23,6 @@ from .service import parse_volume_from_spec
 from .service import Service
 from .service import ServiceNet
 from .service import VolumeFromSpec
-from .utils import parallel_execute
 
 
 log = logging.getLogger(__name__)
@@ -241,42 +241,22 @@ class Project(object):
             service.start(**options)
 
     def stop(self, service_names=None, **options):
-        parallel_execute(
-            objects=self.containers(service_names),
-            obj_callable=lambda c: c.stop(**options),
-            msg_index=lambda c: c.name,
-            msg="Stopping"
-        )
+        parallel.parallel_stop(self.containers(service_names), options)
 
     def pause(self, service_names=None, **options):
-        for service in reversed(self.get_services(service_names)):
-            service.pause(**options)
+        parallel.parallel_pause(reversed(self.containers(service_names)), options)
 
     def unpause(self, service_names=None, **options):
-        for service in self.get_services(service_names):
-            service.unpause(**options)
+        parallel.parallel_unpause(self.containers(service_names), options)
 
     def kill(self, service_names=None, **options):
-        parallel_execute(
-            objects=self.containers(service_names),
-            obj_callable=lambda c: c.kill(**options),
-            msg_index=lambda c: c.name,
-            msg="Killing"
-        )
+        parallel.parallel_kill(self.containers(service_names), options)
 
     def remove_stopped(self, service_names=None, **options):
-        all_containers = self.containers(service_names, stopped=True)
-        stopped_containers = [c for c in all_containers if not c.is_running]
-        parallel_execute(
-            objects=stopped_containers,
-            obj_callable=lambda c: c.remove(**options),
-            msg_index=lambda c: c.name,
-            msg="Removing"
-        )
+        parallel.parallel_remove(self.containers(service_names, stopped=True), options)
 
     def restart(self, service_names=None, **options):
-        for service in self.get_services(service_names):
-            service.restart(**options)
+        parallel.parallel_restart(self.containers(service_names, stopped=True), options)
 
     def build(self, service_names=None, no_cache=False, pull=False, force_rm=False):
         for service in self.get_services(service_names):

+ 17 - 32
compose/service.py

@@ -29,10 +29,13 @@ from .const import LABEL_SERVICE
 from .const import LABEL_VERSION
 from .container import Container
 from .legacy import check_for_legacy_containers
+from .parallel import parallel_execute
+from .parallel import parallel_remove
+from .parallel import parallel_start
+from .parallel import parallel_stop
 from .progress_stream import stream_output
 from .progress_stream import StreamOutputError
 from .utils import json_hash
-from .utils import parallel_execute
 
 
 log = logging.getLogger(__name__)
@@ -241,12 +244,7 @@ class Service(object):
                 else:
                     containers_to_start = stopped_containers
 
-                parallel_execute(
-                    objects=containers_to_start,
-                    obj_callable=lambda c: c.start(),
-                    msg_index=lambda c: c.name,
-                    msg="Starting"
-                )
+                parallel_start(containers_to_start, {})
 
                 num_running += len(containers_to_start)
 
@@ -259,35 +257,22 @@ class Service(object):
             ]
 
             parallel_execute(
-                objects=container_numbers,
-                obj_callable=lambda n: create_and_start(service=self, number=n),
-                msg_index=lambda n: n,
-                msg="Creating and starting"
+                container_numbers,
+                lambda n: create_and_start(service=self, number=n),
+                lambda n: n,
+                "Creating and starting"
             )
 
         if desired_num < num_running:
             num_to_stop = num_running - desired_num
-            sorted_running_containers = sorted(running_containers, key=attrgetter('number'))
-            containers_to_stop = sorted_running_containers[-num_to_stop:]
-
-            parallel_execute(
-                objects=containers_to_stop,
-                obj_callable=lambda c: c.stop(timeout=timeout),
-                msg_index=lambda c: c.name,
-                msg="Stopping"
-            )
-
-        self.remove_stopped()
-
-    def remove_stopped(self, **options):
-        containers = [c for c in self.containers(stopped=True) if not c.is_running]
-
-        parallel_execute(
-            objects=containers,
-            obj_callable=lambda c: c.remove(**options),
-            msg_index=lambda c: c.name,
-            msg="Removing"
-        )
+            sorted_running_containers = sorted(
+                running_containers,
+                key=attrgetter('number'))
+            parallel_stop(
+                sorted_running_containers[-num_to_stop:],
+                dict(timeout=timeout))
+
+        parallel_remove(self.containers(stopped=True), {})
 
     def create_container(self,
                          one_off=False,

+ 0 - 95
compose/utils.py

@@ -2,84 +2,13 @@ import codecs
 import hashlib
 import json
 import json.decoder
-import logging
-import sys
-from threading import Thread
 
 import six
-from docker.errors import APIError
-from six.moves.queue import Empty
-from six.moves.queue import Queue
 
 
-log = logging.getLogger(__name__)
-
 json_decoder = json.JSONDecoder()
 
 
-def parallel_execute(objects, obj_callable, msg_index, msg):
-    """
-    For a given list of objects, call the callable passing in the first
-    object we give it.
-    """
-    stream = get_output_stream(sys.stdout)
-    lines = []
-
-    for obj in objects:
-        write_out_msg(stream, lines, msg_index(obj), msg)
-
-    q = Queue()
-
-    def inner_execute_function(an_callable, parameter, msg_index):
-        error = None
-        try:
-            result = an_callable(parameter)
-        except APIError as e:
-            error = e.explanation
-            result = "error"
-        except Exception as e:
-            error = e
-            result = 'unexpected_exception'
-
-        q.put((msg_index, result, error))
-
-    for an_object in objects:
-        t = Thread(
-            target=inner_execute_function,
-            args=(obj_callable, an_object, msg_index(an_object)),
-        )
-        t.daemon = True
-        t.start()
-
-    done = 0
-    errors = {}
-    total_to_execute = len(objects)
-
-    while done < total_to_execute:
-        try:
-            msg_index, result, error = q.get(timeout=1)
-
-            if result == 'unexpected_exception':
-                errors[msg_index] = result, error
-            if result == 'error':
-                errors[msg_index] = result, error
-                write_out_msg(stream, lines, msg_index, msg, status='error')
-            else:
-                write_out_msg(stream, lines, msg_index, msg)
-            done += 1
-        except Empty:
-            pass
-
-    if not errors:
-        return
-
-    stream.write("\n")
-    for msg_index, (result, error) in errors.items():
-        stream.write("ERROR: for {}  {} \n".format(msg_index, error))
-        if result == 'unexpected_exception':
-            raise error
-
-
 def get_output_stream(stream):
     if six.PY3:
         return stream
@@ -151,30 +80,6 @@ def json_stream(stream):
     return split_buffer(stream, json_splitter, json_decoder.decode)
 
 
-def write_out_msg(stream, lines, msg_index, msg, status="done"):
-    """
-    Using special ANSI code characters we can write out the msg over the top of
-    a previous status message, if it exists.
-    """
-    obj_index = msg_index
-    if msg_index in lines:
-        position = lines.index(obj_index)
-        diff = len(lines) - position
-        # move up
-        stream.write("%c[%dA" % (27, diff))
-        # erase
-        stream.write("%c[2K\r" % 27)
-        stream.write("{} {} ... {}\r".format(msg, obj_index, status))
-        # move back down
-        stream.write("%c[%dB" % (27, diff))
-    else:
-        diff = 0
-        lines.append(obj_index)
-        stream.write("{} {} ... \r\n".format(msg, obj_index))
-
-    stream.flush()
-
-
 def json_hash(obj):
     dump = json.dumps(obj, sort_keys=True, separators=(',', ':'))
     h = hashlib.sha256()

+ 17 - 10
tests/integration/service_test.py

@@ -36,6 +36,12 @@ def create_and_start_container(service, **override_options):
     return container
 
 
+def remove_stopped(service):
+    containers = [c for c in service.containers(stopped=True) if not c.is_running]
+    for container in containers:
+        container.remove()
+
+
 class ServiceTest(DockerClientTestCase):
     def test_containers(self):
         foo = self.create_service('foo')
@@ -94,14 +100,14 @@ class ServiceTest(DockerClientTestCase):
         create_and_start_container(service)
         self.assertEqual(len(service.containers()), 1)
 
-        service.remove_stopped()
+        remove_stopped(service)
         self.assertEqual(len(service.containers()), 1)
 
         service.kill()
         self.assertEqual(len(service.containers()), 0)
         self.assertEqual(len(service.containers(stopped=True)), 1)
 
-        service.remove_stopped()
+        remove_stopped(service)
         self.assertEqual(len(service.containers(stopped=True)), 0)
 
     def test_create_container_with_one_off(self):
@@ -659,9 +665,8 @@ class ServiceTest(DockerClientTestCase):
         self.assertIn('Creating', captured_output)
         self.assertIn('Starting', captured_output)
 
-    def test_scale_with_api_returns_errors(self):
-        """
-        Test that when scaling if the API returns an error, that error is handled
+    def test_scale_with_api_error(self):
+        """Test that when scaling if the API returns an error, that error is handled
         and the remaining threads continue.
         """
         service = self.create_service('web')
@@ -670,7 +675,10 @@ class ServiceTest(DockerClientTestCase):
 
         with mock.patch(
             'compose.container.Container.create',
-                side_effect=APIError(message="testing", response={}, explanation="Boom")):
+            side_effect=APIError(
+                message="testing",
+                response={},
+                explanation="Boom")):
 
             with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout:
                 service.scale(3)
@@ -679,9 +687,8 @@ class ServiceTest(DockerClientTestCase):
         self.assertTrue(service.containers()[0].is_running)
         self.assertIn("ERROR: for 2  Boom", mock_stdout.getvalue())
 
-    def test_scale_with_api_returns_unexpected_exception(self):
-        """
-        Test that when scaling if the API returns an error, that is not of type
+    def test_scale_with_unexpected_exception(self):
+        """Test that when scaling if the API returns an error, that is not of type
         APIError, that error is re-raised.
         """
         service = self.create_service('web')
@@ -903,7 +910,7 @@ class ServiceTest(DockerClientTestCase):
             self.assertIn(pair, labels)
 
         service.kill()
-        service.remove_stopped()
+        remove_stopped(service)
 
         labels_list = ["%s=%s" % pair for pair in labels_dict.items()]
 

+ 1 - 1
tox.ini

@@ -44,5 +44,5 @@ directory = coverage-html
 # Allow really long lines for now
 max-line-length = 140
 # Set this high for now
-max-complexity = 20
+max-complexity = 12
 exclude = compose/packages