diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/contextlib.py | 12 | ||||
-rw-r--r-- | Lib/test/test_contextlib.py | 24 |
2 files changed, 21 insertions, 15 deletions
diff --git a/Lib/contextlib.py b/Lib/contextlib.py index fb89118..d3219f6 100644 --- a/Lib/contextlib.py +++ b/Lib/contextlib.py @@ -166,20 +166,16 @@ class redirect_stdout: def __init__(self, new_target): self._new_target = new_target - self._old_target = self._sentinel = object() + # We use a list of old targets to make this CM re-entrant + self._old_targets = [] def __enter__(self): - if self._old_target is not self._sentinel: - raise RuntimeError("Cannot reenter {!r}".format(self)) - self._old_target = sys.stdout + self._old_targets.append(sys.stdout) sys.stdout = self._new_target return self._new_target def __exit__(self, exctype, excinst, exctb): - restore_stdout = self._old_target - self._old_target = self._sentinel - sys.stdout = restore_stdout - + sys.stdout = self._old_targets.pop() class suppress: diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index 916ac80..b8770c8 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -666,11 +666,18 @@ class TestRedirectStdout(unittest.TestCase): obj = redirect_stdout(None) self.assertEqual(obj.__doc__, cm_docstring) + def test_no_redirect_in_init(self): + orig_stdout = sys.stdout + redirect_stdout(None) + self.assertIs(sys.stdout, orig_stdout) + def test_redirect_to_string_io(self): f = io.StringIO() msg = "Consider an API like help(), which prints directly to stdout" + orig_stdout = sys.stdout with redirect_stdout(f): print(msg) + self.assertIs(sys.stdout, orig_stdout) s = f.getvalue().strip() self.assertEqual(s, msg) @@ -682,23 +689,26 @@ class TestRedirectStdout(unittest.TestCase): def test_cm_is_reusable(self): f = io.StringIO() write_to_f = redirect_stdout(f) + orig_stdout = sys.stdout with write_to_f: print("Hello", end=" ") with write_to_f: print("World!") + self.assertIs(sys.stdout, orig_stdout) s = f.getvalue() self.assertEqual(s, "Hello World!\n") - # If this is ever made reentrant, update the reusable-but-not-reentrant - # example at the end of the contextlib docs accordingly. - def test_nested_reentry_fails(self): + def test_cm_is_reentrant(self): f = io.StringIO() write_to_f = redirect_stdout(f) - with self.assertRaisesRegex(RuntimeError, "Cannot reenter"): + orig_stdout = sys.stdout + with write_to_f: + print("Hello", end=" ") with write_to_f: - print("Hello", end=" ") - with write_to_f: - print("World!") + print("World!") + self.assertIs(sys.stdout, orig_stdout) + s = f.getvalue() + self.assertEqual(s, "Hello World!\n") class TestSuppress(unittest.TestCase): |