diff options
author | Łukasz Langa <lukasz@langa.pl> | 2013-06-05 10:20:24 (GMT) |
---|---|---|
committer | Łukasz Langa <lukasz@langa.pl> | 2013-06-05 10:20:24 (GMT) |
commit | 6f69251980f385beac9d6f3e8ff4775bd37a1779 (patch) | |
tree | 0e0a667c980df9384ccc984292af51a9b70c6598 /Lib/test/test_functools.py | |
parent | 072318b178f9824de5e0672218495f699dbdce44 (diff) | |
download | cpython-6f69251980f385beac9d6f3e8ff4775bd37a1779.zip cpython-6f69251980f385beac9d6f3e8ff4775bd37a1779.tar.gz cpython-6f69251980f385beac9d6f3e8ff4775bd37a1779.tar.bz2 |
Add reference implementation for PEP 443
PEP accepted: http://mail.python.org/pipermail/python-dev/2013-June/126734.html
Diffstat (limited to 'Lib/test/test_functools.py')
-rw-r--r-- | Lib/test/test_functools.py | 372 |
1 files changed, 369 insertions, 3 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 9b3c31e..a6b1e03 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1,24 +1,30 @@ import collections +from itertools import permutations +import pickle +from random import choice import sys -import unittest from test import support +import unittest from weakref import proxy -import pickle -from random import choice import functools py_functools = support.import_fresh_module('functools', blocked=['_functools']) c_functools = support.import_fresh_module('functools', fresh=['_functools']) +decimal = support.import_fresh_module('decimal', fresh=['_decimal']) + + def capture(*args, **kw): """capture all positional and keyword arguments""" return args, kw + def signature(part): """ return the signature of a partial object """ return (part.func, part.args, part.keywords, part.__dict__) + class TestPartial: def test_basic_examples(self): @@ -138,6 +144,7 @@ class TestPartial: join = self.partial(''.join) self.assertEqual(join(data), '0123456789') + @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestPartialC(TestPartial, unittest.TestCase): if c_functools: @@ -194,18 +201,22 @@ class TestPartialC(TestPartial, unittest.TestCase): "new style getargs format but argument is not a tuple", f.__setstate__, BadSequence()) + class TestPartialPy(TestPartial, unittest.TestCase): partial = staticmethod(py_functools.partial) + if c_functools: class PartialSubclass(c_functools.partial): pass + @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestPartialCSubclass(TestPartialC): if c_functools: partial = PartialSubclass + class TestUpdateWrapper(unittest.TestCase): def check_wrapper(self, wrapper, wrapped, @@ -312,6 +323,7 @@ class TestUpdateWrapper(unittest.TestCase): self.assertTrue(wrapper.__doc__.startswith('max(')) self.assertEqual(wrapper.__annotations__, {}) + class TestWraps(TestUpdateWrapper): def _default_update(self): @@ -372,6 +384,7 @@ class TestWraps(TestUpdateWrapper): self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) + class TestReduce(unittest.TestCase): func = functools.reduce @@ -452,6 +465,7 @@ class TestReduce(unittest.TestCase): d = {"one": 1, "two": 2, "three": 3} self.assertEqual(self.func(add, d), "".join(d.keys())) + class TestCmpToKey: def test_cmp_to_key(self): @@ -534,14 +548,17 @@ class TestCmpToKey: self.assertRaises(TypeError, hash, k) self.assertNotIsInstance(k, collections.Hashable) + @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): if c_functools: cmp_to_key = c_functools.cmp_to_key + class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): cmp_to_key = staticmethod(py_functools.cmp_to_key) + class TestTotalOrdering(unittest.TestCase): def test_total_ordering_lt(self): @@ -642,6 +659,7 @@ class TestTotalOrdering(unittest.TestCase): with self.assertRaises(TypeError): TestTO(8) <= () + class TestLRU(unittest.TestCase): def test_lru(self): @@ -834,6 +852,353 @@ class TestLRU(unittest.TestCase): DoubleEq(2)) # Verify the correct return value +class TestSingleDispatch(unittest.TestCase): + def test_simple_overloads(self): + @functools.singledispatch + def g(obj): + return "base" + def g_int(i): + return "integer" + g.register(int, g_int) + self.assertEqual(g("str"), "base") + self.assertEqual(g(1), "integer") + self.assertEqual(g([1,2,3]), "base") + + def test_mro(self): + @functools.singledispatch + def g(obj): + return "base" + class C: + pass + class D(C): + pass + def g_C(c): + return "C" + g.register(C, g_C) + self.assertEqual(g(C()), "C") + self.assertEqual(g(D()), "C") + + def test_classic_classes(self): + @functools.singledispatch + def g(obj): + return "base" + class C: + pass + class D(C): + pass + def g_C(c): + return "C" + g.register(C, g_C) + self.assertEqual(g(C()), "C") + self.assertEqual(g(D()), "C") + + def test_register_decorator(self): + @functools.singledispatch + def g(obj): + return "base" + @g.register(int) + def g_int(i): + return "int %s" % (i,) + self.assertEqual(g(""), "base") + self.assertEqual(g(12), "int 12") + self.assertIs(g.dispatch(int), g_int) + self.assertIs(g.dispatch(object), g.dispatch(str)) + # Note: in the assert above this is not g. + # @singledispatch returns the wrapper. + + def test_wrapping_attributes(self): + @functools.singledispatch + def g(obj): + "Simple test" + return "Test" + self.assertEqual(g.__name__, "g") + self.assertEqual(g.__doc__, "Simple test") + + @unittest.skipUnless(decimal, 'requires _decimal') + @support.cpython_only + def test_c_classes(self): + @functools.singledispatch + def g(obj): + return "base" + @g.register(decimal.DecimalException) + def _(obj): + return obj.args + subn = decimal.Subnormal("Exponent < Emin") + rnd = decimal.Rounded("Number got rounded") + self.assertEqual(g(subn), ("Exponent < Emin",)) + self.assertEqual(g(rnd), ("Number got rounded",)) + @g.register(decimal.Subnormal) + def _(obj): + return "Too small to care." + self.assertEqual(g(subn), "Too small to care.") + self.assertEqual(g(rnd), ("Number got rounded",)) + + def test_compose_mro(self): + c = collections + mro = functools._compose_mro + bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] + for haystack in permutations(bases): + m = mro(dict, haystack) + self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, object]) + bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict] + for haystack in permutations(bases): + m = mro(c.ChainMap, haystack) + self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping, + c.Sized, c.Iterable, c.Container, object]) + # Note: The MRO order below depends on haystack ordering. + m = mro(c.defaultdict, [c.Sized, c.Container, str]) + self.assertEqual(m, [c.defaultdict, dict, c.Container, c.Sized, object]) + m = mro(c.defaultdict, [c.Container, c.Sized, str]) + self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object]) + + def test_register_abc(self): + c = collections + d = {"a": "b"} + l = [1, 2, 3] + s = {object(), None} + f = frozenset(s) + t = (1, 2, 3) + @functools.singledispatch + def g(obj): + return "base" + self.assertEqual(g(d), "base") + self.assertEqual(g(l), "base") + self.assertEqual(g(s), "base") + self.assertEqual(g(f), "base") + self.assertEqual(g(t), "base") + g.register(c.Sized, lambda obj: "sized") + self.assertEqual(g(d), "sized") + self.assertEqual(g(l), "sized") + self.assertEqual(g(s), "sized") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + g.register(c.MutableMapping, lambda obj: "mutablemapping") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "sized") + self.assertEqual(g(s), "sized") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + g.register(c.ChainMap, lambda obj: "chainmap") + self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered + self.assertEqual(g(l), "sized") + self.assertEqual(g(s), "sized") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + g.register(c.MutableSequence, lambda obj: "mutablesequence") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "sized") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + g.register(c.MutableSet, lambda obj: "mutableset") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + g.register(c.Mapping, lambda obj: "mapping") + self.assertEqual(g(d), "mutablemapping") # not specific enough + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sized") + g.register(c.Sequence, lambda obj: "sequence") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "sized") + self.assertEqual(g(t), "sequence") + g.register(c.Set, lambda obj: "set") + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "set") + self.assertEqual(g(t), "sequence") + g.register(dict, lambda obj: "dict") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "mutablesequence") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "set") + self.assertEqual(g(t), "sequence") + g.register(list, lambda obj: "list") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + self.assertEqual(g(s), "mutableset") + self.assertEqual(g(f), "set") + self.assertEqual(g(t), "sequence") + g.register(set, lambda obj: "concrete-set") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + self.assertEqual(g(s), "concrete-set") + self.assertEqual(g(f), "set") + self.assertEqual(g(t), "sequence") + g.register(frozenset, lambda obj: "frozen-set") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + self.assertEqual(g(s), "concrete-set") + self.assertEqual(g(f), "frozen-set") + self.assertEqual(g(t), "sequence") + g.register(tuple, lambda obj: "tuple") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + self.assertEqual(g(s), "concrete-set") + self.assertEqual(g(f), "frozen-set") + self.assertEqual(g(t), "tuple") + + def test_mro_conflicts(self): + c = collections + + @functools.singledispatch + def g(arg): + return "base" + + class O(c.Sized): + def __len__(self): + return 0 + + o = O() + self.assertEqual(g(o), "base") + g.register(c.Iterable, lambda arg: "iterable") + g.register(c.Container, lambda arg: "container") + g.register(c.Sized, lambda arg: "sized") + g.register(c.Set, lambda arg: "set") + self.assertEqual(g(o), "sized") + c.Iterable.register(O) + self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ + c.Container.register(O) + self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ + + class P: + pass + + p = P() + self.assertEqual(g(p), "base") + c.Iterable.register(P) + self.assertEqual(g(p), "iterable") + c.Container.register(P) + with self.assertRaises(RuntimeError) as re: + g(p) + self.assertEqual( + str(re), + ("Ambiguous dispatch: <class 'collections.abc.Container'> " + "or <class 'collections.abc.Iterable'>"), + ) + + class Q(c.Sized): + def __len__(self): + return 0 + + q = Q() + self.assertEqual(g(q), "sized") + c.Iterable.register(Q) + self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ + c.Set.register(Q) + self.assertEqual(g(q), "set") # because c.Set is a subclass of + # c.Sized which is explicitly in + # __mro__ + + def test_cache_invalidation(self): + from collections import UserDict + class TracingDict(UserDict): + def __init__(self, *args, **kwargs): + super(TracingDict, self).__init__(*args, **kwargs) + self.set_ops = [] + self.get_ops = [] + def __getitem__(self, key): + result = self.data[key] + self.get_ops.append(key) + return result + def __setitem__(self, key, value): + self.set_ops.append(key) + self.data[key] = value + def clear(self): + self.data.clear() + _orig_wkd = functools.WeakKeyDictionary + td = TracingDict() + functools.WeakKeyDictionary = lambda: td + c = collections + @functools.singledispatch + def g(arg): + return "base" + d = {} + l = [] + self.assertEqual(len(td), 0) + self.assertEqual(g(d), "base") + self.assertEqual(len(td), 1) + self.assertEqual(td.get_ops, []) + self.assertEqual(td.set_ops, [dict]) + self.assertEqual(td.data[dict], g.registry[object]) + self.assertEqual(g(l), "base") + self.assertEqual(len(td), 2) + self.assertEqual(td.get_ops, []) + self.assertEqual(td.set_ops, [dict, list]) + self.assertEqual(td.data[dict], g.registry[object]) + self.assertEqual(td.data[list], g.registry[object]) + self.assertEqual(td.data[dict], td.data[list]) + self.assertEqual(g(l), "base") + self.assertEqual(g(d), "base") + self.assertEqual(td.get_ops, [list, dict]) + self.assertEqual(td.set_ops, [dict, list]) + g.register(list, lambda arg: "list") + self.assertEqual(td.get_ops, [list, dict]) + self.assertEqual(len(td), 0) + self.assertEqual(g(d), "base") + self.assertEqual(len(td), 1) + self.assertEqual(td.get_ops, [list, dict]) + self.assertEqual(td.set_ops, [dict, list, dict]) + self.assertEqual(td.data[dict], + functools._find_impl(dict, g.registry)) + self.assertEqual(g(l), "list") + self.assertEqual(len(td), 2) + self.assertEqual(td.get_ops, [list, dict]) + self.assertEqual(td.set_ops, [dict, list, dict, list]) + self.assertEqual(td.data[list], + functools._find_impl(list, g.registry)) + class X: + pass + c.MutableMapping.register(X) # Will not invalidate the cache, + # not using ABCs yet. + self.assertEqual(g(d), "base") + self.assertEqual(g(l), "list") + self.assertEqual(td.get_ops, [list, dict, dict, list]) + self.assertEqual(td.set_ops, [dict, list, dict, list]) + g.register(c.Sized, lambda arg: "sized") + self.assertEqual(len(td), 0) + self.assertEqual(g(d), "sized") + self.assertEqual(len(td), 1) + self.assertEqual(td.get_ops, [list, dict, dict, list]) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) + self.assertEqual(g(l), "list") + self.assertEqual(len(td), 2) + self.assertEqual(td.get_ops, [list, dict, dict, list]) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) + self.assertEqual(g(l), "list") + self.assertEqual(g(d), "sized") + self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) + g.dispatch(list) + g.dispatch(dict) + self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, + list, dict]) + self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) + c.MutableSet.register(X) # Will invalidate the cache. + self.assertEqual(len(td), 2) # Stale cache. + self.assertEqual(g(l), "list") + self.assertEqual(len(td), 1) + g.register(c.MutableMapping, lambda arg: "mutablemapping") + self.assertEqual(len(td), 0) + self.assertEqual(g(d), "mutablemapping") + self.assertEqual(len(td), 1) + self.assertEqual(g(l), "list") + self.assertEqual(len(td), 2) + g.register(dict, lambda arg: "dict") + self.assertEqual(g(d), "dict") + self.assertEqual(g(l), "list") + g._clear_cache() + self.assertEqual(len(td), 0) + functools.WeakKeyDictionary = _orig_wkd + + def test_main(verbose=None): test_classes = ( TestPartialC, @@ -846,6 +1211,7 @@ def test_main(verbose=None): TestWraps, TestReduce, TestLRU, + TestSingleDispatch, ) support.run_unittest(*test_classes) |