diff options
author | Łukasz Langa <lukasz@langa.pl> | 2013-07-01 14:00:38 (GMT) |
---|---|---|
committer | Łukasz Langa <lukasz@langa.pl> | 2013-07-01 14:00:38 (GMT) |
commit | 3720c77e307781b1e9f459a8e7844fceacef5cba (patch) | |
tree | 9a0016fbad100b07b248ad5215a51a2eab218eaf /Lib | |
parent | 04926aeb2f88c39a25505e4a0474c6fb735e0f46 (diff) | |
download | cpython-3720c77e307781b1e9f459a8e7844fceacef5cba.zip cpython-3720c77e307781b1e9f459a8e7844fceacef5cba.tar.gz cpython-3720c77e307781b1e9f459a8e7844fceacef5cba.tar.bz2 |
Issue #18244: Adopt C3-based linearization in functools.singledispatch for improved ABC support
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/functools.py | 178 | ||||
-rw-r--r-- | Lib/test/test_functools.py | 174 |
2 files changed, 288 insertions, 64 deletions
diff --git a/Lib/functools.py b/Lib/functools.py index 9403e8e..95c1a41 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -365,46 +365,138 @@ def lru_cache(maxsize=128, typed=False): ### singledispatch() - single-dispatch generic function decorator ################################################################################ -def _compose_mro(cls, haystack): - """Calculates the MRO for a given class `cls`, including relevant abstract - base classes from `haystack`. +def _c3_merge(sequences): + """Merges MROs in *sequences* to a single MRO using the C3 algorithm. + + Adapted from http://www.python.org/download/releases/2.3/mro/. """ - bases = set(cls.__mro__) - mro = list(cls.__mro__) - for needle in haystack: - if (needle in bases or not hasattr(needle, '__mro__') - or not issubclass(cls, needle)): - continue # either present in the __mro__ already or unrelated - for index, base in enumerate(mro): - if not issubclass(base, needle): + result = [] + while True: + sequences = [s for s in sequences if s] # purge empty sequences + if not sequences: + return result + for s1 in sequences: # find merge candidates among seq heads + candidate = s1[0] + for s2 in sequences: + if candidate in s2[1:]: + candidate = None + break # reject the current head, it appears later + else: break - if base in bases and not issubclass(needle, base): - # Conflict resolution: put classes present in __mro__ and their - # subclasses first. See test_mro_conflicts() in test_functools.py - # for examples. - index += 1 - mro.insert(index, needle) - return mro + if not candidate: + raise RuntimeError("Inconsistent hierarchy") + result.append(candidate) + # remove the chosen candidate + for seq in sequences: + if seq[0] == candidate: + del seq[0] + +def _c3_mro(cls, abcs=None): + """Computes the method resolution order using extended C3 linearization. + + If no *abcs* are given, the algorithm works exactly like the built-in C3 + linearization used for method resolution. + + If given, *abcs* is a list of abstract base classes that should be inserted + into the resulting MRO. Unrelated ABCs are ignored and don't end up in the + result. The algorithm inserts ABCs where their functionality is introduced, + i.e. issubclass(cls, abc) returns True for the class itself but returns + False for all its direct base classes. Implicit ABCs for a given class + (either registered or inferred from the presence of a special method like + __len__) are inserted directly after the last ABC explicitly listed in the + MRO of said class. If two implicit ABCs end up next to each other in the + resulting MRO, their ordering depends on the order of types in *abcs*. + + """ + for i, base in enumerate(reversed(cls.__bases__)): + if hasattr(base, '__abstractmethods__'): + boundary = len(cls.__bases__) - i + break # Bases up to the last explicit ABC are considered first. + else: + boundary = 0 + abcs = list(abcs) if abcs else [] + explicit_bases = list(cls.__bases__[:boundary]) + abstract_bases = [] + other_bases = list(cls.__bases__[boundary:]) + for base in abcs: + if issubclass(cls, base) and not any( + issubclass(b, base) for b in cls.__bases__ + ): + # If *cls* is the class that introduces behaviour described by + # an ABC *base*, insert said ABC to its MRO. + abstract_bases.append(base) + for base in abstract_bases: + abcs.remove(base) + explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases] + abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases] + other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases] + return _c3_merge( + [[cls]] + + explicit_c3_mros + abstract_c3_mros + other_c3_mros + + [explicit_bases] + [abstract_bases] + [other_bases] + ) + +def _compose_mro(cls, types): + """Calculates the method resolution order for a given class *cls*. + + Includes relevant abstract base classes (with their respective bases) from + the *types* iterable. Uses a modified C3 linearization algorithm. + + """ + bases = set(cls.__mro__) + # Remove entries which are already present in the __mro__ or unrelated. + def is_related(typ): + return (typ not in bases and hasattr(typ, '__mro__') + and issubclass(cls, typ)) + types = [n for n in types if is_related(n)] + # Remove entries which are strict bases of other entries (they will end up + # in the MRO anyway. + def is_strict_base(typ): + for other in types: + if typ != other and typ in other.__mro__: + return True + return False + types = [n for n in types if not is_strict_base(n)] + # Subclasses of the ABCs in *types* which are also implemented by + # *cls* can be used to stabilize ABC ordering. + type_set = set(types) + mro = [] + for typ in types: + found = [] + for sub in typ.__subclasses__(): + if sub not in bases and issubclass(cls, sub): + found.append([s for s in sub.__mro__ if s in type_set]) + if not found: + mro.append(typ) + continue + # Favor subclasses with the biggest number of useful bases + found.sort(key=len, reverse=True) + for sub in found: + for subcls in sub: + if subcls not in mro: + mro.append(subcls) + return _c3_mro(cls, abcs=mro) def _find_impl(cls, registry): - """Returns the best matching implementation for the given class `cls` in - `registry`. Where there is no registered implementation for a specific - type, its method resolution order is used to find a more generic - implementation. + """Returns the best matching implementation from *registry* for type *cls*. + + Where there is no registered implementation for a specific type, its method + resolution order is used to find a more generic implementation. - Note: if `registry` does not contain an implementation for the base - `object` type, this function may return None. + Note: if *registry* does not contain an implementation for the base + *object* type, this function may return None. """ mro = _compose_mro(cls, registry.keys()) match = None for t in mro: if match is not None: - # If `match` is an ABC but there is another unrelated, equally - # matching ABC. Refuse the temptation to guess. - if (t in registry and not issubclass(match, t) - and match not in cls.__mro__): + # If *match* is an implicit ABC but there is another unrelated, + # equally matching implicit ABC, refuse the temptation to guess. + if (t in registry and t not in cls.__mro__ + and match not in cls.__mro__ + and not issubclass(match, t)): raise RuntimeError("Ambiguous dispatch: {} or {}".format( match, t)) break @@ -418,19 +510,19 @@ def singledispatch(func): Transforms a function into a generic function, which can have different behaviours depending upon the type of its first argument. The decorated function acts as the default implementation, and additional - implementations can be registered using the 'register()' attribute of - the generic function. + implementations can be registered using the register() attribute of the + generic function. """ registry = {} dispatch_cache = WeakKeyDictionary() cache_token = None - def dispatch(typ): - """generic_func.dispatch(type) -> <function implementation> + def dispatch(cls): + """generic_func.dispatch(cls) -> <function implementation> Runs the dispatch algorithm to return the best available implementation - for the given `type` registered on `generic_func`. + for the given *cls* registered on *generic_func*. """ nonlocal cache_token @@ -440,26 +532,26 @@ def singledispatch(func): dispatch_cache.clear() cache_token = current_token try: - impl = dispatch_cache[typ] + impl = dispatch_cache[cls] except KeyError: try: - impl = registry[typ] + impl = registry[cls] except KeyError: - impl = _find_impl(typ, registry) - dispatch_cache[typ] = impl + impl = _find_impl(cls, registry) + dispatch_cache[cls] = impl return impl - def register(typ, func=None): - """generic_func.register(type, func) -> func + def register(cls, func=None): + """generic_func.register(cls, func) -> func - Registers a new implementation for the given `type` on a `generic_func`. + Registers a new implementation for the given *cls* on a *generic_func*. """ nonlocal cache_token if func is None: - return lambda f: register(typ, f) - registry[typ] = func - if cache_token is None and hasattr(typ, '__abstractmethods__'): + return lambda f: register(cls, f) + registry[cls] = func + if cache_token is None and hasattr(cls, '__abstractmethods__'): cache_token = get_cache_token() dispatch_cache.clear() return func diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 49c807d..99dccb0 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -929,22 +929,55 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(rnd), ("Number got rounded",)) def test_compose_mro(self): + # None of the examples in this test depend on haystack ordering. c = collections mro = functools._compose_mro bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] for haystack in permutations(bases): m = mro(dict, haystack) - self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, object]) + self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized, + c.Iterable, c.Container, object]) bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict] for haystack in permutations(bases): m = mro(c.ChainMap, haystack) self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping, c.Sized, c.Iterable, c.Container, object]) - # Note: The MRO order below depends on haystack ordering. - m = mro(c.defaultdict, [c.Sized, c.Container, str]) - self.assertEqual(m, [c.defaultdict, dict, c.Container, c.Sized, object]) - m = mro(c.defaultdict, [c.Container, c.Sized, str]) - self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object]) + + # If there's a generic function with implementations registered for + # both Sized and Container, passing a defaultdict to it results in an + # ambiguous dispatch which will cause a RuntimeError (see + # test_mro_conflicts). + bases = [c.Container, c.Sized, str] + for haystack in permutations(bases): + m = mro(c.defaultdict, [c.Sized, c.Container, str]) + self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, + object]) + + # MutableSequence below is registered directly on D. In other words, it + # preceeds MutableMapping which means single dispatch will always + # choose MutableSequence here. + class D(c.defaultdict): + pass + c.MutableSequence.register(D) + bases = [c.MutableSequence, c.MutableMapping] + for haystack in permutations(bases): + m = mro(D, bases) + self.assertEqual(m, [D, c.MutableSequence, c.Sequence, + c.defaultdict, dict, c.MutableMapping, + c.Mapping, c.Sized, c.Iterable, c.Container, + object]) + + # Container and Callable are registered on different base classes and + # a generic function supporting both should always pick the Callable + # implementation if a C instance is passed. + class C(c.defaultdict): + def __call__(self): + pass + bases = [c.Sized, c.Callable, c.Container, c.Mapping] + for haystack in permutations(bases): + m = mro(C, haystack) + self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping, + c.Sized, c.Iterable, c.Container, object]) def test_register_abc(self): c = collections @@ -1040,17 +1073,37 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(f), "frozen-set") self.assertEqual(g(t), "tuple") - def test_mro_conflicts(self): + def test_c3_abc(self): c = collections + mro = functools._c3_mro + class A(object): + pass + class B(A): + def __len__(self): + return 0 # implies Sized + @c.Container.register + class C(object): + pass + class D(object): + pass # unrelated + class X(D, C, B): + def __call__(self): + pass # implies Callable + expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] + for abcs in permutations([c.Sized, c.Callable, c.Container]): + self.assertEqual(mro(X, abcs=abcs), expected) + # unrelated ABCs don't appear in the resulting MRO + many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] + self.assertEqual(mro(X, abcs=many_abcs), expected) + def test_mro_conflicts(self): + c = collections @functools.singledispatch def g(arg): return "base" - class O(c.Sized): def __len__(self): return 0 - o = O() self.assertEqual(g(o), "base") g.register(c.Iterable, lambda arg: "iterable") @@ -1062,35 +1115,114 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ c.Container.register(O) self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ - + c.Set.register(O) + self.assertEqual(g(o), "set") # because c.Set is a subclass of + # c.Sized and c.Container class P: pass - p = P() self.assertEqual(g(p), "base") c.Iterable.register(P) self.assertEqual(g(p), "iterable") c.Container.register(P) - with self.assertRaises(RuntimeError) as re: + with self.assertRaises(RuntimeError) as re_one: g(p) - self.assertEqual( - str(re), - ("Ambiguous dispatch: <class 'collections.abc.Container'> " - "or <class 'collections.abc.Iterable'>"), - ) - + self.assertIn( + str(re_one.exception), + (("Ambiguous dispatch: <class 'collections.abc.Container'> " + "or <class 'collections.abc.Iterable'>"), + ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " + "or <class 'collections.abc.Container'>")), + ) class Q(c.Sized): def __len__(self): return 0 - q = Q() self.assertEqual(g(q), "sized") c.Iterable.register(Q) self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ c.Set.register(Q) self.assertEqual(g(q), "set") # because c.Set is a subclass of - # c.Sized which is explicitly in - # __mro__ + # c.Sized and c.Iterable + @functools.singledispatch + def h(arg): + return "base" + @h.register(c.Sized) + def _(arg): + return "sized" + @h.register(c.Container) + def _(arg): + return "container" + # Even though Sized and Container are explicit bases of MutableMapping, + # this ABC is implicitly registered on defaultdict which makes all of + # MutableMapping's bases implicit as well from defaultdict's + # perspective. + with self.assertRaises(RuntimeError) as re_two: + h(c.defaultdict(lambda: 0)) + self.assertIn( + str(re_two.exception), + (("Ambiguous dispatch: <class 'collections.abc.Container'> " + "or <class 'collections.abc.Sized'>"), + ("Ambiguous dispatch: <class 'collections.abc.Sized'> " + "or <class 'collections.abc.Container'>")), + ) + class R(c.defaultdict): + pass + c.MutableSequence.register(R) + @functools.singledispatch + def i(arg): + return "base" + @i.register(c.MutableMapping) + def _(arg): + return "mapping" + @i.register(c.MutableSequence) + def _(arg): + return "sequence" + r = R() + self.assertEqual(i(r), "sequence") + class S: + pass + class T(S, c.Sized): + def __len__(self): + return 0 + t = T() + self.assertEqual(h(t), "sized") + c.Container.register(T) + self.assertEqual(h(t), "sized") # because it's explicitly in the MRO + class U: + def __len__(self): + return 0 + u = U() + self.assertEqual(h(u), "sized") # implicit Sized subclass inferred + # from the existence of __len__() + c.Container.register(U) + # There is no preference for registered versus inferred ABCs. + with self.assertRaises(RuntimeError) as re_three: + h(u) + self.assertIn( + str(re_three.exception), + (("Ambiguous dispatch: <class 'collections.abc.Container'> " + "or <class 'collections.abc.Sized'>"), + ("Ambiguous dispatch: <class 'collections.abc.Sized'> " + "or <class 'collections.abc.Container'>")), + ) + class V(c.Sized, S): + def __len__(self): + return 0 + @functools.singledispatch + def j(arg): + return "base" + @j.register(S) + def _(arg): + return "s" + @j.register(c.Container) + def _(arg): + return "container" + v = V() + self.assertEqual(j(v), "s") + c.Container.register(V) + self.assertEqual(j(v), "container") # because it ends up right after + # Sized in the MRO def test_cache_invalidation(self): from collections import UserDict |