1
0
Эх сурвалжийг харах

Merge pull request #2786 from dnephin/refactor_project

Fix a few bugs around networking and project initilization
Daniel Nephin 10 жил өмнө
parent
commit
c290e560cb

+ 1 - 1
compose/cli/main.py

@@ -693,7 +693,7 @@ def run_one_off_container(container_options, project, service, options):
                 start_deps=True,
                 strategy=ConvergenceStrategy.never)
 
-    project.initialize_networks()
+    project.initialize()
 
     container = service.create_container(
         quiet=True,

+ 73 - 0
compose/network.py

@@ -104,3 +104,76 @@ def create_ipam_config_from_dict(ipam_dict):
             for config in ipam_dict.get('config', [])
         ],
     )
+
+
+def build_networks(name, config_data, client):
+    network_config = config_data.networks or {}
+    networks = {
+        network_name: Network(
+            client=client, project=name, name=network_name,
+            driver=data.get('driver'),
+            driver_opts=data.get('driver_opts'),
+            ipam=data.get('ipam'),
+            external_name=data.get('external_name'),
+        )
+        for network_name, data in network_config.items()
+    }
+
+    if 'default' not in networks:
+        networks['default'] = Network(client, name, 'default')
+
+    return networks
+
+
+class ProjectNetworks(object):
+
+    def __init__(self, networks, use_networking):
+        self.networks = networks or {}
+        self.use_networking = use_networking
+
+    @classmethod
+    def from_services(cls, services, networks, use_networking):
+        service_networks = {
+            network: networks.get(network)
+            for service in services
+            for network in get_network_names_for_service(service)
+        }
+        unused = set(networks) - set(service_networks) - {'default'}
+        if unused:
+            log.warn(
+                "Some networks were defined but are not used by any service: "
+                "{}".format(", ".join(unused)))
+        return cls(service_networks, use_networking)
+
+    def remove(self):
+        if not self.use_networking:
+            return
+        for network in self.networks.values():
+            network.remove()
+
+    def initialize(self):
+        if not self.use_networking:
+            return
+
+        for network in self.networks.values():
+            network.ensure()
+
+
+def get_network_names_for_service(service_dict):
+    if 'network_mode' in service_dict:
+        return []
+    return service_dict.get('networks', ['default'])
+
+
+def get_networks(service_dict, network_definitions):
+    networks = []
+    for name in get_network_names_for_service(service_dict):
+        network = network_definitions.get(name)
+        if network:
+            networks.append(network.full_name)
+        else:
+            raise ConfigurationError(
+                'Service "{}" uses an undefined network "{}"'
+                .format(service_dict['name'], name))
+
+    return networks

+ 36 - 137
compose/project.py

@@ -6,7 +6,6 @@ import logging
 from functools import reduce
 
 from docker.errors import APIError
-from docker.errors import NotFound
 
 from . import parallel
 from .config import ConfigurationError
@@ -19,14 +18,16 @@ from .const import LABEL_ONE_OFF
 from .const import LABEL_PROJECT
 from .const import LABEL_SERVICE
 from .container import Container
-from .network import Network
+from .network import build_networks
+from .network import get_networks
+from .network import ProjectNetworks
 from .service import ContainerNetworkMode
 from .service import ConvergenceStrategy
 from .service import NetworkMode
 from .service import Service
 from .service import ServiceNetworkMode
 from .utils import microseconds_from_time_nano
-from .volume import Volume
+from .volume import ProjectVolumes
 
 
 log = logging.getLogger(__name__)
@@ -36,15 +37,12 @@ class Project(object):
     """
     A collection of services.
     """
-    def __init__(self, name, services, client, networks=None, volumes=None,
-                 use_networking=False, network_driver=None):
+    def __init__(self, name, services, client, networks=None, volumes=None):
         self.name = name
         self.services = services
         self.client = client
