|  | @@ -7,14 +7,35 @@ from string import Template
 | 
											
												
													
														|  |  import six
 |  |  import six
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  from .errors import ConfigurationError
 |  |  from .errors import ConfigurationError
 | 
											
												
													
														|  | 
 |  | +from compose.const import COMPOSEFILE_V1 as V1
 | 
											
												
													
														|  | 
 |  | +from compose.const import COMPOSEFILE_V2_0 as V2_0
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |  log = logging.getLogger(__name__)
 |  |  log = logging.getLogger(__name__)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -def interpolate_environment_variables(config, section, environment):
 |  | 
 | 
											
												
													
														|  | 
 |  | +class Interpolator(object):
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, templater, mapping):
 | 
											
												
													
														|  | 
 |  | +        self.templater = templater
 | 
											
												
													
														|  | 
 |  | +        self.mapping = mapping
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def interpolate(self, string):
 | 
											
												
													
														|  | 
 |  | +        try:
 | 
											
												
													
														|  | 
 |  | +            return self.templater(string).substitute(self.mapping)
 | 
											
												
													
														|  | 
 |  | +        except ValueError:
 | 
											
												
													
														|  | 
 |  | +            raise InvalidInterpolation(string)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +def interpolate_environment_variables(version, config, section, environment):
 | 
											
												
													
														|  | 
 |  | +    if version in (V2_0, V1):
 | 
											
												
													
														|  | 
 |  | +        interpolator = Interpolator(Template, environment)
 | 
											
												
													
														|  | 
 |  | +    else:
 | 
											
												
													
														|  | 
 |  | +        interpolator = Interpolator(TemplateWithDefaults, environment)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      def process_item(name, config_dict):
 |  |      def process_item(name, config_dict):
 | 
											
												
													
														|  |          return dict(
 |  |          return dict(
 | 
											
												
													
														|  | -            (key, interpolate_value(name, key, val, section, environment))
 |  | 
 | 
											
												
													
														|  | 
 |  | +            (key, interpolate_value(name, key, val, section, interpolator))
 | 
											
												
													
														|  |              for key, val in (config_dict or {}).items()
 |  |              for key, val in (config_dict or {}).items()
 | 
											
												
													
														|  |          )
 |  |          )
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -24,9 +45,9 @@ def interpolate_environment_variables(config, section, environment):
 | 
											
												
													
														|  |      )
 |  |      )
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -def interpolate_value(name, config_key, value, section, mapping):
 |  | 
 | 
											
												
													
														|  | 
 |  | +def interpolate_value(name, config_key, value, section, interpolator):
 | 
											
												
													
														|  |      try:
 |  |      try:
 | 
											
												
													
														|  | -        return recursive_interpolate(value, mapping)
 |  | 
 | 
											
												
													
														|  | 
 |  | +        return recursive_interpolate(value, interpolator)
 | 
											
												
													
														|  |      except InvalidInterpolation as e:
 |  |      except InvalidInterpolation as e:
 | 
											
												
													
														|  |          raise ConfigurationError(
 |  |          raise ConfigurationError(
 | 
											
												
													
														|  |              'Invalid interpolation format for "{config_key}" option '
 |  |              'Invalid interpolation format for "{config_key}" option '
 | 
											
										
											
												
													
														|  | @@ -37,25 +58,44 @@ def interpolate_value(name, config_key, value, section, mapping):
 | 
											
												
													
														|  |                  string=e.string))
 |  |                  string=e.string))
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -def recursive_interpolate(obj, mapping):
 |  | 
 | 
											
												
													
														|  | 
 |  | +def recursive_interpolate(obj, interpolator):
 | 
											
												
													
														|  |      if isinstance(obj, six.string_types):
 |  |      if isinstance(obj, six.string_types):
 | 
											
												
													
														|  | -        return interpolate(obj, mapping)
 |  | 
 | 
											
												
													
														|  | -    elif isinstance(obj, dict):
 |  | 
 | 
											
												
													
														|  | 
 |  | +        return interpolator.interpolate(obj)
 | 
											
												
													
														|  | 
 |  | +    if isinstance(obj, dict):
 | 
											
												
													
														|  |          return dict(
 |  |          return dict(
 | 
											
												
													
														|  | -            (key, recursive_interpolate(val, mapping))
 |  | 
 | 
											
												
													
														|  | 
 |  | +            (key, recursive_interpolate(val, interpolator))
 | 
											
												
													
														|  |              for (key, val) in obj.items()
 |  |              for (key, val) in obj.items()
 | 
											
												
													
														|  |          )
 |  |          )
 | 
											
												
													
														|  | -    elif isinstance(obj, list):
 |  | 
 | 
											
												
													
														|  | -        return [recursive_interpolate(val, mapping) for val in obj]
 |  | 
 | 
											
												
													
														|  | -    else:
 |  | 
 | 
											
												
													
														|  | -        return obj
 |  | 
 | 
											
												
													
														|  | 
 |  | +    if isinstance(obj, list):
 | 
											
												
													
														|  | 
 |  | +        return [recursive_interpolate(val, interpolator) for val in obj]
 | 
											
												
													
														|  | 
 |  | +    return obj
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -def interpolate(string, mapping):
 |  | 
 | 
											
												
													
														|  | -    try:
 |  | 
 | 
											
												
													
														|  | -        return Template(string).substitute(mapping)
 |  | 
 | 
											
												
													
														|  | -    except ValueError:
 |  | 
 | 
											
												
													
														|  | -        raise InvalidInterpolation(string)
 |  | 
 | 
											
												
													
														|  | 
 |  | +class TemplateWithDefaults(Template):
 | 
											
												
													
														|  | 
 |  | +    idpattern = r'[_a-z][_a-z0-9]*(?::?-[_a-z0-9]+)?'
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    # Modified from python2.7/string.py
 | 
											
												
													
														|  | 
 |  | +    def substitute(self, mapping):
 | 
											
												
													
														|  | 
 |  | +        # Helper function for .sub()
 | 
											
												
													
														|  | 
 |  | +        def convert(mo):
 | 
											
												
													
														|  | 
 |  | +            # Check the most common path first.
 | 
											
												
													
														|  | 
 |  | +            named = mo.group('named') or mo.group('braced')
 | 
											
												
													
														|  | 
 |  | +            if named is not None:
 | 
											
												
													
														|  | 
 |  | +                if ':-' in named:
 | 
											
												
													
														|  | 
 |  | +                    var, _, default = named.partition(':-')
 | 
											
												
													
														|  | 
 |  | +                    return mapping.get(var) or default
 | 
											
												
													
														|  | 
 |  | +                if '-' in named:
 | 
											
												
													
														|  | 
 |  | +                    var, _, default = named.partition('-')
 | 
											
												
													
														|  | 
 |  | +                    return mapping.get(var, default)
 | 
											
												
													
														|  | 
 |  | +                val = mapping[named]
 | 
											
												
													
														|  | 
 |  | +                return '%s' % (val,)
 | 
											
												
													
														|  | 
 |  | +            if mo.group('escaped') is not None:
 | 
											
												
													
														|  | 
 |  | +                return self.delimiter
 | 
											
												
													
														|  | 
 |  | +            if mo.group('invalid') is not None:
 | 
											
												
													
														|  | 
 |  | +                self._invalid(mo)
 | 
											
												
													
														|  | 
 |  | +            raise ValueError('Unrecognized named group in pattern',
 | 
											
												
													
														|  | 
 |  | +                             self.pattern)
 | 
											
												
													
														|  | 
 |  | +        return self.pattern.sub(convert, self.template)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  class InvalidInterpolation(Exception):
 |  |  class InvalidInterpolation(Exception):
 |