Browse Source

Merge pull request #1521 from dano/validate-service-names

Validate that service names passed to Project.containers aren't bogus.
Aanand Prasad 10 years ago
parent
commit
bc14c473c9
2 changed files with 17 additions and 0 deletions
  1. 12 0
      compose/project.py
  2. 5 0
      tests/integration/cli_test.py

+ 12 - 0
compose/project.py

@@ -99,6 +99,16 @@ class Project(object):
 
         raise NoSuchService(name)
 
+    def validate_service_names(self, service_names):
+        """
+        Validate that the given list of service names only contains valid
+        services. Raises NoSuchService if one of the names is invalid.
+        """
+        valid_names = self.service_names
+        for name in service_names:
+            if name not in valid_names:
+                raise NoSuchService(name)
+
     def get_services(self, service_names=None, include_deps=False):
         """
         Returns a list of this project's services filtered
@@ -276,6 +286,8 @@ class Project(object):
             service.remove_stopped(**options)
 
     def containers(self, service_names=None, stopped=False, one_off=False):
+        if service_names:
+            self.validate_service_names(service_names)
         containers = [
             Container.from_ps(self.client, container)
             for container in self.client.containers(

+ 5 - 0
tests/integration/cli_test.py

@@ -9,6 +9,7 @@ from mock import patch
 
 from .testcases import DockerClientTestCase
 from compose.cli.main import TopLevelCommand
+from compose.project import NoSuchService
 
 
 class CLITestCase(DockerClientTestCase):
@@ -362,6 +363,10 @@ class CLITestCase(DockerClientTestCase):
         self.assertEqual(len(service.containers(stopped=True)), 1)
         self.assertFalse(service.containers(stopped=True)[0].is_running)
 
+    def test_logs_invalid_service_name(self):
+        with self.assertRaises(NoSuchService):
+            self.command.dispatch(['logs', 'madeupname'], None)
+
     def test_kill(self):
         self.command.dispatch(['up', '-d'], None)
         service = self.project.get_service('simple')