diff options
author | Andrew Svetlov <andrew.svetlov@gmail.com> | 2022-03-14 11:54:13 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-03-14 11:54:13 (GMT) |
commit | 9523c0d84f351a610dc651b234461eb015fa3b82 (patch) | |
tree | 5f6f6bed4353eb9c149f65ab2dc95db12d378db3 /Lib/test/test_asyncio | |
parent | 2153daf0a02a598ed5df93f2f224c1ab2a2cca0d (diff) | |
download | cpython-9523c0d84f351a610dc651b234461eb015fa3b82.zip cpython-9523c0d84f351a610dc651b234461eb015fa3b82.tar.gz cpython-9523c0d84f351a610dc651b234461eb015fa3b82.tar.bz2 |
bpo-46994: Accept explicit contextvars.Context in asyncio create_task() API (GH-31837)
Diffstat (limited to 'Lib/test/test_asyncio')
-rw-r--r-- | Lib/test/test_asyncio/test_taskgroups.py | 18 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_tasks.py | 88 |
2 files changed, 104 insertions, 2 deletions
diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index df51528..dea5d6d 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -2,6 +2,7 @@ import asyncio +import contextvars from asyncio import taskgroups import unittest @@ -708,6 +709,23 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase): t = g.create_task(coro(), name="yolo") self.assertEqual(t.get_name(), "yolo") + async def test_taskgroup_task_context(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async with taskgroups.TaskGroup() as g: + ctx = contextvars.copy_context() + self.assertIsNone(ctx.get(cvar)) + t1 = g.create_task(coro(1), context=ctx) + await t1 + self.assertEqual(1, ctx.get(cvar)) + t2 = g.create_task(coro(2), context=ctx) + await t2 + self.assertEqual(2, ctx.get(cvar)) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index 95fabf7..b6ef627 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -95,8 +95,8 @@ class BaseTaskTests: Task = None Future = None - def new_task(self, loop, coro, name='TestTask'): - return self.__class__.Task(coro, loop=loop, name=name) + def new_task(self, loop, coro, name='TestTask', context=None): + return self.__class__.Task(coro, loop=loop, name=name, context=context) def new_future(self, loop): return self.__class__.Future(loop=loop) @@ -2527,6 +2527,90 @@ class BaseTaskTests: self.assertEqual(cvar.get(), -1) + def test_context_4(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async def main(): + ret = [] + ctx = contextvars.copy_context() + ret.append(ctx.get(cvar)) + t1 = self.new_task(loop, coro(1), context=ctx) + await t1 + ret.append(ctx.get(cvar)) + t2 = self.new_task(loop, coro(2), context=ctx) + await t2 + ret.append(ctx.get(cvar)) + return ret + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + ret = loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual([None, 1, 2], ret) + + def test_context_5(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async def main(): + ret = [] + ctx = contextvars.copy_context() + ret.append(ctx.get(cvar)) + t1 = asyncio.create_task(coro(1), context=ctx) + await t1 + ret.append(ctx.get(cvar)) + t2 = asyncio.create_task(coro(2), context=ctx) + await t2 + ret.append(ctx.get(cvar)) + return ret + + loop = asyncio.new_event_loop() + try: + task = self.new_task(loop, main()) + ret = loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual([None, 1, 2], ret) + + def test_context_6(self): + cvar = contextvars.ContextVar('cvar') + + async def coro(val): + await asyncio.sleep(0) + cvar.set(val) + + async def main(): + ret = [] + ctx = contextvars.copy_context() + ret.append(ctx.get(cvar)) + t1 = loop.create_task(coro(1), context=ctx) + await t1 + ret.append(ctx.get(cvar)) + t2 = loop.create_task(coro(2), context=ctx) + await t2 + ret.append(ctx.get(cvar)) + return ret + + loop = asyncio.new_event_loop() + try: + task = loop.create_task(main()) + ret = loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual([None, 1, 2], ret) + def test_get_coro(self): loop = asyncio.new_event_loop() coro = coroutine_function() |