parallel_test.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import unittest
  2. from threading import Lock
  3. from docker.errors import APIError
  4. from compose.parallel import GlobalLimit
  5. from compose.parallel import parallel_execute
  6. from compose.parallel import parallel_execute_iter
  7. from compose.parallel import ParallelStreamWriter
  8. from compose.parallel import UpstreamError
  9. web = 'web'
  10. db = 'db'
  11. data_volume = 'data_volume'
  12. cache = 'cache'
  13. objects = [web, db, data_volume, cache]
  14. deps = {
  15. web: [db, cache],
  16. db: [data_volume],
  17. data_volume: [],
  18. cache: [],
  19. }
  20. def get_deps(obj):
  21. return [(dep, None) for dep in deps[obj]]
  22. class ParallelTest(unittest.TestCase):
  23. def test_parallel_execute(self):
  24. results, errors = parallel_execute(
  25. objects=[1, 2, 3, 4, 5],
  26. func=lambda x: x * 2,
  27. get_name=str,
  28. msg="Doubling",
  29. )
  30. assert sorted(results) == [2, 4, 6, 8, 10]
  31. assert errors == {}
  32. def test_parallel_execute_with_limit(self):
  33. limit = 1
  34. tasks = 20
  35. lock = Lock()
  36. def f(obj):
  37. locked = lock.acquire(False)
  38. # we should always get the lock because we're the only thread running
  39. assert locked
  40. lock.release()
  41. return None
  42. results, errors = parallel_execute(
  43. objects=list(range(tasks)),
  44. func=f,
  45. get_name=str,
  46. msg="Testing",
  47. limit=limit,
  48. )
  49. assert results == tasks * [None]
  50. assert errors == {}
  51. def test_parallel_execute_with_global_limit(self):
  52. GlobalLimit.set_global_limit(1)
  53. self.addCleanup(GlobalLimit.set_global_limit, None)
  54. tasks = 20
  55. lock = Lock()
  56. def f(obj):
  57. locked = lock.acquire(False)
  58. # we should always get the lock because we're the only thread running
  59. assert locked
  60. lock.release()
  61. return None
  62. results, errors = parallel_execute(
  63. objects=list(range(tasks)),
  64. func=f,
  65. get_name=str,
  66. msg="Testing",
  67. )
  68. assert results == tasks * [None]
  69. assert errors == {}
  70. def test_parallel_execute_with_deps(self):
  71. log = []
  72. def process(x):
  73. log.append(x)
  74. parallel_execute(
  75. objects=objects,
  76. func=process,
  77. get_name=lambda obj: obj,
  78. msg="Processing",
  79. get_deps=get_deps,
  80. )
  81. assert sorted(log) == sorted(objects)
  82. assert log.index(data_volume) < log.index(db)
  83. assert log.index(db) < log.index(web)
  84. assert log.index(cache) < log.index(web)
  85. def test_parallel_execute_with_upstream_errors(self):
  86. log = []
  87. def process(x):
  88. if x is data_volume:
  89. raise APIError(None, None, "Something went wrong")
  90. log.append(x)
  91. parallel_execute(
  92. objects=objects,
  93. func=process,
  94. get_name=lambda obj: obj,
  95. msg="Processing",
  96. get_deps=get_deps,
  97. )
  98. assert log == [cache]
  99. events = [
  100. (obj, result, type(exception))
  101. for obj, result, exception
  102. in parallel_execute_iter(objects, process, get_deps, None)
  103. ]
  104. assert (cache, None, type(None)) in events
  105. assert (data_volume, None, APIError) in events
  106. assert (db, None, UpstreamError) in events
  107. assert (web, None, UpstreamError) in events
  108. def test_parallel_execute_alignment(capsys):
  109. ParallelStreamWriter.instance = None
  110. results, errors = parallel_execute(
  111. objects=["short", "a very long name"],
  112. func=lambda x: x,
  113. get_name=str,
  114. msg="Aligning",
  115. )
  116. assert errors == {}
  117. _, err = capsys.readouterr()
  118. a, b = err.split('\n')[:2]
  119. assert a.index('...') == b.index('...')
  120. def test_parallel_execute_ansi(capsys):
  121. ParallelStreamWriter.instance = None
  122. ParallelStreamWriter.set_noansi(value=False)
  123. results, errors = parallel_execute(
  124. objects=["something", "something more"],
  125. func=lambda x: x,
  126. get_name=str,
  127. msg="Control characters",
  128. )
  129. assert errors == {}
  130. _, err = capsys.readouterr()
  131. assert "\x1b" in err
  132. def test_parallel_execute_noansi(capsys):
  133. ParallelStreamWriter.instance = None
  134. ParallelStreamWriter.set_noansi()
  135. results, errors = parallel_execute(
  136. objects=["something", "something more"],
  137. func=lambda x: x,
  138. get_name=str,
  139. msg="Control characters",
  140. )
  141. assert errors == {}
  142. _, err = capsys.readouterr()
  143. assert "\x1b" not in err