diff options
-rw-r--r-- | Lib/asyncio/taskgroups.py | 22 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_taskgroups.py | 19 |
2 files changed, 17 insertions, 24 deletions
diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 7182778..57b0eaf 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -3,10 +3,6 @@ __all__ = ["TaskGroup"] -import itertools -import textwrap -import traceback -import types import weakref from . import events @@ -15,12 +11,7 @@ from . import tasks class TaskGroup: - def __init__(self, *, name=None): - if name is None: - self._name = f'tg-{_name_counter()}' - else: - self._name = str(name) - + def __init__(self): self._entered = False self._exiting = False self._aborting = False @@ -33,11 +24,8 @@ class TaskGroup: self._base_error = None self._on_completed_fut = None - def get_name(self): - return self._name - def __repr__(self): - msg = f'<TaskGroup {self._name!r}' + msg = f'<TaskGroup' if self._tasks: msg += f' tasks:{len(self._tasks)}' if self._unfinished_tasks: @@ -152,12 +140,13 @@ class TaskGroup: me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors) raise me from None - def create_task(self, coro): + def create_task(self, coro, *, name=None): if not self._entered: raise RuntimeError(f"TaskGroup {self!r} has not been entered") if self._exiting and self._unfinished_tasks == 0: raise RuntimeError(f"TaskGroup {self!r} is finished") task = self._loop.create_task(coro) + tasks._set_task_name(task, name) task.add_done_callback(self._on_task_done) self._unfinished_tasks += 1 self._tasks.add(task) @@ -230,6 +219,3 @@ class TaskGroup: # # after TaskGroup is finished. self._parent_cancel_requested = True self._parent_task.cancel() - - -_name_counter = itertools.count(1).__next__ diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index ea6ee2e..aab1fd1 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -368,10 +368,10 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): raise ValueError(t) async def runner(): - async with taskgroups.TaskGroup(name='g1') as g1: + async with taskgroups.TaskGroup() as g1: g1.create_task(crash_after(0.1)) - async with taskgroups.TaskGroup(name='g2') as g2: + async with taskgroups.TaskGroup() as g2: g2.create_task(crash_after(0.2)) r = asyncio.create_task(runner()) @@ -387,10 +387,10 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): raise ValueError(t) async def runner(): - async with taskgroups.TaskGroup(name='g1') as g1: + async with taskgroups.TaskGroup() as g1: g1.create_task(crash_after(10)) - async with taskgroups.TaskGroup(name='g2') as g2: + async with taskgroups.TaskGroup() as g2: g2.create_task(crash_after(0.1)) r = asyncio.create_task(runner()) @@ -407,7 +407,7 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): 1 / 0 async def runner(): - async with taskgroups.TaskGroup(name='g1') as g1: + async with taskgroups.TaskGroup() as g1: g1.create_task(crash_soon()) try: await asyncio.sleep(10) @@ -430,7 +430,7 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): 1 / 0 async def nested_runner(): - async with taskgroups.TaskGroup(name='g1') as g1: + async with taskgroups.TaskGroup() as g1: g1.create_task(crash_soon()) try: await asyncio.sleep(10) @@ -692,3 +692,10 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) self.assertGreaterEqual(nhydras, 10) + + async def test_taskgroup_task_name(self): + async def coro(): + await asyncio.sleep(0) + async with taskgroups.TaskGroup() as g: + t = g.create_task(coro(), name="yolo") + self.assertEqual(t.get_name(), "yolo") |