summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_asyncio
diff options
context:
space:
mode:
authorAndrew Svetlov <andrew.svetlov@gmail.com>2022-03-14 11:54:13 (GMT)
committerGitHub <noreply@github.com>2022-03-14 11:54:13 (GMT)
commit9523c0d84f351a610dc651b234461eb015fa3b82 (patch)
tree5f6f6bed4353eb9c149f65ab2dc95db12d378db3 /Lib/test/test_asyncio
parent2153daf0a02a598ed5df93f2f224c1ab2a2cca0d (diff)
downloadcpython-9523c0d84f351a610dc651b234461eb015fa3b82.zip
cpython-9523c0d84f351a610dc651b234461eb015fa3b82.tar.gz
cpython-9523c0d84f351a610dc651b234461eb015fa3b82.tar.bz2
bpo-46994: Accept explicit contextvars.Context in asyncio create_task() API (GH-31837)
Diffstat (limited to 'Lib/test/test_asyncio')
-rw-r--r--Lib/test/test_asyncio/test_taskgroups.py18
-rw-r--r--Lib/test/test_asyncio/test_tasks.py88
2 files changed, 104 insertions, 2 deletions
diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py
index df51528..dea5d6d 100644
--- a/Lib/test/test_asyncio/test_taskgroups.py
+++ b/Lib/test/test_asyncio/test_taskgroups.py
@@ -2,6 +2,7 @@
import asyncio
+import contextvars
from asyncio import taskgroups
import unittest
@@ -708,6 +709,23 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
t = g.create_task(coro(), name="yolo")
self.assertEqual(t.get_name(), "yolo")
+ async def test_taskgroup_task_context(self):
+ cvar = contextvars.ContextVar('cvar')
+
+ async def coro(val):
+ await asyncio.sleep(0)
+ cvar.set(val)
+
+ async with taskgroups.TaskGroup() as g:
+ ctx = contextvars.copy_context()
+ self.assertIsNone(ctx.get(cvar))
+ t1 = g.create_task(coro(1), context=ctx)
+ await t1
+ self.assertEqual(1, ctx.get(cvar))
+ t2 = g.create_task(coro(2), context=ctx)
+ await t2
+ self.assertEqual(2, ctx.get(cvar))
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
index 95fabf7..b6ef627 100644
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -95,8 +95,8 @@ class BaseTaskTests:
Task = None
Future = None
- def new_task(self, loop, coro, name='TestTask'):
- return self.__class__.Task(coro, loop=loop, name=name)
+ def new_task(self, loop, coro, name='TestTask', context=None):
+ return self.__class__.Task(coro, loop=loop, name=name, context=context)
def new_future(self, loop):
return self.__class__.Future(loop=loop)
@@ -2527,6 +2527,90 @@ class BaseTaskTests:
self.assertEqual(cvar.get(), -1)
+ def test_context_4(self):
+ cvar = contextvars.ContextVar('cvar')
+
+ async def coro(val):
+ await asyncio.sleep(0)
+ cvar.set(val)
+
+ async def main():
+ ret = []
+ ctx = contextvars.copy_context()
+ ret.append(ctx.get(cvar))
+ t1 = self.new_task(loop, coro(1), context=ctx)
+ await t1
+ ret.append(ctx.get(cvar))
+ t2 = self.new_task(loop, coro(2), context=ctx)
+ await t2
+ ret.append(ctx.get(cvar))
+ return ret
+
+ loop = asyncio.new_event_loop()
+ try:
+ task = self.new_task(loop, main())
+ ret = loop.run_until_complete(task)
+ finally:
+ loop.close()
+
+ self.assertEqual([None, 1, 2], ret)
+
+ def test_context_5(self):
+ cvar = contextvars.ContextVar('cvar')
+
+ async def coro(val):
+ await asyncio.sleep(0)
+ cvar.set(val)
+
+ async def main():
+ ret = []
+ ctx = contextvars.copy_context()
+ ret.append(ctx.get(cvar))
+ t1 = asyncio.create_task(coro(1), context=ctx)
+ await t1
+ ret.append(ctx.get(cvar))
+ t2 = asyncio.create_task(coro(2), context=ctx)
+ await t2
+ ret.append(ctx.get(cvar))
+ return ret
+
+ loop = asyncio.new_event_loop()
+ try:
+ task = self.new_task(loop, main())
+ ret = loop.run_until_complete(task)
+ finally:
+ loop.close()
+
+ self.assertEqual([None, 1, 2], ret)
+
+ def test_context_6(self):
+ cvar = contextvars.ContextVar('cvar')
+
+ async def coro(val):
+ await asyncio.sleep(0)
+ cvar.set(val)
+
+ async def main():
+ ret = []
+ ctx = contextvars.copy_context()
+ ret.append(ctx.get(cvar))
+ t1 = loop.create_task(coro(1), context=ctx)
+ await t1
+ ret.append(ctx.get(cvar))
+ t2 = loop.create_task(coro(2), context=ctx)
+ await t2
+ ret.append(ctx.get(cvar))
+ return ret
+
+ loop = asyncio.new_event_loop()
+ try:
+ task = loop.create_task(main())
+ ret = loop.run_until_complete(task)
+ finally:
+ loop.close()
+
+ self.assertEqual([None, 1, 2], ret)
+
def test_get_coro(self):
loop = asyncio.new_event_loop()
coro = coroutine_function()