浏览代码

Refactor network_mode logic out of Service.

Signed-off-by: Daniel Nephin <[email protected]>
Daniel Nephin 10 年之前
父节点
当前提交
187ad4ce26
共有 4 个文件被更改,包括 116 次插入37 次删除
  1. 7 4
      compose/project.py
  2. 59 29
      compose/service.py
  3. 3 3
      tests/unit/project_test.py
  4. 47 1
      tests/unit/service_test.py

+ 7 - 4
compose/project.py

@@ -14,7 +14,10 @@ from .const import LABEL_PROJECT
 from .const import LABEL_SERVICE
 from .container import Container
 from .legacy import check_for_legacy_containers
+from .service import ContainerNet
+from .service import Net
 from .service import Service
+from .service import ServiceNet
 from .utils import parallel_execute
 
 
@@ -192,18 +195,18 @@ class Project(object):
     def get_net(self, service_dict):
         net = service_dict.pop('net', None)
         if not net:
-            return
+            return Net(None)
 
         net_name = get_service_name_from_net(net)
         if not net_name:
-            return net
+            return Net(net)
 
         try:
-            return self.get_service(net_name)
+            return ServiceNet(self.get_service(net_name))
         except NoSuchService:
             pass
         try:
-            return Container.from_id(self.client, net_name)
+            return ContainerNet(Container.from_id(self.client, net_name))
         except APIError:
             raise ConfigurationError(
                 'Service "%s" is trying to use the network of "%s", '

+ 59 - 29
compose/service.py

@@ -105,7 +105,7 @@ class Service(object):
         self.project = project
         self.links = links or []
         self.volumes_from = volumes_from or []
-        self.net = net or None
+        self.net = net or Net(None)
         self.options = options
 
     def containers(self, stopped=False, one_off=False, filters={}):
@@ -489,12 +489,12 @@ class Service(object):
             'options': self.options,
             'image_id': self.image()['Id'],
             'links': [(service.name, alias) for service, alias in self.links],
-            'net': self.get_net_name() or getattr(self.net, 'id', self.net),
+            'net': self.net.id,
             'volumes_from': self.get_volumes_from_names(),
         }
 
     def get_dependency_names(self):
-        net_name = self.get_net_name()
+        net_name = self.net.service_name
         return (self.get_linked_names() +
                 self.get_volumes_from_names() +
                 ([net_name] if net_name else []))
@@ -505,12 +505,6 @@ class Service(object):
     def get_volumes_from_names(self):
         return [s.name for s in self.volumes_from if isinstance(s, Service)]
 
-    def get_net_name(self):
-        if isinstance(self.net, Service):
-            return self.net.name
-        else:
-            return
-
     def get_container_name(self, number, one_off=False):
         # TODO: Implement issue #652 here
         return build_container_name(self.project, self.name, number, one_off)
@@ -562,25 +556,6 @@ class Service(object):
 
         return volumes_from
 
-    def _get_net(self):
-        if not self.net:
-            return None
-
-        if isinstance(self.net, Service):
-            containers = self.net.containers()
-            if len(containers) > 0:
-                net = 'container:' + containers[0].id
-            else:
-                log.warning("Warning: Service %s is trying to use reuse the network stack "
-                            "of another service that is not running." % (self.net.name))
-                net = None
-        elif isinstance(self.net, Container):
-            net = 'container:' + self.net.id
-        else:
-            net = self.net
-
-        return net
-
     def _get_container_create_options(
             self,
             override_options,
@@ -694,7 +669,7 @@ class Service(object):
             binds=options.get('binds'),
             volumes_from=self._get_volumes_from(),
             privileged=privileged,
-            network_mode=self._get_net(),
+            network_mode=self.net.mode,
             devices=devices,
             dns=dns,
             dns_search=dns_search,
@@ -793,6 +768,61 @@ class Service(object):
         stream_output(output, sys.stdout)
 
 
+class Net(object):
+    """A `standard` network mode (ex: host, bridge)"""
+
+    service_name = None
+
+    def __init__(self, net):
+        self.net = net
+
+    @property
+    def id(self):
+        return self.net
+
+    mode = id
+
+
+class ContainerNet(object):
+    """A network mode that uses a containers network stack."""
+
+    service_name = None
+
+    def __init__(self, container):
+        self.container = container
+
+    @property
+    def id(self):
+        return self.container.id
+
+    @property
+    def mode(self):
+        return 'container:' + self.container.id
+
+
+class ServiceNet(object):
+    """A network mode that uses a service's network stack."""
+
+    def __init__(self, service):
+        self.service = service
+
+    @property
+    def id(self):
+        return self.service.name
+
+    service_name = id
+
+    @property
+    def mode(self):
+        containers = self.service.containers()
+        if containers:
+            return 'container:' + containers[0].id
+
+        log.warn("Warning: Service %s is trying to use reuse the network stack "
+                 "of another service that is not running." % (self.id))
+        return None
+
+
 # Names
 
 

+ 3 - 3
tests/unit/project_test.py

@@ -221,7 +221,7 @@ class ProjectTest(unittest.TestCase):
             }
         ], self.mock_client)
         service = project.get_service('test')
