Skip to content

Optionally prevent child tasks from being cancelled in asyncio.TaskGroup #101581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
DontPanicO opened this issue Feb 5, 2023 · 11 comments
Open
Labels
stdlib Python modules in the Lib dir topic-asyncio type-feature A feature request or enhancement

Comments

@DontPanicO
Copy link

DontPanicO commented Feb 5, 2023

Feature or enhancement

Add a flag to asyncio.TaskGroup to control whether, if a child task crashes, other should be cancelled or not.

Pitch

Currently asyncio.TaskGroup always cancels all child tasks if one fails. Even if that's the most common use case, I'd like to be able to switch from the current behaviour to prevent child tasks cancellation when one failure occurs and raise an ExceptionGroup with all exceptions raised (only after other tasks completed).

  • Excpetions raised in the async with block still have to cause child tasks to be canceled.
  • SystemExit and KeyboardInterrupt raised in a child task still cancel other tasks.

Example usage:

async def zero_div():
    1 / 0

async def wait_some(d):
    await asyncio.sleep(d)
    print(f'finished after {d}')

async def main():
    async with asyncio.TaskGroup(abort_on_first_exception=False) as tg:
        tg.create_task(wait_some(2))
        tg.create_task(wait_some(3))
        tg.create_task(zero_div())

asyncio.run(main())
# Should print out:
# finished after 2
# finished after 3
# and then raise an `ExceptionGroup` with a `ZeroDivisionError`

Looking at asyncio.TaskGroup source code, it seems that it could be achieved by adding the abort_on_first_exception (or whatever it should be named) flag to the __init__ method and then modify _on_task_done as follow:

class TaskGroup:

    def __init__(self, abort_on_first_exception=True):
        ...
        self._abort_on_first_exception = abort_on_first_exception

    ...

    def _on_task_done(self, task):
        ...
        self._errors.append(exc)
        is_base_error = self._is_base_error(exc)
        if is_base_error and self._base_error is None:
            self._base_error = exc

        if self._parent_task.done():
            # Not sure if this case is possible, but we want to handle
            # it anyways.
            self._loop.call_exception_handler({
                'message': f'Task {task!r} has errored out but its parent '
                           f'task {self._parent_task} is already completed',
                'exception': exc,
                'task': task,
            })
            return

        if not self._aborting and not self._parent_cancel_requested:
            #comment skipped for brevity
            if is_base_error or self._abort_on_first_exception:
                self._abort()
                self._parent_cancel_requested = True
                self._parent_task.cancel()

If it's reasonable to have it in asyncio, I'll be glad to submit a PR for it.

Previous discussion

Post on discuss

Linked PRs

@DontPanicO DontPanicO added the type-feature A feature request or enhancement label Feb 5, 2023
@github-project-automation github-project-automation bot moved this to Todo in asyncio Feb 5, 2023
@gvanrossum
Copy link
Member

Yeah, I think this is reasonable. We need to bikeshed on the name of the flag, and I would like to ensure that when the awaiting task (i.e., the TaskGroup itself) is cancelled, all awaited tasks (i.e., all running tasks that were created by the TaskGroup) are cancelled. There are tests for this, look for r.cancel() in test_taskgroups.py, we should probably clone some of these and run them with the new flag as well.

@DontPanicO
Copy link
Author

We need to bikeshed on the name of the flag

Yes, it's quite hard to figure out a name that's descriptive enough but reasonably short. May abort_on_child_exception ?

I would like to ensure that when the awaiting task (i.e., the TaskGroup itself) is cancelled, all awaited tasks (i.e., all running tasks that were created by the TaskGroup) are cancelled. There are tests for this, look for r.cancel() in test_taskgroups.py, we should probably clone some of these and run them with the new flag as well.

In the local clone of my fork I've already defined new tests for the new intended behaviour and I've copied the ones you've pointed out (setting the new flag to prevent child tasks cancellation).

