summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_functools.py
diff options
context:
space:
mode:
authorNick Coghlan <ncoghlan@gmail.com>2006-06-08 13:54:49 (GMT)
committerNick Coghlan <ncoghlan@gmail.com>2006-06-08 13:54:49 (GMT)
commit676725db928dae9f963b3a8389c0a9124e1eacbd (patch)
tree69e56c6aa31a3998aea76f7aa3aa4ae76f803d3a /Lib/test/test_functools.py
parent98251f8a2f169d5fd1b6ae0fc9c020d00ec74df5 (diff)
downloadcpython-676725db928dae9f963b3a8389c0a9124e1eacbd.zip
cpython-676725db928dae9f963b3a8389c0a9124e1eacbd.tar.gz
cpython-676725db928dae9f963b3a8389c0a9124e1eacbd.tar.bz2
Add functools.update_wrapper() and functools.wraps() as described in PEP 356
Diffstat (limited to 'Lib/test/test_functools.py')
-rw-r--r--Lib/test/test_functools.py109
1 files changed, 109 insertions, 0 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index 609e8f4..8dc185b 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -152,6 +152,113 @@ class TestPythonPartial(TestPartial):
thetype = PythonPartial
+class TestUpdateWrapper(unittest.TestCase):
+
+ def check_wrapper(self, wrapper, wrapped,
+ assigned=functools.WRAPPER_ASSIGNMENTS,
+ updated=functools.WRAPPER_UPDATES):
+ # Check attributes were assigned
+ for name in assigned:
+ self.failUnless(getattr(wrapper, name) is getattr(wrapped, name))
+ # Check attributes were updated
+ for name in updated:
+ wrapper_attr = getattr(wrapper, name)
+ wrapped_attr = getattr(wrapped, name)
+ for key in wrapped_attr:
+ self.failUnless(wrapped_attr[key] is wrapper_attr[key])
+
+ def test_default_update(self):
+ def f():
+ """This is a test"""
+ pass
+ f.attr = 'This is also a test'
+ def wrapper():
+ pass
+ functools.update_wrapper(wrapper, f)
+ self.check_wrapper(wrapper, f)
+ self.assertEqual(wrapper.__name__, 'f')
+ self.assertEqual(wrapper.__doc__, 'This is a test')
+ self.assertEqual(wrapper.attr, 'This is also a test')
+
+ def test_no_update(self):
+ def f():
+ """This is a test"""
+ pass
+ f.attr = 'This is also a test'
+ def wrapper():
+ pass
+ functools.update_wrapper(wrapper, f, (), ())
+ self.check_wrapper(wrapper, f, (), ())
+ self.assertEqual(wrapper.__name__, 'wrapper')
+ self.assertEqual(wrapper.__doc__, None)
+ self.failIf(hasattr(wrapper, 'attr'))
+
+ def test_selective_update(self):
+ def f():
+ pass
+ f.attr = 'This is a different test'
+ f.dict_attr = dict(a=1, b=2, c=3)
+ def wrapper():
+ pass
+ wrapper.dict_attr = {}
+ assign = ('attr',)
+ update = ('dict_attr',)
+ functools.update_wrapper(wrapper, f, assign, update)
+ self.check_wrapper(wrapper, f, assign, update)
+ self.assertEqual(wrapper.__name__, 'wrapper')
+ self.assertEqual(wrapper.__doc__, None)
+ self.assertEqual(wrapper.attr, 'This is a different test')
+ self.assertEqual(wrapper.dict_attr, f.dict_attr)
+
+
+class TestWraps(TestUpdateWrapper):
+
+ def test_default_update(self):
+ def f():
+ """This is a test"""
+ pass
+ f.attr = 'This is also a test'
+ @functools.wraps(f)
+ def wrapper():
+ pass
+ self.check_wrapper(wrapper, f)
+ self.assertEqual(wrapper.__name__, 'f')
+ self.assertEqual(wrapper.__doc__, 'This is a test')
+ self.assertEqual(wrapper.attr, 'This is also a test')
+
+ def test_no_update(self):
+ def f():
+ """This is a test"""
+ pass
+ f.attr = 'This is also a test'
+ @functools.wraps(f, (), ())
+ def wrapper():
+ pass
+ self.check_wrapper(wrapper, f, (), ())
+ self.assertEqual(wrapper.__name__, 'wrapper')
+ self.assertEqual(wrapper.__doc__, None)
+ self.failIf(hasattr(wrapper, 'attr'))
+
+ def test_selective_update(self):
+ def f():
+ pass
+ f.attr = 'This is a different test'
+ f.dict_attr = dict(a=1, b=2, c=3)
+ def add_dict_attr(f):
+ f.dict_attr = {}
+ return f
+ assign = ('attr',)
+ update = ('dict_attr',)
+ @functools.wraps(f, assign, update)
+ @add_dict_attr
+ def wrapper():
+ pass
+ self.check_wrapper(wrapper, f, assign, update)
+ self.assertEqual(wrapper.__name__, 'wrapper')
+ self.assertEqual(wrapper.__doc__, None)
+ self.assertEqual(wrapper.attr, 'This is a different test')
+ self.assertEqual(wrapper.dict_attr, f.dict_attr)
+
def test_main(verbose=None):
@@ -160,6 +267,8 @@ def test_main(verbose=None):
TestPartial,
TestPartialSubclass,
TestPythonPartial,
+ TestUpdateWrapper,
+ TestWraps
)
test_support.run_unittest(*test_classes)