interpolation.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from __future__ import absolute_import
  2. from __future__ import unicode_literals
  3. import logging
  4. from string import Template
  5. import six
  6. from .errors import ConfigurationError
  7. from compose.const import COMPOSEFILE_V2_0 as V2_0
  8. log = logging.getLogger(__name__)
  9. class Interpolator(object):
  10. def __init__(self, templater, mapping):
  11. self.templater = templater
  12. self.mapping = mapping
  13. def interpolate(self, string):
  14. try:
  15. return self.templater(string).substitute(self.mapping)
  16. except ValueError:
  17. raise InvalidInterpolation(string)
  18. def interpolate_environment_variables(version, config, section, environment):
  19. if version <= V2_0:
  20. interpolator = Interpolator(Template, environment)
  21. else:
  22. interpolator = Interpolator(TemplateWithDefaults, environment)
  23. def process_item(name, config_dict):
  24. return dict(
  25. (key, interpolate_value(name, key, val, section, interpolator))
  26. for key, val in (config_dict or {}).items()
  27. )
  28. return dict(
  29. (name, process_item(name, config_dict or {}))
  30. for name, config_dict in config.items()
  31. )
  32. def interpolate_value(name, config_key, value, section, interpolator):
  33. try:
  34. return recursive_interpolate(value, interpolator)
  35. except InvalidInterpolation as e:
  36. raise ConfigurationError(
  37. 'Invalid interpolation format for "{config_key}" option '
  38. 'in {section} "{name}": "{string}"'.format(
  39. config_key=config_key,
  40. name=name,
  41. section=section,
  42. string=e.string))
  43. def recursive_interpolate(obj, interpolator):
  44. if isinstance(obj, six.string_types):
  45. return interpolator.interpolate(obj)
  46. if isinstance(obj, dict):
  47. return dict(
  48. (key, recursive_interpolate(val, interpolator))
  49. for (key, val) in obj.items()
  50. )
  51. if isinstance(obj, list):
  52. return [recursive_interpolate(val, interpolator) for val in obj]
  53. return obj
  54. class TemplateWithDefaults(Template):
  55. idpattern = r'[_a-z][_a-z0-9]*(?::?-[^}]+)?'
  56. # Modified from python2.7/string.py
  57. def substitute(self, mapping):
  58. # Helper function for .sub()
  59. def convert(mo):
  60. # Check the most common path first.
  61. named = mo.group('named') or mo.group('braced')
  62. if named is not None:
  63. if ':-' in named:
  64. var, _, default = named.partition(':-')
  65. return mapping.get(var) or default
  66. if '-' in named:
  67. var, _, default = named.partition('-')
  68. return mapping.get(var, default)
  69. val = mapping[named]
  70. return '%s' % (val,)
  71. if mo.group('escaped') is not None:
  72. return self.delimiter
  73. if mo.group('invalid') is not None:
  74. self._invalid(mo)
  75. raise ValueError('Unrecognized named group in pattern',
  76. self.pattern)
  77. return self.pattern.sub(convert, self.template)
  78. class InvalidInterpolation(Exception):
  79. def __init__(self, string):
  80. self.string = string