|
@@ -26,6 +26,7 @@ from .sort_services import get_service_name_from_network_mode
|
|
|
from .sort_services import sort_service_dicts
|
|
|
from .types import parse_extra_hosts
|
|
|
from .types import parse_restart_spec
|
|
|
+from .types import ServiceLink
|
|
|
from .types import VolumeFromSpec
|
|
|
from .types import VolumeSpec
|
|
|
from .validation import match_named_volumes
|
|
@@ -641,51 +642,79 @@ def merge_service_dicts_from_files(base, override, version):
|
|
|
return new_service
|
|
|
|
|
|
|
|
|
-def merge_service_dicts(base, override, version):
|
|
|
- d = {}
|
|
|
+class MergeDict(dict):
|
|
|
+ """A dict-like object responsible for merging two dicts into one."""
|
|
|
+
|
|
|
+ def __init__(self, base, override):
|
|
|
+ self.base = base
|
|
|
+ self.override = override
|
|
|
+
|
|
|
+ def needs_merge(self, field):
|
|
|
+ return field in self.base or field in self.override
|
|
|
+
|
|
|
+ def merge_field(self, field, merge_func, default=None):
|
|
|
+ if not self.needs_merge(field):
|
|
|
+ return
|
|
|
+
|
|
|
+ self[field] = merge_func(
|
|
|
+ self.base.get(field, default),
|
|
|
+ self.override.get(field, default))
|
|
|
+
|
|
|
+ def merge_mapping(self, field, parse_func):
|
|
|
+ if not self.needs_merge(field):
|
|
|
+ return
|
|
|
+
|
|
|
+ self[field] = parse_func(self.base.get(field))
|
|
|
+ self[field].update(parse_func(self.override.get(field)))
|
|
|
|
|
|
- def merge_field(field, merge_func, default=None):
|
|
|
- if field in base or field in override:
|
|
|
- d[field] = merge_func(
|
|
|
- base.get(field, default),
|
|
|
- override.get(field, default))
|
|
|
+ def merge_sequence(self, field, parse_func):
|
|
|
+ def parse_sequence_func(seq):
|
|
|
+ return to_mapping((parse_func(item) for item in seq), 'merge_field')
|
|
|
|
|
|
- def merge_mapping(mapping, parse_func):
|
|
|
- if mapping in base or mapping in override:
|
|
|
- merged = parse_func(base.get(mapping, None))
|
|
|
- merged.update(parse_func(override.get(mapping, None)))
|
|
|
- d[mapping] = merged
|
|
|
+ if not self.needs_merge(field):
|
|
|
+ return
|
|
|
|
|
|
- merge_mapping('environment', parse_environment)
|
|
|
- merge_mapping('labels', parse_labels)
|
|
|
- merge_mapping('ulimits', parse_ulimits)
|
|
|
+ merged = parse_sequence_func(self.base.get(field, []))
|
|
|
+ merged.update(parse_sequence_func(self.override.get(field, [])))
|
|
|
+ self[field] = [item.repr() for item in merged.values()]
|
|
|
+
|
|
|
+ def merge_scalar(self, field):
|
|
|
+ if self.needs_merge(field):
|
|
|
+ self[field] = self.override.get(field, self.base.get(field))
|
|
|
+
|
|
|
+
|
|
|
+def merge_service_dicts(base, override, version):
|
|
|
+ md = MergeDict(base, override)
|
|
|
+
|
|
|
+ md.merge_mapping('environment', parse_environment)
|
|
|
+ md.merge_mapping('labels', parse_labels)
|
|
|
+ md.merge_mapping('ulimits', parse_ulimits)
|
|
|
+ md.merge_sequence('links', ServiceLink.parse)
|
|
|
|
|
|
for field in ['volumes', 'devices']:
|
|
|
- merge_field(field, merge_path_mappings)
|
|
|
+ md.merge_field(field, merge_path_mappings)
|
|
|
|
|
|
for field in [
|
|
|
'depends_on',
|
|
|
'expose',
|
|
|
'external_links',
|
|
|
- 'links',
|
|
|
'ports',
|
|
|
'volumes_from',
|
|
|
]:
|
|
|
- merge_field(field, operator.add, default=[])
|
|
|
+ md.merge_field(field, operator.add, default=[])
|
|
|
|
|
|
for field in ['dns', 'dns_search', 'env_file']:
|
|
|
- merge_field(field, merge_list_or_string)
|
|
|
+ md.merge_field(field, merge_list_or_string)
|
|
|
|
|
|
- for field in set(ALLOWED_KEYS) - set(d):
|
|
|
- if field in base or field in override:
|
|
|
- d[field] = override.get(field, base.get(field))
|
|
|
+ for field in set(ALLOWED_KEYS) - set(md):
|
|
|
+ md.merge_scalar(field)
|
|
|
|
|
|
if version == V1:
|
|
|
- legacy_v1_merge_image_or_build(d, base, override)
|
|
|
+ legacy_v1_merge_image_or_build(md, base, override)
|
|
|
else:
|
|
|
- merge_build(d, base, override)
|
|
|
+ merge_build(md, base, override)
|
|
|
|
|
|
- return d
|
|
|
+ return dict(md)
|
|
|
|
|
|
|
|
|
def merge_build(output, base, override):
|
|
@@ -919,6 +948,10 @@ def to_list(value):
|
|
|
return value
|
|
|
|
|
|
|
|
|
+def to_mapping(sequence, key_field):
|
|
|
+ return {getattr(item, key_field): item for item in sequence}
|
|
|
+
|
|
|
+
|
|
|
def has_uppercase(name):
|
|
|
return any(char in string.ascii_uppercase for char in name)
|
|
|
|