Browse Source

Require volumes_from a container to be explicit in V2 config.

Signed-off-by: Daniel Nephin <[email protected]>
Daniel Nephin 10 năm trước cách đây
mục cha
commit
b76dc1e05e

+ 11 - 5
compose/config/config.py

@@ -292,7 +292,7 @@ def load_volumes(config_files):
 
 
 def load_services(working_dir, filename, service_configs, version):
-    def build_service(service_name, service_dict):
+    def build_service(service_name, service_dict, service_names):
         service_config = ServiceConfig.with_abs_paths(
             working_dir,
             filename,
@@ -305,13 +305,17 @@ def load_services(working_dir, filename, service_configs, version):
         validate_against_service_schema(service_dict, service_config.name, version)
         validate_paths(service_dict)
 
-        service_dict = finalize_service(service_config._replace(config=service_dict))
+        service_dict = finalize_service(
+            service_config._replace(config=service_dict),
+            service_names,
+            version)
         service_dict['name'] = service_config.name
         return service_dict
 
     def build_services(service_config):
+        service_names = service_config.keys()
         return sort_service_dicts([
-            build_service(name, service_dict)
+            build_service(name, service_dict, service_names)
             for name, service_dict in service_config.items()
         ])
 
@@ -504,7 +508,7 @@ def process_service(service_config):
     return service_dict
 
 
-def finalize_service(service_config):
+def finalize_service(service_config, service_names, version):
     service_dict = dict(service_config.config)
 
     if 'environment' in service_dict or 'env_file' in service_dict:
@@ -513,7 +517,9 @@ def finalize_service(service_config):
 
     if 'volumes_from' in service_dict:
         service_dict['volumes_from'] = [
-            VolumeFromSpec.parse(vf) for vf in service_dict['volumes_from']]
+            VolumeFromSpec.parse(vf, service_names, version)
+            for vf in service_dict['volumes_from']
+        ]
 
     if 'volumes' in service_dict:
         service_dict['volumes'] = [

+ 41 - 3
compose/config/types.py

@@ -11,10 +11,16 @@ from compose.config.errors import ConfigurationError
 from compose.const import IS_WINDOWS_PLATFORM
 
 
-class VolumeFromSpec(namedtuple('_VolumeFromSpec', 'source mode')):
+class VolumeFromSpec(namedtuple('_VolumeFromSpec', 'source mode type')):
 
+    # TODO: drop service_names arg when v1 is removed
     @classmethod
-    def parse(cls, volume_from_config):
+    def parse(cls, volume_from_config, service_names, version):
+        func = cls.parse_v1 if version == 1 else cls.parse_v2
+        return func(service_names, volume_from_config)
+
+    @classmethod
+    def parse_v1(cls, service_names, volume_from_config):
         parts = volume_from_config.split(':')
         if len(parts) > 2:
             raise ConfigurationError(
@@ -27,7 +33,39 @@ class VolumeFromSpec(namedtuple('_VolumeFromSpec', 'source mode')):
         else:
             source, mode = parts
 
-        return cls(source, mode)
+        type = 'service' if source in service_names else 'container'
+        return cls(source, mode, type)
+
+    @classmethod
+    def parse_v2(cls, service_names, volume_from_config):
+        parts = volume_from_config.split(':')
+        if len(parts) > 3:
+            raise ConfigurationError(
+                "volume_from {} has incorrect format, should be one of "
+                "'<service name>[:<mode>]' or "
+                "'container:<container name>[:<mode>]'".format(volume_from_config))
+
+        if len(parts) == 1:
+            source = parts[0]
+            return cls(source, 'rw', 'service')
+
+        if len(parts) == 2:
+            if parts[0] == 'container':
+                type, source = parts
+                return cls(source, 'rw', type)
+
+            source, mode = parts
+            return cls(source, mode, 'service')
+
+        if len(parts) == 3:
+            type, source, mode = parts
+            if type not in ('service', 'container'):
+                raise ConfigurationError(
+                    "Unknown volumes_from type '{}' in '{}'".format(
+                        type,
+                        volume_from_config))
+
+        return cls(source, mode, type)
 
 
 def parse_restart_spec(restart_config):

+ 29 - 23
compose/project.py

@@ -60,7 +60,7 @@ class Project(object):
 
         for service_dict in config_data.services:
             links = project.get_links(service_dict)
-            volumes_from = project.get_volumes_from(service_dict)
+            volumes_from = get_volumes_from(project, service_dict)
             net = project.get_net(service_dict)
 
             project.services.append(
@@ -162,28 +162,6 @@ class Project(object):
             del service_dict['links']
         return links
 
-    def get_volumes_from(self, service_dict):
-        volumes_from = []
-        if 'volumes_from' in service_dict:
-            for volume_from_spec in service_dict.get('volumes_from', []):
-                # Get service
-                try:
-                    service = self.get_service(volume_from_spec.source)
-                    volume_from_spec = volume_from_spec._replace(source=service)
-                except NoSuchService:
-                    try:
-                        container = Container.from_id(self.client, volume_from_spec.source)
-                        volume_from_spec = volume_from_spec._replace(source=container)
-                    except APIError:
-                        raise ConfigurationError(
-                            'Service "%s" mounts volumes from "%s", which is '
-                            'not the name of a service or container.' % (
-                                service_dict['name'],
-                                volume_from_spec.source))
-                volumes_from.append(volume_from_spec)
-            del service_dict['volumes_from']
-        return volumes_from
-
     def get_net(self, service_dict):
         net = service_dict.pop('net', None)
         if not net:
@@ -465,6 +443,34 @@ def remove_links(service_dicts):
         del s['links']
 
 
+def get_volumes_from(project, service_dict):
+    volumes_from = service_dict.pop('volumes_from', None)
+    if not volumes_from:
+        return []
+
+    def build_volume_from(spec):
+        if spec.type == 'service':
+            try:
+                return spec._replace(source=project.get_service(spec.source))
+            except NoSuchService:
+                pass
+
+        if spec.type == 'container':
+            try:
+                container = Container.from_id(project.client, spec.source)
+                return spec._replace(source=container)
+            except APIError:
+                pass
+
+        raise ConfigurationError(
+            "Service \"{}\" mounts volumes from \"{}\", which is not the name "
+            "of a service or container.".format(
+                service_dict['name'],
+                spec.source))
+
+    return [build_volume_from(vf) for vf in volumes_from]
+
+
 class NoSuchService(Exception):
     def __init__(self, name):
         self.name = name

+ 1 - 1
setup.py

@@ -28,7 +28,7 @@ def find_version(*file_paths):
 
 
 install_requires = [
-    'cached-property >= 1.2.0',
+    'cached-property >= 1.2.0, < 2',
     'docopt >= 0.6.1, < 0.7',
     'PyYAML >= 3.10, < 4',
     'requests >= 2.6.1, < 2.8',

+ 1 - 1
tests/integration/project_test.py

@@ -81,7 +81,7 @@ class ProjectTest(DockerClientTestCase):
         )
         db = project.get_service('db')
         data = project.get_service('data')
-        self.assertEqual(db.volumes_from, [VolumeFromSpec(data, 'rw')])
+        self.assertEqual(db.volumes_from, [VolumeFromSpec(data, 'rw', 'service')])
 
     def test_volumes_from_container(self):
         data_container = Container.create(

+ 2 - 2
tests/integration/service_test.py

@@ -224,8 +224,8 @@ class ServiceTest(DockerClientTestCase):
         host_service = self.create_service(
             'host',
             volumes_from=[
-                VolumeFromSpec(volume_service, 'rw'),
-                VolumeFromSpec(volume_container_2, 'rw')
+                VolumeFromSpec(volume_service, 'rw', 'service'),
+                VolumeFromSpec(volume_container_2, 'rw', 'container')
             ]
         )
         host_container = host_service.create_container()

+ 1 - 1
tests/unit/config/config_test.py

@@ -19,7 +19,7 @@ from compose.const import IS_WINDOWS_PLATFORM
 from tests import mock
 from tests import unittest
 
-DEFAULT_VERSION = 2
+DEFAULT_VERSION = V2 = 2
 V1 = 1
 
 

+ 3 - 3
tests/unit/config/sort_services_test.py

@@ -77,7 +77,7 @@ class SortServiceTest(unittest.TestCase):
             },
             {
                 'name': 'parent',
-                'volumes_from': [VolumeFromSpec('child', 'rw')]
+                'volumes_from': [VolumeFromSpec('child', 'rw', 'service')]
             },
             {
                 'links': ['parent'],
@@ -120,7 +120,7 @@ class SortServiceTest(unittest.TestCase):
             },
             {
                 'name': 'parent',
-                'volumes_from': [VolumeFromSpec('child', 'ro')]
+                'volumes_from': [VolumeFromSpec('child', 'ro', 'service')]
             },
             {
                 'name': 'child'
@@ -145,7 +145,7 @@ class SortServiceTest(unittest.TestCase):
             },
             {
                 'name': 'two',
-                'volumes_from': [VolumeFromSpec('one', 'rw')]
+                'volumes_from': [VolumeFromSpec('one', 'rw', 'service')]
             },
             {
                 'name': 'one'

+ 45 - 0
tests/unit/config/types_test.py

@@ -5,8 +5,11 @@ import pytest
 
 from compose.config.errors import ConfigurationError
 from compose.config.types import parse_extra_hosts
+from compose.config.types import VolumeFromSpec
 from compose.config.types import VolumeSpec
 from compose.const import IS_WINDOWS_PLATFORM
+from tests.unit.config.config_test import V1
+from tests.unit.config.config_test import V2
 
 
 def test_parse_extra_hosts_list():
@@ -67,3 +70,45 @@ class TestVolumeSpec(object):
             "/opt/shiny/config",
             "ro"
         )
+
+
+class TestVolumesFromSpec(object):
+
+    services = ['servicea', 'serviceb']
+
+    def test_parse_v1_from_service(self):
+        volume_from = VolumeFromSpec.parse('servicea', self.services, V1)
+        assert volume_from == VolumeFromSpec('servicea', 'rw', 'service')
+
+    def test_parse_v1_from_container(self):
+        volume_from = VolumeFromSpec.parse('foo:ro', self.services, V1)
+        assert volume_from == VolumeFromSpec('foo', 'ro', 'container')
+
+    def test_parse_v1_invalid(self):
+        with pytest.raises(ConfigurationError):
+            VolumeFromSpec.parse('unknown:format:ro', self.services, V1)
+
+    def test_parse_v2_from_service(self):
+        volume_from = VolumeFromSpec.parse('servicea', self.services, V2)
+        assert volume_from == VolumeFromSpec('servicea', 'rw', 'service')
+
+    def test_parse_v2_from_service_with_mode(self):
+        volume_from = VolumeFromSpec.parse('servicea:ro', self.services, V2)
+        assert volume_from == VolumeFromSpec('servicea', 'ro', 'service')
+
+    def test_parse_v2_from_container(self):
+        volume_from = VolumeFromSpec.parse('container:foo', self.services, V2)
+        assert volume_from == VolumeFromSpec('foo', 'rw', 'container')
+
+    def test_parse_v2_from_container_with_mode(self):
+        volume_from = VolumeFromSpec.parse('container:foo:ro', self.services, V2)
+        assert volume_from == VolumeFromSpec('foo', 'ro', 'container')
+
+    def test_parse_v2_invalid_type(self):
+        with pytest.raises(ConfigurationError) as exc:
+            VolumeFromSpec.parse('bogus:foo:ro', self.services, V2)
+        assert "Unknown volumes_from type 'bogus'" in exc.exconly()
+
+    def test_parse_v2_invalid(self):
+        with pytest.raises(ConfigurationError):
+            VolumeFromSpec.parse('unknown:format:ro', self.services, V2)

+ 9 - 8
tests/unit/project_test.py

@@ -165,10 +165,10 @@ class ProjectTest(unittest.TestCase):
             {
                 'name': 'test',
                 'image': 'busybox:latest',
-                'volumes_from': [VolumeFromSpec('aaa', 'rw')]
+                'volumes_from': [VolumeFromSpec('aaa', 'rw', 'container')]
             }
         ], None), self.mock_client)
