瀏覽代碼

Handle volume driver change error in config.

Assume version=1 if file is empty in get_config_version
Empty files are invalid anyway, so this simplifies the algorithm
somewhat.
https://github.com/docker/compose/pull/2421#discussion_r47223144

Don't leak version considerations in interpolation/service validation

Signed-off-by: Joffrey F <[email protected]>
Joffrey F 10 年之前
父節點
當前提交
a7689f3da8

+ 16 - 6
compose/config/config.py

@@ -118,6 +118,9 @@ class ConfigFile(namedtuple('_ConfigFile', 'filename config')):
     def from_filename(cls, filename):
     def from_filename(cls, filename):
         return cls(filename, load_yaml(filename))
         return cls(filename, load_yaml(filename))
 
 
+    def get_service_dicts(self, version):
+        return self.config if version == 1 else self.config.get('services', {})
+
 
 
 class Config(namedtuple('_Config', 'version services volumes')):
 class Config(namedtuple('_Config', 'version services volumes')):
     """
     """
@@ -164,9 +167,11 @@ def find(base_dir, filenames):
 def get_config_version(config_details):
 def get_config_version(config_details):
     def get_version(config):
     def get_version(config):
         if config.config is None:
         if config.config is None:
-            return None
+            return 1
         version = config.config.get('version', 1)
         version = config.config.get('version', 1)
         if isinstance(version, dict):
         if isinstance(version, dict):
+            # in that case 'version' is probably a service name, so assume
+            # this is a legacy (version=1) file
             version = 1
             version = 1
         return version
         return version
 
 
@@ -176,9 +181,6 @@ def get_config_version(config_details):
     for next_file in config_details.config_files[1:]:
     for next_file in config_details.config_files[1:]:
         validate_top_level_object(next_file)
         validate_top_level_object(next_file)
         next_file_version = get_version(next_file)
         next_file_version = get_version(next_file)
-        if version is None:
-            version = next_file_version
-            continue
 
 
         if version != next_file_version and next_file_version is not None:
         if version != next_file_version and next_file_version is not None:
             raise ConfigurationError(
             raise ConfigurationError(
@@ -316,8 +318,16 @@ def load_services(working_dir, config_files, version):
 
 
 
 
 def process_config_file(config_file, version, service_name=None):
 def process_config_file(config_file, version, service_name=None):
-    validate_top_level_service_objects(config_file, version)
-    processed_config = interpolate_environment_variables(config_file.config, version)
+    service_dicts = config_file.get_service_dicts(version)
+    validate_top_level_service_objects(
+        config_file.filename, service_dicts
+    )
+    interpolated_config = interpolate_environment_variables(service_dicts)
+    if version == 2:
+        processed_config = dict(config_file.config)
+        processed_config.update({'services': interpolated_config})
+    if version == 1:
+        processed_config = interpolated_config
     validate_against_fields_schema(
     validate_against_fields_schema(
         processed_config, config_file.filename, version
         processed_config, config_file.filename, version
     )
     )

+ 2 - 8
compose/config/interpolation.py

@@ -8,19 +8,13 @@ from .errors import ConfigurationError
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 
 
 
 
-def interpolate_environment_variables(config, version):
+def interpolate_environment_variables(service_dicts):
     mapping = BlankDefaultDict(os.environ)
     mapping = BlankDefaultDict(os.environ)
-    service_dicts = config if version == 1 else config.get('services', {})
 
 
-    interpolated = dict(
+    return dict(
         (service_name, process_service(service_name, service_dict, mapping))
         (service_name, process_service(service_name, service_dict, mapping))
         for (service_name, service_dict) in service_dicts.items()
         for (service_name, service_dict) in service_dicts.items()
     )
     )
-    if version == 1:
-        return interpolated
-    result = dict(config)
-    result.update({'services': interpolated})
-    return result
 
 
 
 
 def process_service(service_name, service_dict, mapping):
 def process_service(service_name, service_dict, mapping):

+ 5 - 5
compose/config/validation.py

@@ -74,19 +74,18 @@ def format_boolean_in_environment(instance):
     return True
     return True
 
 
 
 
-def validate_top_level_service_objects(config_file, version):
+def validate_top_level_service_objects(filename, service_dicts):
     """Perform some high level validation of the service name and value.
     """Perform some high level validation of the service name and value.
 
 
     This validation must happen before interpolation, which must happen
     This validation must happen before interpolation, which must happen
     before the rest of validation, which is why it's separate from the
     before the rest of validation, which is why it's separate from the
     rest of the service validation.
     rest of the service validation.
     """
     """
-    service_dicts = config_file.config if version == 1 else config_file.config.get('services', {})
     for service_name, service_dict in service_dicts.items():
     for service_name, service_dict in service_dicts.items():
         if not isinstance(service_name, six.string_types):
         if not isinstance(service_name, six.string_types):
             raise ConfigurationError(
             raise ConfigurationError(
                 "In file '{}' service name: {} needs to be a string, eg '{}'".format(
                 "In file '{}' service name: {} needs to be a string, eg '{}'".format(
-                    config_file.filename,
+                    filename,
                     service_name,
                     service_name,
                     service_name))
                     service_name))
 
 
@@ -95,8 +94,9 @@ def validate_top_level_service_objects(config_file, version):
                 "In file '{}' service '{}' doesn\'t have any configuration options. "
                 "In file '{}' service '{}' doesn\'t have any configuration options. "
                 "All top level keys in your docker-compose.yml must map "
                 "All top level keys in your docker-compose.yml must map "
                 "to a dictionary of configuration options.".format(
                 "to a dictionary of configuration options.".format(
-                    config_file.filename,
-                    service_name))
+                    filename, service_name
+                )
+            )
 
 
 
 
 def validate_top_level_object(config_file):
 def validate_top_level_object(config_file):

