Ver código fonte

Implement environment singleton to be accessed throughout the code

Load and parse environment file from working dir

Signed-off-by: Joffrey F <[email protected]>
Joffrey F 9 anos atrás
pai
commit
c69d8a3bd2

+ 8 - 5
compose/cli/command.py

@@ -21,7 +21,7 @@ log = logging.getLogger(__name__)
 def project_from_options(project_dir, options):
     return get_project(
         project_dir,
-        get_config_path_from_options(options),
+        get_config_path_from_options(project_dir, options),
         project_name=options.get('--project-name'),
         verbose=options.get('--verbose'),
         host=options.get('--host'),
@@ -29,12 +29,13 @@ def project_from_options(project_dir, options):
     )
 
 
-def get_config_path_from_options(options):
+def get_config_path_from_options(base_dir, options):
     file_option = options.get('--file')
     if file_option:
         return file_option
 
-    config_files = os.environ.get('COMPOSE_FILE')
+    environment = config.environment.get_instance(base_dir)
+    config_files = environment.get('COMPOSE_FILE')
     if config_files:
         return config_files.split(os.pathsep)
     return None
@@ -57,8 +58,9 @@ def get_project(project_dir, config_path=None, project_name=None, verbose=False,
     config_details = config.find(project_dir, config_path)
     project_name = get_project_name(config_details.working_dir, project_name)
     config_data = config.load(config_details)
+    environment = config.environment.get_instance(project_dir)
 
-    api_version = os.environ.get(
+    api_version = environment.get(
         'COMPOSE_API_VERSION',
         API_VERSIONS[config_data.version])
     client = get_client(
@@ -73,7 +75,8 @@ def get_project_name(working_dir, project_name=None):
     def normalize_name(name):
         return re.sub(r'[^a-z0-9]', '', name.lower())
 
-    project_name = project_name or os.environ.get('COMPOSE_PROJECT_NAME')
+    environment = config.environment.get_instance(working_dir)
+    project_name = project_name or environment.get('COMPOSE_PROJECT_NAME')
     if project_name:
         return normalize_name(project_name)
 

+ 1 - 1
compose/cli/main.py

@@ -222,7 +222,7 @@ class TopLevelCommand(object):
             --services      Print the service names, one per line.
 
         """
-        config_path = get_config_path_from_options(config_options)
+        config_path = get_config_path_from_options(self.project_dir, config_options)
         compose_config = config.load(config.find(self.project_dir, config_path))
 
         if options['--quiet']:

+ 1 - 0
compose/config/__init__.py

@@ -2,6 +2,7 @@
 from __future__ import absolute_import
 from __future__ import unicode_literals
 
+from . import environment
 from .config import ConfigurationError
 from .config import DOCKER_CONFIG_KEYS
 from .config import find

+ 27 - 13
compose/config/config.py

@@ -17,6 +17,7 @@ from cached_property import cached_property
 from ..const import COMPOSEFILE_V1 as V1
 from ..const import COMPOSEFILE_V2_0 as V2_0
 from ..utils import build_string_dict
+from .environment import Environment
 from .errors import CircularReference
 from .errors import ComposeFileNotFound
 from .errors import ConfigurationError
@@ -211,7 +212,8 @@ def find(base_dir, filenames):
     if filenames == ['-']:
         return ConfigDetails(
             os.getcwd(),
-            [ConfigFile(None, yaml.safe_load(sys.stdin))])
+            [ConfigFile(None, yaml.safe_load(sys.stdin))],
+        )
 
     if filenames:
         filenames = [os.path.join(base_dir, f) for f in filenames]
@@ -221,7 +223,8 @@ def find(base_dir, filenames):
     log.debug("Using configuration files: {}".format(",".join(filenames)))
     return ConfigDetails(
         os.path.dirname(filenames[0]),
-        [ConfigFile.from_filename(f) for f in filenames])
+        [ConfigFile.from_filename(f) for f in filenames],
+    )
 
 
 def validate_config_version(config_files):
@@ -288,6 +291,10 @@ def load(config_details):
     """
     validate_config_version(config_details.config_files)
 
+    # load environment in working dir for later use in interpolation
+    # it is done here to avoid having to pass down working_dir
+    Environment.get_instance(config_details.working_dir)
+
     processed_files = [
         process_config_file(config_file)
         for config_file in config_details.config_files
@@ -302,9 +309,8 @@ def load(config_details):
         config_details.config_files, 'get_networks', 'Network'
     )
     service_dicts = load_services(
-        config_details.working_dir,
-        main_file,
-        [file.get_service_dicts() for file in config_details.config_files])
+        config_details, main_file,
+    )
 
     if main_file.version != V1:
         for service_dict in service_dicts:
@@ -348,14 +354,16 @@ def load_mapping(config_files, get_func, entity_type):
     return mapping
 
 
-def load_services(working_dir, config_file, service_configs):
+def load_services(config_details, config_file):
     def build_service(service_name, service_dict, service_names):
         service_config = ServiceConfig.with_abs_paths(
-            working_dir,
+            config_details.working_dir,
             config_file.filename,
             service_name,
             service_dict)
-        resolver = ServiceExtendsResolver(service_config, config_file)
+        resolver = ServiceExtendsResolver(
+            service_config, config_file
+        )
         service_dict = process_service(resolver.run())
 
         service_config = service_config._replace(config=service_dict)
@@ -383,6 +391,10 @@ def load_services(working_dir, config_file, service_configs):
             for name in all_service_names
         }
 
+    service_configs = [
+        file.get_service_dicts() for file in config_details.config_files
+    ]
+
     service_config = service_configs[0]
     for next_config in service_configs[1:]:
         service_config = merge_services(service_config, next_config)
@@ -462,8 +474,8 @@ class ServiceExtendsResolver(object):
         extends_file = ConfigFile.from_filename(config_path)
         validate_config_version([self.config_file, extends_file])
         extended_file = process_config_file(
-            extends_file,
-            service_name=service_name)
+            extends_file, service_name=service_name
+        )
         service_config = extended_file.get_service(service_name)
 
         return config_path, service_config, service_name