-        self.assertEqual(project.get_service('test')._get_volumes_from(), [container_id + ":rw"])
+        assert project.get_service('test')._get_volumes_from() == [container_id + ":rw"]
 
     def test_use_volumes_from_service_no_container(self):
         container_name = 'test_vol_1'
@@ -188,10 +188,10 @@ class ProjectTest(unittest.TestCase):
             {
                 'name': 'test',
                 'image': 'busybox:latest',
-                'volumes_from': [VolumeFromSpec('vol', 'rw')]
+                'volumes_from': [VolumeFromSpec('vol', 'rw', 'service')]
             }
         ], None), self.mock_client)
-        self.assertEqual(project.get_service('test')._get_volumes_from(), [container_name + ":rw"])
+        assert project.get_service('test')._get_volumes_from() == [container_name + ":rw"]
 
     def test_use_volumes_from_service_container(self):
         container_ids = ['aabbccddee', '12345']
@@ -204,16 +204,17 @@ class ProjectTest(unittest.TestCase):
             {
                 'name': 'test',
                 'image': 'busybox:latest',
-                'volumes_from': [VolumeFromSpec('vol', 'rw')]
+                'volumes_from': [VolumeFromSpec('vol', 'rw', 'service')]
             }
         ], None), None)
         with mock.patch.object(Service, 'containers') as mock_return:
             mock_return.return_value = [
                 mock.Mock(id=container_id, spec=Container)
                 for container_id in container_ids]
