Explorar o código

Refactor config loading to move version check into ConfigFile.

Adds the cached_property package.

Signed-off-by: Daniel Nephin <[email protected]>
Daniel Nephin %!s(int64=9) %!d(string=hai) anos
pai
achega
c3968a439f
Modificáronse 4 ficheiros con 55 adicións e 50 borrados
  1. 48 44
      compose/config/config.py
  2. 5 6
      compose/config/validation.py
  3. 1 0
      requirements.txt
  4. 1 0
      setup.py

+ 48 - 44
compose/config/config.py

@@ -10,6 +10,7 @@ from collections import namedtuple
 
 import six
 import yaml
+from cached_property import cached_property
 
 from ..const import COMPOSEFILE_VERSIONS
 from .errors import CircularReference
@@ -119,11 +120,23 @@ class ConfigFile(namedtuple('_ConfigFile', 'filename config')):
     def from_filename(cls, filename):
         return cls(filename, load_yaml(filename))
 
-    def get_service_dicts(self, version):
-        return self.config if version == 1 else self.config.get('services', {})
+    @cached_property
+    def version(self):
+        if self.config is None:
+            return 1
+        version = self.config.get('version', 1)
+        if isinstance(version, dict):
+            log.warn("Unexpected type for field 'version', in file {} assuming "
+                     "version is the name of a service, and defaulting to "
+                     "Compose file version 1".format(self.filename))
+            return 1
+        return version
+
+    def get_service_dicts(self):
+        return self.config if self.version == 1 else self.config.get('services', {})
 
-    def get_volumes(self, version):
-        return {} if version == 1 else self.config.get('volumes', {})
+    def get_volumes(self):
+        return {} if self.version == 1 else self.config.get('volumes', {})
 
 
 class Config(namedtuple('_Config', 'version services volumes')):
@@ -168,32 +181,24 @@ def find(base_dir, filenames):
         [ConfigFile.from_filename(f) for f in filenames])
 
 
-def get_config_version(config_details):
-    def get_version(config):
-        if config.config is None:
-            return 1
-        version = config.config.get('version', 1)
-        if isinstance(version, dict):
-            # in that case 'version' is probably a service name, so assume
-            # this is a legacy (version=1) file
-            version = 1
-        return version
-
+def validate_config_version(config_details):
     main_file = config_details.config_files[0]
     validate_top_level_object(main_file)
-    version = get_version(main_file)
     for next_file in config_details.config_files[1:]:
         validate_top_level_object(next_file)
-        next_file_version = get_version(next_file)
 
-        if version != next_file_version and next_file_version is not None:
+        if main_file.version != next_file.version:
             raise ConfigurationError(
-                "Version mismatch: main file {0} specifies version {1} but "
+                "Version mismatch: file {0} specifies version {1} but "
                 "extension file {2} uses version {3}".format(
-                    main_file.filename, version, next_file.filename, next_file_version
-                )
-            )
-    return version
+                    main_file.filename,
+                    main_file.version,
+                    next_file.filename,
+                    next_file.version))
+
+    if main_file.version not in COMPOSEFILE_VERSIONS:
+        raise ConfigurationError(
+            'Invalid Compose file version: {0}'.format(main_file.version))
 
 
 def get_default_config_files(base_dir):
