summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_functools.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_functools.py')
-rw-r--r--Lib/test/test_functools.py328
1 files changed, 308 insertions, 20 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index d20bafe..73a77d6 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -1,8 +1,10 @@
import functools
+import sys
import unittest
from test import support
from weakref import proxy
import pickle
+from random import choice
@staticmethod
def PythonPartial(func, *args, **keywords):
@@ -44,9 +46,17 @@ class TestPartial(unittest.TestCase):
# attributes should not be writable
if not isinstance(self.thetype, type):
return
- self.assertRaises(TypeError, setattr, p, 'func', map)
- self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
- self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
+ self.assertRaises(AttributeError, setattr, p, 'func', map)
+ self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
+ self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
+
+ p = self.thetype(hex)
+ try:
+ del p.__dict__
+ except TypeError:
+ pass
+ else:
+ self.fail('partial object allowed __dict__ to be deleted')
def test_argument_checking(self):
self.assertRaises(TypeError, self.thetype) # need at least a func arg
@@ -122,15 +132,6 @@ class TestPartial(unittest.TestCase):
self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
- def test_attributes(self):
- p = self.thetype(hex)
- try:
- del p.__dict__
- except TypeError:
- pass
- else:
- self.fail('partial object allowed __dict__ to be deleted')
-
def test_weakref(self):
f = self.thetype(int, base=16)
p = proxy(f)
@@ -145,6 +146,32 @@ class TestPartial(unittest.TestCase):
join = self.thetype(''.join)
self.assertEqual(join(data), '0123456789')
+ def test_repr(self):
+ args = (object(), object())
+ args_repr = ', '.join(repr(a) for a in args)
+ kwargs = {'a': object(), 'b': object()}
+ kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
+ if self.thetype is functools.partial:
+ name = 'functools.partial'
+ else:
+ name = self.thetype.__name__
+
+ f = self.thetype(capture)
+ self.assertEqual('{}({!r})'.format(name, capture),
+ repr(f))
+
+ f = self.thetype(capture, *args)
+ self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
+ repr(f))
+
+ f = self.thetype(capture, **kwargs)
+ self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
+ repr(f))
+
+ f = self.thetype(capture, *args, **kwargs)
+ self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
+ repr(f))
+
def test_pickle(self):
f = self.thetype(signature, 'asdf', bar=True)
f.add_something_to__dict__ = True
@@ -162,6 +189,9 @@ class TestPythonPartial(TestPartial):
thetype = PythonPartial
+ # the python version hasn't a nice repr
+ def test_repr(self): pass
+
# the python version isn't picklable
def test_pickle(self): pass
@@ -180,7 +210,7 @@ class TestUpdateWrapper(unittest.TestCase):
for key in wrapped_attr:
self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
- def test_default_update(self):
+ def _default_update(self):
def f(a:'This is a new annotation'):
"""This is a test"""
pass
@@ -188,13 +218,23 @@ class TestUpdateWrapper(unittest.TestCase):
def wrapper(b:'This is the prior annotation'):
pass
functools.update_wrapper(wrapper, f)
+ return wrapper, f
+
+ def test_default_update(self):
+ wrapper, f = self._default_update()
self.check_wrapper(wrapper, f)
+ self.assertIs(wrapper.__wrapped__, f)
self.assertEqual(wrapper.__name__, 'f')
- self.assertEqual(wrapper.__doc__, 'This is a test')
self.assertEqual(wrapper.attr, 'This is also a test')
self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
self.assertNotIn('b', wrapper.__annotations__)
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ def test_default_update_doc(self):
+ wrapper, f = self._default_update()
+ self.assertEqual(wrapper.__doc__, 'This is a test')
+
def test_no_update(self):
def f():
"""This is a test"""
@@ -226,6 +266,28 @@ class TestUpdateWrapper(unittest.TestCase):
self.assertEqual(wrapper.attr, 'This is a different test')
self.assertEqual(wrapper.dict_attr, f.dict_attr)
+ def test_missing_attributes(self):
+ def f():
+ pass
+ def wrapper():
+ pass
+ wrapper.dict_attr = {}
+ assign = ('attr',)
+ update = ('dict_attr',)
+ # Missing attributes on wrapped object are ignored
+ functools.update_wrapper(wrapper, f, assign, update)
+ self.assertNotIn('attr', wrapper.__dict__)
+ self.assertEqual(wrapper.dict_attr, {})
+ # Wrapper must have expected attributes for updating
+ del wrapper.dict_attr
+ with self.assertRaises(AttributeError):
+ functools.update_wrapper(wrapper, f, assign, update)
+ wrapper.dict_attr = 1
+ with self.assertRaises(AttributeError):
+ functools.update_wrapper(wrapper, f, assign, update)
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
def test_builtin_update(self):
# Test for bug #1576241
def wrapper():
@@ -237,7 +299,7 @@ class TestUpdateWrapper(unittest.TestCase):
class TestWraps(TestUpdateWrapper):
- def test_default_update(self):
+ def _default_update(self):
def f():
"""This is a test"""
pass
@@ -246,10 +308,19 @@ class TestWraps(TestUpdateWrapper):
def wrapper():
pass
self.check_wrapper(wrapper, f)
+ return wrapper
+
+ def test_default_update(self):
+ wrapper = self._default_update()
self.assertEqual(wrapper.__name__, 'f')
- self.assertEqual(wrapper.__doc__, 'This is a test')
self.assertEqual(wrapper.attr, 'This is also a test')
+ @unittest.skipIf(not sys.flags.optimize <= 1,
+ "Docstrings are omitted with -O2 and above")
+ def test_default_update_doc(self):
+ wrapper = self._default_update()
+ self.assertEqual(wrapper.__doc__, 'This is a test')
+
def test_no_update(self):
def f():
"""This is a test"""
@@ -363,18 +434,235 @@ class TestReduce(unittest.TestCase):
d = {"one": 1, "two": 2, "three": 3}
self.assertEqual(self.func(add, d), "".join(d.keys()))
-
-
+class TestCmpToKey(unittest.TestCase):
+ def test_cmp_to_key(self):
+ def mycmp(x, y):
+ return y - x
+ self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
+ [4, 3, 2, 1, 0])
+
+ def test_hash(self):
+ def mycmp(x, y):
+ return y - x
+ key = functools.cmp_to_key(mycmp)
+ k = key(10)
+ self.assertRaises(TypeError, hash(k))
+
+class TestTotalOrdering(unittest.TestCase):
+
+ def test_total_ordering_lt(self):
+ @functools.total_ordering
+ class A:
+ def __init__(self, value):
+ self.value = value
+ def __lt__(self, other):
+ return self.value < other.value
+ def __eq__(self, other):
+ return self.value == other.value
+ self.assertTrue(A(1) < A(2))
+ self.assertTrue(A(2) > A(1))
+ self.assertTrue(A(1) <= A(2))
+ self.assertTrue(A(2) >= A(1))
+ self.assertTrue(A(2) <= A(2))
+ self.assertTrue(A(2) >= A(2))
+
+ def test_total_ordering_le(self):
+ @functools.total_ordering
+ class A:
+ def __init__(self, value):
+ self.value = value
+ def __le__(self, other):
+ return self.value <= other.value
+ def __eq__(self, other):
+ return self.value == other.value
+ self.assertTrue(A(1) < A(2))
+ self.assertTrue(A(2) > A(1))
+ self.assertTrue(A(1) <= A(2))
+ self.assertTrue(A(2) >= A(1))
+ self.assertTrue(A(2) <= A(2))
+ self.assertTrue(A(2) >= A(2))
+
+ def test_total_ordering_gt(self):
+ @functools.total_ordering
+ class A:
+ def __init__(self, value):
+ self.value = value
+ def __gt__(self, other):
+ return self.value > other.value
+ def __eq__(self, other):
+ return self.value == other.value
+ self.assertTrue(A(1) < A(2))
+ self.assertTrue(A(2) > A(1))
+ self.assertTrue(A(1) <= A(2))
+ self.assertTrue(A(2) >= A(1))
+ self.assertTrue(A(2) <= A(2))
+ self.assertTrue(A(2) >= A(2))
+
+ def test_total_ordering_ge(self):
+ @functools.total_ordering
+ class A:
+ def __init__(self, value):
+ self.value = value
+ def __ge__(self, other):
+ return self.value >= other.value
+ def __eq__(self, other):
+ return self.value == other.value
+ self.assertTrue(A(1) < A(2))
+ self.assertTrue(A(2) > A(1))
+ self.assertTrue(A(1) <= A(2))
+ self.assertTrue(A(2) >= A(1))
+ self.assertTrue(A(2) <= A(2))
+ self.assertTrue(A(2) >= A(2))
+
+ def test_total_ordering_no_overwrite(self):
+ # new methods should not overwrite existing
+ @functools.total_ordering
+ class A(int):
+ pass
+ self.assertTrue(A(1) < A(2))
+ self.assertTrue(A(2) > A(1))
+ self.assertTrue(A(1) <= A(2))
+ self.assertTrue(A(2) >= A(1))
+ self.assertTrue(A(2) <= A(2))
+ self.assertTrue(A(2) >= A(2))
+
+ def test_no_operations_defined(self):
+ with self.assertRaises(ValueError):
+ @functools.total_ordering
+ class A:
+ pass
+
+ def test_bug_10042(self):
+ @functools.total_ordering
+ class TestTO:
+ def __init__(self, value):
+ self.value = value
+ def __eq__(self, other):
+ if isinstance(other, TestTO):
+ return self.value == other.value
+ return False
+ def __lt__(self, other):
+ if isinstance(other, TestTO):
+ return self.value < other.value
+ raise TypeError
+ with self.assertRaises(TypeError):
+ TestTO(8) <= ()
+
+class TestLRU(unittest.TestCase):
+
+ def test_lru(self):
+ def orig(x, y):
+ return 3*x+y
+ f = functools.lru_cache(maxsize=20)(orig)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(maxsize, 20)
+ self.assertEqual(currsize, 0)
+ self.assertEqual(hits, 0)
+ self.assertEqual(misses, 0)
+
+ domain = range(5)
+ for i in range(1000):
+ x, y = choice(domain), choice(domain)
+ actual = f(x, y)
+ expected = orig(x, y)
+ self.assertEqual(actual, expected)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertTrue(hits > misses)
+ self.assertEqual(hits + misses, 1000)
+ self.assertEqual(currsize, 20)
+
+ f.cache_clear() # test clearing
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(hits, 0)
+ self.assertEqual(misses, 0)
+ self.assertEqual(currsize, 0)
+ f(x, y)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(hits, 0)
+ self.assertEqual(misses, 1)
+ self.assertEqual(currsize, 1)
+
+ # Test bypassing the cache
+ self.assertIs(f.__wrapped__, orig)
+ f.__wrapped__(x, y)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(hits, 0)
+ self.assertEqual(misses, 1)
+ self.assertEqual(currsize, 1)
+
+ # test size zero (which means "never-cache")
+ @functools.lru_cache(0)
+ def f():
+ nonlocal f_cnt
+ f_cnt += 1
+ return 20
+ self.assertEqual(f.cache_info().maxsize, 0)
+ f_cnt = 0
+ for i in range(5):
+ self.assertEqual(f(), 20)
+ self.assertEqual(f_cnt, 5)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(hits, 0)
+ self.assertEqual(misses, 5)
+ self.assertEqual(currsize, 0)
+
+ # test size one
+ @functools.lru_cache(1)
+ def f():
+ nonlocal f_cnt
+ f_cnt += 1
+ return 20
+ self.assertEqual(f.cache_info().maxsize, 1)
+ f_cnt = 0
+ for i in range(5):
+ self.assertEqual(f(), 20)
+ self.assertEqual(f_cnt, 1)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(hits, 4)
+ self.assertEqual(misses, 1)
+ self.assertEqual(currsize, 1)
+
+ # test size two
+ @functools.lru_cache(2)
+ def f(x):
+ nonlocal f_cnt
+ f_cnt += 1
+ return x*10
+ self.assertEqual(f.cache_info().maxsize, 2)
+ f_cnt = 0
+ for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
+ # * * * *
+ self.assertEqual(f(x), x*10)
+ self.assertEqual(f_cnt, 4)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(hits, 12)
+ self.assertEqual(misses, 4)
+ self.assertEqual(currsize, 2)
+
+ def test_lru_with_maxsize_none(self):
+ @functools.lru_cache(maxsize=None)
+ def fib(n):
+ if n < 2:
+ return n
+ return fib(n-1) + fib(n-2)
+ self.assertEqual([fib(n) for n in range(16)],
+ [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
+ self.assertEqual(fib.cache_info(),
+ functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
+ fib.cache_clear()
+ self.assertEqual(fib.cache_info(),
+ functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
def test_main(verbose=None):
- import sys
test_classes = (
TestPartial,
TestPartialSubclass,
TestPythonPartial,
TestUpdateWrapper,
+ TestTotalOrdering,
TestWraps,
- TestReduce
+ TestReduce,
+ TestLRU,
)
support.run_unittest(*test_classes)