1
0
Эх сурвалжийг харах

Refactor project network initlization.

Signed-off-by: Daniel Nephin <[email protected]>
Daniel Nephin 9 жил өмнө
parent
commit
8e838968fe

+ 64 - 0
compose/network.py

@@ -104,3 +104,67 @@ def create_ipam_config_from_dict(ipam_dict):
             for config in ipam_dict.get('config', [])
         ],
     )
+
+
+def build_networks(name, config_data, client):
+    network_config = config_data.networks or {}
+    networks = {
+        network_name: Network(
+            client=client, project=name, name=network_name,
+            driver=data.get('driver'),
+            driver_opts=data.get('driver_opts'),
+            ipam=data.get('ipam'),
+            external_name=data.get('external_name'),
+        )
+        for network_name, data in network_config.items()
+    }
+
+    if 'default' not in networks:
+        networks['default'] = Network(client, name, 'default')
+
+    return networks
+
+
+class ProjectNetworks(object):
+
+    def __init__(self, networks, use_networking):
+        self.networks = networks or {}
+        self.use_networking = use_networking
+
+    @classmethod
+    def from_services(cls, services, networks, use_networking):
+        networks = {
+            network: networks[network]
+            for service in services
+            for network in service.get('networks', ['default'])
+        }
+        return cls(networks, use_networking)
+
+    def remove(self):
+        if not self.use_networking:
+            return
+        for network in self.networks.values():
+            network.remove()
+
+    def initialize(self):
+        if not self.use_networking:
+            return
+
+        for network in self.networks.values():
+            network.ensure()
+
+
+def get_networks(service_dict, network_definitions):
+    if 'network_mode' in service_dict:
+        return []
+
+    networks = []
+    for name in service_dict.pop('networks', ['default']):
+        network = network_definitions.get(name)
+        if network:
+            networks.append(network.full_name)
+        else:
+            raise ConfigurationError(
+                'Service "{}" uses an undefined network "{}"'
+                .format(service_dict['name'], name))
+    return networks

+ 20 - 73
compose/project.py

@@ -19,7 +19,9 @@ from .const import LABEL_ONE_OFF
 from .const import LABEL_PROJECT
 from .const import LABEL_SERVICE
 from .container import Container
-from .network import Network
+from .network import build_networks
+from .network import get_networks
+from .network import ProjectNetworks
 from .service import ContainerNetworkMode
 from .service import ConvergenceStrategy
 from .service import NetworkMode
@@ -36,15 +38,12 @@ class Project(object):
     """
     A collection of services.
     """
-    def __init__(self, name, services, client, networks=None, volumes=None,
-                 use_networking=False, network_driver=None):
+    def __init__(self, name, services, client, networks=None, volumes=None):
         self.name = name
         self.services = services
         self.client = client
-        self.use_networking = use_networking
-        self.network_driver = network_driver
-        self.networks = networks or []
         self.volumes = volumes or {}
