diff options
author | Itamar Ostricher <itamarost@gmail.com> | 2023-05-01 21:10:13 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-01 21:10:13 (GMT) |
commit | a474e04388c2ef6aca75c26cb70a1b6200235feb (patch) | |
tree | 43520d5ad16016620f149dc1e84d4d57e45051d5 /Lib/test/test_asyncio | |
parent | 59bc36aacddd5a3acd32c80c0dfd0726135a7817 (diff) | |
download | cpython-a474e04388c2ef6aca75c26cb70a1b6200235feb.zip cpython-a474e04388c2ef6aca75c26cb70a1b6200235feb.tar.gz cpython-a474e04388c2ef6aca75c26cb70a1b6200235feb.tar.bz2 |
gh-97696: asyncio eager tasks factory (#102853)
Co-authored-by: Jacob Bower <jbower@meta.com>
Co-authored-by: Carol Willing <carolcode@willingconsulting.com>
Diffstat (limited to 'Lib/test/test_asyncio')
-rw-r--r-- | Lib/test/test_asyncio/test_eager_task_factory.py | 344 |
1 files changed, 344 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py new file mode 100644 index 0000000..fe69093 --- /dev/null +++ b/Lib/test/test_asyncio/test_eager_task_factory.py @@ -0,0 +1,344 @@ +"""Tests for base_events.py""" + +import asyncio +import contextvars +import gc +import time +import unittest + +from types import GenericAlias +from unittest import mock +from asyncio import base_events +from asyncio import tasks +from test.test_asyncio import utils as test_utils +from test.test_asyncio.test_tasks import get_innermost_context +from test import support + +MOCK_ANY = mock.ANY + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class EagerTaskFactoryLoopTests: + + Task = None + + def run_coro(self, coro): + """ + Helper method to run the `coro` coroutine in the test event loop. + It helps with making sure the event loop is running before starting + to execute `coro`. This is important for testing the eager step + functionality, since an eager step is taken only if the event loop + is already running. + """ + + async def coro_runner(): + self.assertTrue(asyncio.get_event_loop().is_running()) + return await coro + + return self.loop.run_until_complete(coro) + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.eager_task_factory = asyncio.create_eager_task_factory(self.Task) + self.loop.set_task_factory(self.eager_task_factory) + self.set_event_loop(self.loop) + + def test_eager_task_factory_set(self): + self.assertIsNotNone(self.eager_task_factory) + self.assertIs(self.loop.get_task_factory(), self.eager_task_factory) + + async def noop(): pass + + async def run(): + t = self.loop.create_task(noop()) + self.assertIsInstance(t, self.Task) + await t + + self.run_coro(run()) + + def test_await_future_during_eager_step(self): + + async def set_result(fut, val): + fut.set_result(val) + + async def run(): + fut = self.loop.create_future() + t = self.loop.create_task(set_result(fut, 'my message')) + # assert the eager step completed the task + self.assertTrue(t.done()) + return await fut + + self.assertEqual(self.run_coro(run()), 'my message') + + def test_eager_completion(self): + + async def coro(): + return 'hello' + + async def run(): + t = self.loop.create_task(coro()) + # assert the eager step completed the task + self.assertTrue(t.done()) + return await t + + self.assertEqual(self.run_coro(run()), 'hello') + + def test_block_after_eager_step(self): + + async def coro(): + await asyncio.sleep(0.1) + return 'finished after blocking' + + async def run(): + t = self.loop.create_task(coro()) + self.assertFalse(t.done()) + result = await t + self.assertTrue(t.done()) + return result + + self.assertEqual(self.run_coro(run()), 'finished after blocking') + + def test_cancellation_after_eager_completion(self): + + async def coro(): + return 'finished without blocking' + + async def run(): + t = self.loop.create_task(coro()) + t.cancel() + result = await t + # finished task can't be cancelled + self.assertFalse(t.cancelled()) + return result + + self.assertEqual(self.run_coro(run()), 'finished without blocking') + + def test_cancellation_after_eager_step_blocks(self): + + async def coro(): + await asyncio.sleep(0.1) + return 'finished after blocking' + + async def run(): + t = self.loop.create_task(coro()) + t.cancel('cancellation message') + self.assertGreater(t.cancelling(), 0) + result = await t + + with self.assertRaises(asyncio.CancelledError) as cm: + self.run_coro(run()) + + self.assertEqual('cancellation message', cm.exception.args[0]) + + def test_current_task(self): + captured_current_task = None + + async def coro(): + nonlocal captured_current_task + captured_current_task = asyncio.current_task() + # verify the task before and after blocking is identical + await asyncio.sleep(0.1) + self.assertIs(asyncio.current_task(), captured_current_task) + + async def run(): + t = self.loop.create_task(coro()) + self.assertIs(captured_current_task, t) + await t + + self.run_coro(run()) + captured_current_task = None + + def test_all_tasks_with_eager_completion(self): + captured_all_tasks = None + + async def coro(): + nonlocal captured_all_tasks + captured_all_tasks = asyncio.all_tasks() + + async def run(): + t = self.loop.create_task(coro()) + self.assertIn(t, captured_all_tasks) + self.assertNotIn(t, asyncio.all_tasks()) + + self.run_coro(run()) + + def test_all_tasks_with_blocking(self): + captured_eager_all_tasks = None + + async def coro(fut1, fut2): + nonlocal captured_eager_all_tasks + captured_eager_all_tasks = asyncio.all_tasks() + await fut1 + fut2.set_result(None) + + async def run(): + fut1 = self.loop.create_future() + fut2 = self.loop.create_future() + t = self.loop.create_task(coro(fut1, fut2)) + self.assertIn(t, captured_eager_all_tasks) + self.assertIn(t, asyncio.all_tasks()) + fut1.set_result(None) + await fut2 + self.assertNotIn(t, asyncio.all_tasks()) + + self.run_coro(run()) + + def test_context_vars(self): + cv = contextvars.ContextVar('cv', default=0) + + coro_first_step_ran = False + coro_second_step_ran = False + + async def coro(): + nonlocal coro_first_step_ran + nonlocal coro_second_step_ran + self.assertEqual(cv.get(), 1) + cv.set(2) + self.assertEqual(cv.get(), 2) + coro_first_step_ran = True + await asyncio.sleep(0.1) + self.assertEqual(cv.get(), 2) + cv.set(3) + self.assertEqual(cv.get(), 3) + coro_second_step_ran = True + + async def run(): + cv.set(1) + t = self.loop.create_task(coro()) + self.assertTrue(coro_first_step_ran) + self.assertFalse(coro_second_step_ran) + self.assertEqual(cv.get(), 1) + await t + self.assertTrue(coro_second_step_ran) + self.assertEqual(cv.get(), 1) + + self.run_coro(run()) + + +class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): + Task = tasks._PyTask + + +@unittest.skipUnless(hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): + Task = getattr(tasks, '_CTask', None) + + +class AsyncTaskCounter: + def __init__(self, loop, *, task_class, eager): + self.suspense_count = 0 + self.task_count = 0 + + def CountingTask(*args, eager_start=False, **kwargs): + if not eager_start: + self.task_count += 1 + kwargs["eager_start"] = eager_start + return task_class(*args, **kwargs) + + if eager: + factory = asyncio.create_eager_task_factory(CountingTask) + else: + def factory(loop, coro, **kwargs): + return CountingTask(coro, loop=loop, **kwargs) + loop.set_task_factory(factory) + + def get(self): + return self.task_count + + +async def awaitable_chain(depth): + if depth == 0: + return 0 + return 1 + await awaitable_chain(depth - 1) + + +async def recursive_taskgroups(width, depth): + if depth == 0: + return + + async with asyncio.TaskGroup() as tg: + futures = [ + tg.create_task(recursive_taskgroups(width, depth - 1)) + for _ in range(width) + ] + + +async def recursive_gather(width, depth): + if depth == 0: + return + + await asyncio.gather( + *[recursive_gather(width, depth - 1) for _ in range(width)] + ) + + +class BaseTaskCountingTests: + + Task = None + eager = None + expected_task_count = None + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.counter = AsyncTaskCounter(self.loop, task_class=self.Task, eager=self.eager) + self.set_event_loop(self.loop) + + def test_awaitables_chain(self): + observed_depth = self.loop.run_until_complete(awaitable_chain(100)) + self.assertEqual(observed_depth, 100) + self.assertEqual(self.counter.get(), 0 if self.eager else 1) + + def test_recursive_taskgroups(self): + num_tasks = self.loop.run_until_complete(recursive_taskgroups(5, 4)) + self.assertEqual(self.counter.get(), self.expected_task_count) + + def test_recursive_gather(self): + self.loop.run_until_complete(recursive_gather(5, 4)) + self.assertEqual(self.counter.get(), self.expected_task_count) + + +class BaseNonEagerTaskFactoryTests(BaseTaskCountingTests): + eager = False + expected_task_count = 781 # 1 + 5 + 5^2 + 5^3 + 5^4 + + +class BaseEagerTaskFactoryTests(BaseTaskCountingTests): + eager = True + expected_task_count = 0 + + +class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): + Task = asyncio.Task + + +class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase): + Task = asyncio.Task + + +class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): + Task = tasks._PyTask + + +class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase): + Task = tasks._PyTask + + +@unittest.skipUnless(hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase): + Task = getattr(tasks, '_CTask', None) + + +@unittest.skipUnless(hasattr(tasks, '_CTask'), + 'requires the C _asyncio module') +class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase): + Task = getattr(tasks, '_CTask', None) + +if __name__ == '__main__': + unittest.main() |