summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPeter Bierma <zintensitydev@gmail.com>2024-09-26 05:11:17 (GMT)
committerGitHub <noreply@github.com>2024-09-26 05:11:17 (GMT)
commitde929f353c413459834a2a37b2d9b0240673d874 (patch)
tree641ff0ca3ff11a7dd48b6cbf410a8a01cee64926
parentd9296529eb0a65f988e8600d3073977dff0ce5a9 (diff)
downloadcpython-de929f353c413459834a2a37b2d9b0240673d874.zip
cpython-de929f353c413459834a2a37b2d9b0240673d874.tar.gz
cpython-de929f353c413459834a2a37b2d9b0240673d874.tar.bz2
gh-124309: Modernize the `staggered_race` implementation to support eager task factories (#124390)
Co-authored-by: Thomas Grainger <tagrain@gmail.com> Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com> Co-authored-by: Carol Willing <carolcode@willingconsulting.com> Co-authored-by: Kumar Aditya <kumaraditya@python.org>
-rw-r--r--Lib/asyncio/base_events.py2
-rw-r--r--Lib/asyncio/staggered.py79
-rw-r--r--Lib/test/test_asyncio/test_eager_task_factory.py47
-rw-r--r--Lib/test/test_asyncio/test_staggered.py37
-rw-r--r--Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst1
5 files changed, 100 insertions, 66 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 000647f..ffcc017 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -1144,7 +1144,7 @@ class BaseEventLoop(events.AbstractEventLoop):
(functools.partial(self._connect_sock,
exceptions, addrinfo, laddr_infos)
for addrinfo in infos),
- happy_eyeballs_delay, loop=self)
+ happy_eyeballs_delay)
if sock is None:
exceptions = [exc for sub in exceptions for exc in sub]
diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py
index c3a7441..4458d01 100644
--- a/Lib/asyncio/staggered.py
+++ b/Lib/asyncio/staggered.py
@@ -4,13 +4,14 @@ __all__ = 'staggered_race',
import contextlib
-from . import events
-from . import exceptions as exceptions_mod
from . import locks
from . import tasks
+from . import taskgroups
+class _Done(Exception):
+ pass
-async def staggered_race(coro_fns, delay, *, loop=None):
+async def staggered_race(coro_fns, delay):
"""Run coroutines with staggered start times and take the first to finish.
This method takes an iterable of coroutine functions. The first one is
@@ -42,8 +43,6 @@ async def staggered_race(coro_fns, delay, *, loop=None):
delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially.
- loop: the event loop to use.
-
Returns:
tuple *(winner_result, winner_index, exceptions)* where
@@ -62,36 +61,11 @@ async def staggered_race(coro_fns, delay, *, loop=None):
"""
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
- loop = loop or events.get_running_loop()
- enum_coro_fns = enumerate(coro_fns)
winner_result = None
winner_index = None
exceptions = []
- running_tasks = []
-
- async def run_one_coro(previous_failed) -> None:
- # Wait for the previous task to finish, or for delay seconds
- if previous_failed is not None:
- with contextlib.suppress(exceptions_mod.TimeoutError):
- # Use asyncio.wait_for() instead of asyncio.wait() here, so
- # that if we get cancelled at this point, Event.wait() is also
- # cancelled, otherwise there will be a "Task destroyed but it is
- # pending" later.
- await tasks.wait_for(previous_failed.wait(), delay)
- # Get the next coroutine to run
- try:
- this_index, coro_fn = next(enum_coro_fns)
- except StopIteration:
- return
- # Start task that will run the next coroutine
- this_failed = locks.Event()
- next_task = loop.create_task(run_one_coro(this_failed))
- running_tasks.append(next_task)
- assert len(running_tasks) == this_index + 2
- # Prepare place to put this coroutine's exceptions if not won
- exceptions.append(None)
- assert len(exceptions) == this_index + 1
+ async def run_one_coro(this_index, coro_fn, this_failed):
try:
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
@@ -105,34 +79,17 @@ async def staggered_race(coro_fns, delay, *, loop=None):
assert winner_index is None
winner_index = this_index
winner_result = result
- # Cancel all other tasks. We take care to not cancel the current
- # task as well. If we do so, then since there is no `await` after
- # here and CancelledError are usually thrown at one, we will
- # encounter a curious corner case where the current task will end
- # up as done() == True, cancelled() == False, exception() ==
- # asyncio.CancelledError. This behavior is specified in
- # https://bugs.python.org/issue30048
- for i, t in enumerate(running_tasks):
- if i != this_index:
- t.cancel()
-
- first_task = loop.create_task(run_one_coro(None))
- running_tasks.append(first_task)
+ raise _Done
+
try:
- # Wait for a growing list of tasks to all finish: poor man's version of
- # curio's TaskGroup or trio's nursery
- done_count = 0
- while done_count != len(running_tasks):
- done, _ = await tasks.wait(running_tasks)
- done_count = len(done)
- # If run_one_coro raises an unhandled exception, it's probably a
- # programming error, and I want to see it.
- if __debug__:
- for d in done:
- if d.done() and not d.cancelled() and d.exception():
- raise d.exception()
- return winner_result, winner_index, exceptions
- finally:
- # Make sure no tasks are left running if we leave this function
- for t in running_tasks:
- t.cancel()
+ async with taskgroups.TaskGroup() as tg:
+ for this_index, coro_fn in enumerate(coro_fns):
+ this_failed = locks.Event()
+ exceptions.append(None)
+ tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
+ with contextlib.suppress(TimeoutError):
+ await tasks.wait_for(this_failed.wait(), delay)
+ except* _Done:
+ pass
+
+ return winner_result, winner_index, exceptions
diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py
index 0777f39..1579ad1 100644
--- a/Lib/test/test_asyncio/test_eager_task_factory.py
+++ b/Lib/test/test_asyncio/test_eager_task_factory.py
@@ -213,6 +213,53 @@ class EagerTaskFactoryLoopTests:
self.run_coro(run())
+ def test_staggered_race_with_eager_tasks(self):
+ # See https://github.com/python/cpython/issues/124309
+
+ async def fail():
+ await asyncio.sleep(0)
+ raise ValueError("no good")
+
+ async def run():
+ winner, index, excs = await asyncio.staggered.staggered_race(
+ [
+ lambda: asyncio.sleep(2, result="sleep2"),
+ lambda: asyncio.sleep(1, result="sleep1"),
+ lambda: fail()
+ ],
+ delay=0.25
+ )
+ self.assertEqual(winner, 'sleep1')
+ self.assertEqual(index, 1)
+ self.assertIsNone(excs[index])
+ self.assertIsInstance(excs[0], asyncio.CancelledError)
+ self.assertIsInstance(excs[2], ValueError)
+
+ self.run_coro(run())
+
+ def test_staggered_race_with_eager_tasks_no_delay(self):
+ # See https://github.com/python/cpython/issues/124309
+ async def fail():
+ raise ValueError("no good")
+
+ async def run():
+ winner, index, excs = await asyncio.staggered.staggered_race(
+ [
+ lambda: fail(),
+ lambda: asyncio.sleep(1, result="sleep1"),
+ lambda: asyncio.sleep(0, result="sleep0"),
+ ],
+ delay=None
+ )
+ self.assertEqual(winner, 'sleep1')
+ self.assertEqual(index, 1)
+ self.assertIsNone(excs[index])
+ self.assertIsInstance(excs[0], ValueError)
+ self.assertEqual(len(excs), 2)
+
+ self.run_coro(run())
+
+
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = tasks._PyTask
diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py
index e6e32f7..21a39b3 100644
--- a/Lib/test/test_asyncio/test_staggered.py
+++ b/Lib/test/test_asyncio/test_staggered.py
@@ -82,16 +82,45 @@ class StaggeredTests(unittest.IsolatedAsyncioTestCase):
async def coro(index):
raise ValueError(index)
+ for delay in [None, 0, 0.1, 1]:
+ with self.subTest(delay=delay):
+ winner, index, excs = await staggered_race(
+ [
+ lambda: coro(0),
+ lambda: coro(1),
+ ],
+ delay=delay,
+ )
+
+ self.assertIs(winner, None)
+ self.assertIs(index, None)
+ self.assertEqual(len(excs), 2)
+ self.assertIsInstance(excs[0], ValueError)
+ self.assertIsInstance(excs[1], ValueError)
+
+ async def test_long_delay_early_failure(self):
+ async def coro(index):
+ await asyncio.sleep(0) # Dummy coroutine for the 1 case
+ if index == 0:
+ await asyncio.sleep(0.1) # Dummy coroutine
+ raise ValueError(index)
+
+ return f'Res: {index}'
+
winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
- delay=None,
+ delay=10,
)
- self.assertIs(winner, None)
- self.assertIs(index, None)
+ self.assertEqual(winner, 'Res: 1')
+ self.assertEqual(index, 1)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
- self.assertIsInstance(excs[1], ValueError)
+ self.assertIsNone(excs[1])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst b/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst
new file mode 100644
index 0000000..89610fa
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst
@@ -0,0 +1 @@
+Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.