summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio/taskgroups.py
diff options
context:
space:
mode:
authorYury Selivanov <yury@edgedb.com>2022-05-27 22:20:21 (GMT)
committerGitHub <noreply@github.com>2022-05-27 22:20:21 (GMT)
commite6a57678cafe18ca132ee9510252168fcc392a8d (patch)
tree1171a3a061301a6c143b61eaf774264fa988a38f /Lib/asyncio/taskgroups.py
parent70cfe56cafb2b549983f63d5d1a54654fe63c15c (diff)
downloadcpython-e6a57678cafe18ca132ee9510252168fcc392a8d.zip
cpython-e6a57678cafe18ca132ee9510252168fcc392a8d.tar.gz
cpython-e6a57678cafe18ca132ee9510252168fcc392a8d.tar.bz2
gh-93297: Make asyncio task groups prevent child tasks from being GCed (#93299)
Diffstat (limited to 'Lib/asyncio/taskgroups.py')
-rw-r--r--Lib/asyncio/taskgroups.py19
1 files changed, 6 insertions, 13 deletions
diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py
index 6af21f3..9e0610d 100644
--- a/Lib/asyncio/taskgroups.py
+++ b/Lib/asyncio/taskgroups.py
@@ -3,8 +3,6 @@
__all__ = ["TaskGroup"]
-import weakref
-
from . import events
from . import exceptions
from . import tasks
@@ -19,8 +17,7 @@ class TaskGroup:
self._loop = None
self._parent_task = None
self._parent_cancel_requested = False
- self._tasks = weakref.WeakSet()
- self._unfinished_tasks = 0
+ self._tasks = set()
self._errors = []
self._base_error = None
self._on_completed_fut = None
@@ -29,8 +26,6 @@ class TaskGroup:
info = ['']
if self._tasks:
info.append(f'tasks={len(self._tasks)}')
- if self._unfinished_tasks:
- info.append(f'unfinished={self._unfinished_tasks}')
if self._errors:
info.append(f'errors={len(self._errors)}')
if self._aborting:
@@ -93,7 +88,7 @@ class TaskGroup:
# can be cancelled multiple times if our parent task
# is being cancelled repeatedly (or even once, when
# our own cancellation is already in progress)
- while self._unfinished_tasks:
+ while self._tasks:
if self._on_completed_fut is None:
self._on_completed_fut = self._loop.create_future()
@@ -114,7 +109,7 @@ class TaskGroup:
self._on_completed_fut = None
- assert self._unfinished_tasks == 0
+ assert not self._tasks
if self._base_error is not None:
raise self._base_error
@@ -141,7 +136,7 @@ class TaskGroup:
def create_task(self, coro, *, name=None, context=None):
if not self._entered:
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
- if self._exiting and self._unfinished_tasks == 0:
+ if self._exiting and not self._tasks:
raise RuntimeError(f"TaskGroup {self!r} is finished")
if context is None:
task = self._loop.create_task(coro)
@@ -149,7 +144,6 @@ class TaskGroup:
task = self._loop.create_task(coro, context=context)
tasks._set_task_name(task, name)
task.add_done_callback(self._on_task_done)
- self._unfinished_tasks += 1
self._tasks.add(task)
return task
@@ -169,10 +163,9 @@ class TaskGroup:
t.cancel()
def _on_task_done(self, task):
- self._unfinished_tasks -= 1
- assert self._unfinished_tasks >= 0
+ self._tasks.discard(task)
- if self._on_completed_fut is not None and not self._unfinished_tasks:
+ if self._on_completed_fut is not None and not self._tasks:
if not self._on_completed_fut.done():
self._on_completed_fut.set_result(True)