1
0
Эх сурвалжийг харах

Connect services to networks with the 'networks' key

Signed-off-by: Aanand Prasad <[email protected]>
Aanand Prasad 9 жил өмнө
parent
commit
3eafdbb01b

+ 2 - 1
compose/cli/main.py

@@ -704,7 +704,7 @@ def run_one_off_container(container_options, project, service, options):
         **container_options)
 
     if options['-d']:
-        container.start()
+        service.start_container(container)
         print(container.name)
         return
 
@@ -716,6 +716,7 @@ def run_one_off_container(container_options, project, service, options):
     try:
         try:
             dockerpty.start(project.client, container.id, interactive=not options['-T'])
+            service.connect_container_to_networks(container)
             exit_code = container.wait()
         except signals.ShutdownException:
             project.client.stop(container.id)

+ 7 - 0
compose/config/service_schema_v2.json

@@ -89,6 +89,13 @@
         "mac_address": {"type": "string"},
         "mem_limit": {"type": ["number", "string"]},
         "memswap_limit": {"type": ["number", "string"]},
+
+        "networks": {
+          "type": "array",
+          "items": {"type": "string"},
+          "uniqueItems": true
+        },
+
         "pid": {"type": ["string", "null"]},
 
         "ports": {

+ 1 - 0
compose/network.py

@@ -11,6 +11,7 @@ from .config import ConfigurationError
 log = logging.getLogger(__name__)
 
 
+# TODO: support external networks
 class Network(object):
     def __init__(self, client, project, name, driver=None, driver_opts=None):
         self.client = client

+ 37 - 11
compose/project.py

@@ -58,7 +58,21 @@ class Project(object):
         use_networking = (config_data.version and config_data.version >= 2)
         project = cls(name, [], client, use_networking=use_networking)
 
+        custom_networks = []
+        if config_data.networks:
+            for network_name, data in config_data.networks.items():
+                custom_networks.append(
+                    Network(
+                        client=client, project=name, name=network_name,
+                        driver=data.get('driver'), driver_opts=data.get('driver_opts')
+                    )
+                )
+
         for service_dict in config_data.services:
+            networks = project.get_networks(
+                service_dict,
+                custom_networks + [project.default_network])
+
             links = project.get_links(service_dict)
             volumes_from = get_volumes_from(project, service_dict)
             net = project.get_net(service_dict)
@@ -68,19 +82,15 @@ class Project(object):
                     client=client,
                     project=name,
                     use_networking=use_networking,
+                    networks=networks,
                     links=links,
                     net=net,
                     volumes_from=volumes_from,
                     **service_dict))
 
-        if config_data.networks:
-            for network_name, data in config_data.networks.items():
-                project.networks.append(
-                    Network(
-                        client=client, project=name, name=network_name,
-                        driver=data.get('driver'), driver_opts=data.get('driver_opts')
-                    )
-                )
+        project.networks += custom_networks
+        if project.uses_default_network():
+            project.networks.append(project.default_network)
 
         if config_data.volumes:
             for vol_name, data in config_data.volumes.items():
@@ -154,6 +164,18 @@ class Project(object):
             service.remove_duplicate_containers()
         return services
 
+    def get_networks(self, service_dict, network_definitions):
+        networks = []
+        for name in service_dict.pop('networks', ['default']):
+            matches = [n for n in network_definitions if n.name == name]
+            if matches:
+                networks.append(matches[0].full_name)
+            else:
+                raise ConfigurationError(
+                    'Service "{}" uses an undefined network "{}"'
+                    .format(service_dict['name'], name))
+        return networks
+
     def get_links(self, service_dict):
         links = []
         if 'links' in service_dict:
@@ -172,10 +194,11 @@ class Project(object):
         return links
 
     def get_net(self, service_dict):
+        if self.use_networking:
+            return Net(None)
+
         net = service_dict.pop('net', None)
         if not net:
