Skip to content

Commit

Permalink
Bring back a limited nest_asyncio
Browse files Browse the repository at this point in the history
Summary:
This brings back a limited subset of `nest_asyncio` to allow re-entrant calls into the event loop, but avoid monkey-patching the global namespace.

I don't really like this solution, but it seems the cleanest way to unblock suing TorchSnapshot in Bento for now.
A cleaner solution would be to refactor the code to *stop* passing around an event loop and instead rely on `nested_asyncio_run(Coroutine[T]) -> T` function (which we can implement relatively easily), but that seemed to be a much larger refactor.

Reviewed By: JKSenthil

Differential Revision: D60845357

fbshipit-source-id: 3873e507ddbfce219233aea4c759c55e0f574bd5
  • Loading branch information
alanhdu authored and facebook-github-bot committed Aug 6, 2024
1 parent ce8d7b6 commit 3f5ad9c
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 13 deletions.
157 changes: 157 additions & 0 deletions torchsnapshot/asyncio_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# pyre-unsafe

import asyncio
import functools
import os
import sys
import threading
from contextlib import contextmanager
from heapq import heappop


# copy-pasted from nest-asyncio, but modified to avoid patching the global
# namespace and instead only patching the instance variable
def _patch_loop(loop: asyncio.AbstractEventLoop) -> None:
def run_forever(self):
with manage_run(self), manage_asyncgens(self):
while True:
self._run_once()
if self._stopping:
break
self._stopping = False

def run_until_complete(self, future):
with manage_run(self):
f = asyncio.ensure_future(future, loop=self)
if f is not future:
f._log_destroy_pending = False
while not f.done():
self._run_once()
if self._stopping:
break
if not f.done():
raise RuntimeError("Event loop stopped before Future completed.")
return f.result()

def _run_once(self):
"""
Simplified re-implementation of asyncio's _run_once that
runs handles as they become ready.
"""
now = self.time()
ready = self._ready
scheduled = self._scheduled
while scheduled and scheduled[0]._cancelled:
heappop(scheduled)

timeout = (
0
if ready or self._stopping
else min(max(scheduled[0]._when - now, 0), 86400) if scheduled else None
)
event_list = self._selector.select(timeout)
self._process_events(event_list)

end_time = self.time() + self._clock_resolution
while scheduled and scheduled[0]._when < end_time:
handle = heappop(scheduled)
ready.append(handle)

for _ in range(len(ready)):
if not ready:
break
handle = ready.popleft()
if not handle._cancelled:
handle._run()
handle = None

@contextmanager
def manage_run(self):
"""Set up the loop for running."""
self._check_closed()
old_thread_id = self._thread_id
old_running_loop = asyncio.events._get_running_loop()
try:
self._thread_id = threading.get_ident()
asyncio.events._set_running_loop(self)
self._num_runs_pending += 1
if self._is_proactorloop:
if self._self_reading_future is None:
self.call_soon(self._loop_self_reading)
yield
finally:
self._thread_id = old_thread_id
asyncio.events._set_running_loop(old_running_loop)
self._num_runs_pending -= 1
if self._is_proactorloop:
if (
self._num_runs_pending == 0
and self._self_reading_future is not None
):
ov = self._self_reading_future._ov
self._self_reading_future.cancel()
if ov is not None:
self._proactor._unregister(ov)
self._self_reading_future = None

@contextmanager
def manage_asyncgens(self):
old_agen_hooks = sys.get_asyncgen_hooks()
try:
self._set_coroutine_origin_tracking(self._debug)
if self._asyncgens is not None:
sys.set_asyncgen_hooks(
firstiter=self._asyncgen_firstiter_hook,
finalizer=self._asyncgen_finalizer_hook,
)
yield
finally:
self._set_coroutine_origin_tracking(False)
if self._asyncgens is not None:
sys.set_asyncgen_hooks(*old_agen_hooks)

def _check_running(self):
"""Do not throw exception if loop is already running."""
pass

# pyre-fixme[8]: Attribute has type `(self: AbstractEventLoop) -> None`; used as
# `partial[typing.Any]`.
loop.run_forever = functools.partial(run_forever, loop)
# pyre-fixme[8]: Attribute has type `(self: AbstractEventLoop, future:
# Union[Awaitable[Variable[_T]], Generator[typing.Any, None, Variable[_T]]]) ->
# _T`; used as `partial[typing.Any]`.
loop.run_until_complete = functools.partial(run_until_complete, loop)
# pyre-fixme[16]: `AbstractEventLoop` has no attribute `_run_once`.
loop._run_once = functools.partial(_run_once, loop)
# pyre-fixme[16]: `AbstractEventLoop` has no attribute `_check_running`.
loop._check_running = functools.partial(_check_running, loop)
# pyre-fixme[16]: `AbstractEventLoop` has no attribute `_nest_patched`.
loop._nest_patched = True
# pyre-fixme[16]: `AbstractEventLoop` has no attribute `_num_runs_pending`.
loop._num_runs_pending = 0
# pyre-fixme[16]: `AbstractEventLoop` has no attribute `_is_proactorloop`.
loop._is_proactorloop = os.name == "nt" and isinstance(
loop,
# pyre-fixme[16]: Module `asyncio` has no attribute `ProactorEventLoop`.
asyncio.ProactorEventLoop,
)


