summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorRon Frederick <ronf@timeheart.net>2024-09-26 06:15:08 (GMT)
committerGitHub <noreply@github.com>2024-09-26 06:15:08 (GMT)
commit1229cb8c1412d37cf3206eab407f03e21d602cbd (patch)
tree9ff9363d0b2c15026c225cb16e6305fe8cd9028c /Lib
parent46f5cbca4c37c57f718d3de0d7f7ddfc44298535 (diff)
downloadcpython-1229cb8c1412d37cf3206eab407f03e21d602cbd.zip
cpython-1229cb8c1412d37cf3206eab407f03e21d602cbd.tar.gz
cpython-1229cb8c1412d37cf3206eab407f03e21d602cbd.tar.bz2
gh-120284: Enhance `asyncio.run` to accept awaitable objects (#120566)
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/runners.py17
-rw-r--r--Lib/test/test_asyncio/test_runners.py29
2 files changed, 32 insertions, 14 deletions
diff --git a/Lib/asyncio/runners.py b/Lib/asyncio/runners.py
index 1b89236..0e63c34 100644
--- a/Lib/asyncio/runners.py
+++ b/Lib/asyncio/runners.py
@@ -3,6 +3,7 @@ __all__ = ('Runner', 'run')
import contextvars
import enum
import functools
+import inspect
import threading
import signal
from . import coroutines
@@ -84,10 +85,7 @@ class Runner:
return self._loop
def run(self, coro, *, context=None):
- """Run a coroutine inside the embedded event loop."""
- if not coroutines.iscoroutine(coro):
- raise ValueError("a coroutine was expected, got {!r}".format(coro))
-
+ """Run code in the embedded event loop."""
if events._get_running_loop() is not None:
# fail fast with short traceback
raise RuntimeError(
@@ -95,8 +93,19 @@ class Runner:
self._lazy_init()
+ if not coroutines.iscoroutine(coro):
+ if inspect.isawaitable(coro):
+ async def _wrap_awaitable(awaitable):
+ return await awaitable
+
+ coro = _wrap_awaitable(coro)
+ else:
+ raise TypeError('An asyncio.Future, a coroutine or an '
+ 'awaitable is required')
+
if context is None:
context = self._context
+
task = self._loop.create_task(coro, context=context)
if (threading.current_thread() is threading.main_thread()
diff --git a/Lib/test/test_asyncio/test_runners.py b/Lib/test/test_asyncio/test_runners.py
index 266f057..45f70d0 100644
--- a/Lib/test/test_asyncio/test_runners.py
+++ b/Lib/test/test_asyncio/test_runners.py
@@ -93,8 +93,8 @@ class RunTests(BaseTest):
def test_asyncio_run_only_coro(self):
for o in {1, lambda: None}:
with self.subTest(obj=o), \
- self.assertRaisesRegex(ValueError,
- 'a coroutine was expected'):
+ self.assertRaisesRegex(TypeError,
+ 'an awaitable is required'):
asyncio.run(o)
def test_asyncio_run_debug(self):
@@ -319,19 +319,28 @@ class RunnerTests(BaseTest):
def test_run_non_coro(self):
with asyncio.Runner() as runner:
with self.assertRaisesRegex(
- ValueError,
- "a coroutine was expected"
+ TypeError,
+ "an awaitable is required"
):
runner.run(123)
def test_run_future(self):
with asyncio.Runner() as runner:
- with self.assertRaisesRegex(
- ValueError,
- "a coroutine was expected"
- ):
- fut = runner.get_loop().create_future()
- runner.run(fut)
+ fut = runner.get_loop().create_future()
+ fut.set_result('done')
+ self.assertEqual('done', runner.run(fut))
+
+ def test_run_awaitable(self):
+ class MyAwaitable:
+ def __await__(self):
+ return self.run().__await__()
+
+ @staticmethod
+ async def run():
+ return 'done'
+
+ with asyncio.Runner() as runner:
+ self.assertEqual('done', runner.run(MyAwaitable()))
def test_explicit_close(self):
runner = asyncio.Runner()