-            if self.use_networking:
-                return Net(self.default_network.full_name)
             return Net(None)
 
         net_name = get_service_name_from_net(net)
@@ -282,6 +305,9 @@ class Project(object):
             volume.remove()
 
     def initialize_networks(self):
+        if not self.use_networking:
+            return
+
         networks = self.networks
         if self.uses_default_network():
             networks.append(self.default_network)
@@ -291,7 +317,7 @@ class Project(object):
 
     def uses_default_network(self):
         return any(
-            service.net.mode == self.default_network.full_name
+            self.default_network.full_name in service.networks
             for service in self.services
         )
 

+ 15 - 4
compose/service.py

@@ -116,6 +116,7 @@ class Service(object):
         links=None,
         volumes_from=None,
         net=None,
+        networks=None,
         **options
     ):
         self.name = name
@@ -125,6 +126,7 @@ class Service(object):
         self.links = links or []
         self.volumes_from = volumes_from or []
         self.net = net or Net(None)
+        self.networks = networks or []
         self.options = options
 
     def containers(self, stopped=False, one_off=False, filters={}):
@@ -175,7 +177,7 @@ class Service(object):
 
         def create_and_start(service, number):
             container = service.create_container(number=number, quiet=True)
-            container.start()
+            service.start_container(container)
             return container
 
         running_containers = self.containers(stopped=False)
@@ -348,7 +350,7 @@ class Service(object):
                 container.attach_log_stream()
 
             if start:
-                container.start()
+                self.start_container(container)
 
             return [container]
 
@@ -406,7 +408,7 @@ class Service(object):
         if attach_logs:
             new_container.attach_log_stream()
         if start_new_container:
-            new_container.start()
+            self.start_container(new_container)
         container.remove()
         return new_container
 
@@ -415,9 +417,18 @@ class Service(object):
             log.info("Starting %s" % container.name)
             if attach_logs:
                 container.attach_log_stream()
-            container.start()
+            return self.start_container(container)
+
+    def start_container(self, container):
+        container.start()
+        self.connect_container_to_networks(container)
         return container
 
+    def connect_container_to_networks(self, container):
+        for network in self.networks:
+            log.debug('Connecting "{}" to "{}"'.format(container.name, network))
+            self.client.connect_container_to_network(container.id, network)
+
     def remove_duplicate_containers(self, timeout=DEFAULT_TIMEOUT):
         for c in self.duplicate_containers():
             log.info('Removing %s' % c.name)

+ 54 - 14
tests/acceptance/cli_test.py

@@ -103,8 +103,15 @@ class CLITestCase(DockerClientTestCase):
         if self.base_dir:
             self.project.kill()
             self.project.remove_stopped()
+
             for container in self.project.containers(stopped=True, one_off=True):
                 container.remove(force=True)
+
+            networks = self.client.networks()
+            for n in networks:
+                if n['Name'].startswith('{}_'.format(self.project.name)):
+                    self.client.remove_network(n['Name'])
+
         super(CLITestCase, self).tearDown()
 
     @property
@@ -357,12 +364,11 @@ class CLITestCase(DockerClientTestCase):
         services = self.project.get_services()
 
         networks = self.client.networks(names=[self.project.default_network.full_name])
-        for n in networks:
-            self.addCleanup(self.client.remove_network, n['Id'])
         self.assertEqual(len(networks), 1)
         self.assertEqual(networks[0]['Driver'], 'bridge')
 
         network = self.client.inspect_network(networks[0]['Id'])
+        # print self.project.services[0].containers()[0].get('NetworkSettings')
         self.assertEqual(len(network['Containers']), len(services))
 
         for service in services:
@@ -374,14 +380,52 @@ class CLITestCase(DockerClientTestCase):
         self.base_dir = 'tests/fixtures/networks'
         self.dispatch(['up', '-d'], None)
 
