diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/taskgroups.py | 19 | ||||
-rw-r--r-- | Lib/asyncio/tasks.py | 2 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_taskgroups.py | 66 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_tasks.py | 24 |
4 files changed, 105 insertions, 6 deletions
diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 57f0123..f2ee964 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -77,12 +77,6 @@ class TaskGroup: propagate_cancellation_error = exc else: propagate_cancellation_error = None - if self._parent_cancel_requested: - # If this flag is set we *must* call uncancel(). - if self._parent_task.uncancel() == 0: - # If there are no pending cancellations left, - # don't propagate CancelledError. - propagate_cancellation_error = None if et is not None: if not self._aborting: @@ -130,6 +124,13 @@ class TaskGroup: if self._base_error is not None: raise self._base_error + if self._parent_cancel_requested: + # If this flag is set we *must* call uncancel(). + if self._parent_task.uncancel() == 0: + # If there are no pending cancellations left, + # don't propagate CancelledError. + propagate_cancellation_error = None + # Propagate CancelledError if there is one, except if there # are other errors -- those have priority. if propagate_cancellation_error is not None and not self._errors: @@ -139,6 +140,12 @@ class TaskGroup: self._errors.append(exc) if self._errors: + # If the parent task is being cancelled from the outside + # of the taskgroup, un-cancel and re-cancel the parent task, + # which will keep the cancel count stable. + if self._parent_task.cancelling(): + self._parent_task.uncancel() + self._parent_task.cancel() # Exceptions are heavy objects that can have object # cycles (bad for GC); let's not keep a reference to # a bunch of them. diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index 7fb697b9..dadcb5b 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -255,6 +255,8 @@ class Task(futures._PyFuture): # Inherit Python Task implementation """ if self._num_cancels_requested > 0: self._num_cancels_requested -= 1 + if self._num_cancels_requested == 0: + self._must_cancel = False return self._num_cancels_requested def __eager_start(self): diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index 1ec8116..4852536 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -833,6 +833,72 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): loop = asyncio.get_event_loop() loop.run_until_complete(run_coro_after_tg_closes()) + async def test_cancelling_level_preserved(self): + async def raise_after(t, e): + await asyncio.sleep(t) + raise e() + + try: + async with asyncio.TaskGroup() as tg: + tg.create_task(raise_after(0.0, RuntimeError)) + except* RuntimeError: + pass + self.assertEqual(asyncio.current_task().cancelling(), 0) + + async def test_nested_groups_both_cancelled(self): + async def raise_after(t, e): + await asyncio.sleep(t) + raise e() + + try: + async with asyncio.TaskGroup() as outer_tg: + try: + async with asyncio.TaskGroup() as inner_tg: + inner_tg.create_task(raise_after(0, RuntimeError)) + outer_tg.create_task(raise_after(0, ValueError)) + except* RuntimeError: + pass + else: + self.fail("RuntimeError not raised") + self.assertEqual(asyncio.current_task().cancelling(), 1) + except* ValueError: + pass + else: + self.fail("ValueError not raised") + self.assertEqual(asyncio.current_task().cancelling(), 0) + + async def test_error_and_cancel(self): + event = asyncio.Event() + + async def raise_error(): + event.set() + await asyncio.sleep(0) + raise RuntimeError() + + async def inner(): + try: + async with taskgroups.TaskGroup() as tg: + tg.create_task(raise_error()) + await asyncio.sleep(1) + self.fail("Sleep in group should have been cancelled") + except* RuntimeError: + self.assertEqual(asyncio.current_task().cancelling(), 1) + self.assertEqual(asyncio.current_task().cancelling(), 1) + await asyncio.sleep(1) + self.fail("Sleep after group should have been cancelled") + + async def outer(): + t = asyncio.create_task(inner()) + await event.wait() + self.assertEqual(t.cancelling(), 0) + t.cancel() + self.assertEqual(t.cancelling(), 1) + with self.assertRaises(asyncio.CancelledError): + await t + self.assertTrue(t.cancelled()) + + await outer() + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index bc6d88e..5b09c81 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -684,6 +684,30 @@ class BaseTaskTests: finally: loop.close() + def test_uncancel_resets_must_cancel(self): + + async def coro(): + await fut + return 42 + + loop = asyncio.new_event_loop() + fut = asyncio.Future(loop=loop) + task = self.new_task(loop, coro()) + loop.run_until_complete(asyncio.sleep(0)) # Get task waiting for fut + fut.set_result(None) # Make task runnable + try: + task.cancel() # Enter cancelled state + self.assertEqual(task.cancelling(), 1) + self.assertTrue(task._must_cancel) + + task.uncancel() # Undo cancellation + self.assertEqual(task.cancelling(), 0) + self.assertFalse(task._must_cancel) + finally: + res = loop.run_until_complete(task) + self.assertEqual(res, 42) + loop.close() + def test_cancel(self): def gen(): |