parallel_test.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from __future__ import absolute_import
  2. from __future__ import unicode_literals
  3. import six
  4. from docker.errors import APIError
  5. from compose.parallel import parallel_execute
  6. web = 'web'
  7. db = 'db'
  8. data_volume = 'data_volume'
  9. cache = 'cache'
  10. objects = [web, db, data_volume, cache]
  11. deps = {
  12. web: [db, cache],
  13. db: [data_volume],
  14. data_volume: [],
  15. cache: [],
  16. }
  17. def get_deps(obj):
  18. return deps[obj]
  19. def test_parallel_execute():
  20. results = parallel_execute(
  21. objects=[1, 2, 3, 4, 5],
  22. func=lambda x: x * 2,
  23. get_name=six.text_type,
  24. msg="Doubling",
  25. )
  26. assert sorted(results) == [2, 4, 6, 8, 10]
  27. def test_parallel_execute_with_deps():
  28. log = []
  29. def process(x):
  30. log.append(x)
  31. parallel_execute(
  32. objects=objects,
  33. func=process,
  34. get_name=lambda obj: obj,
  35. msg="Processing",
  36. get_deps=get_deps,
  37. )
  38. assert sorted(log) == sorted(objects)
  39. assert log.index(data_volume) < log.index(db)
  40. assert log.index(db) < log.index(web)
  41. assert log.index(cache) < log.index(web)
  42. def test_parallel_execute_with_upstream_errors():
  43. log = []
  44. def process(x):
  45. if x is data_volume:
  46. raise APIError(None, None, "Something went wrong")
  47. log.append(x)
  48. parallel_execute(
  49. objects=objects,
  50. func=process,
  51. get_name=lambda obj: obj,
  52. msg="Processing",
  53. get_deps=get_deps,
  54. )
  55. assert log == [cache]