diff --git a/sniffio/__init__.py b/sniffio/__init__.py index fb3364d..1bf6630 100644 --- a/sniffio/__init__.py +++ b/sniffio/__init__.py @@ -2,7 +2,7 @@ __all__ = [ "current_async_library", "AsyncLibraryNotFoundError", - "current_async_library_cvar" + "current_async_library_cvar", "hooks" ] from ._version import __version__ @@ -12,4 +12,5 @@ AsyncLibraryNotFoundError, current_async_library_cvar, thread_local, + hooks, ) diff --git a/sniffio/_impl.py b/sniffio/_impl.py index c1a7bbf..2f5886c 100644 --- a/sniffio/_impl.py +++ b/sniffio/_impl.py @@ -1,3 +1,4 @@ +from functools import partial from contextvars import ContextVar from typing import Optional import sys @@ -22,10 +23,68 @@ class AsyncLibraryNotFoundError(RuntimeError): pass +def _guessed_mode() -> str: + # special support for trio-asyncio + value = thread_local.name + if value is not None: + return value + + value = current_async_library_cvar.get() + if value is not None: + return value + + # Need to sniff for asyncio + if "asyncio" in sys.modules: + import asyncio + try: + current_task = asyncio.current_task # type: ignore[attr-defined] + except AttributeError: + current_task = asyncio.Task.current_task # type: ignore[attr-defined] + try: + if current_task() is not None: + return "asyncio" + except RuntimeError: + pass + + # Sniff for curio (for now) + if 'curio' in sys.modules: + from curio.meta import curio_running + if curio_running(): + return 'curio' + + raise AsyncLibraryNotFoundError( + "unknown async library, or not in async context" + ) + + +def _noop_hook(v: str) -> str: + return v + + +_NO_HOOK = object() + +# this is publicly mutable, if an async framework wants to implement complex +# async gen hook behaviour it can set +# sniffio.hooks[__package__] = detect_me. As long as it does so before +# defining its async gen finalizer function it is free from race conditions +hooks = { + # could be trio-asyncio or trio-guest mode + # once trio and trio-asyncio and sniffio align trio should set + # sniffio.hooks['trio'] = detect_trio + "trio": _guessed_mode, + # pre-cache some well-known well behaved asyncgen_finalizer modules + # and so it saves a trip around _is_asyncio(finalizer) when we + # know asyncio is asyncio and curio is curio + "asyncio.base_events": partial(_noop_hook, "asyncio"), + "curio.meta": partial(_noop_hook, "curio"), + _NO_HOOK: _guessed_mode, # no hooks installed, fallback +} + + def current_async_library() -> str: """Detect which async library is currently running. - The following libraries are currently supported: + The following libraries are currently special-cased: ================ =========== ============================ Library Requires Magic string @@ -63,33 +122,38 @@ async def generic_sleep(seconds): raise RuntimeError(f"Unsupported library {library!r}") """ - value = thread_local.name - if value is not None: - return value - - value = current_async_library_cvar.get() - if value is not None: - return value - - # Need to sniff for asyncio + finalizer = sys.get_asyncgen_hooks().finalizer + finalizer_module = getattr(finalizer, "__module__", _NO_HOOK) + if finalizer_module is None: # finalizer is old cython function + if "uvloop" in sys.modules and _is_asyncio(finalizer): + return "asyncio" + + try: + hook = hooks[finalizer_module] + except KeyError: + pass + else: + return hook() + + # special case asyncio - when implementing an asyncio event loop + # you have to implement _asyncgen_finalizer_hook in your own module + if _is_asyncio(finalizer): # eg qasync _SelectorEventLoop + hooks[finalizer_module] = partial(_noop_hook, "asyncio") + return "asyncio" + + # when implementing a twisted reactor you'd need to rely on hooks defined in + # twisted.internet.defer + assert type(finalizer_module) is str + sniffio_name = finalizer_module.rpartition(".")[0] + hooks[finalizer_module] = partial(_noop_hook, sniffio_name) + return sniffio_name + + +def _is_asyncio(finalizer): if "asyncio" in sys.modules: import asyncio try: - current_task = asyncio.current_task # type: ignore[attr-defined] - except AttributeError: - current_task = asyncio.Task.current_task # type: ignore[attr-defined] - try: - if current_task() is not None: - return "asyncio" + return finalizer == asyncio.get_running_loop()._asyncgen_finalizer_hook except RuntimeError: - pass - - # Sniff for curio (for now) - if 'curio' in sys.modules: - from curio.meta import curio_running - if curio_running(): - return 'curio' - - raise AsyncLibraryNotFoundError( - "unknown async library, or not in async context" - ) + return False + return False diff --git a/sniffio/_tests/test_sniffio.py b/sniffio/_tests/test_sniffio.py index 984c8c0..0192334 100644 --- a/sniffio/_tests/test_sniffio.py +++ b/sniffio/_tests/test_sniffio.py @@ -58,6 +58,28 @@ async def this_is_asyncio(): current_async_library() +def test_uvloop(): + import uvloop + + with pytest.raises(AsyncLibraryNotFoundError): + current_async_library() + + ran = [] + + async def this_is_asyncio(): + assert current_async_library() == "asyncio" + # Call it a second time to exercise the caching logic + assert current_async_library() == "asyncio" + ran.append(True) + + loop = uvloop.new_event_loop() + loop.run_until_complete(this_is_asyncio()) + assert ran == [True] + + with pytest.raises(AsyncLibraryNotFoundError): + current_async_library() + + # https://github.com/dabeaz/curio/pull/354 @pytest.mark.skipif( os.name == "nt" and sys.version_info >= (3, 9), @@ -82,3 +104,48 @@ async def this_is_curio(): with pytest.raises(AsyncLibraryNotFoundError): current_async_library() + + +def test_asyncio_in_curio(): + import curio + import asyncio + + async def this_is_asyncio(): + return current_async_library() + + async def this_is_curio(): + return current_async_library(), asyncio.run(this_is_asyncio()) + + assert curio.run(this_is_curio) == ("curio", "asyncio") + + +def test_curio_in_asyncio(): + import asyncio + import curio + + async def this_is_curio(): + return current_async_library() + + async def this_is_asyncio(): + return current_async_library(), curio.run(this_is_curio) + + assert asyncio.run(this_is_asyncio()) == ("asyncio", "curio") + + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason='to_thread requires 3.9') +def test_curio_in_asyncio_to_thread(): + import curio + import sniffio + import asyncio + + async def current_framework(): + return sniffio.current_async_library() + + + async def amain(): + sniffio.current_async_library() + return await asyncio.to_thread(curio.run, current_framework) + + + assert asyncio.run(amain()) == "curio" diff --git a/test-requirements.txt b/test-requirements.txt index 6742196..67eed4a 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,3 +1,4 @@ pytest pytest-cov curio +uvloop