-            self.assertEqual(
-                project.get_service('test')._get_volumes_from(),
-                [container_ids[0] + ':rw'])
+            assert (
+                project.get_service('test')._get_volumes_from() ==
+                [container_ids[0] + ':rw']
+            )
 
     def test_events(self):
         services = [Service(name='web'), Service(name='db')]

+ 23 - 6
tests/unit/service_test.py

@@ -70,7 +70,11 @@ class ServiceTest(unittest.TestCase):
         service = Service(
             'test',
             image='foo',
-            volumes_from=[VolumeFromSpec(mock.Mock(id=container_id, spec=Container), 'rw')])
+            volumes_from=[
+                VolumeFromSpec(
+                    mock.Mock(id=container_id, spec=Container),
+                    'rw',
+                    'container')])
 
         self.assertEqual(service._get_volumes_from(), [container_id + ':rw'])
 
@@ -79,7 +83,11 @@ class ServiceTest(unittest.TestCase):
         service = Service(
             'test',
             image='foo',
-            volumes_from=[VolumeFromSpec(mock.Mock(id=container_id, spec=Container), 'ro')])
+            volumes_from=[
+                VolumeFromSpec(
+                    mock.Mock(id=container_id, spec=Container),
+                    'ro',
+                    'container')])
 
         self.assertEqual(service._get_volumes_from(), [container_id + ':ro'])
 
