summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGuido van Rossum <guido@python.org>2006-03-01 17:10:01 (GMT)
committerGuido van Rossum <guido@python.org>2006-03-01 17:10:01 (GMT)
commita9f068726fb4cf3693bd70b4b98bd0deaba45443 (patch)
tree1f34c66af4697944ee39965dd03813fb3b28a4ce
parent6db0e00d571781806cb850088365730fa64e80a6 (diff)
downloadcpython-a9f068726fb4cf3693bd70b4b98bd0deaba45443.zip
cpython-a9f068726fb4cf3693bd70b4b98bd0deaba45443.tar.gz
cpython-a9f068726fb4cf3693bd70b4b98bd0deaba45443.tar.bz2
Fix a bug in nested() - if one of the sub-context-managers swallows the
exception, it should not be propagated up. With unit tests.
-rw-r--r--Lib/contextlib.py5
-rw-r--r--Lib/test/test_contextlib.py54
2 files changed, 58 insertions, 1 deletions
diff --git a/Lib/contextlib.py b/Lib/contextlib.py
index 33d83a6..33c302d 100644
--- a/Lib/contextlib.py
+++ b/Lib/contextlib.py
@@ -91,7 +91,6 @@ def nested(*contexts):
"""
exits = []
vars = []
- exc = (None, None, None)
try:
try:
for context in contexts:
@@ -103,6 +102,8 @@ def nested(*contexts):
yield vars
except:
exc = sys.exc_info()
+ else:
+ exc = (None, None, None)
finally:
while exits:
exit = exits.pop()
@@ -110,6 +111,8 @@ def nested(*contexts):
exit(*exc)
except:
exc = sys.exc_info()
+ else:
+ exc = (None, None, None)
if exc != (None, None, None):
raise
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index 8c8d887..f8db88c 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -107,6 +107,60 @@ class NestedTestCase(unittest.TestCase):
else:
self.fail("Didn't raise ZeroDivisionError")
+ def test_nested_b_swallows(self):
+ @contextmanager
+ def a():
+ yield
+ @contextmanager
+ def b():
+ try:
+ yield
+ except:
+ # Swallow the exception
+ pass
+ try:
+ with nested(a(), b()):
+ 1/0
+ except ZeroDivisionError:
+ self.fail("Didn't swallow ZeroDivisionError")
+
+ def test_nested_break(self):
+ @contextmanager
+ def a():
+ yield
+ state = 0
+ while True:
+ state += 1
+ with nested(a(), a()):
+ break
+ state += 10
+ self.assertEqual(state, 1)
+
+ def test_nested_continue(self):
+ @contextmanager
+ def a():
+ yield
+ state = 0
+ while state < 3:
+ state += 1
+ with nested(a(), a()):
+ continue
+ state += 10
+ self.assertEqual(state, 3)
+
+ def test_nested_return(self):
+ @contextmanager
+ def a():
+ try:
+ yield
+ except:
+ pass
+ def foo():
+ with nested(a(), a()):
+ return 1
+ return 10
+ self.assertEqual(foo(), 1)
+
class ClosingTestCase(unittest.TestCase):
# XXX This needs more work