Browse Source

Convert some cli tests to pytest.

Signed-off-by: Daniel Nephin <[email protected]>
Daniel Nephin 9 years ago
parent
commit
886328640f
1 changed files with 35 additions and 32 deletions
  1. 35 32
      tests/unit/cli/main_test.py

+ 35 - 32
tests/unit/cli/main_test.py

@@ -3,6 +3,8 @@ from __future__ import unicode_literals
 
 import logging
 
+import pytest
+
 from compose import container
 from compose.cli.errors import UserError
 from compose.cli.formatter import ConsoleWarningFormatter
@@ -11,7 +13,6 @@ from compose.cli.main import convergence_strategy_from_opts
 from compose.cli.main import setup_console_handler
 from compose.service import ConvergenceStrategy
 from tests import mock
-from tests import unittest
 
 
 def mock_container(service, number):
@@ -22,7 +23,14 @@ def mock_container(service, number):
         name_without_project='{0}_{1}'.format(service, number))
 
 
-class CLIMainTestCase(unittest.TestCase):
[email protected]
+def logging_handler():
+    stream = mock.Mock()
+    stream.isatty.return_value = True
+    return logging.StreamHandler(stream=stream)
+
+
+class TestCLIMainTestCase(object):
 
     def test_build_log_printer(self):
         containers = [
@@ -34,7 +42,7 @@ class CLIMainTestCase(unittest.TestCase):
         ]
         service_names = ['web', 'db']
         log_printer = build_log_printer(containers, service_names, True, False, {'follow': True})
-        self.assertEqual(log_printer.containers, containers[:3])
+        assert log_printer.containers == containers[:3]
 
     def test_build_log_printer_all_services(self):
         containers = [
@@ -44,58 +52,53 @@ class CLIMainTestCase(unittest.TestCase):
         ]
         service_names = []
         log_printer = build_log_printer(containers, service_names, True, False, {'follow': True})
-        self.assertEqual(log_printer.containers, containers)
-
+        assert log_printer.containers == containers
 
-class SetupConsoleHandlerTestCase(unittest.TestCase):
 
-    def setUp(self):
-        self.stream = mock.Mock()
-        self.stream.isatty.return_value = True
-        self.handler = logging.StreamHandler(stream=self.stream)
+class TestSetupConsoleHandlerTestCase(object):
 
-    def test_with_tty_verbose(self):
-        setup_console_handler(self.handler, True)
-        assert type(self.handler.formatter) == ConsoleWarningFormatter
-        assert '%(name)s' in self.handler.formatter._fmt
-        assert '%(funcName)s' in self.handler.formatter._fmt
+    def test_with_tty_verbose(self, logging_handler):
+        setup_console_handler(logging_handler, True)
+        assert type(logging_handler.formatter) == ConsoleWarningFormatter
+        assert '%(name)s' in logging_handler.formatter._fmt
+        assert '%(funcName)s' in logging_handler.formatter._fmt
 
-    def test_with_tty_not_verbose(self):
-        setup_console_handler(self.handler, False)
-        assert type(self.handler.formatter) == ConsoleWarningFormatter
-        assert '%(name)s' not in self.handler.formatter._fmt
-        assert '%(funcName)s' not in self.handler.formatter._fmt
+    def test_with_tty_not_verbose(self, logging_handler):
+        setup_console_handler(logging_handler, False)
+        assert type(logging_handler.formatter) == ConsoleWarningFormatter
+        assert '%(name)s' not in logging_handler.formatter._fmt
+        assert '%(funcName)s' not in logging_handler.formatter._fmt
 
-    def test_with_not_a_tty(self):
-        self.stream.isatty.return_value = False
-        setup_console_handler(self.handler, False)
-        assert type(self.handler.formatter) == logging.Formatter
+    def test_with_not_a_tty(self, logging_handler):
+        logging_handler.stream.isatty.return_value = False
+        setup_console_handler(logging_handler, False)
+        assert type(logging_handler.formatter) == logging.Formatter
 
 
-class ConvergeStrategyFromOptsTestCase(unittest.TestCase):
+class TestConvergeStrategyFromOptsTestCase(object):
 
     def test_invalid_opts(self):
         options = {'--force-recreate': True, '--no-recreate': True}
-        with self.assertRaises(UserError):
+        with pytest.raises(UserError):
             convergence_strategy_from_opts(options)
 
     def test_always(self):
         options = {'--force-recreate': True, '--no-recreate': False}
-        self.assertEqual(
-            convergence_strategy_from_opts(options),
+        assert (
+            convergence_strategy_from_opts(options) ==
             ConvergenceStrategy.always
         )
 
     def test_never(self):
         options = {'--force-recreate': False, '--no-recreate': True}
-        self.assertEqual(
-            convergence_strategy_from_opts(options),
+        assert (
+            convergence_strategy_from_opts(options) ==
             ConvergenceStrategy.never
         )
 
     def test_changed(self):
         options = {'--force-recreate': False, '--no-recreate': False}
-        self.assertEqual(
-            convergence_strategy_from_opts(options),
+        assert (
+            convergence_strategy_from_opts(options) ==
             ConvergenceStrategy.changed
         )