瀏覽代碼

Manage encoding errors in progress_stream

Signed-off-by: Joffrey F <[email protected]>
Joffrey F 7 年之前
父節點
當前提交
fd94fab264
共有 2 個文件被更改,包括 51 次插入12 次删除
  1. 20 12
      compose/progress_stream.py
  2. 31 0
      tests/unit/progress_stream_test.py

+ 20 - 12
compose/progress_stream.py

@@ -8,6 +8,14 @@ class StreamOutputError(Exception):
     pass
 
 
+def write_to_stream(s, stream):
+    try:
+        stream.write(s)
+    except UnicodeEncodeError:
+        encoding = getattr(stream, 'encoding', 'ascii')
+        stream.write(s.encode(encoding, errors='replace').decode(encoding))
+
+
 def stream_output(output, stream):
     is_terminal = hasattr(stream, 'isatty') and stream.isatty()
     stream = utils.get_output_stream(stream)
@@ -34,18 +42,18 @@ def stream_output(output, stream):
 
         if image_id not in lines:
             lines[image_id] = len(lines)
-            stream.write("\n")
+            write_to_stream("\n", stream)
 
         diff = len(lines) - lines[image_id]
 
         # move cursor up `diff` rows
-        stream.write("%c[%dA" % (27, diff))
+        write_to_stream("%c[%dA" % (27, diff), stream)
 
         print_output_event(event, stream, is_terminal)
 
         if 'id' in event:
             # move cursor back down
-            stream.write("%c[%dB" % (27, diff))
+            write_to_stream("%c[%dB" % (27, diff), stream)
 
         stream.flush()
 
@@ -60,36 +68,36 @@ def print_output_event(event, stream, is_terminal):
 
     if is_terminal and 'stream' not in event:
         # erase current line
-        stream.write("%c[2K\r" % 27)
+        write_to_stream("%c[2K\r" % 27, stream)
         terminator = "\r"
     elif 'progressDetail' in event:
         return
 
     if 'time' in event:
-        stream.write("[%s] " % event['time'])
+        write_to_stream("[%s] " % event['time'], stream)
 
     if 'id' in event:
-        stream.write("%s: " % event['id'])
+        write_to_stream("%s: " % event['id'], stream)
 
     if 'from' in event:
-        stream.write("(from %s) " % event['from'])
+        write_to_stream("(from %s) " % event['from'], stream)
 
     status = event.get('status', '')
 
     if 'progress' in event:
-        stream.write("%s %s%s" % (status, event['progress'], terminator))
+        write_to_stream("%s %s%s" % (status, event['progress'], terminator), stream)
     elif 'progressDetail' in event:
         detail = event['progressDetail']
         total = detail.get('total')
         if 'current' in detail and total:
             percentage = float(detail['current']) / float(total) * 100
-            stream.write('%s (%.1f%%)%s' % (status, percentage, terminator))
+            write_to_stream('%s (%.1f%%)%s' % (status, percentage, terminator), stream)
         else:
-            stream.write('%s%s' % (status, terminator))
+            write_to_stream('%s%s' % (status, terminator), stream)
     elif 'stream' in event:
-        stream.write("%s%s" % (event['stream'], terminator))
+        write_to_stream("%s%s" % (event['stream'], terminator), stream)
     else:
-        stream.write("%s%s\n" % (status, terminator))
+        write_to_stream("%s%s\n" % (status, terminator), stream)
 
 
 def get_digest_from_pull(events):

+ 31 - 0
tests/unit/progress_stream_test.py

@@ -1,6 +1,13 @@
+# ~*~ encoding: utf-8 ~*~
 from __future__ import absolute_import
 from __future__ import unicode_literals
 
+import io
+import os
+import random
+import shutil
+import tempfile
+
 from six import StringIO
 
 from compose import progress_stream
@@ -66,6 +73,30 @@ class ProgressStreamTestCase(unittest.TestCase):
         events = progress_stream.stream_output(events, output)
         assert len(output.getvalue()) > 0
 
+    def test_mismatched_encoding_stream_write(self):
+        tmpdir = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, tmpdir, True)
+
+        def mktempfile(encoding):
+            fname = os.path.join(tmpdir, hex(random.getrandbits(128))[2:-1])
+            return io.open(fname, mode='w+', encoding=encoding)
+
+        text = '就吃饭'
+        with mktempfile(encoding='utf-8') as tf:
+            progress_stream.write_to_stream(text, tf)
+            tf.seek(0)
+            assert tf.read() == text
+
+        with mktempfile(encoding='utf-32') as tf:
+            progress_stream.write_to_stream(text, tf)
+            tf.seek(0)
+            assert tf.read() == text
+
+        with mktempfile(encoding='ascii') as tf:
+            progress_stream.write_to_stream(text, tf)
+            tf.seek(0)
+            assert tf.read() == '???'
+
 
 def test_get_digest_from_push():
     digest = "sha256:abcd"