diff options
author | Guido van Rossum <guido@python.org> | 2006-03-01 17:10:01 (GMT) |
---|---|---|
committer | Guido van Rossum <guido@python.org> | 2006-03-01 17:10:01 (GMT) |
commit | a9f068726fb4cf3693bd70b4b98bd0deaba45443 (patch) | |
tree | 1f34c66af4697944ee39965dd03813fb3b28a4ce | |
parent | 6db0e00d571781806cb850088365730fa64e80a6 (diff) | |
download | cpython-a9f068726fb4cf3693bd70b4b98bd0deaba45443.zip cpython-a9f068726fb4cf3693bd70b4b98bd0deaba45443.tar.gz cpython-a9f068726fb4cf3693bd70b4b98bd0deaba45443.tar.bz2 |
Fix a bug in nested() - if one of the sub-context-managers swallows the
exception, it should not be propagated up. With unit tests.
-rw-r--r-- | Lib/contextlib.py | 5 | ||||
-rw-r--r-- | Lib/test/test_contextlib.py | 54 |
2 files changed, 58 insertions, 1 deletions
diff --git a/Lib/contextlib.py b/Lib/contextlib.py index 33d83a6..33c302d 100644 --- a/Lib/contextlib.py +++ b/Lib/contextlib.py @@ -91,7 +91,6 @@ def nested(*contexts): """ exits = [] vars = [] - exc = (None, None, None) try: try: for context in contexts: @@ -103,6 +102,8 @@ def nested(*contexts): yield vars except: exc = sys.exc_info() + else: + exc = (None, None, None) finally: while exits: exit = exits.pop() @@ -110,6 +111,8 @@ def nested(*contexts): exit(*exc) except: exc = sys.exc_info() + else: + exc = (None, None, None) if exc != (None, None, None): raise diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index 8c8d887..f8db88c 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -107,6 +107,60 @@ class NestedTestCase(unittest.TestCase): else: self.fail("Didn't raise ZeroDivisionError") + def test_nested_b_swallows(self): + @contextmanager + def a(): + yield + @contextmanager + def b(): + try: + yield + except: + # Swallow the exception + pass + try: + with nested(a(), b()): + 1/0 + except ZeroDivisionError: + self.fail("Didn't swallow ZeroDivisionError") + + def test_nested_break(self): + @contextmanager + def a(): + yield + state = 0 + while True: + state += 1 + with nested(a(), a()): + break + state += 10 + self.assertEqual(state, 1) + + def test_nested_continue(self): + @contextmanager + def a(): + yield + state = 0 + while state < 3: + state += 1 + with nested(a(), a()): + continue + state += 10 + self.assertEqual(state, 3) + + def test_nested_return(self): + @contextmanager + def a(): + try: + yield + except: + pass + def foo(): + with nested(a(), a()): + return 1 + return 10 + self.assertEqual(foo(), 1) + class ClosingTestCase(unittest.TestCase): # XXX This needs more work |