Pārlūkot izejas kodu

Merge pull request #3108 from dnephin/inplace_var_defaults

Support inline default values for interpolation
Joffrey F 9 gadi atpakaļ
vecāks
revīzija
eac01a7cd5

+ 13 - 9
compose/config/config.py

@@ -413,31 +413,35 @@ def load_services(config_details, config_file):
     return build_services(service_config)
 
 
-def interpolate_config_section(filename, config, section, environment):
-    validate_config_section(filename, config, section)
-    return interpolate_environment_variables(config, section, environment)
+def interpolate_config_section(config_file, config, section, environment):
+    validate_config_section(config_file.filename, config, section)
+    return interpolate_environment_variables(
+            config_file.version,
+            config,
+            section,
+            environment)
 
 
 def process_config_file(config_file, environment, service_name=None):
     services = interpolate_config_section(
-        config_file.filename,
+        config_file,
         config_file.get_service_dicts(),
         'service',
-        environment,)
+        environment)
 
     if config_file.version in (V2_0, V2_1):
         processed_config = dict(config_file.config)
         processed_config['services'] = services
         processed_config['volumes'] = interpolate_config_section(
-            config_file.filename,
+            config_file,
             config_file.get_volumes(),
             'volume',
-            environment,)
+            environment)
         processed_config['networks'] = interpolate_config_section(
-            config_file.filename,
+            config_file,
             config_file.get_networks(),
             'network',
-            environment,)
+            environment)
 
     if config_file.version == V1:
         processed_config = services

+ 57 - 17
compose/config/interpolation.py

@@ -7,14 +7,35 @@ from string import Template
 import six
 
 from .errors import ConfigurationError
+from compose.const import COMPOSEFILE_V1 as V1
+from compose.const import COMPOSEFILE_V2_0 as V2_0
+
+
 log = logging.getLogger(__name__)
 
 
-def interpolate_environment_variables(config, section, environment):
+class Interpolator(object):
+
+    def __init__(self, templater, mapping):
+        self.templater = templater
+        self.mapping = mapping
+
+    def interpolate(self, string):
+        try:
+            return self.templater(string).substitute(self.mapping)
+        except ValueError:
+            raise InvalidInterpolation(string)
+
+
+def interpolate_environment_variables(version, config, section, environment):
+    if version in (V2_0, V1):
+        interpolator = Interpolator(Template, environment)
+    else:
+        interpolator = Interpolator(TemplateWithDefaults, environment)
 
     def process_item(name, config_dict):
         return dict(
-            (key, interpolate_value(name, key, val, section, environment))
+            (key, interpolate_value(name, key, val, section, interpolator))
             for key, val in (config_dict or {}).items()
         )
 
@@ -24,9 +45,9 @@ def interpolate_environment_variables(config, section, environment):
     )
 
 
-def interpolate_value(name, config_key, value, section, mapping):
+def interpolate_value(name, config_key, value, section, interpolator):
     try:
-        return recursive_interpolate(value, mapping)
+        return recursive_interpolate(value, interpolator)
     except InvalidInterpolation as e:
         raise ConfigurationError(
             'Invalid interpolation format for "{config_key}" option '
@@ -37,25 +58,44 @@ def interpolate_value(name, config_key, value, section, mapping):
                 string=e.string))
 
 
-def recursive_interpolate(obj, mapping):
+def recursive_interpolate(obj, interpolator):
     if isinstance(obj, six.string_types):
-        return interpolate(obj, mapping)
-    elif isinstance(obj, dict):
+        return interpolator.interpolate(obj)
+    if isinstance(obj, dict):
         return dict(
-            (key, recursive_interpolate(val, mapping))
+            (key, recursive_interpolate(val, interpolator))
             for (key, val) in obj.items()
         )
-    elif isinstance(obj, list):
-        return [recursive_interpolate(val, mapping) for val in obj]
-    else:
-        return obj
+    if isinstance(obj, list):
+        return [recursive_interpolate(val, interpolator) for val in obj]
+    return obj
 
 
-def interpolate(string, mapping):
-    try:
-        return Template(string).substitute(mapping)
-    except ValueError:
-        raise InvalidInterpolation(string)
+class TemplateWithDefaults(Template):
+    idpattern = r'[_a-z][_a-z0-9]*(?::?-[_a-z0-9]+)?'
+
+    # Modified from python2.7/string.py
+    def substitute(self, mapping):
+        # Helper function for .sub()
+        def convert(mo):
+            # Check the most common path first.
+            named = mo.group('named') or mo.group('braced')
+            if named is not None:
+                if ':-' in named:
+                    var, _, default = named.partition(':-')
+                    return mapping.get(var) or default
+                if '-' in named:
+                    var, _, default = named.partition('-')
+                    return mapping.get(var, default)
+                val = mapping[named]
+                return '%s' % (val,)
+            if mo.group('escaped') is not None:
+                return self.delimiter
+            if mo.group('invalid') is not None:
+                self._invalid(mo)
+            raise ValueError('Unrecognized named group in pattern',
+                             self.pattern)
+        return self.pattern.sub(convert, self.template)
 
 
 class InvalidInterpolation(Exception):

