Browse Source

Merge pull request #4541 from shin-/4502-expanded-port-syntax

Add support for expanded port syntax in 3.1 format
Joffrey F 8 years ago
parent
commit
0167aba2b7

+ 35 - 1
compose/config/config.py

@@ -35,6 +35,7 @@ from .sort_services import sort_service_dicts
 from .types import parse_extra_hosts
 from .types import parse_restart_spec
 from .types import ServiceLink
+from .types import ServicePort
 from .types import VolumeFromSpec
 from .types import VolumeSpec
 from .validation import match_named_volumes
@@ -685,10 +686,25 @@ def process_service(service_config):
             service_dict[field] = to_list(service_dict[field])
 
     service_dict = process_healthcheck(service_dict, service_config.name)
+    service_dict = process_ports(service_dict)
 
     return service_dict
 
 
+def process_ports(service_dict):
+    if 'ports' not in service_dict:
+        return service_dict
+
+    ports = []
+    for port_definition in service_dict['ports']:
+        if isinstance(port_definition, ServicePort):
+            ports.append(port_definition)
+        else:
+            ports.extend(ServicePort.parse(port_definition))
+    service_dict['ports'] = ports
+    return service_dict
+
+
 def process_depends_on(service_dict):
     if 'depends_on' in service_dict and not isinstance(service_dict['depends_on'], dict):
         service_dict['depends_on'] = dict([
@@ -866,7 +882,7 @@ def merge_service_dicts(base, override, version):
         md.merge_field(field, merge_path_mappings)
 
     for field in [
-        'ports', 'cap_add', 'cap_drop', 'expose', 'external_links',
+        'cap_add', 'cap_drop', 'expose', 'external_links',
         'security_opt', 'volumes_from',
     ]:
         md.merge_field(field, merge_unique_items_lists, default=[])
@@ -875,6 +891,7 @@ def merge_service_dicts(base, override, version):
         md.merge_field(field, merge_list_or_string)
 
     md.merge_field('logging', merge_logging, default={})
+    merge_ports(md, base, override)
 
     for field in set(ALLOWED_KEYS) - set(md):
         md.merge_scalar(field)
@@ -893,6 +910,23 @@ def merge_unique_items_lists(base, override):
     return sorted(set().union(base, override))
 
 
+def merge_ports(md, base, override):
+    def parse_sequence_func(seq):
+        acc = []
+        for item in seq:
+            acc.extend(ServicePort.parse(item))
+        return to_mapping(acc, 'merge_field')
+
+    field = 'ports'
+
+    if not md.needs_merge(field):
+        return
+
+    merged = parse_sequence_func(md.base.get(field, []))
+    merged.update(parse_sequence_func(md.override.get(field, [])))
+    md[field] = [item for item in sorted(merged.values())]
+
+
 def merge_build(output, base, override):
     def to_dict(service):
         build_config = service.get('build', {})

+ 15 - 2
compose/config/config_schema_v3.1.json

@@ -168,8 +168,21 @@
         "ports": {
           "type": "array",
           "items": {
-            "type": ["string", "number"],
-            "format": "ports"
+            "oneOf": [
+              {"type": "number", "format": "ports"},
+              {"type": "string", "format": "ports"},
+              {
+                "type": "object",
+                "properties": {
+                  "mode": {"type": "string"},
+                  "target": {"type": "integer"},
+                  "published": {"type": "integer"},
+                  "protocol": {"type": "string"}
+                },
+                "required": ["target"],
+                "additionalProperties": false
+              }
+            ]
           },
           "uniqueItems": true
         },

+ 12 - 2
compose/config/serialize.py

@@ -7,6 +7,7 @@ import yaml
 from compose.config import types
 from compose.config.config import V1
 from compose.config.config import V2_1
+from compose.config.config import V3_1
 
 
 def serialize_config_type(dumper, data):
@@ -14,8 +15,14 @@ def serialize_config_type(dumper, data):
     return representer(data.repr())
 
 
+def serialize_dict_type(dumper, data):
+    return dumper.represent_dict(data.repr())
+
+
 yaml.SafeDumper.add_representer(types.VolumeFromSpec, serialize_config_type)
 yaml.SafeDumper.add_representer(types.VolumeSpec, serialize_config_type)
+yaml.SafeDumper.add_representer(types.ServiceSecret, serialize_dict_type)
+yaml.SafeDumper.add_representer(types.ServicePort, serialize_dict_type)
 
 
 def denormalize_config(config):
@@ -102,7 +109,10 @@ def denormalize_service_dict(service_dict, version):
                 service_dict['healthcheck']['timeout']
             )
 
-    if 'secrets' in service_dict:
-        service_dict['secrets'] = map(lambda s: s.repr(), service_dict['secrets'])
+    if 'ports' in service_dict and version != V3_1:
+        service_dict['ports'] = map(
+            lambda p: p.legacy_repr() if isinstance(p, types.ServicePort) else p,
+            service_dict['ports']
+        )
 
     return service_dict

+ 59 - 0
compose/config/types.py

@@ -9,6 +9,7 @@ import re
 from collections import namedtuple
 
 import six
+from docker.utils.ports import build_port_bindings
 
 from ..const import COMPOSEFILE_V1 as V1
 from .errors import ConfigurationError
@@ -259,3 +260,61 @@ class ServiceSecret(namedtuple('_ServiceSecret', 'source target uid gid mode')):
         return dict(
             [(k, v) for k, v in self._asdict().items() if v is not None]
         )
+
+
+class ServicePort(namedtuple('_ServicePort', 'target published protocol mode external_ip')):
+
+    @classmethod
+    def parse(cls, spec):
+        if not isinstance(spec, dict):
+            result = []
+            for k, v in build_port_bindings([spec]).items():
+                if '/' in k:
+                    target, proto = k.split('/', 1)
+                else:
+                    target, proto = (k, None)
+                for pub in v:
+                    if pub is None:
+                        result.append(
+                            cls(target, None, proto, None, None)
+                        )
+                    elif isinstance(pub, tuple):
+                        result.append(
+                            cls(target, pub[1], proto, None, pub[0])
+                        )
+                    else:
+                        result.append(
+                            cls(target, pub, proto, None, None)
+                        )
+            return result
+
+        return [cls(
+            spec.get('target'),
+            spec.get('published'),
+            spec.get('protocol'),
+            spec.get('mode'),
+            None
+        )]
+
+    @property
+    def merge_field(self):
+        return (self.target, self.published)
+
+    def repr(self):
+        return dict(
+            [(k, v) for k, v in self._asdict().items() if v is not None]
+        )
+
+    def legacy_repr(self):
+        return normalize_port_dict(self.repr())
+
+
+def normalize_port_dict(port):
+    return '{external_ip}{has_ext_ip}{published}{is_pub}{target}/{protocol}'.format(
+        published=port.get('published', ''),
+        is_pub=(':' if port.get('published') else ''),
+        target=port.get('target'),
+        protocol=port.get('protocol', 'tcp'),
+        external_ip=port.get('external_ip', ''),
+        has_ext_ip=(':' if port.get('external_ip') else ''),
+    )

+ 20 - 5
compose/service.py

@@ -22,6 +22,7 @@ from . import const
 from . import progress_stream
 from .config import DOCKER_CONFIG_KEYS
 from .config import merge_environment
+from .config.types import ServicePort
 from .config.types import VolumeSpec
 from .const import DEFAULT_TIMEOUT
 from .const import IS_WINDOWS_PLATFORM
@@ -696,7 +697,7 @@ class Service(object):
 
         if 'ports' in container_options or 'expose' in self.options:
             container_options['ports'] = build_container_ports(
-                container_options,
+                formatted_ports(container_options.get('ports', [])),
                 self.options)
 
         container_options['environment'] = merge_environment(
@@ -750,7 +751,9 @@ class Service(object):
 
         host_config = self.client.create_host_config(
             links=self._get_links(link_to_self=one_off),
-            port_bindings=build_port_bindings(options.get('ports') or []),
+            port_bindings=build_port_bindings(
+                formatted_ports(options.get('ports', []))
+            ),
             binds=options.get('binds'),
             volumes_from=self._get_volumes_from(),
             privileged=options.get('privileged', False),
@@ -880,7 +883,10 @@ class Service(object):
 
     def specifies_host_port(self):
         def has_host_port(binding):
-            _, external_bindings = split_port(binding)
+            if isinstance(binding, dict):
+                external_bindings = binding.get('published')
+            else:
+                _, external_bindings = split_port(binding)
 
             # there are no external bindings
             if external_bindings is None:
@@ -1225,12 +1231,21 @@ def format_environment(environment):
         return '{key}={value}'.format(key=key, value=value)
     return [format_env(*item) for item in environment.items()]
 
+
 # Ports
+def formatted_ports(ports):
+    result = []
+    for port in ports:
+        if isinstance(port, ServicePort):
+            result.append(port.legacy_repr())
+        else:
+            result.append(port)
+    return result
 
 
-def build_container_ports(container_options, options):
+def build_container_ports(container_ports, options):
     ports = []
-    all_ports = container_options.get('ports', []) + options.get('expose', [])
+    all_ports = container_ports + options.get('expose', [])
     for port_range in all_ports:
         internal_range, _ = split_port(port_range)
         for port in internal_range:

+ 13 - 0
tests/acceptance/cli_test.py

@@ -1808,6 +1808,19 @@ class CLITestCase(DockerClientTestCase):
         self.assertEqual(get_port(3001), "0.0.0.0:49152")
         self.assertEqual(get_port(3002), "0.0.0.0:49153")
 
+    def test_expanded_port(self):
+        self.base_dir = 'tests/fixtures/ports-composefile'
+        self.dispatch(['-f', 'expanded-notation.yml', 'up', '-d'])
+        container = self.project.get_service('simple').get_container()
+
+        def get_port(number):
+            result = self.dispatch(['port', 'simple', str(number)])
+            return result.stdout.rstrip()
+
+        self.assertEqual(get_port(3000), container.get_local_port(3000))
+        self.assertEqual(get_port(3001), "0.0.0.0:49152")
+        self.assertEqual(get_port(3002), "0.0.0.0:49153")
+
     def test_port_with_scale(self):
         self.base_dir = 'tests/fixtures/ports-composefile-scale'
         self.dispatch(['scale', 'simple=2'], None)

+ 15 - 0
tests/fixtures/ports-composefile/expanded-notation.yml

@@ -0,0 +1,15 @@
+version: '3.1'
+services:
+    simple:
+      image: busybox:latest
+      command: top
+      ports:
+        - target: 3000
+        - target: 3001
+          published: 49152
+        - target: 3002
+          published: 49153
+          protocol: tcp
+        - target: 3003
+          published: 49154
+          protocol: udp

+ 74 - 19
tests/unit/config/config_test.py

@@ -10,6 +10,7 @@ from operator import itemgetter
 
 import py
 import pytest
+import yaml
 
 from ...helpers import build_config_details
 from compose.config import config
@@ -25,6 +26,7 @@ from compose.config.environment import Environment
 from compose.config.errors import ConfigurationError
 from compose.config.errors import VERSION_EXPLANATION
 from compose.config.serialize import denormalize_service_dict
+from compose.config.serialize import serialize_config
 from compose.config.serialize import serialize_ns_time_value
 from compose.config.types import VolumeSpec
 from compose.const import IS_WINDOWS_PLATFORM
@@ -1794,6 +1796,30 @@ class ConfigTest(unittest.TestCase):
             }
         }
 
+    def test_merge_mixed_ports(self):
+        base = {
+            'image': 'busybox:latest',
+            'command': 'top',
+            'ports': [
+                {
+                    'target': '1245',
+                    'published': '1245',
+                    'protocol': 'tcp',
+                }
+            ]
+        }
+
+        override = {
+            'ports': ['1245:1245/udp']
+        }
+
+        actual = config.merge_service_dicts(base, override, V3_1)
+        assert actual == {
+            'image': 'busybox:latest',
+            'command': 'top',
+            'ports': [types.ServicePort('1245', '1245', 'udp', None, None)]
+        }
+
     def test_merge_depends_on_no_override(self):
         base = {
             'image': 'busybox',
@@ -2269,7 +2295,10 @@ class InterpolationTest(unittest.TestCase):
         self.assertEqual(service_dicts[0], {
             'name': 'web',
             'image': 'alpine:latest',
-            'ports': ['5643', '9999'],
+            'ports': [
+                types.ServicePort.parse('5643')[0],
+                types.ServicePort.parse('9999')[0]
+            ],
             'command': 'true'
         })
 
@@ -2292,7 +2321,7 @@ class InterpolationTest(unittest.TestCase):
             {
                 'name': 'web',
                 'image': 'busybox',
-                'ports': ['80:8000'],
+                'ports': types.ServicePort.parse('80:8000'),
                 'labels': {'mylabel': 'myvalue'},
                 'hostname': 'host-',
                 'command': '${ESCAPED}',
@@ -2576,13 +2605,37 @@ class MergePortsTest(unittest.TestCase, MergeListsTest):
     base_config = ['10:8000', '9000']
     override_config = ['20:8000']
 
+    def merged_config(self):
+        return self.convert(self.base_config) | self.convert(self.override_config)
+
+    def convert(self, port_config):
+        return set(config.merge_service_dicts(
+            {self.config_name: port_config},
+            {self.config_name: []},
+            DEFAULT_VERSION
+        )[self.config_name])
+
     def test_duplicate_port_mappings(self):
         service_dict = config.merge_service_dicts(
             {self.config_name: self.base_config},
             {self.config_name: self.base_config},
             DEFAULT_VERSION
         )
-        assert set(service_dict[self.config_name]) == set(self.base_config)
+        assert set(service_dict[self.config_name]) == self.convert(self.base_config)
+
+    def test_no_override(self):
+        service_dict = config.merge_service_dicts(
+            {self.config_name: self.base_config},
+            {},
+            DEFAULT_VERSION)
+        assert set(service_dict[self.config_name]) == self.convert(self.base_config)
+
+    def test_no_base(self):
+        service_dict = config.merge_service_dicts(
+            {},
+            {self.config_name: self.base_config},
+            DEFAULT_VERSION)
+        assert set(service_dict[self.config_name]) == self.convert(self.base_config)
 
 
 class MergeNetworksTest(unittest.TestCase, MergeListsTest):
@@ -3610,23 +3663,25 @@ class SerializeTest(unittest.TestCase):
         assert denormalized_service['healthcheck']['interval'] == '100s'
         assert denormalized_service['healthcheck']['timeout'] == '30s'
 
-    def test_denormalize_secrets(self):
+    def test_serialize_secrets(self):
         service_dict = {
-            'name': 'web',
             'image': 'example/web',
             'secrets': [
-                types.ServiceSecret('one', None, None, None, None),
-                types.ServiceSecret('source', 'target', '100', '200', 0o777),
-            ],
+                {'source': 'one'},
+                {
+                    'source': 'source',
+                    'target': 'target',
+                    'uid': '100',
+                    'gid': '200',
+                    'mode': 0o777,
+                }
+            ]
         }
-        denormalized_service = denormalize_service_dict(service_dict, V3_1)
-        assert secret_sort(denormalized_service['secrets']) == secret_sort([
-            {'source': 'one'},
-            {
-                'source': 'source',
-                'target': 'target',
-                'uid': '100',
-                'gid': '200',
-                'mode': 0o777,
-            },
-        ])
+        config_dict = config.load(build_config_details({
+            'version': '3.1',
+            'services': {'web': service_dict}
+        }))
+
+        serialized_config = yaml.load(serialize_config(config_dict))
+        serialized_service = serialized_config['services']['web']
+        assert secret_sort(serialized_service['secrets']) == secret_sort(service_dict['secrets'])

+ 44 - 0
tests/unit/config/types_test.py

@@ -7,6 +7,7 @@ from compose.config.config import V1
 from compose.config.config import V2_0
 from compose.config.errors import ConfigurationError
 from compose.config.types import parse_extra_hosts
+from compose.config.types import ServicePort
 from compose.config.types import VolumeFromSpec
 from compose.config.types import VolumeSpec
 
@@ -41,6 +42,49 @@ def test_parse_extra_hosts_dict():
     }
 
 
+class TestServicePort(object):
+    def test_parse_dict(self):
+        data = {
+            'target': 8000,
+            'published': 8000,
+            'protocol': 'udp',
+            'mode': 'global',
+        }
+        ports = ServicePort.parse(data)
+        assert len(ports) == 1
+        assert ports[0].repr() == data
+
+    def test_parse_simple_target_port(self):
+        ports = ServicePort.parse(8000)
+        assert len(ports) == 1
+        assert ports[0].target == '8000'
+
+    def test_parse_complete_port_definition(self):
+        port_def = '1.1.1.1:3000:3000/udp'
+        ports = ServicePort.parse(port_def)
+        assert len(ports) == 1
+        assert ports[0].repr() == {
+            'target': '3000',
+            'published': '3000',
+            'external_ip': '1.1.1.1',
+            'protocol': 'udp',
+        }
+        assert ports[0].legacy_repr() == port_def
+
+    def test_parse_port_range(self):
+        ports = ServicePort.parse('25000-25001:4000-4001')
+        assert len(ports) == 2
+        reprs = [p.repr() for p in ports]
+        assert {
+            'target': '4000',
+            'published': '25000'
+        } in reprs
+        assert {
+            'target': '4001',
+            'published': '25001'
+        } in reprs
+
+
 class TestVolumeSpec(object):
 
     def test_parse_volume_spec_only_one_path(self):

+ 21 - 0
tests/unit/service_test.py

@@ -7,6 +7,7 @@ from docker.errors import APIError
 
 from .. import mock
 from .. import unittest
+from compose.config.types import ServicePort
 from compose.config.types import VolumeFromSpec
 from compose.config.types import VolumeSpec
 from compose.const import LABEL_CONFIG_HASH
@@ -19,6 +20,7 @@ from compose.service import build_ulimits
 from compose.service import build_volume_binding
 from compose.service import BuildAction
 from compose.service import ContainerNetworkMode
+from compose.service import formatted_ports
 from compose.service import get_container_data_volumes
 from compose.service import ImageType
 from compose.service import merge_volume_bindings
@@ -778,6 +780,25 @@ class NetTestCase(unittest.TestCase):
         self.assertEqual(network_mode.service_name, service_name)
 
 
+class ServicePortsTest(unittest.TestCase):
+    def test_formatted_ports(self):
+        ports = [
+            '3000',
+            '0.0.0.0:4025-4030:23000-23005',
+            ServicePort(6000, None, None, None, None),
+            ServicePort(8080, 8080, None, None, None),
+            ServicePort('20000', '20000', 'udp', 'ingress', None),
+            ServicePort(30000, '30000', 'tcp', None, '127.0.0.1'),
+        ]
+        formatted = formatted_ports(ports)
+        assert ports[0] in formatted
+        assert ports[1] in formatted
+        assert '6000/tcp' in formatted
+        assert '8080:8080/tcp' in formatted
+        assert '20000:20000/udp' in formatted
+        assert '127.0.0.1:30000:30000/tcp' in formatted
+
+
 def build_mount(destination, source, mode='rw'):
     return {'Source': source, 'Destination': destination, 'Mode': mode}