summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_functools.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_functools.py')
-rw-r--r--Lib/test/test_functools.py292
1 files changed, 247 insertions, 45 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index c0d24d8c..d822b2d 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -1,5 +1,6 @@
import abc
import collections
+import copy
from itertools import permutations
import pickle
from random import choice
@@ -7,6 +8,10 @@ import sys
from test import support
import unittest
from weakref import proxy
+try:
+ import threading
+except ImportError:
+ threading = None
import functools
@@ -133,6 +138,25 @@ class TestPartial:
join = self.partial(''.join)
self.assertEqual(join(data), '0123456789')
+ def test_nested_optimization(self):
+ partial = self.partial
+ inner = partial(signature, 'asdf')
+ nested = partial(inner, bar=True)
+ flat = partial(signature, 'asdf', bar=True)
+ self.assertEqual(signature(nested), signature(flat))
+
+ def test_nested_partial_with_attribute(self):
+ # see issue 25137
+ partial = self.partial
+
+ def foo(bar):
+ return bar
+
+ p = partial(foo, 'first')
+ p2 = partial(p, 'second')
+ p2.new_attr = 'spam'
+ self.assertEqual(p2.new_attr, 'spam')
+
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestPartialC(TestPartial, unittest.TestCase):
@@ -224,6 +248,9 @@ class TestPartialCSubclass(TestPartialC):
if c_functools:
partial = PartialSubclass
+ # partial subclasses are not optimized for nested calls
+ test_nested_optimization = None
+
class TestPartialMethod(unittest.TestCase):
@@ -884,12 +911,30 @@ class TestTotalOrdering(unittest.TestCase):
with self.assertRaises(TypeError):
a <= b
-class TestLRU(unittest.TestCase):
+ def test_pickle(self):
+ for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
+ for name in '__lt__', '__gt__', '__le__', '__ge__':
+ with self.subTest(method=name, proto=proto):
+ method = getattr(Orderable_LT, name)
+ method_copy = pickle.loads(pickle.dumps(method, proto))
+ self.assertIs(method_copy, method)
+
+@functools.total_ordering
+class Orderable_LT:
+ def __init__(self, value):
+ self.value = value
+ def __lt__(self, other):
+ return self.value < other.value
+ def __eq__(self, other):
+ return self.value == other.value
+
+
+class TestLRU:
def test_lru(self):
def orig(x, y):
return 3 * x + y
- f = functools.lru_cache(maxsize=20)(orig)
+ f = self.module.lru_cache(maxsize=20)(orig)
hits, misses, maxsize, currsize = f.cache_info()
self.assertEqual(maxsize, 20)
self.assertEqual(currsize, 0)
@@ -927,7 +972,7 @@ class TestLRU(unittest.TestCase):
self.assertEqual(currsize, 1)
# test size zero (which means "never-cache")
- @functools.lru_cache(0)
+ @self.module.lru_cache(0)
def f():
nonlocal f_cnt
f_cnt += 1
@@ -943,7 +988,7 @@ class TestLRU(unittest.TestCase):
self.assertEqual(currsize, 0)
# test size one
- @functools.lru_cache(1)
+ @self.module.lru_cache(1)
def f():
nonlocal f_cnt
f_cnt += 1
@@ -959,7 +1004,7 @@ class TestLRU(unittest.TestCase):
self.assertEqual(currsize, 1)
# test size two
- @functools.lru_cache(2)
+ @self.module.lru_cache(2)
def f(x):
nonlocal f_cnt
f_cnt += 1
@@ -976,7 +1021,7 @@ class TestLRU(unittest.TestCase):
self.assertEqual(currsize, 2)
def test_lru_with_maxsize_none(self):
- @functools.lru_cache(maxsize=None)
+ @self.module.lru_cache(maxsize=None)
def fib(n):
if n < 2:
return n
@@ -984,17 +1029,26 @@ class TestLRU(unittest.TestCase):
self.assertEqual([fib(n) for n in range(16)],
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
self.assertEqual(fib.cache_info(),
- functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
+ self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
fib.cache_clear()
self.assertEqual(fib.cache_info(),
- functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
+ self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
+
+ def test_lru_with_maxsize_negative(self):
+ @self.module.lru_cache(maxsize=-10)
+ def eq(n):
+ return n
+ for i in (0, 1):
+ self.assertEqual([eq(n) for n in range(150)], list(range(150)))
+ self.assertEqual(eq.cache_info(),
+ self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
def test_lru_with_exceptions(self):
# Verify that user_function exceptions get passed through without
# creating a hard-to-read chained exception.
# http://bugs.python.org/issue13177
for maxsize in (None, 128):
- @functools.lru_cache(maxsize)
+ @self.module.lru_cache(maxsize)
def func(i):
return 'abc'[i]
self.assertEqual(func(0), 'a')
@@ -1007,7 +1061,7 @@ class TestLRU(unittest.TestCase):
def test_lru_with_types(self):
for maxsize in (None, 128):
- @functools.lru_cache(maxsize=maxsize, typed=True)
+ @self.module.lru_cache(maxsize=maxsize, typed=True)
def square(x):
return x * x
self.assertEqual(square(3), 9)
@@ -1022,7 +1076,7 @@ class TestLRU(unittest.TestCase):
self.assertEqual(square.cache_info().misses, 4)
def test_lru_with_keyword_args(self):
- @functools.lru_cache()
+ @self.module.lru_cache()
def fib(n):
if n < 2:
return n
@@ -1032,13 +1086,13 @@ class TestLRU(unittest.TestCase):
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
)
self.assertEqual(fib.cache_info(),
- functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
+ self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
fib.cache_clear()
self.assertEqual(fib.cache_info(),
- functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
+ self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
def test_lru_with_keyword_args_maxsize_none(self):
- @functools.lru_cache(maxsize=None)
+ @self.module.lru_cache(maxsize=None)
def fib(n):
if n < 2:
return n
@@ -1046,15 +1100,100 @@ class TestLRU(unittest.TestCase):
self.assertEqual([fib(n=number) for number in range(16)],
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
self.assertEqual(fib.cache_info(),
- functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
+ self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
fib.cache_clear()
self.assertEqual(fib.cache_info(),
- functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
+ self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
+
+ def test_lru_cache_decoration(self):
+ def f(zomg: 'zomg_annotation'):
+ """f doc string"""
+ return 42
+ g = self.module.lru_cache()(f)
+ for attr in self.module.WRAPPER_ASSIGNMENTS:
+ self.assertEqual(getattr(g, attr), getattr(f, attr))
+
+ @unittest.skipUnless(threading, 'This test requires threading.')
+ def test_lru_cache_threaded(self):
+ n, m = 5, 11
+ def orig(x, y):
+ return 3 * x + y
+ f = self.module.lru_cache(maxsize=n*m)(orig)
+ hits, misses, maxsize, currsize = f.cache_info()
+ self.assertEqual(currsize, 0)
+
+ start = threading.Event()
+ def full(k):
+ start.wait(10)
+ for _ in range(m):
+ self.assertEqual(f(k, 0), orig(k, 0))
+
+ def clear():
+ start.wait(10)
+ for _ in range(2*m):
+ f.cache_clear()
+
+ orig_si = sys.getswitchinterval()
+ sys.setswitchinterval(1e-6)
+ try:
+ # create n threads in order to fill cache
+ threads = [threading.Thread(target=full, args=[k])
+ for k in range(n)]
+ with support.start_threads(threads):
+ start.set()
+
+ hits, misses, maxsize, currsize = f.cache_info()
+ if self.module is py_functools:
+ # XXX: Why can be not equal?
+ self.assertLessEqual(misses, n)
+ self.assertLessEqual(hits, m*n - misses)
+ else:
+ self.assertEqual(misses, n)
+ self.assertEqual(hits, m*n - misses)
+ self.assertEqual(currsize, n)
+
+ # create n threads in order to fill cache and 1 to clear it
+ threads = [threading.Thread(target=clear)]
+ threads += [threading.Thread(target=full, args=[k])
+ for k in range(n)]
+ start.clear()
+ with support.start_threads(threads):
+ start.set()
+ finally:
+ sys.setswitchinterval(orig_si)
+
+ @unittest.skipUnless(threading, 'This test requires threading.')
+ def test_lru_cache_threaded2(self):
+ # Simultaneous call with the same arguments
+ n, m = 5, 7
+ start = threading.Barrier(n+1)
+ pause = threading.Barrier(n+1)
+ stop = threading.Barrier(n+1)
+ @self.module.lru_cache(maxsize=m*n)
+ def f(x):
+ pause.wait(10)
+ return 3 * x
+ self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
+ def test():
+ for i in range(m):
+ start.wait(10)
+ self.assertEqual(f(i), 3 * i)
+ stop.wait(10)
+ threads = [threading.Thread(target=test) for k in range(n)]
+ with support.start_threads(threads):
+ for i in range(m):
+ start.wait(10)
+ stop.reset()
+ pause.wait(10)
+ start.reset()
+ stop.wait(10)
+ pause.reset()
+ self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
def test_need_for_rlock(self):
# This will deadlock on an LRU cache that uses a regular lock
- @functools.lru_cache(maxsize=10)
+ @self.module.lru_cache(maxsize=10)
def test_func(x):
'Used to demonstrate a reentrant lru_cache call within a single thread'
return x
@@ -1082,6 +1221,96 @@ class TestLRU(unittest.TestCase):
def f():
pass
+ def test_lru_method(self):
+ class X(int):
+ f_cnt = 0
+ @self.module.lru_cache(2)
+ def f(self, x):
+ self.f_cnt += 1
+ return x*10+self
+ a = X(5)
+ b = X(5)
+ c = X(7)
+ self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
+
+ for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
+ self.assertEqual(a.f(x), x*10 + 5)
+ self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
+ self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
+
+ for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
+ self.assertEqual(b.f(x), x*10 + 5)
+ self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
+ self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
+
+ for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
+ self.assertEqual(c.f(x), x*10 + 7)
+ self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
+ self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
+
+ self.assertEqual(a.f.cache_info(), X.f.cache_info())
+ self.assertEqual(b.f.cache_info(), X.f.cache_info())
+ self.assertEqual(c.f.cache_info(), X.f.cache_info())
+
+ def test_pickle(self):
+ cls = self.__class__
+ for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(proto=proto, func=f):
+ f_copy = pickle.loads(pickle.dumps(f, proto))
+ self.assertIs(f_copy, f)
+
+ def test_copy(self):
+ cls = self.__class__
+ for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
+ with self.subTest(func=f):
+ f_copy = copy.copy(f)
+ self.assertIs(f_copy, f)
+
+ def test_deepcopy(self):
+ cls = self.__class__
+ for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
+ with self.subTest(func=f):
+ f_copy = copy.deepcopy(f)
+ self.assertIs(f_copy, f)
+
+
+@py_functools.lru_cache()
+def py_cached_func(x, y):
+ return 3 * x + y
+
+@c_functools.lru_cache()
+def c_cached_func(x, y):
+ return 3 * x + y
+
+
+class TestLRUPy(TestLRU, unittest.TestCase):
+ module = py_functools
+ cached_func = py_cached_func,
+
+ @module.lru_cache()
+ def cached_meth(self, x, y):
+ return 3 * x + y
+
+ @staticmethod
+ @module.lru_cache()
+ def cached_staticmeth(x, y):
+ return 3 * x + y
+
+
+class TestLRUC(TestLRU, unittest.TestCase):
+ module = c_functools
+ cached_func = c_cached_func,
+
+ @module.lru_cache()
+ def cached_meth(self, x, y):
+ return 3 * x + y
+
+ @staticmethod
+ @module.lru_cache()
+ def cached_staticmeth(x, y):
+ return 3 * x + y
+
class TestSingleDispatch(unittest.TestCase):
def test_simple_overloads(self):
@@ -1576,32 +1805,5 @@ class TestSingleDispatch(unittest.TestCase):
functools.WeakKeyDictionary = _orig_wkd
-def test_main(verbose=None):
- test_classes = (
- TestPartialC,
- TestPartialPy,
- TestPartialCSubclass,
- TestPartialMethod,
- TestUpdateWrapper,
- TestTotalOrdering,
- TestCmpToKeyC,
- TestCmpToKeyPy,
- TestWraps,
- TestReduce,
- TestLRU,
- TestSingleDispatch,
- )
- support.run_unittest(*test_classes)
-
- # verify reference counting
- if verbose and hasattr(sys, "gettotalrefcount"):
- import gc
- counts = [None] * 5
- for i in range(len(counts)):
- support.run_unittest(*test_classes)
- gc.collect()
- counts[i] = sys.gettotalrefcount()
- print(counts)
-
if __name__ == '__main__':
- test_main(verbose=True)
+ unittest.main()