From edbee56d698ebb4489aa68311f44d104a23f5eb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tin=20Tvrtkovi=C4=87?= Date: Sat, 26 Feb 2022 17:18:48 +0100 Subject: Taskgroup tweaks (GH-31559) Now uses .cancel()/.uncancel(), for even fewer broken edge cases. --- Lib/asyncio/taskgroups.py | 50 +++++++++++++++----------------- Lib/test/test_asyncio/test_taskgroups.py | 26 ++++++++++++----- 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 756fc55..c3ce94a 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -66,31 +66,28 @@ class TaskGroup: self._base_error is None): self._base_error = exc - if et is exceptions.CancelledError: - if self._parent_cancel_requested: - # Only if we did request task to cancel ourselves - # we mark it as no longer cancelled. - self._parent_task.uncancel() - else: - propagate_cancellation_error = et - - if et is not None and not self._aborting: - # Our parent task is being cancelled: - # - # async with TaskGroup() as g: - # g.create_task(...) - # await ... # <- CancelledError - # + if et is not None: if et is exceptions.CancelledError: - propagate_cancellation_error = et - - # or there's an exception in "async with": - # - # async with TaskGroup() as g: - # g.create_task(...) - # 1 / 0 - # - self._abort() + if self._parent_cancel_requested and not self._parent_task.uncancel(): + # Do nothing, i.e. swallow the error. + pass + else: + propagate_cancellation_error = exc + + if not self._aborting: + # Our parent task is being cancelled: + # + # async with TaskGroup() as g: + # g.create_task(...) + # await ... # <- CancelledError + # + # or there's an exception in "async with": + # + # async with TaskGroup() as g: + # g.create_task(...) + # 1 / 0 + # + self._abort() # We use while-loop here because "self._on_completed_fut" # can be cancelled multiple times if our parent task @@ -118,7 +115,6 @@ class TaskGroup: self._on_completed_fut = None assert self._unfinished_tasks == 0 - self._on_completed_fut = None # no longer needed if self._base_error is not None: raise self._base_error @@ -199,8 +195,7 @@ class TaskGroup: }) return - self._abort() - if not self._parent_task.cancelling(): + if not self._aborting and not self._parent_cancel_requested: # If parent task *is not* being cancelled, it means that we want # to manually cancel it to abort whatever is being run right now # in the TaskGroup. But we want to mark parent task as @@ -219,5 +214,6 @@ class TaskGroup: # pass # await something_else # this line has to be called # # after TaskGroup is finished. + self._abort() self._parent_cancel_requested = True self._parent_task.cancel() diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index 40774a8..df51528 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -120,7 +120,11 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): self.assertTrue(t2_cancel) self.assertTrue(t2.cancelled()) - async def test_taskgroup_05(self): + async def test_cancel_children_on_child_error(self): + """ + When a child task raises an error, the rest of the children + are cancelled and the errors are gathered into an EG. + """ NUM = 0 t2_cancel = False @@ -165,7 +169,7 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): self.assertTrue(t2_cancel) self.assertTrue(runner_cancel) - async def test_taskgroup_06(self): + async def test_cancellation(self): NUM = 0 @@ -186,10 +190,12 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): await asyncio.sleep(0.1) self.assertFalse(r.done()) - r.cancel() - with self.assertRaises(asyncio.CancelledError): + r.cancel("test") + with self.assertRaises(asyncio.CancelledError) as cm: await r + self.assertEqual(cm.exception.args, ('test',)) + self.assertEqual(NUM, 5) async def test_taskgroup_07(self): @@ -226,7 +232,7 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): self.assertEqual(NUM, 15) - async def test_taskgroup_08(self): + async def test_cancellation_in_body(self): async def foo(): await asyncio.sleep(0.1) @@ -246,10 +252,12 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): await asyncio.sleep(0.1) self.assertFalse(r.done()) - r.cancel() - with self.assertRaises(asyncio.CancelledError): + r.cancel("test") + with self.assertRaises(asyncio.CancelledError) as cm: await r + self.assertEqual(cm.exception.args, ('test',)) + async def test_taskgroup_09(self): t1 = t2 = None @@ -699,3 +707,7 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): async with taskgroups.TaskGroup() as g: t = g.create_task(coro(), name="yolo") self.assertEqual(t.get_name(), "yolo") + + +if __name__ == "__main__": + unittest.main() -- cgit v0.12