-        networks = self.client.networks(names=[
-            '{}_{}'.format(self.project.name, n)
-            for n in ['foo', 'bar']])
+        back_name = '{}_back'.format(self.project.name)
+        front_name = '{}_front'.format(self.project.name)
+
+        networks = [
+            n for n in self.client.networks()
+            if n['Name'].startswith('{}_'.format(self.project.name))
+        ]
+
+        # Two networks were created: back and front
+        assert sorted(n['Name'] for n in networks) == [back_name, front_name]
 
-        self.assertEqual(len(networks), 2)
+        back_network = [n for n in networks if n['Name'] == back_name][0]
+        front_network = [n for n in networks if n['Name'] == front_name][0]
 
-        for net in networks:
-            self.assertEqual(net['Driver'], 'bridge')
+        web_container = self.project.get_service('web').containers()[0]
+        app_container = self.project.get_service('app').containers()[0]
+        db_container = self.project.get_service('db').containers()[0]
+
+        # db and app joined the back network
+        assert sorted(back_network['Containers']) == sorted([db_container.id, app_container.id])
+
+        # web and app joined the front network
+        assert sorted(front_network['Containers']) == sorted([web_container.id, app_container.id])
+
+    def test_up_missing_network(self):
+        self.base_dir = 'tests/fixtures/networks'
+
+        result = self.dispatch(
+            ['-f', 'missing-network.yml', 'up', '-d'],
+            returncode=1)
+
+        assert 'Service "web" uses an undefined network "foo"' in result.stderr
+
+    def test_up_no_services(self):
+        self.base_dir = 'tests/fixtures/no-services'
+        self.dispatch(['up', '-d'], None)
+
+        network_names = [
+            n['Name'] for n in self.client.networks()
+            if n['Name'].startswith('{}_'.format(self.project.name))
+        ]
+
+        assert sorted(network_names) == [
+            '{}_{}'.format(self.project.name, name)
+            for name in ['bar', 'foo']
+        ]
 
     def test_up_with_links_is_invalid(self):
         self.base_dir = 'tests/fixtures/v2-simple'
@@ -400,9 +444,7 @@ class CLITestCase(DockerClientTestCase):
 
         # No network was created
         networks = self.client.networks(names=[self.project.default_network.full_name])
-        for n in networks:
-            self.addCleanup(self.client.remove_network, n['Id'])
-        self.assertEqual(len(networks), 0)
+        assert networks == []
 
         web = self.project.get_service('web')
         db = self.project.get_service('db')
@@ -731,8 +773,6 @@ class CLITestCase(DockerClientTestCase):
         service = self.project.get_service('simple')
         container, = service.containers(stopped=True, one_off=True)
         networks = self.client.networks(names=[self.project.default_network.full_name])
-        for n in networks:
-            self.addCleanup(self.client.remove_network, n['Id'])
         self.assertEqual(len(networks), 1)
         self.assertEqual(container.human_readable_command, u'true')
 
@@ -890,7 +930,7 @@ class CLITestCase(DockerClientTestCase):
     def test_restart(self):
         service = self.project.get_service('simple')
         container = service.create_container()
-        container.start()
+        service.start_container(container)
         started_at = container.dictionary['State']['StartedAt']
         self.dispatch(['restart', '-t', '1'], None)
         container.inspect()

+ 16 - 4
tests/fixtures/networks/docker-compose.yml

@@ -1,7 +1,19 @@
 version: 2
 
-networks:
-  foo:
-    driver:
+services:
+  web:
+    image: busybox
+    command: top
+    networks: ["front"]
+  app:
+    image: busybox
+    command: top
+    networks: ["front", "back"]
+  db:
+    image: busybox
+    command: top
+    networks: ["back"]
 
-  bar: {}
+networks:
+  front: {}
+  back: {}

+ 10 - 0
tests/fixtures/networks/missing-network.yml

@@ -0,0 +1,10 @@
+version: 2
+
+services:
+  web:
+    image: busybox
+    command: top
+    networks: ["foo"]
+
+networks:
+  bar: {}

+ 5 - 0
tests/fixtures/no-services/docker-compose.yml

