diff --git a/README.rst b/README.rst index ff9fbed..0cfcf5d 100644 --- a/README.rst +++ b/README.rst @@ -5,15 +5,15 @@ Introduction By design asyncio `does not allow `_ its event loop to be nested. This presents a practical problem: -When in an environment where the event loop is -already running it's impossible to run tasks and wait -for the result. Trying to do so will give the error +in an environment where the event loop is +already running, it's impossible to run tasks and wait +for the result. Attempting to do so will lead to a "``RuntimeError: This event loop is already running``". -The issue pops up in various environments, such as web servers, +This issue pops up in various environments, including web servers, GUI applications and in Jupyter notebooks. -This module patches asyncio to allow nested use of ``asyncio.run`` and +This module patches asyncio to enable nested usage of ``asyncio.run`` and ``loop.run_until_complete``. Installation @@ -30,15 +30,20 @@ Usage .. code-block:: python - import nest_asyncio - nest_asyncio.apply() + from nest_asyncio import NestedAsyncIO + with NestedAsyncIO(): + ... -Optionally the specific loop that needs patching can be given -as argument to ``apply``, otherwise the current event loop is used. -An event loop can be patched whether it is already running -or not. Only event loops from asyncio can be patched; -Loops from other projects, such as uvloop or quamash, -generally can't be patched. + +Wrap any code requiring nested runs with a ``NestedAsyncIO`` +context manager or manually call ``apply`` and ``revert`` on +demand. Optionally, a specific loop may be supplied as an +as argument to ``apply`` or the constructor if you do not +wish to patch the the current event loop. An event loop +may be patched regardless of its state, running +or stopped. Note that this packages is limited to ``asyncio`` +event loops: general loops from other projects, such as +``uvloop`` or ``quamash``, cannot be patched. .. |Build| image:: https://github.com/erdewit/nest_asyncio/actions/workflows/test.yml/badge.svg?branche=master diff --git a/nest_asyncio.py b/nest_asyncio.py index 1cb5c25..3d3537e 100644 --- a/nest_asyncio.py +++ b/nest_asyncio.py @@ -9,211 +9,310 @@ from heapq import heappop -def apply(loop=None): - """Patch asyncio to make its event loop reentrant.""" - _patch_asyncio() - _patch_policy() - _patch_tornado() - - loop = loop or asyncio.get_event_loop() - _patch_loop(loop) - - -def _patch_asyncio(): - """Patch asyncio module to use pure Python tasks and futures.""" - - def run(main, *, debug=False): - loop = asyncio.get_event_loop() - loop.set_debug(debug) - task = asyncio.ensure_future(main) - try: - return loop.run_until_complete(task) - finally: - if not task.done(): - task.cancel() - with suppress(asyncio.CancelledError): - loop.run_until_complete(task) - - def _get_event_loop(stacklevel=3): - loop = events._get_running_loop() - if loop is None: - loop = events.get_event_loop_policy().get_event_loop() - return loop - - # Use module level _current_tasks, all_tasks and patch run method. - if hasattr(asyncio, '_nest_patched'): - return - if sys.version_info >= (3, 6, 0): - asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = \ - asyncio.tasks._PyTask - asyncio.Future = asyncio.futures._CFuture = asyncio.futures.Future = \ - asyncio.futures._PyFuture - if sys.version_info < (3, 7, 0): - asyncio.tasks._current_tasks = asyncio.tasks.Task._current_tasks - asyncio.all_tasks = asyncio.tasks.Task.all_tasks - if sys.version_info >= (3, 9, 0): - events._get_event_loop = events.get_event_loop = \ - asyncio.get_event_loop = _get_event_loop - asyncio.run = run - asyncio._nest_patched = True - - -def _patch_policy(): - """Patch the policy to always return a patched loop.""" - - def get_event_loop(self): - if self._local._loop is None: - loop = self.new_event_loop() - _patch_loop(loop) - self.set_event_loop(loop) - return self._local._loop - - policy = events.get_event_loop_policy() - policy.__class__.get_event_loop = get_event_loop - - -def _patch_loop(loop): - """Patch loop to make it reentrant.""" - - def run_forever(self): - with manage_run(self), manage_asyncgens(self): - while True: - self._run_once() - if self._stopping: - break - self._stopping = False - - def run_until_complete(self, future): - with manage_run(self): - f = asyncio.ensure_future(future, loop=self) - if f is not future: - f._log_destroy_pending = False - while not f.done(): - self._run_once() - if self._stopping: +class NestedAsyncIO: + __slots__ = [ + "_loop", + "orig_run", + "orig_tasks", + "orig_futures", + "orig_loop_attrs", + "policy_get_loop", + "orig_get_loops", + "orig_tc", + "patched" + ] + _instance = None + _initialized = False + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, loop=None): + if not self._initialized: + self._loop = loop + self.orig_run = None + self.orig_tasks = [] + self.orig_futures = [] + self.orig_loop_attrs = {} + self.policy_get_loop = None + self.orig_get_loops = {} + self.orig_tc = None + self.patched = False + self.__class__._initialized = True + + def __enter__(self): + self.apply(self._loop) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.revert() + + def apply(self, loop=None): + """Patch asyncio to make its event loop reentrant.""" + if not self.patched: + self.patch_asyncio() + self.patch_policy() + self.patch_tornado() + + loop = loop or asyncio.get_event_loop() + self.patch_loop(loop) + self.patched = True + + def revert(self): + if self.patched: + for loop in self.orig_loop_attrs: + self.unpatch_loop(loop) + self.unpatch_tornado() + self.unpatch_policy() + self.unpatch_asyncio() + self.patched = False + + def patch_asyncio(self): + """Patch asyncio module to use pure Python tasks and futures.""" + + def run(main, *, debug=False): + loop = asyncio.get_event_loop() + loop.set_debug(debug) + task = asyncio.ensure_future(main) + try: + return loop.run_until_complete(task) + finally: + if not task.done(): + task.cancel() + with suppress(asyncio.CancelledError): + loop.run_until_complete(task) + + def _get_event_loop(stacklevel=3): + return (events._get_running_loop() + or events.get_event_loop_policy().get_event_loop()) + + # Use module level _current_tasks, all_tasks and patch run method. + if getattr(asyncio, '_nest_patched', False): + return + if sys.version_info >= (3, 6, 0): + self.orig_tasks = [asyncio.Task, asyncio.tasks._CTask, + asyncio.tasks.Task] + asyncio.Task = asyncio.tasks._CTask = \ + asyncio.tasks.Task = asyncio.tasks._PyTask + self.orig_futures = [asyncio.Future, asyncio.futures._CFuture, + asyncio.futures.Future] + asyncio.Future = asyncio.futures._CFuture = \ + asyncio.futures.Future = asyncio.futures._PyFuture + if sys.version_info < (3, 7, 0): + asyncio.tasks._current_tasks = asyncio.tasks.Task._current_tasks + asyncio.all_tasks = asyncio.tasks.Task.all_tasks + elif sys.version_info >= (3, 9, 0): + self.orig_get_loops = \ + {"events__get_event_loop": events._get_event_loop, + "events_get_event_loop": events.get_event_loop, + "asyncio_get_event_loop": asyncio.get_event_loop} + events._get_event_loop = events.get_event_loop = \ + asyncio.get_event_loop = _get_event_loop + self.orig_run = asyncio.run + asyncio.run = run + asyncio._nest_patched = True + + def unpatch_asyncio(self): + if self.orig_run: + asyncio.run = self.orig_run + asyncio._nest_patched = False + if sys.version_info >= (3, 6, 0): + (asyncio.Task, asyncio.tasks._CTask, + asyncio.tasks.Task) = self.orig_tasks + (asyncio.Future, asyncio.futures._CFuture, + asyncio.futures.Future) = self.orig_futures + if sys.version_info >= (3, 9, 0): + for key, value in self.orig_get_loops.items(): + setattr(asyncio if key.startswith('asyncio') + else events, key.split('_')[-1], value) + + def patch_policy(self): + """Patch the policy to always return a patched loop.""" + + def get_event_loop(this): + if this._local._loop is None: + loop = this.new_event_loop() + self.patch_loop(loop) + this.set_event_loop(loop) + return this._local._loop + + cls = events.get_event_loop_policy().__class__ + self.policy_get_loop = cls.get_event_loop + cls.get_event_loop = get_event_loop + + def unpatch_policy(self): + cls = events.get_event_loop_policy().__class__ + orig = self.policy_get_loop + if orig: + cls.get_event_loop = orig + + def patch_loop(self, loop): + """Patch loop to make it reentrant.""" + + def run_forever(this): + with manage_run(this), manage_asyncgens(this): + while True: + this._run_once() + if this._stopping: + break + this._stopping = False + + def run_until_complete(this, future): + with manage_run(this): + f = asyncio.ensure_future(future, loop=this) + if f is not future: + f._log_destroy_pending = False + while not f.done(): + this._run_once() + if this._stopping: + break + if not f.done(): + raise RuntimeError( + 'Event loop stopped before Future completed.') + return f.result() + + def _run_once(this): + """ + Simplified re-implementation of asyncio's _run_once that + runs handles as they become ready. + """ + ready = this._ready + scheduled = this._scheduled + while scheduled and scheduled[0]._cancelled: + heappop(scheduled) + + timeout = ( + 0 if ready or this._stopping + else min(max( + scheduled[0]._when - this.time(), 0), 86400) if scheduled + else None) + event_list = this._selector.select(timeout) + this._process_events(event_list) + + end_time = this.time() + this._clock_resolution + while scheduled and scheduled[0]._when < end_time: + handle = heappop(scheduled) + ready.append(handle) + + for _ in range(len(ready)): + if not ready: break - if not f.done(): - raise RuntimeError( - 'Event loop stopped before Future completed.') - return f.result() + handle = ready.popleft() + if not handle._cancelled: + # preempt the current task so that that checks in + # Task.__step do not raise + curr_task = curr_tasks.pop(this, None) + + try: + handle._run() + finally: + # restore the current task + if curr_task is not None: + curr_tasks[this] = curr_task + + handle = None - def _run_once(self): + @contextmanager + def manage_run(this): + """Set up the loop for running.""" + this._check_closed() + old_thread_id = this._thread_id + old_running_loop = events._get_running_loop() + try: + this._thread_id = threading.get_ident() + events._set_running_loop(this) + this._num_runs_pending += 1 + if this._is_proactorloop: + if this._self_reading_future is None: + this.call_soon(this._loop_self_reading) + yield + finally: + this._thread_id = old_thread_id + events._set_running_loop(old_running_loop) + this._num_runs_pending -= 1 + if this._is_proactorloop: + if (this._num_runs_pending == 0 + and this._self_reading_future is not None): + ov = this._self_reading_future._ov + this._self_reading_future.cancel() + if ov is not None: + this._proactor._unregister(ov) + this._self_reading_future = None + + @contextmanager + def manage_asyncgens(this): + if not hasattr(sys, 'get_asyncgen_hooks'): + # Python version is too old. + return + old_agen_hooks = sys.get_asyncgen_hooks() + try: + this._set_coroutine_origin_tracking(this._debug) + if this._asyncgens is not None: + sys.set_asyncgen_hooks( + firstiter=this._asyncgen_firstiter_hook, + finalizer=this._asyncgen_finalizer_hook) + yield + finally: + this._set_coroutine_origin_tracking(False) + if this._asyncgens is not None: + sys.set_asyncgen_hooks(*old_agen_hooks) + + def _check_running(this): + """Do not throw exception if loop is already running.""" + pass + + if getattr(loop, '_nest_patched', False): + return + if not isinstance(loop, asyncio.BaseEventLoop): + raise ValueError('Can\'t patch loop of type %s' % type(loop)) + cls = loop.__class__ + self.orig_loop_attrs[cls] = {} + self.orig_loop_attrs[cls]["run_forever"] = cls.run_forever + cls.run_forever = run_forever + self.orig_loop_attrs[cls]["run_until_complete"] = \ + cls.run_until_complete + cls.run_until_complete = run_until_complete + self.orig_loop_attrs[cls]["_run_once"] = cls._run_once + cls._run_once = _run_once + self.orig_loop_attrs[cls]["_check_running"] = cls._check_running + cls._check_running = _check_running + self.orig_loop_attrs[cls]["_check_runnung"] = cls._check_running + cls._check_runnung = _check_running # typo in Python 3.7 source + cls._num_runs_pending = 1 if loop.is_running() else 0 + cls._is_proactorloop = (os.name == 'nt' + and issubclass(cls, + asyncio.ProactorEventLoop)) + if sys.version_info < (3, 7, 0): + cls._set_coroutine_origin_tracking = cls._set_coroutine_wrapper + curr_tasks = asyncio.tasks._current_tasks \ + if sys.version_info >= (3, 7, 0) else asyncio.Task._current_tasks + cls._nest_patched = True + + def unpatch_loop(self, loop): + loop._nest_patched = False + if self.orig_loop_attrs[loop]: + for key, value in self.orig_loop_attrs[loop].items(): + setattr(loop, key, value) + + for attr in ['_num_runs_pending', '_is_proactorloop']: + if hasattr(loop, attr): + delattr(loop, attr) + + def patch_tornado(self): """ - Simplified re-implementation of asyncio's _run_once that - runs handles as they become ready. + If tornado is imported before nest_asyncio, make tornado aware of + the pure-Python asyncio Future. """ - ready = self._ready - scheduled = self._scheduled - while scheduled and scheduled[0]._cancelled: - heappop(scheduled) - - timeout = ( - 0 if ready or self._stopping - else min(max( - scheduled[0]._when - self.time(), 0), 86400) if scheduled - else None) - event_list = self._selector.select(timeout) - self._process_events(event_list) - - end_time = self.time() + self._clock_resolution - while scheduled and scheduled[0]._when < end_time: - handle = heappop(scheduled) - ready.append(handle) - - for _ in range(len(ready)): - if not ready: - break - handle = ready.popleft() - if not handle._cancelled: - # preempt the current task so that that checks in - # Task.__step do not raise - curr_task = curr_tasks.pop(self, None) - - try: - handle._run() - finally: - # restore the current task - if curr_task is not None: - curr_tasks[self] = curr_task - - handle = None - - @contextmanager - def manage_run(self): - """Set up the loop for running.""" - self._check_closed() - old_thread_id = self._thread_id - old_running_loop = events._get_running_loop() - try: - self._thread_id = threading.get_ident() - events._set_running_loop(self) - self._num_runs_pending += 1 - if self._is_proactorloop: - if self._self_reading_future is None: - self.call_soon(self._loop_self_reading) - yield - finally: - self._thread_id = old_thread_id - events._set_running_loop(old_running_loop) - self._num_runs_pending -= 1 - if self._is_proactorloop: - if (self._num_runs_pending == 0 - and self._self_reading_future is not None): - ov = self._self_reading_future._ov - self._self_reading_future.cancel() - if ov is not None: - self._proactor._unregister(ov) - self._self_reading_future = None - - @contextmanager - def manage_asyncgens(self): - if not hasattr(sys, 'get_asyncgen_hooks'): - # Python version is too old. - return - old_agen_hooks = sys.get_asyncgen_hooks() - try: - self._set_coroutine_origin_tracking(self._debug) - if self._asyncgens is not None: - sys.set_asyncgen_hooks( - firstiter=self._asyncgen_firstiter_hook, - finalizer=self._asyncgen_finalizer_hook) - yield - finally: - self._set_coroutine_origin_tracking(False) - if self._asyncgens is not None: - sys.set_asyncgen_hooks(*old_agen_hooks) - - def _check_running(self): - """Do not throw exception if loop is already running.""" - pass - - if hasattr(loop, '_nest_patched'): - return - if not isinstance(loop, asyncio.BaseEventLoop): - raise ValueError('Can\'t patch loop of type %s' % type(loop)) - cls = loop.__class__ - cls.run_forever = run_forever - cls.run_until_complete = run_until_complete - cls._run_once = _run_once - cls._check_running = _check_running - cls._check_runnung = _check_running # typo in Python 3.7 source - cls._num_runs_pending = 1 if loop.is_running() else 0 - cls._is_proactorloop = ( - os.name == 'nt' and issubclass(cls, asyncio.ProactorEventLoop)) - if sys.version_info < (3, 7, 0): - cls._set_coroutine_origin_tracking = cls._set_coroutine_wrapper - curr_tasks = asyncio.tasks._current_tasks \ - if sys.version_info >= (3, 7, 0) else asyncio.Task._current_tasks - cls._nest_patched = True - - -def _patch_tornado(): - """ - If tornado is imported before nest_asyncio, make tornado aware of - the pure-Python asyncio Future. - """ - if 'tornado' in sys.modules: - import tornado.concurrent as tc # type: ignore - tc.Future = asyncio.Future - if asyncio.Future not in tc.FUTURES: - tc.FUTURES += (asyncio.Future,) + if 'tornado' in sys.modules: + import tornado.concurrent as tc # type: ignore + self.orig_tc = tc.Future + tc.Future = asyncio.Future + if asyncio.Future not in tc.FUTURES: + tc.FUTURES += (asyncio.Future,) + + def unpatch_tornado(self): + if self.orig_tc: + import tornado.concurrent as tc + tc.Future = self.orig_tc diff --git a/tests/nest_test.py b/tests/nest_test.py index 076dbfe..1493198 100644 --- a/tests/nest_test.py +++ b/tests/nest_test.py @@ -2,7 +2,7 @@ import sys import unittest -import nest_asyncio +from nest_asyncio import NestedAsyncIO def exception_handler(loop, context): @@ -12,12 +12,13 @@ def exception_handler(loop, context): class NestTest(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() - nest_asyncio.apply(self.loop) + NestedAsyncIO().apply(self.loop) asyncio.set_event_loop(self.loop) self.loop.set_debug(True) self.loop.set_exception_handler(exception_handler) def tearDown(self): + NestedAsyncIO().revert() self.assertIsNone(asyncio._get_running_loop()) self.loop.close() del self.loop @@ -27,7 +28,6 @@ async def coro(self): return 42 def test_nesting(self): - async def f1(): result = self.loop.run_until_complete(self.coro()) self.assertEqual(result, await self.coro()) @@ -42,7 +42,6 @@ async def f2(): self.assertEqual(result, 42) def test_ensure_future_with_run_until_complete(self): - async def f(): task = asyncio.ensure_future(self.coro()) return self.loop.run_until_complete(task) @@ -51,7 +50,6 @@ async def f(): self.assertEqual(result, 42) def test_ensure_future_with_run_until_complete_with_wait(self): - async def f(): task = asyncio.ensure_future(self.coro()) done, pending = self.loop.run_until_complete( @@ -63,7 +61,6 @@ async def f(): self.assertEqual(result, 42) def test_timeout(self): - async def f1(): await asyncio.sleep(0.1) @@ -74,7 +71,6 @@ async def f2(): self.loop.run_until_complete(f2()) def test_two_run_until_completes_in_one_outer_loop(self): - async def f1(): self.loop.run_until_complete(asyncio.sleep(0.02)) return 4