# TODO: this is *not* an amazing w
def maybe_nested_loop() -> asyncio.AbstractEventLoop:
try:
original = asyncio.get_running_loop()
except RuntimeError:
original = None

loop = asyncio.new_event_loop()
if original is None:
return loop
else:
# Need to monkey-patch the loop so it can be re-entrant, which makes things
# work on old versions of Jupyter
#
# It would be better if we could refactor the code to rely more on
# asyncio.run instead of passing the event loop into places, but oh well...
_patch_loop(loop)
return loop
8 changes: 5 additions & 3 deletions torchsnapshot/io_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from dataclasses import dataclass, field
from typing import Generic, Optional, Tuple, TypeVar, Union

from .asyncio_utils import maybe_nested_loop


BufferType = Union[bytes, memoryview]

Expand Down Expand Up @@ -99,19 +101,19 @@ def sync_write(
self, write_io: WriteIO, event_loop: Optional[asyncio.AbstractEventLoop] = None
) -> None:
if event_loop is None:
event_loop = asyncio.new_event_loop()
event_loop = maybe_nested_loop()
event_loop.run_until_complete(self.write(write_io=write_io))

def sync_read(
self, read_io: ReadIO, event_loop: Optional[asyncio.AbstractEventLoop] = None
) -> None:
if event_loop is None:
event_loop = asyncio.new_event_loop()
event_loop = maybe_nested_loop()
event_loop.run_until_complete(self.read(read_io=read_io))

def sync_close(
self, event_loop: Optional[asyncio.AbstractEventLoop] = None
) -> None:
if event_loop is None:
event_loop = asyncio.new_event_loop()
event_loop = maybe_nested_loop()
event_loop.run_until_complete(self.close())
17 changes: 7 additions & 10 deletions torchsnapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,15 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from torchsnapshot.dtensor_utils import is_sharded

from .asyncio_utils import maybe_nested_loop
from .batcher import batch_read_requests, batch_write_requests

from .dist_store import get_or_create_store, LinearBarrier

from .event import Event
from .event_handlers import log_event

from .flatten import flatten, inflate
from .io_preparer import prepare_read, prepare_write
from .io_types import ReadIO, ReadReq, StoragePlugin, WriteIO, WriteReq
from .knobs import is_batching_disabled

from .manifest import Entry, Manifest, PrimitiveEntry, SnapshotMetadata
from .manifest_ops import get_manifest_for_rank, handle_sharded_tensor_elasticity
from .manifest_utils import is_container_entry
Expand Down Expand Up @@ -99,7 +96,7 @@ def __init__(
@property
def metadata(self) -> SnapshotMetadata:
if self._metadata is None:
event_loop = asyncio.new_event_loop()
event_loop = maybe_nested_loop()
storage = url_to_storage_plugin_in_event_loop(
url_path=self.path,
event_loop=event_loop,
Expand Down Expand Up @@ -169,7 +166,7 @@ def take(
torch._C._log_api_usage_once("torchsnapshot.Snapshot.take")
cls._validate_app_state(app_state)

event_loop = asyncio.new_event_loop()
event_loop = maybe_nested_loop()
pg_wrapper = PGWrapper(pg=pg)

unique_id = _generate_random_int64()
Expand Down Expand Up @@ -263,7 +260,7 @@ def async_take(
torch._C._log_api_usage_once("torchsnapshot.Snapshot.async_take")
cls._validate_app_state(app_state)

event_loop = asyncio.new_event_loop()
event_loop = maybe_nested_loop()
pg_wrapper = PGWrapper(pg=pg)

unique_id = _generate_random_int64()
Expand Down Expand Up @@ -336,7 +333,7 @@ def restore(self, app_state: AppState, strict: bool = True) -> None:
torch._C._log_api_usage_once("torchsnapshot.Snapshot.restore")
self._validate_app_state(app_state)

event_loop = asyncio.new_event_loop()
event_loop = maybe_nested_loop()
pg_wrapper = PGWrapper(self.pg)

unique_id = _generate_random_int64()
Expand Down Expand Up @@ -459,7 +456,7 @@ def read_object(
"Its state won't be changed after load. The loaded object will be returned."
)

event_loop = asyncio.new_event_loop()
event_loop = maybe_nested_loop()
pg_wrapper = PGWrapper(self.pg)
storage = url_to_storage_plugin_in_event_loop(
url_path=self.path,
Expand Down Expand Up @@ -703,7 +700,7 @@ def get_state_dict_for_key(self, key: str) -> Dict[Any, Any]:
snapshot = Snapshot.take(path=..., app_state={"stateful_key": module})
module_state_dict = snapshot.get_state_dict_for_key("stateful_key")
"""
event_loop = asyncio.new_event_loop()
event_loop = maybe_nested_loop()
pg = PGWrapper(self.pg)

manifest, _ = get_manifest_for_rank(metadata=self.metadata, rank=pg.get_rank())
Expand Down

0 comments on commit 3f5ad9c

Please sign in to comment.