sandbox_tests.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import os
  2. import shutil
  3. import tempfile
  4. import unittest
  5. import requests
  6. from sandbox import (
  7. Error,
  8. ExecuteResponse,
  9. ExecutionEventTypeDisplayData,
  10. ExecutionEventTypeError,
  11. ExecutionEventTypeStream,
  12. ExecutionStatusOK,
  13. ExecutionStatusTimeout,
  14. )
  15. # We'll create a temporary directory for the tests to avoid any side effects.
  16. temp_dir = tempfile.mkdtemp()
  17. BASE_URL = "http://localhost:8888/"
  18. def url(path: str) -> str:
  19. return BASE_URL + path
  20. class TestExecuteHandler(unittest.TestCase):
  21. def must_bind_with_execute_response(self, r: requests.Response) -> ExecuteResponse:
  22. self.assertEqual(r.status_code, 200)
  23. return ExecuteResponse.model_validate_json(r.content)
  24. def must_bind_with_error(self, r: requests.Response) -> Error:
  25. return Error.model_validate_json(r.content)
  26. def test_execute_hello(self):
  27. r = requests.post(
  28. url("execute"), json={"code": "print('hello')", "timeout_secs": 10}
  29. )
  30. res = self.must_bind_with_execute_response(r)
  31. self.assertEqual(len(res.events), 1)
  32. self.assertEqual(res.events[0].type, ExecutionEventTypeStream)
  33. self.assertEqual(res.events[0].data.stream, "stdout") # type: ignore
  34. self.assertEqual(res.events[0].data.text, "hello\n") # type: ignore
  35. def test_execute_timeout(self):
  36. r = requests.post(
  37. url("execute"),
  38. json={"code": "import time\ntime.sleep(5)", "timeout_secs": 1},
  39. )
  40. res = self.must_bind_with_execute_response(r)
  41. self.assertEqual(len(res.events), 0)
  42. self.assertEqual(res.status, ExecutionStatusTimeout)
  43. def test_execute_syntax_error(self):
  44. r = requests.post(
  45. url("execute"), json={"code": "print('hello'", "timeout_secs": 10}
  46. )
  47. err = self.must_bind_with_execute_response(r)
  48. self.assertEqual(err.status, ExecutionStatusOK)
  49. self.assertEqual(len(err.events), 1)
  50. self.assertEqual(err.events[0].type, ExecutionEventTypeError)
  51. self.assertEqual(err.events[0].data.ename, "SyntaxError") # type: ignore
  52. self.assertIsNotNone(err.events[0].data.evalue) # type: ignore
  53. self.assertGreater(len(err.events[0].data.traceback), 0) # type: ignore
  54. def test_execute_invalid_timeout(self):
  55. r = requests.post(
  56. url("execute"),
  57. json={"code": "print('hello')", "timeout_secs": -1},
  58. )
  59. self.must_bind_with_error(r)
  60. def test_execute_display_data(self):
  61. code = """import matplotlib.pyplot as plt
  62. plt.plot([1, 2, 3, 4])
  63. plt.ylabel('some numbers')
  64. plt.show()"""
  65. r = requests.post(url("execute"), json={"code": code, "timeout_secs": 10})
  66. res = self.must_bind_with_execute_response(r)
  67. self.assertEqual(res.status, ExecutionStatusOK)
  68. self.assertEqual(len(res.events), 1)
  69. self.assertEqual(res.events[0].type, ExecutionEventTypeDisplayData)
  70. self.assertIsNotNone(res.events[0].data.variants["image/png"]) # type: ignore
  71. self.assertIsNotNone(res.events[0].data.variants["text/plain"]) # type: ignore
  72. def test_execute_pil_image(self):
  73. code = """from PIL import Image
  74. img = Image.new('RGB', (60, 30), color = 'red')
  75. # Override the show method of the Image class
  76. def new_show(self, *args, **kwargs):
  77. display(self)
  78. Image.Image.show = new_show
  79. img.show()"""
  80. r = requests.post(url("execute"), json={"code": code, "timeout_secs": 10})
  81. res = self.must_bind_with_execute_response(r)
  82. self.assertEqual(res.status, ExecutionStatusOK)
  83. self.assertEqual(len(res.events), 1)
  84. self.assertEqual(res.events[0].type, ExecutionEventTypeDisplayData)
  85. self.assertIsNotNone(res.events[0].data.variants["image/png"]) # type: ignore
  86. self.assertIsNotNone(res.events[0].data.variants["text/plain"]) # type: ignore
  87. class FileUploadHandlerTest(unittest.TestCase):
  88. @classmethod
  89. def setUpClass(cls):
  90. cls.temp_dir = tempfile.mkdtemp()
  91. cls.BASE_URL = f"http://localhost:8888/files/upload/-{cls.temp_dir}/"
  92. def test_upload_file(self):
  93. file_path = os.path.join(self.temp_dir, "test.txt")
  94. large_binary_file = os.urandom(1024 * 1024 * 10) # 10 MB
  95. r = requests.post(self.BASE_URL + "test.txt", data=large_binary_file)
  96. self.assertEqual(r.status_code, 201)
  97. self.assertTrue(os.path.exists(file_path))
  98. with open(file_path, "rb") as f:
  99. self.assertEqual(f.read(), large_binary_file)
  100. def test_upload_existing_file(self):
  101. file_path = os.path.join(self.temp_dir, "existing.txt")
  102. with open(file_path, "wb") as f:
  103. f.write(b"exists")
  104. with open(file_path, "rb") as f:
  105. r = requests.post(self.BASE_URL + "existing.txt", data=f.read())
  106. self.assertEqual(r.status_code, 409)
  107. error = Error.model_validate_json(r.content)
  108. self.assertEqual(error.error, "file already exists")
  109. def test_directory_creation(self):
  110. file_path = os.path.join(self.temp_dir, "newdir", "test.txt")
  111. os.makedirs(os.path.dirname(file_path), exist_ok=True)
  112. r = requests.post(self.BASE_URL + "newdir/test.txt", data=b"test content")
  113. self.assertEqual(r.status_code, 201)
  114. self.assertTrue(os.path.exists(file_path))
  115. with open(file_path, "rb") as f:
  116. self.assertEqual(f.read(), b"test content")
  117. @classmethod
  118. def tearDownClass(cls):
  119. # Clean up the temp_dir after all tests
  120. if os.path.exists(cls.temp_dir):
  121. shutil.rmtree(cls.temp_dir)
  122. if __name__ == "__main__":
  123. unittest.main()