Răsfoiți Sursa

Limit occurrences of creating an environment object.
.env file is always read from the project_dir

Signed-off-by: Joffrey F <[email protected]>

Joffrey F 9 ani în urmă
părinte
comite
36f1b4589c

+ 14 - 9
compose/cli/command.py

@@ -20,22 +20,23 @@ log = logging.getLogger(__name__)
 
 
 def project_from_options(project_dir, options):
+    environment = Environment.from_env_file(project_dir)
     return get_project(
         project_dir,
-        get_config_path_from_options(project_dir, options),
+        get_config_path_from_options(project_dir, options, environment),
         project_name=options.get('--project-name'),
         verbose=options.get('--verbose'),
         host=options.get('--host'),
         tls_config=tls_config_from_options(options),
+        environment=environment
     )
 
 
-def get_config_path_from_options(base_dir, options):
+def get_config_path_from_options(base_dir, options, environment):
     file_option = options.get('--file')
     if file_option:
         return file_option
 
-    environment = Environment.from_env_file(base_dir)
     config_files = environment.get('COMPOSE_FILE')
     if config_files:
         return config_files.split(os.pathsep)
@@ -55,11 +56,14 @@ def get_client(verbose=False, version=None, tls_config=None, host=None):
 
 
 def get_project(project_dir, config_path=None, project_name=None, verbose=False,
-                host=None, tls_config=None):
-    config_details = config.find(project_dir, config_path)
-    project_name = get_project_name(config_details.working_dir, project_name)
+                host=None, tls_config=None, environment=None):
+    if not environment:
+        environment = Environment.from_env_file(project_dir)
+    config_details = config.find(project_dir, config_path, environment)
+    project_name = get_project_name(
+        config_details.working_dir, project_name, environment
+    )
     config_data = config.load(config_details)
