diff options
author | Ron Frederick <ronf@timeheart.net> | 2024-09-26 06:15:08 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-26 06:15:08 (GMT) |
commit | 1229cb8c1412d37cf3206eab407f03e21d602cbd (patch) | |
tree | 9ff9363d0b2c15026c225cb16e6305fe8cd9028c /Lib | |
parent | 46f5cbca4c37c57f718d3de0d7f7ddfc44298535 (diff) | |
download | cpython-1229cb8c1412d37cf3206eab407f03e21d602cbd.zip cpython-1229cb8c1412d37cf3206eab407f03e21d602cbd.tar.gz cpython-1229cb8c1412d37cf3206eab407f03e21d602cbd.tar.bz2 |
gh-120284: Enhance `asyncio.run` to accept awaitable objects (#120566)
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/runners.py | 17 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_runners.py | 29 |
2 files changed, 32 insertions, 14 deletions
diff --git a/Lib/asyncio/runners.py b/Lib/asyncio/runners.py index 1b89236..0e63c34 100644 --- a/Lib/asyncio/runners.py +++ b/Lib/asyncio/runners.py @@ -3,6 +3,7 @@ __all__ = ('Runner', 'run') import contextvars import enum import functools +import inspect import threading import signal from . import coroutines @@ -84,10 +85,7 @@ class Runner: return self._loop def run(self, coro, *, context=None): - """Run a coroutine inside the embedded event loop.""" - if not coroutines.iscoroutine(coro): - raise ValueError("a coroutine was expected, got {!r}".format(coro)) - + """Run code in the embedded event loop.""" if events._get_running_loop() is not None: # fail fast with short traceback raise RuntimeError( @@ -95,8 +93,19 @@ class Runner: self._lazy_init() + if not coroutines.iscoroutine(coro): + if inspect.isawaitable(coro): + async def _wrap_awaitable(awaitable): + return await awaitable + + coro = _wrap_awaitable(coro) + else: + raise TypeError('An asyncio.Future, a coroutine or an ' + 'awaitable is required') + if context is None: context = self._context + task = self._loop.create_task(coro, context=context) if (threading.current_thread() is threading.main_thread() diff --git a/Lib/test/test_asyncio/test_runners.py b/Lib/test/test_asyncio/test_runners.py index 266f057..45f70d0 100644 --- a/Lib/test/test_asyncio/test_runners.py +++ b/Lib/test/test_asyncio/test_runners.py @@ -93,8 +93,8 @@ class RunTests(BaseTest): def test_asyncio_run_only_coro(self): for o in {1, lambda: None}: with self.subTest(obj=o), \ - self.assertRaisesRegex(ValueError, - 'a coroutine was expected'): + self.assertRaisesRegex(TypeError, + 'an awaitable is required'): asyncio.run(o) def test_asyncio_run_debug(self): @@ -319,19 +319,28 @@ class RunnerTests(BaseTest): def test_run_non_coro(self): with asyncio.Runner() as runner: with self.assertRaisesRegex( - ValueError, - "a coroutine was expected" + TypeError, + "an awaitable is required" ): runner.run(123) def test_run_future(self): with asyncio.Runner() as runner: - with self.assertRaisesRegex( - ValueError, - "a coroutine was expected" - ): - fut = runner.get_loop().create_future() - runner.run(fut) + fut = runner.get_loop().create_future() + fut.set_result('done') + self.assertEqual('done', runner.run(fut)) + + def test_run_awaitable(self): + class MyAwaitable: + def __await__(self): + return self.run().__await__() + + @staticmethod + async def run(): + return 'done' + + with asyncio.Runner() as runner: + self.assertEqual('done', runner.run(MyAwaitable())) def test_explicit_close(self): runner = asyncio.Runner() |