@@ -242,23 +247,22 @@ def load(config_details):
 
     Return a fully interpolated, extended and validated configuration.
     """
-    version = get_config_version(config_details)
-    if version not in COMPOSEFILE_VERSIONS:
-        raise ConfigurationError('Invalid config version provided: {0}'.format(version))
+    validate_config_version(config_details)
 
     processed_files = [
-        process_config_file(config_file, version=version)
+        process_config_file(config_file)
         for config_file in config_details.config_files
     ]
     config_details = config_details._replace(config_files=processed_files)
 
+    main_file = config_details.config_files[0]
     volumes = load_volumes(config_details.config_files)
     service_dicts = load_services(
         config_details.working_dir,
-        config_details.config_files[0].filename,
-        [file.get_service_dicts(version) for file in config_details.config_files],
-        version)
-    return Config(version, service_dicts, volumes)
+        main_file.filename,
+        [file.get_service_dicts() for file in config_details.config_files],
+        main_file.version)
+    return Config(main_file.version, service_dicts, volumes)
 
 
 def load_volumes(config_files):
@@ -328,27 +332,28 @@ def load_services(working_dir, filename, service_configs, version):
     return build_services(service_config)
 
 
-def process_config_file(config_file, version, service_name=None):
-    service_dicts = config_file.get_service_dicts(version)
-    validate_top_level_service_objects(
-        config_file.filename, service_dicts
-    )
+def process_config_file(config_file, service_name=None):
+    service_dicts = config_file.get_service_dicts()
+    validate_top_level_service_objects(config_file.filename, service_dicts)
+
+    # TODO: interpolate config in volumes/network sections as well
     interpolated_config = interpolate_environment_variables(service_dicts)
-    if version == 2:
+
+    if config_file.version == 2:
         processed_config = dict(config_file.config)
         processed_config.update({'services': interpolated_config})
-    if version == 1:
+    if config_file.version == 1:
         processed_config = interpolated_config
-    validate_against_fields_schema(
-        processed_config, config_file.filename, version
-    )
+
+    config_file = config_file._replace(config=processed_config)
+    validate_against_fields_schema(config_file)
 
     if service_name and service_name not in processed_config:
         raise ConfigurationError(
             "Cannot extend service '{}' in {}: Service not found".format(
                 service_name, config_file.filename))
 
-    return config_file._replace(config=processed_config)
+    return config_file
 
 
 class ServiceExtendsResolver(object):
@@ -385,8 +390,7 @@ class ServiceExtendsResolver(object):
 
         extended_file = process_config_file(
             ConfigFile.from_filename(config_path),
-            version=self.version, service_name=service_name
-        )
+            service_name=service_name)
         service_config = extended_file.config[service_name]
         return config_path, service_config, service_name
 

+ 5 - 6
compose/config/validation.py

@@ -105,8 +105,7 @@ def validate_top_level_service_objects(filename, service_dicts):
 def validate_top_level_object(config_file):
     if not isinstance(config_file.config, dict):
         raise ConfigurationError(
-            "Top level object in '{}' needs to be an object not '{}'. Check "
-            "that you have defined a service at the top level.".format(
+            "Top level object in '{}' needs to be an object not '{}'.".format(
                 config_file.filename,
                 type(config_file.config)))
 
@@ -291,13 +290,13 @@ def process_errors(errors, service_name=None):
     return '\n'.join(format_error_message(error, service_name) for error in errors)
 
 
-def validate_against_fields_schema(config, filename, version):
-    schema_filename = "fields_schema_v{0}.json".format(version)
+def validate_against_fields_schema(config_file):
+    schema_filename = "fields_schema_v{0}.json".format(config_file.version)
     _validate_against_schema(
-        config,
+        config_file.config,
         schema_filename,
         format_checker=["ports", "expose", "bool-value-in-mapping"],
-        filename=filename)
+        filename=config_file.filename)
 
 
 def validate_against_service_schema(config, service_name, version):

+ 1 - 0
requirements.txt

@@ -1,4 +1,5 @@
 PyYAML==3.11
+cached-property==1.2.0
 dockerpty==0.3.4
 docopt==0.6.1
 enum34==1.0.4

+ 1 - 0
setup.py

@@ -28,6 +28,7 @@ def find_version(*file_paths):
 
 
 install_requires = [
+    'cached-property >= 1.2.0',
     'docopt >= 0.6.1, < 0.7',
     'PyYAML >= 3.10, < 4',
     'requests >= 2.6.1, < 2.8',