Browse Source

Avoid duplicate warnings if an unset env variable is used multiple times

Signed-off-by: Aanand Prasad <[email protected]>
Aanand Prasad 10 years ago
parent
commit
4c65891db1
3 changed files with 62 additions and 31 deletions
  1. 22 16
      compose/config/interpolation.py
  2. 24 0
      tests/unit/config_test.py
  3. 16 15
      tests/unit/interpolation_test.py

+ 22 - 16
compose/config/interpolation.py

@@ -10,13 +10,15 @@ log = logging.getLogger(__name__)
 
 
 def interpolate_environment_variables(config):
+    mapping = BlankDefaultDict(os.environ)
+
     return dict(
-        (service_name, process_service(service_name, service_dict))
+        (service_name, process_service(service_name, service_dict, mapping))
         for (service_name, service_dict) in config.items()
     )
 
 
-def process_service(service_name, service_dict):
+def process_service(service_name, service_dict, mapping):
     if not isinstance(service_dict, dict):
         raise ConfigurationError(
             'Service "%s" doesn\'t have any configuration options. '
@@ -25,14 +27,14 @@ def process_service(service_name, service_dict):
         )
 
     return dict(
-        (key, interpolate_value(service_name, key, val))
+        (key, interpolate_value(service_name, key, val, mapping))
         for (key, val) in service_dict.items()
     )
 
 
-def interpolate_value(service_name, config_key, value):
+def interpolate_value(service_name, config_key, value, mapping):
     try:
-        return recursive_interpolate(value)
+        return recursive_interpolate(value, mapping)
     except InvalidInterpolation as e:
         raise ConfigurationError(
             'Invalid interpolation format for "{config_key}" option '
@@ -45,39 +47,43 @@ def interpolate_value(service_name, config_key, value):
         )
 
 
-def recursive_interpolate(obj):
+def recursive_interpolate(obj, mapping):
     if isinstance(obj, six.string_types):
-        return interpolate(obj, os.environ)
+        return interpolate(obj, mapping)
     elif isinstance(obj, dict):
         return dict(
-            (key, recursive_interpolate(val))
+            (key, recursive_interpolate(val, mapping))
             for (key, val) in obj.items()
         )
     elif isinstance(obj, list):
-        return map(recursive_interpolate, obj)
+        return [recursive_interpolate(val, mapping) for val in obj]
     else:
         return obj
 
 
 def interpolate(string, mapping):
     try:
-        return Template(string).substitute(BlankDefaultDict(mapping))
+        return Template(string).substitute(mapping)
     except ValueError:
         raise InvalidInterpolation(string)
 
 
 class BlankDefaultDict(dict):
-    def __init__(self, mapping):
-        super(BlankDefaultDict, self).__init__(mapping)
+    def __init__(self, *args, **kwargs):
+        super(BlankDefaultDict, self).__init__(*args, **kwargs)
+        self.missing_keys = []
 
     def __getitem__(self, key):
         try:
             return super(BlankDefaultDict, self).__getitem__(key)
         except KeyError:
-            log.warn(
-                "The {} variable is not set. Substituting a blank string."
-                .format(key)
-            )
+            if key not in self.missing_keys:
+                log.warn(
+                    "The {} variable is not set. Substituting a blank string."
+                    .format(key)
+                )
+                self.missing_keys.append(key)
+
             return ""
 
 

+ 24 - 0
tests/unit/config_test.py

@@ -198,6 +198,30 @@ class InterpolationTest(unittest.TestCase):
             }
         ])
 
+    @mock.patch.dict(os.environ)
+    def test_unset_variable_produces_warning(self):
+        os.environ.pop('FOO', None)
+        os.environ.pop('BAR', None)
+        config_details = config.ConfigDetails(
+            config={
+                'web': {
+                    'image': '${FOO}',
+                    'command': '${BAR}',
+                    'entrypoint': '${BAR}',
+                },
+            },
+            working_dir='.',
+            filename=None,
+        )
+
+        with mock.patch('compose.config.interpolation.log') as log:
+            config.load(config_details)
+
+            self.assertEqual(2, log.warn.call_count)
+            warnings = sorted(args[0][0] for args in log.warn.call_args_list)
+            self.assertIn('BAR', warnings[0])
+            self.assertIn('FOO', warnings[1])
+
     @mock.patch.dict(os.environ)
     def test_invalid_interpolation(self):
         with self.assertRaises(config.ConfigurationError) as cm:

+ 16 - 15
tests/unit/interpolation_test.py

@@ -1,31 +1,32 @@
 import unittest
 
 from compose.config.interpolation import interpolate, InvalidInterpolation
+from compose.config.interpolation import BlankDefaultDict as bddict
 
 
 class InterpolationTest(unittest.TestCase):
     def test_valid_interpolations(self):
-        self.assertEqual(interpolate('$foo', dict(foo='hi')), 'hi')
-        self.assertEqual(interpolate('${foo}', dict(foo='hi')), 'hi')
+        self.assertEqual(interpolate('$foo', bddict(foo='hi')), 'hi')
+        self.assertEqual(interpolate('${foo}', bddict(foo='hi')), 'hi')
 
-        self.assertEqual(interpolate('${subject} love you', dict(subject='i')), 'i love you')
-        self.assertEqual(interpolate('i ${verb} you', dict(verb='love')), 'i love you')
-        self.assertEqual(interpolate('i love ${object}', dict(object='you')), 'i love you')
+        self.assertEqual(interpolate('${subject} love you', bddict(subject='i')), 'i love you')
+        self.assertEqual(interpolate('i ${verb} you', bddict(verb='love')), 'i love you')
+        self.assertEqual(interpolate('i love ${object}', bddict(object='you')), 'i love you')
 
     def test_empty_value(self):
-        self.assertEqual(interpolate('${foo}', dict(foo='')), '')
+        self.assertEqual(interpolate('${foo}', bddict(foo='')), '')
 
     def test_unset_value(self):
-        self.assertEqual(interpolate('${foo}', dict()), '')
+        self.assertEqual(interpolate('${foo}', bddict()), '')
 
     def test_escaped_interpolation(self):
-        self.assertEqual(interpolate('$${foo}', dict(foo='hi')), '${foo}')
+        self.assertEqual(interpolate('$${foo}', bddict(foo='hi')), '${foo}')
 
     def test_invalid_strings(self):
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('${', dict()))
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('$}', dict()))
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('${}', dict()))
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('${ }', dict()))
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('${ foo}', dict()))
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo }', dict()))
-        self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo!}', dict()))
+        self.assertRaises(InvalidInterpolation, lambda: interpolate('${', bddict()))
+        self.assertRaises(InvalidInterpolation, lambda: interpolate('$}', bddict()))
+        self.assertRaises(InvalidInterpolation, lambda: interpolate('${}', bddict()))
+        self.assertRaises(InvalidInterpolation, lambda: interpolate('${ }', bddict()))
+        self.assertRaises(InvalidInterpolation, lambda: interpolate('${ foo}', bddict()))
+        self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo }', bddict()))
+        self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo!}', bddict()))