+ 60 - 14
tests/unit/config/interpolation_test.py

@@ -1,21 +1,28 @@
 from __future__ import absolute_import
 from __future__ import unicode_literals
 
-import os
-
-import mock
 import pytest
 
 from compose.config.environment import Environment
 from compose.config.interpolation import interpolate_environment_variables
+from compose.config.interpolation import Interpolator
+from compose.config.interpolation import InvalidInterpolation
+from compose.config.interpolation import TemplateWithDefaults
 
 
[email protected]_fixture
[email protected]
 def mock_env():
-    with mock.patch.dict(os.environ):
-        os.environ['USER'] = 'jenny'
-        os.environ['FOO'] = 'bar'
-        yield
+    return Environment({'USER': 'jenny', 'FOO': 'bar'})
+
+
[email protected]
+def variable_mapping():
+    return Environment({'FOO': 'first', 'BAR': ''})
+
+
[email protected]
+def defaults_interpolator(variable_mapping):
+    return Interpolator(TemplateWithDefaults, variable_mapping).interpolate
 
 
 def test_interpolate_environment_variables_in_services(mock_env):
@@ -43,9 +50,8 @@ def test_interpolate_environment_variables_in_services(mock_env):
             }
         }
     }
-    assert interpolate_environment_variables(
-        services, 'service', Environment.from_env_file(None)
-    ) == expected
+    value = interpolate_environment_variables("2.0", services, 'service', mock_env)
+    assert value == expected
 
 
 def test_interpolate_environment_variables_in_volumes(mock_env):
@@ -69,6 +75,46 @@ def test_interpolate_environment_variables_in_volumes(mock_env):
         },
         'other': {},
     }
-    assert interpolate_environment_variables(
-        volumes, 'volume', Environment.from_env_file(None)
-    ) == expected
+    value = interpolate_environment_variables("2.0",  volumes, 'volume', mock_env)
+    assert value == expected
+
+
+def test_escaped_interpolation(defaults_interpolator):
+    assert defaults_interpolator('$${foo}') == '${foo}'
+
+
+def test_invalid_interpolation(defaults_interpolator):
+    with pytest.raises(InvalidInterpolation):
+        defaults_interpolator('${')
+    with pytest.raises(InvalidInterpolation):
+        defaults_interpolator('$}')
+    with pytest.raises(InvalidInterpolation):
+        defaults_interpolator('${}')
+    with pytest.raises(InvalidInterpolation):
+        defaults_interpolator('${ }')
+    with pytest.raises(InvalidInterpolation):
+        defaults_interpolator('${ foo}')
+    with pytest.raises(InvalidInterpolation):
+        defaults_interpolator('${foo }')
+    with pytest.raises(InvalidInterpolation):
+        defaults_interpolator('${foo!}')
+
+
+def test_interpolate_missing_no_default(defaults_interpolator):
+    assert defaults_interpolator("This ${missing} var") == "This  var"
+    assert defaults_interpolator("This ${BAR} var") == "This  var"
+
+
+def test_interpolate_with_value(defaults_interpolator):
+    assert defaults_interpolator("This $FOO var") == "This first var"
+    assert defaults_interpolator("This ${FOO} var") == "This first var"
+
+
+def test_interpolate_missing_with_default(defaults_interpolator):
+    assert defaults_interpolator("ok ${missing:-def}") == "ok def"
+    assert defaults_interpolator("ok ${missing-def}") == "ok def"
+
+
+def test_interpolate_with_empty_and_default_value(defaults_interpolator):
+    assert defaults_interpolator("ok ${BAR:-def}") == "ok def"
+    assert defaults_interpolator("ok ${BAR-def}") == "ok "

+ 0 - 36
tests/unit/interpolation_test.py

@@ -1,36 +0,0 @@
-from __future__ import absolute_import
-from __future__ import unicode_literals
-
-import unittest
-
-from compose.config.environment import Environment as bddict
-from compose.config.interpolation import interpolate
-from compose.config.interpolation import InvalidInterpolation
-
-
-class InterpolationTest(unittest.TestCase):
-    def test_valid_interpolations(self):
-        self.assertEqual(interpolate('$foo', bddict(foo='hi')), 'hi')
-        self.assertEqual(interpolate('${foo}', bddict(foo='hi')), 'hi')
-
-        self.assertEqual(interpolate('${subject} love you', bddict(subject='i')), 'i love you')
-        self.assertEqual(interpolate('i ${verb} you', bddict(verb='love')), 'i love you')
-        self.assertEqual(interpolate('i love ${object}', bddict(object='you')), 'i love you')
-
-    def test_empty_value(self):
-        self.assertEqual(interpolate('${foo}', bddict(foo='')), '')
-
-    def test_unset_value(self):
-        self.assertEqual(interpolate('${foo}', bddict()), '')
-
-    def test_escaped_interpolation(self):
-        self.assertEqual(interpolate('$${foo}', bddict(foo='hi')), '${foo}')
-
-    def test_invalid_strings(self):
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('${', bddict()))
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('$}', bddict()))
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('${}', bddict()))
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('${ }', bddict()))
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('${ foo}', bddict()))
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo }', bddict()))
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo!}', bddict()))