interpolation.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from __future__ import absolute_import
  2. from __future__ import unicode_literals
  3. import logging
  4. from string import Template
  5. import six
  6. from .errors import ConfigurationError
  7. from compose.const import COMPOSEFILE_V1 as V1
  8. from compose.const import COMPOSEFILE_V2_0 as V2_0
  9. log = logging.getLogger(__name__)
  10. class Interpolator(object):
  11. def __init__(self, templater, mapping):
  12. self.templater = templater
  13. self.mapping = mapping
  14. def interpolate(self, string):
  15. try:
  16. return self.templater(string).substitute(self.mapping)
  17. except ValueError:
  18. raise InvalidInterpolation(string)
  19. def interpolate_environment_variables(version, config, section, environment):
  20. if version in (V2_0, V1):
  21. interpolator = Interpolator(Template, environment)
  22. else:
  23. interpolator = Interpolator(TemplateWithDefaults, environment)
  24. def process_item(name, config_dict):
  25. return dict(
  26. (key, interpolate_value(name, key, val, section, interpolator))
  27. for key, val in (config_dict or {}).items()
  28. )
  29. return dict(
  30. (name, process_item(name, config_dict or {}))
  31. for name, config_dict in config.items()
  32. )
  33. def interpolate_value(name, config_key, value, section, interpolator):
  34. try:
  35. return recursive_interpolate(value, interpolator)
  36. except InvalidInterpolation as e:
  37. raise ConfigurationError(
  38. 'Invalid interpolation format for "{config_key}" option '
  39. 'in {section} "{name}": "{string}"'.format(
  40. config_key=config_key,
  41. name=name,
  42. section=section,
  43. string=e.string))
  44. def recursive_interpolate(obj, interpolator):
  45. if isinstance(obj, six.string_types):
  46. return interpolator.interpolate(obj)
  47. if isinstance(obj, dict):
  48. return dict(
  49. (key, recursive_interpolate(val, interpolator))
  50. for (key, val) in obj.items()
  51. )
  52. if isinstance(obj, list):
  53. return [recursive_interpolate(val, interpolator) for val in obj]
  54. return obj
  55. class TemplateWithDefaults(Template):
  56. idpattern = r'[_a-z][_a-z0-9]*(?::?-[^}]+)?'
  57. # Modified from python2.7/string.py
  58. def substitute(self, mapping):
  59. # Helper function for .sub()
  60. def convert(mo):
  61. # Check the most common path first.
  62. named = mo.group('named') or mo.group('braced')
  63. if named is not None:
  64. if ':-' in named:
  65. var, _, default = named.partition(':-')
  66. return mapping.get(var) or default
  67. if '-' in named:
  68. var, _, default = named.partition('-')
  69. return mapping.get(var, default)
  70. val = mapping[named]
  71. return '%s' % (val,)
  72. if mo.group('escaped') is not None:
  73. return self.delimiter
  74. if mo.group('invalid') is not None:
  75. self._invalid(mo)
  76. raise ValueError('Unrecognized named group in pattern',
  77. self.pattern)
  78. return self.pattern.sub(convert, self.template)
  79. class InvalidInterpolation(Exception):
  80. def __init__(self, string):
  81. self.string = string