Should I open a PR now or it'd be better to wait for the flag name to have been found?

@gvanrossum
Copy link
Member

abort_on_child_exception still feels too long. Maybe shield_errors=True (default False)? Or propagate_errors=False (default True)? Or spread_errors=True (default False)? (You see a pattern emerge, I'd like a verb and a noun, and the noun should probably be "errors". :-)

@DontPanicO
Copy link
Author

DontPanicO commented Feb 6, 2023

I consider shield_errors (default False) to be the closest. propagate_errors may be misleading since exceptions still propagate after other child tasks have been completed.
I'd like to follow this one, if there are no objections to shield_errors or better proposals.

@arhadthedev
Copy link
Member

Maybe postpone_errors (default False)?

@gvanrossum
Copy link
Member

Oh! defer_errors=True (default False).

@DontPanicO
Copy link
Author

Oh! defer_errors=True (default False).

Fine with me. We all agree on defer_errors=True (default False)?

@kumaraditya303
Copy link
Contributor

Oh! defer_errors=True (default False).

SGTM

@gvanrossum
Copy link
Member

So, since there is still more discussion on Discourse, I'd like to get a review from @achimnol, who (I hope) is collecting all the various possible solutions, e.g. Supervisor (also yours) and PersistentTaskGroup; there's a somewhat recent (post-US-PyCon) writeup in Discourse: https://discuss.python.org/t/asyncio-tasks-and-exception-handling-recommended-idioms/23806/9

@sscherfke
Copy link

Any updates on this? :-)

@Jonathan-Landeed
Copy link

Jonathan-Landeed commented Mar 24, 2025

This is by far the cleanest solution I've found:

