Explorar o código

Move parsing of volumes_from to the last step of config parsing.

Includes creating a new compose.config.types module for all the domain objects.

Signed-off-by: Daniel Nephin <[email protected]>
Daniel Nephin %!s(int64=10) %!d(string=hai) anos
pai
achega
068edfa313

+ 19 - 0
compose/config/config.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import codecs
 import logging
 import operator
@@ -12,6 +14,7 @@ from .errors import CircularReference
 from .errors import ComposeFileNotFound
 from .errors import ConfigurationError
 from .interpolation import interpolate_environment_variables
+from .types import VolumeFromSpec
 from .validation import validate_against_fields_schema
 from .validation import validate_against_service_schema
 from .validation import validate_extends_file_path
@@ -198,8 +201,12 @@ def load(config_details):
             service_dict)
         resolver = ServiceExtendsResolver(service_config)
         service_dict = process_service(resolver.run())
+
+        # TODO: move to validate_service()
         validate_against_service_schema(service_dict, service_config.name)
         validate_paths(service_dict)
+
+        service_dict = finalize_service(service_config._replace(config=service_dict))
         service_dict['name'] = service_config.name
         return service_dict
 
@@ -353,6 +360,7 @@ def validate_ulimits(ulimit_config):
                     "than 'hard' value".format(ulimit_config))
 
 
+# TODO: rename to normalize_service
 def process_service(service_config):
     working_dir = service_config.working_dir
     service_dict = dict(service_config.config)
@@ -370,12 +378,23 @@ def process_service(service_config):
     if 'labels' in service_dict:
         service_dict['labels'] = parse_labels(service_dict['labels'])
 
+    # TODO: move to a validate_service()
     if 'ulimits' in service_dict:
         validate_ulimits(service_dict['ulimits'])
 
     return service_dict
 
 
