summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Doc/library/asyncio-eventloop.rst11
-rw-r--r--Doc/library/asyncio-task.rst9
-rw-r--r--Lib/asyncio/base_events.py11
-rw-r--r--Lib/asyncio/events.py2
-rw-r--r--Lib/asyncio/taskgroups.py7
-rw-r--r--Lib/asyncio/tasks.py16
-rw-r--r--Lib/test/test_asyncio/test_taskgroups.py18
-rw-r--r--Lib/test/test_asyncio/test_tasks.py88
-rw-r--r--Lib/unittest/async_case.py55
-rw-r--r--Lib/unittest/test/test_async_case.py18
-rw-r--r--Misc/NEWS.d/next/Library/2022-03-12-12-34-13.bpo-46994.d7hPdz.rst2
-rw-r--r--Modules/_asynciomodule.c16
-rw-r--r--Modules/clinic/_asynciomodule.c.h21
13 files changed, 209 insertions, 65 deletions
diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst
index 4776853..4f0f8c0 100644
--- a/Doc/library/asyncio-eventloop.rst
+++ b/Doc/library/asyncio-eventloop.rst
@@ -330,7 +330,7 @@ Creating Futures and Tasks
.. versionadded:: 3.5.2
-.. method:: loop.create_task(coro, *, name=None)
+.. method:: loop.create_task(coro, *, name=None, context=None)
Schedule the execution of a :ref:`coroutine`.
Return a :class:`Task` object.
@@ -342,9 +342,16 @@ Creating Futures and Tasks
If the *name* argument is provided and not ``None``, it is set as
the name of the task using :meth:`Task.set_name`.
+ An optional keyword-only *context* argument allows specifying a
+ custom :class:`contextvars.Context` for the *coro* to run in.
+ The current context copy is created when no *context* is provided.
+
.. versionchanged:: 3.8
Added the *name* parameter.
+ .. versionchanged:: 3.11
+ Added the *context* parameter.
+
.. method:: loop.set_task_factory(factory)
Set a task factory that will be used by
@@ -352,7 +359,7 @@ Creating Futures and Tasks
If *factory* is ``None`` the default task factory will be set.
Otherwise, *factory* must be a *callable* with the signature matching
- ``(loop, coro)``, where *loop* is a reference to the active
+ ``(loop, coro, context=None)``, where *loop* is a reference to the active
event loop, and *coro* is a coroutine object. The callable
must return a :class:`asyncio.Future`-compatible object.
diff --git a/Doc/library/asyncio-task.rst b/Doc/library/asyncio-task.rst
index b30b289..faf5910 100644
--- a/Doc/library/asyncio-task.rst
+++ b/Doc/library/asyncio-task.rst
@@ -244,7 +244,7 @@ Running an asyncio Program
Creating Tasks
==============
-.. function:: create_task(coro, *, name=None)
+.. function:: create_task(coro, *, name=None, context=None)
Wrap the *coro* :ref:`coroutine <coroutine>` into a :class:`Task`
and schedule its execution. Return the Task object.
@@ -252,6 +252,10 @@ Creating Tasks
If *name* is not ``None``, it is set as the name of the task using
:meth:`Task.set_name`.
+ An optional keyword-only *context* argument allows specifying a
+ custom :class:`contextvars.Context` for the *coro* to run in.
+ The current context copy is created when no *context* is provided.
+
The task is executed in the loop returned by :func:`get_running_loop`,
:exc:`RuntimeError` is raised if there is no running loop in
current thread.
@@ -281,6 +285,9 @@ Creating Tasks
.. versionchanged:: 3.8
Added the *name* parameter.
+ .. versionchanged:: 3.11
+ Added the *context* parameter.
+
Sleeping
========
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 51c4e66..5eea165 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -426,18 +426,23 @@ class BaseEventLoop(events.AbstractEventLoop):
"""Create a Future object attached to the loop."""
return futures.Future(loop=self)
- def create_task(self, coro, *, name=None):
+ def create_task(self, coro, *, name=None, context=None):
"""Schedule a coroutine object.
Return a task object.
"""
self._check_closed()
if self._task_factory is None:
- task = tasks.Task(coro, loop=self, name=name)
+ task = tasks.Task(coro, loop=self, name=name, context=context)
if task._source_traceback:
del task._source_traceback[-1]
else:
- task = self._task_factory(self, coro)
+ if context is None:
+ # Use legacy API if context is not needed
+ task = self._task_factory(self, coro)
+ else:
+ task = self._task_factory(self, coro, context=context)
+
tasks._set_task_name(task, name)
return task
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index e682a19..0d26ea5 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -274,7 +274,7 @@ class AbstractEventLoop:
# Method scheduling a coroutine object: create a task.
- def create_task(self, coro, *, name=None):
+ def create_task(self, coro, *, name=None, context=None):
raise NotImplementedError
# Methods for interacting with threads.
diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py
index c3ce94a..6af21f3 100644
--- a/Lib/asyncio/taskgroups.py
+++ b/Lib/asyncio/taskgroups.py
@@ -138,12 +138,15 @@ class TaskGroup:
me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
raise me from None
- def create_task(self, coro, *, name=None):
+ def create_task(self, coro, *, name=None, context=None):
if not self._entered:
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
if self._exiting and self._unfinished_tasks == 0:
raise RuntimeError(f"TaskGroup {self!r} is finished")
- task = self._loop.create_task(coro)
+ if context is None:
+ task = self._loop.create_task(coro)
+ else:
+ task = self._loop.create_task(coro, context=context)
tasks._set_task_name(task, name)
task.add_done_callback(self._on_task_done)
self._unfinished_tasks += 1
diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py
index e604298..b4f1eed 100644
--- a/Lib/asyncio/tasks.py
+++ b/Lib/asyncio/tasks.py
@@ -93,7 +93,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
# status is still pending
_log_destroy_pending = True
- def __init__(self, coro, *, loop=None, name=None):
+ def __init__(self, coro, *, loop=None, name=None, context=None):
super().__init__(loop=loop)
if self._source_traceback:
del self._source_traceback[-1]
@@ -112,7 +112,10 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
self._must_cancel = False
self._fut_waiter = None
self._coro = coro
- self._context = contextvars.copy_context()
+ if context is None:
+ self._context = contextvars.copy_context()
+ else:
+ self._context = context
self._loop.call_soon(self.__step, context=self._context)
_register_task(self)
@@ -360,13 +363,18 @@ else:
Task = _CTask = _asyncio.Task
-def create_task(coro, *, name=None):
+def create_task(coro, *, name=None, context=None):
"""Schedule the execution of a coroutine object in a spawn task.
Return a Task object.
"""
loop = events.get_running_loop()
- task = loop.create_task(coro)
+ if context is None:
+ # Use legacy API if context is not needed
+ task = loop.create_task(coro)
+ else:
+ task = loop.create_task(coro, context=context)
+
_set_task_name(task, name)
return task
diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py
index df51528..dea5d6d 100644
--- a/Lib/test/test_asyncio/test_taskgroups.py
+++ b/Lib/test/test_asyncio/test_taskgroups.py
@@ -2,6 +2,7 @@
import asyncio
+import contextvars
from asyncio import taskgroups
import unittest
@@ -708,6 +709,23 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
t = g.create_task(coro(), name="yolo")
self.assertEqual(t.get_name(), "yolo")
+ async def test_taskgroup_task_context(self):
+ cvar = contextvars.ContextVar('cvar')
+
+ async def coro(val):
+ await asyncio.sleep(0)
+ cvar.set(val)
+
+ async with taskgroups.TaskGroup() as g:
+ ctx = contextvars.copy_context()
+ self.assertIsNone(ctx.get(cvar))
+ t1 = g.create_task(coro(1), context=ctx)
+ await t1
+ self.assertEqual(1, ctx.get(cvar))
+ t2 = g.create_task(coro(2), context=ctx)
+ await t2
+ self.assertEqual(2, ctx.get(cvar))
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
index 95fabf7..b6ef627 100644
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -95,8 +95,8 @@ class BaseTaskTests:
Task = None
Future = None
- def new_task(self, loop, coro, name='TestTask'):
- return self.__class__.Task(coro, loop=loop, name=name)
+ def new_task(self, loop, coro, name='TestTask', context=None):
+ return self.__class__.Task(coro, loop=loop, name=name, context=context)
def new_future(self, loop):
return self.__class__.Future(loop=loop)
@@ -2527,6 +2527,90 @@ class BaseTaskTests:
self.assertEqual(cvar.get(), -1)
+ def test_context_4(self):
+ cvar = contextvars.ContextVar('cvar')
+
+ async def coro(val):
+ await asyncio.sleep(0)
+ cvar.set(val)
+
+ async def main():
+ ret = []
+ ctx = contextvars.copy_context()
+ ret.append(ctx.get(cvar))
+ t1 = self.new_task(loop, coro(1), context=ctx)
+ await t1
+ ret.append(ctx.get(cvar))
+ t2 = self.new_task(loop, coro(2), context=ctx)
+ await t2
+ ret.append(ctx.get(cvar))
+ return ret
+
+ loop = asyncio.new_event_loop()
+ try:
+ task = self.new_task(loop, main())
+ ret = loop.run_until_complete(task)
+ finally:
+ loop.close()
+
+ self.assertEqual([None, 1, 2], ret)
+
+ def test_context_5(self):
+ cvar = contextvars.ContextVar('cvar')
+
+ async def coro(val):
+ await asyncio.sleep(0)
+ cvar.set(val)
+
+ async def main():
+ ret = []
+ ctx = contextvars.copy_context()
+ ret.append(ctx.get(cvar))
+ t1 = asyncio.create_task(coro(1), context=ctx)
+ await t1
+ ret.append(ctx.get(cvar))
+ t2 = asyncio.create_task(coro(2), context=ctx)
+ await t2
+ ret.append(ctx.get(cvar))
+ return ret
+
+ loop = asyncio.new_event_loop()
+ try:
+ task = self.new_task(loop, main())
+ ret = loop.run_until_complete(task)
+ finally:
+ loop.close()
+
+ self.assertEqual([None, 1, 2], ret)
+
+ def test_context_6(self):
+ cvar = contextvars.ContextVar('cvar')
+
+ async def coro(val):
+ await asyncio.sleep(0)
+ cvar.set(val)
+
+ async def main():
+ ret = []
+ ctx = contextvars.copy_context()
+ ret.append(ctx.get(cvar))
+ t1 = loop.create_task(coro(1), context=ctx)
+ await t1
+ ret.append(ctx.get(cvar))
+ t2 = loop.create_task(coro(2), context=ctx)
+ await t2
+ ret.append(ctx.get(cvar))
+ return ret
+
+ loop = asyncio.new_event_loop()
+ try:
+ task = loop.create_task(main())
+ ret = loop.run_until_complete(task)
+ finally:
+ loop.close()
+
+ self.assertEqual([None, 1, 2], ret)
+
def test_get_coro(self):
loop = asyncio.new_event_loop()
coro = coroutine_function()
diff --git a/Lib/unittest/async_case.py b/Lib/unittest/async_case.py
index 3c57bb5..25adc3d 100644
--- a/Lib/unittest/async_case.py
+++ b/Lib/unittest/async_case.py
@@ -1,4 +1,5 @@
import asyncio
+import contextvars
import inspect
import warnings
@@ -34,7 +35,7 @@ class IsolatedAsyncioTestCase(TestCase):
def __init__(self, methodName='runTest'):
super().__init__(methodName)
self._asyncioTestLoop = None
- self._asyncioCallsQueue = None
+ self._asyncioTestContext = contextvars.copy_context()
async def asyncSetUp(self):
pass
@@ -58,7 +59,7 @@ class IsolatedAsyncioTestCase(TestCase):
self.addCleanup(*(func, *args), **kwargs)
def _callSetUp(self):
- self.setUp()
+ self._asyncioTestContext.run(self.setUp)
self._callAsync(self.asyncSetUp)
def _callTestMethod(self, method):
@@ -68,47 +69,30 @@ class IsolatedAsyncioTestCase(TestCase):
def _callTearDown(self):
self._callAsync(self.asyncTearDown)
- self.tearDown()
+ self._asyncioTestContext.run(self.tearDown)
def _callCleanup(self, function, *args, **kwargs):
self._callMaybeAsync(function, *args, **kwargs)
def _callAsync(self, func, /, *args, **kwargs):
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
- ret = func(*args, **kwargs)
- assert inspect.isawaitable(ret), f'{func!r} returned non-awaitable'
- fut = self._asyncioTestLoop.create_future()
- self._asyncioCallsQueue.put_nowait((fut, ret))
- return self._asyncioTestLoop.run_until_complete(fut)
+ assert inspect.iscoroutinefunction(func), f'{func!r} is not an async function'
+ task = self._asyncioTestLoop.create_task(
+ func(*args, **kwargs),
+ context=self._asyncioTestContext,
+ )
+ return self._asyncioTestLoop.run_until_complete(task)
def _callMaybeAsync(self, func, /, *args, **kwargs):
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
- ret = func(*args, **kwargs)
- if inspect.isawaitable(ret):
- fut = self._asyncioTestLoop.create_future()
- self._asyncioCallsQueue.put_nowait((fut, ret))
- return self._asyncioTestLoop.run_until_complete(fut)
+ if inspect.iscoroutinefunction(func):
+ task = self._asyncioTestLoop.create_task(
+ func(*args, **kwargs),
+ context=self._asyncioTestContext,
+ )
+ return self._asyncioTestLoop.run_until_complete(task)
else:
- return ret
-
- async def _asyncioLoopRunner(self, fut):
- self._asyncioCallsQueue = queue = asyncio.Queue()
- fut.set_result(None)
- while True:
- query = await queue.get()
- queue.task_done()
- if query is None:
- return
- fut, awaitable = query
- try:
- ret = await awaitable
- if not fut.cancelled():
- fut.set_result(ret)
- except (SystemExit, KeyboardInterrupt):
- raise
- except (BaseException, asyncio.CancelledError) as ex:
- if not fut.cancelled():
- fut.set_exception(ex)
+ return self._asyncioTestContext.run(func, *args, **kwargs)
def _setupAsyncioLoop(self):
assert self._asyncioTestLoop is None, 'asyncio test loop already initialized'
@@ -116,16 +100,11 @@ class IsolatedAsyncioTestCase(TestCase):
asyncio.set_event_loop(loop)
loop.set_debug(True)
self._asyncioTestLoop = loop
- fut = loop.create_future()
- self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut))
- loop.run_until_complete(fut)
def _tearDownAsyncioLoop(self):
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
loop = self._asyncioTestLoop
self._asyncioTestLoop = None
- self._asyncioCallsQueue.put_nowait(None)
- loop.run_until_complete(self._asyncioCallsQueue.join())
try:
# cancel all tasks
diff --git a/Lib/unittest/test/test_async_case.py b/Lib/unittest/test/test_async_case.py
index 3717486..7dc8a6b 100644
--- a/Lib/unittest/test/test_async_case.py
+++ b/Lib/unittest/test/test_async_case.py
@@ -1,4 +1,5 @@
import asyncio
+import contextvars
import unittest
from test import support
@@ -11,6 +12,9 @@ def tearDownModule():
asyncio.set_event_loop_policy(None)
+VAR = contextvars.ContextVar('VAR', default=())
+
+
class TestAsyncCase(unittest.TestCase):
maxDiff = None
@@ -24,22 +28,26 @@ class TestAsyncCase(unittest.TestCase):
def setUp(self):
self.assertEqual(events, [])
events.append('setUp')
+ VAR.set(VAR.get() + ('setUp',))
async def asyncSetUp(self):
self.assertEqual(events, ['setUp'])
events.append('asyncSetUp')
+ VAR.set(VAR.get() + ('asyncSetUp',))
self.addAsyncCleanup(self.on_cleanup1)
async def test_func(self):
self.assertEqual(events, ['setUp',
'asyncSetUp'])
events.append('test')
+ VAR.set(VAR.get() + ('test',))
self.addAsyncCleanup(self.on_cleanup2)
async def asyncTearDown(self):
self.assertEqual(events, ['setUp',
'asyncSetUp',
'test'])
+ VAR.set(VAR.get() + ('asyncTearDown',))
events.append('asyncTearDown')
def tearDown(self):
@@ -48,6 +56,7 @@ class TestAsyncCase(unittest.TestCase):
'test',
'asyncTearDown'])
events.append('tearDown')
+ VAR.set(VAR.get() + ('tearDown',))
async def on_cleanup1(self):
self.assertEqual(events, ['setUp',
@@ -57,6 +66,9 @@ class TestAsyncCase(unittest.TestCase):
'tearDown',
'cleanup2'])
events.append('cleanup1')
+ VAR.set(VAR.get() + ('cleanup1',))
+ nonlocal cvar
+ cvar = VAR.get()
async def on_cleanup2(self):
self.assertEqual(events, ['setUp',
@@ -65,8 +77,10 @@ class TestAsyncCase(unittest.TestCase):
'asyncTearDown',
'tearDown'])
events.append('cleanup2')
+ VAR.set(VAR.get() + ('cleanup2',))
events = []
+ cvar = ()
test = Test("test_func")
result = test.run()
self.assertEqual(result.errors, [])
@@ -74,13 +88,17 @@ class TestAsyncCase(unittest.TestCase):
expected = ['setUp', 'asyncSetUp', 'test',
'asyncTearDown', 'tearDown', 'cleanup2', 'cleanup1']
self.assertEqual(events, expected)
+ self.assertEqual(cvar, tuple(expected))
events = []
+ cvar = ()
test = Test("test_func")
test.debug()
self.assertEqual(events, expected)
+ self.assertEqual(cvar, tuple(expected))
test.doCleanups()
self.assertEqual(events, expected)
+ self.assertEqual(cvar, tuple(expected))
def test_exception_in_setup(self):
class Test(unittest.IsolatedAsyncioTestCase):
diff --git a/Misc/NEWS.d/next/Library/2022-03-12-12-34-13.bpo-46994.d7hPdz.rst b/Misc/NEWS.d/next/Library/2022-03-12-12-34-13.bpo-46994.d7hPdz.rst
new file mode 100644
index 0000000..765936f
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2022-03-12-12-34-13.bpo-46994.d7hPdz.rst
@@ -0,0 +1,2 @@
+Accept explicit contextvars.Context in :func:`asyncio.create_task` and
+:meth:`asyncio.loop.create_task`.
diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c
index 2a6c0b3..4b12744 100644
--- a/Modules/_asynciomodule.c
+++ b/Modules/_asynciomodule.c
@@ -2003,14 +2003,16 @@ _asyncio.Task.__init__
*
loop: object = None
name: object = None
+ context: object = None
A coroutine wrapped in a Future.
[clinic start generated code]*/
static int
_asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop,
- PyObject *name)
-/*[clinic end generated code: output=88b12b83d570df50 input=352a3137fe60091d]*/
+ PyObject *name, PyObject *context)
+/*[clinic end generated code: output=49ac96fe33d0e5c7 input=924522490c8ce825]*/
+
{
if (future_init((FutureObj*)self, loop)) {
return -1;
@@ -2028,9 +2030,13 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop,
return -1;
}
- Py_XSETREF(self->task_context, PyContext_CopyCurrent());
- if (self->task_context == NULL) {
- return -1;
+ if (context == Py_None) {
+ Py_XSETREF(self->task_context, PyContext_CopyCurrent());
+ if (self->task_context == NULL) {
+ return -1;
+ }
+ } else {
+ self->task_context = Py_NewRef(context);
}
Py_CLEAR(self->task_fut_waiter);
diff --git a/Modules/clinic/_asynciomodule.c.h b/Modules/clinic/_asynciomodule.c.h
index 2b84ef0..4a90dfa 100644
--- a/Modules/clinic/_asynciomodule.c.h
+++ b/Modules/clinic/_asynciomodule.c.h
@@ -310,28 +310,29 @@ _asyncio_Future__repr_info(FutureObj *self, PyObject *Py_UNUSED(ignored))
}
PyDoc_STRVAR(_asyncio_Task___init____doc__,
-"Task(coro, *, loop=None, name=None)\n"
+"Task(coro, *, loop=None, name=None, context=None)\n"
"--\n"
"\n"
"A coroutine wrapped in a Future.");
static int
_asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop,
- PyObject *name);
+ PyObject *name, PyObject *context);
static int
_asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs)
{
int return_value = -1;
- static const char * const _keywords[] = {"coro", "loop", "name", NULL};
+ static const char * const _keywords[] = {"coro", "loop", "name", "context", NULL};
static _PyArg_Parser _parser = {NULL, _keywords, "Task", 0};
- PyObject *argsbuf[3];
+ PyObject *argsbuf[4];
PyObject * const *fastargs;
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 1;
PyObject *coro;
PyObject *loop = Py_None;
PyObject *name = Py_None;
+ PyObject *context = Py_None;
fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 1, 0, argsbuf);
if (!fastargs) {
@@ -347,9 +348,15 @@ _asyncio_Task___init__(PyObject *self, PyObject *args, PyObject *kwargs)
goto skip_optional_kwonly;
}
}
- name = fastargs[2];
+ if (fastargs[2]) {
+ name = fastargs[2];
+ if (!--noptargs) {
+ goto skip_optional_kwonly;
+ }
+ }
+ context = fastargs[3];
skip_optional_kwonly:
- return_value = _asyncio_Task___init___impl((TaskObj *)self, coro, loop, name);
+ return_value = _asyncio_Task___init___impl((TaskObj *)self, coro, loop, name, context);
exit:
return return_value;
@@ -917,4 +924,4 @@ _asyncio__leave_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs,
exit:
return return_value;
}
-/*[clinic end generated code: output=344927e9b6016ad7 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=540ed3caf5a4d57d input=a9049054013a1b77]*/