Explorar o código

Refactor config loading for handling volumes_from in v2.

Signed-off-by: Daniel Nephin <[email protected]>
Daniel Nephin %!s(int64=10) %!d(string=hai) anos
pai
achega
de949284f5
Modificáronse 1 ficheiros con 25 adicións e 35 borrados
  1. 25 35
      compose/config/config.py

+ 25 - 35
compose/config/config.py

@@ -122,6 +122,9 @@ class ConfigFile(namedtuple('_ConfigFile', 'filename config')):
     def get_service_dicts(self, version):
         return self.config if version == 1 else self.config.get('services', {})
 
+    def get_volumes(self, version):
+        return {} if version == 1 else self.config.get('volumes', {})
+
 
 class Config(namedtuple('_Config', 'version services volumes')):
     """
@@ -243,41 +246,29 @@ def load(config_details):
     if version not in COMPOSEFILE_VERSIONS:
         raise ConfigurationError('Invalid config version provided: {0}'.format(version))
 
-    processed_files = []
-    for config_file in config_details.config_files:
-        processed_files.append(
-            process_config_file(config_file, version=version)
-        )
+    processed_files = [
+        process_config_file(config_file, version=version)
+        for config_file in config_details.config_files
+    ]
     config_details = config_details._replace(config_files=processed_files)
 
-    if version == 1:
-        service_dicts = load_services(
-            config_details.working_dir, config_details.config_files,
-            version
-        )
-        volumes = {}
-    elif version == 2:
-        config_files = [
-            ConfigFile(f.filename, f.config.get('services', {}))
-            for f in config_details.config_files
-        ]
-        service_dicts = load_services(
-            config_details.working_dir, config_files, version
-        )
-        volumes = load_volumes(config_details.config_files)
-
+    volumes = load_volumes(config_details.config_files)
+    service_dicts = load_services(
+        config_details.working_dir,
+        config_details.config_files[0].filename,
+        [file.get_service_dicts(version) for file in config_details.config_files],
+        version)
     return Config(version, service_dicts, volumes)
 
 
 def load_volumes(config_files):
     volumes = {}
     for config_file in config_files:
-        for name, volume_config in config_file.config.get('volumes', {}).items():
-            if volume_config is None:
-                volumes.update({name: {}})
+        for name, volume_config in config_file.get_volumes().items():
+            volumes[name] = volume_config or {}
+            if not volume_config:
                 continue
 
-            volumes.update({name: volume_config})
             external = volume_config.get('external')
             if external:
                 if len(volume_config.keys()) > 1:
@@ -296,8 +287,8 @@ def load_volumes(config_files):
     return volumes
 
 
-def load_services(working_dir, config_files, version):
-    def build_service(filename, service_name, service_dict):
+def load_services(working_dir, filename, service_configs, version):
+    def build_service(service_name, service_dict):
         service_config = ServiceConfig.with_abs_paths(
             working_dir,
             filename,
@@ -314,10 +305,10 @@ def load_services(working_dir, config_files, version):
         service_dict['name'] = service_config.name
         return service_dict
 
-    def build_services(config_file):
+    def build_services(service_config):
         return sort_service_dicts([
-            build_service(config_file.filename, name, service_dict)
-            for name, service_dict in config_file.config.items()
+            build_service(name, service_dict)
+            for name, service_dict in service_config.items()
         ])
 
     def merge_services(base, override):
@@ -330,12 +321,11 @@ def load_services(working_dir, config_files, version):
             for name in all_service_names
         }
 
-    config_file = config_files[0]
-    for next_file in config_files[1:]:
-        config = merge_services(config_file.config, next_file.config)
-        config_file = config_file._replace(config=config)
+    service_config = service_configs[0]
+    for next_config in service_configs[1:]:
+        service_config = merge_services(service_config, next_config)
 
-    return build_services(config_file)
+    return build_services(service_config)
 
 
 def process_config_file(config_file, version, service_name=None):