+        self.networks = networks or ProjectNetworks({}, False)
 
     def labels(self, one_off=False):
         return [
@@ -58,23 +57,12 @@ class Project(object):
         Construct a Project from a config.Config object.
         """
         use_networking = (config_data.version and config_data.version != V1)
-        project = cls(name, [], client, use_networking=use_networking)
-
-        network_config = config_data.networks or {}
-        custom_networks = [
-            Network(
-                client=client, project=name, name=network_name,
-                driver=data.get('driver'),
-                driver_opts=data.get('driver_opts'),
-                ipam=data.get('ipam'),
-                external_name=data.get('external_name'),
-            )
-            for network_name, data in network_config.items()
-        ]
-
-        all_networks = custom_networks[:]
-        if 'default' not in network_config:
-            all_networks.append(project.default_network)
+        networks = build_networks(name, config_data, client)
+        project_networks = ProjectNetworks.from_services(
+            config_data.services,
+            networks,
+            use_networking)
+        project = cls(name, [], client, project_networks)
 
         if config_data.volumes:
             for vol_name, data in config_data.volumes.items():
@@ -86,13 +74,15 @@ class Project(object):
                 )
 
         for service_dict in config_data.services:
+            service_dict = dict(service_dict)
             if use_networking:
-                networks = get_networks(service_dict, all_networks)
+                service_networks = get_networks(service_dict, networks)
             else:
-                networks = []
+                service_networks = []
 
+            service_dict.pop('networks', None)
             links = project.get_links(service_dict)
-            network_mode = project.get_network_mode(service_dict, networks)
+            network_mode = project.get_network_mode(service_dict, service_networks)
             volumes_from = get_volumes_from(project, service_dict)
 
             if config_data.version != V1:
@@ -109,17 +99,13 @@ class Project(object):
                     client=client,
                     project=name,
                     use_networking=use_networking,
-                    networks=networks,
+                    networks=service_networks,
                     links=links,
                     network_mode=network_mode,
                     volumes_from=volumes_from,
                     **service_dict)
             )
 
-        project.networks += custom_networks
-        if 'default' not in network_config and project.uses_default_network():
-            project.networks.append(project.default_network)
-
         return project
 
     @property
@@ -201,7 +187,7 @@ class Project(object):
     def get_network_mode(self, service_dict, networks):
         network_mode = service_dict.pop('network_mode', None)
         if not network_mode:
-            if self.use_networking:
+            if self.networks.use_networking:
                 return NetworkMode(networks[0]) if networks else NetworkMode('none')
             return NetworkMode(None)
 
@@ -285,7 +271,7 @@ class Project(object):
     def down(self, remove_image_type, include_volumes):
         self.stop()
         self.remove_stopped(v=include_volumes)
-        self.remove_networks()
+        self.networks.remove()
 
         if include_volumes:
             self.remove_volumes()
@@ -296,33 +282,10 @@ class Project(object):
         for service in self.get_services():
             service.remove_image(remove_image_type)
 
-    def remove_networks(self):
-        if not self.use_networking:
-            return
-        for network in self.networks:
-            network.remove()
-
     def remove_volumes(self):
         for volume in self.volumes.values():
             volume.remove()
 
-    def initialize_networks(self):
-        if not self.use_networking:
-            return
-
-        for network in self.networks:
-            network.ensure()
-
-    def uses_default_network(self):
-        return any(
-            self.default_network.full_name in service.networks
-            for service in self.services
-        )
-
-    @property
-    def default_network(self):
-        return Network(client=self.client, project=self.name, name='default')
-
     def restart(self, service_names=None, **options):
         containers = self.containers(service_names, stopped=True)
         parallel.parallel_restart(containers, options)
@@ -392,7 +355,7 @@ class Project(object):
 
         plans = self._get_convergence_plans(services, strategy)
 
-        self.initialize_networks()
+        self.networks.initialize()
         self.initialize_volumes()
 
         return [
@@ -465,22 +428,6 @@ class Project(object):
         return acc + dep_services
 
 
-def get_networks(service_dict, network_definitions):
-    if 'network_mode' in service_dict:
-        return []
-
-    networks = []
-    for name in service_dict.pop('networks', ['default']):
-        matches = [n for n in network_definitions if n.name == name]
-        if matches:
-            networks.append(matches[0].full_name)
-        else:
-            raise ConfigurationError(
-                'Service "{}" uses an undefined network "{}"'
-                .format(service_dict['name'], name))
-    return networks
-
-
 def get_volumes_from(project, service_dict):
     volumes_from = service_dict.pop('volumes_from', None)
     if not volumes_from:

+ 4 - 4
tests/unit/project_test.py

@@ -45,7 +45,7 @@ class ProjectTest(unittest.TestCase):
         self.assertEqual(project.get_service('web').options['image'], 'busybox:latest')
         self.assertEqual(project.get_service('db').name, 'db')
         self.assertEqual(project.get_service('db').options['image'], 'busybox:latest')
-        self.assertFalse(project.use_networking)
+        self.assertFalse(project.networks.use_networking)
 
     def test_from_config_v2(self):
         config = Config(
@@ -65,7 +65,7 @@ class ProjectTest(unittest.TestCase):
         )
         project = Project.from_config('composetest', config, None)
         self.assertEqual(len(project.services), 2)
-        self.assertTrue(project.use_networking)
+        self.assertTrue(project.networks.use_networking)
 
     def test_get_service(self):
         web = Service(
@@ -426,7 +426,7 @@ class ProjectTest(unittest.TestCase):
             ),
         )
 
-        assert project.uses_default_network()
+        assert 'default' in project.networks.networks
 
     def test_uses_default_network_false(self):
         project = Project.from_config(
@@ -446,7 +446,7 @@ class ProjectTest(unittest.TestCase):
             ),
         )
 
-        assert not project.uses_default_network()
+        assert 'default' not in project.networks.networks
 
     def test_container_without_name(self):
         self.mock_client.containers.return_value = [