summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/contextlib.py12
-rw-r--r--Lib/test/test_contextlib.py24
2 files changed, 21 insertions, 15 deletions
diff --git a/Lib/contextlib.py b/Lib/contextlib.py
index fb89118..d3219f6 100644
--- a/Lib/contextlib.py
+++ b/Lib/contextlib.py
@@ -166,20 +166,16 @@ class redirect_stdout:
def __init__(self, new_target):
self._new_target = new_target
- self._old_target = self._sentinel = object()
+ # We use a list of old targets to make this CM re-entrant
+ self._old_targets = []
def __enter__(self):
- if self._old_target is not self._sentinel:
- raise RuntimeError("Cannot reenter {!r}".format(self))
- self._old_target = sys.stdout
+ self._old_targets.append(sys.stdout)
sys.stdout = self._new_target
return self._new_target
def __exit__(self, exctype, excinst, exctb):
- restore_stdout = self._old_target
- self._old_target = self._sentinel
- sys.stdout = restore_stdout
-
+ sys.stdout = self._old_targets.pop()
class suppress:
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index 916ac80..b8770c8 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -666,11 +666,18 @@ class TestRedirectStdout(unittest.TestCase):
obj = redirect_stdout(None)
self.assertEqual(obj.__doc__, cm_docstring)
+ def test_no_redirect_in_init(self):
+ orig_stdout = sys.stdout
+ redirect_stdout(None)
+ self.assertIs(sys.stdout, orig_stdout)
+
def test_redirect_to_string_io(self):
f = io.StringIO()
msg = "Consider an API like help(), which prints directly to stdout"
+ orig_stdout = sys.stdout
with redirect_stdout(f):
print(msg)
+ self.assertIs(sys.stdout, orig_stdout)
s = f.getvalue().strip()
self.assertEqual(s, msg)
@@ -682,23 +689,26 @@ class TestRedirectStdout(unittest.TestCase):
def test_cm_is_reusable(self):
f = io.StringIO()
write_to_f = redirect_stdout(f)
+ orig_stdout = sys.stdout
with write_to_f:
print("Hello", end=" ")
with write_to_f:
print("World!")
+ self.assertIs(sys.stdout, orig_stdout)
s = f.getvalue()
self.assertEqual(s, "Hello World!\n")
- # If this is ever made reentrant, update the reusable-but-not-reentrant
- # example at the end of the contextlib docs accordingly.
- def test_nested_reentry_fails(self):
+ def test_cm_is_reentrant(self):
f = io.StringIO()
write_to_f = redirect_stdout(f)
- with self.assertRaisesRegex(RuntimeError, "Cannot reenter"):
+ orig_stdout = sys.stdout
+ with write_to_f:
+ print("Hello", end=" ")
with write_to_f:
- print("Hello", end=" ")
- with write_to_f:
- print("World!")
+ print("World!")
+ self.assertIs(sys.stdout, orig_stdout)
+ s = f.getvalue()
+ self.assertEqual(s, "Hello World!\n")
class TestSuppress(unittest.TestCase):