summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2022-05-08 14:49:09 (GMT)
committerGitHub <noreply@github.com>2022-05-08 14:49:09 (GMT)
commit086c6b1b0fe8d47ebd15512d7bdcb64c60a360f0 (patch)
treea7b1eaf75879c3fded1b946b2331f6a45dfc8fc7
parent8f293180791f2836570bdfc29aadba04a538d435 (diff)
downloadcpython-086c6b1b0fe8d47ebd15512d7bdcb64c60a360f0.zip
cpython-086c6b1b0fe8d47ebd15512d7bdcb64c60a360f0.tar.gz
cpython-086c6b1b0fe8d47ebd15512d7bdcb64c60a360f0.tar.bz2
bpo-45046: Support context managers in unittest (GH-28045)
Add methods enterContext() and enterClassContext() in TestCase. Add method enterAsyncContext() in IsolatedAsyncioTestCase. Add function enterModuleContext().
-rw-r--r--Doc/library/unittest.rst42
-rw-r--r--Doc/whatsnew/3.11.rst12
-rw-r--r--Lib/distutils/tests/test_build_ext.py4
-rw-r--r--Lib/test/test__osx_support.py3
-rw-r--r--Lib/test/test_argparse.py6
-rw-r--r--Lib/test/test_getopt.py6
-rw-r--r--Lib/test/test_gettext.py7
-rw-r--r--Lib/test/test_global.py11
-rw-r--r--Lib/test/test_importlib/source/test_finder.py19
-rw-r--r--Lib/test/test_importlib/test_namespace_pkgs.py7
-rw-r--r--Lib/test/test_logging.py4
-rw-r--r--Lib/test/test_nntplib.py3
-rw-r--r--Lib/test/test_peg_generator/test_c_parser.py4
-rw-r--r--Lib/test/test_poll.py3
-rw-r--r--Lib/test/test_posix.py12
-rw-r--r--Lib/test/test_set.py6
-rwxr-xr-xLib/test/test_socket.py4
-rw-r--r--Lib/test/test_ssl.py6
-rw-r--r--Lib/test/test_tempfile.py6
-rw-r--r--Lib/test/test_urllib.py7
-rw-r--r--Lib/unittest/__init__.py5
-rw-r--r--Lib/unittest/async_case.py20
-rw-r--r--Lib/unittest/case.py32
-rw-r--r--Lib/unittest/test/test_async_case.py53
-rw-r--r--Lib/unittest/test/test_runner.py110
-rw-r--r--Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst7
26 files changed, 307 insertions, 92 deletions
diff --git a/Doc/library/unittest.rst b/Doc/library/unittest.rst
index 9b8b75a..f6bcba0 100644
--- a/Doc/library/unittest.rst
+++ b/Doc/library/unittest.rst
@@ -1495,6 +1495,16 @@ Test cases
.. versionadded:: 3.1
+ .. method:: enterContext(cm)
+
+ Enter the supplied :term:`context manager`. If successful, also
+ add its :meth:`~object.__exit__` method as a cleanup function by
+ :meth:`addCleanup` and return the result of the
+ :meth:`~object.__enter__` method.
+
+ .. versionadded:: 3.11
+
+
.. method:: doCleanups()
This method is called unconditionally after :meth:`tearDown`, or
@@ -1510,6 +1520,7 @@ Test cases
.. versionadded:: 3.1
+
.. classmethod:: addClassCleanup(function, /, *args, **kwargs)
Add a function to be called after :meth:`tearDownClass` to cleanup
@@ -1524,6 +1535,16 @@ Test cases
.. versionadded:: 3.8
+ .. classmethod:: enterClassContext(cm)
+
+ Enter the supplied :term:`context manager`. If successful, also
+ add its :meth:`~object.__exit__` method as a cleanup function by
+ :meth:`addClassCleanup` and return the result of the
+ :meth:`~object.__enter__` method.
+
+ .. versionadded:: 3.11
+
+
.. classmethod:: doClassCleanups()
This method is called unconditionally after :meth:`tearDownClass`, or
@@ -1571,6 +1592,16 @@ Test cases
This method accepts a coroutine that can be used as a cleanup function.
+ .. coroutinemethod:: enterAsyncContext(cm)
+
+ Enter the supplied :term:`asynchronous context manager`. If successful,
+ also add its :meth:`~object.__aexit__` method as a cleanup function by
+ :meth:`addAsyncCleanup` and return the result of the
+ :meth:`~object.__aenter__` method.
+
+ .. versionadded:: 3.11
+
+
.. method:: run(result=None)
Sets up a new event loop to run the test, collecting the result into
@@ -2465,6 +2496,16 @@ To add cleanup code that must be run even in the case of an exception, use
.. versionadded:: 3.8
+.. classmethod:: enterModuleContext(cm)
+
+ Enter the supplied :term:`context manager`. If successful, also
+ add its :meth:`~object.__exit__` method as a cleanup function by
+ :func:`addModuleCleanup` and return the result of the
+ :meth:`~object.__enter__` method.
+
+ .. versionadded:: 3.11
+
+
.. function:: doModuleCleanups()
This function is called unconditionally after :func:`tearDownModule`, or
@@ -2480,6 +2521,7 @@ To add cleanup code that must be run even in the case of an exception, use
.. versionadded:: 3.8
+
Signal Handling
---------------
diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst
index c4e8e6f..defaeeb 100644
--- a/Doc/whatsnew/3.11.rst
+++ b/Doc/whatsnew/3.11.rst
@@ -758,6 +758,18 @@ unicodedata
* The Unicode database has been updated to version 14.0.0. (:issue:`45190`).
+unittest
+--------
+
+* Added methods :meth:`~unittest.TestCase.enterContext` and
+ :meth:`~unittest.TestCase.enterClassContext` of class
+ :class:`~unittest.TestCase`, method
+ :meth:`~unittest.IsolatedAsyncioTestCase.enterAsyncContext` of
+ class :class:`~unittest.IsolatedAsyncioTestCase` and function
+ :func:`unittest.enterModuleContext`.
+ (Contributed by Serhiy Storchaka in :issue:`45046`.)
+
+
venv
----
diff --git a/Lib/distutils/tests/test_build_ext.py b/Lib/distutils/tests/test_build_ext.py
index 031897b..4ebeafe 100644
--- a/Lib/distutils/tests/test_build_ext.py
+++ b/Lib/distutils/tests/test_build_ext.py
@@ -41,9 +41,7 @@ class BuildExtTestCase(TempdirManager,
# bpo-30132: On Windows, a .pdb file may be created in the current
# working directory. Create a temporary working directory to cleanup
# everything at the end of the test.
- change_cwd = os_helper.change_cwd(self.tmp_dir)
- change_cwd.__enter__()
- self.addCleanup(change_cwd.__exit__, None, None, None)
+ self.enterContext(os_helper.change_cwd(self.tmp_dir))
def tearDown(self):
import site
diff --git a/Lib/test/test__osx_support.py b/Lib/test/test__osx_support.py
index 907ae27..4a14cb3 100644
--- a/Lib/test/test__osx_support.py
+++ b/Lib/test/test__osx_support.py
@@ -19,8 +19,7 @@ class Test_OSXSupport(unittest.TestCase):
self.maxDiff = None
self.prog_name = 'bogus_program_xxxx'
self.temp_path_dir = os.path.abspath(os.getcwd())
- self.env = os_helper.EnvironmentVarGuard()
- self.addCleanup(self.env.__exit__)
+ self.env = self.enterContext(os_helper.EnvironmentVarGuard())
for cv in ('CFLAGS', 'LDFLAGS', 'CPPFLAGS',
'BASECFLAGS', 'BLDSHARED', 'LDSHARED', 'CC',
'CXX', 'PY_CFLAGS', 'PY_LDFLAGS', 'PY_CPPFLAGS',
diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py
index 8509deb..273db45 100644
--- a/Lib/test/test_argparse.py
+++ b/Lib/test/test_argparse.py
@@ -41,9 +41,8 @@ class TestCase(unittest.TestCase):
# The tests assume that line wrapping occurs at 80 columns, but this
# behaviour can be overridden by setting the COLUMNS environment
# variable. To ensure that this width is used, set COLUMNS to 80.
- env = os_helper.EnvironmentVarGuard()
+ env = self.enterContext(os_helper.EnvironmentVarGuard())
env['COLUMNS'] = '80'
- self.addCleanup(env.__exit__)
class TempDirMixin(object):
@@ -3428,9 +3427,8 @@ class TestShortColumns(HelpTestCase):
but we don't want any exceptions thrown in such cases. Only ugly representation.
'''
def setUp(self):
- env = os_helper.EnvironmentVarGuard()
+ env = self.enterContext(os_helper.EnvironmentVarGuard())
env.set("COLUMNS", '15')
- self.addCleanup(env.__exit__)
parser_signature = TestHelpBiggerOptionals.parser_signature
argument_signatures = TestHelpBiggerOptionals.argument_signatures
diff --git a/Lib/test/test_getopt.py b/Lib/test/test_getopt.py
index 9261276..64b9ce0 100644
--- a/Lib/test/test_getopt.py
+++ b/Lib/test/test_getopt.py
@@ -11,14 +11,10 @@ sentinel = object()
class GetoptTests(unittest.TestCase):
def setUp(self):
- self.env = EnvironmentVarGuard()
+ self.env = self.enterContext(EnvironmentVarGuard())
if "POSIXLY_CORRECT" in self.env:
del self.env["POSIXLY_CORRECT"]
- def tearDown(self):
- self.env.__exit__()
- del self.env
-
def assertError(self, *args, **kwargs):
self.assertRaises(getopt.GetoptError, *args, **kwargs)
diff --git a/Lib/test/test_gettext.py b/Lib/test/test_gettext.py
index 467652a..1608d1b 100644
--- a/Lib/test/test_gettext.py
+++ b/Lib/test/test_gettext.py
@@ -117,6 +117,7 @@ MMOFILE = os.path.join(LOCALEDIR, 'metadata.mo')
class GettextBaseTest(unittest.TestCase):
def setUp(self):
+ self.addCleanup(os_helper.rmtree, os.path.split(LOCALEDIR)[0])
if not os.path.isdir(LOCALEDIR):
os.makedirs(LOCALEDIR)
with open(MOFILE, 'wb') as fp:
@@ -129,14 +130,10 @@ class GettextBaseTest(unittest.TestCase):
fp.write(base64.decodebytes(UMO_DATA))
with open(MMOFILE, 'wb') as fp:
fp.write(base64.decodebytes(MMO_DATA))
- self.env = os_helper.EnvironmentVarGuard()
+ self.env = self.enterContext(os_helper.EnvironmentVarGuard())
self.env['LANGUAGE'] = 'xx'
gettext._translations.clear()
- def tearDown(self):
- self.env.__exit__()
- del self.env
- os_helper.rmtree(os.path.split(LOCALEDIR)[0])
GNU_MO_DATA_ISSUE_17898 = b'''\
3hIElQAAAAABAAAAHAAAACQAAAAAAAAAAAAAAAAAAAAsAAAAggAAAC0AAAAAUGx1cmFsLUZvcm1z
diff --git a/Lib/test/test_global.py b/Lib/test/test_global.py
index d0bde3f..f5b38c2 100644
--- a/Lib/test/test_global.py
+++ b/Lib/test/test_global.py
@@ -9,14 +9,9 @@ import warnings
class GlobalTests(unittest.TestCase):
def setUp(self):
- self._warnings_manager = check_warnings()
- self._warnings_manager.__enter__()
+ self.enterContext(check_warnings())
warnings.filterwarnings("error", module="<test string>")
- def tearDown(self):
- self._warnings_manager.__exit__(None, None, None)
-
-
def test1(self):
prog_text_1 = """\
def wrong1():
@@ -54,9 +49,7 @@ x = 2
def setUpModule():
- cm = warnings.catch_warnings()
- cm.__enter__()
- unittest.addModuleCleanup(cm.__exit__, None, None, None)
+ unittest.enterModuleContext(warnings.catch_warnings())
warnings.filterwarnings("error", module="<test string>")
diff --git a/Lib/test/test_importlib/source/test_finder.py b/Lib/test/test_importlib/source/test_finder.py
index 6a23e9d..bed9d56 100644
--- a/Lib/test/test_importlib/source/test_finder.py
+++ b/Lib/test/test_importlib/source/test_finder.py
@@ -157,21 +157,12 @@ class FinderTests(abc.FinderTests):
def test_no_read_directory(self):
# Issue #16730
tempdir = tempfile.TemporaryDirectory()
+ self.enterContext(tempdir)
+ # Since we muck with the permissions, we want to set them back to
+ # their original values to make sure the directory can be properly
+ # cleaned up.
original_mode = os.stat(tempdir.name).st_mode
- def cleanup(tempdir):
- """Cleanup function for the temporary directory.
-
- Since we muck with the permissions, we want to set them back to
- their original values to make sure the directory can be properly
- cleaned up.
-
- """
- os.chmod(tempdir.name, original_mode)
- # If this is not explicitly called then the __del__ method is used,
- # but since already mucking around might as well explicitly clean
- # up.
- tempdir.__exit__(None, None, None)
- self.addCleanup(cleanup, tempdir)
+ self.addCleanup(os.chmod, tempdir.name, original_mode)
os.chmod(tempdir.name, stat.S_IWUSR | stat.S_IXUSR)
finder = self.get_finder(tempdir.name)
found = self._find(finder, 'doesnotexist')
diff --git a/Lib/test/test_importlib/test_namespace_pkgs.py b/Lib/test/test_importlib/test_namespace_pkgs.py
index 2ea41b7..cd08498 100644
--- a/Lib/test/test_importlib/test_namespace_pkgs.py
+++ b/Lib/test/test_importlib/test_namespace_pkgs.py
@@ -65,12 +65,7 @@ class NamespacePackageTest(unittest.TestCase):
self.resolved_paths = [
os.path.join(self.root, path) for path in self.paths
]
- self.ctx = namespace_tree_context(path=self.resolved_paths)
- self.ctx.__enter__()
-
- def tearDown(self):
- # TODO: will we ever want to pass exc_info to __exit__?
- self.ctx.__exit__(None, None, None)
+ self.enterContext(namespace_tree_context(path=self.resolved_paths))
class SingleNamespacePackage(NamespacePackageTest):
diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py
index 5d4dded..e69afae 100644
--- a/Lib/test/test_logging.py
+++ b/Lib/test/test_logging.py
@@ -5650,9 +5650,7 @@ class MiscTestCase(unittest.TestCase):
# why the test does this, but in any case we save the current locale
# first and restore it at the end.
def setUpModule():
- cm = support.run_with_locale('LC_ALL', '')
- cm.__enter__()
- unittest.addModuleCleanup(cm.__exit__, None, None, None)
+ unittest.enterModuleContext(support.run_with_locale('LC_ALL', ''))
if __name__ == "__main__":
diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py
index 9812c05..31a02f8 100644
--- a/Lib/test/test_nntplib.py
+++ b/Lib/test/test_nntplib.py
@@ -1593,8 +1593,7 @@ class LocalServerTests(unittest.TestCase):
self.background.start()
self.addCleanup(self.background.join)
- self.nntp = NNTP(socket_helper.HOST, port, usenetrc=False).__enter__()
- self.addCleanup(self.nntp.__exit__, None, None, None)
+ self.nntp = self.enterContext(NNTP(socket_helper.HOST, port, usenetrc=False))
def run_server(self, sock):
# Could be generalized to handle more commands in separate methods
diff --git a/Lib/test/test_peg_generator/test_c_parser.py b/Lib/test/test_peg_generator/test_c_parser.py
index 13b83a9..d25bc11 100644
--- a/Lib/test/test_peg_generator/test_c_parser.py
+++ b/Lib/test/test_peg_generator/test_c_parser.py
@@ -96,9 +96,7 @@ class TestCParser(unittest.TestCase):
self.skipTest("The %r command is not found" % cmd)
self.old_cwd = os.getcwd()
self.tmp_path = tempfile.mkdtemp(dir=self.tmp_base)
- change_cwd = os_helper.change_cwd(self.tmp_path)
- change_cwd.__enter__()
- self.addCleanup(change_cwd.__exit__, None, None, None)
+ self.enterContext(os_helper.change_cwd(self.tmp_path))
def tearDown(self):
os.chdir(self.old_cwd)
diff --git a/Lib/test/test_poll.py b/Lib/test/test_poll.py
index 7d542b5..02165a0 100644
--- a/Lib/test/test_poll.py
+++ b/Lib/test/test_poll.py
@@ -128,8 +128,7 @@ class PollTests(unittest.TestCase):
cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done'
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
bufsize=0)
- proc.__enter__()
- self.addCleanup(proc.__exit__, None, None, None)
+ self.enterContext(proc)
p = proc.stdout
pollster = select.poll()
pollster.register( p, select.POLLIN )
diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py
index f44b8d0..28e5e90 100644
--- a/Lib/test/test_posix.py
+++ b/Lib/test/test_posix.py
@@ -53,19 +53,13 @@ class PosixTester(unittest.TestCase):
def setUp(self):
# create empty file
+ self.addCleanup(os_helper.unlink, os_helper.TESTFN)
with open(os_helper.TESTFN, "wb"):
pass
- self.teardown_files = [ os_helper.TESTFN ]
- self._warnings_manager = warnings_helper.check_warnings()
- self._warnings_manager.__enter__()
+ self.enterContext(warnings_helper.check_warnings())
warnings.filterwarnings('ignore', '.* potential security risk .*',
RuntimeWarning)
- def tearDown(self):
- for teardown_file in self.teardown_files:
- os_helper.unlink(teardown_file)
- self._warnings_manager.__exit__(None, None, None)
-
def testNoArgFunctions(self):
# test posix functions which take no arguments and have
# no side-effects which we need to cleanup (e.g., fork, wait, abort)
@@ -973,8 +967,8 @@ class PosixTester(unittest.TestCase):
self.assertTrue(hasattr(testfn_st, 'st_flags'))
+ self.addCleanup(os_helper.unlink, _DUMMY_SYMLINK)
os.symlink(os_helper.TESTFN, _DUMMY_SYMLINK)
- self.teardown_files.append(_DUMMY_SYMLINK)
dummy_symlink_st = os.lstat(_DUMMY_SYMLINK)
def chflags_nofollow(path, flags):
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
index 3b57517..43f23db 100644
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -1022,8 +1022,7 @@ class TestBasicOpsBytes(TestBasicOps, unittest.TestCase):
class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
def setUp(self):
- self._warning_filters = warnings_helper.check_warnings()
- self._warning_filters.__enter__()
+ self.enterContext(warnings_helper.check_warnings())
warnings.simplefilter('ignore', BytesWarning)
self.case = "string and bytes set"
self.values = ["a", "b", b"a", b"b"]
@@ -1031,9 +1030,6 @@ class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
self.dup = set(self.values)
self.length = 4
- def tearDown(self):
- self._warning_filters.__exit__(None, None, None)
-
def test_repr(self):
self.check_repr_against_values()
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index 6133637..1aaa9e4 100755
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -338,9 +338,7 @@ class ThreadableTest:
self.server_ready.set()
def _setUp(self):
- self.wait_threads = threading_helper.wait_threads_exit()
- self.wait_threads.__enter__()
- self.addCleanup(self.wait_threads.__exit__, None, None, None)
+ self.enterContext(threading_helper.wait_threads_exit())
self.server_ready = threading.Event()
self.client_ready = threading.Event()
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 0eb8d18..fed7637 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -1999,9 +1999,8 @@ class SimpleBackgroundTests(unittest.TestCase):
self.server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self.server_context.load_cert_chain(SIGNED_CERTFILE)
server = ThreadedEchoServer(context=self.server_context)
+ self.enterContext(server)
self.server_addr = (HOST, server.port)
- server.__enter__()
- self.addCleanup(server.__exit__, None, None, None)
def test_connect(self):
with test_wrap_socket(socket.socket(socket.AF_INET),
@@ -3713,8 +3712,7 @@ class ThreadedTests(unittest.TestCase):
def test_recv_zero(self):
server = ThreadedEchoServer(CERTFILE)
- server.__enter__()
- self.addCleanup(server.__exit__, None, None)
+ self.enterContext(server)
s = socket.create_connection((HOST, server.port))
self.addCleanup(s.close)
s = test_wrap_socket(s, suppress_ragged_eofs=False)
diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py
index a05f3c8..f056e5c 100644
--- a/Lib/test/test_tempfile.py
+++ b/Lib/test/test_tempfile.py
@@ -90,14 +90,10 @@ class BaseTestCase(unittest.TestCase):
b_check = re.compile(br"^[a-z0-9_-]{8}$")
def setUp(self):
- self._warnings_manager = warnings_helper.check_warnings()
- self._warnings_manager.__enter__()
+ self.enterContext(warnings_helper.check_warnings())
warnings.filterwarnings("ignore", category=RuntimeWarning,
message="mktemp", module=__name__)
- def tearDown(self):
- self._warnings_manager.__exit__(None, None, None)
-
def nameCheck(self, name, dir, pre, suf):
(ndir, nbase) = os.path.split(name)
npre = nbase[:len(pre)]
diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py
index 82f1d9d..bc6e74c 100644
--- a/Lib/test/test_urllib.py
+++ b/Lib/test/test_urllib.py
@@ -232,17 +232,12 @@ class ProxyTests(unittest.TestCase):
def setUp(self):
# Records changes to env vars
- self.env = os_helper.EnvironmentVarGuard()
+ self.env = self.enterContext(os_helper.EnvironmentVarGuard())
# Delete all proxy related env vars
for k in list(os.environ):
if 'proxy' in k.lower():
self.env.unset(k)
- def tearDown(self):
- # Restore all proxy related env vars
- self.env.__exit__()
- del self.env
-
def test_getproxies_environment_keep_no_proxies(self):
self.env.set('NO_PROXY', 'localhost')
proxies = urllib.request.getproxies_environment()
diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py
index eda951c..005d23f 100644
--- a/Lib/unittest/__init__.py
+++ b/Lib/unittest/__init__.py
@@ -49,7 +49,7 @@ __all__ = ['TestResult', 'TestCase', 'IsolatedAsyncioTestCase', 'TestSuite',
'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless',
'expectedFailure', 'TextTestResult', 'installHandler',
'registerResult', 'removeResult', 'removeHandler',
- 'addModuleCleanup', 'doModuleCleanups']
+ 'addModuleCleanup', 'doModuleCleanups', 'enterModuleContext']
# Expose obsolete functions for backwards compatibility
# bpo-5846: Deprecated in Python 3.11, scheduled for removal in Python 3.13.
@@ -59,7 +59,8 @@ __unittest = True
from .result import TestResult
from .case import (addModuleCleanup, TestCase, FunctionTestCase, SkipTest, skip,
- skipIf, skipUnless, expectedFailure, doModuleCleanups)
+ skipIf, skipUnless, expectedFailure, doModuleCleanups,
+ enterModuleContext)
from .suite import BaseTestSuite, TestSuite
from .loader import TestLoader, defaultTestLoader
from .main import TestProgram, main
diff --git a/Lib/unittest/async_case.py b/Lib/unittest/async_case.py
index 85b938f..a90eed9 100644
--- a/Lib/unittest/async_case.py
+++ b/Lib/unittest/async_case.py
@@ -58,6 +58,26 @@ class IsolatedAsyncioTestCase(TestCase):
# 3. Regular "def func()" that returns awaitable object
self.addCleanup(*(func, *args), **kwargs)
+ async def enterAsyncContext(self, cm):
+ """Enters the supplied asynchronous context manager.
+
+ If successful, also adds its __aexit__ method as a cleanup
+ function and returns the result of the __aenter__ method.
+ """
+ # We look up the special methods on the type to match the with
+ # statement.
+ cls = type(cm)
+ try:
+ enter = cls.__aenter__
+ exit = cls.__aexit__
+ except AttributeError:
+ raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
+ f"not support the asynchronous context manager protocol"
+ ) from None
+ result = await enter(cm)
+ self.addAsyncCleanup(exit, cm, None, None, None)
+ return result
+
def _callSetUp(self):
self._asyncioTestContext.run(self.setUp)
self._callAsync(self.asyncSetUp)
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index 55770c0..ffc8f19 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -102,12 +102,31 @@ def _id(obj):
return obj
+def _enter_context(cm, addcleanup):
+ # We look up the special methods on the type to match the with
+ # statement.
+ cls = type(cm)
+ try:
+ enter = cls.__enter__
+ exit = cls.__exit__
+ except AttributeError:
+ raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
+ f"not support the context manager protocol") from None
+ result = enter(cm)
+ addcleanup(exit, cm, None, None, None)
+ return result
+
+
_module_cleanups = []
def addModuleCleanup(function, /, *args, **kwargs):
"""Same as addCleanup, except the cleanup items are called even if
setUpModule fails (unlike tearDownModule)."""
_module_cleanups.append((function, args, kwargs))
+def enterModuleContext(cm):
+ """Same as enterContext, but module-wide."""
+ return _enter_context(cm, addModuleCleanup)
+
def doModuleCleanups():
"""Execute all module cleanup functions. Normally called for you after
@@ -426,12 +445,25 @@ class TestCase(object):
Cleanup items are called even if setUp fails (unlike tearDown)."""
self._cleanups.append((function, args, kwargs))
+ def enterContext(self, cm):
+ """Enters the supplied context manager.
+
+ If successful, also adds its __exit__ method as a cleanup
+ function and returns the result of the __enter__ method.
+ """
+ return _enter_context(cm, self.addCleanup)
+
@classmethod
def addClassCleanup(cls, function, /, *args, **kwargs):
"""Same as addCleanup, except the cleanup items are called even if
setUpClass fails (unlike tearDownClass)."""
cls._class_cleanups.append((function, args, kwargs))
+ @classmethod
+ def enterClassContext(cls, cm):
+ """Same as enterContext, but class-wide."""
+ return _enter_context(cm, cls.addClassCleanup)
+
def setUp(self):
"Hook method for setting up the test fixture before exercising it."
pass
diff --git a/Lib/unittest/test/test_async_case.py b/Lib/unittest/test/test_async_case.py
index 1b910a4..beadcac 100644
--- a/Lib/unittest/test/test_async_case.py
+++ b/Lib/unittest/test/test_async_case.py
@@ -14,6 +14,29 @@ def tearDownModule():
asyncio.set_event_loop_policy(None)
+class TestCM:
+ def __init__(self, ordering, enter_result=None):
+ self.ordering = ordering
+ self.enter_result = enter_result
+
+ async def __aenter__(self):
+ self.ordering.append('enter')
+ return self.enter_result
+
+ async def __aexit__(self, *exc_info):
+ self.ordering.append('exit')
+
+
+class LacksEnterAndExit:
+ pass
+class LacksEnter:
+ async def __aexit__(self, *exc_info):
+ pass
+class LacksExit:
+ async def __aenter__(self):
+ pass
+
+
VAR = contextvars.ContextVar('VAR', default=())
@@ -337,6 +360,36 @@ class TestAsyncCase(unittest.TestCase):
output = test.run()
self.assertTrue(cancelled)
+ def test_enterAsyncContext(self):
+ events = []
+
+ class Test(unittest.IsolatedAsyncioTestCase):
+ async def test_func(slf):
+ slf.addAsyncCleanup(events.append, 'cleanup1')
+ cm = TestCM(events, 42)
+ self.assertEqual(await slf.enterAsyncContext(cm), 42)
+ slf.addAsyncCleanup(events.append, 'cleanup2')
+ events.append('test')
+
+ test = Test('test_func')
+ output = test.run()
+ self.assertTrue(output.wasSuccessful(), output)
+ self.assertEqual(events, ['enter', 'test', 'cleanup2', 'exit', 'cleanup1'])
+
+ def test_enterAsyncContext_arg_errors(self):
+ class Test(unittest.IsolatedAsyncioTestCase):
+ async def test_func(slf):
+ with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
+ await slf.enterAsyncContext(LacksEnterAndExit())
+ with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
+ await slf.enterAsyncContext(LacksEnter())
+ with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
+ await slf.enterAsyncContext(LacksExit())
+
+ test = Test('test_func')
+ output = test.run()
+ self.assertTrue(output.wasSuccessful())
+
def test_debug_cleanup_same_loop(self):
class Test(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
diff --git a/Lib/unittest/test/test_runner.py b/Lib/unittest/test/test_runner.py
index 18062ae..d3488b4 100644
--- a/Lib/unittest/test/test_runner.py
+++ b/Lib/unittest/test/test_runner.py
@@ -46,6 +46,29 @@ def cleanup(ordering, blowUp=False):
raise Exception('CleanUpExc')
+class TestCM:
+ def __init__(self, ordering, enter_result=None):
+ self.ordering = ordering
+ self.enter_result = enter_result
+
+ def __enter__(self):
+ self.ordering.append('enter')
+ return self.enter_result
+
+ def __exit__(self, *exc_info):
+ self.ordering.append('exit')
+
+
+class LacksEnterAndExit:
+ pass
+class LacksEnter:
+ def __exit__(self, *exc_info):
+ pass
+class LacksExit:
+ def __enter__(self):
+ pass
+
+
class TestCleanUp(unittest.TestCase):
def testCleanUp(self):
class TestableTest(unittest.TestCase):
@@ -173,6 +196,39 @@ class TestCleanUp(unittest.TestCase):
self.assertEqual(ordering, ['setUp', 'test', 'tearDown', 'cleanup1', 'cleanup2'])
+ def test_enterContext(self):
+ class TestableTest(unittest.TestCase):
+ def testNothing(self):
+ pass
+
+ test = TestableTest('testNothing')
+ cleanups = []
+
+ test.addCleanup(cleanups.append, 'cleanup1')
+ cm = TestCM(cleanups, 42)
+ self.assertEqual(test.enterContext(cm), 42)
+ test.addCleanup(cleanups.append, 'cleanup2')
+
+ self.assertTrue(test.doCleanups())
+ self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1'])
+
+ def test_enterContext_arg_errors(self):
+ class TestableTest(unittest.TestCase):
+ def testNothing(self):
+ pass
+
+ test = TestableTest('testNothing')
+
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ test.enterContext(LacksEnterAndExit())
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ test.enterContext(LacksEnter())
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ test.enterContext(LacksExit())
+
+ self.assertEqual(test._cleanups, [])
+
+
class TestClassCleanup(unittest.TestCase):
def test_addClassCleanUp(self):
class TestableTest(unittest.TestCase):
@@ -451,6 +507,35 @@ class TestClassCleanup(unittest.TestCase):
self.assertEqual(ordering,
['setUpClass', 'test', 'tearDownClass', 'cleanup_good'])
+ def test_enterClassContext(self):
+ class TestableTest(unittest.TestCase):
+ def testNothing(self):
+ pass
+
+ cleanups = []
+
+ TestableTest.addClassCleanup(cleanups.append, 'cleanup1')
+ cm = TestCM(cleanups, 42)
+ self.assertEqual(TestableTest.enterClassContext(cm), 42)
+ TestableTest.addClassCleanup(cleanups.append, 'cleanup2')
+
+ TestableTest.doClassCleanups()
+ self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1'])
+
+ def test_enterClassContext_arg_errors(self):
+ class TestableTest(unittest.TestCase):
+ def testNothing(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ TestableTest.enterClassContext(LacksEnterAndExit())
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ TestableTest.enterClassContext(LacksEnter())
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ TestableTest.enterClassContext(LacksExit())
+
+ self.assertEqual(TestableTest._class_cleanups, [])
+
class TestModuleCleanUp(unittest.TestCase):
def test_add_and_do_ModuleCleanup(self):
@@ -1000,6 +1085,31 @@ class TestModuleCleanUp(unittest.TestCase):
'cleanup2', 'setUp2', 'test2', 'tearDown2',
'cleanup3', 'tearDownModule', 'cleanup1'])
+ def test_enterModuleContext(self):
+ cleanups = []
+
+ unittest.addModuleCleanup(cleanups.append, 'cleanup1')
+ cm = TestCM(cleanups, 42)
+ self.assertEqual(unittest.enterModuleContext(cm), 42)
+ unittest.addModuleCleanup(cleanups.append, 'cleanup2')
+
+ unittest.case.doModuleCleanups()
+ self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1'])
+
+ def test_enterModuleContext_arg_errors(self):
+ class TestableTest(unittest.TestCase):
+ def testNothing(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ unittest.enterModuleContext(LacksEnterAndExit())
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ unittest.enterModuleContext(LacksEnter())
+ with self.assertRaisesRegex(TypeError, 'the context manager'):
+ unittest.enterModuleContext(LacksExit())
+
+ self.assertEqual(unittest.case._module_cleanups, [])
+
class Test_TextTestRunner(unittest.TestCase):
"""Tests for TextTestRunner."""
diff --git a/Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst b/Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst
new file mode 100644
index 0000000..8072afa
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst
@@ -0,0 +1,7 @@
+Add support of context managers in :mod:`unittest`: methods
+:meth:`~unittest.TestCase.enterContext` and
+:meth:`~unittest.TestCase.enterClassContext` of class
+:class:`~unittest.TestCase`, method
+:meth:`~unittest.IsolatedAsyncioTestCase.enterAsyncContext` of class
+:class:`~unittest.IsolatedAsyncioTestCase` and function
+:func:`unittest.enterModuleContext`.