diff options
author | Guido van Rossum <guido@python.org> | 2006-02-28 21:57:43 (GMT) |
---|---|---|
committer | Guido van Rossum <guido@python.org> | 2006-02-28 21:57:43 (GMT) |
commit | 1a5e21e0334a6d4e1c756575023c7157fc9ee306 (patch) | |
tree | d2c1c9383b3c6d8194449ae756e663b0b0ac9e4e /Lib | |
parent | 87a8b4fee56b8204ee9f7b0ce2e5db0564e8f86e (diff) | |
download | cpython-1a5e21e0334a6d4e1c756575023c7157fc9ee306.zip cpython-1a5e21e0334a6d4e1c756575023c7157fc9ee306.tar.gz cpython-1a5e21e0334a6d4e1c756575023c7157fc9ee306.tar.bz2 |
Updates to the with-statement:
- New semantics for __exit__() -- it must re-raise the exception
if type is not None; the with-statement itself doesn't do this.
(See the updated PEP for motivation.)
- Added context managers to:
- file
- thread.LockType
- threading.{Lock,RLock,Condition,Semaphore,BoundedSemaphore}
- decimal.Context
- Added contextlib.py, which defines @contextmanager, nested(), closing().
- Unit tests all around; bot no docs yet.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/contextlib.py | 138 | ||||
-rw-r--r-- | Lib/decimal.py | 29 | ||||
-rw-r--r-- | Lib/test/contextmanager.py | 33 | ||||
-rw-r--r-- | Lib/test/nested.py | 40 | ||||
-rw-r--r-- | Lib/test/test_contextlib.py | 240 | ||||
-rw-r--r-- | Lib/test/test_with.py | 84 | ||||
-rw-r--r-- | Lib/threading.py | 29 |
7 files changed, 506 insertions, 87 deletions
diff --git a/Lib/contextlib.py b/Lib/contextlib.py new file mode 100644 index 0000000..33d83a6 --- /dev/null +++ b/Lib/contextlib.py @@ -0,0 +1,138 @@ +"""Utilities for with-statement contexts. See PEP 343.""" + +import sys + +__all__ = ["contextmanager", "nested", "closing"] + +class GeneratorContextManager(object): + """Helper for @contextmanager decorator.""" + + def __init__(self, gen): + self.gen = gen + + def __context__(self): + return self + + def __enter__(self): + try: + return self.gen.next() + except StopIteration: + raise RuntimeError("generator didn't yield") + + def __exit__(self, type, value, traceback): + if type is None: + try: + self.gen.next() + except StopIteration: + return + else: + raise RuntimeError("generator didn't stop") + else: + try: + self.gen.throw(type, value, traceback) + except StopIteration: + pass + + +def contextmanager(func): + """@contextmanager decorator. + + Typical usage: + + @contextmanager + def some_generator(<arguments>): + <setup> + try: + yield <value> + finally: + <cleanup> + + This makes this: + + with some_generator(<arguments>) as <variable>: + <body> + + equivalent to this: + + <setup> + try: + <variable> = <value> + <body> + finally: + <cleanup> + + """ + def helper(*args, **kwds): + return GeneratorContextManager(func(*args, **kwds)) + try: + helper.__name__ = func.__name__ + helper.__doc__ = func.__doc__ + except: + pass + return helper + + +@contextmanager +def nested(*contexts): + """Support multiple context managers in a single with-statement. + + Code like this: + + with nested(A, B, C) as (X, Y, Z): + <body> + + is equivalent to this: + + with A as X: + with B as Y: + with C as Z: + <body> + + """ + exits = [] + vars = [] + exc = (None, None, None) + try: + try: + for context in contexts: + mgr = context.__context__() + exit = mgr.__exit__ + enter = mgr.__enter__ + vars.append(enter()) + exits.append(exit) + yield vars + except: + exc = sys.exc_info() + finally: + while exits: + exit = exits.pop() + try: + exit(*exc) + except: + exc = sys.exc_info() + if exc != (None, None, None): + raise + + +@contextmanager +def closing(thing): + """Context manager to automatically close something at the end of a block. + + Code like this: + + with closing(<module>.open(<arguments>)) as f: + <block> + + is equivalent to this: + + f = <module>.open(<arguments>) + try: + <block> + finally: + f.close() + + """ + try: + yield thing + finally: + thing.close() diff --git a/Lib/decimal.py b/Lib/decimal.py index 677d26b..49f8115 100644 --- a/Lib/decimal.py +++ b/Lib/decimal.py @@ -2173,6 +2173,32 @@ for name in rounding_functions: del name, val, globalname, rounding_functions +class ContextManager(object): + """Helper class to simplify Context management. + + Sample usage: + + with decimal.ExtendedContext: + s = ... + return +s # Convert result to normal precision + + with decimal.getcontext() as ctx: + ctx.prec += 2 + s = ... + return +s + + """ + def __init__(self, new_context): + self.new_context = new_context + def __enter__(self): + self.saved_context = getcontext() + setcontext(self.new_context) + return self.new_context + def __exit__(self, t, v, tb): + setcontext(self.saved_context) + if t is not None: + raise t, v, tb + class Context(object): """Contains the context for a Decimal instance. @@ -2224,6 +2250,9 @@ class Context(object): s.append('traps=[' + ', '.join([t.__name__ for t, v in self.traps.items() if v]) + ']') return ', '.join(s) + ')' + def __context__(self): + return ContextManager(self.copy()) + def clear_flags(self): """Reset all flags to zero""" for flag in self.flags: diff --git a/Lib/test/contextmanager.py b/Lib/test/contextmanager.py deleted file mode 100644 index 07fe61c..0000000 --- a/Lib/test/contextmanager.py +++ /dev/null @@ -1,33 +0,0 @@ -class GeneratorContextManager(object): - def __init__(self, gen): - self.gen = gen - - def __context__(self): - return self - - def __enter__(self): - try: - return self.gen.next() - except StopIteration: - raise RuntimeError("generator didn't yield") - - def __exit__(self, type, value, traceback): - if type is None: - try: - self.gen.next() - except StopIteration: - return - else: - raise RuntimeError("generator didn't stop") - else: - try: - self.gen.throw(type, value, traceback) - except (type, StopIteration): - return - else: - raise RuntimeError("generator caught exception") - -def contextmanager(func): - def helper(*args, **kwds): - return GeneratorContextManager(func(*args, **kwds)) - return helper diff --git a/Lib/test/nested.py b/Lib/test/nested.py deleted file mode 100644 index b5030e0..0000000 --- a/Lib/test/nested.py +++ /dev/null @@ -1,40 +0,0 @@ -import sys -from collections import deque - - -class nested(object): - def __init__(self, *contexts): - self.contexts = contexts - self.entered = None - - def __context__(self): - return self - - def __enter__(self): - if self.entered is not None: - raise RuntimeError("Context is not reentrant") - self.entered = deque() - vars = [] - try: - for context in self.contexts: - mgr = context.__context__() - vars.append(mgr.__enter__()) - self.entered.appendleft(mgr) - except: - self.__exit__(*sys.exc_info()) - raise - return vars - - def __exit__(self, *exc_info): - # Behave like nested with statements - # first in, last out - # New exceptions override old ones - ex = exc_info - for mgr in self.entered: - try: - mgr.__exit__(*ex) - except: - ex = sys.exc_info() - self.entered = None - if ex is not exc_info: - raise ex[0], ex[1], ex[2] diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py new file mode 100644 index 0000000..8c8d887 --- /dev/null +++ b/Lib/test/test_contextlib.py @@ -0,0 +1,240 @@ +"""Unit tests for contextlib.py, and other context managers.""" + +from __future__ import with_statement + +import os +import decimal +import tempfile +import unittest +import threading +from contextlib import * # Tests __all__ + +class ContextManagerTestCase(unittest.TestCase): + + def test_contextmanager_plain(self): + state = [] + @contextmanager + def woohoo(): + state.append(1) + yield 42 + state.append(999) + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + self.assertEqual(state, [1, 42, 999]) + + def test_contextmanager_finally(self): + state = [] + @contextmanager + def woohoo(): + state.append(1) + try: + yield 42 + finally: + state.append(999) + try: + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError() + except ZeroDivisionError: + pass + else: + self.fail("Expected ZeroDivisionError") + self.assertEqual(state, [1, 42, 999]) + + def test_contextmanager_except(self): + state = [] + @contextmanager + def woohoo(): + state.append(1) + try: + yield 42 + except ZeroDivisionError, e: + state.append(e.args[0]) + self.assertEqual(state, [1, 42, 999]) + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError(999) + self.assertEqual(state, [1, 42, 999]) + +class NestedTestCase(unittest.TestCase): + + # XXX This needs more work + + def test_nested(self): + @contextmanager + def a(): + yield 1 + @contextmanager + def b(): + yield 2 + @contextmanager + def c(): + yield 3 + with nested(a(), b(), c()) as (x, y, z): + self.assertEqual(x, 1) + self.assertEqual(y, 2) + self.assertEqual(z, 3) + + def test_nested_cleanup(self): + state = [] + @contextmanager + def a(): + state.append(1) + try: + yield 2 + finally: + state.append(3) + @contextmanager + def b(): + state.append(4) + try: + yield 5 + finally: + state.append(6) + try: + with nested(a(), b()) as (x, y): + state.append(x) + state.append(y) + 1/0 + except ZeroDivisionError: + self.assertEqual(state, [1, 4, 2, 5, 6, 3]) + else: + self.fail("Didn't raise ZeroDivisionError") + +class ClosingTestCase(unittest.TestCase): + + # XXX This needs more work + + def test_closing(self): + state = [] + class C: + def close(self): + state.append(1) + x = C() + self.assertEqual(state, []) + with closing(x) as y: + self.assertEqual(x, y) + self.assertEqual(state, [1]) + + def test_closing_error(self): + state = [] + class C: + def close(self): + state.append(1) + x = C() + self.assertEqual(state, []) + try: + with closing(x) as y: + self.assertEqual(x, y) + 1/0 + except ZeroDivisionError: + self.assertEqual(state, [1]) + else: + self.fail("Didn't raise ZeroDivisionError") + +class FileContextTestCase(unittest.TestCase): + + def testWithOpen(self): + tfn = tempfile.mktemp() + try: + f = None + with open(tfn, "w") as f: + self.failIf(f.closed) + f.write("Booh\n") + self.failUnless(f.closed) + f = None + try: + with open(tfn, "r") as f: + self.failIf(f.closed) + self.assertEqual(f.read(), "Booh\n") + 1/0 + except ZeroDivisionError: + self.failUnless(f.closed) + else: + self.fail("Didn't raise ZeroDivisionError") + finally: + try: + os.remove(tfn) + except os.error: + pass + +class LockContextTestCase(unittest.TestCase): + + def boilerPlate(self, lock, locked): + self.failIf(locked()) + with lock: + self.failUnless(locked()) + self.failIf(locked()) + try: + with lock: + self.failUnless(locked()) + 1/0 + except ZeroDivisionError: + self.failIf(locked()) + else: + self.fail("Didn't raise ZeroDivisionError") + + def testWithLock(self): + lock = threading.Lock() + self.boilerPlate(lock, lock.locked) + + def testWithRLock(self): + lock = threading.RLock() + self.boilerPlate(lock, lock._is_owned) + + def testWithCondition(self): + lock = threading.Condition() + def locked(): + return lock._is_owned() + self.boilerPlate(lock, locked) + + def testWithSemaphore(self): + lock = threading.Semaphore() + def locked(): + if lock.acquire(False): + lock.release() + return False + else: + return True + self.boilerPlate(lock, locked) + + def testWithBoundedSemaphore(self): + lock = threading.BoundedSemaphore() + def locked(): + if lock.acquire(False): + lock.release() + return False + else: + return True + self.boilerPlate(lock, locked) + +class DecimalContextTestCase(unittest.TestCase): + + # XXX Somebody should write more thorough tests for this + + def testBasic(self): + ctx = decimal.getcontext() + ctx.prec = save_prec = decimal.ExtendedContext.prec + 5 + with decimal.ExtendedContext: + self.assertEqual(decimal.getcontext().prec, + decimal.ExtendedContext.prec) + self.assertEqual(decimal.getcontext().prec, save_prec) + try: + with decimal.ExtendedContext: + self.assertEqual(decimal.getcontext().prec, + decimal.ExtendedContext.prec) + 1/0 + except ZeroDivisionError: + self.assertEqual(decimal.getcontext().prec, save_prec) + else: + self.fail("Didn't raise ZeroDivisionError") + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_with.py b/Lib/test/test_with.py index ed072c9..36035e3 100644 --- a/Lib/test/test_with.py +++ b/Lib/test/test_with.py @@ -7,9 +7,10 @@ from __future__ import with_statement __author__ = "Mike Bland" __email__ = "mbland at acm dot org" +import sys import unittest -from test.contextmanager import GeneratorContextManager -from test.nested import nested +from collections import deque +from contextlib import GeneratorContextManager, contextmanager from test.test_support import run_unittest @@ -57,9 +58,48 @@ def mock_contextmanager_generator(): mock.stopped = True -class MockNested(nested): +class Nested(object): + + def __init__(self, *contexts): + self.contexts = contexts + self.entered = None + + def __context__(self): + return self + + def __enter__(self): + if self.entered is not None: + raise RuntimeError("Context is not reentrant") + self.entered = deque() + vars = [] + try: + for context in self.contexts: + mgr = context.__context__() + vars.append(mgr.__enter__()) + self.entered.appendleft(mgr) + except: + self.__exit__(*sys.exc_info()) + raise + return vars + + def __exit__(self, *exc_info): + # Behave like nested with statements + # first in, last out + # New exceptions override old ones + ex = exc_info + for mgr in self.entered: + try: + mgr.__exit__(*ex) + except: + ex = sys.exc_info() + self.entered = None + if ex is not exc_info: + raise ex[0], ex[1], ex[2] + + +class MockNested(Nested): def __init__(self, *contexts): - nested.__init__(self, *contexts) + Nested.__init__(self, *contexts) self.context_called = False self.enter_called = False self.exit_called = False @@ -67,16 +107,16 @@ class MockNested(nested): def __context__(self): self.context_called = True - return nested.__context__(self) + return Nested.__context__(self) def __enter__(self): self.enter_called = True - return nested.__enter__(self) + return Nested.__enter__(self) def __exit__(self, *exc_info): self.exit_called = True self.exit_args = exc_info - return nested.__exit__(self, *exc_info) + return Nested.__exit__(self, *exc_info) class FailureTestCase(unittest.TestCase): @@ -294,7 +334,7 @@ class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin): class NestedNonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin): def testSingleArgInlineGeneratorSyntax(self): - with nested(mock_contextmanager_generator()): + with Nested(mock_contextmanager_generator()): pass def testSingleArgUnbound(self): @@ -310,7 +350,7 @@ class NestedNonexceptionalTestCase(unittest.TestCase, m = mock_contextmanager_generator() # This will bind all the arguments to nested() into a single list # assigned to foo. - with nested(m) as foo: + with Nested(m) as foo: self.assertInWithManagerInvariants(m) self.assertAfterWithManagerInvariantsNoError(m) @@ -318,14 +358,13 @@ class NestedNonexceptionalTestCase(unittest.TestCase, m = mock_contextmanager_generator() # This will bind all the arguments to nested() into a single list # assigned to foo. - # FIXME: what should this do: with nested(m) as (foo,): - with nested(m) as (foo): + with Nested(m) as (foo): self.assertInWithManagerInvariants(m) self.assertAfterWithManagerInvariantsNoError(m) def testSingleArgBoundToMultipleElementTupleError(self): def shouldThrowValueError(): - with nested(mock_contextmanager_generator()) as (foo, bar): + with Nested(mock_contextmanager_generator()) as (foo, bar): pass self.assertRaises(ValueError, shouldThrowValueError) @@ -535,7 +574,9 @@ class AssignmentTargetTestCase(unittest.TestCase): class C: def __context__(self): return self def __enter__(self): return 1, 2, 3 - def __exit__(self, *a): pass + def __exit__(self, t, v, tb): + if t is not None: + raise t, v, tb targets = {1: [0, 1, 2]} with C() as (targets[1][0], targets[1][1], targets[1][2]): self.assertEqual(targets, {1: [1, 2, 3]}) @@ -551,11 +592,26 @@ class AssignmentTargetTestCase(unittest.TestCase): self.assertEqual(blah.three, 3) +class ExitSwallowsExceptionTestCase(unittest.TestCase): + + def testExitSwallowsException(self): + class AfricanOrEuropean: + def __context__(self): return self + def __enter__(self): pass + def __exit__(self, t, v, tb): pass + try: + with AfricanOrEuropean(): + 1/0 + except ZeroDivisionError: + self.fail("ZeroDivisionError should have been swallowed") + + def test_main(): run_unittest(FailureTestCase, NonexceptionalTestCase, NestedNonexceptionalTestCase, ExceptionalTestCase, NonLocalFlowControlTestCase, - AssignmentTargetTestCase) + AssignmentTargetTestCase, + ExitSwallowsExceptionTestCase) if __name__ == '__main__': diff --git a/Lib/threading.py b/Lib/threading.py index 9cc108e..5b485d5 100644 --- a/Lib/threading.py +++ b/Lib/threading.py @@ -90,6 +90,9 @@ class _RLock(_Verbose): self.__owner and self.__owner.getName(), self.__count) + def __context__(self): + return self + def acquire(self, blocking=1): me = currentThread() if self.__owner is me: @@ -108,6 +111,8 @@ class _RLock(_Verbose): self._note("%s.acquire(%s): failure", self, blocking) return rc + __enter__ = acquire + def release(self): me = currentThread() assert self.__owner is me, "release() of un-acquire()d lock" @@ -121,6 +126,11 @@ class _RLock(_Verbose): if __debug__: self._note("%s.release(): non-final release", self) + def __exit__(self, t, v, tb): + self.release() + if t is not None: + raise t, v, tb + # Internal methods used by condition variables def _acquire_restore(self, (count, owner)): @@ -156,6 +166,7 @@ class _Condition(_Verbose): self.__lock = lock # Export the lock's acquire() and release() methods self.acquire = lock.acquire + self.__enter__ = self.acquire self.release = lock.release # If the lock defines _release_save() and/or _acquire_restore(), # these override the default implementations (which just call @@ -174,6 +185,14 @@ class _Condition(_Verbose): pass self.__waiters = [] + def __context__(self): + return self + + def __exit__(self, t, v, tb): + self.release() + if t is not None: + raise t, v, tb + def __repr__(self): return "<Condition(%s, %d)>" % (self.__lock, len(self.__waiters)) @@ -267,6 +286,9 @@ class _Semaphore(_Verbose): self.__cond = Condition(Lock()) self.__value = value + def __context__(self): + return self + def acquire(self, blocking=1): rc = False self.__cond.acquire() @@ -286,6 +308,8 @@ class _Semaphore(_Verbose): self.__cond.release() return rc + __enter__ = acquire + def release(self): self.__cond.acquire() self.__value = self.__value + 1 @@ -295,6 +319,11 @@ class _Semaphore(_Verbose): self.__cond.notify() self.__cond.release() + def __exit__(self, t, v, tb): + self.release() + if t is not None: + raise t, v, tb + def BoundedSemaphore(*args, **kwargs): return _BoundedSemaphore(*args, **kwargs) |