+def finalize_service(service_config):
+    service_dict = dict(service_config.config)
+
+    if 'volumes_from' in service_dict:
+        service_dict['volumes_from'] = [
+            VolumeFromSpec.parse(vf) for vf in service_dict['volumes_from']]
+
+    return service_dict
+
+
 def merge_service_dicts_from_files(base, override):
     """When merging services from multiple files we need to merge the `extends`
     field. This is not handled by `merge_service_dicts()` which is used to

+ 28 - 0
compose/config/types.py

@@ -0,0 +1,28 @@
+"""
+Types for objects parsed from the configuration.
+"""
+from __future__ import absolute_import
+from __future__ import unicode_literals
+
+from collections import namedtuple
+
+from compose.config.errors import ConfigurationError
+
+
+class VolumeFromSpec(namedtuple('_VolumeFromSpec', 'source mode')):
+
+    @classmethod
+    def parse(cls, volume_from_config):
+        parts = volume_from_config.split(':')
+        if len(parts) > 2:
+            raise ConfigurationError(
+                "volume_from {} has incorrect format, should be "
+                "service[:mode]".format(volume_from_config))
+
+        if len(parts) == 1:
+            source = parts[0]
+            mode = 'rw'
+        else:
+            source, mode = parts
+
+        return cls(source, mode)

+ 6 - 12
compose/project.py

@@ -19,10 +19,8 @@ from .legacy import check_for_legacy_containers
 from .service import ContainerNet
 from .service import ConvergenceStrategy
 from .service import Net
-from .service import parse_volume_from_spec
 from .service import Service
 from .service import ServiceNet
-from .service import VolumeFromSpec
 
 
 log = logging.getLogger(__name__)
@@ -38,10 +36,7 @@ def sort_service_dicts(services):
         return [link.split(':')[0] for link in links]
 
     def get_service_names_from_volumes_from(volumes_from):
-        return [
-            parse_volume_from_spec(volume_from).source
-            for volume_from in volumes_from
-        ]
+        return [volume_from.source for volume_from in volumes_from]
 
     def get_service_dependents(service_dict, services):
         name = service_dict['name']
@@ -192,16 +187,15 @@ class Project(object):
     def get_volumes_from(self, service_dict):
         volumes_from = []
         if 'volumes_from' in service_dict:
-            for volume_from_config in service_dict.get('volumes_from', []):
-                volume_from_spec = parse_volume_from_spec(volume_from_config)
+            for volume_from_spec in service_dict.get('volumes_from', []):
                 # Get service
                 try:
-                    service_name = self.get_service(volume_from_spec.source)
-                    volume_from_spec = VolumeFromSpec(service_name, volume_from_spec.mode)
+                    service = self.get_service(volume_from_spec.source)
+                    volume_from_spec = volume_from_spec._replace(source=service)
                 except NoSuchService:
                     try:
-                        container_name = Container.from_id(self.client, volume_from_spec.source)
-                        volume_from_spec = VolumeFromSpec(container_name, volume_from_spec.mode)
+                        container = Container.from_id(self.client, volume_from_spec.source)
+                        volume_from_spec = volume_from_spec._replace(source=container)
                     except APIError:
                         raise ConfigurationError(
                             'Service "%s" mounts volumes from "%s", which is '

+ 1 - 18
compose/service.py

@@ -70,6 +70,7 @@ class BuildError(Exception):
         self.reason = reason
 
 
+# TODO: remove
 class ConfigError(ValueError):
     pass
 
@@ -86,9 +87,6 @@ class NoSuchImageError(Exception):
 VolumeSpec = namedtuple('VolumeSpec', 'external internal mode')
 
 
-VolumeFromSpec = namedtuple('VolumeFromSpec', 'source mode')
-
-
 ServiceName = namedtuple('ServiceName', 'project service number')
 
 
@@ -1029,21 +1027,6 @@ def build_volume_from(volume_from_spec):
         return ["{}:{}".format(volume_from_spec.source.id, volume_from_spec.mode)]
 
 
-def parse_volume_from_spec(volume_from_config):
-    parts = volume_from_config.split(':')
-    if len(parts) > 2:
-        raise ConfigError("Volume %s has incorrect format, should be "
-                          "external:internal[:mode]" % volume_from_config)
-
-    if len(parts) == 1:
-        source = parts[0]
-        mode = 'rw'
-    else:
-        source, mode = parts
-
-    return VolumeFromSpec(source, mode)
-
-
 # Labels
 
 

+ 1 - 1
tests/integration/project_test.py

@@ -3,12 +3,12 @@ from __future__ import unicode_literals
 from .testcases import DockerClientTestCase
 from compose.cli.docker_client import docker_client
 from compose.config import config
+from compose.config.types import VolumeFromSpec
 from compose.const import LABEL_PROJECT
 from compose.container import Container
 from compose.project import Project
 from compose.service import ConvergenceStrategy
 from compose.service import Net
-from compose.service import VolumeFromSpec
 
 
 def build_service_dicts(service_config):

+ 1 - 1
tests/integration/service_test.py

@@ -14,6 +14,7 @@ from .. import mock
 from .testcases import DockerClientTestCase
 from .testcases import pull_busybox
 from compose import __version__
+from compose.config.types import VolumeFromSpec
 from compose.const import LABEL_CONFIG_HASH
 from compose.const import LABEL_CONTAINER_NUMBER
 from compose.const import LABEL_ONE_OFF
@@ -27,7 +28,6 @@ from compose.service import ConvergencePlan
 from compose.service import ConvergenceStrategy
 from compose.service import Net
 from compose.service import Service
-from compose.service import VolumeFromSpec
 
 
 def create_and_start_container(service, **override_options):

+ 13 - 10
tests/unit/project_test.py

@@ -4,6 +4,7 @@ import docker
 
 from .. import mock
 from .. import unittest
+from compose.config.types import VolumeFromSpec
 from compose.const import LABEL_SERVICE
 from compose.container import Container
 from compose.project import Project
@@ -43,7 +44,7 @@ class ProjectTest(unittest.TestCase):
             {
                 'name': 'db',
                 'image': 'busybox:latest',
-                'volumes_from': ['volume']
+                'volumes_from': [VolumeFromSpec('volume', 'ro')]
             },
             {
                 'name': 'volume',
@@ -167,7 +168,7 @@ class ProjectTest(unittest.TestCase):
             {
                 'name': 'test',
                 'image': 'busybox:latest',
-                'volumes_from': ['aaa']
+                'volumes_from': [VolumeFromSpec('aaa', 'rw')]
             }
         ], self.mock_client)
         self.assertEqual(project.get_service('test')._get_volumes_from(), [container_id + ":rw"])
@@ -190,17 +191,13 @@ class ProjectTest(unittest.TestCase):
             {
                 'name': 'test',
                 'image': 'busybox:latest',
-                'volumes_from': ['vol']
+                'volumes_from': [VolumeFromSpec('vol', 'rw')]
             }
         ], self.mock_client)
         self.assertEqual(project.get_service('test')._get_volumes_from(), [container_name + ":rw"])
 
-    @mock.patch.object(Service, 'containers')
-    def test_use_volumes_from_service_container(self, mock_return):
+    def test_use_volumes_from_service_container(self):
         container_ids = ['aabbccddee', '12345']
-        mock_return.return_value = [
-            mock.Mock(id=container_id, spec=Container)
-            for container_id in container_ids]
 
         project = Project.from_dicts('test', [
             {
@@ -210,10 +207,16 @@ class ProjectTest(unittest.TestCase):
             {
                 'name': 'test',
                 'image': 'busybox:latest',
-                'volumes_from': ['vol']
+                'volumes_from': [VolumeFromSpec('vol', 'rw')]
             }
         ], None)
-        self.assertEqual(project.get_service('test')._get_volumes_from(), [container_ids[0] + ':rw'])
+        with mock.patch.object(Service, 'containers') as mock_return:
+            mock_return.return_value = [
+                mock.Mock(id=container_id, spec=Container)
+                for container_id in container_ids]
+            self.assertEqual(
+                project.get_service('test')._get_volumes_from(),
+                [container_ids[0] + ':rw'])
 
     def test_net_unset(self):
         project = Project.from_dicts('test', [

+ 1 - 0
tests/unit/service_test.py

@@ -6,6 +6,7 @@ import pytest
 
 from .. import mock
 from .. import unittest
+from compose.config.types import VolumeFromSpec
 from compose.const import IS_WINDOWS_PLATFORM
 from compose.const import LABEL_CONFIG_HASH
 from compose.const import LABEL_ONE_OFF

+ 4 - 3
tests/unit/sort_service_test.py

@@ -1,4 +1,5 @@
 from .. import unittest
+from compose.config.types import VolumeFromSpec
 from compose.project import DependencyError
 from compose.project import sort_service_dicts
 
@@ -73,7 +74,7 @@ class SortServiceTest(unittest.TestCase):
             },
             {
                 'name': 'parent',
-                'volumes_from': ['child']
+                'volumes_from': [VolumeFromSpec('child', 'rw')]
             },
             {
                 'links': ['parent'],
@@ -116,7 +117,7 @@ class SortServiceTest(unittest.TestCase):
             },
             {
                 'name': 'parent',
-                'volumes_from': ['child']
+                'volumes_from': [VolumeFromSpec('child', 'ro')]
             },
             {
                 'name': 'child'
@@ -141,7 +142,7 @@ class SortServiceTest(unittest.TestCase):
             },
             {
                 'name': 'two',
-                'volumes_from': ['one']
+                'volumes_from': [VolumeFromSpec('one', 'rw')]
             },
             {
                 'name': 'one'