@@ -476,7 +488,8 @@ class ServiceExtendsResolver(object):
                 service_name,
                 service_dict),
             self.config_file,
-            already_seen=self.already_seen + [self.signature])
+            already_seen=self.already_seen + [self.signature],
+        )
 
         service_config = resolver.run()
         other_service_dict = process_service(service_config)
@@ -824,10 +837,11 @@ def parse_ulimits(ulimits):
 
 
 def resolve_env_var(key, val):
+    environment = Environment.get_instance()
     if val is not None:
         return key, val
-    elif key in os.environ:
-        return key, os.environ[key]
+    elif key in environment:
+        return key, environment[key]
     else:
         return key, None
 

+ 69 - 0
compose/config/environment.py

@@ -0,0 +1,69 @@
+from __future__ import absolute_import
+from __future__ import unicode_literals
+
+import logging
+import os
+
+from .errors import ConfigurationError
+
+log = logging.getLogger(__name__)
+
+
+class BlankDefaultDict(dict):
+    def __init__(self, *args, **kwargs):
+        super(BlankDefaultDict, self).__init__(*args, **kwargs)
+        self.missing_keys = []
+
+    def __getitem__(self, key):
+        try:
+            return super(BlankDefaultDict, self).__getitem__(key)
+        except KeyError:
+            if key not in self.missing_keys:
+                log.warn(
+                    "The {} variable is not set. Defaulting to a blank string."
+                    .format(key)
+                )
+                self.missing_keys.append(key)
+
+            return ""
+
+
+class Environment(BlankDefaultDict):
+    __instance = None
+
+    @classmethod
+    def get_instance(cls, base_dir='.'):
+        if cls.__instance:
+            return cls.__instance
+
+        instance = cls(base_dir)
+        cls.__instance = instance
+        return instance
+
+    @classmethod
+    def reset(cls):
+        cls.__instance = None
+
+    def __init__(self, base_dir):
+        super(Environment, self).__init__()
+        self.load_environment_file(os.path.join(base_dir, '.env'))
+        self.update(os.environ)
+
+    def load_environment_file(self, path):
+        if not os.path.exists(path):
+            return
+        mapping = {}
+        with open(path, 'r') as f:
+            for line in f.readlines():
+                line = line.strip()
+                if '=' not in line:
+                    raise ConfigurationError(
+                        'Invalid environment variable mapping in env file. '
+                        'Missing "=" in "{0}"'.format(line)
+                    )
+                mapping.__setitem__(*line.split('=', 1))
+        self.update(mapping)
+
+
+def get_instance(base_dir=None):
+    return Environment.get_instance(base_dir)

+ 2 - 21
compose/config/interpolation.py

@@ -2,17 +2,17 @@ from __future__ import absolute_import
 from __future__ import unicode_literals
 
 import logging
-import os
 from string import Template
 
 import six
 
+from .environment import Environment
 from .errors import ConfigurationError
 log = logging.getLogger(__name__)
 
 
 def interpolate_environment_variables(config, section):
