summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/contextlib.py126
-rw-r--r--Lib/test/test_contextlib.py123
2 files changed, 244 insertions, 5 deletions
diff --git a/Lib/contextlib.py b/Lib/contextlib.py
index 2f8f00d..ead1155 100644
--- a/Lib/contextlib.py
+++ b/Lib/contextlib.py
@@ -1,9 +1,10 @@
"""Utilities for with-statement contexts. See PEP 343."""
import sys
+from collections import deque
from functools import wraps
-__all__ = ["contextmanager", "closing", "ContextDecorator"]
+__all__ = ["contextmanager", "closing", "ContextDecorator", "ExitStack"]
class ContextDecorator(object):
@@ -12,12 +13,12 @@ class ContextDecorator(object):
def _recreate_cm(self):
"""Return a recreated instance of self.
- Allows otherwise one-shot context managers like
+ Allows an otherwise one-shot context manager like
_GeneratorContextManager to support use as
- decorators via implicit recreation.
+ a decorator via implicit recreation.
- Note: this is a private interface just for _GCM in 3.2 but will be
- renamed and documented for third party use in 3.3
+ This is a private interface just for _GeneratorContextManager.
+ See issue #11647 for details.
"""
return self
@@ -138,3 +139,118 @@ class closing(object):
return self.thing
def __exit__(self, *exc_info):
self.thing.close()
+
+
+# Inspired by discussions on http://bugs.python.org/issue13585
+class ExitStack(object):
+ """Context manager for dynamic management of a stack of exit callbacks
+
+ For example:
+
+ with ExitStack() as stack:
+ files = [stack.enter_context(open(fname)) for fname in filenames]
+ # All opened files will automatically be closed at the end of
+ # the with statement, even if attempts to open files later
+ # in the list throw an exception
+
+ """
+ def __init__(self):
+ self._exit_callbacks = deque()
+
+ def pop_all(self):
+ """Preserve the context stack by transferring it to a new instance"""
+ new_stack = type(self)()
+ new_stack._exit_callbacks = self._exit_callbacks
+ self._exit_callbacks = deque()
+ return new_stack
+
+ def _push_cm_exit(self, cm, cm_exit):
+ """Helper to correctly register callbacks to __exit__ methods"""
+ def _exit_wrapper(*exc_details):
+ return cm_exit(cm, *exc_details)
+ _exit_wrapper.__self__ = cm
+ self.push(_exit_wrapper)
+
+ def push(self, exit):
+ """Registers a callback with the standard __exit__ method signature
+
+ Can suppress exceptions the same way __exit__ methods can.
+
+ Also accepts any object with an __exit__ method (registering a call
+ to the method instead of the object itself)
+ """
+ # We use an unbound method rather than a bound method to follow
+ # the standard lookup behaviour for special methods
+ _cb_type = type(exit)
+ try:
+ exit_method = _cb_type.__exit__
+ except AttributeError:
+ # Not a context manager, so assume its a callable
+ self._exit_callbacks.append(exit)
+ else:
+ self._push_cm_exit(exit, exit_method)
+ return exit # Allow use as a decorator
+
+ def callback(self, callback, *args, **kwds):
+ """Registers an arbitrary callback and arguments.
+
+ Cannot suppress exceptions.
+ """
+ def _exit_wrapper(exc_type, exc, tb):
+ callback(*args, **kwds)
+ # We changed the signature, so using @wraps is not appropriate, but
+ # setting __wrapped__ may still help with introspection
+ _exit_wrapper.__wrapped__ = callback
+ self.push(_exit_wrapper)
+ return callback # Allow use as a decorator
+
+ def enter_context(self, cm):
+ """Enters the supplied context manager
+
+ If successful, also pushes its __exit__ method as a callback and
+ returns the result of the __enter__ method.
+ """
+ # We look up the special methods on the type to match the with statement
+ _cm_type = type(cm)
+ _exit = _cm_type.__exit__
+ result = _cm_type.__enter__(cm)
+ self._push_cm_exit(cm, _exit)
+ return result
+
+ def close(self):
+ """Immediately unwind the context stack"""
+ self.__exit__(None, None, None)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exc_details):
+ if not self._exit_callbacks:
+ return
+ # This looks complicated, but it is really just
+ # setting up a chain of try-expect statements to ensure
+ # that outer callbacks still get invoked even if an
+ # inner one throws an exception
+ def _invoke_next_callback(exc_details):
+ # Callbacks are removed from the list in FIFO order
+ # but the recursion means they're invoked in LIFO order
+ cb = self._exit_callbacks.popleft()
+ if not self._exit_callbacks:
+ # Innermost callback is invoked directly
+ return cb(*exc_details)
+ # More callbacks left, so descend another level in the stack
+ try:
+ suppress_exc = _invoke_next_callback(exc_details)
+ except:
+ suppress_exc = cb(*sys.exc_info())
+ # Check if this cb suppressed the inner exception
+ if not suppress_exc:
+ raise
+ else:
+ # Check if inner cb suppressed the original exception
+ if suppress_exc:
+ exc_details = (None, None, None)
+ suppress_exc = cb(*exc_details) or suppress_exc
+ return suppress_exc
+ # Kick off the recursive chain
+ return _invoke_next_callback(exc_details)
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index 6e38305..8bed88e 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -370,6 +370,129 @@ class TestContextDecorator(unittest.TestCase):
self.assertEqual(state, [1, 'something else', 999])
+class TestExitStack(unittest.TestCase):
+
+ def test_no_resources(self):
+ with ExitStack():
+ pass
+
+ def test_callback(self):
+ expected = [
+ ((), {}),
+ ((1,), {}),
+ ((1,2), {}),
+ ((), dict(example=1)),
+ ((1,), dict(example=1)),
+ ((1,2), dict(example=1)),
+ ]
+ result = []
+ def _exit(*args, **kwds):
+ """Test metadata propagation"""
+ result.append((args, kwds))
+ with ExitStack() as stack:
+ for args, kwds in reversed(expected):
+ if args and kwds:
+ f = stack.callback(_exit, *args, **kwds)
+ elif args:
+ f = stack.callback(_exit, *args)
+ elif kwds:
+ f = stack.callback(_exit, **kwds)
+ else:
+ f = stack.callback(_exit)
+ self.assertIs(f, _exit)
+ for wrapper in stack._exit_callbacks:
+ self.assertIs(wrapper.__wrapped__, _exit)
+ self.assertNotEqual(wrapper.__name__, _exit.__name__)
+ self.assertIsNone(wrapper.__doc__, _exit.__doc__)
+ self.assertEqual(result, expected)
+
+ def test_push(self):
+ exc_raised = ZeroDivisionError
+ def _expect_exc(exc_type, exc, exc_tb):
+ self.assertIs(exc_type, exc_raised)
+ def _suppress_exc(*exc_details):
+ return True
+ def _expect_ok(exc_type, exc, exc_tb):
+ self.assertIsNone(exc_type)
+ self.assertIsNone(exc)
+ self.assertIsNone(exc_tb)
+ class ExitCM(object):
+ def __init__(self, check_exc):
+ self.check_exc = check_exc
+ def __enter__(self):
+ self.fail("Should not be called!")
+ def __exit__(self, *exc_details):
+ self.check_exc(*exc_details)
+ with ExitStack() as stack:
+ stack.push(_expect_ok)
+ self.assertIs(stack._exit_callbacks[-1], _expect_ok)
+ cm = ExitCM(_expect_ok)
+ stack.push(cm)
+ self.assertIs(stack._exit_callbacks[-1].__self__, cm)
+ stack.push(_suppress_exc)
+ self.assertIs(stack._exit_callbacks[-1], _suppress_exc)
+ cm = ExitCM(_expect_exc)
+ stack.push(cm)
+ self.assertIs(stack._exit_callbacks[-1].__self__, cm)
+ stack.push(_expect_exc)
+ self.assertIs(stack._exit_callbacks[-1], _expect_exc)
+ stack.push(_expect_exc)
+ self.assertIs(stack._exit_callbacks[-1], _expect_exc)
+ 1/0
+
+ def test_enter_context(self):
+ class TestCM(object):
+ def __enter__(self):
+ result.append(1)
+ def __exit__(self, *exc_details):
+ result.append(3)
+
+ result = []
+ cm = TestCM()
+ with ExitStack() as stack:
+ @stack.callback # Registered first => cleaned up last
+ def _exit():
+ result.append(4)
+ self.assertIsNotNone(_exit)
+ stack.enter_context(cm)
+ self.assertIs(stack._exit_callbacks[-1].__self__, cm)
+ result.append(2)
+ self.assertEqual(result, [1, 2, 3, 4])
+
+ def test_close(self):
+ result = []
+ with ExitStack() as stack:
+ @stack.callback
+ def _exit():
+ result.append(1)
+ self.assertIsNotNone(_exit)
+ stack.close()
+ result.append(2)
+ self.assertEqual(result, [1, 2])
+
+ def test_pop_all(self):
+ result = []
+ with ExitStack() as stack:
+ @stack.callback
+ def _exit():
+ result.append(3)
+ self.assertIsNotNone(_exit)
+ new_stack = stack.pop_all()
+ result.append(1)
+ result.append(2)
+ new_stack.close()
+ self.assertEqual(result, [1, 2, 3])
+
+ def test_instance_bypass(self):
+ class Example(object): pass
+ cm = Example()
+ cm.__exit__ = object()
+ stack = ExitStack()
+ self.assertRaises(AttributeError, stack.enter_context, cm)
+ stack.push(cm)
+ self.assertIs(stack._exit_callbacks[-1], cm)
+
+
# This is needed to make the test actually run under regrtest.py!
def test_main():
support.run_unittest(__name__)