+ 12 - 0
compose/project.py

@@ -236,6 +236,18 @@ class Project(object):
             raise ConfigurationError(
             raise ConfigurationError(
                 'Volume %s specifies nonexistent driver %s' % (volume.name, volume.driver)
                 'Volume %s specifies nonexistent driver %s' % (volume.name, volume.driver)
             )
             )
+        except APIError as e:
+            if 'Choose a different volume name' in str(e):
+                raise ConfigurationError(
+                    'Configuration for volume {0} specifies driver {1}, but '
+                    'a volume with the same name uses a different driver '
+                    '({3}). If you wish to use the new configuration, please '
+                    'remove the existing volume "{2}" first:\n'
+                    '$ docker volume rm {2}'.format(
+                        volume.name, volume.driver, volume.full_name,
+                        volume.inspect()['Driver']
+                    )
+                )
 
 
     def restart(self, service_names=None, **options):
     def restart(self, service_names=None, **options):
         containers = self.containers(service_names, stopped=True)
         containers = self.containers(service_names, stopped=True)

+ 36 - 2
tests/integration/project_test.py

@@ -579,11 +579,11 @@ class ProjectTest(DockerClientTestCase):
         vol_name = '{0:x}'.format(random.getrandbits(32))
         vol_name = '{0:x}'.format(random.getrandbits(32))
 
 
         config_data = config.Config(
         config_data = config.Config(
-            2, [{
+            version=2, services=[{
                 'name': 'web',
                 'name': 'web',
                 'image': 'busybox:latest',
                 'image': 'busybox:latest',
                 'command': 'top'
                 'command': 'top'
-            }], {vol_name: {'driver': 'foobar'}}
+            }], volumes={vol_name: {'driver': 'foobar'}}
         )
         )
 
 
         project = Project.from_config(
         project = Project.from_config(
@@ -592,3 +592,37 @@ class ProjectTest(DockerClientTestCase):
         )
         )
         with self.assertRaises(config.ConfigurationError):
         with self.assertRaises(config.ConfigurationError):
             project.initialize_volumes()
             project.initialize_volumes()
+
+    def test_project_up_updated_driver(self):
+        vol_name = '{0:x}'.format(random.getrandbits(32))
+        full_vol_name = 'composetest_{0}'.format(vol_name)
+
+        config_data = config.Config(
+            version=2, services=[{
+                'name': 'web',
+                'image': 'busybox:latest',
+                'command': 'top'
+            }], volumes={vol_name: {'driver': 'local'}}
+        )
+        project = Project.from_config(
+            name='composetest',
+            config_data=config_data, client=self.client
+        )
+        project.initialize_volumes()
+
+        volume_data = self.client.inspect_volume(full_vol_name)
+        self.assertEqual(volume_data['Name'], full_vol_name)
+        self.assertEqual(volume_data['Driver'], 'local')
+
+        config_data = config_data._replace(
+            volumes={vol_name: {'driver': 'smb'}}
+        )
+        project = Project.from_config(
+            name='composetest',
+            config_data=config_data, client=self.client
+        )
+        with self.assertRaises(config.ConfigurationError) as e:
+            project.initialize_volumes()
+        assert 'Configuration for volume {0} specifies driver smb'.format(
+            vol_name
+        ) in str(e.exception)

+ 23 - 0
tests/unit/config/config_test.py

@@ -286,6 +286,18 @@ class ConfigTest(unittest.TestCase):
         error_msg = "Top level object in 'override.yml' needs to be an object"
         error_msg = "Top level object in 'override.yml' needs to be an object"
         assert error_msg in exc.exconly()
         assert error_msg in exc.exconly()
 
 
+    def test_load_with_multiple_files_and_empty_override_v2(self):
+        base_file = config.ConfigFile(
+            'base.yml',
+            {'version': 2, 'services': {'web': {'image': 'example/web'}}})
+        override_file = config.ConfigFile('override.yml', None)
+        details = config.ConfigDetails('.', [base_file, override_file])
+
+        with pytest.raises(ConfigurationError) as exc:
+            config.load(details)
+        error_msg = "Top level object in 'override.yml' needs to be an object"
+        assert error_msg in exc.exconly()
+
     def test_load_with_multiple_files_and_empty_base(self):
     def test_load_with_multiple_files_and_empty_base(self):
         base_file = config.ConfigFile('base.yml', None)
         base_file = config.ConfigFile('base.yml', None)
         override_file = config.ConfigFile(
         override_file = config.ConfigFile(
@@ -297,6 +309,17 @@ class ConfigTest(unittest.TestCase):
             config.load(details)
             config.load(details)
         assert "Top level object in 'base.yml' needs to be an object" in exc.exconly()
         assert "Top level object in 'base.yml' needs to be an object" in exc.exconly()
 
 
+    def test_load_with_multiple_files_and_empty_base_v2(self):
+        base_file = config.ConfigFile('base.yml', None)
+        override_file = config.ConfigFile(
+            'override.tml',
+            {'version': 2, 'services': {'web': {'image': 'example/web'}}}
+        )
+        details = config.ConfigDetails('.', [base_file, override_file])
+        with pytest.raises(ConfigurationError) as exc:
+            config.load(details)
+        assert "Top level object in 'base.yml' needs to be an object" in exc.exconly()
+
     def test_load_with_multiple_files_and_extends_in_override_file(self):
     def test_load_with_multiple_files_and_extends_in_override_file(self):
         base_file = config.ConfigFile(
         base_file = config.ConfigFile(
             'base.yaml',
             'base.yaml',