diff options
Diffstat (limited to 'Lib/test/test_with.py')
-rw-r--r-- | Lib/test/test_with.py | 84 |
1 files changed, 70 insertions, 14 deletions
diff --git a/Lib/test/test_with.py b/Lib/test/test_with.py index ed072c9..36035e3 100644 --- a/Lib/test/test_with.py +++ b/Lib/test/test_with.py @@ -7,9 +7,10 @@ from __future__ import with_statement __author__ = "Mike Bland" __email__ = "mbland at acm dot org" +import sys import unittest -from test.contextmanager import GeneratorContextManager -from test.nested import nested +from collections import deque +from contextlib import GeneratorContextManager, contextmanager from test.test_support import run_unittest @@ -57,9 +58,48 @@ def mock_contextmanager_generator(): mock.stopped = True -class MockNested(nested): +class Nested(object): + + def __init__(self, *contexts): + self.contexts = contexts + self.entered = None + + def __context__(self): + return self + + def __enter__(self): + if self.entered is not None: + raise RuntimeError("Context is not reentrant") + self.entered = deque() + vars = [] + try: + for context in self.contexts: + mgr = context.__context__() + vars.append(mgr.__enter__()) + self.entered.appendleft(mgr) + except: + self.__exit__(*sys.exc_info()) + raise + return vars + + def __exit__(self, *exc_info): + # Behave like nested with statements + # first in, last out + # New exceptions override old ones + ex = exc_info + for mgr in self.entered: + try: + mgr.__exit__(*ex) + except: + ex = sys.exc_info() + self.entered = None + if ex is not exc_info: + raise ex[0], ex[1], ex[2] + + +class MockNested(Nested): def __init__(self, *contexts): - nested.__init__(self, *contexts) + Nested.__init__(self, *contexts) self.context_called = False self.enter_called = False self.exit_called = False @@ -67,16 +107,16 @@ class MockNested(nested): def __context__(self): self.context_called = True - return nested.__context__(self) + return Nested.__context__(self) def __enter__(self): self.enter_called = True - return nested.__enter__(self) + return Nested.__enter__(self) def __exit__(self, *exc_info): self.exit_called = True self.exit_args = exc_info - return nested.__exit__(self, *exc_info) + return Nested.__exit__(self, *exc_info) class FailureTestCase(unittest.TestCase): @@ -294,7 +334,7 @@ class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin): class NestedNonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin): def testSingleArgInlineGeneratorSyntax(self): - with nested(mock_contextmanager_generator()): + with Nested(mock_contextmanager_generator()): pass def testSingleArgUnbound(self): @@ -310,7 +350,7 @@ class NestedNonexceptionalTestCase(unittest.TestCase, m = mock_contextmanager_generator() # This will bind all the arguments to nested() into a single list # assigned to foo. - with nested(m) as foo: + with Nested(m) as foo: self.assertInWithManagerInvariants(m) self.assertAfterWithManagerInvariantsNoError(m) @@ -318,14 +358,13 @@ class NestedNonexceptionalTestCase(unittest.TestCase, m = mock_contextmanager_generator() # This will bind all the arguments to nested() into a single list # assigned to foo. - # FIXME: what should this do: with nested(m) as (foo,): - with nested(m) as (foo): + with Nested(m) as (foo): self.assertInWithManagerInvariants(m) self.assertAfterWithManagerInvariantsNoError(m) def testSingleArgBoundToMultipleElementTupleError(self): def shouldThrowValueError(): - with nested(mock_contextmanager_generator()) as (foo, bar): + with Nested(mock_contextmanager_generator()) as (foo, bar): pass self.assertRaises(ValueError, shouldThrowValueError) @@ -535,7 +574,9 @@ class AssignmentTargetTestCase(unittest.TestCase): class C: def __context__(self): return self def __enter__(self): return 1, 2, 3 - def __exit__(self, *a): pass + def __exit__(self, t, v, tb): + if t is not None: + raise t, v, tb targets = {1: [0, 1, 2]} with C() as (targets[1][0], targets[1][1], targets[1][2]): self.assertEqual(targets, {1: [1, 2, 3]}) @@ -551,11 +592,26 @@ class AssignmentTargetTestCase(unittest.TestCase): self.assertEqual(blah.three, 3) +class ExitSwallowsExceptionTestCase(unittest.TestCase): + + def testExitSwallowsException(self): + class AfricanOrEuropean: + def __context__(self): return self + def __enter__(self): pass + def __exit__(self, t, v, tb): pass + try: + with AfricanOrEuropean(): + 1/0 + except ZeroDivisionError: + self.fail("ZeroDivisionError should have been swallowed") + + def test_main(): run_unittest(FailureTestCase, NonexceptionalTestCase, NestedNonexceptionalTestCase, ExceptionalTestCase, NonLocalFlowControlTestCase, - AssignmentTargetTestCase) + AssignmentTargetTestCase, + ExitSwallowsExceptionTestCase) if __name__ == '__main__': |