diff --git a/changes/58.feature.md b/changes/58.feature.md new file mode 100644 index 0000000..87cbbc6 --- /dev/null +++ b/changes/58.feature.md @@ -0,0 +1 @@ +Add `TaskContext` and `TaskScope` to reuse the core concepts and logics for structured concurrency in `Supervisor` and other coroutine aggregation utilities like `as_completed_safe()`, `gather_safe()`, and `race()`. diff --git a/docs/aiotools.taskcontext.rst b/docs/aiotools.taskcontext.rst new file mode 100644 index 0000000..3e6450d --- /dev/null +++ b/docs/aiotools.taskcontext.rst @@ -0,0 +1,6 @@ +Task Context +============ + +.. automodule:: aiotools.taskcontext + :members: + :undoc-members: diff --git a/docs/aiotools.taskscope.rst b/docs/aiotools.taskscope.rst new file mode 100644 index 0000000..a5f8a62 --- /dev/null +++ b/docs/aiotools.taskscope.rst @@ -0,0 +1,6 @@ +Task Scope +========== + +.. automodule:: aiotools.taskscope + :members: + :undoc-members: diff --git a/docs/index.rst b/docs/index.rst index fc63071..67b4f0e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,6 +17,8 @@ aiotools Documentation aiotools.func aiotools.iter aiotools.server + aiotools.taskcontext + aiotools.taskscope aiotools.supervisor aiotools.taskgroup aiotools.timer diff --git a/src/aiotools/__init__.py b/src/aiotools/__init__.py index 79e9b02..9b55395 100644 --- a/src/aiotools/__init__.py +++ b/src/aiotools/__init__.py @@ -1,7 +1,18 @@ from importlib.metadata import version # Import submodules only when installed properly -from . import context, func, server, taskgroup, timeouts, timer, utils +from . import ( + context, + func, + server, + supervisor, + taskcontext, + taskgroup, + taskscope, + timeouts, + timer, + utils, +) from . import defer as _defer from . import fork as _fork from . import iter as _iter @@ -13,7 +24,10 @@ *func.__all__, *_iter.__all__, *server.__all__, + *taskcontext.__all__, + *taskscope.__all__, *taskgroup.__all__, + *supervisor.__all__, *timeouts.__all__, *timer.__all__, *utils.__all__, @@ -29,7 +43,10 @@ from .func import * # noqa from .iter import * # noqa from .server import * # noqa +from .supervisor import * # noqa +from .taskcontext import * # noqa from .taskgroup import * # noqa +from .taskscope import * # noqa from .timeouts import * # noqa from .timer import * # noqa from .utils import * # noqa diff --git a/src/aiotools/supervisor.py b/src/aiotools/supervisor.py index 7eb253d..56a47e7 100644 --- a/src/aiotools/supervisor.py +++ b/src/aiotools/supervisor.py @@ -1,10 +1,12 @@ -from asyncio import events, exceptions, tasks +__all__ = ["Supervisor"] + +import contextvars from typing import Optional -__all__ = ["Supervisor"] +from .taskscope import TaskScope -class Supervisor: +class Supervisor(TaskScope): """ Supervisor is a primitive structure to provide a long-lived context manager scope for an indefinite set of subtasks. During its lifetime, it is free to spawn new @@ -20,208 +22,25 @@ class Supervisor: from its subtasks. Instead, the callers must use additional task-done callbacks to process subtask results and exceptions. - Supervisor provides the same analogy to Kotlin's ``SupervisorScope`` and - Javascript's ``Promise.allSettled()``, while :class:`asyncio.TaskGroup` provides - the same analogy to Kotlin's ``CoroutineScope`` and Javascript's - ``Promise.all()``. + Supervisor provides the same analogy to Kotlin's |SupervisorScope|_ and + Javascript's |Promise.allSettled()|_, while :class:`asyncio.TaskGroup` provides + the same analogy to Kotlin's |CoroutineScope|_ and Javascript's |Promise.all()|_. + + .. |SupervisorScope| replace:: ``SupervisorScope`` + .. |CoroutineScope| replace:: ``CoroutineScope`` + .. _SupervisorScope: https://kotlinlang.org/api/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/supervisor-scope.html + .. _CoroutineScope: https://kotlinlang.org/api/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-coroutine-scope/ + .. |Promise.allSettled()| replace:: ``Promise.allSettled()`` + .. |Promise.all()| replace:: ``Promise.all()`` + .. _Promise.allSettled(): https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Promise/allSettled + .. _Promise.all(): https://developer.mozilla.org/ko/docs/Web/JavaScript/Reference/Global_Objects/Promise/all The original implementation is based on DontPanicO's pull request (https://github.com/achimnol/cpython/pull/31) and :class:`PersistentTaskGroup`, but it is modified *not* to store unhandled subtask exceptions. .. versionadded:: 2.0 - """ - - def __init__(self): - self._entered = False - self._exiting = False - self._aborting = False - self._loop = None - self._parent_task = None - self._parent_cancel_requested = False - self._tasks = set() - self._base_error = None - self._on_completed_fut = None - - def __repr__(self): - info = [""] - if self._tasks: - info.append(f"tasks={len(self._tasks)}") - if self._aborting: - info.append("cancelling") - elif self._entered: - info.append("entered") - - info_str = " ".join(info) - return f"" - - async def __aenter__(self): - if self._entered: - raise RuntimeError(f"Supervisor {self!r} has been already entered") - self._entered = True - - if self._loop is None: - self._loop = events.get_running_loop() - - self._parent_task = tasks.current_task(self._loop) - if self._parent_task is None: - raise RuntimeError(f"Supervisor {self!r} cannot determine the parent task") - - return self - - async def __aexit__(self, et, exc, tb): - self._exiting = True - - if exc is not None and self._is_base_error(exc) and self._base_error is None: - # SystemExit or KeyboardInterrupt in "async with" - # so we cancel other tasks. - self._base_error = exc - self._abort() - - propagate_cancellation_error = exc if et is exceptions.CancelledError else None - if self._parent_cancel_requested: - assert self._parent_task is not None - # If this flag is set we *must* call uncancel(). - if self._parent_task.uncancel() == 0: - # If there are no pending cancellations left, - # don't propagate CancelledError. - propagate_cancellation_error = None - - prop_ex = await self._wait_completion() - assert not self._tasks - if prop_ex is not None: - propagate_cancellation_error = prop_ex - - if self._base_error is not None: - raise self._base_error - - # Propagate CancelledError if there is one, except if there - # are other errors -- those have priority. - if propagate_cancellation_error: - raise propagate_cancellation_error - - # In the original version, it raises BaseExceptionGroup - # if there are collected errors in self._errors. - # This part is deliberately removed to prevent memory leak - # due to accumulating error objects for an indefinite length of time. - - def create_task(self, coro, *, name=None, context=None): - if not self._entered: - raise RuntimeError(f"Supervisor {self!r} has not been entered") - if self._exiting and not self._tasks: - raise RuntimeError(f"Supervisor {self!r} is finished") - if self._aborting: - raise RuntimeError(f"Supervisor {self!r} is shutting down") - if context is None: - task = self._loop.create_task(coro) - else: - task = self._loop.create_task(coro, context=context) - tasks._set_task_name(task, name) - task.add_done_callback(self._on_task_done) - self._tasks.add(task) - return task - - # Since Python 3.8 Tasks propagate all exceptions correctly, - # except for KeyboardInterrupt and SystemExit which are - # still considered special. - - def _is_base_error(self, exc: BaseException) -> bool: - assert isinstance(exc, BaseException) - return isinstance(exc, (SystemExit, KeyboardInterrupt)) - - def _abort(self, msg: Optional[str] = None): - self._aborting = True - - for t in self._tasks: - if not t.done(): - t.cancel(msg=msg) - - async def _wait_completion(self): - # We use while-loop here because "self._on_completed_fut" - # can be cancelled multiple times if our parent task - # is being cancelled repeatedly (or even once, when - # our own cancellation is already in progress) - propagate_cancellation_error = None - while self._tasks: - if self._on_completed_fut is None: - self._on_completed_fut = self._loop.create_future() - - try: - await self._on_completed_fut - except exceptions.CancelledError as ex: - if not self._aborting: - # Our parent task is being cancelled: - # - # async def wrapper(): - # async with TaskGroup() as g: - # g.create_task(foo) - # - # "wrapper" is being cancelled while "foo" is - # still running. - propagate_cancellation_error = ex - self._abort(msg=ex.args[0] if ex.args else None) - self._on_completed_fut = None - - return propagate_cancellation_error - - async def shutdown(self) -> None: - self._abort(msg="supervisor.shutdown") - await self._wait_completion() - - def _on_task_done(self, task): - self._tasks.discard(task) - - if self._on_completed_fut is not None and not self._tasks: - if not self._on_completed_fut.done(): - self._on_completed_fut.set_result(True) - - if task.cancelled(): - return - - exc = task.exception() - if exc is None: - return - - _is_base_error = self._is_base_error(exc) - if _is_base_error and self._base_error is None: - self._base_error = exc - - assert self._parent_task is not None - if self._parent_task.done(): - # Not sure if this case is possible, but we want to handle - # it anyways. - self._loop.call_exception_handler( - { - "message": ( - f"Task {task!r} has errored out but its parent " - f"task {self._parent_task} is already completed" - ), - "exception": exc, - "task": task, - } - ) - return + """ # noqa: E501 - if _is_base_error and not self._aborting and not self._parent_cancel_requested: - # For base SystemExit and KeyboardInterrupt ONLY, if parent task - # *is not* being cancelled, it means that we want to manually cancel - # it to abort whatever is being run right now in the Supervisor. - # But we want to mark parent task as "not cancelled" later in __aexit__. - # Example situation that we need to handle: - # - # async def foo(): - # try: - # async with Supervisor() as s: - # s.create_task(crash_soon()) - # await something # <- this needs to be waited - # # by the Supervisor, unless - # # crash_soon() raises either - # # SystemExit or KeyboardInterrupt - # except Exception: - # # Ignore any exceptions raised in the Supervisor - # pass - # await something_else # this line has to be called - # # after TaskGroup is finished. - self._abort() - self._parent_cancel_requested = True - self._parent_task.cancel() + def __init__(self, context: Optional[contextvars.Context] = None) -> None: + super().__init__(delegate_errors=None, context=context) diff --git a/src/aiotools/taskcontext.py b/src/aiotools/taskcontext.py new file mode 100644 index 0000000..82f9f2b --- /dev/null +++ b/src/aiotools/taskcontext.py @@ -0,0 +1,205 @@ +__all__ = ["ErrorArg", "ErrorCallback", "TaskContext"] + +import asyncio +import contextvars +import enum +import warnings +from contextvars import Context +from typing import ( + Any, + Callable, + Coroutine, + Generator, + Optional, + TypeAlias, + TypedDict, + TypeVar, +) + + +class DefaultErrorHandler(enum.Enum): + TOKEN = 0 + + +class ErrorArg(TypedDict): + message: str + exception: BaseException + task: asyncio.Task[Any] + + +ErrorCallback: TypeAlias = Callable[[ErrorArg], None] +T = TypeVar("T") + + +class TaskContext: + """ + TaskContext keeps the references to the child tasks during its lifetime, + so that they can be terminated safely when shutdown is explicitly requested. + + This is the loosest form of child task managers among TaskScope, TaskGroup, and + Supervisor, as it does not enforce structured concurrency but just provides a + reference set to child tasks. + + You may replace existing patterns using :class:`weakref.WeakSet` to keep track + of child tasks for a long-running server application with TaskContext. + + If ``delegate_errors`` is not set (the default behavior), it will run + :meth:`loop.call_exception_handler() ` + with the context argument consisting of the ``message``, ``task`` (the child task + that raised the exception), and ``exception`` (the exception object) fields. + + If it is set *None*, it will silently ignore the exception. + + If it is set as a callable function, it will invoke the specified callback + function using the context argument same to that used when calling + :meth:`loop.call_exception_handler() `. + + If you provide ``context``, it will be passed to :meth:`create_task()` by default. + + .. versionadded:: 2.1 + """ + + _tasks: set[asyncio.Task[Any]] + _parent_task: Optional[asyncio.Task[Any]] + _loop: Optional[asyncio.AbstractEventLoop] + + def __init__( + self, + delegate_errors: Optional[ + ErrorCallback | DefaultErrorHandler + ] = DefaultErrorHandler.TOKEN, + context: Optional[contextvars.Context] = None, + ) -> None: + self._loop = None + self._tasks = set() + self._parent_task = None + self._delegate_errors = delegate_errors + self._default_context = context + # status flags + self._entered = False + self._exited = False + self._aborting = False + + def __del__(self) -> None: + loc: str = "" # TODO: implement + if self._entered and not self._exited: + warnings.warn( + f"TaskContext initialized at {loc} is not properly " + "terminated until it is garbage-collected.", + category=ResourceWarning, + ) + + def __repr__(self) -> str: + info = [""] + if self._tasks: + info.append(f"tasks={len(self._tasks)}") + if self._aborting: + info.append("cancelling") + elif self._entered: + info.append("entered") + + info_str = " ".join(info) + return f"<{type(self).__name__} {info_str}>" + + async def shutdown(self) -> None: + """ + Triggers cancellation and waits for completion. + """ + self._aborting = True + try: + for t in {*self._tasks}: + if not t.done(): + t.cancel() + try: + await t + except asyncio.CancelledError: + pass + finally: + self._exited = True + + def create_task( + self, + coro: Generator[None, None, T] | Coroutine[Any, None, T], + *, + name: Optional[str] = None, + context: Optional[Context] = None, + ) -> asyncio.Task[T]: + """ + Create a new task in this scope and return it. + Similar to :func:`asyncio.create_task()`. + """ + if not self._entered: + self._entered = True + self._loop = asyncio.get_running_loop() + parent_task = asyncio.current_task() + assert parent_task is not None + if self._parent_task is not None and parent_task is not self._parent_task: + raise RuntimeError( + "{type(self).__name__} must be used within a single parent task." + ) + self._parent_task = parent_task + return self._create_task(coro, name=name, context=context) + + def _create_task( + self, + coro: Generator[None, None, T] | Coroutine[Any, None, T], + *, + name: Optional[str] = None, + context: Optional[Context] = None, + ) -> asyncio.Task[T]: + assert self._loop is not None + if self._aborting: + raise RuntimeError(f"{type(self).__name__} {self!r} is shutting down") + task: asyncio.Task[T] = self._loop.create_task( + coro, + name=name, + context=self._default_context if context is None else context, + ) + # optimization: Immediately call the done callback if the task is + # already done (e.g. if the coro was able to complete eagerly), + # and skip scheduling a done callback + if task.done(): + self._on_task_done(task) + else: + self._tasks.add(task) + task.add_done_callback(self._on_task_done) + return task + + def _on_task_done(self, task: asyncio.Task[Any]) -> None: + self._tasks.discard(task) + if task.cancelled(): + return + exc = task.exception() + if exc is None: + return + self._handle_task_exception(task) + + def _handle_task_exception(self, task: asyncio.Task[Any]) -> None: + assert self._loop is not None + exc = task.exception() + assert exc is not None + match self._delegate_errors: + case None: + pass # deliberately set to ignore errors + case func if callable(func): + func( + { + "message": ( + f"Task {task!r} has errored inside the parent " + f"task {self._parent_task}" + ), + "exception": exc, + "task": task, + } + ) + case DefaultErrorHandler(): + self._loop.call_exception_handler( + { + "message": ( + f"Task {task!r} has errored inside the parent " + f"task {self._parent_task}" + ), + "exception": exc, + "task": task, + } + ) diff --git a/src/aiotools/taskscope.py b/src/aiotools/taskscope.py new file mode 100644 index 0000000..cef7445 --- /dev/null +++ b/src/aiotools/taskscope.py @@ -0,0 +1,258 @@ +__all__ = ("TaskScope",) + +import asyncio +import contextvars +from asyncio import events, exceptions, tasks +from contextvars import Context +from typing import ( + Any, + Coroutine, + Generator, + Optional, + Self, + TypeGuard, + TypeVar, +) + +from .taskcontext import DefaultErrorHandler, ErrorCallback, TaskContext + +T = TypeVar("T") + + +class TaskScope(TaskContext): + """ + TaskScope is an asynchronous context manager which implements structured + concurrency, i.e., "scoped" cancellation, over a set of child tasks. + It terminates when all child tasks make conclusion (either results or exceptions). + + TaskScope subclasses TaskContext, but it mandates use of ``async with`` blocks + to clarify which task is the parent of the child tasks spawned via + :meth:`create_task()`. + + The key difference to :class:`asyncio.TaskGroup` is that it allows + customization of the exception handling logic for unhandled child + task exceptions, instead cancelling all pending child tasks upon any + unhandled child task exceptions. + + Refer :class:`TaskContext` for the descriptions about the constructor arguments. + + Based on this customizability, :class:`Supervisor` is a mere alias of TaskScope + with ``delegate_errors=None``. + + .. versionadded:: 2.1 + """ + + _tasks: set[asyncio.Task[Any]] + _on_completed_fut: Optional[asyncio.Future] + _base_error: Optional[BaseException] + + def __init__( + self, + delegate_errors: Optional[ + ErrorCallback | DefaultErrorHandler + ] = DefaultErrorHandler.TOKEN, + context: Optional[contextvars.Context] = None, + ) -> None: + super().__init__(delegate_errors=delegate_errors, context=context) + # status flags + self._entered = False + self._exiting = False + self._aborting = False + self._parent_cancel_requested = False + # taskscope-specifics + self._base_error = None + self._on_completed_fut = None + self._has_errors = False + + async def __aenter__(self) -> Self: + if self._entered: + raise RuntimeError( + f"{type(self).__name__} {self!r} has been already entered" + ) + self._entered = True + + if self._loop is None: + self._loop = events.get_running_loop() + + self._parent_task = tasks.current_task(self._loop) + if self._parent_task is None: + raise RuntimeError( + f"{type(self).__name__} {self!r} cannot determine the parent task" + ) + + return self + + async def __aexit__(self, et, exc, tb) -> Optional[bool]: + assert self._loop is not None + assert self._parent_task is not None + self._exiting = True + + if exc is not None and self._is_base_error(exc) and self._base_error is None: + self._base_error = exc + + propagate_cancellation_error = exc if et is exceptions.CancelledError else None + if self._parent_cancel_requested: + # If this flag is set we *must* call uncancel(). + if self._parent_task.uncancel() == 0: + # If there are no pending cancellations left, + # don't propagate CancelledError. + propagate_cancellation_error = None + + if et is not None: + if not self._aborting: + # Our parent task is being cancelled: + # + # async with TaskGroup() as g: + # g.create_task(...) + # await ... # <- CancelledError + # + # or there's an exception in "async with": + # + # async with TaskGroup() as g: + # g.create_task(...) + # 1 / 0 + # + self.abort() + + prop_ex = await self._wait_completion() + assert not self._tasks + if prop_ex is not None: + propagate_cancellation_error = prop_ex + self._exited = True + + if self._base_error is not None: + raise self._base_error + + # Propagate CancelledError as the child exceptions are handled separately. + if propagate_cancellation_error: + raise propagate_cancellation_error + + if et is not None and et is not exceptions.CancelledError: + self._has_errors = True + return None + + async def _wait_completion(self): + # We use while-loop here because "self._on_completed_fut" + # can be cancelled multiple times if our parent task + # is being cancelled repeatedly (or even once, when + # our own cancellation is already in progress) + assert self._loop is not None + propagate_cancellation_error = None + while self._tasks: + if self._on_completed_fut is None: + self._on_completed_fut = self._loop.create_future() + try: + await self._on_completed_fut + except exceptions.CancelledError as ex: + if not self._aborting: + # Our parent task is being cancelled: + # + # async def wrapper(): + # async with TaskGroup() as g: + # g.create_task(foo) + # + # "wrapper" is being cancelled while "foo" is + # still running. + propagate_cancellation_error = ex + self.abort(msg=ex.args[0] if ex.args else None) + self._on_completed_fut = None + return propagate_cancellation_error + + async def shutdown(self) -> None: + """ + Triggers cancellation and waits for completion. + """ + self.abort(r"{self!r} is shutdown") + await self._wait_completion() + + def create_task( + self, + coro: Generator[None, None, T] | Coroutine[Any, None, T], + *, + name: Optional[str] = None, + context: Optional[Context] = None, + ) -> tasks.Task[T]: + """ + Create a new task in this scope and return it. + Similar to :func:`asyncio.create_task()`. + """ + if not self._entered: + raise RuntimeError(f"{type(self).__name__} {self!r} has not been entered") + if self._exiting and not self._tasks: + raise RuntimeError(f"{type(self).__name__} {self!r} is finished") + return self._create_task(coro, name=name, context=context) + + # Since Python 3.8 Tasks propagate all exceptions correctly, + # except for KeyboardInterrupt and SystemExit which are + # still considered special. + + def _is_base_error(self, exc: BaseException) -> TypeGuard[BaseException]: + assert isinstance(exc, BaseException) + return isinstance(exc, (SystemExit, KeyboardInterrupt)) + + def abort(self, msg: Optional[str] = None) -> None: + """ + Triggers cancellation and immediately returns *without* waiting for completion. + """ + self._aborting = True + for t in self._tasks: + if not t.done(): + t.cancel(msg=msg) + + def _on_task_done(self, task: asyncio.Task[Any]) -> None: + assert self._loop is not None + assert self._parent_task is not None + self._tasks.discard(task) + if self._on_completed_fut is not None and not self._tasks: + if not self._on_completed_fut.done(): + self._on_completed_fut.set_result(True) + if task.cancelled(): + return + exc = task.exception() + if exc is None: + return + + self._has_errors = True + self._handle_task_exception(task) + + is_base_error = self._is_base_error(exc) + if is_base_error and self._base_error is None: + self._base_error = exc + + if self._parent_task.done(): + # Not sure if this case is possible, but we want to handle + # it anyways. + self._loop.call_exception_handler( + { + "message": ( + f"Task {task!r} has errored out but its parent " + f"task {self._parent_task} is already completed" + ), + "exception": exc, + "task": task, + } + ) + return + + if is_base_error: + # If parent task *is not* being cancelled, it means that we want + # to manually cancel it to abort whatever is being run right now + # in the TaskGroup. But we want to mark parent task as + # "not cancelled" later in __aexit__. Example situation that + # we need to handle: + # + # async def foo(): + # try: + # async with TaskGroup() as g: + # g.create_task(crash_soon()) + # await something # <- this needs to be canceled + # # by the TaskGroup, e.g. + # # foo() needs to be cancelled + # except Exception: + # # Ignore any exceptions raised in the TaskGroup + # pass + # await something_else # this line has to be called + # # after TaskGroup is finished. + self.abort() + self._parent_cancel_requested = True + self._parent_task.cancel() diff --git a/src/aiotools/utils.py b/src/aiotools/utils.py index 469269b..1f37eaf 100644 --- a/src/aiotools/utils.py +++ b/src/aiotools/utils.py @@ -11,6 +11,8 @@ Any, AsyncGenerator, Awaitable, + Coroutine, + Generator, Iterable, List, Optional, @@ -22,6 +24,7 @@ from .supervisor import Supervisor __all__ = ( + "cancel_and_wait", "as_completed_safe", "gather_safe", "race", @@ -30,8 +33,28 @@ T = TypeVar("T") +async def cancel_and_wait( + task: asyncio.Task[Any], + msg: str | None = None, +) -> None: + """ + See the discussion in https://github.com/python/cpython/issues/103486 + """ + task.cancel(msg) + try: + await task + except asyncio.CancelledError: + parent_task = asyncio.current_task() + if parent_task is not None and parent_task.cancelling() == 0: + raise + else: + return # this is the only non-exceptional return + else: + raise RuntimeError("Cancelled task did not end with an exception") + + async def as_completed_safe( - coros: Iterable[Awaitable[T]], + coros: Iterable[Coroutine[Any, None, T] | Generator[None, None, T]], *, context: Optional[Context] = None, ) -> AsyncGenerator[Awaitable[T], None]: @@ -54,9 +77,9 @@ async def as_completed_safe( def result_callback(t: asyncio.Task[Any]) -> None: q.put_nowait(t) - async with Supervisor() as supervisor: + async with Supervisor(context=context) as supervisor: for coro in coros: - t = supervisor.create_task(coro, context=context) + t = supervisor.create_task(coro) t.add_done_callback(result_callback) remaining += 1 while remaining: @@ -78,7 +101,7 @@ def result_callback(t: asyncio.Task[Any]) -> None: async def gather_safe( - coros: Iterable[Awaitable[T]], + coros: Iterable[Coroutine[Any, None, T] | Generator[None, None, T]], *, context: Optional[Context] = None, ) -> List[T | Exception]: @@ -97,16 +120,17 @@ async def gather_safe( .. versionadded:: 2.0 """ tasks = [] - async with Supervisor() as supervisor: + t: asyncio.Task[T] + async with Supervisor(context=context) as supervisor: for coro in coros: - t = supervisor.create_task(coro, context=context) + t = supervisor.create_task(coro) tasks.append(t) # To ensure safety, the Python version must be 3.7 or higher. return await asyncio.gather(*tasks, return_exceptions=True) async def race( - coros: Iterable[Awaitable[T]], + coros: Iterable[Coroutine[Any, None, T] | Generator[None, None, T]], *, continue_on_error: bool = False, context: Optional[Context] = None, diff --git a/tests/test_supervisor.py b/tests/test_supervisor.py index ad1e8b8..f351ae0 100644 --- a/tests/test_supervisor.py +++ b/tests/test_supervisor.py @@ -40,7 +40,35 @@ async def test_supervisor_partial_failure(): @pytest.mark.asyncio -async def test_supervisor_timeout(): +async def test_supervisor_timeout_before_failure(): + results = [] + errors = [] + cancelled = 0 + tasks = [] + with VirtualClock().patch_loop(): + with pytest.raises(TimeoutError): + async with ( + asyncio.timeout(0.15), + Supervisor() as supervisor, + ): + tasks.append(supervisor.create_task(do_job(0.1, 1))) + # timeout here + tasks.append(supervisor.create_task(fail_job(0.2))) + tasks.append(supervisor.create_task(do_job(0.3, 3))) + for t in tasks: + try: + results.append(t.result()) + except asyncio.CancelledError: + cancelled += 1 + except Exception as e: + errors.append(e) + assert results == [1] + assert len(errors) == 0 + assert cancelled == 2 + + +@pytest.mark.asyncio +async def test_supervisor_timeout_after_failure(): results = [] errors = [] cancelled = 0 @@ -53,6 +81,7 @@ async def test_supervisor_timeout(): ): tasks.append(supervisor.create_task(do_job(0.1, 1))) tasks.append(supervisor.create_task(fail_job(0.2))) + # timeout here tasks.append(supervisor.create_task(do_job(0.3, 3))) for t in tasks: try: