diff options
author | Guido van Rossum <guido@python.org> | 2006-02-28 21:57:43 (GMT) |
---|---|---|
committer | Guido van Rossum <guido@python.org> | 2006-02-28 21:57:43 (GMT) |
commit | 1a5e21e0334a6d4e1c756575023c7157fc9ee306 (patch) | |
tree | d2c1c9383b3c6d8194449ae756e663b0b0ac9e4e /Lib/test | |
parent | 87a8b4fee56b8204ee9f7b0ce2e5db0564e8f86e (diff) | |
download | cpython-1a5e21e0334a6d4e1c756575023c7157fc9ee306.zip cpython-1a5e21e0334a6d4e1c756575023c7157fc9ee306.tar.gz cpython-1a5e21e0334a6d4e1c756575023c7157fc9ee306.tar.bz2 |
Updates to the with-statement:
- New semantics for __exit__() -- it must re-raise the exception
if type is not None; the with-statement itself doesn't do this.
(See the updated PEP for motivation.)
- Added context managers to:
- file
- thread.LockType
- threading.{Lock,RLock,Condition,Semaphore,BoundedSemaphore}
- decimal.Context
- Added contextlib.py, which defines @contextmanager, nested(), closing().
- Unit tests all around; bot no docs yet.
Diffstat (limited to 'Lib/test')
-rw-r--r-- | Lib/test/contextmanager.py | 33 | ||||
-rw-r--r-- | Lib/test/nested.py | 40 | ||||
-rw-r--r-- | Lib/test/test_contextlib.py | 240 | ||||
-rw-r--r-- | Lib/test/test_with.py | 84 |
4 files changed, 310 insertions, 87 deletions
diff --git a/Lib/test/contextmanager.py b/Lib/test/contextmanager.py deleted file mode 100644 index 07fe61c..0000000 --- a/Lib/test/contextmanager.py +++ /dev/null @@ -1,33 +0,0 @@ -class GeneratorContextManager(object): - def __init__(self, gen): - self.gen = gen - - def __context__(self): - return self - - def __enter__(self): - try: - return self.gen.next() - except StopIteration: - raise RuntimeError("generator didn't yield") - - def __exit__(self, type, value, traceback): - if type is None: - try: - self.gen.next() - except StopIteration: - return - else: - raise RuntimeError("generator didn't stop") - else: - try: - self.gen.throw(type, value, traceback) - except (type, StopIteration): - return - else: - raise RuntimeError("generator caught exception") - -def contextmanager(func): - def helper(*args, **kwds): - return GeneratorContextManager(func(*args, **kwds)) - return helper diff --git a/Lib/test/nested.py b/Lib/test/nested.py deleted file mode 100644 index b5030e0..0000000 --- a/Lib/test/nested.py +++ /dev/null @@ -1,40 +0,0 @@ -import sys -from collections import deque - - -class nested(object): - def __init__(self, *contexts): - self.contexts = contexts - self.entered = None - - def __context__(self): - return self - - def __enter__(self): - if self.entered is not None: - raise RuntimeError("Context is not reentrant") - self.entered = deque() - vars = [] - try: - for context in self.contexts: - mgr = context.__context__() - vars.append(mgr.__enter__()) - self.entered.appendleft(mgr) - except: - self.__exit__(*sys.exc_info()) - raise - return vars - - def __exit__(self, *exc_info): - # Behave like nested with statements - # first in, last out - # New exceptions override old ones - ex = exc_info - for mgr in self.entered: - try: - mgr.__exit__(*ex) - except: - ex = sys.exc_info() - self.entered = None - if ex is not exc_info: - raise ex[0], ex[1], ex[2] diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py new file mode 100644 index 0000000..8c8d887 --- /dev/null +++ b/Lib/test/test_contextlib.py @@ -0,0 +1,240 @@ +"""Unit tests for contextlib.py, and other context managers.""" + +from __future__ import with_statement + +import os +import decimal +import tempfile +import unittest +import threading +from contextlib import * # Tests __all__ + +class ContextManagerTestCase(unittest.TestCase): + + def test_contextmanager_plain(self): + state = [] + @contextmanager + def woohoo(): + state.append(1) + yield 42 + state.append(999) + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + self.assertEqual(state, [1, 42, 999]) + + def test_contextmanager_finally(self): + state = [] + @contextmanager + def woohoo(): + state.append(1) + try: + yield 42 + finally: + state.append(999) + try: + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError() + except ZeroDivisionError: + pass + else: + self.fail("Expected ZeroDivisionError") + self.assertEqual(state, [1, 42, 999]) + + def test_contextmanager_except(self): + state = [] + @contextmanager + def woohoo(): + state.append(1) + try: + yield 42 + except ZeroDivisionError, e: + state.append(e.args[0]) + self.assertEqual(state, [1, 42, 999]) + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError(999) + self.assertEqual(state, [1, 42, 999]) + +class NestedTestCase(unittest.TestCase): + + # XXX This needs more work + + def test_nested(self): + @contextmanager + def a(): + yield 1 + @contextmanager + def b(): + yield 2 + @contextmanager + def c(): + yield 3 + with nested(a(), b(), c()) as (x, y, z): + self.assertEqual(x, 1) + self.assertEqual(y, 2) + self.assertEqual(z, 3) + + def test_nested_cleanup(self): + state = [] + @contextmanager + def a(): + state.append(1) + try: + yield 2 + finally: + state.append(3) + @contextmanager + def b(): + state.append(4) + try: + yield 5 + finally: + state.append(6) + try: + with nested(a(), b()) as (x, y): + state.append(x) + state.append(y) + 1/0 + except ZeroDivisionError: + self.assertEqual(state, [1, 4, 2, 5, 6, 3]) + else: + self.fail("Didn't raise ZeroDivisionError") + +class ClosingTestCase(unittest.TestCase): + + # XXX This needs more work + + def test_closing(self): + state = [] + class C: + def close(self): + state.append(1) + x = C() + self.assertEqual(state, []) + with closing(x) as y: + self.assertEqual(x, y) + self.assertEqual(state, [1]) + + def test_closing_error(self): + state = [] + class C: + def close(self): + state.append(1) + x = C() + self.assertEqual(state, []) + try: + with closing(x) as y: + self.assertEqual(x, y) + 1/0 + except ZeroDivisionError: + self.assertEqual(state, [1]) + else: + self.fail("Didn't raise ZeroDivisionError") + +class FileContextTestCase(unittest.TestCase): + + def testWithOpen(self): + tfn = tempfile.mktemp() + try: + f = None + with open(tfn, "w") as f: + self.failIf(f.closed) + f.write("Booh\n") + self.failUnless(f.closed) + f = None + try: + with open(tfn, "r") as f: + self.failIf(f.closed) + self.assertEqual(f.read(), "Booh\n") + 1/0 + except ZeroDivisionError: + self.failUnless(f.closed) + else: + self.fail("Didn't raise ZeroDivisionError") + finally: + try: + os.remove(tfn) + except os.error: + pass + +class LockContextTestCase(unittest.TestCase): + + def boilerPlate(self, lock, locked): + self.failIf(locked()) + with lock: + self.failUnless(locked()) + self.failIf(locked()) + try: + with lock: + self.failUnless(locked()) + 1/0 + except ZeroDivisionError: + self.failIf(locked()) + else: + self.fail("Didn't raise ZeroDivisionError") + + def testWithLock(self): + lock = threading.Lock() + self.boilerPlate(lock, lock.locked) + + def testWithRLock(self): + lock = threading.RLock() + self.boilerPlate(lock, lock._is_owned) + + def testWithCondition(self): + lock = threading.Condition() + def locked(): + return lock._is_owned() + self.boilerPlate(lock, locked) + + def testWithSemaphore(self): + lock = threading.Semaphore() + def locked(): + if lock.acquire(False): + lock.release() + return False + else: + return True + self.boilerPlate(lock, locked) + + def testWithBoundedSemaphore(self): + lock = threading.BoundedSemaphore() + def locked(): + if lock.acquire(False): + lock.release() + return False + else: + return True + self.boilerPlate(lock, locked) + +class DecimalContextTestCase(unittest.TestCase): + + # XXX Somebody should write more thorough tests for this + + def testBasic(self): + ctx = decimal.getcontext() + ctx.prec = save_prec = decimal.ExtendedContext.prec + 5 + with decimal.ExtendedContext: + self.assertEqual(decimal.getcontext().prec, + decimal.ExtendedContext.prec) + self.assertEqual(decimal.getcontext().prec, save_prec) + try: + with decimal.ExtendedContext: + self.assertEqual(decimal.getcontext().prec, + decimal.ExtendedContext.prec) + 1/0 + except ZeroDivisionError: + self.assertEqual(decimal.getcontext().prec, save_prec) + else: + self.fail("Didn't raise ZeroDivisionError") + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_with.py b/Lib/test/test_with.py index ed072c9..36035e3 100644 --- a/Lib/test/test_with.py +++ b/Lib/test/test_with.py @@ -7,9 +7,10 @@ from __future__ import with_statement __author__ = "Mike Bland" __email__ = "mbland at acm dot org" +import sys import unittest -from test.contextmanager import GeneratorContextManager -from test.nested import nested +from collections import deque +from contextlib import GeneratorContextManager, contextmanager from test.test_support import run_unittest @@ -57,9 +58,48 @@ def mock_contextmanager_generator(): mock.stopped = True -class MockNested(nested): +class Nested(object): + + def __init__(self, *contexts): + self.contexts = contexts + self.entered = None + + def __context__(self): + return self + + def __enter__(self): + if self.entered is not None: + raise RuntimeError("Context is not reentrant") + self.entered = deque() + vars = [] + try: + for context in self.contexts: + mgr = context.__context__() + vars.append(mgr.__enter__()) + self.entered.appendleft(mgr) + except: + self.__exit__(*sys.exc_info()) + raise + return vars + + def __exit__(self, *exc_info): + # Behave like nested with statements + # first in, last out + # New exceptions override old ones + ex = exc_info + for mgr in self.entered: + try: + mgr.__exit__(*ex) + except: + ex = sys.exc_info() + self.entered = None + if ex is not exc_info: + raise ex[0], ex[1], ex[2] + + +class MockNested(Nested): def __init__(self, *contexts): - nested.__init__(self, *contexts) + Nested.__init__(self, *contexts) self.context_called = False self.enter_called = False self.exit_called = False @@ -67,16 +107,16 @@ class MockNested(nested): def __context__(self): self.context_called = True - return nested.__context__(self) + return Nested.__context__(self) def __enter__(self): self.enter_called = True - return nested.__enter__(self) + return Nested.__enter__(self) def __exit__(self, *exc_info): self.exit_called = True self.exit_args = exc_info - return nested.__exit__(self, *exc_info) + return Nested.__exit__(self, *exc_info) class FailureTestCase(unittest.TestCase): @@ -294,7 +334,7 @@ class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin): class NestedNonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin): def testSingleArgInlineGeneratorSyntax(self): - with nested(mock_contextmanager_generator()): + with Nested(mock_contextmanager_generator()): pass def testSingleArgUnbound(self): @@ -310,7 +350,7 @@ class NestedNonexceptionalTestCase(unittest.TestCase, m = mock_contextmanager_generator() # This will bind all the arguments to nested() into a single list # assigned to foo. - with nested(m) as foo: + with Nested(m) as foo: self.assertInWithManagerInvariants(m) self.assertAfterWithManagerInvariantsNoError(m) @@ -318,14 +358,13 @@ class NestedNonexceptionalTestCase(unittest.TestCase, m = mock_contextmanager_generator() # This will bind all the arguments to nested() into a single list # assigned to foo. - # FIXME: what should this do: with nested(m) as (foo,): - with nested(m) as (foo): + with Nested(m) as (foo): self.assertInWithManagerInvariants(m) self.assertAfterWithManagerInvariantsNoError(m) def testSingleArgBoundToMultipleElementTupleError(self): def shouldThrowValueError(): - with nested(mock_contextmanager_generator()) as (foo, bar): + with Nested(mock_contextmanager_generator()) as (foo, bar): pass self.assertRaises(ValueError, shouldThrowValueError) @@ -535,7 +574,9 @@ class AssignmentTargetTestCase(unittest.TestCase): class C: def __context__(self): return self def __enter__(self): return 1, 2, 3 - def __exit__(self, *a): pass + def __exit__(self, t, v, tb): + if t is not None: + raise t, v, tb targets = {1: [0, 1, 2]} with C() as (targets[1][0], targets[1][1], targets[1][2]): self.assertEqual(targets, {1: [1, 2, 3]}) @@ -551,11 +592,26 @@ class AssignmentTargetTestCase(unittest.TestCase): self.assertEqual(blah.three, 3) +class ExitSwallowsExceptionTestCase(unittest.TestCase): + + def testExitSwallowsException(self): + class AfricanOrEuropean: + def __context__(self): return self + def __enter__(self): pass + def __exit__(self, t, v, tb): pass + try: + with AfricanOrEuropean(): + 1/0 + except ZeroDivisionError: + self.fail("ZeroDivisionError should have been swallowed") + + def test_main(): run_unittest(FailureTestCase, NonexceptionalTestCase, NestedNonexceptionalTestCase, ExceptionalTestCase, NonLocalFlowControlTestCase, - AssignmentTargetTestCase) + AssignmentTargetTestCase, + ExitSwallowsExceptionTestCase) if __name__ == '__main__': |