From 2df96d0247cbf512c24e4d463369398e602e689e Mon Sep 17 00:00:00 2001 From: DontPanicO Date: Tue, 7 Feb 2023 17:13:39 +0000 Subject: [PATCH 1/4] add `defer_errors` flag to `asyncio.TaskGroup` `asyncio.TaskGroup` cancels all child tasks when one raises. Setting `defer_errors=True` will prevent this --- Lib/asyncio/taskgroups.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 911419e1769c17..b00ba87ad988e1 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -11,7 +11,7 @@ class TaskGroup: - def __init__(self): + def __init__(self, defer_errors=False): self._entered = False self._exiting = False self._aborting = False @@ -22,6 +22,7 @@ def __init__(self): self._errors = [] self._base_error = None self._on_completed_fut = None + self._defer_errors = defer_errors def __repr__(self): info = [''] @@ -180,7 +181,8 @@ def _on_task_done(self, task): return self._errors.append(exc) - if self._is_base_error(exc) and self._base_error is None: + is_base_error = self._is_base_error(exc) + if is_base_error and self._base_error is None: self._base_error = exc if self._parent_task.done(): @@ -213,6 +215,7 @@ def _on_task_done(self, task): # pass # await something_else # this line has to be called # # after TaskGroup is finished. - self._abort() - self._parent_cancel_requested = True - self._parent_task.cancel() + if not self._defer_errors or is_base_error: + self._abort() + self._parent_cancel_requested = True + self._parent_task.cancel() From 90e58aacadd68f102d42a07246f6968aa10723dc Mon Sep 17 00:00:00 2001 From: DontPanicO Date: Tue, 7 Feb 2023 17:14:38 +0000 Subject: [PATCH 2/4] tests for `asyncio.TaskGroup` with `defer_errors=True` --- Lib/test/test_asyncio/test_taskgroups.py | 434 +++++++++++++++++++++++ 1 file changed, 434 insertions(+) diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index 6a0231f2859a62..55357a803c7d3f 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -779,6 +779,440 @@ async def main(): await asyncio.create_task(main()) + async def test_children_complete_on_child_error(self): + async def zero_division(): + 1 / 0 + + async def foo1(): + await asyncio.sleep(0.1) + return 42 + + async def foo2(): + await asyncio.sleep(0.2) + return 11 + + with self.assertRaises(ExceptionGroup) as cm: + async with taskgroups.TaskGroup(defer_errors=True) as g: + t1 = g.create_task(foo1()) + t2 = g.create_task(foo2()) + g.create_task(zero_division()) + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + self.assertEqual(t1.result(), 42) + self.assertEqual(t2.result(), 11) + + async def test_inner_complete_on_child_error(self): + async def zero_division(): + 1 / 0 + + async def foo1(): + await asyncio.sleep(0.1) + return 42 + + async def foo2(): + await asyncio.sleep(0.2) + return 11 + + with self.assertRaises(ExceptionGroup) as cm: + async with taskgroups.TaskGroup(defer_errors=True) as g: + t1 = g.create_task(foo1()) + g.create_task(zero_division()) + r1 = await foo2() + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + self.assertEqual(t1.result(), 42) + self.assertEqual(r1, 11) + + async def test_children_exceptions_propagate(self): + async def zero_division(): + 1 / 0 + + async def value_error(): + await asyncio.sleep(0.2) + raise ValueError + + async def foo1(): + await asyncio.sleep(0.4) + return 42 + + with self.assertRaises(ExceptionGroup) as cm: + async with taskgroups.TaskGroup(defer_errors=True) as g: + g.create_task(zero_division()) + g.create_task(value_error()) + t1 = g.create_task(foo1()) + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError, ValueError}) + + self.assertEqual(t1.result(), 42) + + async def test_children_cancel_on_inner_failure(self): + async def zero_division(): + 1 / 0 + + async def foo1(): + await asyncio.sleep(0.2) + return 42 + + with self.assertRaises(ExceptionGroup) as cm: + async with taskgroups.TaskGroup(defer_errors=True) as g: + t1 = g.create_task(foo1()) + await zero_division() + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + self.assertTrue(t1.cancelled()) + + async def test_cancellation_01(self): + + NUM = 0 + + async def foo(): + nonlocal NUM + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + NUM += 1 + raise + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g: + for _ in range(5): + g.create_task(foo()) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError) as cm: + await r + + self.assertEqual(NUM, 5) + + async def test_taskgroup_35(self): + + NUM = 0 + + async def foo(): + nonlocal NUM + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + NUM += 1 + raise + + async def runner(): + nonlocal NUM + async with taskgroups.TaskGroup(defer_errors=True) as g: + for _ in range(5): + g.create_task(foo()) + + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + NUM += 10 + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + self.assertEqual(NUM, 15) + + async def test_taskgroup_36(self): + + async def foo(): + try: + await asyncio.sleep(10) + finally: + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g: + for _ in range(5): + g.create_task(foo()) + + await asyncio.sleep(10) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + async def test_taskgroup_37(self): + + async def foo(): + try: + await asyncio.sleep(10) + finally: + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True): + async with taskgroups.TaskGroup(defer_errors=True) as g2: + for _ in range(5): + g2.create_task(foo()) + + await asyncio.sleep(10) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + + self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) + self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError}) + + async def test_taskgroup_37a(self): + + async def foo(): + try: + await asyncio.sleep(10) + finally: + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(): + async with taskgroups.TaskGroup(defer_errors=True) as g2: + for _ in range(5): + g2.create_task(foo()) + + await asyncio.sleep(10) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + + self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) + self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError}) + + async def test_taskgroup_38(self): + + async def foo(): + try: + await asyncio.sleep(10) + finally: + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g1: + g1.create_task(asyncio.sleep(10)) + + async with taskgroups.TaskGroup(defer_errors=True) as g2: + for _ in range(5): + g2.create_task(foo()) + + await asyncio.sleep(10) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + + self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) + self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError}) + + async def test_taskgroup_39(self): + + async def crash_soon(): + await asyncio.sleep(0.3) + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g1: + g1.create_task(crash_soon()) + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + await asyncio.sleep(0.5) + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + async def test_taskgroup_40(self): + + async def crash_soon(): + await asyncio.sleep(0.3) + 1 / 0 + + async def nested_runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g1: + g1.create_task(crash_soon()) + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + await asyncio.sleep(0.5) + raise + + async def runner(): + t = asyncio.create_task(nested_runner()) + await t + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + async def test_taskgroup_41(self): + + NUM = 0 + + async def runner(): + nonlocal NUM + async with taskgroups.TaskGroup(defer_errors=True): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + NUM += 10 + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + self.assertEqual(NUM, 10) + + async def test_taskgroup_42(self): + + NUM = 0 + + async def runner(): + nonlocal NUM + async with taskgroups.TaskGroup(defer_errors=True): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + NUM += 10 + # This isn't a good idea, but we have to support + # this weird case. + raise MyExc + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + + try: + await r + except ExceptionGroup as t: + self.assertEqual(get_error_types(t),{MyExc}) + else: + self.fail('ExceptionGroup was not raised') + + self.assertEqual(NUM, 10) + + async def test_taskgroup_43(self): + + async def crash_soon(): + await asyncio.sleep(0.1) + 1 / 0 + + async def nested(): + try: + await asyncio.sleep(10) + finally: + raise MyExc + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g: + g.create_task(crash_soon()) + await nested() + + r = asyncio.create_task(runner()) + try: + await r + except ExceptionGroup as t: + self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError}) + else: + self.fail('TasgGroupError was not raised') + + async def test_taskgroup_44(self): + + async def foo1(): + await asyncio.sleep(1) + return 42 + + async def foo2(): + await asyncio.sleep(2) + return 11 + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g: + g.create_task(foo1()) + g.create_task(foo2()) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.05) + r.cancel() + + with self.assertRaises(asyncio.CancelledError): + await r + + async def test_taskgroup_45(self): + + NUM = 0 + + async def foo1(): + nonlocal NUM + await asyncio.sleep(0.2) + NUM += 1 + + async def foo2(): + nonlocal NUM + await asyncio.sleep(0.3) + NUM += 2 + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g: + g.create_task(foo1()) + g.create_task(foo2()) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.05) + r.cancel() + + with self.assertRaises(asyncio.CancelledError): + await r + + self.assertEqual(NUM, 0) + + if __name__ == "__main__": unittest.main() From 0b08887d175156a8d2c23859c5b51da5621f259b Mon Sep 17 00:00:00 2001 From: DontPanicO Date: Tue, 7 Feb 2023 18:10:20 +0000 Subject: [PATCH 3/4] Add `NEWS.d` entry --- .../next/Library/2023-02-07-18-08-32.gh-issue-101581.lxw8WY.rst | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 Misc/NEWS.d/next/Library/2023-02-07-18-08-32.gh-issue-101581.lxw8WY.rst diff --git a/Misc/NEWS.d/next/Library/2023-02-07-18-08-32.gh-issue-101581.lxw8WY.rst b/Misc/NEWS.d/next/Library/2023-02-07-18-08-32.gh-issue-101581.lxw8WY.rst new file mode 100644 index 00000000000000..e179ea9f680129 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-02-07-18-08-32.gh-issue-101581.lxw8WY.rst @@ -0,0 +1,2 @@ +Adds ``defer_errors`` flag to :class:`asyncio.TaskGroup` to optionally +prevent child tasks from being cancelled. From 481d37269efc91943329d9a6ec9e2466d799f7d4 Mon Sep 17 00:00:00 2001 From: DontPanicO Date: Fri, 17 Mar 2023 11:18:52 +0100 Subject: [PATCH 4/4] Update `asyncio.TaskGroup` docstring Add explanation about `defer_errors=True` bheaviour in `asyncio.TaskGroup` docstring --- Lib/asyncio/taskgroups.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 4a93e8cc54a63f..4d20598388b151 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -23,6 +23,8 @@ class TaskGroup: Any exceptions other than `asyncio.CancelledError` raised within a task will cancel all remaining tasks and wait for them to exit. + You can prevent this behavior by passing `defer_errors=True` to + the constructor. The exceptions are then combined and raised as an `ExceptionGroup`. """ def __init__(self, defer_errors=False):