-    environment = Environment.from_env_file(project_dir)
 
     api_version = environment.get(
         'COMPOSE_API_VERSION',
@@ -72,11 +76,12 @@ def get_project(project_dir, config_path=None, project_name=None, verbose=False,
     return Project.from_config(project_name, config_data, client)
 
 
-def get_project_name(working_dir, project_name=None):
+def get_project_name(working_dir, project_name=None, environment=None):
     def normalize_name(name):
         return re.sub(r'[^a-z0-9]', '', name.lower())
 
-    environment = Environment.from_env_file(working_dir)
+    if not environment:
+        environment = Environment.from_env_file(working_dir)
     project_name = project_name or environment.get('COMPOSE_PROJECT_NAME')
     if project_name:
         return normalize_name(project_name)

+ 8 - 2
compose/cli/main.py

@@ -17,6 +17,7 @@ from .. import __version__
 from ..config import config
 from ..config import ConfigurationError
 from ..config import parse_environment
+from ..config.environment import Environment
 from ..config.serialize import serialize_config
 from ..const import DEFAULT_TIMEOUT
 from ..const import IS_WINDOWS_PLATFORM
@@ -222,8 +223,13 @@ class TopLevelCommand(object):
             --services      Print the service names, one per line.
 
         """
-        config_path = get_config_path_from_options(self.project_dir, config_options)
-        compose_config = config.load(config.find(self.project_dir, config_path))
+        environment = Environment.from_env_file(self.project_dir)
+        config_path = get_config_path_from_options(
+            self.project_dir, config_options, environment
+        )
+        compose_config = config.load(
+            config.find(self.project_dir, config_path, environment)
+        )
 
         if options['--quiet']:
             return

+ 7 - 7
compose/config/config.py

@@ -124,13 +124,11 @@ class ConfigDetails(namedtuple('_ConfigDetails', 'working_dir config_files envir
     :param environment: computed environment values for this project
     :type  environment: :class:`environment.Environment`
      """
-
-    def __new__(cls, working_dir, config_files):
+    def __new__(cls, working_dir, config_files, environment=None):
+        if environment is None:
+            environment = Environment.from_env_file(working_dir)
         return super(ConfigDetails, cls).__new__(
-            cls,
-            working_dir,
-            config_files,
-            Environment.from_env_file(working_dir),
+            cls, working_dir, config_files, environment
         )
 
 
@@ -219,11 +217,12 @@ class ServiceConfig(namedtuple('_ServiceConfig', 'working_dir filename name conf
             config)
 
 
-def find(base_dir, filenames):
+def find(base_dir, filenames, environment):
     if filenames == ['-']:
         return ConfigDetails(
             os.getcwd(),
             [ConfigFile(None, yaml.safe_load(sys.stdin))],
+            environment
         )
 
     if filenames:
@@ -235,6 +234,7 @@ def find(base_dir, filenames):
     return ConfigDetails(
         os.path.dirname(filenames[0]),
         [ConfigFile.from_filename(f) for f in filenames],
+        environment
     )
 
 

+ 2 - 1
tests/helpers.py

@@ -13,4 +13,5 @@ def build_config(contents, **kwargs):
 def build_config_details(contents, working_dir='working_dir', filename='filename.yml'):
     return ConfigDetails(
         working_dir,
-        [ConfigFile(filename, contents)])
+        [ConfigFile(filename, contents)],
+    )

+ 15 - 5
tests/unit/cli/command_test.py

@@ -6,6 +6,7 @@ import os
 import pytest
 
 from compose.cli.command import get_config_path_from_options
+from compose.config.environment import Environment
 from compose.const import IS_WINDOWS_PLATFORM
 from tests import mock
 
@@ -15,24 +16,33 @@ class TestGetConfigPathFromOptions(object):
     def test_path_from_options(self):
         paths = ['one.yml', 'two.yml']
         opts = {'--file': paths}
-        assert get_config_path_from_options('.', opts) == paths
+        environment = Environment.from_env_file('.')
+        assert get_config_path_from_options('.', opts, environment) == paths
 
     def test_single_path_from_env(self):
         with mock.patch.dict(os.environ):
             os.environ['COMPOSE_FILE'] = 'one.yml'
-            assert get_config_path_from_options('.', {}) == ['one.yml']
+            environment = Environment.from_env_file('.')
+            assert get_config_path_from_options('.', {}, environment) == ['one.yml']
 
     @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason='posix separator')
     def test_multiple_path_from_env(self):
         with mock.patch.dict(os.environ):
             os.environ['COMPOSE_FILE'] = 'one.yml:two.yml'
-            assert get_config_path_from_options('.', {}) == ['one.yml', 'two.yml']
+            environment = Environment.from_env_file('.')
+            assert get_config_path_from_options(
+                '.', {}, environment
+            ) == ['one.yml', 'two.yml']
 
     @pytest.mark.skipif(not IS_WINDOWS_PLATFORM, reason='windows separator')
     def test_multiple_path_from_env_windows(self):
         with mock.patch.dict(os.environ):
             os.environ['COMPOSE_FILE'] = 'one.yml;two.yml'
-            assert get_config_path_from_options('.', {}) == ['one.yml', 'two.yml']
+            environment = Environment.from_env_file('.')
+            assert get_config_path_from_options(
+                '.', {}, environment
+            ) == ['one.yml', 'two.yml']
 
     def test_no_path(self):
-        assert not get_config_path_from_options('.', {})
+        environment = Environment.from_env_file('.')
+        assert not get_config_path_from_options('.', {}, environment)

+ 11 - 3
tests/unit/config/config_test.py

@@ -1584,8 +1584,11 @@ class PortsTest(unittest.TestCase):
 class InterpolationTest(unittest.TestCase):
     @mock.patch.dict(os.environ)
     def test_config_file_with_environment_file(self):
+        project_dir = 'tests/fixtures/default-env-file'
         service_dicts = config.load(
-            config.find('tests/fixtures/default-env-file', None)
+            config.find(
+                project_dir, None, Environment.from_env_file(project_dir)
+            )
         ).services
 
         self.assertEqual(service_dicts[0], {
@@ -1597,6 +1600,7 @@ class InterpolationTest(unittest.TestCase):
 
     @mock.patch.dict(os.environ)
     def test_config_file_with_environment_variable(self):
+        project_dir = 'tests/fixtures/environment-interpolation'
         os.environ.update(
             IMAGE="busybox",
             HOST_PORT="80",
@@ -1604,7 +1608,9 @@ class InterpolationTest(unittest.TestCase):
         )
 
         service_dicts = config.load(
-            config.find('tests/fixtures/environment-interpolation', None),
+            config.find(
+                project_dir, None, Environment.from_env_file(project_dir)
+            )
         ).services
 
         self.assertEqual(service_dicts, [
@@ -2149,7 +2155,9 @@ class EnvTest(unittest.TestCase):
 
 
 def load_from_filename(filename):
-    return config.load(config.find('.', [filename])).services
+    return config.load(
+        config.find('.', [filename], Environment.from_env_file('.'))
+    ).services
 
 
 class ExtendsTest(unittest.TestCase):