summaryrefslogtreecommitdiffstats
path: root/Lib/unittest
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2022-05-08 14:49:09 (GMT)
committerGitHub <noreply@github.com>2022-05-08 14:49:09 (GMT)
commit086c6b1b0fe8d47ebd15512d7bdcb64c60a360f0 (patch)
treea7b1eaf75879c3fded1b946b2331f6a45dfc8fc7 /Lib/unittest
parent8f293180791f2836570bdfc29aadba04a538d435 (diff)
downloadcpython-086c6b1b0fe8d47ebd15512d7bdcb64c60a360f0.zip
cpython-086c6b1b0fe8d47ebd15512d7bdcb64c60a360f0.tar.gz
cpython-086c6b1b0fe8d47ebd15512d7bdcb64c60a360f0.tar.bz2
bpo-45046: Support context managers in unittest (GH-28045)
Add methods enterContext() and enterClassContext() in TestCase. Add method enterAsyncContext() in IsolatedAsyncioTestCase. Add function enterModuleContext().
Diffstat (limited to 'Lib/unittest')
-rw-r--r--Lib/unittest/__init__.py5
-rw-r--r--Lib/unittest/async_case.py20
-rw-r--r--Lib/unittest/case.py32
-rw-r--r--Lib/unittest/test/test_async_case.py53
-rw-r--r--Lib/unittest/test/test_runner.py110
5 files changed, 218 insertions, 2 deletions
diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py
index eda951c..005d23f 100644
--- a/Lib/unittest/__init__.py
+++ b/Lib/unittest/__init__.py
@@ -49,7 +49,7 @@ __all__ = ['TestResult', 'TestCase', 'IsolatedAsyncioTestCase', 'TestSuite',
'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless',
'expectedFailure', 'TextTestResult', 'installHandler',
'registerResult', 'removeResult', 'removeHandler',
- 'addModuleCleanup', 'doModuleCleanups']
+ 'addModuleCleanup', 'doModuleCleanups', 'enterModuleContext']
# Expose obsolete functions for backwards compatibility
# bpo-5846: Deprecated in Python 3.11, scheduled for removal in Python 3.13.
@@ -59,7 +59,8 @@ __unittest = True
from .result import TestResult
from .case import (addModuleCleanup, TestCase, FunctionTestCase, SkipTest, skip,
- skipIf, skipUnless, expectedFailure, doModuleCleanups)
+ skipIf, skipUnless, expectedFailure, doModuleCleanups,
+ enterModuleContext)
from .suite import BaseTestSuite, TestSuite
from .loader import TestLoader, defaultTestLoader
from .main import TestProgram, main
diff --git a/Lib/unittest/async_case.py b/Lib/unittest/async_case.py
index 85b938f..a90eed9 100644
--- a/Lib/unittest/async_case.py
+++ b/Lib/unittest/async_case.py
@@ -58,6 +58,26 @@ class IsolatedAsyncioTestCase(TestCase):
# 3. Regular "def func()" that returns awaitable object
self.addCleanup(*(func, *args), **kwargs)
+ async def enterAsyncContext(self, cm):
+ """Enters the supplied asynchronous context manager.
+
+ If successful, also adds its __aexit__ method as a cleanup
+ function and returns the result of the __aenter__ method.
+ """
+ # We look up the special methods on the type to match the with
+ # statement.
+ cls = type(cm)
+ try:
+ enter = cls.__aenter__
+ exit = cls.__aexit__
+ except AttributeError:
+ raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
+ f"not support the asynchronous context manager protocol"
+ ) from None
+ result = await enter(cm)
+ self.addAsyncCleanup(exit, cm, None, None, None)
+ return result
+
def _callSetUp(self):
self._asyncioTestContext.run(self.setUp)
self._callAsync(self.asyncSetUp)
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index 55770c0..ffc8f19 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -102,12 +102,31 @@ def _id(obj):
return obj
+def _enter_context(cm, addcleanup):
+ # We look up the special methods on the type to match the with
+ # statement.
+ cls = type(cm)
+ try:
+ enter = cls.__enter__
+ exit = cls.__exit__
+ except AttributeError:
+ raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
+ f"not support the context manager protocol") from None
+ result = enter(cm)
+ addcleanup(exit, cm, None, None, None)
+ return result
+
+
_module_cleanups = []
def addModuleCleanup(function, /, *args, **kwargs):
"""Same as addCleanup, except the cleanup items are called even if
setUpModule fails (unlike tearDownModule)."""
_module_cleanups.append((function, args, kwargs))
+def enterModuleContext(cm):
+ """Same as enterContext, but module-wide."""
+ return _enter_context(cm, addModuleCleanup)
+
def doModuleCleanups():
"""Execute all module cleanup functions. Normally called for you after
@@ -426,12 +445,25 @@ class TestCase(object):
Cleanup items are called even if setUp fails (unlike tearDown)."""
self._cleanups.append((function, args, kwargs))
+ def enterContext(self, cm):
+ """Enters the supplied context manager.
+
+ If successful, also adds its __exit__ method as a cleanup
+ function and returns the result of the __enter__ method.
+ """
+ return _enter_context(cm, self.addCleanup)
+
@classmethod
def addClassCleanup(cls, function, /, *args, **kwargs):
"""Same as addCleanup, except the cleanup items are called even if
setUpClass fails (unlike tearDownClass)."""
cls._class_cleanups.append((function, args, kwargs))
+ @classmethod
+ def enterClassContext(cls, cm):
+ """Same as enterContext, but class-wide."""
+ return _enter_context(cm, cls.addClassCleanup)
+
def setUp(self):
"Hook method for setting up the test fixture before exercising it."
pass
diff --git a/Lib/unittest/test/test_async_case.py b/Lib/unittest/test/test_async_case.py
index 1b910a4..beadcac 100644
--- a/Lib/unittest/test/test_async_case.py
+++ b/Lib/unittest/test/test_async_case.py
@@ -14,6 +14,29 @@ def tearDownModule():
asyncio.set_event_loop_policy(None)
+class TestCM:
+ def __init__(self, ordering, enter_result=None):
+ self.ordering = ordering
+ self.enter_result = enter_result
+
+ async def __aenter__(self):
+ self.ordering.append('enter')
+ return self.enter_result
+
+ async def __aexit__(self, *exc_info):
+ self.ordering.append('exit')
+
+
+class LacksEnterAndExit:
+ pass
+class LacksEnter:
+ async def __aexit__(self, *exc_info):
+ pass
+class LacksExit:
+ async def __aenter__(self):
+ pass
+
+
VAR = contextvars.ContextVar('VAR', default=())
@@ -337,6 +360,36 @@ class TestAsyncCase(unittest.TestCase):
output = test.run()
self.assertTrue(cancelled)
+ def test_enterAsyncContext(self):
+ events = []
+
+ class Test(unittest.IsolatedAsyncioTestCase):
+ async def test_func(slf):
+ slf.addAsyncCleanup(events.append, 'cleanup1')
+ cm = TestCM(events, 42)
+ self.assertEqual(await slf.enterAsyncContext(cm), 42)
+ slf.addAsyncCleanup(events.append, 'cleanup2')
+ events.append('test')
+
+ test = Test('test_func')
+ output = test.run()
+ self.assertTrue(output.wasSuccessful(), output)
+ self.assertEqual(events, ['enter', 'test', 'cleanup2', 'exit', 'cleanup1'])
+
+ def test_enterAsyncContext_arg_errors(self):
+ class Test(unittest.IsolatedAsyncioTestCase):
+ async def test_func(slf):
+ with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
+ await slf.enterAsyncContext(LacksEnterAndExit())
+ with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
+ await slf.enterAsyncContext(LacksEnter())
+ with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
+ await slf.enterAsyncContext(LacksExit())
+
+ test = Test('test_func')
+ output = test.run()
+ self.assertTrue(output.wasSuccessful())
+
def test_debug_cleanup_same_loop(self):
class Test(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
diff --git a/Lib/unittest/test/test_runner.py b/Lib/unittest/test/test_runner.py
index 18062ae..d3488b4 100644
--- a/Lib/unittest/test/test_runner.py
+++ b/Lib/unittest/test/test_runner.py
@@ -46,6 +46,29 @@ def cleanup(ordering, blowUp=False):
raise Exception('CleanUpExc')
+class TestCM:
+ def __init__(self, ordering, enter_result=None):
+ self.ordering = ordering
+ self.enter_result = enter_result
+
+ def __enter__(self):
+ self.ordering.append('enter')
+ return self.enter_result
+
+ def __exit__(self, *exc_info):
+ self.ordering.append('exit')
+
+
+class LacksEnterAndExit:
+ pass
+class LacksEnter:
+ def __exit__(self, *exc_info):
+ pass
+class LacksExit:
+ def __enter__(self):
+ pass
+
+
class TestCleanUp(unittest.TestCase):
def testCleanUp(self):
class TestableTest(unittest.TestCase):
@@ -173,6 +196,39 @@ class TestCleanUp(unittest.TestCase):
self.assertEqual(ordering, ['setUp', 'test', 'tearDown', 'cleanup1', 'cleanup2'])
+ def test_enterContext(self):
+ class TestableTest(unittest.TestCase):
+ def testNothing(self):
+ pass
+
+ test = TestableTest('testNothing')
+ cleanups = []
+
+ test.addCleanup(cleanups.append, 'cleanup1')
+ cm = TestCM(cleanups, 42)
+ self.assertEqual(test.enterContext(cm), 42)
+ test.addCleanup(cleanups.append, 'cleanup2')
+
+ self.assertTrue(test.doCleanups())
+ self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1'])
+
+ def test_enterContext_arg_errors(self):
+ class TestableTest(unittest.TestCase):
+ def testNothing(self):
+ pass
+
+ test = TestableTest('testNothing')
+
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ test.enterContext(LacksEnterAndExit())
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ test.enterContext(LacksEnter())
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ test.enterContext(LacksExit())
+
+ self.assertEqual(test._cleanups, [])
+
+
class TestClassCleanup(unittest.TestCase):
def test_addClassCleanUp(self):
class TestableTest(unittest.TestCase):
@@ -451,6 +507,35 @@ class TestClassCleanup(unittest.TestCase):
self.assertEqual(ordering,
['setUpClass', 'test', 'tearDownClass', 'cleanup_good'])
+ def test_enterClassContext(self):
+ class TestableTest(unittest.TestCase):
+ def testNothing(self):
+ pass
+
+ cleanups = []
+
+ TestableTest.addClassCleanup(cleanups.append, 'cleanup1')
+ cm = TestCM(cleanups, 42)
+ self.assertEqual(TestableTest.enterClassContext(cm), 42)
+ TestableTest.addClassCleanup(cleanups.append, 'cleanup2')
+
+ TestableTest.doClassCleanups()
+ self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1'])
+
+ def test_enterClassContext_arg_errors(self):
+ class TestableTest(unittest.TestCase):
+ def testNothing(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ TestableTest.enterClassContext(LacksEnterAndExit())
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ TestableTest.enterClassContext(LacksEnter())
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ TestableTest.enterClassContext(LacksExit())
+
+ self.assertEqual(TestableTest._class_cleanups, [])
+
class TestModuleCleanUp(unittest.TestCase):
def test_add_and_do_ModuleCleanup(self):
@@ -1000,6 +1085,31 @@ class TestModuleCleanUp(unittest.TestCase):
'cleanup2', 'setUp2', 'test2', 'tearDown2',
'cleanup3', 'tearDownModule', 'cleanup1'])
+ def test_enterModuleContext(self):
+ cleanups = []
+
+ unittest.addModuleCleanup(cleanups.append, 'cleanup1')
+ cm = TestCM(cleanups, 42)
+ self.assertEqual(unittest.enterModuleContext(cm), 42)
+ unittest.addModuleCleanup(cleanups.append, 'cleanup2')
+
+ unittest.case.doModuleCleanups()
+ self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1'])
+
+ def test_enterModuleContext_arg_errors(self):
+ class TestableTest(unittest.TestCase):
+ def testNothing(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ unittest.enterModuleContext(LacksEnterAndExit())
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ unittest.enterModuleContext(LacksEnter())
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ unittest.enterModuleContext(LacksExit())
+
+ self.assertEqual(unittest.case._module_cleanups, [])
+
class Test_TextTestRunner(unittest.TestCase):
"""Tests for TextTestRunner."""