@@ -0,0 +1,5 @@
+version: 2
+
+networks:
+  foo: {}
+  bar: {}

+ 2 - 2
tests/integration/resilience_test.py

@@ -17,7 +17,7 @@ class ResilienceTest(DockerClientTestCase):
         self.project = Project('composetest', [self.db], self.client)
 
         container = self.db.create_container()
-        container.start()
+        self.db.start_container(container)
         self.host_path = container.get_mount('/var/db')['Source']
 
     def test_successful_recreate(self):
@@ -35,7 +35,7 @@ class ResilienceTest(DockerClientTestCase):
         self.assertEqual(container.get_mount('/var/db')['Source'], self.host_path)
 
     def test_start_failure(self):
-        with mock.patch('compose.container.Container.start', crash):
+        with mock.patch('compose.service.Service.start_container', crash):
             with self.assertRaises(Crash):
                 self.project.up(strategy=ConvergenceStrategy.always)
 

+ 13 - 20
tests/integration/service_test.py

@@ -32,14 +32,7 @@ from compose.service import Service
 
 def create_and_start_container(service, **override_options):
     container = service.create_container(**override_options)
-    container.start()
-    return container
-
-
-def remove_stopped(service):
-    containers = [c for c in service.containers(stopped=True) if not c.is_running]
-    for container in containers:
-        container.remove()
+    return service.start_container(container)
 
 
 class ServiceTest(DockerClientTestCase):
@@ -88,19 +81,19 @@ class ServiceTest(DockerClientTestCase):
     def test_create_container_with_unspecified_volume(self):
         service = self.create_service('db', volumes=[VolumeSpec.parse('/var/db')])
         container = service.create_container()
-        container.start()
+        service.start_container(container)
         assert container.get_mount('/var/db')
 
     def test_create_container_with_volume_driver(self):
         service = self.create_service('db', volume_driver='foodriver')
         container = service.create_container()
-        container.start()
+        service.start_container(container)
         self.assertEqual('foodriver', container.get('HostConfig.VolumeDriver'))
 
     def test_create_container_with_cpu_shares(self):
         service = self.create_service('db', cpu_shares=73)
         container = service.create_container()
-        container.start()
+        service.start_container(container)
         self.assertEqual(container.get('HostConfig.CpuShares'), 73)
 
     def test_create_container_with_cpu_quota(self):
@@ -113,7 +106,7 @@ class ServiceTest(DockerClientTestCase):
         extra_hosts = ['somehost:162.242.195.82', 'otherhost:50.31.209.229']
         service = self.create_service('db', extra_hosts=extra_hosts)
         container = service.create_container()
-        container.start()
+        service.start_container(container)
         self.assertEqual(set(container.get('HostConfig.ExtraHosts')), set(extra_hosts))
 
     def test_create_container_with_extra_hosts_dicts(self):
@@ -121,33 +114,33 @@ class ServiceTest(DockerClientTestCase):
         extra_hosts_list = ['somehost:162.242.195.82', 'otherhost:50.31.209.229']
         service = self.create_service('db', extra_hosts=extra_hosts)
         container = service.create_container()
-        container.start()
+        service.start_container(container)
         self.assertEqual(set(container.get('HostConfig.ExtraHosts')), set(extra_hosts_list))
 
     def test_create_container_with_cpu_set(self):
         service = self.create_service('db', cpuset='0')
         container = service.create_container()
-        container.start()
+        service.start_container(container)
         self.assertEqual(container.get('HostConfig.CpusetCpus'), '0')
 
     def test_create_container_with_read_only_root_fs(self):
         read_only = True
         service = self.create_service('db', read_only=read_only)
         container = service.create_container()
-        container.start()
+        service.start_container(container)
         self.assertEqual(container.get('HostConfig.ReadonlyRootfs'), read_only, container.get('HostConfig'))
 
     def test_create_container_with_security_opt(self):
         security_opt = ['label:disable']
         service = self.create_service('db', security_opt=security_opt)
         container = service.create_container()
