Răsfoiți Sursa

Validate that an extended config file has the same version as the base.

Signed-off-by: Daniel Nephin <[email protected]>
Daniel Nephin 9 ani în urmă
părinte
comite
89e31f7a8d
2 a modificat fișierele cu 50 adăugiri și 27 ștergeri
  1. 20 20
      compose/config/config.py
  2. 30 7
      tests/unit/config/config_test.py

+ 20 - 20
compose/config/config.py

@@ -188,10 +188,10 @@ def find(base_dir, filenames):
         [ConfigFile.from_filename(f) for f in filenames])
 
 
-def validate_config_version(config_details):
-    main_file = config_details.config_files[0]
+def validate_config_version(config_files):
+    main_file = config_files[0]
     validate_top_level_object(main_file)
-    for next_file in config_details.config_files[1:]:
+    for next_file in config_files[1:]:
         validate_top_level_object(next_file)
 
         if main_file.version != next_file.version:
@@ -254,7 +254,7 @@ def load(config_details):
 
     Return a fully interpolated, extended and validated configuration.
     """
-    validate_config_version(config_details)
+    validate_config_version(config_details.config_files)
 
     processed_files = [
         process_config_file(config_file)
@@ -267,9 +267,8 @@ def load(config_details):
     networks = load_mapping(config_details.config_files, 'get_networks', 'Network')
     service_dicts = load_services(
         config_details.working_dir,
-        main_file.filename,
-        [file.get_service_dicts() for file in config_details.config_files],
-        main_file.version)
+        main_file,
+        [file.get_service_dicts() for file in config_details.config_files])
     return Config(main_file.version, service_dicts, volumes, networks)
 
 
@@ -303,21 +302,21 @@ def load_mapping(config_files, get_func, entity_type):
     return mapping
 
 
-def load_services(working_dir, filename, service_configs, version):
+def load_services(working_dir, config_file, service_configs):
     def build_service(service_name, service_dict, service_names):
         service_config = ServiceConfig.with_abs_paths(
             working_dir,
-            filename,
+            config_file.filename,
             service_name,
             service_dict)
-        resolver = ServiceExtendsResolver(service_config, version)
+        resolver = ServiceExtendsResolver(service_config, config_file)
         service_dict = process_service(resolver.run())
 
-        validate_service(service_dict, service_config.name, version)
+        validate_service(service_dict, service_config.name, config_file.version)
         service_dict = finalize_service(
             service_config._replace(config=service_dict),
             service_names,
-            version)
+            config_file.version)
         return service_dict
 
     def build_services(service_config):
@@ -333,7 +332,7 @@ def load_services(working_dir, filename, service_configs, version):
             name: merge_service_dicts_from_files(
                 base.get(name, {}),
                 override.get(name, {}),
-                version)
+                config_file.version)
             for name in all_service_names
         }
 
@@ -373,11 +372,11 @@ def process_config_file(config_file, service_name=None):
 
 
 class ServiceExtendsResolver(object):
-    def __init__(self, service_config, version, already_seen=None):
+    def __init__(self, service_config, config_file, already_seen=None):
         self.service_config = service_config
         self.working_dir = service_config.working_dir
         self.already_seen = already_seen or []
-        self.version = version
+        self.config_file = config_file
 
     @property
     def signature(self):
@@ -404,8 +403,10 @@ class ServiceExtendsResolver(object):
         config_path = self.get_extended_config_path(extends)
         service_name = extends['service']
 
+        extends_file = ConfigFile.from_filename(config_path)
+        validate_config_version([self.config_file, extends_file])
         extended_file = process_config_file(
-            ConfigFile.from_filename(config_path),
+            extends_file,
             service_name=service_name)
         service_config = extended_file.config[service_name]
         return config_path, service_config, service_name
@@ -417,7 +418,7 @@ class ServiceExtendsResolver(object):
                 extended_config_path,
                 service_name,
                 service_dict),
-            self.version,
+            self.config_file,
             already_seen=self.already_seen + [self.signature])
 
         service_config = resolver.run()
@@ -425,13 +426,12 @@ class ServiceExtendsResolver(object):
         validate_extended_service_dict(
             other_service_dict,
             extended_config_path,
-            service_name,
-        )
+            service_name)
 
         return merge_service_dicts(
             other_service_dict,
             self.service_config.config,
-            self.version)
+            self.config_file.version)
 
     def get_extended_config_path(self, extends_options):
         """Service we are extending either has a value for 'file' set, which we

+ 30 - 7
tests/unit/config/config_test.py

@@ -25,14 +25,15 @@ V1 = 1
 
 
 def make_service_dict(name, service_dict, working_dir, filename=None):
+    """Test helper function to construct a ServiceExtendsResolver
     """
-    Test helper function to construct a ServiceExtendsResolver
-    """
-    resolver = config.ServiceExtendsResolver(config.ServiceConfig(
-        working_dir=working_dir,
-        filename=filename,
-        name=name,
-        config=service_dict), version=1)
+    resolver = config.ServiceExtendsResolver(
+        config.ServiceConfig(
+            working_dir=working_dir,
+            filename=filename,
+            name=name,
+            config=service_dict),
+        config.ConfigFile(filename=filename, config={}))
     return config.process_service(resolver.run())
 
 
@@ -1888,6 +1889,28 @@ class ExtendsTest(unittest.TestCase):
 
         assert config == expected
 
+    def test_extends_with_mixed_versions_is_error(self):
+        tmpdir = py.test.ensuretemp('test_extends_with_mixed_version')
+        self.addCleanup(tmpdir.remove)
+        tmpdir.join('docker-compose.yml').write("""
+            version: 2
+            services:
+              web:
+                extends:
+                  file: base.yml
+                  service: base
+                image: busybox
+        """)
+        tmpdir.join('base.yml').write("""
+            base:
+              volumes: ['/foo']
+              ports: ['3000:3000']
+        """)
+
+        with pytest.raises(ConfigurationError) as exc:
+            load_from_filename(str(tmpdir.join('docker-compose.yml')))
+        assert 'Version mismatch' in exc.exconly()
+
 
 @pytest.mark.xfail(IS_WINDOWS_PLATFORM, reason='paths use slash')
 class ExpandPathTest(unittest.TestCase):