diff options
author | Guido van Rossum <guido@python.org> | 2006-02-28 21:57:43 (GMT) |
---|---|---|
committer | Guido van Rossum <guido@python.org> | 2006-02-28 21:57:43 (GMT) |
commit | 1a5e21e0334a6d4e1c756575023c7157fc9ee306 (patch) | |
tree | d2c1c9383b3c6d8194449ae756e663b0b0ac9e4e /Lib/test/test_contextlib.py | |
parent | 87a8b4fee56b8204ee9f7b0ce2e5db0564e8f86e (diff) | |
download | cpython-1a5e21e0334a6d4e1c756575023c7157fc9ee306.zip cpython-1a5e21e0334a6d4e1c756575023c7157fc9ee306.tar.gz cpython-1a5e21e0334a6d4e1c756575023c7157fc9ee306.tar.bz2 |
Updates to the with-statement:
- New semantics for __exit__() -- it must re-raise the exception
if type is not None; the with-statement itself doesn't do this.
(See the updated PEP for motivation.)
- Added context managers to:
- file
- thread.LockType
- threading.{Lock,RLock,Condition,Semaphore,BoundedSemaphore}
- decimal.Context
- Added contextlib.py, which defines @contextmanager, nested(), closing().
- Unit tests all around; bot no docs yet.
Diffstat (limited to 'Lib/test/test_contextlib.py')
-rw-r--r-- | Lib/test/test_contextlib.py | 240 |
1 files changed, 240 insertions, 0 deletions
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py new file mode 100644 index 0000000..8c8d887 --- /dev/null +++ b/Lib/test/test_contextlib.py @@ -0,0 +1,240 @@ +"""Unit tests for contextlib.py, and other context managers.""" + +from __future__ import with_statement + +import os +import decimal +import tempfile +import unittest +import threading +from contextlib import * # Tests __all__ + +class ContextManagerTestCase(unittest.TestCase): + + def test_contextmanager_plain(self): + state = [] + @contextmanager + def woohoo(): + state.append(1) + yield 42 + state.append(999) + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + self.assertEqual(state, [1, 42, 999]) + + def test_contextmanager_finally(self): + state = [] + @contextmanager + def woohoo(): + state.append(1) + try: + yield 42 + finally: + state.append(999) + try: + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError() + except ZeroDivisionError: + pass + else: + self.fail("Expected ZeroDivisionError") + self.assertEqual(state, [1, 42, 999]) + + def test_contextmanager_except(self): + state = [] + @contextmanager + def woohoo(): + state.append(1) + try: + yield 42 + except ZeroDivisionError, e: + state.append(e.args[0]) + self.assertEqual(state, [1, 42, 999]) + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError(999) + self.assertEqual(state, [1, 42, 999]) + +class NestedTestCase(unittest.TestCase): + + # XXX This needs more work + + def test_nested(self): + @contextmanager + def a(): + yield 1 + @contextmanager + def b(): + yield 2 + @contextmanager + def c(): + yield 3 + with nested(a(), b(), c()) as (x, y, z): + self.assertEqual(x, 1) + self.assertEqual(y, 2) + self.assertEqual(z, 3) + + def test_nested_cleanup(self): + state = [] + @contextmanager + def a(): + state.append(1) + try: + yield 2 + finally: + state.append(3) + @contextmanager + def b(): + state.append(4) + try: + yield 5 + finally: + state.append(6) + try: + with nested(a(), b()) as (x, y): + state.append(x) + state.append(y) + 1/0 + except ZeroDivisionError: + self.assertEqual(state, [1, 4, 2, 5, 6, 3]) + else: + self.fail("Didn't raise ZeroDivisionError") + +class ClosingTestCase(unittest.TestCase): + + # XXX This needs more work + + def test_closing(self): + state = [] + class C: + def close(self): + state.append(1) + x = C() + self.assertEqual(state, []) + with closing(x) as y: + self.assertEqual(x, y) + self.assertEqual(state, [1]) + + def test_closing_error(self): + state = [] + class C: + def close(self): + state.append(1) + x = C() + self.assertEqual(state, []) + try: + with closing(x) as y: + self.assertEqual(x, y) + 1/0 + except ZeroDivisionError: + self.assertEqual(state, [1]) + else: + self.fail("Didn't raise ZeroDivisionError") + +class FileContextTestCase(unittest.TestCase): + + def testWithOpen(self): + tfn = tempfile.mktemp() + try: + f = None + with open(tfn, "w") as f: + self.failIf(f.closed) + f.write("Booh\n") + self.failUnless(f.closed) + f = None + try: + with open(tfn, "r") as f: + self.failIf(f.closed) + self.assertEqual(f.read(), "Booh\n") + 1/0 + except ZeroDivisionError: + self.failUnless(f.closed) + else: + self.fail("Didn't raise ZeroDivisionError") + finally: + try: + os.remove(tfn) + except os.error: + pass + +class LockContextTestCase(unittest.TestCase): + + def boilerPlate(self, lock, locked): + self.failIf(locked()) + with lock: + self.failUnless(locked()) + self.failIf(locked()) + try: + with lock: + self.failUnless(locked()) + 1/0 + except ZeroDivisionError: + self.failIf(locked()) + else: + self.fail("Didn't raise ZeroDivisionError") + + def testWithLock(self): + lock = threading.Lock() + self.boilerPlate(lock, lock.locked) + + def testWithRLock(self): + lock = threading.RLock() + self.boilerPlate(lock, lock._is_owned) + + def testWithCondition(self): + lock = threading.Condition() + def locked(): + return lock._is_owned() + self.boilerPlate(lock, locked) + + def testWithSemaphore(self): + lock = threading.Semaphore() + def locked(): + if lock.acquire(False): + lock.release() + return False + else: + return True + self.boilerPlate(lock, locked) + + def testWithBoundedSemaphore(self): + lock = threading.BoundedSemaphore() + def locked(): + if lock.acquire(False): + lock.release() + return False + else: + return True + self.boilerPlate(lock, locked) + +class DecimalContextTestCase(unittest.TestCase): + + # XXX Somebody should write more thorough tests for this + + def testBasic(self): + ctx = decimal.getcontext() + ctx.prec = save_prec = decimal.ExtendedContext.prec + 5 + with decimal.ExtendedContext: + self.assertEqual(decimal.getcontext().prec, + decimal.ExtendedContext.prec) + self.assertEqual(decimal.getcontext().prec, save_prec) + try: + with decimal.ExtendedContext: + self.assertEqual(decimal.getcontext().prec, + decimal.ExtendedContext.prec) + 1/0 + except ZeroDivisionError: + self.assertEqual(decimal.getcontext().prec, save_prec) + else: + self.fail("Didn't raise ZeroDivisionError") + + +if __name__ == "__main__": + unittest.main() |