summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/taskgroups.py19
-rw-r--r--Lib/asyncio/tasks.py2
-rw-r--r--Lib/test/test_asyncio/test_taskgroups.py66
-rw-r--r--Lib/test/test_asyncio/test_tasks.py24
4 files changed, 105 insertions, 6 deletions
diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py
index 57f0123..f2ee964 100644
--- a/Lib/asyncio/taskgroups.py
+++ b/Lib/asyncio/taskgroups.py
@@ -77,12 +77,6 @@ class TaskGroup:
propagate_cancellation_error = exc
else:
propagate_cancellation_error = None
- if self._parent_cancel_requested:
- # If this flag is set we *must* call uncancel().
- if self._parent_task.uncancel() == 0:
- # If there are no pending cancellations left,
- # don't propagate CancelledError.
- propagate_cancellation_error = None
if et is not None:
if not self._aborting:
@@ -130,6 +124,13 @@ class TaskGroup:
if self._base_error is not None:
raise self._base_error
+ if self._parent_cancel_requested:
+ # If this flag is set we *must* call uncancel().
+ if self._parent_task.uncancel() == 0:
+ # If there are no pending cancellations left,
+ # don't propagate CancelledError.
+ propagate_cancellation_error = None
+
# Propagate CancelledError if there is one, except if there
# are other errors -- those have priority.
if propagate_cancellation_error is not None and not self._errors:
@@ -139,6 +140,12 @@ class TaskGroup:
self._errors.append(exc)
if self._errors:
+ # If the parent task is being cancelled from the outside
+ # of the taskgroup, un-cancel and re-cancel the parent task,
+ # which will keep the cancel count stable.
+ if self._parent_task.cancelling():
+ self._parent_task.uncancel()
+ self._parent_task.cancel()
# Exceptions are heavy objects that can have object
# cycles (bad for GC); let's not keep a reference to
# a bunch of them.
diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py
index 7fb697b9..dadcb5b 100644
--- a/Lib/asyncio/tasks.py
+++ b/Lib/asyncio/tasks.py
@@ -255,6 +255,8 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
"""
if self._num_cancels_requested > 0:
self._num_cancels_requested -= 1
+ if self._num_cancels_requested == 0:
+ self._must_cancel = False
return self._num_cancels_requested
def __eager_start(self):
diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py
index 1ec8116..4852536 100644
--- a/Lib/test/test_asyncio/test_taskgroups.py
+++ b/Lib/test/test_asyncio/test_taskgroups.py
@@ -833,6 +833,72 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
loop = asyncio.get_event_loop()
loop.run_until_complete(run_coro_after_tg_closes())
+ async def test_cancelling_level_preserved(self):
+ async def raise_after(t, e):
+ await asyncio.sleep(t)
+ raise e()
+
+ try:
+ async with asyncio.TaskGroup() as tg:
+ tg.create_task(raise_after(0.0, RuntimeError))
+ except* RuntimeError:
+ pass
+ self.assertEqual(asyncio.current_task().cancelling(), 0)
+
+ async def test_nested_groups_both_cancelled(self):
+ async def raise_after(t, e):
+ await asyncio.sleep(t)
+ raise e()
+
+ try:
+ async with asyncio.TaskGroup() as outer_tg:
+ try:
+ async with asyncio.TaskGroup() as inner_tg:
+ inner_tg.create_task(raise_after(0, RuntimeError))
+ outer_tg.create_task(raise_after(0, ValueError))
+ except* RuntimeError:
+ pass
+ else:
+ self.fail("RuntimeError not raised")
+ self.assertEqual(asyncio.current_task().cancelling(), 1)
+ except* ValueError:
+ pass
+ else:
+ self.fail("ValueError not raised")
+ self.assertEqual(asyncio.current_task().cancelling(), 0)
+
+ async def test_error_and_cancel(self):
+ event = asyncio.Event()
+
+ async def raise_error():
+ event.set()
+ await asyncio.sleep(0)
+ raise RuntimeError()
+
+ async def inner():
+ try:
+ async with taskgroups.TaskGroup() as tg:
+ tg.create_task(raise_error())
+ await asyncio.sleep(1)
+ self.fail("Sleep in group should have been cancelled")
+ except* RuntimeError:
+ self.assertEqual(asyncio.current_task().cancelling(), 1)
+ self.assertEqual(asyncio.current_task().cancelling(), 1)
+ await asyncio.sleep(1)
+ self.fail("Sleep after group should have been cancelled")
+
+ async def outer():
+ t = asyncio.create_task(inner())
+ await event.wait()
+ self.assertEqual(t.cancelling(), 0)
+ t.cancel()
+ self.assertEqual(t.cancelling(), 1)
+ with self.assertRaises(asyncio.CancelledError):
+ await t
+ self.assertTrue(t.cancelled())
+
+ await outer()
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
index bc6d88e..5b09c81 100644
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -684,6 +684,30 @@ class BaseTaskTests:
finally:
loop.close()
+ def test_uncancel_resets_must_cancel(self):
+
+ async def coro():
+ await fut
+ return 42
+
+ loop = asyncio.new_event_loop()
+ fut = asyncio.Future(loop=loop)
+ task = self.new_task(loop, coro())
+ loop.run_until_complete(asyncio.sleep(0)) # Get task waiting for fut
+ fut.set_result(None) # Make task runnable
+ try:
+ task.cancel() # Enter cancelled state
+ self.assertEqual(task.cancelling(), 1)
+ self.assertTrue(task._must_cancel)
+
+ task.uncancel() # Undo cancellation
+ self.assertEqual(task.cancelling(), 0)
+ self.assertFalse(task._must_cancel)
+ finally:
+ res = loop.run_until_complete(task)
+ self.assertEqual(res, 42)
+ loop.close()
+
def test_cancel(self):
def gen():