-        self.use_networking = use_networking
-        self.network_driver = network_driver
-        self.networks = networks or []
-        self.volumes = volumes or {}
+        self.volumes = volumes or ProjectVolumes({})
+        self.networks = networks or ProjectNetworks({}, False)
 
     def labels(self, one_off=False):
         return [
@@ -58,68 +56,45 @@ class Project(object):
         Construct a Project from a config.Config object.
         """
         use_networking = (config_data.version and config_data.version != V1)
-        project = cls(name, [], client, use_networking=use_networking)
-
-        network_config = config_data.networks or {}
-        custom_networks = [
-            Network(
-                client=client, project=name, name=network_name,
-                driver=data.get('driver'),
-                driver_opts=data.get('driver_opts'),
-                ipam=data.get('ipam'),
-                external_name=data.get('external_name'),
-            )
-            for network_name, data in network_config.items()
-        ]
-
-        all_networks = custom_networks[:]
-        if 'default' not in network_config:
-            all_networks.append(project.default_network)
-
-        if config_data.volumes:
-            for vol_name, data in config_data.volumes.items():
-                project.volumes[vol_name] = Volume(
-                    client=client, project=name, name=vol_name,
-                    driver=data.get('driver'),
-                    driver_opts=data.get('driver_opts'),
-                    external_name=data.get('external_name')
-                )
+        networks = build_networks(name, config_data, client)
+        project_networks = ProjectNetworks.from_services(
+            config_data.services,
+            networks,
+            use_networking)
+        volumes = ProjectVolumes.from_config(name, config_data, client)
+        project = cls(name, [], client, project_networks, volumes)
 
         for service_dict in config_data.services:
+            service_dict = dict(service_dict)
             if use_networking:
-                networks = get_networks(service_dict, all_networks)
+                service_networks = get_networks(service_dict, networks)
             else:
-                networks = []
+                service_networks = []
 
+            service_dict.pop('networks', None)
             links = project.get_links(service_dict)
-            network_mode = project.get_network_mode(service_dict, networks)
+            network_mode = project.get_network_mode(service_dict, service_networks)
             volumes_from = get_volumes_from(project, service_dict)
 
             if config_data.version != V1:
-                service_volumes = service_dict.get('volumes', [])
-                for volume_spec in service_volumes:
-                    if volume_spec.is_named_volume:
-                        declared_volume = project.volumes[volume_spec.external]
-                        service_volumes[service_volumes.index(volume_spec)] = (
-                            volume_spec._replace(external=declared_volume.full_name)
-                        )
+                service_dict['volumes'] = [
+                    volumes.namespace_spec(volume_spec)
+                    for volume_spec in service_dict.get('volumes', [])
+                ]
 
             project.services.append(
                 Service(
+                    service_dict.pop('name'),
                     client=client,
                     project=name,
                     use_networking=use_networking,
-                    networks=networks,
+                    networks=service_networks,
                     links=links,
                     network_mode=network_mode,
                     volumes_from=volumes_from,
                     **service_dict)
             )
 
-        project.networks += custom_networks
-        if 'default' not in network_config and project.uses_default_network():
-            project.networks.append(project.default_network)
-
         return project
 
     @property
@@ -201,7 +176,7 @@ class Project(object):
     def get_network_mode(self, service_dict, networks):
         network_mode = service_dict.pop('network_mode', None)
         if not network_mode:
-            if self.use_networking:
+            if self.networks.use_networking:
                 return NetworkMode(networks[0]) if networks else NetworkMode('none')
             return NetworkMode(None)
 
@@ -246,49 +221,13 @@ class Project(object):
     def remove_stopped(self, service_names=None, **options):
         parallel.parallel_remove(self.containers(service_names, stopped=True), options)
 
-    def initialize_volumes(self):
-        try:
-            for volume in self.volumes.values():
-                if volume.external:
-                    log.debug(
-                        'Volume {0} declared as external. No new '
-                        'volume will be created.'.format(volume.name)
-                    )
-                    if not volume.exists():
-                        raise ConfigurationError(
-                            'Volume {name} declared as external, but could'
-                            ' not be found. Please create the volume manually'
-                            ' using `{command}{name}` and try again.'.format(
-                                name=volume.full_name,
-                                command='docker volume create --name='
-                            )
-                        )
-                    continue
-                volume.create()
-        except NotFound:
-            raise ConfigurationError(
-                'Volume %s specifies nonexistent driver %s' % (volume.name, volume.driver)
-            )
-        except APIError as e:
-            if 'Choose a different volume name' in str(e):
-                raise ConfigurationError(
-                    'Configuration for volume {0} specifies driver {1}, but '
-                    'a volume with the same name uses a different driver '
-                    '({3}). If you wish to use the new configuration, please '
-                    'remove the existing volume "{2}" first:\n'
-                    '$ docker volume rm {2}'.format(
-                        volume.name, volume.driver, volume.full_name,
-                        volume.inspect()['Driver']
-                    )
-                )
-
     def down(self, remove_image_type, include_volumes):
         self.stop()
         self.remove_stopped(v=include_volumes)
-        self.remove_networks()
+        self.networks.remove()
 
         if include_volumes:
-            self.remove_volumes()
+            self.volumes.remove()
 
         self.remove_images(remove_image_type)
 
@@ -296,33 +235,6 @@ class Project(object):
         for service in self.get_services():
             service.remove_image(remove_image_type)
 
-    def remove_networks(self):
-        if not self.use_networking:
-            return
-        for network in self.networks:
-            network.remove()
-
-    def remove_volumes(self):
-        for volume in self.volumes.values():
-            volume.remove()
-
-    def initialize_networks(self):
-        if not self.use_networking:
-            return
-
-        for network in self.networks:
-            network.ensure()
-
-    def uses_default_network(self):
-        return any(
-            self.default_network.full_name in service.networks
-            for service in self.services
-        )
-
-    @property
-    def default_network(self):
-        return Network(client=self.client, project=self.name, name='default')
-
     def restart(self, service_names=None, **options):
         containers = self.containers(service_names, stopped=True)
         parallel.parallel_restart(containers, options)
@@ -388,13 +300,12 @@ class Project(object):
            timeout=DEFAULT_TIMEOUT,
            detached=False):
 
-        services = self.get_services_without_duplicate(service_names, include_deps=start_deps)
+        self.initialize()
+        services = self.get_services_without_duplicate(
+            service_names,
+            include_deps=start_deps)
 
         plans = self._get_convergence_plans(services, strategy)
-
-        self.initialize_networks()
-        self.initialize_volumes()
-
         return [
             container
             for service in services
@@ -406,6 +317,10 @@ class Project(object):
             )
         ]
 
+    def initialize(self):
+        self.networks.initialize()
+        self.volumes.initialize()
+
     def _get_convergence_plans(self, services, strategy):
         plans = {}
 
@@ -465,22 +380,6 @@ class Project(object):
         return acc + dep_services
 
 
-def get_networks(service_dict, network_definitions):
-    if 'network_mode' in service_dict:
-        return []
-
-    networks = []
-    for name in service_dict.pop('networks', ['default']):
-        matches = [n for n in network_definitions if n.name == name]
-        if matches:
-            networks.append(matches[0].full_name)
-        else:
-            raise ConfigurationError(
-                'Service "{}" uses an undefined network "{}"'
-                .format(service_dict['name'], name))
-    return networks
-
-
 def get_volumes_from(project, service_dict):
     volumes_from = service_dict.pop('volumes_from', None)
     if not volumes_from:

+ 1 - 0
compose/service.py

@@ -472,6 +472,7 @@ class Service(object):
             'image_id': self.image()['Id'],
             'links': self.get_link_names(),
             'net': self.network_mode.id,
+            'networks': self.networks,
             'volumes_from': [
                 (v.source.name, v.mode)
                 for v in self.volumes_from if isinstance(v.source, Service)

+ 70 - 0
compose/volume.py

@@ -3,8 +3,10 @@ from __future__ import unicode_literals
 
 import logging
 
+from docker.errors import APIError
 from docker.errors import NotFound
 
+from .config import ConfigurationError
 
 log = logging.getLogger(__name__)
 
@@ -50,3 +52,71 @@ class Volume(object):
         if self.external_name:
             return self.external_name
         return '{0}_{1}'.format(self.project, self.name)
+
+
+class ProjectVolumes(object):
+
+    def __init__(self, volumes):
+        self.volumes = volumes
+
+    @classmethod
+    def from_config(cls, name, config_data, client):
+        config_volumes = config_data.volumes or {}
+        volumes = {
+            vol_name: Volume(
+                    client=client,
+                    project=name,
+                    name=vol_name,
+                    driver=data.get('driver'),
+                    driver_opts=data.get('driver_opts'),
+                    external_name=data.get('external_name'))
+            for vol_name, data in config_volumes.items()
+        }
+        return cls(volumes)
+
+    def remove(self):
+        for volume in self.volumes.values():
+            volume.remove()
+
+    def initialize(self):
+        try:
+            for volume in self.volumes.values():
+                if volume.external:
+                    log.debug(
+                        'Volume {0} declared as external. No new '
+                        'volume will be created.'.format(volume.name)
+                    )
+                    if not volume.exists():
+                        raise ConfigurationError(
+                            'Volume {name} declared as external, but could'
+                            ' not be found. Please create the volume manually'
+                            ' using `{command}{name}` and try again.'.format(
+                                name=volume.full_name,
+                                command='docker volume create --name='
+                            )
+                        )
+                    continue
+                volume.create()
+        except NotFound:
+            raise ConfigurationError(
+                'Volume %s specifies nonexistent driver %s' % (volume.name, volume.driver)
+            )
+        except APIError as e:
+            if 'Choose a different volume name' in str(e):
+                raise ConfigurationError(
+                    'Configuration for volume {0} specifies driver {1}, but '
+                    'a volume with the same name uses a different driver '
+                    '({3}). If you wish to use the new configuration, please '
+                    'remove the existing volume "{2}" first:\n'
+                    '$ docker volume rm {2}'.format(
+                        volume.name, volume.driver, volume.full_name,
+                        volume.inspect()['Driver']
+                    )
+                )
+
+    def namespace_spec(self, volume_spec):
+        if not volume_spec.is_named_volume:
+            return volume_spec
+
+        volume = self.volumes[volume_spec.external]
+        return volume_spec._replace(external=volume.full_name)

+ 8 - 8
tests/acceptance/cli_test.py

@@ -406,7 +406,8 @@ class CLITestCase(DockerClientTestCase):
 
         services = self.project.get_services()
 
-        networks = self.client.networks(names=[self.project.default_network.full_name])
+        network_name = self.project.networks.networks['default'].full_name
+        networks = self.client.networks(names=[network_name])
         self.assertEqual(len(networks), 1)
         self.assertEqual(networks[0]['Driver'], 'bridge')
         assert 'com.docker.network.bridge.enable_icc' not in networks[0]['Options']
@@ -439,7 +440,9 @@ class CLITestCase(DockerClientTestCase):
 
         self.dispatch(['-f', filename, 'up', '-d'], None)
 
-        networks = self.client.networks(names=[self.project.default_network.full_name])
+        network_name = self.project.networks.networks['default'].full_name
+        networks = self.client.networks(names=[network_name])
+
         assert networks[0]['Options']['com.docker.network.bridge.enable_icc'] == 'false'
 
     @v2_only()
@@ -586,18 +589,15 @@ class CLITestCase(DockerClientTestCase):
             n['Name'] for n in self.client.networks()
             if n['Name'].startswith('{}_'.format(self.project.name))
         ]
-
-        assert sorted(network_names) == [
-            '{}_{}'.format(self.project.name, name)
-            for name in ['bar', 'foo']
-        ]
+        assert network_names == []
 
     def test_up_with_links_v1(self):
         self.base_dir = 'tests/fixtures/links-composefile'
         self.dispatch(['up', '-d', 'web'], None)
 
         # No network was created
-        networks = self.client.networks(names=[self.project.default_network.full_name])
+        network_name = self.project.networks.networks['default'].full_name
+        networks = self.client.networks(names=[network_name])
         assert networks == []
 
         web = self.project.get_service('web')

+ 14 - 8
tests/integration/project_test.py

@@ -565,6 +565,7 @@ class ProjectTest(DockerClientTestCase):
                 'name': 'web',
                 'image': 'busybox:latest',
                 'command': 'top',
+                'networks': ['foo', 'bar', 'baz'],
             }],
             volumes={},
             networks={
@@ -594,7 +595,11 @@ class ProjectTest(DockerClientTestCase):
     def test_up_with_ipam_config(self):
         config_data = config.Config(
             version=V2_0,
-            services=[],
+            services=[{
+                'name': 'web',
+                'image': 'busybox:latest',
+                'networks': ['front'],
+            }],
             volumes={},
             networks={
                 'front': {
@@ -744,7 +749,7 @@ class ProjectTest(DockerClientTestCase):
             name='composetest',
             config_data=config_data, client=self.client
         )
-        project.initialize_volumes()
+        project.volumes.initialize()
 
         volume_data = self.client.inspect_volume(full_vol_name)
         self.assertEqual(volume_data['Name'], full_vol_name)
@@ -795,7 +800,7 @@ class ProjectTest(DockerClientTestCase):
             config_data=config_data, client=self.client
         )
         with self.assertRaises(config.ConfigurationError):
-            project.initialize_volumes()
+            project.volumes.initialize()
 
     @v2_only()
     def test_initialize_volumes_updated_driver(self):
@@ -816,7 +821,7 @@ class ProjectTest(DockerClientTestCase):
             name='composetest',
             config_data=config_data, client=self.client
         )
-        project.initialize_volumes()
+        project.volumes.initialize()
 
         volume_data = self.client.inspect_volume(full_vol_name)
         self.assertEqual(volume_data['Name'], full_vol_name)
@@ -827,10 +832,11 @@ class ProjectTest(DockerClientTestCase):
         )
         project = Project.from_config(
             name='composetest',
-            config_data=config_data, client=self.client
+            config_data=config_data,
+            client=self.client
         )
         with self.assertRaises(config.ConfigurationError) as e:
-            project.initialize_volumes()
+            project.volumes.initialize()
         assert 'Configuration for volume {0} specifies driver smb'.format(
             vol_name
         ) in str(e.exception)
@@ -857,7 +863,7 @@ class ProjectTest(DockerClientTestCase):
             name='composetest',
             config_data=config_data, client=self.client
         )
-        project.initialize_volumes()
+        project.volumes.initialize()
 
         with self.assertRaises(NotFound):
             self.client.inspect_volume(full_vol_name)
@@ -883,7 +889,7 @@ class ProjectTest(DockerClientTestCase):
             config_data=config_data, client=self.client
         )
         with self.assertRaises(config.ConfigurationError) as e:
-            project.initialize_volumes()
+            project.volumes.initialize()
         assert 'Volume {0} declared as external'.format(
             vol_name
         ) in str(e.exception)

+ 4 - 4
tests/unit/project_test.py

@@ -45,7 +45,7 @@ class ProjectTest(unittest.TestCase):
         self.assertEqual(project.get_service('web').options['image'], 'busybox:latest')
         self.assertEqual(project.get_service('db').name, 'db')
         self.assertEqual(project.get_service('db').options['image'], 'busybox:latest')
-        self.assertFalse(project.use_networking)
+        self.assertFalse(project.networks.use_networking)
 
     def test_from_config_v2(self):
         config = Config(
@@ -65,7 +65,7 @@ class ProjectTest(unittest.TestCase):
         )
         project = Project.from_config('composetest', config, None)
         self.assertEqual(len(project.services), 2)
-        self.assertTrue(project.use_networking)
+        self.assertTrue(project.networks.use_networking)
 
     def test_get_service(self):
         web = Service(
@@ -426,7 +426,7 @@ class ProjectTest(unittest.TestCase):
             ),
         )
 
-        assert project.uses_default_network()
+        assert 'default' in project.networks.networks
 
     def test_uses_default_network_false(self):
         project = Project.from_config(
@@ -446,7 +446,7 @@ class ProjectTest(unittest.TestCase):
             ),
         )
 
-        assert not project.uses_default_network()
+        assert 'default' not in project.networks.networks
 
     def test_container_without_name(self):
         self.mock_client.containers.return_value = [

+ 5 - 3
tests/unit/service_test.py

@@ -266,7 +266,7 @@ class ServiceTest(unittest.TestCase):
 
         self.assertEqual(
             opts['labels'][LABEL_CONFIG_HASH],
-            '3c85881a8903b9d73a06c41860c8be08acce1494ab4cf8408375966dccd714de')
+            'f8bfa1058ad1f4231372a0b1639f0dfdb574dafff4e8d7938049ae993f7cf1fc')
         self.assertEqual(
             opts['environment'],
             {
@@ -417,9 +417,10 @@ class ServiceTest(unittest.TestCase):
             'options': {'image': 'example.com/foo'},
             'links': [('one', 'one')],
             'net': 'other',
+            'networks': [],
             'volumes_from': [('two', 'rw')],
         }
-        self.assertEqual(config_dict, expected)
+        assert config_dict == expected
 
     def test_config_dict_with_network_mode_from_container(self):
         self.mock_client.inspect_image.return_value = {'Id': 'abcd'}
@@ -437,10 +438,11 @@ class ServiceTest(unittest.TestCase):
             'image_id': 'abcd',
             'options': {'image': 'example.com/foo'},
             'links': [],
+            'networks': [],
             'net': 'aaabbb',
             'volumes_from': [],
         }
-        self.assertEqual(config_dict, expected)
+        assert config_dict == expected
 
     def test_remove_image_none(self):
         web = Service('web', image='example', client=self.mock_client)