summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_asyncio
diff options
context:
space:
mode:
authorItamar Ostricher <itamarost@gmail.com>2023-05-01 21:10:13 (GMT)
committerGitHub <noreply@github.com>2023-05-01 21:10:13 (GMT)
commita474e04388c2ef6aca75c26cb70a1b6200235feb (patch)
tree43520d5ad16016620f149dc1e84d4d57e45051d5 /Lib/test/test_asyncio
parent59bc36aacddd5a3acd32c80c0dfd0726135a7817 (diff)
downloadcpython-a474e04388c2ef6aca75c26cb70a1b6200235feb.zip
cpython-a474e04388c2ef6aca75c26cb70a1b6200235feb.tar.gz
cpython-a474e04388c2ef6aca75c26cb70a1b6200235feb.tar.bz2
gh-97696: asyncio eager tasks factory (#102853)
Co-authored-by: Jacob Bower <jbower@meta.com> Co-authored-by: Carol Willing <carolcode@willingconsulting.com>
Diffstat (limited to 'Lib/test/test_asyncio')
-rw-r--r--Lib/test/test_asyncio/test_eager_task_factory.py344
1 files changed, 344 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py
new file mode 100644
index 0000000..fe69093
--- /dev/null
+++ b/Lib/test/test_asyncio/test_eager_task_factory.py
@@ -0,0 +1,344 @@
+"""Tests for base_events.py"""
+
+import asyncio
+import contextvars
+import gc
+import time
+import unittest
+
+from types import GenericAlias
+from unittest import mock
+from asyncio import base_events
+from asyncio import tasks
+from test.test_asyncio import utils as test_utils
+from test.test_asyncio.test_tasks import get_innermost_context
+from test import support
+
+MOCK_ANY = mock.ANY
+
+
+def tearDownModule():
+ asyncio.set_event_loop_policy(None)
+
+
+class EagerTaskFactoryLoopTests:
+
+ Task = None
+
+ def run_coro(self, coro):
+ """
+ Helper method to run the `coro` coroutine in the test event loop.
+ It helps with making sure the event loop is running before starting
+ to execute `coro`. This is important for testing the eager step
+ functionality, since an eager step is taken only if the event loop
+ is already running.
+ """
+
+ async def coro_runner():
+ self.assertTrue(asyncio.get_event_loop().is_running())
+ return await coro
+
+ return self.loop.run_until_complete(coro)
+
+ def setUp(self):
+ super().setUp()
+ self.loop = asyncio.new_event_loop()
+ self.eager_task_factory = asyncio.create_eager_task_factory(self.Task)
+ self.loop.set_task_factory(self.eager_task_factory)
+ self.set_event_loop(self.loop)
+
+ def test_eager_task_factory_set(self):
+ self.assertIsNotNone(self.eager_task_factory)
+ self.assertIs(self.loop.get_task_factory(), self.eager_task_factory)
+
+ async def noop(): pass
+
+ async def run():
+ t = self.loop.create_task(noop())
+ self.assertIsInstance(t, self.Task)
+ await t
+
+ self.run_coro(run())
+
+ def test_await_future_during_eager_step(self):
+
+ async def set_result(fut, val):
+ fut.set_result(val)
+
+ async def run():
+ fut = self.loop.create_future()
+ t = self.loop.create_task(set_result(fut, 'my message'))
+ # assert the eager step completed the task
+ self.assertTrue(t.done())
+ return await fut
+
+ self.assertEqual(self.run_coro(run()), 'my message')
+
+ def test_eager_completion(self):
+
+ async def coro():
+ return 'hello'
+
+ async def run():
+ t = self.loop.create_task(coro())
+ # assert the eager step completed the task
+ self.assertTrue(t.done())
+ return await t
+
+ self.assertEqual(self.run_coro(run()), 'hello')
+
+ def test_block_after_eager_step(self):
+
+ async def coro():
+ await asyncio.sleep(0.1)
+ return 'finished after blocking'
+
+ async def run():
+ t = self.loop.create_task(coro())
+ self.assertFalse(t.done())
+ result = await t
+ self.assertTrue(t.done())
+ return result
+
+ self.assertEqual(self.run_coro(run()), 'finished after blocking')
+
+ def test_cancellation_after_eager_completion(self):
+
+ async def coro():
+ return 'finished without blocking'
+
+ async def run():
+ t = self.loop.create_task(coro())
+ t.cancel()
+ result = await t
+ # finished task can't be cancelled
+ self.assertFalse(t.cancelled())
+ return result
+
+ self.assertEqual(self.run_coro(run()), 'finished without blocking')
+
+ def test_cancellation_after_eager_step_blocks(self):
+
+ async def coro():
+ await asyncio.sleep(0.1)
+ return 'finished after blocking'
+
+ async def run():
+ t = self.loop.create_task(coro())
+ t.cancel('cancellation message')
+ self.assertGreater(t.cancelling(), 0)
+ result = await t
+
+ with self.assertRaises(asyncio.CancelledError) as cm:
+ self.run_coro(run())
+
+ self.assertEqual('cancellation message', cm.exception.args[0])
+
+ def test_current_task(self):
+ captured_current_task = None
+
+ async def coro():
+ nonlocal captured_current_task
+ captured_current_task = asyncio.current_task()
+ # verify the task before and after blocking is identical
+ await asyncio.sleep(0.1)
+ self.assertIs(asyncio.current_task(), captured_current_task)
+
+ async def run():
+ t = self.loop.create_task(coro())
+ self.assertIs(captured_current_task, t)
+ await t
+
+ self.run_coro(run())
+ captured_current_task = None
+
+ def test_all_tasks_with_eager_completion(self):
+ captured_all_tasks = None
+
+ async def coro():
+ nonlocal captured_all_tasks
+ captured_all_tasks = asyncio.all_tasks()
+
+ async def run():
+ t = self.loop.create_task(coro())
+ self.assertIn(t, captured_all_tasks)
+ self.assertNotIn(t, asyncio.all_tasks())
+
+ self.run_coro(run())
+
+ def test_all_tasks_with_blocking(self):
+ captured_eager_all_tasks = None
+
+ async def coro(fut1, fut2):
+ nonlocal captured_eager_all_tasks
+ captured_eager_all_tasks = asyncio.all_tasks()
+ await fut1
+ fut2.set_result(None)
+
+ async def run():
+ fut1 = self.loop.create_future()
+ fut2 = self.loop.create_future()
+ t = self.loop.create_task(coro(fut1, fut2))
+ self.assertIn(t, captured_eager_all_tasks)
+ self.assertIn(t, asyncio.all_tasks())
+ fut1.set_result(None)
+ await fut2
+ self.assertNotIn(t, asyncio.all_tasks())
+
+ self.run_coro(run())
+
+ def test_context_vars(self):
+ cv = contextvars.ContextVar('cv', default=0)
+
+ coro_first_step_ran = False
+ coro_second_step_ran = False
+
+ async def coro():
+ nonlocal coro_first_step_ran
+ nonlocal coro_second_step_ran
+ self.assertEqual(cv.get(), 1)
+ cv.set(2)
+ self.assertEqual(cv.get(), 2)
+ coro_first_step_ran = True
+ await asyncio.sleep(0.1)
+ self.assertEqual(cv.get(), 2)
+ cv.set(3)
+ self.assertEqual(cv.get(), 3)
+ coro_second_step_ran = True
+
+ async def run():
+ cv.set(1)
+ t = self.loop.create_task(coro())
+ self.assertTrue(coro_first_step_ran)
+ self.assertFalse(coro_second_step_ran)
+ self.assertEqual(cv.get(), 1)
+ await t
+ self.assertTrue(coro_second_step_ran)
+ self.assertEqual(cv.get(), 1)
+
+ self.run_coro(run())
+
+
+class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
+ Task = tasks._PyTask
+
+
+@unittest.skipUnless(hasattr(tasks, '_CTask'),
+ 'requires the C _asyncio module')
+class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
+ Task = getattr(tasks, '_CTask', None)
+
+
+class AsyncTaskCounter:
+ def __init__(self, loop, *, task_class, eager):
+ self.suspense_count = 0
+ self.task_count = 0
+
+ def CountingTask(*args, eager_start=False, **kwargs):
+ if not eager_start:
+ self.task_count += 1
+ kwargs["eager_start"] = eager_start
+ return task_class(*args, **kwargs)
+
+ if eager:
+ factory = asyncio.create_eager_task_factory(CountingTask)
+ else:
+ def factory(loop, coro, **kwargs):
+ return CountingTask(coro, loop=loop, **kwargs)
+ loop.set_task_factory(factory)
+
+ def get(self):
+ return self.task_count
+
+
+async def awaitable_chain(depth):
+ if depth == 0:
+ return 0
+ return 1 + await awaitable_chain(depth - 1)
+
+
+async def recursive_taskgroups(width, depth):
+ if depth == 0:
+ return
+
+ async with asyncio.TaskGroup() as tg:
+ futures = [
+ tg.create_task(recursive_taskgroups(width, depth - 1))
+ for _ in range(width)
+ ]
+
+
+async def recursive_gather(width, depth):
+ if depth == 0:
+ return
+
+ await asyncio.gather(
+ *[recursive_gather(width, depth - 1) for _ in range(width)]
+ )
+
+
+class BaseTaskCountingTests:
+
+ Task = None
+ eager = None
+ expected_task_count = None
+
+ def setUp(self):
+ super().setUp()
+ self.loop = asyncio.new_event_loop()
+ self.counter = AsyncTaskCounter(self.loop, task_class=self.Task, eager=self.eager)
+ self.set_event_loop(self.loop)
+
+ def test_awaitables_chain(self):
+ observed_depth = self.loop.run_until_complete(awaitable_chain(100))
+ self.assertEqual(observed_depth, 100)
+ self.assertEqual(self.counter.get(), 0 if self.eager else 1)
+
+ def test_recursive_taskgroups(self):
+ num_tasks = self.loop.run_until_complete(recursive_taskgroups(5, 4))
+ self.assertEqual(self.counter.get(), self.expected_task_count)
+
+ def test_recursive_gather(self):
+ self.loop.run_until_complete(recursive_gather(5, 4))
+ self.assertEqual(self.counter.get(), self.expected_task_count)
+
+
+class BaseNonEagerTaskFactoryTests(BaseTaskCountingTests):
+ eager = False
+ expected_task_count = 781 # 1 + 5 + 5^2 + 5^3 + 5^4
+
+
+class BaseEagerTaskFactoryTests(BaseTaskCountingTests):
+ eager = True
+ expected_task_count = 0
+
+
+class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
+ Task = asyncio.Task
+
+
+class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
+ Task = asyncio.Task
+
+
+class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
+ Task = tasks._PyTask
+
+
+class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
+ Task = tasks._PyTask
+
+
+@unittest.skipUnless(hasattr(tasks, '_CTask'),
+ 'requires the C _asyncio module')
+class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
+ Task = getattr(tasks, '_CTask', None)
+
+
+@unittest.skipUnless(hasattr(tasks, '_CTask'),
+ 'requires the C _asyncio module')
+class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
+ Task = getattr(tasks, '_CTask', None)
+
+if __name__ == '__main__':
+ unittest.main()