summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_contextlib.py
diff options
context:
space:
mode:
authorGuido van Rossum <guido@python.org>2006-02-28 21:57:43 (GMT)
committerGuido van Rossum <guido@python.org>2006-02-28 21:57:43 (GMT)
commit1a5e21e0334a6d4e1c756575023c7157fc9ee306 (patch)
treed2c1c9383b3c6d8194449ae756e663b0b0ac9e4e /Lib/test/test_contextlib.py
parent87a8b4fee56b8204ee9f7b0ce2e5db0564e8f86e (diff)
downloadcpython-1a5e21e0334a6d4e1c756575023c7157fc9ee306.zip
cpython-1a5e21e0334a6d4e1c756575023c7157fc9ee306.tar.gz
cpython-1a5e21e0334a6d4e1c756575023c7157fc9ee306.tar.bz2
Updates to the with-statement:
- New semantics for __exit__() -- it must re-raise the exception if type is not None; the with-statement itself doesn't do this. (See the updated PEP for motivation.) - Added context managers to: - file - thread.LockType - threading.{Lock,RLock,Condition,Semaphore,BoundedSemaphore} - decimal.Context - Added contextlib.py, which defines @contextmanager, nested(), closing(). - Unit tests all around; bot no docs yet.
Diffstat (limited to 'Lib/test/test_contextlib.py')
-rw-r--r--Lib/test/test_contextlib.py240
1 files changed, 240 insertions, 0 deletions
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
new file mode 100644
index 0000000..8c8d887
--- /dev/null
+++ b/Lib/test/test_contextlib.py
@@ -0,0 +1,240 @@
+"""Unit tests for contextlib.py, and other context managers."""
+
+from __future__ import with_statement
+
+import os
+import decimal
+import tempfile
+import unittest
+import threading
+from contextlib import * # Tests __all__
+
+class ContextManagerTestCase(unittest.TestCase):
+
+ def test_contextmanager_plain(self):
+ state = []
+ @contextmanager
+ def woohoo():
+ state.append(1)
+ yield 42
+ state.append(999)
+ with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ self.assertEqual(state, [1, 42, 999])
+
+ def test_contextmanager_finally(self):
+ state = []
+ @contextmanager
+ def woohoo():
+ state.append(1)
+ try:
+ yield 42
+ finally:
+ state.append(999)
+ try:
+ with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ raise ZeroDivisionError()
+ except ZeroDivisionError:
+ pass
+ else:
+ self.fail("Expected ZeroDivisionError")
+ self.assertEqual(state, [1, 42, 999])
+
+ def test_contextmanager_except(self):
+ state = []
+ @contextmanager
+ def woohoo():
+ state.append(1)
+ try:
+ yield 42
+ except ZeroDivisionError, e:
+ state.append(e.args[0])
+ self.assertEqual(state, [1, 42, 999])
+ with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ raise ZeroDivisionError(999)
+ self.assertEqual(state, [1, 42, 999])
+
+class NestedTestCase(unittest.TestCase):
+
+ # XXX This needs more work
+
+ def test_nested(self):
+ @contextmanager
+ def a():
+ yield 1
+ @contextmanager
+ def b():
+ yield 2
+ @contextmanager
+ def c():
+ yield 3
+ with nested(a(), b(), c()) as (x, y, z):
+ self.assertEqual(x, 1)
+ self.assertEqual(y, 2)
+ self.assertEqual(z, 3)
+
+ def test_nested_cleanup(self):
+ state = []
+ @contextmanager
+ def a():
+ state.append(1)
+ try:
+ yield 2
+ finally:
+ state.append(3)
+ @contextmanager
+ def b():
+ state.append(4)
+ try:
+ yield 5
+ finally:
+ state.append(6)
+ try:
+ with nested(a(), b()) as (x, y):
+ state.append(x)
+ state.append(y)
+ 1/0
+ except ZeroDivisionError:
+ self.assertEqual(state, [1, 4, 2, 5, 6, 3])
+ else:
+ self.fail("Didn't raise ZeroDivisionError")
+
+class ClosingTestCase(unittest.TestCase):
+
+ # XXX This needs more work
+
+ def test_closing(self):
+ state = []
+ class C:
+ def close(self):
+ state.append(1)
+ x = C()
+ self.assertEqual(state, [])
+ with closing(x) as y:
+ self.assertEqual(x, y)
+ self.assertEqual(state, [1])
+
+ def test_closing_error(self):
+ state = []
+ class C:
+ def close(self):
+ state.append(1)
+ x = C()
+ self.assertEqual(state, [])
+ try:
+ with closing(x) as y:
+ self.assertEqual(x, y)
+ 1/0
+ except ZeroDivisionError:
+ self.assertEqual(state, [1])
+ else:
+ self.fail("Didn't raise ZeroDivisionError")
+
+class FileContextTestCase(unittest.TestCase):
+
+ def testWithOpen(self):
+ tfn = tempfile.mktemp()
+ try:
+ f = None
+ with open(tfn, "w") as f:
+ self.failIf(f.closed)
+ f.write("Booh\n")
+ self.failUnless(f.closed)
+ f = None
+ try:
+ with open(tfn, "r") as f:
+ self.failIf(f.closed)
+ self.assertEqual(f.read(), "Booh\n")
+ 1/0
+ except ZeroDivisionError:
+ self.failUnless(f.closed)
+ else:
+ self.fail("Didn't raise ZeroDivisionError")
+ finally:
+ try:
+ os.remove(tfn)
+ except os.error:
+ pass
+
+class LockContextTestCase(unittest.TestCase):
+
+ def boilerPlate(self, lock, locked):
+ self.failIf(locked())
+ with lock:
+ self.failUnless(locked())
+ self.failIf(locked())
+ try:
+ with lock:
+ self.failUnless(locked())
+ 1/0
+ except ZeroDivisionError:
+ self.failIf(locked())
+ else:
+ self.fail("Didn't raise ZeroDivisionError")
+
+ def testWithLock(self):
+ lock = threading.Lock()
+ self.boilerPlate(lock, lock.locked)
+
+ def testWithRLock(self):
+ lock = threading.RLock()
+ self.boilerPlate(lock, lock._is_owned)
+
+ def testWithCondition(self):
+ lock = threading.Condition()
+ def locked():
+ return lock._is_owned()
+ self.boilerPlate(lock, locked)
+
+ def testWithSemaphore(self):
+ lock = threading.Semaphore()
+ def locked():
+ if lock.acquire(False):
+ lock.release()
+ return False
+ else:
+ return True
+ self.boilerPlate(lock, locked)
+
+ def testWithBoundedSemaphore(self):
+ lock = threading.BoundedSemaphore()
+ def locked():
+ if lock.acquire(False):
+ lock.release()
+ return False
+ else:
+ return True
+ self.boilerPlate(lock, locked)
+
+class DecimalContextTestCase(unittest.TestCase):
+
+ # XXX Somebody should write more thorough tests for this
+
+ def testBasic(self):
+ ctx = decimal.getcontext()
+ ctx.prec = save_prec = decimal.ExtendedContext.prec + 5
+ with decimal.ExtendedContext:
+ self.assertEqual(decimal.getcontext().prec,
+ decimal.ExtendedContext.prec)
+ self.assertEqual(decimal.getcontext().prec, save_prec)
+ try:
+ with decimal.ExtendedContext:
+ self.assertEqual(decimal.getcontext().prec,
+ decimal.ExtendedContext.prec)
+ 1/0
+ except ZeroDivisionError:
+ self.assertEqual(decimal.getcontext().prec, save_prec)
+ else:
+ self.fail("Didn't raise ZeroDivisionError")
+
+
+if __name__ == "__main__":
+ unittest.main()