-        self.assertEqual(service._get_net(), None)
+        self.assertEqual(service.net.id, None)
         self.assertNotIn('NetworkMode', service._get_container_host_config({}))
 
     def test_use_net_from_container(self):
@@ -236,7 +236,7 @@ class ProjectTest(unittest.TestCase):
             }
         ], self.mock_client)
         service = project.get_service('test')
-        self.assertEqual(service._get_net(), 'container:' + container_id)
+        self.assertEqual(service.net.mode, 'container:' + container_id)
 
     def test_use_net_from_service(self):
         container_name = 'test_aaa_1'
@@ -261,7 +261,7 @@ class ProjectTest(unittest.TestCase):
         ], self.mock_client)
 
         service = project.get_service('test')
-        self.assertEqual(service._get_net(), 'container:' + container_name)
+        self.assertEqual(service.net.mode, 'container:' + container_name)
 
     def test_container_without_name(self):
         self.mock_client.containers.return_value = [

+ 47 - 1
tests/unit/service_test.py

@@ -13,13 +13,16 @@ from compose.const import LABEL_SERVICE
 from compose.container import Container
 from compose.service import build_volume_binding
 from compose.service import ConfigError
+from compose.service import ContainerNet
 from compose.service import get_container_data_volumes
 from compose.service import merge_volume_bindings
 from compose.service import NeedsBuildError
+from compose.service import Net
 from compose.service import NoSuchImageError
 from compose.service import parse_repository_tag
 from compose.service import parse_volume_spec
 from compose.service import Service
+from compose.service import ServiceNet
 
 
 class ServiceTest(unittest.TestCase):
@@ -337,7 +340,7 @@ class ServiceTest(unittest.TestCase):
             'foo',
             image='example.com/foo',
             client=self.mock_client,
-            net=Service('other'),
+            net=ServiceNet(Service('other')),
             links=[(Service('one'), 'one')],
             volumes_from=[Service('two')])
 
@@ -373,6 +376,49 @@ class ServiceTest(unittest.TestCase):
         self.assertEqual(config_dict, expected)
 
 
+class NetTestCase(unittest.TestCase):
+
+    def test_net(self):
+        net = Net('host')
+        self.assertEqual(net.id, 'host')
+        self.assertEqual(net.mode, 'host')
+        self.assertEqual(net.service_name, None)
+
+    def test_net_container(self):
+        container_id = 'abcd'
+        net = ContainerNet(Container(None, {'Id': container_id}))
+        self.assertEqual(net.id, container_id)
+        self.assertEqual(net.mode, 'container:' + container_id)
+        self.assertEqual(net.service_name, None)
+
+    def test_net_service(self):
+        container_id = 'bbbb'
+        service_name = 'web'
+        mock_client = mock.create_autospec(docker.Client)
+        mock_client.containers.return_value = [
+            {'Id': container_id, 'Name': container_id, 'Image': 'abcd'},
+        ]
+
+        service = Service(name=service_name, client=mock_client)
+        net = ServiceNet(service)
+
+        self.assertEqual(net.id, service_name)
+        self.assertEqual(net.mode, 'container:' + container_id)
+        self.assertEqual(net.service_name, service_name)
+
+    def test_net_service_no_containers(self):
+        service_name = 'web'
+        mock_client = mock.create_autospec(docker.Client)
+        mock_client.containers.return_value = []
+
+        service = Service(name=service_name, client=mock_client)
+        net = ServiceNet(service)
+
+        self.assertEqual(net.id, service_name)
+        self.assertEqual(net.mode, None)
+        self.assertEqual(net.service_name, service_name)
+
+
 def mock_get_image(images):
     if images:
         return images[0]