summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/asyncio')
-rw-r--r--Lib/asyncio/__init__.py1
-rw-r--r--Lib/asyncio/base_tasks.py2
-rw-r--r--Lib/asyncio/taskgroups.py235
-rw-r--r--Lib/asyncio/tasks.py16
4 files changed, 252 insertions, 2 deletions
diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py
index 200b14c..db1124c 100644
--- a/Lib/asyncio/__init__.py
+++ b/Lib/asyncio/__init__.py
@@ -17,6 +17,7 @@ from .queues import *
from .streams import *
from .subprocess import *
from .tasks import *
+from .taskgroups import *
from .threads import *
from .transports import *
diff --git a/Lib/asyncio/base_tasks.py b/Lib/asyncio/base_tasks.py
index 09bb171..1d62389 100644
--- a/Lib/asyncio/base_tasks.py
+++ b/Lib/asyncio/base_tasks.py
@@ -8,7 +8,7 @@ from . import coroutines
def _task_repr_info(task):
info = base_futures._future_repr_info(task)
- if task._must_cancel:
+ if task.cancelling() and not task.done():
# replace status
info[0] = 'cancelling'
diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py
new file mode 100644
index 0000000..7182778
--- /dev/null
+++ b/Lib/asyncio/taskgroups.py
@@ -0,0 +1,235 @@
+# Adapted with permission from the EdgeDB project.
+
+
+__all__ = ["TaskGroup"]
+
+import itertools
+import textwrap
+import traceback
+import types
+import weakref
+
+from . import events
+from . import exceptions
+from . import tasks
+
+class TaskGroup:
+
+ def __init__(self, *, name=None):
+ if name is None:
+ self._name = f'tg-{_name_counter()}'
+ else:
+ self._name = str(name)
+
+ self._entered = False
+ self._exiting = False
+ self._aborting = False
+ self._loop = None
+ self._parent_task = None
+ self._parent_cancel_requested = False
+ self._tasks = weakref.WeakSet()
+ self._unfinished_tasks = 0
+ self._errors = []
+ self._base_error = None
+ self._on_completed_fut = None
+
+ def get_name(self):
+ return self._name
+
+ def __repr__(self):
+ msg = f'<TaskGroup {self._name!r}'
+ if self._tasks:
+ msg += f' tasks:{len(self._tasks)}'
+ if self._unfinished_tasks:
+ msg += f' unfinished:{self._unfinished_tasks}'
+ if self._errors:
+ msg += f' errors:{len(self._errors)}'
+ if self._aborting:
+ msg += ' cancelling'
+ elif self._entered:
+ msg += ' entered'
+ msg += '>'
+ return msg
+
+ async def __aenter__(self):
+ if self._entered:
+ raise RuntimeError(
+ f"TaskGroup {self!r} has been already entered")
+ self._entered = True
+
+ if self._loop is None:
+ self._loop = events.get_running_loop()
+
+ self._parent_task = tasks.current_task(self._loop)
+ if self._parent_task is None:
+ raise RuntimeError(
+ f'TaskGroup {self!r} cannot determine the parent task')
+
+ return self
+
+ async def __aexit__(self, et, exc, tb):
+ self._exiting = True
+ propagate_cancellation_error = None
+
+ if (exc is not None and
+ self._is_base_error(exc) and
+ self._base_error is None):
+ self._base_error = exc
+
+ if et is exceptions.CancelledError:
+ if self._parent_cancel_requested:
+ # Only if we did request task to cancel ourselves
+ # we mark it as no longer cancelled.
+ self._parent_task.uncancel()
+ else:
+ propagate_cancellation_error = et
+
+ if et is not None and not self._aborting:
+ # Our parent task is being cancelled:
+ #
+ # async with TaskGroup() as g:
+ # g.create_task(...)
+ # await ... # <- CancelledError
+ #
+ if et is exceptions.CancelledError:
+ propagate_cancellation_error = et
+
+ # or there's an exception in "async with":
+ #
+ # async with TaskGroup() as g:
+ # g.create_task(...)
+ # 1 / 0
+ #
+ self._abort()
+
+ # We use while-loop here because "self._on_completed_fut"
+ # can be cancelled multiple times if our parent task
+ # is being cancelled repeatedly (or even once, when
+ # our own cancellation is already in progress)
+ while self._unfinished_tasks:
+ if self._on_completed_fut is None:
+ self._on_completed_fut = self._loop.create_future()
+
+ try:
+ await self._on_completed_fut
+ except exceptions.CancelledError as ex:
+ if not self._aborting:
+ # Our parent task is being cancelled:
+ #
+ # async def wrapper():
+ # async with TaskGroup() as g:
+ # g.create_task(foo)
+ #
+ # "wrapper" is being cancelled while "foo" is
+ # still running.
+ propagate_cancellation_error = ex
+ self._abort()
+
+ self._on_completed_fut = None
+
+ assert self._unfinished_tasks == 0
+ self._on_completed_fut = None # no longer needed
+
+ if self._base_error is not None:
+ raise self._base_error
+
+ if propagate_cancellation_error is not None:
+ # The wrapping task was cancelled; since we're done with
+ # closing all child tasks, just propagate the cancellation
+ # request now.
+ raise propagate_cancellation_error
+
+ if et is not None and et is not exceptions.CancelledError:
+ self._errors.append(exc)
+
+ if self._errors:
+ # Exceptions are heavy objects that can have object
+ # cycles (bad for GC); let's not keep a reference to
+ # a bunch of them.
+ errors = self._errors
+ self._errors = None
+
+ me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
+ raise me from None
+
+ def create_task(self, coro):
+ if not self._entered:
+ raise RuntimeError(f"TaskGroup {self!r} has not been entered")
+ if self._exiting and self._unfinished_tasks == 0:
+ raise RuntimeError(f"TaskGroup {self!r} is finished")
+ task = self._loop.create_task(coro)
+ task.add_done_callback(self._on_task_done)
+ self._unfinished_tasks += 1
+ self._tasks.add(task)
+ return task
+
+ # Since Python 3.8 Tasks propagate all exceptions correctly,
+ # except for KeyboardInterrupt and SystemExit which are
+ # still considered special.
+
+ def _is_base_error(self, exc: BaseException) -> bool:
+ assert isinstance(exc, BaseException)
+ return isinstance(exc, (SystemExit, KeyboardInterrupt))
+
+ def _abort(self):
+ self._aborting = True
+
+ for t in self._tasks:
+ if not t.done():
+ t.cancel()
+
+ def _on_task_done(self, task):
+ self._unfinished_tasks -= 1
+ assert self._unfinished_tasks >= 0
+
+ if self._on_completed_fut is not None and not self._unfinished_tasks:
+ if not self._on_completed_fut.done():
+ self._on_completed_fut.set_result(True)
+
+ if task.cancelled():
+ return
+
+ exc = task.exception()
+ if exc is None:
+ return
+
+ self._errors.append(exc)
+ if self._is_base_error(exc) and self._base_error is None:
+ self._base_error = exc
+
+ if self._parent_task.done():
+ # Not sure if this case is possible, but we want to handle
+ # it anyways.
+ self._loop.call_exception_handler({
+ 'message': f'Task {task!r} has errored out but its parent '
+ f'task {self._parent_task} is already completed',
+ 'exception': exc,
+ 'task': task,
+ })
+ return
+
+ self._abort()
+ if not self._parent_task.cancelling():
+ # If parent task *is not* being cancelled, it means that we want
+ # to manually cancel it to abort whatever is being run right now
+ # in the TaskGroup. But we want to mark parent task as
+ # "not cancelled" later in __aexit__. Example situation that
+ # we need to handle:
+ #
+ # async def foo():
+ # try:
+ # async with TaskGroup() as g:
+ # g.create_task(crash_soon())
+ # await something # <- this needs to be canceled
+ # # by the TaskGroup, e.g.
+ # # foo() needs to be cancelled
+ # except Exception:
+ # # Ignore any exceptions raised in the TaskGroup
+ # pass
+ # await something_else # this line has to be called
+ # # after TaskGroup is finished.
+ self._parent_cancel_requested = True
+ self._parent_task.cancel()
+
+
+_name_counter = itertools.count(1).__next__
diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py
index 2bee5c0..c11d0da 100644
--- a/Lib/asyncio/tasks.py
+++ b/Lib/asyncio/tasks.py
@@ -105,6 +105,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
else:
self._name = str(name)
+ self._cancel_requested = False
self._must_cancel = False
self._fut_waiter = None
self._coro = coro
@@ -201,6 +202,9 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
self._log_traceback = False
if self.done():
return False
+ if self._cancel_requested:
+ return False
+ self._cancel_requested = True
if self._fut_waiter is not None:
if self._fut_waiter.cancel(msg=msg):
# Leave self._fut_waiter; it may be a Task that
@@ -212,6 +216,16 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
self._cancel_message = msg
return True
+ def cancelling(self):
+ return self._cancel_requested
+
+ def uncancel(self):
+ if self._cancel_requested:
+ self._cancel_requested = False
+ return True
+ else:
+ return False
+
def __step(self, exc=None):
if self.done():
raise exceptions.InvalidStateError(
@@ -634,7 +648,7 @@ def _ensure_future(coro_or_future, *, loop=None):
loop = events._get_event_loop(stacklevel=4)
try:
return loop.create_task(coro_or_future)
- except RuntimeError:
+ except RuntimeError:
if not called_wrap_awaitable:
coro_or_future.close()
raise