浏览代码

Merge pull request #442 from dnephin/fix_create_container_volumes

Additional validation for container volumes and ports.
Aanand Prasad 11 年之前
父节点
当前提交
7ad91f3f00
共有 4 个文件被更改,包括 130 次插入55 次删除
  1. 6 2
      docs/yml.md
  2. 1 1
      fig/cli/formatter.py
  3. 55 49
      fig/service.py
  4. 68 3
      tests/unit/service_test.py

+ 6 - 2
docs/yml.md

@@ -74,14 +74,18 @@ expose:
 
 
 ### volumes
 ### volumes
 
 
-Mount paths as volumes, optionally specifying a path on the host machine (`HOST:CONTAINER`).
+Mount paths as volumes, optionally specifying a path on the host machine
+(`HOST:CONTAINER`), or an access mode (`HOST:CONTAINER:ro`).
 
 
-Note: Mapping local volumes is currently unsupported on boot2docker. We recommend you use [docker-osx](https://github.com/noplay/docker-osx) if want to map local volumes.
+Note for fig on OSX: Mapping local volumes is currently unsupported on
+boot2docker. We recommend you use [docker-osx](https://github.com/noplay/docker-osx)
+if want to map local volumes on OSX.
 
 
 ```
 ```
 volumes:
 volumes:
  - /var/lib/mysql
  - /var/lib/mysql
  - cache/:/tmp/cache
  - cache/:/tmp/cache
+ - ~/configs:/etc/configs/:ro
 ```
 ```
 
 
 ### volumes_from
 ### volumes_from

+ 1 - 1
fig/cli/formatter.py

@@ -9,7 +9,7 @@ def get_tty_width():
     if len(tty_size) != 2:
     if len(tty_size) != 2:
         return 80
         return 80
     _, width = tty_size
     _, width = tty_size
-    return width
+    return int(width)
 
 
 
 
 class Formatter(object):
 class Formatter(object):

+ 55 - 49
fig/service.py

@@ -1,5 +1,6 @@
 from __future__ import unicode_literals
 from __future__ import unicode_literals
 from __future__ import absolute_import
 from __future__ import absolute_import
+from collections import namedtuple
 from .packages.docker.errors import APIError
 from .packages.docker.errors import APIError
 import logging
 import logging
 import re
 import re
@@ -39,6 +40,9 @@ class ConfigError(ValueError):
     pass
     pass
 
 
 
 
+VolumeSpec = namedtuple('VolumeSpec', 'external internal mode')
+
+
 class Service(object):
 class Service(object):
     def __init__(self, name, client=None, project='default', links=None, volumes_from=None, **options):
     def __init__(self, name, client=None, project='default', links=None, volumes_from=None, **options):
         if not re.match('^%s+$' % VALID_NAME_CHARS, name):
         if not re.match('^%s+$' % VALID_NAME_CHARS, name):
@@ -214,37 +218,22 @@ class Service(object):
             return self.start_container(container, **options)
             return self.start_container(container, **options)
 
 
     def start_container(self, container=None, intermediate_container=None, **override_options):
     def start_container(self, container=None, intermediate_container=None, **override_options):
-        if container is None:
-            container = self.create_container(**override_options)
-
-        options = self.options.copy()
-        options.update(override_options)
-
-        port_bindings = {}
+        container = container or self.create_container(**override_options)
+        options = dict(self.options, **override_options)
+        ports = dict(split_port(port) for port in options.get('ports') or [])
 
 
-        if options.get('ports', None) is not None:
-            for port in options['ports']:
-                internal_port, external_port = split_port(port)
-                port_bindings[internal_port] = external_port
-
-        volume_bindings = {}
-
-        if options.get('volumes', None) is not None:
-            for volume in options['volumes']:
-                if ':' in volume:
-                    external_dir, internal_dir = volume.split(':')
-                    volume_bindings[os.path.abspath(external_dir)] = {
-                        'bind': internal_dir,
-                        'ro': False,
-                    }
+        volume_bindings = dict(
+            build_volume_binding(parse_volume_spec(volume))
+            for volume in options.get('volumes') or []
+            if ':' in volume)
 
 
         privileged = options.get('privileged', False)
         privileged = options.get('privileged', False)
         net = options.get('net', 'bridge')
         net = options.get('net', 'bridge')
         dns = options.get('dns', None)
         dns = options.get('dns', None)
 
 
         container.start(
         container.start(
-            links=self._get_links(link_to_self=override_options.get('one_off', False)),
-            port_bindings=port_bindings,
+            links=self._get_links(link_to_self=options.get('one_off', False)),
+            port_bindings=ports,
             binds=volume_bindings,
             binds=volume_bindings,
             volumes_from=self._get_volumes_from(intermediate_container),
             volumes_from=self._get_volumes_from(intermediate_container),
             privileged=privileged,
             privileged=privileged,
@@ -256,7 +245,7 @@ class Service(object):
     def start_or_create_containers(self):
     def start_or_create_containers(self):
         containers = self.containers(stopped=True)
         containers = self.containers(stopped=True)
 
 
-        if len(containers) == 0:
+        if not containers:
             log.info("Creating %s..." % self.next_container_name())
             log.info("Creating %s..." % self.next_container_name())
             new_container = self.create_container()
             new_container = self.create_container()
             return [self.start_container(new_container)]
             return [self.start_container(new_container)]
@@ -338,7 +327,9 @@ class Service(object):
             container_options['ports'] = ports
             container_options['ports'] = ports
 
 
         if 'volumes' in container_options:
         if 'volumes' in container_options:
-            container_options['volumes'] = dict((split_volume(v)[1], {}) for v in container_options['volumes'])
+            container_options['volumes'] = dict(
+                (parse_volume_spec(v).internal, {})
+                for v in container_options['volumes'])
 
 
         if 'environment' in container_options:
         if 'environment' in container_options:
             if isinstance(container_options['environment'], list):
             if isinstance(container_options['environment'], list):
@@ -433,32 +424,47 @@ def get_container_name(container):
             return name[1:]
             return name[1:]
 
 
 
 
-def split_volume(v):
-    """
-    If v is of the format EXTERNAL:INTERNAL, returns (EXTERNAL, INTERNAL).
-    If v is of the format INTERNAL, returns (None, INTERNAL).
-    """
-    if ':' in v:
-        return v.split(':', 1)
-    else:
-        return (None, v)
+def parse_volume_spec(volume_config):
+    parts = volume_config.split(':')
+    if len(parts) > 3:
+        raise ConfigError("Volume %s has incorrect format, should be "
+                          "external:internal[:mode]" % volume_config)
+
+    if len(parts) == 1:
+        return VolumeSpec(None, parts[0], 'rw')
+
+    if len(parts) == 2:
+        parts.append('rw')
+
+    external, internal, mode = parts
+    if mode not in ('rw', 'ro'):
+        raise ConfigError("Volume %s has invalid mode (%s), should be "
+                          "one of: rw, ro." % (volume_config, mode))
+
+    return VolumeSpec(external, internal, mode)
+
+
+def build_volume_binding(volume_spec):
+    internal = {'bind': volume_spec.internal, 'ro': volume_spec.mode == 'ro'}
+    external = os.path.expanduser(volume_spec.external)
+    return os.path.abspath(os.path.expandvars(external)), internal
 
 
 
 
 def split_port(port):
 def split_port(port):
-    port = str(port)
-    external_ip = None
-    if ':' in port:
-        external_port, internal_port = port.rsplit(':', 1)
-        if ':' in external_port:
-            external_ip, external_port = external_port.split(':', 1)
-    else:
-        external_port, internal_port = (None, port)
-    if external_ip:
-        if external_port:
-            external_port = (external_ip, external_port)
-        else:
-            external_port = (external_ip,)
-    return internal_port, external_port
+    parts = str(port).split(':')
+    if not 1 <= len(parts) <= 3:
+        raise ConfigError('Invalid port "%s", should be '
+                          '[[remote_ip:]remote_port:]port[/protocol]' % port)
+
+    if len(parts) == 1:
+        internal_port, = parts
+        return internal_port, None
+    if len(parts) == 2:
+        external_port, internal_port = parts
+        return internal_port, external_port
+
+    external_ip, external_port, internal_port = parts
+    return internal_port, (external_ip, external_port or None)
 
 
 
 
 def split_env(env):
 def split_env(env):

+ 68 - 3
tests/unit/service_test.py

@@ -1,8 +1,18 @@
 from __future__ import unicode_literals
 from __future__ import unicode_literals
 from __future__ import absolute_import
 from __future__ import absolute_import
+import os
+
 from .. import unittest
 from .. import unittest
+import mock
+
 from fig import Service
 from fig import Service
-from fig.service import ConfigError, split_port
+from fig.service import (
+    ConfigError,
+    split_port,
+    parse_volume_spec,
+    build_volume_binding,
+)
+
 
 
 class ServiceTest(unittest.TestCase):
 class ServiceTest(unittest.TestCase):
     def test_name_validations(self):
     def test_name_validations(self):
@@ -28,23 +38,35 @@ class ServiceTest(unittest.TestCase):
         self.assertRaises(ConfigError, lambda: Service(name='foo', port=['8000']))
         self.assertRaises(ConfigError, lambda: Service(name='foo', port=['8000']))
         Service(name='foo', ports=['8000'])
         Service(name='foo', ports=['8000'])
 
 
-    def test_split_port(self):
+    def test_split_port_with_host_ip(self):
         internal_port, external_port = split_port("127.0.0.1:1000:2000")
         internal_port, external_port = split_port("127.0.0.1:1000:2000")
         self.assertEqual(internal_port, "2000")
         self.assertEqual(internal_port, "2000")
         self.assertEqual(external_port, ("127.0.0.1", "1000"))
         self.assertEqual(external_port, ("127.0.0.1", "1000"))
 
 
+    def test_split_port_with_protocol(self):
         internal_port, external_port = split_port("127.0.0.1:1000:2000/udp")
         internal_port, external_port = split_port("127.0.0.1:1000:2000/udp")
         self.assertEqual(internal_port, "2000/udp")
         self.assertEqual(internal_port, "2000/udp")
         self.assertEqual(external_port, ("127.0.0.1", "1000"))
         self.assertEqual(external_port, ("127.0.0.1", "1000"))
 
 
+    def test_split_port_with_host_ip_no_port(self):
         internal_port, external_port = split_port("127.0.0.1::2000")
         internal_port, external_port = split_port("127.0.0.1::2000")
         self.assertEqual(internal_port, "2000")
         self.assertEqual(internal_port, "2000")
-        self.assertEqual(external_port, ("127.0.0.1",))
+        self.assertEqual(external_port, ("127.0.0.1", None))
 
 
+    def test_split_port_with_host_port(self):
         internal_port, external_port = split_port("1000:2000")
         internal_port, external_port = split_port("1000:2000")
         self.assertEqual(internal_port, "2000")
         self.assertEqual(internal_port, "2000")
         self.assertEqual(external_port, "1000")
         self.assertEqual(external_port, "1000")
 
 
+    def test_split_port_no_host_port(self):
+        internal_port, external_port = split_port("2000")
+        self.assertEqual(internal_port, "2000")
+        self.assertEqual(external_port, None)
+
+    def test_split_port_invalid(self):
+        with self.assertRaises(ConfigError):
+            split_port("0.0.0.0:1000:2000:tcp")
+
     def test_split_domainname_none(self):
     def test_split_domainname_none(self):
         service = Service('foo',
         service = Service('foo',
                 hostname = 'name',
                 hostname = 'name',
@@ -82,3 +104,46 @@ class ServiceTest(unittest.TestCase):
         opts = service._get_container_create_options({})
         opts = service._get_container_create_options({})
         self.assertEqual(opts['hostname'], 'name.sub', 'hostname')
         self.assertEqual(opts['hostname'], 'name.sub', 'hostname')
         self.assertEqual(opts['domainname'], 'domain.tld', 'domainname')
         self.assertEqual(opts['domainname'], 'domain.tld', 'domainname')
+
+
+class ServiceVolumesTest(unittest.TestCase):
+
+    def test_parse_volume_spec_only_one_path(self):
+        spec = parse_volume_spec('/the/volume')
+        self.assertEqual(spec, (None, '/the/volume', 'rw'))
+
+    def test_parse_volume_spec_internal_and_external(self):
+        spec = parse_volume_spec('external:interval')
+        self.assertEqual(spec, ('external', 'interval', 'rw'))
+
+    def test_parse_volume_spec_with_mode(self):
+        spec = parse_volume_spec('external:interval:ro')
+        self.assertEqual(spec, ('external', 'interval', 'ro'))
+
+    def test_parse_volume_spec_too_many_parts(self):
+        with self.assertRaises(ConfigError):
+            parse_volume_spec('one:two:three:four')
+
+    def test_parse_volume_bad_mode(self):
+        with self.assertRaises(ConfigError):
+            parse_volume_spec('one:two:notrw')
+
+    def test_build_volume_binding(self):
+        binding = build_volume_binding(parse_volume_spec('/outside:/inside'))
+        self.assertEqual(
+            binding,
+            ('/outside', dict(bind='/inside', ro=False)))
+
+    @mock.patch.dict(os.environ)
+    def test_build_volume_binding_with_environ(self):
+        os.environ['VOLUME_PATH'] = '/opt'
+        binding = build_volume_binding(parse_volume_spec('${VOLUME_PATH}:/opt'))
+        self.assertEqual(binding, ('/opt', dict(bind='/opt', ro=False)))
+
+    @mock.patch.dict(os.environ)
+    def test_building_volume_binding_with_home(self):
+        os.environ['HOME'] = '/home/user'
+        binding = build_volume_binding(parse_volume_spec('~:/home/user'))
+        self.assertEqual(
+            binding,
+            ('/home/user', dict(bind='/home/user', ro=False)))