sandbox.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. import argparse
  2. import asyncio
  3. import json
  4. import logging
  5. import os
  6. import signal
  7. import sys
  8. from asyncio import Queue
  9. from datetime import datetime, timezone
  10. from typing import Annotated, List, Union
  11. import tornado.escape
  12. import tornado.ioloop
  13. import tornado.web
  14. from annotated_types import Gt
  15. from jupyter_client.asynchronous.client import AsyncKernelClient
  16. from jupyter_client.manager import AsyncKernelManager
  17. from pydantic import BaseModel
  18. # Shell Jupyter message types
  19. JupyterMessageTypeExecuteRequest = "execute_request"
  20. JupyterMessageTypeExecuteReply = "execute_reply"
  21. # IOPub Jupyter message types
  22. JupyterMessageTypeStream = "stream"
  23. JupyterMessageTypeDisplayData = "display_data"
  24. JupyterMessageTypeExecuteResult = "execute_result"
  25. JupyterMessageTypeError = "error"
  26. JupyterMessageTypeStatus = "status"
  27. # Supported Jupyter message types (IOPub only)
  28. JupyterSupportedMessageTypes = [
  29. JupyterMessageTypeStream,
  30. JupyterMessageTypeDisplayData,
  31. JupyterMessageTypeExecuteResult,
  32. JupyterMessageTypeError,
  33. JupyterMessageTypeStatus,
  34. ]
  35. # Kernel execution states
  36. JupyterExecutionStateBusy = "busy"
  37. JupyterExecutionStateIdle = "idle"
  38. JupyterExecutionStateStarting = "starting"
  39. # Saturn execution event types
  40. ExecutionEventTypeStream = "stream"
  41. ExecutionEventTypeDisplayData = "display_data"
  42. ExecutionEventTypeError = "error"
  43. # Saturn execution statuses
  44. ExecutionStatusOK = "ok"
  45. ExecutionStatusTimeout = "timeout"
  46. class ExecutionEventStream(BaseModel):
  47. stream: str
  48. text: str
  49. class ExecutionEventDisplayData(BaseModel):
  50. variants: dict
  51. class ExecutionEventError(BaseModel):
  52. ename: str
  53. evalue: str
  54. traceback: list[str]
  55. class ExecutionEvent(BaseModel):
  56. type: str
  57. timestamp: str # RFC3339
  58. data: Union[
  59. ExecutionEventStream,
  60. ExecutionEventDisplayData,
  61. ExecutionEventError,
  62. ]
  63. class ExecuteRequest(BaseModel):
  64. code: str
  65. timeout_secs: Annotated[int, Gt(0)]
  66. class ExecuteResponse(BaseModel):
  67. status: str
  68. events: List[ExecutionEvent]
  69. class PingResponse(BaseModel):
  70. last_activity: str # RFC3339
  71. class Error(BaseModel):
  72. error: str
  73. def datetime_to_rfc3339(dt: datetime) -> str:
  74. """Convert a datetime to an RFC3339 formatted string."""
  75. return dt.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
  76. def rfc3339_to_datetime(date_string: str) -> datetime:
  77. """Convert an RFC3339 formatted string to a datetime."""
  78. return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%fZ").replace(
  79. tzinfo=timezone.utc
  80. )
  81. logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
  82. async def async_create_kernel(kernel_name: str):
  83. logging.info(f"Starting kernel for spec '{kernel_name}'")
  84. km = AsyncKernelManager(kernel_name=kernel_name)
  85. await km.start_kernel()
  86. client: AsyncKernelClient = km.client()
  87. client.start_channels()
  88. await client.wait_for_ready()
  89. logging.info("Kernel started")
  90. return km, client
  91. msg_id_to_queue: dict[str, Queue] = {}
  92. async def async_msg_producer(km: AsyncKernelManager, kc: AsyncKernelClient):
  93. try:
  94. while True:
  95. logging.info("Waiting for message...")
  96. msg = await kc.get_iopub_msg()
  97. log_jupyter_kernel_message(msg)
  98. parent_msg_id = msg["parent_header"].get("msg_id")
  99. if parent_msg_id in msg_id_to_queue:
  100. await msg_id_to_queue[parent_msg_id].put(msg)
  101. except Exception as e:
  102. logging.error(f"Error in message producer: {e}")
  103. await async_shutdown(km)
  104. async def async_shutdown(km: AsyncKernelManager):
  105. logging.info("Shutting down kernel...")
  106. await km.shutdown_kernel()
  107. logging.info("Kernel shut down")
  108. sys.exit(0)
  109. class State:
  110. def __init__(self, kernel_client: AsyncKernelClient):
  111. self.last_activity = datetime.now()
  112. self.kernel_client = kernel_client
  113. def reset_last_activity(self):
  114. self.last_activity = datetime.now()
  115. class MainHandler(tornado.web.RequestHandler):
  116. def initialize(self, state: State):
  117. self.state = state
  118. async def get(self):
  119. try:
  120. is_alive = await client.is_alive()
  121. if not is_alive:
  122. raise Exception("kernel is not alive")
  123. self.write(
  124. PingResponse(
  125. last_activity=datetime_to_rfc3339(self.state.last_activity)
  126. ).model_dump_json()
  127. )
  128. except Exception as e:
  129. self.set_status(500)
  130. self.write(Error(error=str(e)).model_dump_json())
  131. return
  132. def serializer(o):
  133. if isinstance(o, datetime):
  134. return o.isoformat()
  135. raise TypeError("Type not serializable")
  136. def log_jupyter_kernel_message(msg):
  137. m = json.dumps(msg, default=serializer)
  138. logging.info(f"Jupyter: {m}")
  139. class ExecuteHandler(tornado.web.RequestHandler):
  140. def initialize(self, state: State):
  141. self.state = state
  142. async def post(self):
  143. parent_msg_id = None
  144. res: ExecuteResponse = ExecuteResponse(status=ExecutionStatusOK, events=[])
  145. try:
  146. logging.info(f"Execute request: {self.request.body}")
  147. self.state.reset_last_activity()
  148. req = ExecuteRequest.model_validate_json(self.request.body)
  149. local_queue = Queue()
  150. parent_msg_id = self.state.kernel_client.execute(req.code)
  151. msg_id_to_queue[parent_msg_id] = local_queue
  152. # Use the timeout logic on message processing
  153. try:
  154. await asyncio.wait_for(
  155. self.process_messages(parent_msg_id, local_queue, res),
  156. timeout=req.timeout_secs,
  157. )
  158. except asyncio.TimeoutError:
  159. logging.info(f"Timeout after {req.timeout_secs}s")
  160. res.status = ExecutionStatusTimeout
  161. return self.write(res.model_dump_json())
  162. self.state.reset_last_activity()
  163. self.write(res.model_dump_json())
  164. except Exception as e:
  165. self.set_status(500)
  166. self.write(Error(error=str(e)).model_dump_json())
  167. finally:
  168. # Cleanup after processing all messages
  169. if parent_msg_id is not None and parent_msg_id in msg_id_to_queue:
  170. del msg_id_to_queue[parent_msg_id]
  171. logging.info(f"Execute response: {res.model_dump_json()}")
  172. async def process_messages(self, parent_msg_id, queue, res):
  173. while True:
  174. msg = await queue.get()
  175. if msg["msg_type"] not in JupyterSupportedMessageTypes:
  176. continue
  177. elif msg["msg_type"] == JupyterMessageTypeStatus:
  178. if msg["content"]["execution_state"] == JupyterExecutionStateIdle:
  179. break
  180. elif msg["msg_type"] == JupyterMessageTypeStream:
  181. res.events.append(
  182. ExecutionEvent(
  183. type=ExecutionEventTypeStream,
  184. timestamp=datetime_to_rfc3339(datetime.now()),
  185. data=ExecutionEventStream(
  186. stream=msg["content"]["name"],
  187. text=msg["content"]["text"],
  188. ),
  189. )
  190. )
  191. elif msg["msg_type"] == JupyterMessageTypeDisplayData:
  192. res.events.append(
  193. ExecutionEvent(
  194. type=ExecutionEventTypeDisplayData,
  195. timestamp=datetime_to_rfc3339(datetime.now()),
  196. data=ExecutionEventDisplayData(variants=msg["content"]["data"]),
  197. )
  198. )
  199. elif msg["msg_type"] == JupyterMessageTypeError:
  200. res.events.append(
  201. ExecutionEvent(
  202. type=ExecutionEventTypeError,
  203. timestamp=datetime_to_rfc3339(datetime.now()),
  204. data=ExecutionEventError(
  205. ename=msg["content"]["ename"],
  206. evalue=msg["content"]["evalue"],
  207. traceback=msg["content"]["traceback"],
  208. ),
  209. )
  210. )
  211. elif msg["msg_type"] == JupyterMessageTypeExecuteResult:
  212. res.events.append(
  213. ExecutionEvent(
  214. type=ExecutionEventTypeDisplayData,
  215. timestamp=datetime_to_rfc3339(datetime.now()),
  216. data=ExecutionEventDisplayData(variants=msg["content"]["data"]),
  217. )
  218. )
  219. @tornado.web.stream_request_body
  220. class FileUploadHandler(tornado.web.RequestHandler):
  221. def initialize(self, state: State):
  222. self.state = state
  223. self.file_obj = None
  224. async def prepare(self):
  225. if self.request.method != "POST":
  226. self.set_status(404)
  227. self.finish()
  228. return
  229. path = self.path_args[0]
  230. full_path = os.path.join("/", path)
  231. os.makedirs(os.path.dirname(full_path), exist_ok=True)
  232. self.file_obj = open(full_path, "wb")
  233. content_length = int(self.request.headers.get("Content-Length", 0))
  234. logging.info(f"File upload: '{path}' (Content-Length: {content_length})")
  235. def data_received(self, chunk):
  236. if self.file_obj:
  237. self.file_obj.write(chunk)
  238. async def post(self, path):
  239. self.state.reset_last_activity()
  240. if self.file_obj:
  241. self.file_obj.close()
  242. self.set_status(201)
  243. class FileDownloadHandler(tornado.web.RequestHandler):
  244. def initialize(self, state: State):
  245. self.state = state
  246. async def get(self, path):
  247. self.state.reset_last_activity()
  248. full_path = os.path.join("/", path)
  249. if not os.path.exists(full_path):
  250. self.set_status(404)
  251. self.write(Error(error="file not found").model_dump_json())
  252. return
  253. content_length = os.path.getsize(full_path)
  254. logging.info(f"File download: '{path}' (Content-Length: {content_length})")
  255. # Set appropriate headers for file download
  256. self.set_header("Content-Length", content_length)
  257. self.set_header("Content-Type", "application/octet-stream")
  258. self.set_header(
  259. "Content-Disposition",
  260. f"attachment; filename*=UTF-8''{tornado.escape.url_escape(os.path.basename(full_path))}",
  261. )
  262. # Stream the file to the client
  263. with open(full_path, "rb") as f:
  264. while True:
  265. chunk = f.read(64 * 1024)
  266. if not chunk:
  267. break
  268. try:
  269. self.write(chunk)
  270. await self.flush()
  271. except tornado.iostream.StreamClosedError:
  272. return
  273. def shutdown(ioloop: tornado.ioloop.IOLoop, km):
  274. logging.info("Shutting down server...")
  275. ioloop.add_callback_from_signal(lambda: async_shutdown(km))
  276. if __name__ == "__main__":
  277. p = argparse.ArgumentParser()
  278. p.add_argument("--port", type=int, default=80)
  279. p.add_argument("--kernel-name", type=str, default="python3")
  280. args = p.parse_args()
  281. km, client = asyncio.run(async_create_kernel(args.kernel_name))
  282. state = State(client)
  283. application = tornado.web.Application(
  284. [
  285. (r"/", MainHandler, {"state": state}),
  286. (r"/execute", ExecuteHandler, {"state": state}),
  287. (r"/files/upload/-/(.*)", FileUploadHandler, {"state": state}),
  288. (r"/files/download/-/(.*)", FileDownloadHandler, {"state": state}),
  289. ]
  290. )
  291. application.listen(args.port)
  292. logging.info(f"Server started at http://localhost:{args.port}")
  293. ioloop = tornado.ioloop.IOLoop.current()
  294. signal.signal(signal.SIGINT, lambda sig, frame: shutdown(ioloop, km))
  295. signal.signal(signal.SIGTERM, lambda sig, frame: shutdown(ioloop, km))
  296. ioloop.add_callback(async_msg_producer, km, client)
  297. tornado.ioloop.IOLoop.current().start()