Răsfoiți Sursa

Merge multi-value options when extending

Closes #1143.

Signed-off-by: Aanand Prasad <[email protected]>
Aanand Prasad 10 ani în urmă
părinte
comite
907918b492
2 a modificat fișierele cu 90 adăugiri și 11 ștergeri
  1. 26 4
      compose/config.py
  2. 64 7
      tests/unit/config_test.py

+ 26 - 4
compose/config.py

@@ -195,10 +195,23 @@ def merge_service_dicts(base, override):
     if 'build' in override and 'image' in d:
         del d['image']
 
-    for k in ALLOWED_KEYS:
-        if k not in ['environment', 'volumes']:
-            if k in override:
-                d[k] = override[k]
+    list_keys = ['ports', 'expose', 'external_links']
+
+    for key in list_keys:
+        if key in base or key in override:
+            d[key] = base.get(key, []) + override.get(key, [])
+
+    list_or_string_keys = ['dns', 'dns_search']
+
+    for key in list_or_string_keys:
+        if key in base or key in override:
+            d[key] = to_list(base.get(key)) + to_list(override.get(key))
+
+    already_merged_keys = ['environment', 'volumes'] + list_keys + list_or_string_keys
+
+    for k in set(ALLOWED_KEYS) - set(already_merged_keys):
+        if k in override:
+            d[k] = override[k]
 
     return d
 
@@ -354,6 +367,15 @@ def expand_path(working_dir, path):
     return os.path.abspath(os.path.join(working_dir, path))
 
 
+def to_list(value):
+    if value is None:
+        return []
+    elif isinstance(value, six.string_types):
+        return [value]
+    else:
+        return value
+
+
 def get_service_name_from_net(net_config):
     if not net_config:
         return

+ 64 - 7
tests/unit/config_test.py

@@ -40,40 +40,40 @@ class ConfigTest(unittest.TestCase):
         config.make_service_dict('foo', {'ports': ['8000']})
 
 
-class MergeTest(unittest.TestCase):
-    def test_merge_volumes_empty(self):
+class MergeVolumesTest(unittest.TestCase):
+    def test_empty(self):
         service_dict = config.merge_service_dicts({}, {})
         self.assertNotIn('volumes', service_dict)
 
-    def test_merge_volumes_no_override(self):
+    def test_no_override(self):
         service_dict = config.merge_service_dicts(
             {'volumes': ['/foo:/code', '/data']},
             {},
         )
         self.assertEqual(set(service_dict['volumes']), set(['/foo:/code', '/data']))
 
-    def test_merge_volumes_no_base(self):
+    def test_no_base(self):
         service_dict = config.merge_service_dicts(
             {},
             {'volumes': ['/bar:/code']},
         )
         self.assertEqual(set(service_dict['volumes']), set(['/bar:/code']))
 
-    def test_merge_volumes_override_explicit_path(self):
+    def test_override_explicit_path(self):
         service_dict = config.merge_service_dicts(
             {'volumes': ['/foo:/code', '/data']},
             {'volumes': ['/bar:/code']},
         )
         self.assertEqual(set(service_dict['volumes']), set(['/bar:/code', '/data']))
 
-    def test_merge_volumes_add_explicit_path(self):
+    def test_add_explicit_path(self):
         service_dict = config.merge_service_dicts(
             {'volumes': ['/foo:/code', '/data']},
             {'volumes': ['/bar:/code', '/quux:/data']},
         )
         self.assertEqual(set(service_dict['volumes']), set(['/bar:/code', '/quux:/data']))
 
-    def test_merge_volumes_remove_explicit_path(self):
+    def test_remove_explicit_path(self):
         service_dict = config.merge_service_dicts(
             {'volumes': ['/foo:/code', '/quux:/data']},
             {'volumes': ['/bar:/code', '/data']},
@@ -114,6 +114,63 @@ class MergeTest(unittest.TestCase):
         )
 
 
+class MergeListsTest(unittest.TestCase):
+    def test_empty(self):
+        service_dict = config.merge_service_dicts({}, {})
+        self.assertNotIn('ports', service_dict)
+
+    def test_no_override(self):
+        service_dict = config.merge_service_dicts(
+            {'ports': ['10:8000', '9000']},
+            {},
+        )
+        self.assertEqual(set(service_dict['ports']), set(['10:8000', '9000']))
+
+    def test_no_base(self):
+        service_dict = config.merge_service_dicts(
+            {},
+            {'ports': ['10:8000', '9000']},
+        )
+        self.assertEqual(set(service_dict['ports']), set(['10:8000', '9000']))
+
+    def test_add_item(self):
+        service_dict = config.merge_service_dicts(
+            {'ports': ['10:8000', '9000']},
+            {'ports': ['20:8000']},
+        )
+        self.assertEqual(set(service_dict['ports']), set(['10:8000', '9000', '20:8000']))
+
+
+class MergeStringsOrListsTest(unittest.TestCase):
+    def test_no_override(self):
+        service_dict = config.merge_service_dicts(
+            {'dns': '8.8.8.8'},
+            {},
+        )
+        self.assertEqual(set(service_dict['dns']), set(['8.8.8.8']))
+
+    def test_no_base(self):
+        service_dict = config.merge_service_dicts(
+            {},
+            {'dns': '8.8.8.8'},
+        )
+        self.assertEqual(set(service_dict['dns']), set(['8.8.8.8']))
+
+    def test_add_string(self):
+        service_dict = config.merge_service_dicts(
+            {'dns': ['8.8.8.8']},
+            {'dns': '9.9.9.9'},
+        )
+        self.assertEqual(set(service_dict['dns']), set(['8.8.8.8', '9.9.9.9']))
+
+    def test_add_list(self):
+        service_dict = config.merge_service_dicts(
+            {'dns': '8.8.8.8'},
+            {'dns': ['9.9.9.9']},
+        )
+        self.assertEqual(set(service_dict['dns']), set(['8.8.8.8', '9.9.9.9']))
+
+
 class EnvTest(unittest.TestCase):
     def test_parse_environment_as_list(self):
         environment = [