diff options
Diffstat (limited to 'Lib/test/test_functools.py')
| -rw-r--r-- | Lib/test/test_functools.py | 328 |
1 files changed, 308 insertions, 20 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index d20bafe..73a77d6 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1,8 +1,10 @@ import functools +import sys import unittest from test import support from weakref import proxy import pickle +from random import choice @staticmethod def PythonPartial(func, *args, **keywords): @@ -44,9 +46,17 @@ class TestPartial(unittest.TestCase): # attributes should not be writable if not isinstance(self.thetype, type): return - self.assertRaises(TypeError, setattr, p, 'func', map) - self.assertRaises(TypeError, setattr, p, 'args', (1, 2)) - self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2)) + self.assertRaises(AttributeError, setattr, p, 'func', map) + self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) + self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) + + p = self.thetype(hex) + try: + del p.__dict__ + except TypeError: + pass + else: + self.fail('partial object allowed __dict__ to be deleted') def test_argument_checking(self): self.assertRaises(TypeError, self.thetype) # need at least a func arg @@ -122,15 +132,6 @@ class TestPartial(unittest.TestCase): self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0) self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1) - def test_attributes(self): - p = self.thetype(hex) - try: - del p.__dict__ - except TypeError: - pass - else: - self.fail('partial object allowed __dict__ to be deleted') - def test_weakref(self): f = self.thetype(int, base=16) p = proxy(f) @@ -145,6 +146,32 @@ class TestPartial(unittest.TestCase): join = self.thetype(''.join) self.assertEqual(join(data), '0123456789') + def test_repr(self): + args = (object(), object()) + args_repr = ', '.join(repr(a) for a in args) + kwargs = {'a': object(), 'b': object()} + kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items()) + if self.thetype is functools.partial: + name = 'functools.partial' + else: + name = self.thetype.__name__ + + f = self.thetype(capture) + self.assertEqual('{}({!r})'.format(name, capture), + repr(f)) + + f = self.thetype(capture, *args) + self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr), + repr(f)) + + f = self.thetype(capture, **kwargs) + self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr), + repr(f)) + + f = self.thetype(capture, *args, **kwargs) + self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr), + repr(f)) + def test_pickle(self): f = self.thetype(signature, 'asdf', bar=True) f.add_something_to__dict__ = True @@ -162,6 +189,9 @@ class TestPythonPartial(TestPartial): thetype = PythonPartial + # the python version hasn't a nice repr + def test_repr(self): pass + # the python version isn't picklable def test_pickle(self): pass @@ -180,7 +210,7 @@ class TestUpdateWrapper(unittest.TestCase): for key in wrapped_attr: self.assertTrue(wrapped_attr[key] is wrapper_attr[key]) - def test_default_update(self): + def _default_update(self): def f(a:'This is a new annotation'): """This is a test""" pass @@ -188,13 +218,23 @@ class TestUpdateWrapper(unittest.TestCase): def wrapper(b:'This is the prior annotation'): pass functools.update_wrapper(wrapper, f) + return wrapper, f + + def test_default_update(self): + wrapper, f = self._default_update() self.check_wrapper(wrapper, f) + self.assertIs(wrapper.__wrapped__, f) self.assertEqual(wrapper.__name__, 'f') - self.assertEqual(wrapper.__doc__, 'This is a test') self.assertEqual(wrapper.attr, 'This is also a test') self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') self.assertNotIn('b', wrapper.__annotations__) + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def test_default_update_doc(self): + wrapper, f = self._default_update() + self.assertEqual(wrapper.__doc__, 'This is a test') + def test_no_update(self): def f(): """This is a test""" @@ -226,6 +266,28 @@ class TestUpdateWrapper(unittest.TestCase): self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) + def test_missing_attributes(self): + def f(): + pass + def wrapper(): + pass + wrapper.dict_attr = {} + assign = ('attr',) + update = ('dict_attr',) + # Missing attributes on wrapped object are ignored + functools.update_wrapper(wrapper, f, assign, update) + self.assertNotIn('attr', wrapper.__dict__) + self.assertEqual(wrapper.dict_attr, {}) + # Wrapper must have expected attributes for updating + del wrapper.dict_attr + with self.assertRaises(AttributeError): + functools.update_wrapper(wrapper, f, assign, update) + wrapper.dict_attr = 1 + with self.assertRaises(AttributeError): + functools.update_wrapper(wrapper, f, assign, update) + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") def test_builtin_update(self): # Test for bug #1576241 def wrapper(): @@ -237,7 +299,7 @@ class TestUpdateWrapper(unittest.TestCase): class TestWraps(TestUpdateWrapper): - def test_default_update(self): + def _default_update(self): def f(): """This is a test""" pass @@ -246,10 +308,19 @@ class TestWraps(TestUpdateWrapper): def wrapper(): pass self.check_wrapper(wrapper, f) + return wrapper + + def test_default_update(self): + wrapper = self._default_update() self.assertEqual(wrapper.__name__, 'f') - self.assertEqual(wrapper.__doc__, 'This is a test') self.assertEqual(wrapper.attr, 'This is also a test') + @unittest.skipIf(not sys.flags.optimize <= 1, + "Docstrings are omitted with -O2 and above") + def test_default_update_doc(self): + wrapper = self._default_update() + self.assertEqual(wrapper.__doc__, 'This is a test') + def test_no_update(self): def f(): """This is a test""" @@ -363,18 +434,235 @@ class TestReduce(unittest.TestCase): d = {"one": 1, "two": 2, "three": 3} self.assertEqual(self.func(add, d), "".join(d.keys())) - - +class TestCmpToKey(unittest.TestCase): + def test_cmp_to_key(self): + def mycmp(x, y): + return y - x + self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)), + [4, 3, 2, 1, 0]) + + def test_hash(self): + def mycmp(x, y): + return y - x + key = functools.cmp_to_key(mycmp) + k = key(10) + self.assertRaises(TypeError, hash(k)) + +class TestTotalOrdering(unittest.TestCase): + + def test_total_ordering_lt(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __lt__(self, other): + return self.value < other.value + def __eq__(self, other): + return self.value == other.value + self.assertTrue(A(1) < A(2)) + self.assertTrue(A(2) > A(1)) + self.assertTrue(A(1) <= A(2)) + self.assertTrue(A(2) >= A(1)) + self.assertTrue(A(2) <= A(2)) + self.assertTrue(A(2) >= A(2)) + + def test_total_ordering_le(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __le__(self, other): + return self.value <= other.value + def __eq__(self, other): + return self.value == other.value + self.assertTrue(A(1) < A(2)) + self.assertTrue(A(2) > A(1)) + self.assertTrue(A(1) <= A(2)) + self.assertTrue(A(2) >= A(1)) + self.assertTrue(A(2) <= A(2)) + self.assertTrue(A(2) >= A(2)) + + def test_total_ordering_gt(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __gt__(self, other): + return self.value > other.value + def __eq__(self, other): + return self.value == other.value + self.assertTrue(A(1) < A(2)) + self.assertTrue(A(2) > A(1)) + self.assertTrue(A(1) <= A(2)) + self.assertTrue(A(2) >= A(1)) + self.assertTrue(A(2) <= A(2)) + self.assertTrue(A(2) >= A(2)) + + def test_total_ordering_ge(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __ge__(self, other): + return self.value >= other.value + def __eq__(self, other): + return self.value == other.value + self.assertTrue(A(1) < A(2)) + self.assertTrue(A(2) > A(1)) + self.assertTrue(A(1) <= A(2)) + self.assertTrue(A(2) >= A(1)) + self.assertTrue(A(2) <= A(2)) + self.assertTrue(A(2) >= A(2)) + + def test_total_ordering_no_overwrite(self): + # new methods should not overwrite existing + @functools.total_ordering + class A(int): + pass + self.assertTrue(A(1) < A(2)) + self.assertTrue(A(2) > A(1)) + self.assertTrue(A(1) <= A(2)) + self.assertTrue(A(2) >= A(1)) + self.assertTrue(A(2) <= A(2)) + self.assertTrue(A(2) >= A(2)) + + def test_no_operations_defined(self): + with self.assertRaises(ValueError): + @functools.total_ordering + class A: + pass + + def test_bug_10042(self): + @functools.total_ordering + class TestTO: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, TestTO): + return self.value == other.value + return False + def __lt__(self, other): + if isinstance(other, TestTO): + return self.value < other.value + raise TypeError + with self.assertRaises(TypeError): + TestTO(8) <= () + +class TestLRU(unittest.TestCase): + + def test_lru(self): + def orig(x, y): + return 3*x+y + f = functools.lru_cache(maxsize=20)(orig) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(maxsize, 20) + self.assertEqual(currsize, 0) + self.assertEqual(hits, 0) + self.assertEqual(misses, 0) + + domain = range(5) + for i in range(1000): + x, y = choice(domain), choice(domain) + actual = f(x, y) + expected = orig(x, y) + self.assertEqual(actual, expected) + hits, misses, maxsize, currsize = f.cache_info() + self.assertTrue(hits > misses) + self.assertEqual(hits + misses, 1000) + self.assertEqual(currsize, 20) + + f.cache_clear() # test clearing + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(hits, 0) + self.assertEqual(misses, 0) + self.assertEqual(currsize, 0) + f(x, y) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(hits, 0) + self.assertEqual(misses, 1) + self.assertEqual(currsize, 1) + + # Test bypassing the cache + self.assertIs(f.__wrapped__, orig) + f.__wrapped__(x, y) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(hits, 0) + self.assertEqual(misses, 1) + self.assertEqual(currsize, 1) + + # test size zero (which means "never-cache") + @functools.lru_cache(0) + def f(): + nonlocal f_cnt + f_cnt += 1 + return 20 + self.assertEqual(f.cache_info().maxsize, 0) + f_cnt = 0 + for i in range(5): + self.assertEqual(f(), 20) + self.assertEqual(f_cnt, 5) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(hits, 0) + self.assertEqual(misses, 5) + self.assertEqual(currsize, 0) + + # test size one + @functools.lru_cache(1) + def f(): + nonlocal f_cnt + f_cnt += 1 + return 20 + self.assertEqual(f.cache_info().maxsize, 1) + f_cnt = 0 + for i in range(5): + self.assertEqual(f(), 20) + self.assertEqual(f_cnt, 1) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(hits, 4) + self.assertEqual(misses, 1) + self.assertEqual(currsize, 1) + + # test size two + @functools.lru_cache(2) + def f(x): + nonlocal f_cnt + f_cnt += 1 + return x*10 + self.assertEqual(f.cache_info().maxsize, 2) + f_cnt = 0 + for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: + # * * * * + self.assertEqual(f(x), x*10) + self.assertEqual(f_cnt, 4) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(hits, 12) + self.assertEqual(misses, 4) + self.assertEqual(currsize, 2) + + def test_lru_with_maxsize_none(self): + @functools.lru_cache(maxsize=None) + def fib(n): + if n < 2: + return n + return fib(n-1) + fib(n-2) + self.assertEqual([fib(n) for n in range(16)], + [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) + self.assertEqual(fib.cache_info(), + functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) + fib.cache_clear() + self.assertEqual(fib.cache_info(), + functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) def test_main(verbose=None): - import sys test_classes = ( TestPartial, TestPartialSubclass, TestPythonPartial, TestUpdateWrapper, + TestTotalOrdering, TestWraps, - TestReduce + TestReduce, + TestLRU, ) support.run_unittest(*test_classes) |
