summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorJason Zhang <yurenzhang2017@gmail.com>2024-03-06 20:20:26 (GMT)
committerGitHub <noreply@github.com>2024-03-06 20:20:26 (GMT)
commitce0ae1d784871085059a415aa589d9bd16ea8301 (patch)
treeec905cf72732077f8491602fb336b13aea967ae8 /Lib
parent7114cf20c015b99123b32c1ba4f5475b7a6c3a13 (diff)
downloadcpython-ce0ae1d784871085059a415aa589d9bd16ea8301.zip
cpython-ce0ae1d784871085059a415aa589d9bd16ea8301.tar.gz
cpython-ce0ae1d784871085059a415aa589d9bd16ea8301.tar.bz2
gh-115957: Close coroutine if TaskGroup.create_task() raises an error (#116009)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/taskgroups.py3
-rw-r--r--Lib/test/test_asyncio/test_taskgroups.py41
2 files changed, 28 insertions, 16 deletions
diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py
index f322b1f..57f0123 100644
--- a/Lib/asyncio/taskgroups.py
+++ b/Lib/asyncio/taskgroups.py
@@ -154,10 +154,13 @@ class TaskGroup:
Similar to `asyncio.create_task`.
"""
if not self._entered:
+ coro.close()
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
if self._exiting and not self._tasks:
+ coro.close()
raise RuntimeError(f"TaskGroup {self!r} is finished")
if self._aborting:
+ coro.close()
raise RuntimeError(f"TaskGroup {self!r} is shutting down")
if context is None:
task = self._loop.create_task(coro, name=name)
diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py
index 7a18362..1ec8116 100644
--- a/Lib/test/test_asyncio/test_taskgroups.py
+++ b/Lib/test/test_asyncio/test_taskgroups.py
@@ -7,6 +7,7 @@ import contextvars
import contextlib
from asyncio import taskgroups
import unittest
+import warnings
from test.test_asyncio.utils import await_without_task
@@ -738,10 +739,7 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
await asyncio.sleep(1)
except asyncio.CancelledError:
with self.assertRaises(RuntimeError):
- g.create_task(c1 := coro1())
- # We still have to await c1 to avoid a warning
- with self.assertRaises(ZeroDivisionError):
- await c1
+ g.create_task(coro1())
with self.assertRaises(ExceptionGroup) as cm:
async with taskgroups.TaskGroup() as g:
@@ -797,22 +795,25 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
pass
async def test_taskgroup_finished(self):
- tg = taskgroups.TaskGroup()
- async with tg:
- pass
- coro = asyncio.sleep(0)
- with self.assertRaisesRegex(RuntimeError, "is finished"):
- tg.create_task(coro)
- # We still have to await coro to avoid a warning
- await coro
+ async def create_task_after_tg_finish():
+ tg = taskgroups.TaskGroup()
+ async with tg:
+ pass
+ coro = asyncio.sleep(0)
+ with self.assertRaisesRegex(RuntimeError, "is finished"):
+ tg.create_task(coro)
+
+ # Make sure the coroutine was closed when submitted to the inactive tg
+ # (if not closed, a RuntimeWarning should have been raised)
+ with warnings.catch_warnings(record=True) as w:
+ await create_task_after_tg_finish()
+ self.assertEqual(len(w), 0)
async def test_taskgroup_not_entered(self):
tg = taskgroups.TaskGroup()
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
- # We still have to await coro to avoid a warning
- await coro
async def test_taskgroup_without_parent_task(self):
tg = taskgroups.TaskGroup()
@@ -821,8 +822,16 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
- # We still have to await coro to avoid a warning
- await coro
+
+ def test_coro_closed_when_tg_closed(self):
+ async def run_coro_after_tg_closes():
+ async with taskgroups.TaskGroup() as tg:
+ pass
+ coro = asyncio.sleep(0)
+ with self.assertRaisesRegex(RuntimeError, "is finished"):
+ tg.create_task(coro)
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(run_coro_after_tg_closes())
if __name__ == "__main__":