index 1633478d1c8..2f0afbbcd82 100644
--- a/Lib/asyncio/taskgroups.py
+++ b/Lib/asyncio/taskgroups.py
@@ -26,7 +26,7 @@ class TaskGroup:
     a task will cancel all remaining tasks and wait for them to exit.
     The exceptions are then combined and raised as an `ExceptionGroup`.
     """
-    def __init__(self):
+    def __init__(self, *, defer_errors: bool = False):
         self._entered = False
         self._exiting = False
         self._aborting = False
@@ -37,6 +37,7 @@ def __init__(self):
         self._errors = []
         self._base_error = None
         self._on_completed_fut = None
+        self._defer_errors = defer_errors

     def __repr__(self):
         info = ['']
@@ -259,7 +260,11 @@ def _on_task_done(self, task):
             })
             return

-        if not self._aborting and not self._parent_cancel_requested:
+        if (
+            not self._aborting
+            and not self._parent_cancel_requested
+            and not self._defer_errors
+        ):
             # If parent task *is not* being cancelled, it means that we want
             # to manually cancel it to abort whatever is being run right now
             # in the TaskGroup.  But we want to mark parent task as

It will be possible to implement with aiojobs.Scheduler, but apparently the current version has a bug: #104091 (comment)

I don't think aiotools.Supervisor will solve this. The exception_handler isn't meant to be a reliable callback for a task exception. I had the same problem with aiojobs.Scheduler.

I tried very hard to get something working with aiotools, but the exception_handler doesn't count as retrieving the exceptions, so it gives "Future exception was never retrieved".

import asyncio
from types import TracebackType

import aiotools

MULTIPLE_EXCEPTIONS_MESSAGE = "Multiple exceptions occurred"

class WaitTaskGroup(aiotools.PersistentTaskGroup):
    _exceptions: list[BaseException]

    def __init__(
        self,
        *,
        name: str | None = None,
    ) -> None:
        async def exception_handler(
            exc_type: type[BaseException],  # noqa: ARG001
            exc_obj: BaseException,
            exc_tb: TracebackType,  # noqa: ARG001
        ) -> None:
            self._exceptions.append(exc_obj)

        super().__init__(name=name, exception_handler=exception_handler)

    async def __aenter__(self) -> "WaitTaskGroup":
        self._exceptions = []
        return await super().__aenter__()

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc_obj: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> bool | None:
        ret = await super().__aexit__(exc_type, exc_obj, exc_tb)
        # exception_handler is awaited in parent class
        if self._exceptions:
            for t in self._tasks:
                # this doesn't fix "Future exception was never retrieved"
                try:
                    await t
                except BaseException:
                    pass
            raise BaseExceptionGroup(
                MULTIPLE_EXCEPTIONS_MESSAGE,
                self._exceptions,
            )
        return ret

You can see how annoying it is to deal with the unretrieved exceptions here:

import asyncio
from types import TracebackType

import aiotools

MULTIPLE_EXCEPTIONS_MESSAGE = "Multiple exceptions occurred"


async def get_string() -> str:
    return "hello"


async def get_int() -> int:
    return 1


# async def raise_exception() -> typing.Never: # also works
async def raise_exception(s: str, *, fail: bool = True) -> None:
    if fail:
        raise Exception(s)


async def slow_print(s: str) -> None:
    await asyncio.sleep(1)
    print(s)


async def test_persistent_task_group(
    *,
    fail: bool = True,
) -> tuple[str, int, None, None]:
    exceptions = []

    async def exception_handler(
        exc_type: type[BaseException],  # noqa: ARG001
        exc_obj: BaseException,
        exc_tb: TracebackType,  # noqa: ARG001
    ) -> None:
        exceptions.append(exc_obj)

    async with aiotools.PersistentTaskGroup(exception_handler=exception_handler) as tg:
        t1 = tg.create_task(get_string())
        t2 = tg.create_task(get_int())
        t3 = tg.create_task(raise_exception("PersistentTaskGroup", fail=fail))
        t4 = tg.create_task(slow_print("PersistentTaskGroup"))
        t5 = tg.create_task(raise_exception("PersistentTaskGroup2", fail=fail))

    if exceptions:
        for t in (t1, t2, t3, t4, t5):
            # Tasks need to be awaited before ExceptionGroup is raised
            # see: https://github.com/aio-libs/aiojobs/blob/master/aiojobs/_scheduler.py#L241
            try:
                await t
            except BaseException:
                pass

        raise BaseExceptionGroup(
            MULTIPLE_EXCEPTIONS_MESSAGE,
            exceptions,
        )

    return t1.result(), t2.result(), t3.result(), t4.result()

I tried to raise the exceptions from asyncio.gather(*task_tuple, return_exceptions=True) as a BaseGroupException, and preserve the type information with a TypeVarTuple, but tuple comprehension loses type info and we don't have type mapping:

Ts = TypeVarTuple("Ts")
def cast_not_exception[*Ts](*args: tuple[*Ts | BaseException]) -> tuple[*Ts]:
    return cast("tuple[*Ts]", args)

Overriding _abort is suggested here, but affects cancels coming from the parent, not just siblings.

import asyncio

class UnabortableTaskGroup(asyncio.TaskGroup):
    _abort = lambda self: None


async def raise_exception() -> None:
    await asyncio.sleep(0.1)
    raise Exception("cancel")


async def slow_print(s: str) -> None:
    await asyncio.sleep(1)
    print(s)


async def test_my_task_group() -> None:
    async with MyTaskGroup() as tg:
        tg.create_task(slow_print("MyTaskGroup"))


async def test_unabortable_task_group() -> None:
    async with UnabortableTaskGroup() as tg:
        tg.create_task(slow_print("UnabortableTaskGroup"))


async def test_cancel() -> None:
    async with asyncio.TaskGroup() as tg:
        # UnabortableTaskGroup is the only one that prints
        tg.create_task(test_my_task_group())
        tg.create_task(test_unabortable_task_group())
        tg.create_task(raise_exception())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stdlib Python modules in the Lib dir topic-asyncio type-feature A feature request or enhancement
Projects
Status: Todo
Development

No branches or pull requests

8 participants