-    mapping = BlankDefaultDict(os.environ)
+    mapping = Environment.get_instance()
 
     def process_item(name, config_dict):
         return dict(
@@ -60,25 +60,6 @@ def interpolate(string, mapping):
         raise InvalidInterpolation(string)
 
 
-class BlankDefaultDict(dict):
-    def __init__(self, *args, **kwargs):
-        super(BlankDefaultDict, self).__init__(*args, **kwargs)
-        self.missing_keys = []
-
-    def __getitem__(self, key):
-        try:
-            return super(BlankDefaultDict, self).__getitem__(key)
-        except KeyError:
-            if key not in self.missing_keys:
-                log.warn(
-                    "The {} variable is not set. Defaulting to a blank string."
-                    .format(key)
-                )
-                self.missing_keys.append(key)
-
-            return ""
-
-
 class InvalidInterpolation(Exception):
     def __init__(self, string):
         self.string = string

+ 2 - 2
tests/acceptance/cli_test.py

@@ -15,7 +15,7 @@ from operator import attrgetter
 import yaml
 from docker import errors
 
-from .. import mock
+from ..helpers import clear_environment
 from compose.cli.command import get_project
 from compose.container import Container
 from compose.project import OneOffFilter
@@ -1452,7 +1452,7 @@ class CLITestCase(DockerClientTestCase):
         self.assertEqual(len(containers), 1)
         self.assertIn("FOO=1", containers[0].get('Config.Env'))
 
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_home_and_env_var_in_volume_path(self):
         os.environ['VOLUME_NAME'] = 'my-volume'
         os.environ['HOME'] = '/tmp/home-dir'

+ 13 - 0
tests/helpers.py

@@ -1,9 +1,14 @@
 from __future__ import absolute_import
 from __future__ import unicode_literals
 
+import functools
+import os
+
+from . import mock
 from compose.config.config import ConfigDetails
 from compose.config.config import ConfigFile
 from compose.config.config import load
+from compose.config.environment import Environment
 
 
 def build_config(contents, **kwargs):
@@ -14,3 +19,11 @@ def build_config_details(contents, working_dir='working_dir', filename='filename
     return ConfigDetails(
         working_dir,
         [ConfigFile(filename, contents)])
+
+
+def clear_environment(f):
+    @functools.wraps(f)
+    def wrapper(self, *args, **kwargs):
+        Environment.reset()
+        with mock.patch.dict(os.environ):
+            f(self, *args, **kwargs)

+ 2 - 1
tests/integration/service_test.py

@@ -12,6 +12,7 @@ from six import StringIO
 from six import text_type
 
 from .. import mock
+from ..helpers import clear_environment
 from .testcases import DockerClientTestCase
 from .testcases import get_links
 from .testcases import pull_busybox
@@ -912,7 +913,7 @@ class ServiceTest(DockerClientTestCase):
         }.items():
             self.assertEqual(env[k], v)
 
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_resolve_env(self):
         os.environ['FILE_DEF'] = 'E1'
         os.environ['FILE_DEF_EMPTY'] = 'E2'

+ 4 - 3
tests/unit/cli_test.py

@@ -11,6 +11,7 @@ import pytest
 from .. import mock
 from .. import unittest
 from ..helpers import build_config
+from ..helpers import clear_environment
 from compose.cli.command import get_project
 from compose.cli.command import get_project_name
 from compose.cli.docopt_command import NoSuchCommand
@@ -43,11 +44,11 @@ class CLITestCase(unittest.TestCase):
         project_name = get_project_name(None, project_name=name)
         self.assertEquals('explicitprojectname', project_name)
 
+    @clear_environment
     def test_project_name_from_environment_new_var(self):
         name = 'namefromenv'
-        with mock.patch.dict(os.environ):
-            os.environ['COMPOSE_PROJECT_NAME'] = name
-            project_name = get_project_name(None)
+        os.environ['COMPOSE_PROJECT_NAME'] = name
+        project_name = get_project_name(None)
         self.assertEquals(project_name, name)
 
     def test_project_name_with_empty_environment_var(self):

+ 19 - 17
tests/unit/config/config_test.py

@@ -23,6 +23,7 @@ from compose.config.types import VolumeSpec
 from compose.const import IS_WINDOWS_PLATFORM
 from tests import mock
 from tests import unittest
+from tests.helpers import clear_environment
 
 DEFAULT_VERSION = V2_0
 
@@ -1581,7 +1582,7 @@ class PortsTest(unittest.TestCase):
 
 
 class InterpolationTest(unittest.TestCase):
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_config_file_with_environment_variable(self):
         os.environ.update(
             IMAGE="busybox",
@@ -1604,7 +1605,7 @@ class InterpolationTest(unittest.TestCase):
             }
         ])
 
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_unset_variable_produces_warning(self):
         os.environ.pop('FOO', None)
         os.environ.pop('BAR', None)