-        container.start()
+        service.start_container(container)
         self.assertEqual(set(container.get('HostConfig.SecurityOpt')), set(security_opt))
 
     def test_create_container_with_mac_address(self):
         service = self.create_service('db', mac_address='02:42:ac:11:65:43')
         container = service.create_container()
-        container.start()
+        service.start_container(container)
         self.assertEqual(container.inspect()['Config']['MacAddress'], '02:42:ac:11:65:43')
 
     def test_create_container_with_specified_volume(self):
@@ -158,7 +151,7 @@ class ServiceTest(DockerClientTestCase):
             'db',
             volumes=[VolumeSpec(host_path, container_path, 'rw')])
         container = service.create_container()
-        container.start()
+        service.start_container(container)
         assert container.get_mount(container_path)
 
         # Match the last component ("host-path"), because boot2docker symlinks /tmp
@@ -229,7 +222,7 @@ class ServiceTest(DockerClientTestCase):
             ]
         )
         host_container = host_service.create_container()
-        host_container.start()
+        host_service.start_container(host_container)
         self.assertIn(volume_container_1.id + ':rw',
                       host_container.get('HostConfig.VolumesFrom'))
         self.assertIn(volume_container_2.id + ':rw',
@@ -248,7 +241,7 @@ class ServiceTest(DockerClientTestCase):
         self.assertEqual(old_container.get('Config.Cmd'), ['-d', '1'])
         self.assertIn('FOO=1', old_container.get('Config.Env'))
         self.assertEqual(old_container.name, 'composetest_db_1')
-        old_container.start()
+        service.start_container(old_container)
         old_container.inspect()  # reload volume data
         volume_path = old_container.get_mount('/etc')['Source']
 

+ 33 - 22
tests/unit/project_test.py

@@ -12,8 +12,6 @@ from compose.config.types import VolumeFromSpec
 from compose.const import LABEL_SERVICE
 from compose.container import Container
 from compose.project import Project
-from compose.service import ContainerNet
-from compose.service import Net
 from compose.service import Service
 
 
@@ -412,29 +410,42 @@ class ProjectTest(unittest.TestCase):
         self.assertEqual(service.net.mode, 'container:' + container_name)
 
     def test_uses_default_network_true(self):
-        web = Service('web', project='test', image="alpine", net=Net('test_default'))
-        db = Service('web', project='test', image="alpine", net=Net('other'))
-        project = Project('test', [web, db], None)
-        assert project.uses_default_network()
+        project = Project.from_config(
+            name='test',
+            client=self.mock_client,
+            config_data=Config(
+                version=2,
+                services=[
+                    {
+                        'name': 'foo',
+                        'image': 'busybox:latest'
+                    },
+                ],
+                networks=None,
+                volumes=None,
+            ),
+        )
 
-    def test_uses_default_network_custom_name(self):
-        web = Service('web', project='test', image="alpine", net=Net('other'))
-        project = Project('test', [web], None)
-        assert not project.uses_default_network()
+        assert project.uses_default_network()
 
-    def test_uses_default_network_host(self):
-        web = Service('web', project='test', image="alpine", net=Net('host'))
-        project = Project('test', [web], None)
-        assert not project.uses_default_network()
+    def test_uses_default_network_false(self):
+        project = Project.from_config(
+            name='test',
+            client=self.mock_client,
+            config_data=Config(
+                version=2,
+                services=[
+                    {
+                        'name': 'foo',
+                        'image': 'busybox:latest',
+                        'networks': ['custom']
+                    },
+                ],
+                networks={'custom': {}},
+                volumes=None,
+            ),
+        )
 
-    def test_uses_default_network_container(self):
-        container = mock.Mock(id='test')
-        web = Service(
-            'web',
-            project='test',
-            image="alpine",
-            net=ContainerNet(container))
-        project = Project('test', [web], None)
         assert not project.uses_default_network()
 
     def test_container_without_name(self):