summaryrefslogtreecommitdiffstats
path: root/Lib/test/support.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/support.py')
-rw-r--r--Lib/test/support.py38
1 files changed, 18 insertions, 20 deletions
diff --git a/Lib/test/support.py b/Lib/test/support.py
index cab366b..0cc8c31 100644
--- a/Lib/test/support.py
+++ b/Lib/test/support.py
@@ -37,6 +37,7 @@ __all__ = [
"findfile", "sortdict", "check_syntax_error", "open_urlresource",
"check_warnings", "CleanImport", "EnvironmentVarGuard",
"TransientResource", "captured_output", "captured_stdout",
+ "captured_stdin", "captured_stderr",
"time_out", "socket_peer_reset", "ioerror_peer_reset",
"run_with_locale", 'temp_umask', "transient_internet",
"set_memlimit", "bigmemtest", "bigaddrspacetest", "BasicTestRunner",
@@ -92,19 +93,15 @@ def import_module(name, deprecated=False):
def _save_and_remove_module(name, orig_modules):
"""Helper function to save and remove a module from sys.modules
- Return True if the module was in sys.modules, False otherwise.
Raise ImportError if the module can't be imported."""
- saved = True
- try:
- orig_modules[name] = sys.modules[name]
- except KeyError:
- # try to import the module and raise an error if it can't be imported
+ # try to import the module and raise an error if it can't be imported
+ if name not in sys.modules:
__import__(name)
- saved = False
- else:
del sys.modules[name]
- return saved
-
+ for modname in list(sys.modules):
+ if modname == name or modname.startswith(name + '.'):
+ orig_modules[modname] = sys.modules[modname]
+ del sys.modules[modname]
def _save_and_block_module(name, orig_modules):
"""Helper function to save and block a module in sys.modules
@@ -132,8 +129,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
If deprecated is True, any module or package deprecation messages
will be suppressed."""
- # NOTE: test_heapq and test_warnings include extra sanity checks to make
- # sure that this utility function is working as expected
+ # NOTE: test_heapq, test_json and test_warnings include extra sanity checks
+ # to make sure that this utility function is working as expected
with _ignore_deprecated_imports(deprecated):
# Keep track of modules saved for later restoration as well
# as those which just need a blocking entry removed
@@ -895,14 +892,8 @@ def transient_internet(resource_name, *, timeout=30.0, errnos=()):
@contextlib.contextmanager
def captured_output(stream_name):
- """Run the 'with' statement body using a StringIO object in place of a
- specific attribute on the sys module.
- Example use (with 'stream_name=stdout')::
-
- with captured_stdout() as s:
- print("hello")
- assert s.getvalue() == "hello"
- """
+ """Return a context manager used by captured_stdout/stdin/stderr
+ that temporarily replaces the sys stream *stream_name* with a StringIO."""
import io
orig_stdout = getattr(sys, stream_name)
setattr(sys, stream_name, io.StringIO())
@@ -912,6 +903,12 @@ def captured_output(stream_name):
setattr(sys, stream_name, orig_stdout)
def captured_stdout():
+ """Capture the output of sys.stdout:
+
+ with captured_stdout() as s:
+ print("hello")
+ self.assertEqual(s.getvalue(), "hello")
+ """
return captured_output("stdout")
def captured_stderr():
@@ -920,6 +917,7 @@ def captured_stderr():
def captured_stdin():
return captured_output("stdin")
+
def gc_collect():
"""Force as many objects as possible to be collected.