diff --git a/docs/source/examples/single_use_workers.py b/docs/source/examples/single_use_workers.py index 6b8d53c6..5a379073 100644 --- a/docs/source/examples/single_use_workers.py +++ b/docs/source/examples/single_use_workers.py @@ -31,10 +31,10 @@ async def amain(): nursery.start_soon(ctx.run_sync, worker, i) print("dual use worker behavior:") - async with trio_parallel.open_worker_context(retire=after_dual_use) as ctx: + async with trio_parallel.cache_scope(retire=after_dual_use): async with trio.open_nursery() as nursery: for i in range(10): - nursery.start_soon(ctx.run_sync, worker, i) + nursery.start_soon(trio_parallel.run_sync, worker, i) print("default behavior:") async with trio.open_nursery() as nursery: diff --git a/docs/source/reference.rst b/docs/source/reference.rst index 11daedb5..5f3516b0 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -138,6 +138,15 @@ lifetime is required in a subset of your application. .. autoclass:: WorkerContext() :members: +Alternatively, you can implicitly override the default context of :func:`run_sync` +in any subset of the task tree using `cache_scope()`. This async context manager +sets an internal TreeVar_ so that the current task and all nested subtasks operate +using an internal, isolated `WorkerContext`, without having to manually pass a +context object around. + +.. autofunction:: cache_scope + :async-with: + One typical use case for configuring workers is to set a policy for taking a worker out of service. For this, use the ``retire`` argument. This example shows how to build (trivial) stateless and stateful worker retirement policies. @@ -145,11 +154,11 @@ build (trivial) stateless and stateful worker retirement policies. .. literalinclude:: examples/single_use_workers.py A more realistic use-case might examine the worker process's memory usage (e.g. with -`psutil `_) and retire if usage is too high. +psutil_) and retire if usage is too high. If you are retiring workers frequently, like in the single-use case, a large amount -of process startup overhead will be incurred with the default worker type. If your -platform supports it, an alternate `WorkerType` might cut that overhead down. +of process startup overhead will be incurred with the default "spawn" worker type. +If your platform supports it, an alternate `WorkerType` might cut that overhead down. .. autoclass:: WorkerType() @@ -161,4 +170,6 @@ You probably won't use these... but create an issue if you do and need help! .. autofunction:: default_context_statistics .. _cloudpickle: https://github.com/cloudpipe/cloudpickle +.. _psutil: https://psutil.readthedocs.io/en/latest/ .. _service: https://github.com/richardsheridan/trio-parallel/issues/348 +.. _TreeVar: https://tricycle.readthedocs.io/en/latest/reference.html#tricycle.TreeVar diff --git a/newsfragments/455.feature.rst b/newsfragments/455.feature.rst new file mode 100644 index 00000000..73c77742 --- /dev/null +++ b/newsfragments/455.feature.rst @@ -0,0 +1,3 @@ +Add `cache_scope()`, an async context manager that can override the behavior of +`trio_parallel.run_sync()` in a subtree of your Trio tasks with an implicit +`WorkerContext`. diff --git a/pyproject.toml b/pyproject.toml index 0ce7b807..3673a48a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "attrs >= 17.3.0", "cffi; os_name == 'nt' and implementation_name != 'pypy'", "tblib", + "tricycle >= 0.3.0" ] requires-python = ">=3.7" dynamic = ["version"] diff --git a/requirements/coverage.txt b/requirements/coverage.txt index d361c9d3..2a0f418b 100644 --- a/requirements/coverage.txt +++ b/requirements/coverage.txt @@ -5,5 +5,5 @@ # # pip-compile-multi # -coverage[toml]==7.6.3 +coverage[toml]==7.6.9 # via -r requirements\coverage.in diff --git a/requirements/docs.txt b/requirements/docs.txt index 4d33e1d8..58140558 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -10,11 +10,11 @@ alabaster==1.0.0 # via sphinx babel==2.16.0 # via sphinx -certifi==2024.8.30 +certifi==2024.12.14 # via requests -charset-normalizer==3.4.0 +charset-normalizer==3.4.1 # via requests -click==8.1.7 +click==8.1.8 # via towncrier colorama==0.4.6 # via @@ -26,13 +26,13 @@ docutils==0.21.2 # sphinx-rtd-theme imagesize==1.4.1 # via sphinx -jinja2==3.1.4 +jinja2==3.1.5 # via # sphinx # towncrier markupsafe==3.0.2 # via jinja2 -packaging==24.1 +packaging==24.2 # via sphinx pygments==2.18.0 # via sphinx @@ -46,7 +46,7 @@ sphinx==8.1.3 # sphinx-rtd-theme # sphinxcontrib-jquery # sphinxcontrib-trio -sphinx-rtd-theme==3.0.1 +sphinx-rtd-theme==3.0.2 # via -r requirements\docs.in sphinxcontrib-applehelp==2.0.0 # via sphinx @@ -66,5 +66,5 @@ sphinxcontrib-trio==1.1.2 # via -r requirements\docs.in towncrier==24.8.0 # via -r requirements\docs.in -urllib3==2.2.3 +urllib3==2.3.0 # via requests diff --git a/requirements/install.in b/requirements/install.in index 734b6685..fd5e71c3 100644 --- a/requirements/install.in +++ b/requirements/install.in @@ -1,6 +1,7 @@ -# keep me in sync with setup.cfg! +# keep me in sync with pyproject.toml! trio >= 0.18.0 outcome attrs >= 17.3.0 cffi; os_name == 'nt' and implementation_name != 'pypy' -tblib \ No newline at end of file +tblib +tricycle >= 0.3.0 diff --git a/requirements/install.txt b/requirements/install.txt index 8ffed149..757ee7c2 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -1,11 +1,11 @@ -# SHA1:63e73656a123a857e40bd1da6cce5906958410ed +# SHA1:ab818e2cd7a4dab60404b5f2cc3e40669cdb1c52 # # This file is autogenerated by pip-compile-multi # To update, run: # # pip-compile-multi # -attrs==24.2.0 +attrs==24.3.0 # via # -r requirements\install.in # outcome @@ -28,5 +28,9 @@ sortedcontainers==2.4.0 # via trio tblib==3.0.0 # via -r requirements\install.in -trio==0.27.0 +tricycle==0.4.1 # via -r requirements\install.in +trio==0.27.0 + # via + # -r requirements\install.in + # tricycle diff --git a/requirements/lint.txt b/requirements/lint.txt index 7a6cb3b1..b2f31306 100644 --- a/requirements/lint.txt +++ b/requirements/lint.txt @@ -7,7 +7,7 @@ # black==24.10.0 # via -r requirements\lint.in -click==8.1.7 +click==8.1.8 # via black colorama==0.4.6 # via click @@ -15,15 +15,15 @@ flake8==7.1.1 # via # -r requirements\lint.in # flake8-async -flake8-async==24.9.5 +flake8-async==24.11.4 # via -r requirements\lint.in -libcst==1.5.0 +libcst==1.5.1 # via flake8-async mccabe==0.7.0 # via flake8 mypy-extensions==1.0.0 # via black -packaging==24.1 +packaging==24.2 # via black pathspec==0.12.1 # via black diff --git a/requirements/test.txt b/requirements/test.txt index da91ad80..2b9655e5 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -8,7 +8,7 @@ -r install.txt colorama==0.4.6 # via pytest -coverage[toml]==7.6.3 +coverage[toml]==7.6.9 # via # -r requirements\test.in # pytest-cov @@ -16,17 +16,17 @@ execnet==2.1.1 # via pytest-xdist iniconfig==2.0.0 # via pytest -packaging==24.1 +packaging==24.2 # via pytest pluggy==1.5.0 # via pytest -pytest==8.3.3 +pytest==8.3.4 # via # -r requirements\test.in # pytest-cov # pytest-trio # pytest-xdist -pytest-cov==5.0.0 +pytest-cov==6.0.0 # via -r requirements\test.in pytest-trio==0.8.0 # via -r requirements\test.in diff --git a/trio_parallel/__init__.py b/trio_parallel/__init__.py index 9b437753..a4fd4403 100644 --- a/trio_parallel/__init__.py +++ b/trio_parallel/__init__.py @@ -3,6 +3,7 @@ from ._impl import ( run_sync, open_worker_context, + cache_scope, WorkerContext, WorkerType, current_default_worker_limiter, diff --git a/trio_parallel/_impl.py b/trio_parallel/_impl.py index d052f883..b4f46c0f 100644 --- a/trio_parallel/_impl.py +++ b/trio_parallel/_impl.py @@ -9,6 +9,7 @@ from typing import Type, Callable, Any, TypeVar import attr +import tricycle import trio from ._proc import WORKER_PROC_MAP @@ -270,42 +271,44 @@ def configure_default_context( DEFAULT_CONTEXT = ctx +CACHE_SCOPE_TREEVAR = tricycle.TreeVar("tp_cache_scope") + if sys.platform == "win32": - DEFAULT_CONTEXT_RUNVAR = trio.lowlevel.RunVar("win32_ctx") + DEFAULT_CONTEXT_RUNVAR = trio.lowlevel.RunVar("tp_win32_ctx") DEFAULT_CONTEXT_PARAMS = {} - # TODO: intelligently test ki protection here such that CI fails if the - # decorators disappear - - @trio.lowlevel.enable_ki_protection def get_default_context(): + try: + return CACHE_SCOPE_TREEVAR.get() + except LookupError: + pass try: ctx = DEFAULT_CONTEXT_RUNVAR.get() except LookupError: ctx = WorkerContext._create(**DEFAULT_CONTEXT_PARAMS) - DEFAULT_CONTEXT_RUNVAR.set(ctx) - # KeyboardInterrupt here could leak the context trio.lowlevel.spawn_system_task(close_at_run_end, ctx) + # set ctx last so as not to leak on KeyboardInterrupt + DEFAULT_CONTEXT_RUNVAR.set(ctx) return ctx - @trio.lowlevel.enable_ki_protection async def close_at_run_end(ctx): try: await trio.sleep_forever() finally: - # KeyboardInterrupt here could leak the context await ctx._aclose() # noqa: ASYNC102 else: def get_default_context(): - return DEFAULT_CONTEXT + try: + return CACHE_SCOPE_TREEVAR.get() + except (LookupError, RuntimeError): + return DEFAULT_CONTEXT - def graceful_default_shutdown(ctx): + @atexit.register + def graceful_default_shutdown(ctx=DEFAULT_CONTEXT): ctx._worker_cache.shutdown(ctx.grace_period) - atexit.register(graceful_default_shutdown, DEFAULT_CONTEXT) - def default_context_statistics(): """Return the statistics corresponding to the default context. @@ -377,6 +380,66 @@ async def open_worker_context( await ctx._aclose() # noqa: ASYNC102 +@asynccontextmanager +@trio.lowlevel.enable_ki_protection +async def cache_scope( + idle_timeout=DEFAULT_CONTEXT.idle_timeout, + init=DEFAULT_CONTEXT.init, + retire=DEFAULT_CONTEXT.retire, + grace_period=DEFAULT_CONTEXT.grace_period, + worker_type=WorkerType.SPAWN, +): + """ + Override the configuration of `trio_parallel.run_sync()` in this task and all + subtasks. + + The internal `WorkerContext` is passed implicitly down the task tree and can + be overridden by nested scopes. Explicit `WorkerContext` objects from + `open_worker_context` will not be overridden. + + Args: + idle_timeout (float): The time in seconds an idle worker will + wait for a CPU-bound job before shutting down and releasing its own + resources. Pass `math.inf` to wait forever. MUST be non-negative. + init (Callable[[], bool]): + An object to call within the worker before waiting for jobs. + This is suitable for initializing worker state so that such stateful logic + does not need to be included in functions passed to + :func:`WorkerContext.run_sync`. MUST be callable without arguments. + retire (Callable[[], bool]): + An object to call within the worker after executing a CPU-bound job. + The return value indicates whether worker should be retired (shut down.) + By default, workers are never retired. + The process-global environment is stable between calls. Among other things, + that means that storing state in global variables works. + MUST be callable without arguments. + grace_period (float): The time in seconds to wait in ``__aexit__`` for workers to + exit before issuing SIGKILL/TerminateProcess and raising `BrokenWorkerError`. + Pass `math.inf` to wait forever. MUST be non-negative. + worker_type (WorkerType): The kind of worker to create, see :class:`WorkerType`. + + Raises: + ValueError | TypeError: if an invalid value is passed for an argument, such as a + negative timeout. + BrokenWorkerError: if a worker does not shut down cleanly when exiting the scope. + + .. warning:: + + The callables passed to retire MUST not raise! Doing so will result in a + :class:`BrokenWorkerError` at an indeterminate future + :func:`WorkerContext.run_sync` call. + """ + + ctx = WorkerContext._create(idle_timeout, init, retire, grace_period, worker_type) + + try: + token = CACHE_SCOPE_TREEVAR.set(ctx) + yield + finally: + CACHE_SCOPE_TREEVAR.reset(token) + await ctx._aclose() # noqa: ASYNC102 + + async def run_sync( sync_fn: Callable[..., T], *args, diff --git a/trio_parallel/_tests/test_impl.py b/trio_parallel/_tests/test_impl.py index 691049c7..c0bf2b6b 100644 --- a/trio_parallel/_tests/test_impl.py +++ b/trio_parallel/_tests/test_impl.py @@ -1,5 +1,6 @@ """ Tests of public API with mocked-out workers ("collaboration" tests)""" +import os import sys from typing import Callable, Optional @@ -185,3 +186,71 @@ async def test_worker_returning_none_can_be_cancelled(): def test_cannot_instantiate_WorkerContext(): with pytest.raises(TypeError): _impl.WorkerContext() + + +async def _assert_worker_pid(pid, matches): + comparison = pid == await run_sync(os.getpid) + assert comparison == matches + + +async def test_cache_scope_overrides_run_sync(): + pid = await run_sync(os.getpid) + + async with _impl.cache_scope(): + await _assert_worker_pid(pid, False) + + +async def test_cache_scope_overrides_nursery_task(): + pid = await run_sync(os.getpid) + + async def check_both_sides_of_task_status_started(pid, task_status): + await _assert_worker_pid(pid, True) + task_status.started() + await _assert_worker_pid(pid, True) + + async with trio.open_nursery() as nursery: + async with _impl.cache_scope(): + nursery.start_soon(_assert_worker_pid, pid, True) + + async with trio.open_nursery() as nursery: + async with _impl.cache_scope(): + await nursery.start(check_both_sides_of_task_status_started, pid) + + +async def test_cache_scope_follows_task_tree_discipline(): + shared_nursery: Optional[trio.Nursery] = None + + async def make_a_cache_scope_around_nursery(task_status): + nonlocal shared_nursery + async with _impl.cache_scope(), trio.open_nursery() as shared_nursery: + await _assert_worker_pid(pid, False) + task_status.started() + await e.wait() + await _assert_worker_pid(pid, True) + + async def assert_elsewhere_in_task_tree(): + await _assert_worker_pid(pid, False) + e.set() + + pid = await run_sync(os.getpid) + e = trio.Event() + async with trio.open_nursery() as nursery: + await nursery.start(make_a_cache_scope_around_nursery) + # this line tests the main difference from contextvars vs treevars + shared_nursery.start_soon(assert_elsewhere_in_task_tree) + + +async def test_cache_scope_overrides_nested(): + pid1 = await run_sync(os.getpid) + async with _impl.cache_scope(): + pid2 = await run_sync(os.getpid) + async with _impl.cache_scope(): + await _assert_worker_pid(pid1, False) + await _assert_worker_pid(pid2, False) + + +async def test_cache_scope_doesnt_override_explicit_context(): + async with _impl.open_worker_context() as ctx: + pid = await ctx.run_sync(os.getpid) + async with _impl.cache_scope(): + assert pid == await ctx.run_sync(os.getpid)