summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/asyncio/__init__.py1
-rw-r--r--Lib/asyncio/base_tasks.py2
-rw-r--r--Lib/asyncio/taskgroups.py235
-rw-r--r--Lib/asyncio/tasks.py16
-rw-r--r--Lib/test/test_asyncio/test_taskgroups.py694
-rw-r--r--Lib/test/test_asyncio/test_tasks.py45
-rw-r--r--Misc/NEWS.d/next/Library/2022-02-14-21-21-49.bpo-46752.m6ldTm.rst2
-rw-r--r--Modules/_asynciomodule.c59
-rw-r--r--Modules/clinic/_asynciomodule.c.h49
9 files changed, 1100 insertions, 3 deletions
diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py
index 200b14c..db1124c 100644
--- a/Lib/asyncio/__init__.py
+++ b/Lib/asyncio/__init__.py
@@ -17,6 +17,7 @@ from .queues import *
from .streams import *
from .subprocess import *
from .tasks import *
+from .taskgroups import *
from .threads import *
from .transports import *
diff --git a/Lib/asyncio/base_tasks.py b/Lib/asyncio/base_tasks.py
index 09bb171..1d62389 100644
--- a/Lib/asyncio/base_tasks.py
+++ b/Lib/asyncio/base_tasks.py
@@ -8,7 +8,7 @@ from . import coroutines
def _task_repr_info(task):
info = base_futures._future_repr_info(task)
- if task._must_cancel:
+ if task.cancelling() and not task.done():
# replace status
info[0] = 'cancelling'
diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py
new file mode 100644
index 0000000..7182778
--- /dev/null
+++ b/Lib/asyncio/taskgroups.py
@@ -0,0 +1,235 @@
+# Adapted with permission from the EdgeDB project.
+
+
+__all__ = ["TaskGroup"]
+
+import itertools
+import textwrap
+import traceback
+import types
+import weakref
+
+from . import events
+from . import exceptions
+from . import tasks
+
+class TaskGroup:
+
+ def __init__(self, *, name=None):
+ if name is None:
+ self._name = f'tg-{_name_counter()}'
+ else:
+ self._name = str(name)
+
+ self._entered = False
+ self._exiting = False
+ self._aborting = False
+ self._loop = None
+ self._parent_task = None
+ self._parent_cancel_requested = False
+ self._tasks = weakref.WeakSet()
+ self._unfinished_tasks = 0
+ self._errors = []
+ self._base_error = None
+ self._on_completed_fut = None
+
+ def get_name(self):
+ return self._name
+
+ def __repr__(self):
+ msg = f'<TaskGroup {self._name!r}'
+ if self._tasks:
+ msg += f' tasks:{len(self._tasks)}'
+ if self._unfinished_tasks:
+ msg += f' unfinished:{self._unfinished_tasks}'
+ if self._errors:
+ msg += f' errors:{len(self._errors)}'
+ if self._aborting:
+ msg += ' cancelling'
+ elif self._entered:
+ msg += ' entered'
+ msg += '>'
+ return msg
+
+ async def __aenter__(self):
+ if self._entered:
+ raise RuntimeError(
+ f"TaskGroup {self!r} has been already entered")
+ self._entered = True
+
+ if self._loop is None:
+ self._loop = events.get_running_loop()
+
+ self._parent_task = tasks.current_task(self._loop)
+ if self._parent_task is None:
+ raise RuntimeError(
+ f'TaskGroup {self!r} cannot determine the parent task')
+
+ return self
+
+ async def __aexit__(self, et, exc, tb):
+ self._exiting = True
+ propagate_cancellation_error = None
+
+ if (exc is not None and
+ self._is_base_error(exc) and
+ self._base_error is None):
+ self._base_error = exc
+
+ if et is exceptions.CancelledError:
+ if self._parent_cancel_requested:
+ # Only if we did request task to cancel ourselves
+ # we mark it as no longer cancelled.
+ self._parent_task.uncancel()
+ else:
+ propagate_cancellation_error = et
+
+ if et is not None and not self._aborting:
+ # Our parent task is being cancelled:
+ #
+ # async with TaskGroup() as g:
+ # g.create_task(...)
+ # await ... # <- CancelledError
+ #
+ if et is exceptions.CancelledError:
+ propagate_cancellation_error = et
+
+ # or there's an exception in "async with":
+ #
+ # async with TaskGroup() as g:
+ # g.create_task(...)
+ # 1 / 0
+ #
+ self._abort()
+
+ # We use while-loop here because "self._on_completed_fut"
+ # can be cancelled multiple times if our parent task
+ # is being cancelled repeatedly (or even once, when
+ # our own cancellation is already in progress)
+ while self._unfinished_tasks:
+ if self._on_completed_fut is None:
+ self._on_completed_fut = self._loop.create_future()
+
+ try:
+ await self._on_completed_fut
+ except exceptions.CancelledError as ex:
+ if not self._aborting:
+ # Our parent task is being cancelled:
+ #
+ # async def wrapper():
+ # async with TaskGroup() as g:
+ # g.create_task(foo)
+ #
+ # "wrapper" is being cancelled while "foo" is
+ # still running.
+ propagate_cancellation_error = ex
+ self._abort()
+
+ self._on_completed_fut = None
+
+ assert self._unfinished_tasks == 0
+ self._on_completed_fut = None # no longer needed
+
+ if self._base_error is not None:
+ raise self._base_error
+
+ if propagate_cancellation_error is not None:
+ # The wrapping task was cancelled; since we're done with
+ # closing all child tasks, just propagate the cancellation
+ # request now.
+ raise propagate_cancellation_error
+
+ if et is not None and et is not exceptions.CancelledError:
+ self._errors.append(exc)
+
+ if self._errors:
+ # Exceptions are heavy objects that can have object
+ # cycles (bad for GC); let's not keep a reference to
+ # a bunch of them.
+ errors = self._errors
+ self._errors = None
+
+ me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
+ raise me from None
+
+ def create_task(self, coro):
+ if not self._entered:
+ raise RuntimeError(f"TaskGroup {self!r} has not been entered")
+ if self._exiting and self._unfinished_tasks == 0:
+ raise RuntimeError(f"TaskGroup {self!r} is finished")
+ task = self._loop.create_task(coro)
+ task.add_done_callback(self._on_task_done)
+ self._unfinished_tasks += 1
+ self._tasks.add(task)
+ return task
+
+ # Since Python 3.8 Tasks propagate all exceptions correctly,
+ # except for KeyboardInterrupt and SystemExit which are
+ # still considered special.
+
+ def _is_base_error(self, exc: BaseException) -> bool:
+ assert isinstance(exc, BaseException)
+ return isinstance(exc, (SystemExit, KeyboardInterrupt))
+
+ def _abort(self):
+ self._aborting = True
+
+ for t in self._tasks:
+ if not t.done():
+ t.cancel()
+
+ def _on_task_done(self, task):
+ self._unfinished_tasks -= 1
+ assert self._unfinished_tasks >= 0
+
+ if self._on_completed_fut is not None and not self._unfinished_tasks:
+ if not self._on_completed_fut.done():
+ self._on_completed_fut.set_result(True)
+
+ if task.cancelled():
+ return
+
+ exc = task.exception()
+ if exc is None:
+ return
+
+ self._errors.append(exc)
+ if self._is_base_error(exc) and self._base_error is None:
+ self._base_error = exc
+
+ if self._parent_task.done():
+ # Not sure if this case is possible, but we want to handle
+ # it anyways.
+ self._loop.call_exception_handler({
+ 'message': f'Task {task!r} has errored out but its parent '
+ f'task {self._parent_task} is already completed',
+ 'exception': exc,
+ 'task': task,
+ })
+ return
+
+ self._abort()
+ if not self._parent_task.cancelling():
+ # If parent task *is not* being cancelled, it means that we want
+ # to manually cancel it to abort whatever is being run right now
+ # in the TaskGroup. But we want to mark parent task as
+ # "not cancelled" later in __aexit__. Example situation that
+ # we need to handle:
+ #
+ # async def foo():
+ # try:
+ # async with TaskGroup() as g:
+ # g.create_task(crash_soon())
+ # await something # <- this needs to be canceled
+ # # by the TaskGroup, e.g.
+ # # foo() needs to be cancelled
+ # except Exception:
+ # # Ignore any exceptions raised in the TaskGroup
+ # pass
+ # await something_else # this line has to be called
+ # # after TaskGroup is finished.
+ self._parent_cancel_requested = True
+ self._parent_task.cancel()
+
+
+_name_counter = itertools.count(1).__next__
diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py
index 2bee5c0..c11d0da 100644
--- a/Lib/asyncio/tasks.py
+++ b/Lib/asyncio/tasks.py
@@ -105,6 +105,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
else:
self._name = str(name)
+ self._cancel_requested = False
self._must_cancel = False
self._fut_waiter = None
self._coro = coro
@@ -201,6 +202,9 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
self._log_traceback = False
if self.done():
return False
+ if self._cancel_requested:
+ return False
+ self._cancel_requested = True
if self._fut_waiter is not None:
if self._fut_waiter.cancel(msg=msg):
# Leave self._fut_waiter; it may be a Task that
@@ -212,6 +216,16 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
self._cancel_message = msg
return True
+ def cancelling(self):
+ return self._cancel_requested
+
+ def uncancel(self):
+ if self._cancel_requested:
+ self._cancel_requested = False
+ return True
+ else:
+ return False
+
def __step(self, exc=None):
if self.done():
raise exceptions.InvalidStateError(
@@ -634,7 +648,7 @@ def _ensure_future(coro_or_future, *, loop=None):
loop = events._get_event_loop(stacklevel=4)
try:
return loop.create_task(coro_or_future)
- except RuntimeError:
+ except RuntimeError:
if not called_wrap_awaitable:
coro_or_future.close()
raise
diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py
new file mode 100644
index 0000000..ea6ee2e
--- /dev/null
+++ b/Lib/test/test_asyncio/test_taskgroups.py
@@ -0,0 +1,694 @@
+# Adapted with permission from the EdgeDB project.
+
+
+import asyncio
+
+from asyncio import taskgroups
+import unittest
+
+
+# To prevent a warning "test altered the execution environment"
+def tearDownModule():
+ asyncio.set_event_loop_policy(None)
+
+
+class MyExc(Exception):
+ pass
+
+
+class MyBaseExc(BaseException):
+ pass
+
+
+def get_error_types(eg):
+ return {type(exc) for exc in eg.exceptions}
+
+
+class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
+
+ async def test_taskgroup_01(self):
+
+ async def foo1():
+ await asyncio.sleep(0.1)
+ return 42
+
+ async def foo2():
+ await asyncio.sleep(0.2)
+ return 11
+
+ async with taskgroups.TaskGroup() as g:
+ t1 = g.create_task(foo1())
+ t2 = g.create_task(foo2())
+
+ self.assertEqual(t1.result(), 42)
+ self.assertEqual(t2.result(), 11)
+
+ async def test_taskgroup_02(self):
+
+ async def foo1():
+ await asyncio.sleep(0.1)
+ return 42
+
+ async def foo2():
+ await asyncio.sleep(0.2)
+ return 11
+
+ async with taskgroups.TaskGroup() as g:
+ t1 = g.create_task(foo1())
+ await asyncio.sleep(0.15)
+ t2 = g.create_task(foo2())
+
+ self.assertEqual(t1.result(), 42)
+ self.assertEqual(t2.result(), 11)
+
+ async def test_taskgroup_03(self):
+
+ async def foo1():
+ await asyncio.sleep(1)
+ return 42
+
+ async def foo2():
+ await asyncio.sleep(0.2)
+ return 11
+
+ async with taskgroups.TaskGroup() as g:
+ t1 = g.create_task(foo1())
+ await asyncio.sleep(0.15)
+ # cancel t1 explicitly, i.e. everything should continue
+ # working as expected.
+ t1.cancel()
+
+ t2 = g.create_task(foo2())
+
+ self.assertTrue(t1.cancelled())
+ self.assertEqual(t2.result(), 11)
+
+ async def test_taskgroup_04(self):
+
+ NUM = 0
+ t2_cancel = False
+ t2 = None
+
+ async def foo1():
+ await asyncio.sleep(0.1)
+ 1 / 0
+
+ async def foo2():
+ nonlocal NUM, t2_cancel
+ try:
+ await asyncio.sleep(1)
+ except asyncio.CancelledError:
+ t2_cancel = True
+ raise
+ NUM += 1
+
+ async def runner():
+ nonlocal NUM, t2
+
+ async with taskgroups.TaskGroup() as g:
+ g.create_task(foo1())
+ t2 = g.create_task(foo2())
+
+ NUM += 10
+
+ with self.assertRaises(ExceptionGroup) as cm:
+ await asyncio.create_task(runner())
+
+ self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
+
+ self.assertEqual(NUM, 0)
+ self.assertTrue(t2_cancel)
+ self.assertTrue(t2.cancelled())
+
+ async def test_taskgroup_05(self):
+
+ NUM = 0
+ t2_cancel = False
+ runner_cancel = False
+
+ async def foo1():
+ await asyncio.sleep(0.1)
+ 1 / 0
+
+ async def foo2():
+ nonlocal NUM, t2_cancel
+ try:
+ await asyncio.sleep(5)
+ except asyncio.CancelledError:
+ t2_cancel = True
+ raise
+ NUM += 1
+
+ async def runner():
+ nonlocal NUM, runner_cancel
+
+ async with taskgroups.TaskGroup() as g:
+ g.create_task(foo1())
+ g.create_task(foo1())
+ g.create_task(foo1())
+ g.create_task(foo2())
+ try:
+ await asyncio.sleep(10)
+ except asyncio.CancelledError:
+ runner_cancel = True
+ raise
+
+ NUM += 10
+
+ # The 3 foo1 sub tasks can be racy when the host is busy - if the
+ # cancellation happens in the middle, we'll see partial sub errors here
+ with self.assertRaises(ExceptionGroup) as cm:
+ await asyncio.create_task(runner())
+
+ self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
+ self.assertEqual(NUM, 0)
+ self.assertTrue(t2_cancel)
+ self.assertTrue(runner_cancel)
+
+ async def test_taskgroup_06(self):
+
+ NUM = 0
+
+ async def foo():
+ nonlocal NUM
+ try:
+ await asyncio.sleep(5)
+ except asyncio.CancelledError:
+ NUM += 1
+ raise
+
+ async def runner():
+ async with taskgroups.TaskGroup() as g:
+ for _ in range(5):
+ g.create_task(foo())
+
+ r = asyncio.create_task(runner())
+ await asyncio.sleep(0.1)
+
+ self.assertFalse(r.done())
+ r.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ await r
+
+ self.assertEqual(NUM, 5)
+
+ async def test_taskgroup_07(self):
+
+ NUM = 0
+
+ async def foo():
+ nonlocal NUM
+ try:
+ await asyncio.sleep(5)
+ except asyncio.CancelledError:
+ NUM += 1
+ raise
+
+ async def runner():
+ nonlocal NUM
+ async with taskgroups.TaskGroup() as g:
+ for _ in range(5):
+ g.create_task(foo())
+
+ try:
+ await asyncio.sleep(10)
+ except asyncio.CancelledError:
+ NUM += 10
+ raise
+
+ r = asyncio.create_task(runner())
+ await asyncio.sleep(0.1)
+
+ self.assertFalse(r.done())
+ r.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ await r
+
+ self.assertEqual(NUM, 15)
+
+ async def test_taskgroup_08(self):
+
+ async def foo():
+ await asyncio.sleep(0.1)
+ 1 / 0
+
+ async def runner():
+ async with taskgroups.TaskGroup() as g:
+ for _ in range(5):
+ g.create_task(foo())
+
+ try:
+ await asyncio.sleep(10)
+ except asyncio.CancelledError:
+ raise
+
+ r = asyncio.create_task(runner())
+ await asyncio.sleep(0.1)
+
+ self.assertFalse(r.done())
+ r.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ await r
+
+ async def test_taskgroup_09(self):
+
+ t1 = t2 = None
+
+ async def foo1():
+ await asyncio.sleep(1)
+ return 42
+
+ async def foo2():
+ await asyncio.sleep(2)
+ return 11
+
+ async def runner():
+ nonlocal t1, t2
+ async with taskgroups.TaskGroup() as g:
+ t1 = g.create_task(foo1())
+ t2 = g.create_task(foo2())
+ await asyncio.sleep(0.1)
+ 1 / 0
+
+ try:
+ await runner()
+ except ExceptionGroup as t:
+ self.assertEqual(get_error_types(t), {ZeroDivisionError})
+ else:
+ self.fail('ExceptionGroup was not raised')
+
+ self.assertTrue(t1.cancelled())
+ self.assertTrue(t2.cancelled())
+
+ async def test_taskgroup_10(self):
+
+ t1 = t2 = None
+
+ async def foo1():
+ await asyncio.sleep(1)
+ return 42
+
+ async def foo2():
+ await asyncio.sleep(2)
+ return 11
+
+ async def runner():
+ nonlocal t1, t2
+ async with taskgroups.TaskGroup() as g:
+ t1 = g.create_task(foo1())
+ t2 = g.create_task(foo2())
+ 1 / 0
+
+ try:
+ await runner()
+ except ExceptionGroup as t:
+ self.assertEqual(get_error_types(t), {ZeroDivisionError})
+ else:
+ self.fail('ExceptionGroup was not raised')
+
+ self.assertTrue(t1.cancelled())
+ self.assertTrue(t2.cancelled())
+
+ async def test_taskgroup_11(self):
+
+ async def foo():
+ await asyncio.sleep(0.1)
+ 1 / 0
+
+ async def runner():
+ async with taskgroups.TaskGroup():
+ async with taskgroups.TaskGroup() as g2:
+ for _ in range(5):
+ g2.create_task(foo())
+
+ try:
+ await asyncio.sleep(10)
+ except asyncio.CancelledError:
+ raise
+
+ r = asyncio.create_task(runner())
+ await asyncio.sleep(0.1)
+
+ self.assertFalse(r.done())
+ r.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ await r
+
+ async def test_taskgroup_12(self):
+
+ async def foo():
+ await asyncio.sleep(0.1)
+ 1 / 0
+
+ async def runner():
+ async with taskgroups.TaskGroup() as g1:
+ g1.create_task(asyncio.sleep(10))
+
+ async with taskgroups.TaskGroup() as g2:
+ for _ in range(5):
+ g2.create_task(foo())
+
+ try:
+ await asyncio.sleep(10)
+ except asyncio.CancelledError:
+ raise
+
+ r = asyncio.create_task(runner())
+ await asyncio.sleep(0.1)
+
+ self.assertFalse(r.done())
+ r.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ await r
+
+ async def test_taskgroup_13(self):
+
+ async def crash_after(t):
+ await asyncio.sleep(t)
+ raise ValueError(t)
+
+ async def runner():
+ async with taskgroups.TaskGroup(name='g1') as g1:
+ g1.create_task(crash_after(0.1))
+
+ async with taskgroups.TaskGroup(name='g2') as g2:
+ g2.create_task(crash_after(0.2))
+
+ r = asyncio.create_task(runner())
+ with self.assertRaises(ExceptionGroup) as cm:
+ await r
+
+ self.assertEqual(get_error_types(cm.exception), {ValueError})
+
+ async def test_taskgroup_14(self):
+
+ async def crash_after(t):
+ await asyncio.sleep(t)
+ raise ValueError(t)
+
+ async def runner():
+ async with taskgroups.TaskGroup(name='g1') as g1:
+ g1.create_task(crash_after(10))
+
+ async with taskgroups.TaskGroup(name='g2') as g2:
+ g2.create_task(crash_after(0.1))
+
+ r = asyncio.create_task(runner())
+ with self.assertRaises(ExceptionGroup) as cm:
+ await r
+
+ self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
+ self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
+
+ async def test_taskgroup_15(self):
+
+ async def crash_soon():
+ await asyncio.sleep(0.3)
+ 1 / 0
+
+ async def runner():
+ async with taskgroups.TaskGroup(name='g1') as g1:
+ g1.create_task(crash_soon())
+ try:
+ await asyncio.sleep(10)
+ except asyncio.CancelledError:
+ await asyncio.sleep(0.5)
+ raise
+
+ r = asyncio.create_task(runner())
+ await asyncio.sleep(0.1)
+
+ self.assertFalse(r.done())
+ r.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ await r
+
+ async def test_taskgroup_16(self):
+
+ async def crash_soon():
+ await asyncio.sleep(0.3)
+ 1 / 0
+
+ async def nested_runner():
+ async with taskgroups.TaskGroup(name='g1') as g1:
+ g1.create_task(crash_soon())
+ try:
+ await asyncio.sleep(10)
+ except asyncio.CancelledError:
+ await asyncio.sleep(0.5)
+ raise
+
+ async def runner():
+ t = asyncio.create_task(nested_runner())
+ await t
+
+ r = asyncio.create_task(runner())
+ await asyncio.sleep(0.1)
+
+ self.assertFalse(r.done())
+ r.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ await r
+
+ async def test_taskgroup_17(self):
+ NUM = 0
+
+ async def runner():
+ nonlocal NUM
+ async with taskgroups.TaskGroup():
+ try:
+ await asyncio.sleep(10)
+ except asyncio.CancelledError:
+ NUM += 10
+ raise
+
+ r = asyncio.create_task(runner())
+ await asyncio.sleep(0.1)
+
+ self.assertFalse(r.done())
+ r.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ await r
+
+ self.assertEqual(NUM, 10)
+
+ async def test_taskgroup_18(self):
+ NUM = 0
+
+ async def runner():
+ nonlocal NUM
+ async with taskgroups.TaskGroup():
+ try:
+ await asyncio.sleep(10)
+ except asyncio.CancelledError:
+ NUM += 10
+ # This isn't a good idea, but we have to support
+ # this weird case.
+ raise MyExc
+
+ r = asyncio.create_task(runner())
+ await asyncio.sleep(0.1)
+
+ self.assertFalse(r.done())
+ r.cancel()
+
+ try:
+ await r
+ except ExceptionGroup as t:
+ self.assertEqual(get_error_types(t),{MyExc})
+ else:
+ self.fail('ExceptionGroup was not raised')
+
+ self.assertEqual(NUM, 10)
+
+ async def test_taskgroup_19(self):
+ async def crash_soon():
+ await asyncio.sleep(0.1)
+ 1 / 0
+
+ async def nested():
+ try:
+ await asyncio.sleep(10)
+ finally:
+ raise MyExc
+
+ async def runner():
+ async with taskgroups.TaskGroup() as g:
+ g.create_task(crash_soon())
+ await nested()
+
+ r = asyncio.create_task(runner())
+ try:
+ await r
+ except ExceptionGroup as t:
+ self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
+ else:
+ self.fail('TasgGroupError was not raised')
+
+ async def test_taskgroup_20(self):
+ async def crash_soon():
+ await asyncio.sleep(0.1)
+ 1 / 0
+
+ async def nested():
+ try:
+ await asyncio.sleep(10)
+ finally:
+ raise KeyboardInterrupt
+
+ async def runner():
+ async with taskgroups.TaskGroup() as g:
+ g.create_task(crash_soon())
+ await nested()
+
+ with self.assertRaises(KeyboardInterrupt):
+ await runner()
+
+ async def test_taskgroup_20a(self):
+ async def crash_soon():
+ await asyncio.sleep(0.1)
+ 1 / 0
+
+ async def nested():
+ try:
+ await asyncio.sleep(10)
+ finally:
+ raise MyBaseExc
+
+ async def runner():
+ async with taskgroups.TaskGroup() as g:
+ g.create_task(crash_soon())
+ await nested()
+
+ with self.assertRaises(BaseExceptionGroup) as cm:
+ await runner()
+
+ self.assertEqual(
+ get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
+ )
+
+ async def _test_taskgroup_21(self):
+ # This test doesn't work as asyncio, currently, doesn't
+ # correctly propagate KeyboardInterrupt (or SystemExit) --
+ # those cause the event loop itself to crash.
+ # (Compare to the previous (passing) test -- that one raises
+ # a plain exception but raises KeyboardInterrupt in nested();
+ # this test does it the other way around.)
+
+ async def crash_soon():
+ await asyncio.sleep(0.1)
+ raise KeyboardInterrupt
+
+ async def nested():
+ try:
+ await asyncio.sleep(10)
+ finally:
+ raise TypeError
+
+ async def runner():
+ async with taskgroups.TaskGroup() as g:
+ g.create_task(crash_soon())
+ await nested()
+
+ with self.assertRaises(KeyboardInterrupt):
+ await runner()
+
+ async def test_taskgroup_21a(self):
+
+ async def crash_soon():
+ await asyncio.sleep(0.1)
+ raise MyBaseExc
+
+ async def nested():
+ try:
+ await asyncio.sleep(10)
+ finally:
+ raise TypeError
+
+ async def runner():
+ async with taskgroups.TaskGroup() as g:
+ g.create_task(crash_soon())
+ await nested()
+
+ with self.assertRaises(BaseExceptionGroup) as cm:
+ await runner()
+
+ self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
+
+ async def test_taskgroup_22(self):
+
+ async def foo1():
+ await asyncio.sleep(1)
+ return 42
+
+ async def foo2():
+ await asyncio.sleep(2)
+ return 11
+
+ async def runner():
+ async with taskgroups.TaskGroup() as g:
+ g.create_task(foo1())
+ g.create_task(foo2())
+
+ r = asyncio.create_task(runner())
+ await asyncio.sleep(0.05)
+ r.cancel()
+
+ with self.assertRaises(asyncio.CancelledError):
+ await r
+
+ async def test_taskgroup_23(self):
+
+ async def do_job(delay):
+ await asyncio.sleep(delay)
+
+ async with taskgroups.TaskGroup() as g:
+ for count in range(10):
+ await asyncio.sleep(0.1)
+ g.create_task(do_job(0.3))
+ if count == 5:
+ self.assertLess(len(g._tasks), 5)
+ await asyncio.sleep(1.35)
+ self.assertEqual(len(g._tasks), 0)
+
+ async def test_taskgroup_24(self):
+
+ async def root(g):
+ await asyncio.sleep(0.1)
+ g.create_task(coro1(0.1))
+ g.create_task(coro1(0.2))
+
+ async def coro1(delay):
+ await asyncio.sleep(delay)
+
+ async def runner():
+ async with taskgroups.TaskGroup() as g:
+ g.create_task(root(g))
+
+ await runner()
+
+ async def test_taskgroup_25(self):
+ nhydras = 0
+
+ async def hydra(g):
+ nonlocal nhydras
+ nhydras += 1
+ await asyncio.sleep(0.01)
+ g.create_task(hydra(g))
+ g.create_task(hydra(g))
+
+ async def hercules():
+ while nhydras < 10:
+ await asyncio.sleep(0.015)
+ 1 / 0
+
+ async def runner():
+ async with taskgroups.TaskGroup() as g:
+ g.create_task(hydra(g))
+ g.create_task(hercules())
+
+ with self.assertRaises(ExceptionGroup) as cm:
+ await runner()
+
+ self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
+ self.assertGreaterEqual(nhydras, 10)
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
index 8c4dcea..fe6bfb3 100644
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -496,6 +496,51 @@ class BaseTaskTests:
# This also distinguishes from the initial has_cycle=None.
self.assertEqual(has_cycle, False)
+
+ def test_cancelling(self):
+ loop = asyncio.new_event_loop()
+
+ async def task():
+ await asyncio.sleep(10)
+
+ try:
+ t = self.new_task(loop, task())
+ self.assertFalse(t.cancelling())
+ self.assertNotIn(" cancelling ", repr(t))
+ self.assertTrue(t.cancel())
+ self.assertTrue(t.cancelling())
+ self.assertIn(" cancelling ", repr(t))
+ self.assertFalse(t.cancel())
+
+ with self.assertRaises(asyncio.CancelledError):
+ loop.run_until_complete(t)
+ finally:
+ loop.close()
+
+ def test_uncancel(self):
+ loop = asyncio.new_event_loop()
+
+ async def task():
+ try:
+ await asyncio.sleep(10)
+ except asyncio.CancelledError:
+ asyncio.current_task().uncancel()
+ await asyncio.sleep(10)
+
+ try:
+ t = self.new_task(loop, task())
+ loop.run_until_complete(asyncio.sleep(0.01))
+ self.assertTrue(t.cancel()) # Cancel first sleep
+ self.assertIn(" cancelling ", repr(t))
+ loop.run_until_complete(asyncio.sleep(0.01))
+ self.assertNotIn(" cancelling ", repr(t)) # after .uncancel()
+ self.assertTrue(t.cancel()) # Cancel second sleep
+
+ with self.assertRaises(asyncio.CancelledError):
+ loop.run_until_complete(t)
+ finally:
+ loop.close()
+
def test_cancel(self):
def gen():
diff --git a/Misc/NEWS.d/next/Library/2022-02-14-21-21-49.bpo-46752.m6ldTm.rst b/Misc/NEWS.d/next/Library/2022-02-14-21-21-49.bpo-46752.m6ldTm.rst
new file mode 100644
index 0000000..f460600
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2022-02-14-21-21-49.bpo-46752.m6ldTm.rst
@@ -0,0 +1,2 @@
+Add task groups to asyncio (structured concurrency, inspired by Trio's nurseries).
+This also introduces a change to task cancellation, where a cancelled task can't be cancelled again until it calls .uncancel().
diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c
index 72dbdb8..6725e2e 100644
--- a/Modules/_asynciomodule.c
+++ b/Modules/_asynciomodule.c
@@ -91,6 +91,7 @@ typedef struct {
PyObject *task_context;
int task_must_cancel;
int task_log_destroy_pending;
+ int task_cancel_requested;
} TaskObj;
typedef struct {
@@ -2039,6 +2040,7 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop,
Py_CLEAR(self->task_fut_waiter);
self->task_must_cancel = 0;
self->task_log_destroy_pending = 1;
+ self->task_cancel_requested = 0;
Py_INCREF(coro);
Py_XSETREF(self->task_coro, coro);
@@ -2205,6 +2207,11 @@ _asyncio_Task_cancel_impl(TaskObj *self, PyObject *msg)
Py_RETURN_FALSE;
}
+ if (self->task_cancel_requested) {
+ Py_RETURN_FALSE;
+ }
+ self->task_cancel_requested = 1;
+
if (self->task_fut_waiter) {
PyObject *res;
int is_true;
@@ -2233,6 +2240,56 @@ _asyncio_Task_cancel_impl(TaskObj *self, PyObject *msg)
}
/*[clinic input]
+_asyncio.Task.cancelling
+
+Return True if the task is in the process of being cancelled.
+
+This is set once .cancel() is called
+and remains set until .uncancel() is called.
+
+As long as this flag is set, further .cancel() calls will be ignored,
+until .uncancel() is called to reset it.
+[clinic start generated code]*/
+
+static PyObject *
+_asyncio_Task_cancelling_impl(TaskObj *self)
+/*[clinic end generated code: output=803b3af96f917d7e input=c50e50f9c3ca4676]*/
+/*[clinic end generated code]*/
+{
+ if (self->task_cancel_requested) {
+ Py_RETURN_TRUE;
+ }
+ else {
+ Py_RETURN_FALSE;
+ }
+}
+
+/*[clinic input]
+_asyncio.Task.uncancel
+
+Reset the flag returned by cancelling().
+
+This should be used by tasks that catch CancelledError
+and wish to continue indefinitely until they are cancelled again.
+
+Returns the previous value of the flag.
+[clinic start generated code]*/
+
+static PyObject *
+_asyncio_Task_uncancel_impl(TaskObj *self)
+/*[clinic end generated code: output=58184d236a817d3c input=5db95e28fcb6f7cd]*/
+/*[clinic end generated code]*/
+{
+ if (self->task_cancel_requested) {
+ self->task_cancel_requested = 0;
+ Py_RETURN_TRUE;
+ }
+ else {
+ Py_RETURN_FALSE;
+ }
+}
+
+/*[clinic input]
_asyncio.Task.get_stack
*
@@ -2455,6 +2512,8 @@ static PyMethodDef TaskType_methods[] = {
_ASYNCIO_TASK_SET_RESULT_METHODDEF
_ASYNCIO_TASK_SET_EXCEPTION_METHODDEF
_ASYNCIO_TASK_CANCEL_METHODDEF
+ _ASYNCIO_TASK_CANCELLING_METHODDEF
+ _ASYNCIO_TASK_UNCANCEL_METHODDEF
_ASYNCIO_TASK_GET_STACK_METHODDEF
_ASYNCIO_TASK_PRINT_STACK_METHODDEF
_ASYNCIO_TASK__MAKE_CANCELLED_ERROR_METHODDEF
diff --git a/Modules/clinic/_asynciomodule.c.h b/Modules/clinic/_asynciomodule.c.h
index c472e65..5648e14 100644
--- a/Modules/clinic/_asynciomodule.c.h
+++ b/Modules/clinic/_asynciomodule.c.h
@@ -447,6 +447,53 @@ exit:
return return_value;
}
+PyDoc_STRVAR(_asyncio_Task_cancelling__doc__,
+"cancelling($self, /)\n"
+"--\n"
+"\n"
+"Return True if the task is in the process of being cancelled.\n"
+"\n"
+"This is set once .cancel() is called\n"
+"and remains set until .uncancel() is called.\n"
+"\n"
+"As long as this flag is set, further .cancel() calls will be ignored,\n"
+"until .uncancel() is called to reset it.");
+
+#define _ASYNCIO_TASK_CANCELLING_METHODDEF \
+ {"cancelling", (PyCFunction)_asyncio_Task_cancelling, METH_NOARGS, _asyncio_Task_cancelling__doc__},
+
+static PyObject *
+_asyncio_Task_cancelling_impl(TaskObj *self);
+
+static PyObject *
+_asyncio_Task_cancelling(TaskObj *self, PyObject *Py_UNUSED(ignored))
+{
+ return _asyncio_Task_cancelling_impl(self);
+}
+
+PyDoc_STRVAR(_asyncio_Task_uncancel__doc__,
+"uncancel($self, /)\n"
+"--\n"
+"\n"
+"Reset the flag returned by cancelling().\n"
+"\n"
+"This should be used by tasks that catch CancelledError\n"
+"and wish to continue indefinitely until they are cancelled again.\n"
+"\n"
+"Returns the previous value of the flag.");
+
+#define _ASYNCIO_TASK_UNCANCEL_METHODDEF \
+ {"uncancel", (PyCFunction)_asyncio_Task_uncancel, METH_NOARGS, _asyncio_Task_uncancel__doc__},
+
+static PyObject *
+_asyncio_Task_uncancel_impl(TaskObj *self);
+
+static PyObject *
+_asyncio_Task_uncancel(TaskObj *self, PyObject *Py_UNUSED(ignored))
+{
+ return _asyncio_Task_uncancel_impl(self);
+}
+
PyDoc_STRVAR(_asyncio_Task_get_stack__doc__,
"get_stack($self, /, *, limit=None)\n"
"--\n"
@@ -871,4 +918,4 @@ _asyncio__leave_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs,
exit:
return return_value;
}
-/*[clinic end generated code: output=0d127162ac92e0c0 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=c02708a9d6a774cc input=a9049054013a1b77]*/