@@ -1628,7 +1629,7 @@ class InterpolationTest(unittest.TestCase):
             self.assertIn('BAR', warnings[0])
             self.assertIn('FOO', warnings[1])
 
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_invalid_interpolation(self):
         with self.assertRaises(config.ConfigurationError) as cm:
             config.load(
@@ -1667,7 +1668,7 @@ class VolumeConfigTest(unittest.TestCase):
         d = make_service_dict('foo', {'build': '.', 'volumes': ['/data']}, working_dir='.')
         self.assertEqual(d['volumes'], ['/data'])
 
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_volume_binding_with_environment_variable(self):
         os.environ['VOLUME_PATH'] = '/host/path'
 
@@ -1681,7 +1682,7 @@ class VolumeConfigTest(unittest.TestCase):
         self.assertEqual(d['volumes'], [VolumeSpec.parse('/host/path:/container/path')])
 
     @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason='posix paths')
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_volume_binding_with_home(self):
         os.environ['HOME'] = '/home/user'
         d = make_service_dict('foo', {'build': '.', 'volumes': ['~:/container/path']}, working_dir='.')
@@ -1739,7 +1740,7 @@ class VolumeConfigTest(unittest.TestCase):
             working_dir='c:\\Users\\me\\myproject')
         self.assertEqual(d['volumes'], ['c:\\Users\\me\\otherproject:/data'])
 
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_home_directory_with_driver_does_not_expand(self):
         os.environ['NAME'] = 'surprise!'
         d = make_service_dict('foo', {
@@ -2025,7 +2026,7 @@ class EnvTest(unittest.TestCase):
     def test_parse_environment_empty(self):
         self.assertEqual(config.parse_environment(None), {})
 
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_resolve_environment(self):
         os.environ['FILE_DEF'] = 'E1'
         os.environ['FILE_DEF_EMPTY'] = 'E2'
@@ -2072,7 +2073,7 @@ class EnvTest(unittest.TestCase):
         assert 'Couldn\'t find env file' in exc.exconly()
         assert 'nonexistent.env' in exc.exconly()
 
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_resolve_environment_from_env_file_with_empty_values(self):
         os.environ['FILE_DEF'] = 'E1'
         os.environ['FILE_DEF_EMPTY'] = 'E2'
@@ -2087,7 +2088,7 @@ class EnvTest(unittest.TestCase):
             },
         )
 
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_resolve_build_args(self):
         os.environ['env_arg'] = 'value2'
 
@@ -2106,7 +2107,7 @@ class EnvTest(unittest.TestCase):
         )
 
     @pytest.mark.xfail(IS_WINDOWS_PLATFORM, reason='paths use slash')
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_resolve_path(self):
         os.environ['HOSTENV'] = '/tmp'
         os.environ['CONTAINERENV'] = '/host/tmp'
@@ -2393,7 +2394,7 @@ class ExtendsTest(unittest.TestCase):
         assert 'net: container' in excinfo.exconly()
         assert 'cannot be extended' in excinfo.exconly()
 
-    @mock.patch.dict(os.environ)
+    @clear_environment
     def test_load_config_runs_interpolation_in_extended_service(self):
         os.environ.update(HOSTNAME_VALUE="penguin")
         expected_interpolated_value = "host-penguin"
@@ -2465,6 +2466,7 @@ class ExtendsTest(unittest.TestCase):
             },
         ]))
 
+    @clear_environment
     def test_extends_with_environment_and_env_files(self):
         tmpdir = py.test.ensuretemp('test_extends_with_environment')
         self.addCleanup(tmpdir.remove)
@@ -2520,12 +2522,12 @@ class ExtendsTest(unittest.TestCase):
                 },
             },
         ]
-        with mock.patch.dict(os.environ):
-            os.environ['SECRET'] = 'secret'
-            os.environ['THING'] = 'thing'
-            os.environ['COMMON_ENV_FILE'] = 'secret'
-            os.environ['TOP_ENV_FILE'] = 'secret'
-            config = load_from_filename(str(tmpdir.join('docker-compose.yml')))
+
+        os.environ['SECRET'] = 'secret'
+        os.environ['THING'] = 'thing'
+        os.environ['COMMON_ENV_FILE'] = 'secret'
+        os.environ['TOP_ENV_FILE'] = 'secret'
+        config = load_from_filename(str(tmpdir.join('docker-compose.yml')))
 
         assert config == expected
 

+ 2 - 0
tests/unit/config/interpolation_test.py

@@ -6,12 +6,14 @@ import os
 import mock
 import pytest
 
+from compose.config.environment import Environment
 from compose.config.interpolation import interpolate_environment_variables
 
 
 @pytest.yield_fixture
 def mock_env():
     with mock.patch.dict(os.environ):
+        Environment.reset()
         os.environ['USER'] = 'jenny'
         os.environ['FOO'] = 'bar'
         yield

+ 1 - 1
tests/unit/interpolation_test.py

@@ -3,7 +3,7 @@ from __future__ import unicode_literals
 
 import unittest
 
-from compose.config.interpolation import BlankDefaultDict as bddict
+from compose.config.environment import BlankDefaultDict as bddict
 from compose.config.interpolation import interpolate
 from compose.config.interpolation import InvalidInterpolation