소스 검색

Avoid encoding crash in log_api_error

Signed-off-by: Joffrey F <[email protected]>
Joffrey F 8 년 전
부모
커밋
c3bcd59aeb
2개의 변경된 파일24개의 추가작업 그리고 3개의 파일을 삭제
  1. 8 3
      compose/cli/errors.py
  2. 16 0
      tests/unit/cli/errors_test.py

+ 8 - 3
compose/cli/errors.py

@@ -7,6 +7,7 @@ import socket
 from distutils.spawn import find_executable
 from textwrap import dedent
 
+import six
 from docker.errors import APIError
 from requests.exceptions import ConnectionError as RequestsConnectionError
 from requests.exceptions import ReadTimeout
@@ -68,14 +69,18 @@ def log_timeout_error(timeout):
 
 
 def log_api_error(e, client_version):
-    if b'client is newer than server' not in e.explanation:
-        log.error(e.explanation)
+    explanation = e.explanation
+    if isinstance(explanation, six.binary_type):
+        explanation = explanation.decode('utf-8')
+
+    if 'client is newer than server' not in explanation:
+        log.error(explanation)
         return
 
     version = API_VERSION_TO_ENGINE_VERSION.get(client_version)
     if not version:
         # They've set a custom API version
-        log.error(e.explanation)
+        log.error(explanation)
         return
 
     log.error(

+ 16 - 0
tests/unit/cli/errors_test.py

@@ -42,10 +42,26 @@ class TestHandleConnectionErrors(object):
         _, args, _ = mock_logging.error.mock_calls[0]
         assert "Docker Engine of version 1.10.0 or greater" in args[0]
 
+    def test_api_error_version_mismatch_unicode_explanation(self, mock_logging):
+        with pytest.raises(errors.ConnectionError):
+            with handle_connection_errors(mock.Mock(api_version='1.22')):
+                raise APIError(None, None, u"client is newer than server")
+
+        _, args, _ = mock_logging.error.mock_calls[0]
+        assert "Docker Engine of version 1.10.0 or greater" in args[0]
+
     def test_api_error_version_other(self, mock_logging):
         msg = b"Something broke!"
         with pytest.raises(errors.ConnectionError):
             with handle_connection_errors(mock.Mock(api_version='1.22')):
                 raise APIError(None, None, msg)
 
+        mock_logging.error.assert_called_once_with(msg.decode('utf-8'))
+
+    def test_api_error_version_other_unicode_explanation(self, mock_logging):
+        msg = u"Something broke!"
+        with pytest.raises(errors.ConnectionError):
+            with handle_connection_errors(mock.Mock(api_version='1.22')):
+                raise APIError(None, None, msg)
+
         mock_logging.error.assert_called_once_with(msg)