diff options
Diffstat (limited to 'Lib/test/support.py')
-rw-r--r-- | Lib/test/support.py | 38 |
1 files changed, 18 insertions, 20 deletions
diff --git a/Lib/test/support.py b/Lib/test/support.py index cab366b..0cc8c31 100644 --- a/Lib/test/support.py +++ b/Lib/test/support.py @@ -37,6 +37,7 @@ __all__ = [ "findfile", "sortdict", "check_syntax_error", "open_urlresource", "check_warnings", "CleanImport", "EnvironmentVarGuard", "TransientResource", "captured_output", "captured_stdout", + "captured_stdin", "captured_stderr", "time_out", "socket_peer_reset", "ioerror_peer_reset", "run_with_locale", 'temp_umask', "transient_internet", "set_memlimit", "bigmemtest", "bigaddrspacetest", "BasicTestRunner", @@ -92,19 +93,15 @@ def import_module(name, deprecated=False): def _save_and_remove_module(name, orig_modules): """Helper function to save and remove a module from sys.modules - Return True if the module was in sys.modules, False otherwise. Raise ImportError if the module can't be imported.""" - saved = True - try: - orig_modules[name] = sys.modules[name] - except KeyError: - # try to import the module and raise an error if it can't be imported + # try to import the module and raise an error if it can't be imported + if name not in sys.modules: __import__(name) - saved = False - else: del sys.modules[name] - return saved - + for modname in list(sys.modules): + if modname == name or modname.startswith(name + '.'): + orig_modules[modname] = sys.modules[modname] + del sys.modules[modname] def _save_and_block_module(name, orig_modules): """Helper function to save and block a module in sys.modules @@ -132,8 +129,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False): If deprecated is True, any module or package deprecation messages will be suppressed.""" - # NOTE: test_heapq and test_warnings include extra sanity checks to make - # sure that this utility function is working as expected + # NOTE: test_heapq, test_json and test_warnings include extra sanity checks + # to make sure that this utility function is working as expected with _ignore_deprecated_imports(deprecated): # Keep track of modules saved for later restoration as well # as those which just need a blocking entry removed @@ -895,14 +892,8 @@ def transient_internet(resource_name, *, timeout=30.0, errnos=()): @contextlib.contextmanager def captured_output(stream_name): - """Run the 'with' statement body using a StringIO object in place of a - specific attribute on the sys module. - Example use (with 'stream_name=stdout'):: - - with captured_stdout() as s: - print("hello") - assert s.getvalue() == "hello" - """ + """Return a context manager used by captured_stdout/stdin/stderr + that temporarily replaces the sys stream *stream_name* with a StringIO.""" import io orig_stdout = getattr(sys, stream_name) setattr(sys, stream_name, io.StringIO()) @@ -912,6 +903,12 @@ def captured_output(stream_name): setattr(sys, stream_name, orig_stdout) def captured_stdout(): + """Capture the output of sys.stdout: + + with captured_stdout() as s: + print("hello") + self.assertEqual(s.getvalue(), "hello") + """ return captured_output("stdout") def captured_stderr(): @@ -920,6 +917,7 @@ def captured_stderr(): def captured_stdin(): return captured_output("stdin") + def gc_collect(): """Force as many objects as possible to be collected. |