@@ -90,7 +98,10 @@ class ServiceTest(unittest.TestCase):
             mock.Mock(id=container_id, spec=Container)
             for container_id in container_ids
         ]
-        service = Service('test', volumes_from=[VolumeFromSpec(from_service, 'rw')], image='foo')
+        service = Service(
+            'test',
+            volumes_from=[VolumeFromSpec(from_service, 'rw', 'service')],
+            image='foo')
 
         self.assertEqual(service._get_volumes_from(), [container_ids[0] + ":rw"])
 
@@ -102,7 +113,10 @@ class ServiceTest(unittest.TestCase):
                 mock.Mock(id=container_id.split(':')[0], spec=Container)
                 for container_id in container_ids
             ]
-            service = Service('test', volumes_from=[VolumeFromSpec(from_service, mode)], image='foo')
+            service = Service(
+                'test',
+                volumes_from=[VolumeFromSpec(from_service, mode, 'service')],
+                image='foo')
 
             self.assertEqual(service._get_volumes_from(), [container_ids[0]])
 
@@ -113,7 +127,10 @@ class ServiceTest(unittest.TestCase):
         from_service.create_container.return_value = mock.Mock(
             id=container_id,
             spec=Container)
-        service = Service('test', image='foo', volumes_from=[VolumeFromSpec(from_service, 'rw')])
+        service = Service(
+            'test',
+            image='foo',
+            volumes_from=[VolumeFromSpec(from_service, 'rw', 'service')])
 
         self.assertEqual(service._get_volumes_from(), [container_id + ':rw'])
         from_service.create_container.assert_called_once_with()
@@ -389,7 +406,7 @@ class ServiceTest(unittest.TestCase):
             client=self.mock_client,
             net=ServiceNet(Service('other')),
             links=[(Service('one'), 'one')],
-            volumes_from=[VolumeFromSpec(Service('two'), 'rw')])
+            volumes_from=[VolumeFromSpec(Service('two'), 'rw', 'service')])
 
         config_dict = service.config_dict()
         expected = {