Browse Source

Make volumes_from and net containers first class dependencies and
assure that starting order is correct. Added supporting unit and
integration tests as well.

Signed-off-by: Gil Clark <[email protected]>

Gil Clark 10 years ago
parent
commit
95f4e2c7c3

+ 3 - 3
compose/cli/main.py

@@ -292,7 +292,7 @@ class TopLevelCommand(Command):
             if len(deps) > 0:
                 project.up(
                     service_names=deps,
-                    start_links=True,
+                    start_deps=True,
                     recreate=False,
                     insecure_registry=insecure_registry,
                     detach=options['-d']
@@ -430,13 +430,13 @@ class TopLevelCommand(Command):
 
         monochrome = options['--no-color']
 
-        start_links = not options['--no-deps']
+        start_deps = not options['--no-deps']
         recreate = not options['--no-recreate']
         service_names = options['SERVICE']
 
         project.up(
             service_names=service_names,
-            start_links=start_links,
+            start_deps=start_deps,
             recreate=recreate,
             insecure_registry=insecure_registry,
             detach=options['-d'],

+ 67 - 21
compose/project.py

@@ -10,6 +10,17 @@ from docker.errors import APIError
 log = logging.getLogger(__name__)
 
 
+def get_service_name_from_net(net_config):
+    if not net_config:
+        return
+
+    if not net_config.startswith('container:'):
+        return
+
+    _, net_name = net_config.split(':', 1)
+    return net_name
+
+
 def sort_service_dicts(services):
     # Topological sort (Cormen/Tarjan algorithm).
     unmarked = services[:]
@@ -19,6 +30,15 @@ def sort_service_dicts(services):
     def get_service_names(links):
         return [link.split(':')[0] for link in links]
 
+    def get_service_dependents(service_dict, services):
+        name = service_dict['name']
+        return [
+            service for service in services
+            if (name in get_service_names(service.get('links', [])) or
+                name in service.get('volumes_from', []) or
+                name == get_service_name_from_net(service.get('net')))
+        ]
+
     def visit(n):
         if n['name'] in temporary_marked:
             if n['name'] in get_service_names(n.get('links', [])):
@@ -29,8 +49,7 @@ def sort_service_dicts(services):
                 raise DependencyError('Circular import between %s' % ' and '.join(temporary_marked))
         if n in unmarked:
             temporary_marked.add(n['name'])
-            dependents = [m for m in services if (n['name'] in get_service_names(m.get('links', []))) or (n['name'] in m.get('volumes_from', []))]
-            for m in dependents:
+            for m in get_service_dependents(n, services):
                 visit(m)
             temporary_marked.remove(n['name'])
             unmarked.remove(n)
@@ -60,8 +79,10 @@ class Project(object):
         for service_dict in sort_service_dicts(service_dicts):
             links = project.get_links(service_dict)
             volumes_from = project.get_volumes_from(service_dict)
+            net = project.get_net(service_dict)
 
-            project.services.append(Service(client=client, project=name, links=links, volumes_from=volumes_from, **service_dict))
+            project.services.append(Service(client=client, project=name, links=links, net=net,
+                                            volumes_from=volumes_from, **service_dict))
         return project
 
     @classmethod
@@ -85,31 +106,31 @@ class Project(object):
 
         raise NoSuchService(name)
 
-    def get_services(self, service_names=None, include_links=False):
+    def get_services(self, service_names=None, include_deps=False):
         """
         Returns a list of this project's services filtered
         by the provided list of names, or all services if service_names is None
         or [].
 
-        If include_links is specified, returns a list including the links for
+        If include_deps is specified, returns a list including the dependencies for
         service_names, in order of dependency.
 
         Preserves the original order of self.services where possible,
-        reordering as needed to resolve links.
+        reordering as needed to resolve dependencies.
 
         Raises NoSuchService if any of the named services do not exist.
         """
         if service_names is None or len(service_names) == 0:
             return self.get_services(
                 service_names=[s.name for s in self.services],
-                include_links=include_links
+                include_deps=include_deps
             )
         else:
             unsorted = [self.get_service(name) for name in service_names]
             services = [s for s in self.services if s in unsorted]
 
-            if include_links:
-                services = reduce(self._inject_links, services, [])
+            if include_deps:
+                services = reduce(self._inject_deps, services, [])
 
             uniques = []
             [uniques.append(s) for s in services if s not in uniques]
@@ -146,6 +167,28 @@ class Project(object):
             del service_dict['volumes_from']
         return volumes_from
 
+    def get_net(self, service_dict):
+        if 'net' in service_dict:
+            net_name = get_service_name_from_net(service_dict.get('net'))
+
+            if net_name:
+                try:
+                    net = self.get_service(net_name)
+                except NoSuchService:
+                    try:
+                        net = Container.from_id(self.client, net_name)
+                    except APIError:
+                        raise ConfigurationError('Serivce "%s" is trying to use the network of "%s", which is not the name of a service or container.' % (service_dict['name'], net_name))
+            else:
+                net = service_dict['net']
+
+            del service_dict['net']
+
+        else:
+            net = 'bridge'
+
+        return net
+
     def start(self, service_names=None, **options):
         for service in self.get_services(service_names):
             service.start(**options)
@@ -171,13 +214,13 @@ class Project(object):
 
     def up(self,
            service_names=None,
-           start_links=True,
+           start_deps=True,
            recreate=True,
            insecure_registry=False,
            detach=False,
            do_build=True):
         running_containers = []
-        for service in self.get_services(service_names, include_links=start_links):
+        for service in self.get_services(service_names, include_deps=start_deps):
             if recreate:
                 for (_, container) in service.recreate_containers(
                         insecure_registry=insecure_registry,
@@ -194,7 +237,7 @@ class Project(object):
         return running_containers
 
     def pull(self, service_names=None, insecure_registry=False):
-        for service in self.get_services(service_names, include_links=True):
+        for service in self.get_services(service_names, include_deps=True):
             service.pull(insecure_registry=insecure_registry)
 
     def remove_stopped(self, service_names=None, **options):
@@ -207,19 +250,22 @@ class Project(object):
                 for service in self.get_services(service_names)
                 if service.has_container(container, one_off=one_off)]
 
-    def _inject_links(self, acc, service):
-        linked_names = service.get_linked_names()
+    def _inject_deps(self, acc, service):
+        net_name = service.get_net_name()
+        dep_names = (service.get_linked_names() +
+                     service.get_volumes_from_names() +
+                     ([net_name] if net_name else []))
 
-        if len(linked_names) > 0:
-            linked_services = self.get_services(
-                service_names=linked_names,
-                include_links=True
+        if len(dep_names) > 0:
+            dep_services = self.get_services(
+                service_names=list(set(dep_names)),
+                include_deps=True
             )
         else:
-            linked_services = []
+            dep_services = []
 
-        linked_services.append(service)
-        return acc + linked_services
+        dep_services.append(service)
+        return acc + dep_services
 
 
 class NoSuchService(Exception):

+ 31 - 4
compose/service.py

@@ -88,7 +88,7 @@ ServiceName = namedtuple('ServiceName', 'project service number')
 
 
 class Service(object):
-    def __init__(self, name, client=None, project='default', links=None, external_links=None, volumes_from=None, **options):
+    def __init__(self, name, client=None, project='default', links=None, external_links=None, volumes_from=None, net=None, **options):
         if not re.match('^%s+$' % VALID_NAME_CHARS, name):
             raise ConfigError('Invalid service name "%s" - only %s are allowed' % (name, VALID_NAME_CHARS))
         if not re.match('^%s+$' % VALID_NAME_CHARS, project):
@@ -116,6 +116,7 @@ class Service(object):
         self.links = links or []
         self.external_links = external_links or []
         self.volumes_from = volumes_from or []
+        self.net = net or None
         self.options = options
 
     def containers(self, stopped=False, one_off=False):
@@ -320,7 +321,6 @@ class Service(object):
             if ':' in volume)
 
         privileged = options.get('privileged', False)
-        net = options.get('net', 'bridge')
         dns = options.get('dns', None)
         dns_search = options.get('dns_search', None)
         cap_add = options.get('cap_add', None)
@@ -334,7 +334,7 @@ class Service(object):
             binds=volume_bindings,
             volumes_from=self._get_volumes_from(intermediate_container),
             privileged=privileged,
-            network_mode=net,
+            network_mode=self._get_net(),
             dns=dns,
             dns_search=dns_search,
             restart_policy=restart,
@@ -364,6 +364,15 @@ class Service(object):
     def get_linked_names(self):
         return [s.name for (s, _) in self.links]
 
+    def get_volumes_from_names(self):
+        return [s.name for s in self.volumes_from if isinstance(s, Service)]
+
+    def get_net_name(self):
+        if isinstance(self.net, Service):
+            return self.net.name
+        else:
+            return
+
     def _next_container_name(self, all_containers, one_off=False):
         bits = [self.project, self.name]
         if one_off:
@@ -399,7 +408,6 @@ class Service(object):
         for volume_source in self.volumes_from:
             if isinstance(volume_source, Service):
                 containers = volume_source.containers(stopped=True)
-
                 if not containers:
                     volumes_from.append(volume_source.create_container().id)
                 else:
@@ -413,6 +421,25 @@ class Service(object):
 
         return volumes_from
 
+    def _get_net(self):
+        if not self.net:
+            return "bridge"
+
+        if isinstance(self.net, Service):
+            containers = self.net.containers()
+            if len(containers) > 0:
+                net = 'container:' + containers[0].id
+            else:
+                log.warning("Warning: Service %s is trying to use reuse the network stack "
+                            "of another service that is not running." % (self.net.name))
+                net = None
+        elif isinstance(self.net, Container):
+            net = 'container:' + self.net.id
+        else:
+            net = self.net
+
+        return net
+
     def _get_container_create_options(self, override_options, one_off=False):
         container_options = dict(
             (k, self.options[k])

+ 133 - 10
tests/integration/project_test.py

@@ -44,6 +44,63 @@ class ProjectTest(DockerClientTestCase):
         db = project.get_service('db')
         self.assertEqual(db.volumes_from, [data_container])
 
+        project.kill()
+        project.remove_stopped()
+
+    def test_net_from_service(self):
+        project = Project.from_config(
+            name='composetest',
+            config={
+                'net': {
+                    'image': 'busybox:latest',
+                    'command': ["/bin/sleep", "300"]
+                },
+                'web': {
+                    'image': 'busybox:latest',
+                    'net': 'container:net',
+                    'command': ["/bin/sleep", "300"]
+                },  
+            },
+            client=self.client,
+        )
+
+        project.up()
+
+        web = project.get_service('web')
+        net = project.get_service('net')
+        self.assertEqual(web._get_net(), 'container:'+net.containers()[0].id)
+
+        project.kill()
+        project.remove_stopped()
+
+    def test_net_from_container(self):
+        net_container = Container.create(
+            self.client,
+            image='busybox:latest',
+            name='composetest_net_container',
+            command='/bin/sleep 300'
+        )
+        net_container.start()
+
+        project = Project.from_config(
+            name='composetest',
+            config={
+                'web': {
+                    'image': 'busybox:latest',
+                    'net': 'container:composetest_net_container'
+                },
+            },
+            client=self.client,
+        )
+
+        project.up()
+
+        web = project.get_service('web')
+        self.assertEqual(web._get_net(), 'container:'+net_container.id)
+
+        project.kill()
+        project.remove_stopped()
+
     def test_start_stop_kill_remove(self):
         web = self.create_service('web')
         db = self.create_service('db')
@@ -199,20 +256,86 @@ class ProjectTest(DockerClientTestCase):
         project.kill()
         project.remove_stopped()
 
-    def test_project_up_with_no_deps(self):
-        console = self.create_service('console')
-        db = self.create_service('db', volumes=['/var/db'])
-        web = self.create_service('web', links=[(db, 'db')])
+    def test_project_up_starts_depends(self):
+        project = Project.from_config(
+            name='composetest',
+            config={
+                'console': {
+                    'image': 'busybox:latest',
+                    'command': ["/bin/sleep", "300"],
+                },
+                'net' : {
+                    'image': 'busybox:latest',
+                    'command': ["/bin/sleep", "300"]
+                },
+                'app': {
+                    'image': 'busybox:latest',
+                    'command': ["/bin/sleep", "300"],
+                    'net': 'container:net'
+                },
+                'web': {
+                    'image': 'busybox:latest',
+                    'command': ["/bin/sleep", "300"],
+                    'net': 'container:net',
+                    'links': ['app']
+                },
+            },
+            client=self.client,
+        )
+        project.start()
+        self.assertEqual(len(project.containers()), 0)
 
-        project = Project('composetest', [web, db, console], self.client)
+        project.up(['web'])
+        self.assertEqual(len(project.containers()), 3)
+        self.assertEqual(len(project.get_service('web').containers()), 1)
+        self.assertEqual(len(project.get_service('app').containers()), 1)
+        self.assertEqual(len(project.get_service('net').containers()), 1)
+        self.assertEqual(len(project.get_service('console').containers()), 0)
+
+        project.kill()
+        project.remove_stopped()
+
+    def test_project_up_with_no_deps(self):
+        project = Project.from_config(
+            name='composetest',
+            config={
+                'console': {
+                    'image': 'busybox:latest',
+                    'command': ["/bin/sleep", "300"],
+                },
+                'net' : {
+                    'image': 'busybox:latest',
+                    'command': ["/bin/sleep", "300"]
+                },
+                'vol': {
+                    'image': 'busybox:latest',
+                    'command': ["/bin/sleep", "300"],
+                    'volumes': ["/tmp"]
+                },
+                'app': {
+                    'image': 'busybox:latest',
+                    'command': ["/bin/sleep", "300"],
+                    'net': 'container:net'
+                },
+                'web': {
+                    'image': 'busybox:latest',
+                    'command': ["/bin/sleep", "300"],
+                    'net': 'container:net',
+                    'links': ['app'],
+                    'volumes_from': ['vol']
+                },
+            },
+            client=self.client,
+        )
         project.start()
         self.assertEqual(len(project.containers()), 0)
 
-        project.up(['web'], start_links=False)
-        self.assertEqual(len(project.containers()), 1)
-        self.assertEqual(len(web.containers()), 1)
-        self.assertEqual(len(db.containers()), 0)
-        self.assertEqual(len(console.containers()), 0)
+        project.up(['web'], start_deps=False)
+        self.assertEqual(len(project.containers(stopped=True)), 2)
+        self.assertEqual(len(project.get_service('web').containers()), 1)
+        self.assertEqual(len(project.get_service('vol').containers(stopped=True)), 1)
+        self.assertEqual(len(project.get_service('net').containers()), 0)
+        self.assertEqual(len(project.get_service('console').containers()), 0)
 
         project.kill()
         project.remove_stopped()

+ 105 - 2
tests/unit/project_test.py

@@ -2,6 +2,10 @@ from __future__ import unicode_literals
 from .. import unittest
 from compose.service import Service
 from compose.project import Project, ConfigurationError
+from compose.container import Container
+
+import mock
+import docker
 
 class ProjectTest(unittest.TestCase):
     def test_from_dict(self):
@@ -120,7 +124,7 @@ class ProjectTest(unittest.TestCase):
         )
         project = Project('test', [web, db, cache, console], None)
         self.assertEqual(
-            project.get_services(['console'], include_links=True),
+            project.get_services(['console'], include_deps=True),
             [db, web, console]
         )
 
@@ -136,6 +140,105 @@ class ProjectTest(unittest.TestCase):
         )
         project = Project('test', [web, db], None)
         self.assertEqual(
-            project.get_services(['web', 'db'], include_links=True),
+            project.get_services(['web', 'db'], include_deps=True),
             [db, web]
         )
+
+    def test_use_volumes_from_container(self):
+        container_id = 'aabbccddee'
+        container_dict = dict(Name='aaa', Id=container_id)
+        mock_client = mock.create_autospec(docker.Client)
+        mock_client.inspect_container.return_value = container_dict
+        project = Project.from_dicts('test', [
+            {
+                'name': 'test',
+                'image': 'busybox:latest',
+                'volumes_from': ['aaa']
+            }
+        ], mock_client)
+        self.assertEqual(project.get_service('test')._get_volumes_from(), [container_id])
+
+    def test_use_volumes_from_service_no_container(self):
+        container_name = 'test_vol_1'
+        mock_client = mock.create_autospec(docker.Client)
+        mock_client.containers.return_value = [
+            {
+                "Name": container_name,
+                "Names": [container_name],
+                "Id": container_name,
+                "Image": 'busybox:latest'
+            }
+        ]
+        project = Project.from_dicts('test', [
+            {
+                'name': 'vol',
+                'image': 'busybox:latest'
+            },
+            {
+                'name': 'test',
+                'image': 'busybox:latest',
+                'volumes_from': ['vol']
+            }
+        ], mock_client)
+        self.assertEqual(project.get_service('test')._get_volumes_from(), [container_name])
+
+    @mock.patch.object(Service, 'containers')
+    def test_use_volumes_from_service_container(self, mock_return):
+        container_ids = ['aabbccddee', '12345']
+        mock_return.return_value = [
+            mock.Mock(id=container_id, spec=Container)
+            for container_id in container_ids]
+
+        project = Project.from_dicts('test', [
+            {
+                'name': 'vol',
+                'image': 'busybox:latest'
+            },
+            {
+                'name': 'test',
+                'image': 'busybox:latest',
+                'volumes_from': ['vol']
+            }
+        ], None)
+        self.assertEqual(project.get_service('test')._get_volumes_from(), container_ids)
+
+    def test_use_net_from_container(self):
+        container_id = 'aabbccddee'
+        container_dict = dict(Name='aaa', Id=container_id)
+        mock_client = mock.create_autospec(docker.Client)
+        mock_client.inspect_container.return_value = container_dict
+        project = Project.from_dicts('test', [
+            {
+                'name': 'test',
+                'image': 'busybox:latest',
+                'net': 'container:aaa'
+            }
+        ], mock_client)
+        service = project.get_service('test')
+        self.assertEqual(service._get_net(), 'container:'+container_id)
+
+    def test_use_net_from_service(self):
+        container_name = 'test_aaa_1'
+        mock_client = mock.create_autospec(docker.Client)
+        mock_client.containers.return_value = [
+            {
+                "Name": container_name,
+                "Names": [container_name],
+                "Id": container_name,
+                "Image": 'busybox:latest'
+            }
+        ]
+        project = Project.from_dicts('test', [
+            {
+                'name': 'aaa',
+                'image': 'busybox:latest'
+            },
+            {
+                'name': 'test',
+                'image': 'busybox:latest',
+                'net': 'container:aaa'
+            }
+        ], mock_client)
+
+        service = project.get_service('test')
+        self.assertEqual(service._get_net(), 'container:'+container_name)

+ 89 - 0
tests/unit/sort_service_test.py

@@ -65,6 +65,95 @@ class SortServiceTest(unittest.TestCase):
         self.assertEqual(sorted_services[1]['name'], 'parent')
         self.assertEqual(sorted_services[2]['name'], 'grandparent')
 
+    def test_sort_service_dicts_4(self):
+        services = [
+            {
+                'name': 'child'
+            },
+            {
+                'name': 'parent',
+                'volumes_from': ['child']
+            },
+            {
+                'links': ['parent'],
+                'name': 'grandparent'
+            },
+        ]
+
+        sorted_services = sort_service_dicts(services)
+        self.assertEqual(len(sorted_services), 3)
+        self.assertEqual(sorted_services[0]['name'], 'child')
+        self.assertEqual(sorted_services[1]['name'], 'parent')
+        self.assertEqual(sorted_services[2]['name'], 'grandparent')
+
+    def test_sort_service_dicts_5(self):
+        services = [
+            {
+                'links': ['parent'],
+                'name': 'grandparent'
+            },
+            {
+                'name': 'parent',
+                'net': 'container:child'
+            },
+            {
+                'name': 'child'
+            }
+        ]
+
+        sorted_services = sort_service_dicts(services)
+        self.assertEqual(len(sorted_services), 3)
+        self.assertEqual(sorted_services[0]['name'], 'child')
+        self.assertEqual(sorted_services[1]['name'], 'parent')
+        self.assertEqual(sorted_services[2]['name'], 'grandparent')
+
+    def test_sort_service_dicts_6(self):
+        services = [
+            {
+                'links': ['parent'],
+                'name': 'grandparent'
+            },
+            {
+                'name': 'parent',
+                'volumes_from': ['child']
+            },
+            {
+                'name': 'child'
+            }
+        ]
+
+        sorted_services = sort_service_dicts(services)
+        self.assertEqual(len(sorted_services), 3)
+        self.assertEqual(sorted_services[0]['name'], 'child')
+        self.assertEqual(sorted_services[1]['name'], 'parent')
+        self.assertEqual(sorted_services[2]['name'], 'grandparent')
+
+    def test_sort_service_dicts_7(self):
+        services = [
+            {
+                'net': 'container:three',
+                'name': 'four'
+            },
+            {
+                'links': ['two'],
+                'name': 'three'
+            },
+            {
+                'name': 'two',
+                'volumes_from': ['one']
+            },
+            {
+                'name': 'one'
+            }
+        ]
+
+        sorted_services = sort_service_dicts(services)
+        self.assertEqual(len(sorted_services), 4)
+        self.assertEqual(sorted_services[0]['name'], 'one')
+        self.assertEqual(sorted_services[1]['name'], 'two')
+        self.assertEqual(sorted_services[2]['name'], 'three')
+        self.assertEqual(sorted_services[3]['name'], 'four')
+
     def test_sort_service_dicts_circular_imports(self):
         services = [
             {