summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_functools.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_functools.py')
-rw-r--r--Lib/test/test_functools.py694
1 files changed, 614 insertions, 80 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index db1e934..ab76efb 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -1,57 +1,54 @@
-import functools
import collections
+from itertools import permutations
+import pickle
+from random import choice
import sys
-import unittest
from test import support
+import unittest
from weakref import proxy
-import pickle
-from random import choice
-@staticmethod
-def PythonPartial(func, *args, **keywords):
- 'Pure Python approximation of partial()'
- def newfunc(*fargs, **fkeywords):
- newkeywords = keywords.copy()
- newkeywords.update(fkeywords)
- return func(*(args + fargs), **newkeywords)
- newfunc.func = func
- newfunc.args = args
- newfunc.keywords = keywords
- return newfunc
+import functools
+
+py_functools = support.import_fresh_module('functools', blocked=['_functools'])
+c_functools = support.import_fresh_module('functools', fresh=['_functools'])
+
+decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
+
def capture(*args, **kw):
"""capture all positional and keyword arguments"""
return args, kw
+
def signature(part):
""" return the signature of a partial object """
return (part.func, part.args, part.keywords, part.__dict__)
-class TestPartial(unittest.TestCase):
- thetype = functools.partial
+class TestPartial:
def test_basic_examples(self):
- p = self.thetype(capture, 1, 2, a=10, b=20)
+ p = self.partial(capture, 1, 2, a=10, b=20)
+ self.assertTrue(callable(p))
self.assertEqual(p(3, 4, b=30, c=40),
((1, 2, 3, 4), dict(a=10, b=30, c=40)))
- p = self.thetype(map, lambda x: x*10)
+ p = self.partial(map, lambda x: x*10)
self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
def test_attributes(self):
- p = self.thetype(capture, 1, 2, a=10, b=20)
+ p = self.partial(capture, 1, 2, a=10, b=20)
# attributes should be readable
self.assertEqual(p.func, capture)
self.assertEqual(p.args, (1, 2))
self.assertEqual(p.keywords, dict(a=10, b=20))
# attributes should not be writable
- if not isinstance(self.thetype, type):
+ if not isinstance(self.partial, type):
return
self.assertRaises(AttributeError, setattr, p, 'func', map)
self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
- p = self.thetype(hex)
+ p = self.partial(hex)
try:
del p.__dict__
except TypeError:
@@ -60,9 +57,9 @@ class TestPartial(unittest.TestCase):
self.fail('partial object allowed __dict__ to be deleted')
def test_argument_checking(self):
- self.assertRaises(TypeError, self.thetype) # need at least a func arg
+ self.assertRaises(TypeError, self.partial) # need at least a func arg
try:
- self.thetype(2)()
+ self.partial(2)()
except TypeError:
pass
else:
@@ -73,7 +70,7 @@ class TestPartial(unittest.TestCase):
def func(a=10, b=20):
return a
d = {'a':3}
- p = self.thetype(func, a=5)
+ p = self.partial(func, a=5)
self.assertEqual(p(**d), 3)
self.assertEqual(d, {'a':3})
p(b=7)
@@ -82,20 +79,20 @@ class TestPartial(unittest.TestCase):
def test_arg_combinations(self):
# exercise special code paths for zero args in either partial
# object or the caller
- p = self.thetype(capture)
+ p = self.partial(capture)
self.assertEqual(p(), ((), {}))
self.assertEqual(p(1,2), ((1,2), {}))
- p = self.thetype(capture, 1, 2)
+ p = self.partial(capture, 1, 2)
self.assertEqual(p(), ((1,2), {}))
self.assertEqual(p(3,4), ((1,2,3,4), {}))
def test_kw_combinations(self):
# exercise special code paths for no keyword args in
# either the partial object or the caller
- p = self.thetype(capture)
+ p = self.partial(capture)
self.assertEqual(p(), ((), {}))
self.assertEqual(p(a=1), ((), {'a':1}))
- p = self.thetype(capture, a=1)
+ p = self.partial(capture, a=1)
self.assertEqual(p(), ((), {'a':1}))
self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
# keyword args in the call override those in the partial object
@@ -104,7 +101,7 @@ class TestPartial(unittest.TestCase):
def test_positional(self):
# make sure positional arguments are captured correctly
for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
- p = self.thetype(capture, *args)
+ p = self.partial(capture, *args)
expected = args + ('x',)
got, empty = p('x')
self.assertTrue(expected == got and empty == {})
@@ -112,14 +109,14 @@ class TestPartial(unittest.TestCase):
def test_keyword(self):
# make sure keyword arguments are captured correctly
for a in ['a', 0, None, 3.5]:
- p = self.thetype(capture, a=a)
+ p = self.partial(capture, a=a)
expected = {'a':a,'x':None}
empty, got = p(x=None)
self.assertTrue(expected == got and empty == ())
def test_no_side_effects(self):
# make sure there are no side effects that affect subsequent calls
- p = self.thetype(capture, 0, a=1)
+ p = self.partial(capture, 0, a=1)
args1, kw1 = p(1, b=2)
self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
args2, kw2 = p()
@@ -128,13 +125,13 @@ class TestPartial(unittest.TestCase):
def test_error_propagation(self):
def f(x, y):
x / y
- self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
- self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
- self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
- self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
+ self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
+ self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
+ self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
+ self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
def test_weakref(self):
- f = self.thetype(int, base=16)
+ f = self.partial(int, base=16)
p = proxy(f)
self.assertEqual(f.func, p.func)
f = None
@@ -142,39 +139,45 @@ class TestPartial(unittest.TestCase):
def test_with_bound_and_unbound_methods(self):
data = list(map(str, range(10)))
- join = self.thetype(str.join, '')
+ join = self.partial(str.join, '')
self.assertEqual(join(data), '0123456789')
- join = self.thetype(''.join)
+ join = self.partial(''.join)
self.assertEqual(join(data), '0123456789')
+
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
+class TestPartialC(TestPartial, unittest.TestCase):
+ if c_functools:
+ partial = c_functools.partial
+
def test_repr(self):
args = (object(), object())
args_repr = ', '.join(repr(a) for a in args)
kwargs = {'a': object(), 'b': object()}
kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
- if self.thetype is functools.partial:
+ if self.partial is c_functools.partial:
name = 'functools.partial'
else:
- name = self.thetype.__name__
+ name = self.partial.__name__
- f = self.thetype(capture)
+ f = self.partial(capture)
self.assertEqual('{}({!r})'.format(name, capture),
repr(f))
- f = self.thetype(capture, *args)
+ f = self.partial(capture, *args)
self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
repr(f))
- f = self.thetype(capture, **kwargs)
+ f = self.partial(capture, **kwargs)
self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
repr(f))
- f = self.thetype(capture, *args, **kwargs)
+ f = self.partial(capture, *args, **kwargs)
self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
repr(f))
def test_pickle(self):
- f = self.thetype(signature, 'asdf', bar=True)
+ f = self.partial(signature, 'asdf', bar=True)
f.add_something_to__dict__ = True
f_copy = pickle.loads(pickle.dumps(f))
self.assertEqual(signature(f), signature(f_copy))
@@ -193,28 +196,26 @@ class TestPartial(unittest.TestCase):
return {}
raise IndexError
- f = self.thetype(object)
+ f = self.partial(object)
self.assertRaisesRegex(SystemError,
"new style getargs format but argument is not a tuple",
f.__setstate__, BadSequence())
-class PartialSubclass(functools.partial):
- pass
-class TestPartialSubclass(TestPartial):
+class TestPartialPy(TestPartial, unittest.TestCase):
+ partial = staticmethod(py_functools.partial)
- thetype = PartialSubclass
-class TestPythonPartial(TestPartial):
+if c_functools:
+ class PartialSubclass(c_functools.partial):
+ pass
- thetype = PythonPartial
- # the python version hasn't a nice repr
- def test_repr(self): pass
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
+class TestPartialCSubclass(TestPartialC):
+ if c_functools:
+ partial = PartialSubclass
- # the python version isn't picklable
- def test_pickle(self): pass
- def test_setstate_refcount(self): pass
class TestUpdateWrapper(unittest.TestCase):
@@ -223,19 +224,26 @@ class TestUpdateWrapper(unittest.TestCase):
updated=functools.WRAPPER_UPDATES):
# Check attributes were assigned
for name in assigned:
- self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
+ self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
# Check attributes were updated
for name in updated:
wrapper_attr = getattr(wrapper, name)
wrapped_attr = getattr(wrapped, name)
for key in wrapped_attr:
- self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
+ if name == "__dict__" and key == "__wrapped__":
+ # __wrapped__ is overwritten by the update code
+ continue
+ self.assertIs(wrapped_attr[key], wrapper_attr[key])
+ # Check __wrapped__
+ self.assertIs(wrapper.__wrapped__, wrapped)
+
def _default_update(self):
def f(a:'This is a new annotation'):
"""This is a test"""
pass
f.attr = 'This is also a test'
+ f.__wrapped__ = "This is a bald faced lie"
def wrapper(b:'This is the prior annotation'):
pass
functools.update_wrapper(wrapper, f)
@@ -322,6 +330,7 @@ class TestUpdateWrapper(unittest.TestCase):
self.assertTrue(wrapper.__doc__.startswith('max('))
self.assertEqual(wrapper.__annotations__, {})
+
class TestWraps(TestUpdateWrapper):
def _default_update(self):
@@ -329,14 +338,15 @@ class TestWraps(TestUpdateWrapper):
"""This is a test"""
pass
f.attr = 'This is also a test'
+ f.__wrapped__ = "This is still a bald faced lie"
@functools.wraps(f)
def wrapper():
pass
- self.check_wrapper(wrapper, f)
return wrapper, f
def test_default_update(self):
wrapper, f = self._default_update()
+ self.check_wrapper(wrapper, f)
self.assertEqual(wrapper.__name__, 'f')
self.assertEqual(wrapper.__qualname__, f.__qualname__)
self.assertEqual(wrapper.attr, 'This is also a test')
@@ -382,6 +392,7 @@ class TestWraps(TestUpdateWrapper):
self.assertEqual(wrapper.attr, 'This is a different test')
self.assertEqual(wrapper.dict_attr, f.dict_attr)
+
class TestReduce(unittest.TestCase):
func = functools.reduce
@@ -462,24 +473,29 @@ class TestReduce(unittest.TestCase):
d = {"one": 1, "two": 2, "three": 3}
self.assertEqual(self.func(add, d), "".join(d.keys()))
-class TestCmpToKey(unittest.TestCase):
+
+class TestCmpToKey:
def test_cmp_to_key(self):
def cmp1(x, y):
return (x > y) - (x < y)
- key = functools.cmp_to_key(cmp1)
+ key = self.cmp_to_key(cmp1)
self.assertEqual(key(3), key(3))
self.assertGreater(key(3), key(1))
+ self.assertGreaterEqual(key(3), key(3))
+
def cmp2(x, y):
return int(x) - int(y)
- key = functools.cmp_to_key(cmp2)
+ key = self.cmp_to_key(cmp2)
self.assertEqual(key(4.0), key('4'))
self.assertLess(key(2), key('35'))
+ self.assertLessEqual(key(2), key('35'))
+ self.assertNotEqual(key(2), key('35'))
def test_cmp_to_key_arguments(self):
def cmp1(x, y):
return (x > y) - (x < y)
- key = functools.cmp_to_key(mycmp=cmp1)
+ key = self.cmp_to_key(mycmp=cmp1)
self.assertEqual(key(obj=3), key(obj=3))
self.assertGreater(key(obj=3), key(obj=1))
with self.assertRaises((TypeError, AttributeError)):
@@ -487,10 +503,10 @@ class TestCmpToKey(unittest.TestCase):
with self.assertRaises((TypeError, AttributeError)):
1 < key(3) # lhs is not a K object
with self.assertRaises(TypeError):
- key = functools.cmp_to_key() # too few args
+ key = self.cmp_to_key() # too few args
with self.assertRaises(TypeError):
- key = functools.cmp_to_key(cmp1, None) # too many args
- key = functools.cmp_to_key(cmp1)
+ key = self.cmp_to_key(cmp1, None) # too many args
+ key = self.cmp_to_key(cmp1)
with self.assertRaises(TypeError):
key() # too few args
with self.assertRaises(TypeError):
@@ -499,7 +515,7 @@ class TestCmpToKey(unittest.TestCase):
def test_bad_cmp(self):
def cmp1(x, y):
raise ZeroDivisionError
- key = functools.cmp_to_key(cmp1)
+ key = self.cmp_to_key(cmp1)
with self.assertRaises(ZeroDivisionError):
key(3) > key(1)
@@ -514,13 +530,13 @@ class TestCmpToKey(unittest.TestCase):
def test_obj_field(self):
def cmp1(x, y):
return (x > y) - (x < y)
- key = functools.cmp_to_key(mycmp=cmp1)
+ key = self.cmp_to_key(mycmp=cmp1)
self.assertEqual(key(50).obj, 50)
def test_sort_int(self):
def mycmp(x, y):
return y - x
- self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
+ self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
[4, 3, 2, 1, 0])
def test_sort_int_str(self):
@@ -528,18 +544,29 @@ class TestCmpToKey(unittest.TestCase):
x, y = int(x), int(y)
return (x > y) - (x < y)
values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
- values = sorted(values, key=functools.cmp_to_key(mycmp))
+ values = sorted(values, key=self.cmp_to_key(mycmp))
self.assertEqual([int(value) for value in values],
[0, 1, 1, 2, 3, 4, 5, 7, 10])
def test_hash(self):
def mycmp(x, y):
return y - x
- key = functools.cmp_to_key(mycmp)
+ key = self.cmp_to_key(mycmp)
k = key(10)
self.assertRaises(TypeError, hash, k)
self.assertNotIsInstance(k, collections.Hashable)
+
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
+class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
+ if c_functools:
+ cmp_to_key = c_functools.cmp_to_key
+
+
+class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
+ cmp_to_key = staticmethod(py_functools.cmp_to_key)
+
+
class TestTotalOrdering(unittest.TestCase):
def test_total_ordering_lt(self):
@@ -640,11 +667,12 @@ class TestTotalOrdering(unittest.TestCase):
with self.assertRaises(TypeError):
TestTO(8) <= ()
+
class TestLRU(unittest.TestCase):
def test_lru(self):
def orig(x, y):
- return 3*x+y
+ return 3 * x + y
f = functools.lru_cache(maxsize=20)(orig)
hits, misses, maxsize, currsize = f.cache_info()
self.assertEqual(maxsize, 20)
@@ -749,7 +777,7 @@ class TestLRU(unittest.TestCase):
# Verify that user_function exceptions get passed through without
# creating a hard-to-read chained exception.
# http://bugs.python.org/issue13177
- for maxsize in (None, 100):
+ for maxsize in (None, 128):
@functools.lru_cache(maxsize)
def func(i):
return 'abc'[i]
@@ -762,7 +790,7 @@ class TestLRU(unittest.TestCase):
func(15)
def test_lru_with_types(self):
- for maxsize in (None, 100):
+ for maxsize in (None, 128):
@functools.lru_cache(maxsize=maxsize, typed=True)
def square(x):
return x * x
@@ -777,6 +805,36 @@ class TestLRU(unittest.TestCase):
self.assertEqual(square.cache_info().hits, 4)
self.assertEqual(square.cache_info().misses, 4)
+ def test_lru_with_keyword_args(self):
+ @functools.lru_cache()
+ def fib(n):
+ if n < 2:
+ return n
+ return fib(n=n-1) + fib(n=n-2)
+ self.assertEqual(
+ [fib(n=number) for number in range(16)],
+ [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
+ )
+ self.assertEqual(fib.cache_info(),
+ functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
+ fib.cache_clear()
+ self.assertEqual(fib.cache_info(),
+ functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
+
+ def test_lru_with_keyword_args_maxsize_none(self):
+ @functools.lru_cache(maxsize=None)
+ def fib(n):
+ if n < 2:
+ return n
+ return fib(n=n-1) + fib(n=n-2)
+ self.assertEqual([fib(n=number) for number in range(16)],
+ [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
+ self.assertEqual(fib.cache_info(),
+ functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
+ fib.cache_clear()
+ self.assertEqual(fib.cache_info(),
+ functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
+
def test_need_for_rlock(self):
# This will deadlock on an LRU cache that uses a regular lock
@@ -802,17 +860,493 @@ class TestLRU(unittest.TestCase):
DoubleEq(2)) # Verify the correct return value
+class TestSingleDispatch(unittest.TestCase):
+ def test_simple_overloads(self):
+ @functools.singledispatch
+ def g(obj):
+ return "base"
+ def g_int(i):
+ return "integer"
+ g.register(int, g_int)
+ self.assertEqual(g("str"), "base")
+ self.assertEqual(g(1), "integer")
+ self.assertEqual(g([1,2,3]), "base")
+
+ def test_mro(self):
+ @functools.singledispatch
+ def g(obj):
+ return "base"
+ class A:
+ pass
+ class C(A):
+ pass
+ class B(A):
+ pass
+ class D(C, B):
+ pass
+ def g_A(a):
+ return "A"
+ def g_B(b):
+ return "B"
+ g.register(A, g_A)
+ g.register(B, g_B)
+ self.assertEqual(g(A()), "A")
+ self.assertEqual(g(B()), "B")
+ self.assertEqual(g(C()), "A")
+ self.assertEqual(g(D()), "B")
+
+ def test_register_decorator(self):
+ @functools.singledispatch
+ def g(obj):
+ return "base"
+ @g.register(int)
+ def g_int(i):
+ return "int %s" % (i,)
+ self.assertEqual(g(""), "base")
+ self.assertEqual(g(12), "int 12")
+ self.assertIs(g.dispatch(int), g_int)
+ self.assertIs(g.dispatch(object), g.dispatch(str))
+ # Note: in the assert above this is not g.
+ # @singledispatch returns the wrapper.
+
+ def test_wrapping_attributes(self):
+ @functools.singledispatch
+ def g(obj):
+ "Simple test"
+ return "Test"
+ self.assertEqual(g.__name__, "g")
+ self.assertEqual(g.__doc__, "Simple test")
+
+ @unittest.skipUnless(decimal, 'requires _decimal')
+ @support.cpython_only
+ def test_c_classes(self):
+ @functools.singledispatch
+ def g(obj):
+ return "base"
+ @g.register(decimal.DecimalException)
+ def _(obj):
+ return obj.args
+ subn = decimal.Subnormal("Exponent < Emin")
+ rnd = decimal.Rounded("Number got rounded")
+ self.assertEqual(g(subn), ("Exponent < Emin",))
+ self.assertEqual(g(rnd), ("Number got rounded",))
+ @g.register(decimal.Subnormal)
+ def _(obj):
+ return "Too small to care."
+ self.assertEqual(g(subn), "Too small to care.")
+ self.assertEqual(g(rnd), ("Number got rounded",))
+
+ def test_compose_mro(self):
+ # None of the examples in this test depend on haystack ordering.
+ c = collections
+ mro = functools._compose_mro
+ bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
+ for haystack in permutations(bases):
+ m = mro(dict, haystack)
+ self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
+ c.Iterable, c.Container, object])
+ bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
+ for haystack in permutations(bases):
+ m = mro(c.ChainMap, haystack)
+ self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
+ c.Sized, c.Iterable, c.Container, object])
+
+ # If there's a generic function with implementations registered for
+ # both Sized and Container, passing a defaultdict to it results in an
+ # ambiguous dispatch which will cause a RuntimeError (see
+ # test_mro_conflicts).
+ bases = [c.Container, c.Sized, str]
+ for haystack in permutations(bases):
+ m = mro(c.defaultdict, [c.Sized, c.Container, str])
+ self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
+ object])
+
+ # MutableSequence below is registered directly on D. In other words, it
+ # preceeds MutableMapping which means single dispatch will always
+ # choose MutableSequence here.
+ class D(c.defaultdict):
+ pass
+ c.MutableSequence.register(D)
+ bases = [c.MutableSequence, c.MutableMapping]
+ for haystack in permutations(bases):
+ m = mro(D, bases)
+ self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
+ c.defaultdict, dict, c.MutableMapping,
+ c.Mapping, c.Sized, c.Iterable, c.Container,
+ object])
+
+ # Container and Callable are registered on different base classes and
+ # a generic function supporting both should always pick the Callable
+ # implementation if a C instance is passed.
+ class C(c.defaultdict):
+ def __call__(self):
+ pass
+ bases = [c.Sized, c.Callable, c.Container, c.Mapping]
+ for haystack in permutations(bases):
+ m = mro(C, haystack)
+ self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
+ c.Sized, c.Iterable, c.Container, object])
+
+ def test_register_abc(self):
+ c = collections
+ d = {"a": "b"}
+ l = [1, 2, 3]
+ s = {object(), None}
+ f = frozenset(s)
+ t = (1, 2, 3)
+ @functools.singledispatch
+ def g(obj):
+ return "base"
+ self.assertEqual(g(d), "base")
+ self.assertEqual(g(l), "base")
+ self.assertEqual(g(s), "base")
+ self.assertEqual(g(f), "base")
+ self.assertEqual(g(t), "base")
+ g.register(c.Sized, lambda obj: "sized")
+ self.assertEqual(g(d), "sized")
+ self.assertEqual(g(l), "sized")
+ self.assertEqual(g(s), "sized")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sized")
+ g.register(c.MutableMapping, lambda obj: "mutablemapping")
+ self.assertEqual(g(d), "mutablemapping")
+ self.assertEqual(g(l), "sized")
+ self.assertEqual(g(s), "sized")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sized")
+ g.register(c.ChainMap, lambda obj: "chainmap")
+ self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
+ self.assertEqual(g(l), "sized")
+ self.assertEqual(g(s), "sized")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sized")
+ g.register(c.MutableSequence, lambda obj: "mutablesequence")
+ self.assertEqual(g(d), "mutablemapping")
+ self.assertEqual(g(l), "mutablesequence")
+ self.assertEqual(g(s), "sized")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sized")
+ g.register(c.MutableSet, lambda obj: "mutableset")
+ self.assertEqual(g(d), "mutablemapping")
+ self.assertEqual(g(l), "mutablesequence")
+ self.assertEqual(g(s), "mutableset")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sized")
+ g.register(c.Mapping, lambda obj: "mapping")
+ self.assertEqual(g(d), "mutablemapping") # not specific enough
+ self.assertEqual(g(l), "mutablesequence")
+ self.assertEqual(g(s), "mutableset")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sized")
+ g.register(c.Sequence, lambda obj: "sequence")
+ self.assertEqual(g(d), "mutablemapping")
+ self.assertEqual(g(l), "mutablesequence")
+ self.assertEqual(g(s), "mutableset")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sequence")
+ g.register(c.Set, lambda obj: "set")
+ self.assertEqual(g(d), "mutablemapping")
+ self.assertEqual(g(l), "mutablesequence")
+ self.assertEqual(g(s), "mutableset")
+ self.assertEqual(g(f), "set")
+ self.assertEqual(g(t), "sequence")
+ g.register(dict, lambda obj: "dict")
+ self.assertEqual(g(d), "dict")
+ self.assertEqual(g(l), "mutablesequence")
+ self.assertEqual(g(s), "mutableset")
+ self.assertEqual(g(f), "set")
+ self.assertEqual(g(t), "sequence")
+ g.register(list, lambda obj: "list")
+ self.assertEqual(g(d), "dict")
+ self.assertEqual(g(l), "list")
+ self.assertEqual(g(s), "mutableset")
+ self.assertEqual(g(f), "set")
+ self.assertEqual(g(t), "sequence")
+ g.register(set, lambda obj: "concrete-set")
+ self.assertEqual(g(d), "dict")
+ self.assertEqual(g(l), "list")
+ self.assertEqual(g(s), "concrete-set")
+ self.assertEqual(g(f), "set")
+ self.assertEqual(g(t), "sequence")
+ g.register(frozenset, lambda obj: "frozen-set")
+ self.assertEqual(g(d), "dict")
+ self.assertEqual(g(l), "list")
+ self.assertEqual(g(s), "concrete-set")
+ self.assertEqual(g(f), "frozen-set")
+ self.assertEqual(g(t), "sequence")
+ g.register(tuple, lambda obj: "tuple")
+ self.assertEqual(g(d), "dict")
+ self.assertEqual(g(l), "list")
+ self.assertEqual(g(s), "concrete-set")
+ self.assertEqual(g(f), "frozen-set")
+ self.assertEqual(g(t), "tuple")
+
+ def test_c3_abc(self):
+ c = collections
+ mro = functools._c3_mro
+ class A(object):
+ pass
+ class B(A):
+ def __len__(self):
+ return 0 # implies Sized
+ @c.Container.register
+ class C(object):
+ pass
+ class D(object):
+ pass # unrelated
+ class X(D, C, B):
+ def __call__(self):
+ pass # implies Callable
+ expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
+ for abcs in permutations([c.Sized, c.Callable, c.Container]):
+ self.assertEqual(mro(X, abcs=abcs), expected)
+ # unrelated ABCs don't appear in the resulting MRO
+ many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
+ self.assertEqual(mro(X, abcs=many_abcs), expected)
+
+ def test_mro_conflicts(self):
+ c = collections
+ @functools.singledispatch
+ def g(arg):
+ return "base"
+ class O(c.Sized):
+ def __len__(self):
+ return 0
+ o = O()
+ self.assertEqual(g(o), "base")
+ g.register(c.Iterable, lambda arg: "iterable")
+ g.register(c.Container, lambda arg: "container")
+ g.register(c.Sized, lambda arg: "sized")
+ g.register(c.Set, lambda arg: "set")
+ self.assertEqual(g(o), "sized")
+ c.Iterable.register(O)
+ self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
+ c.Container.register(O)
+ self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
+ c.Set.register(O)
+ self.assertEqual(g(o), "set") # because c.Set is a subclass of
+ # c.Sized and c.Container
+ class P:
+ pass
+ p = P()
+ self.assertEqual(g(p), "base")
+ c.Iterable.register(P)
+ self.assertEqual(g(p), "iterable")
+ c.Container.register(P)
+ with self.assertRaises(RuntimeError) as re_one:
+ g(p)
+ self.assertIn(
+ str(re_one.exception),
+ (("Ambiguous dispatch: <class 'collections.abc.Container'> "
+ "or <class 'collections.abc.Iterable'>"),
+ ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
+ "or <class 'collections.abc.Container'>")),
+ )
+ class Q(c.Sized):
+ def __len__(self):
+ return 0
+ q = Q()
+ self.assertEqual(g(q), "sized")
+ c.Iterable.register(Q)
+ self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
+ c.Set.register(Q)
+ self.assertEqual(g(q), "set") # because c.Set is a subclass of
+ # c.Sized and c.Iterable
+ @functools.singledispatch
+ def h(arg):
+ return "base"
+ @h.register(c.Sized)
+ def _(arg):
+ return "sized"
+ @h.register(c.Container)
+ def _(arg):
+ return "container"
+ # Even though Sized and Container are explicit bases of MutableMapping,
+ # this ABC is implicitly registered on defaultdict which makes all of
+ # MutableMapping's bases implicit as well from defaultdict's
+ # perspective.
+ with self.assertRaises(RuntimeError) as re_two:
+ h(c.defaultdict(lambda: 0))
+ self.assertIn(
+ str(re_two.exception),
+ (("Ambiguous dispatch: <class 'collections.abc.Container'> "
+ "or <class 'collections.abc.Sized'>"),
+ ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
+ "or <class 'collections.abc.Container'>")),
+ )
+ class R(c.defaultdict):
+ pass
+ c.MutableSequence.register(R)
+ @functools.singledispatch
+ def i(arg):
+ return "base"
+ @i.register(c.MutableMapping)
+ def _(arg):
+ return "mapping"
+ @i.register(c.MutableSequence)
+ def _(arg):
+ return "sequence"
+ r = R()
+ self.assertEqual(i(r), "sequence")
+ class S:
+ pass
+ class T(S, c.Sized):
+ def __len__(self):
+ return 0
+ t = T()
+ self.assertEqual(h(t), "sized")
+ c.Container.register(T)
+ self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
+ class U:
+ def __len__(self):
+ return 0
+ u = U()
+ self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
+ # from the existence of __len__()
+ c.Container.register(U)
+ # There is no preference for registered versus inferred ABCs.
+ with self.assertRaises(RuntimeError) as re_three:
+ h(u)
+ self.assertIn(
+ str(re_three.exception),
+ (("Ambiguous dispatch: <class 'collections.abc.Container'> "
+ "or <class 'collections.abc.Sized'>"),
+ ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
+ "or <class 'collections.abc.Container'>")),
+ )
+ class V(c.Sized, S):
+ def __len__(self):
+ return 0
+ @functools.singledispatch
+ def j(arg):
+ return "base"
+ @j.register(S)
+ def _(arg):
+ return "s"
+ @j.register(c.Container)
+ def _(arg):
+ return "container"
+ v = V()
+ self.assertEqual(j(v), "s")
+ c.Container.register(V)
+ self.assertEqual(j(v), "container") # because it ends up right after
+ # Sized in the MRO
+
+ def test_cache_invalidation(self):
+ from collections import UserDict
+ class TracingDict(UserDict):
+ def __init__(self, *args, **kwargs):
+ super(TracingDict, self).__init__(*args, **kwargs)
+ self.set_ops = []
+ self.get_ops = []
+ def __getitem__(self, key):
+ result = self.data[key]
+ self.get_ops.append(key)
+ return result
+ def __setitem__(self, key, value):
+ self.set_ops.append(key)
+ self.data[key] = value
+ def clear(self):
+ self.data.clear()
+ _orig_wkd = functools.WeakKeyDictionary
+ td = TracingDict()
+ functools.WeakKeyDictionary = lambda: td
+ c = collections
+ @functools.singledispatch
+ def g(arg):
+ return "base"
+ d = {}
+ l = []
+ self.assertEqual(len(td), 0)
+ self.assertEqual(g(d), "base")
+ self.assertEqual(len(td), 1)
+ self.assertEqual(td.get_ops, [])
+ self.assertEqual(td.set_ops, [dict])
+ self.assertEqual(td.data[dict], g.registry[object])
+ self.assertEqual(g(l), "base")
+ self.assertEqual(len(td), 2)
+ self.assertEqual(td.get_ops, [])
+ self.assertEqual(td.set_ops, [dict, list])
+ self.assertEqual(td.data[dict], g.registry[object])
+ self.assertEqual(td.data[list], g.registry[object])
+ self.assertEqual(td.data[dict], td.data[list])
+ self.assertEqual(g(l), "base")
+ self.assertEqual(g(d), "base")
+ self.assertEqual(td.get_ops, [list, dict])
+ self.assertEqual(td.set_ops, [dict, list])
+ g.register(list, lambda arg: "list")
+ self.assertEqual(td.get_ops, [list, dict])
+ self.assertEqual(len(td), 0)
+ self.assertEqual(g(d), "base")
+ self.assertEqual(len(td), 1)
+ self.assertEqual(td.get_ops, [list, dict])
+ self.assertEqual(td.set_ops, [dict, list, dict])
+ self.assertEqual(td.data[dict],
+ functools._find_impl(dict, g.registry))
+ self.assertEqual(g(l), "list")
+ self.assertEqual(len(td), 2)
+ self.assertEqual(td.get_ops, [list, dict])
+ self.assertEqual(td.set_ops, [dict, list, dict, list])
+ self.assertEqual(td.data[list],
+ functools._find_impl(list, g.registry))
+ class X:
+ pass
+ c.MutableMapping.register(X) # Will not invalidate the cache,
+ # not using ABCs yet.
+ self.assertEqual(g(d), "base")
+ self.assertEqual(g(l), "list")
+ self.assertEqual(td.get_ops, [list, dict, dict, list])
+ self.assertEqual(td.set_ops, [dict, list, dict, list])
+ g.register(c.Sized, lambda arg: "sized")
+ self.assertEqual(len(td), 0)
+ self.assertEqual(g(d), "sized")
+ self.assertEqual(len(td), 1)
+ self.assertEqual(td.get_ops, [list, dict, dict, list])
+ self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
+ self.assertEqual(g(l), "list")
+ self.assertEqual(len(td), 2)
+ self.assertEqual(td.get_ops, [list, dict, dict, list])
+ self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
+ self.assertEqual(g(l), "list")
+ self.assertEqual(g(d), "sized")
+ self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
+ self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
+ g.dispatch(list)
+ g.dispatch(dict)
+ self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
+ list, dict])
+ self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
+ c.MutableSet.register(X) # Will invalidate the cache.
+ self.assertEqual(len(td), 2) # Stale cache.
+ self.assertEqual(g(l), "list")
+ self.assertEqual(len(td), 1)
+ g.register(c.MutableMapping, lambda arg: "mutablemapping")
+ self.assertEqual(len(td), 0)
+ self.assertEqual(g(d), "mutablemapping")
+ self.assertEqual(len(td), 1)
+ self.assertEqual(g(l), "list")
+ self.assertEqual(len(td), 2)
+ g.register(dict, lambda arg: "dict")
+ self.assertEqual(g(d), "dict")
+ self.assertEqual(g(l), "list")
+ g._clear_cache()
+ self.assertEqual(len(td), 0)
+ functools.WeakKeyDictionary = _orig_wkd
+
+
def test_main(verbose=None):
test_classes = (
- TestPartial,
- TestPartialSubclass,
- TestPythonPartial,
+ TestPartialC,
+ TestPartialPy,
+ TestPartialCSubclass,
TestUpdateWrapper,
TestTotalOrdering,
- TestCmpToKey,
+ TestCmpToKeyC,
+ TestCmpToKeyPy,
TestWraps,
TestReduce,
TestLRU,
+ TestSingleDispatch,
)
support.run_unittest(*test_classes)