summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/test_support.py2
-rw-r--r--Lib/unittest/__init__.py3
-rw-r--r--Lib/unittest/async_case.py158
-rw-r--r--Lib/unittest/case.py20
-rw-r--r--Lib/unittest/test/test_async_case.py195
5 files changed, 372 insertions, 6 deletions
diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py
index cb664ba..8f0746a 100644
--- a/Lib/test/test_support.py
+++ b/Lib/test/test_support.py
@@ -403,7 +403,7 @@ class TestSupport(unittest.TestCase):
("unittest.result", "unittest.case",
"unittest.suite", "unittest.loader",
"unittest.main", "unittest.runner",
- "unittest.signals"),
+ "unittest.signals", "unittest.async_case"),
extra=extra,
blacklist=blacklist)
diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py
index 5ff1bf3..ace3a6f 100644
--- a/Lib/unittest/__init__.py
+++ b/Lib/unittest/__init__.py
@@ -44,7 +44,7 @@ AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
"""
-__all__ = ['TestResult', 'TestCase', 'TestSuite',
+__all__ = ['TestResult', 'TestCase', 'IsolatedAsyncioTestCase', 'TestSuite',
'TextTestRunner', 'TestLoader', 'FunctionTestCase', 'main',
'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless',
'expectedFailure', 'TextTestResult', 'installHandler',
@@ -57,6 +57,7 @@ __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])
__unittest = True
from .result import TestResult
+from .async_case import IsolatedAsyncioTestCase
from .case import (addModuleCleanup, TestCase, FunctionTestCase, SkipTest, skip,
skipIf, skipUnless, expectedFailure)
from .suite import BaseTestSuite, TestSuite
diff --git a/Lib/unittest/async_case.py b/Lib/unittest/async_case.py
new file mode 100644
index 0000000..a3c8bfb
--- /dev/null
+++ b/Lib/unittest/async_case.py
@@ -0,0 +1,158 @@
+import asyncio
+import inspect
+
+from .case import TestCase
+
+
+
+class IsolatedAsyncioTestCase(TestCase):
+ # Names intentionally have a long prefix
+ # to reduce a chance of clashing with user-defined attributes
+ # from inherited test case
+ #
+ # The class doesn't call loop.run_until_complete(self.setUp()) and family
+ # but uses a different approach:
+ # 1. create a long-running task that reads self.setUp()
+ # awaitable from queue along with a future
+ # 2. await the awaitable object passing in and set the result
+ # into the future object
+ # 3. Outer code puts the awaitable and the future object into a queue
+ # with waiting for the future
+ # The trick is necessary because every run_until_complete() call
+ # creates a new task with embedded ContextVar context.
+ # To share contextvars between setUp(), test and tearDown() we need to execute
+ # them inside the same task.
+
+ # Note: the test case modifies event loop policy if the policy was not instantiated
+ # yet.
+ # asyncio.get_event_loop_policy() creates a default policy on demand but never
+ # returns None
+ # I believe this is not an issue in user level tests but python itself for testing
+ # should reset a policy in every test module
+ # by calling asyncio.set_event_loop_policy(None) in tearDownModule()
+
+ def __init__(self, methodName='runTest'):
+ super().__init__(methodName)
+ self._asyncioTestLoop = None
+ self._asyncioCallsQueue = None
+
+ async def asyncSetUp(self):
+ pass
+
+ async def asyncTearDown(self):
+ pass
+
+ def addAsyncCleanup(self, func, /, *args, **kwargs):
+ # A trivial trampoline to addCleanup()
+ # the function exists because it has a different semantics
+ # and signature:
+ # addCleanup() accepts regular functions
+ # but addAsyncCleanup() accepts coroutines
+ #
+ # We intentionally don't add inspect.iscoroutinefunction() check
+ # for func argument because there is no way
+ # to check for async function reliably:
+ # 1. It can be "async def func()" iself
+ # 2. Class can implement "async def __call__()" method
+ # 3. Regular "def func()" that returns awaitable object
+ self.addCleanup(*(func, *args), **kwargs)
+
+ def _callSetUp(self):
+ self.setUp()
+ self._callAsync(self.asyncSetUp)
+
+ def _callTestMethod(self, method):
+ self._callMaybeAsync(method)
+
+ def _callTearDown(self):
+ self._callAsync(self.asyncTearDown)
+ self.tearDown()
+
+ def _callCleanup(self, function, *args, **kwargs):
+ self._callMaybeAsync(function, *args, **kwargs)
+
+ def _callAsync(self, func, /, *args, **kwargs):
+ assert self._asyncioTestLoop is not None
+ ret = func(*args, **kwargs)
+ assert inspect.isawaitable(ret)
+ fut = self._asyncioTestLoop.create_future()
+ self._asyncioCallsQueue.put_nowait((fut, ret))
+ return self._asyncioTestLoop.run_until_complete(fut)
+
+ def _callMaybeAsync(self, func, /, *args, **kwargs):
+ assert self._asyncioTestLoop is not None
+ ret = func(*args, **kwargs)
+ if inspect.isawaitable(ret):
+ fut = self._asyncioTestLoop.create_future()
+ self._asyncioCallsQueue.put_nowait((fut, ret))
+ return self._asyncioTestLoop.run_until_complete(fut)
+ else:
+ return ret
+
+ async def _asyncioLoopRunner(self):
+ queue = self._asyncioCallsQueue
+ while True:
+ query = await queue.get()
+ queue.task_done()
+ if query is None:
+ return
+ fut, awaitable = query
+ try:
+ ret = await awaitable
+ if not fut.cancelled():
+ fut.set_result(ret)
+ except asyncio.CancelledError:
+ raise
+ except Exception as ex:
+ if not fut.cancelled():
+ fut.set_exception(ex)
+
+ def _setupAsyncioLoop(self):
+ assert self._asyncioTestLoop is None
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ loop.set_debug(True)
+ self._asyncioTestLoop = loop
+ self._asyncioCallsQueue = asyncio.Queue(loop=loop)
+ self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner())
+
+ def _tearDownAsyncioLoop(self):
+ assert self._asyncioTestLoop is not None
+ loop = self._asyncioTestLoop
+ self._asyncioTestLoop = None
+ self._asyncioCallsQueue.put_nowait(None)
+ loop.run_until_complete(self._asyncioCallsQueue.join())
+
+ try:
+ # cancel all tasks
+ to_cancel = asyncio.all_tasks(loop)
+ if not to_cancel:
+ return
+
+ for task in to_cancel:
+ task.cancel()
+
+ loop.run_until_complete(
+ asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
+
+ for task in to_cancel:
+ if task.cancelled():
+ continue
+ if task.exception() is not None:
+ loop.call_exception_handler({
+ 'message': 'unhandled exception during test shutdown',
+ 'exception': task.exception(),
+ 'task': task,
+ })
+ # shutdown asyncgens
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ finally:
+ asyncio.set_event_loop(None)
+ loop.close()
+
+ def run(self, result=None):
+ self._setupAsyncioLoop()
+ try:
+ return super().run(result)
+ finally:
+ self._tearDownAsyncioLoop()
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index 8e01c3d..7b1e869 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -645,6 +645,18 @@ class TestCase(object):
else:
addUnexpectedSuccess(self)
+ def _callSetUp(self):
+ self.setUp()
+
+ def _callTestMethod(self, method):
+ method()
+
+ def _callTearDown(self):
+ self.tearDown()
+
+ def _callCleanup(self, function, /, *args, **kwargs):
+ function(*args, **kwargs)
+
def run(self, result=None):
orig_result = result
if result is None:
@@ -676,14 +688,14 @@ class TestCase(object):
self._outcome = outcome
with outcome.testPartExecutor(self):
- self.setUp()
+ self._callSetUp()
if outcome.success:
outcome.expecting_failure = expecting_failure
with outcome.testPartExecutor(self, isTest=True):
- testMethod()
+ self._callTestMethod(testMethod)
outcome.expecting_failure = False
with outcome.testPartExecutor(self):
- self.tearDown()
+ self._callTearDown()
self.doCleanups()
for test, reason in outcome.skipped:
@@ -721,7 +733,7 @@ class TestCase(object):
while self._cleanups:
function, args, kwargs = self._cleanups.pop()
with outcome.testPartExecutor(self):
- function(*args, **kwargs)
+ self._callCleanup(function, *args, **kwargs)
# return this for backwards compatibility
# even though we no longer use it internally
diff --git a/Lib/unittest/test/test_async_case.py b/Lib/unittest/test/test_async_case.py
new file mode 100644
index 0000000..2db441d
--- /dev/null
+++ b/Lib/unittest/test/test_async_case.py
@@ -0,0 +1,195 @@
+import asyncio
+import unittest
+
+
+def tearDownModule():
+ asyncio.set_event_loop_policy(None)
+
+
+class TestAsyncCase(unittest.TestCase):
+ def test_full_cycle(self):
+ events = []
+
+ class Test(unittest.IsolatedAsyncioTestCase):
+ def setUp(self):
+ self.assertEqual(events, [])
+ events.append('setUp')
+
+ async def asyncSetUp(self):
+ self.assertEqual(events, ['setUp'])
+ events.append('asyncSetUp')
+
+ async def test_func(self):
+ self.assertEqual(events, ['setUp',
+ 'asyncSetUp'])
+ events.append('test')
+ self.addAsyncCleanup(self.on_cleanup)
+
+ async def asyncTearDown(self):
+ self.assertEqual(events, ['setUp',
+ 'asyncSetUp',
+ 'test'])
+ events.append('asyncTearDown')
+
+ def tearDown(self):
+ self.assertEqual(events, ['setUp',
+ 'asyncSetUp',
+ 'test',
+ 'asyncTearDown'])
+ events.append('tearDown')
+
+ async def on_cleanup(self):
+ self.assertEqual(events, ['setUp',
+ 'asyncSetUp',
+ 'test',
+ 'asyncTearDown',
+ 'tearDown'])
+ events.append('cleanup')
+
+ test = Test("test_func")
+ test.run()
+ self.assertEqual(events, ['setUp',
+ 'asyncSetUp',
+ 'test',
+ 'asyncTearDown',
+ 'tearDown',
+ 'cleanup'])
+
+ def test_exception_in_setup(self):
+ events = []
+
+ class Test(unittest.IsolatedAsyncioTestCase):
+ async def asyncSetUp(self):
+ events.append('asyncSetUp')
+ raise Exception()
+
+ async def test_func(self):
+ events.append('test')
+ self.addAsyncCleanup(self.on_cleanup)
+
+ async def asyncTearDown(self):
+ events.append('asyncTearDown')
+
+ async def on_cleanup(self):
+ events.append('cleanup')
+
+
+ test = Test("test_func")
+ test.run()
+ self.assertEqual(events, ['asyncSetUp'])
+
+ def test_exception_in_test(self):
+ events = []
+
+ class Test(unittest.IsolatedAsyncioTestCase):
+ async def asyncSetUp(self):
+ events.append('asyncSetUp')
+
+ async def test_func(self):
+ events.append('test')
+ raise Exception()
+ self.addAsyncCleanup(self.on_cleanup)
+
+ async def asyncTearDown(self):
+ events.append('asyncTearDown')
+
+ async def on_cleanup(self):
+ events.append('cleanup')
+
+ test = Test("test_func")
+ test.run()
+ self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown'])
+
+ def test_exception_in_test_after_adding_cleanup(self):
+ events = []
+
+ class Test(unittest.IsolatedAsyncioTestCase):
+ async def asyncSetUp(self):
+ events.append('asyncSetUp')
+
+ async def test_func(self):
+ events.append('test')
+ self.addAsyncCleanup(self.on_cleanup)
+ raise Exception()
+
+ async def asyncTearDown(self):
+ events.append('asyncTearDown')
+
+ async def on_cleanup(self):
+ events.append('cleanup')
+
+ test = Test("test_func")
+ test.run()
+ self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup'])
+
+ def test_exception_in_tear_down(self):
+ events = []
+
+ class Test(unittest.IsolatedAsyncioTestCase):
+ async def asyncSetUp(self):
+ events.append('asyncSetUp')
+
+ async def test_func(self):
+ events.append('test')
+ self.addAsyncCleanup(self.on_cleanup)
+
+ async def asyncTearDown(self):
+ events.append('asyncTearDown')
+ raise Exception()
+
+ async def on_cleanup(self):
+ events.append('cleanup')
+
+ test = Test("test_func")
+ test.run()
+ self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup'])
+
+
+ def test_exception_in_tear_clean_up(self):
+ events = []
+
+ class Test(unittest.IsolatedAsyncioTestCase):
+ async def asyncSetUp(self):
+ events.append('asyncSetUp')
+
+ async def test_func(self):
+ events.append('test')
+ self.addAsyncCleanup(self.on_cleanup)
+
+ async def asyncTearDown(self):
+ events.append('asyncTearDown')
+
+ async def on_cleanup(self):
+ events.append('cleanup')
+ raise Exception()
+
+ test = Test("test_func")
+ test.run()
+ self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup'])
+
+ def test_cleanups_interleave_order(self):
+ events = []
+
+ class Test(unittest.IsolatedAsyncioTestCase):
+ async def test_func(self):
+ self.addAsyncCleanup(self.on_sync_cleanup, 1)
+ self.addAsyncCleanup(self.on_async_cleanup, 2)
+ self.addAsyncCleanup(self.on_sync_cleanup, 3)
+ self.addAsyncCleanup(self.on_async_cleanup, 4)
+
+ async def on_sync_cleanup(self, val):
+ events.append(f'sync_cleanup {val}')
+
+ async def on_async_cleanup(self, val):
+ events.append(f'async_cleanup {val}')
+
+ test = Test("test_func")
+ test.run()
+ self.assertEqual(events, ['async_cleanup 4',
+ 'sync_cleanup 3',
+ 'async_cleanup 2',
+ 'sync_cleanup 1'])
+
+
+if __name__ == "__main__":
+ unittest.main()