diff options
author | Jelle Zijlstra <jelle.zijlstra@gmail.com> | 2022-04-16 16:01:43 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-16 16:01:43 (GMT) |
commit | 055760ed9e745a3104acbfa8a3b76eb26a72590d (patch) | |
tree | 8e846d627b70ca2e746aa1c6a8f6ded2988c086b /Lib | |
parent | 9300b6d72948b94c0924a75ea14c6298156522d0 (diff) | |
download | cpython-055760ed9e745a3104acbfa8a3b76eb26a72590d.zip cpython-055760ed9e745a3104acbfa8a3b76eb26a72590d.tar.gz cpython-055760ed9e745a3104acbfa8a3b76eb26a72590d.tar.bz2 |
gh-89263: Add typing.get_overloads (GH-31716)
Based on suggestions by Guido van Rossum, Spencer Brown, and Alex Waygood.
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Guido van Rossum <gvanrossum@gmail.com>
Co-authored-by: Ken Jin <kenjin4096@gmail.com>
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/test_typing.py | 72 | ||||
-rw-r--r-- | Lib/typing.py | 34 |
2 files changed, 102 insertions, 4 deletions
diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index ffd0592..d480847 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -1,5 +1,6 @@ import contextlib import collections +from collections import defaultdict from functools import lru_cache import inspect import pickle @@ -7,9 +8,11 @@ import re import sys import warnings from unittest import TestCase, main, skipUnless, skip +from unittest.mock import patch from copy import copy, deepcopy from typing import Any, NoReturn, Never, assert_never +from typing import overload, get_overloads, clear_overloads from typing import TypeVar, TypeVarTuple, Unpack, AnyStr from typing import T, KT, VT # Not in __all__. from typing import Union, Optional, Literal @@ -3890,11 +3893,22 @@ class ForwardRefTests(BaseTestCase): self.assertEqual("x" | X, Union["x", X]) +@lru_cache() +def cached_func(x, y): + return 3 * x + y + + +class MethodHolder: + @classmethod + def clsmethod(cls): ... + @staticmethod + def stmethod(): ... + def method(self): ... + + class OverloadTests(BaseTestCase): def test_overload_fails(self): - from typing import overload - with self.assertRaises(RuntimeError): @overload @@ -3904,8 +3918,6 @@ class OverloadTests(BaseTestCase): blah() def test_overload_succeeds(self): - from typing import overload - @overload def blah(): pass @@ -3915,6 +3927,58 @@ class OverloadTests(BaseTestCase): blah() + def set_up_overloads(self): + def blah(): + pass + + overload1 = blah + overload(blah) + + def blah(): + pass + + overload2 = blah + overload(blah) + + def blah(): + pass + + return blah, [overload1, overload2] + + # Make sure we don't clear the global overload registry + @patch("typing._overload_registry", + defaultdict(lambda: defaultdict(dict))) + def test_overload_registry(self): + # The registry starts out empty + self.assertEqual(typing._overload_registry, {}) + + impl, overloads = self.set_up_overloads() + self.assertNotEqual(typing._overload_registry, {}) + self.assertEqual(list(get_overloads(impl)), overloads) + + def some_other_func(): pass + overload(some_other_func) + other_overload = some_other_func + def some_other_func(): pass + self.assertEqual(list(get_overloads(some_other_func)), [other_overload]) + + # Make sure that after we clear all overloads, the registry is + # completely empty. + clear_overloads() + self.assertEqual(typing._overload_registry, {}) + self.assertEqual(get_overloads(impl), []) + + # Querying a function with no overloads shouldn't change the registry. + def the_only_one(): pass + self.assertEqual(get_overloads(the_only_one), []) + self.assertEqual(typing._overload_registry, {}) + + def test_overload_registry_repeated(self): + for _ in range(2): + impl, overloads = self.set_up_overloads() + + self.assertEqual(list(get_overloads(impl)), overloads) + # Definitions needed for features introduced in Python 3.6 diff --git a/Lib/typing.py b/Lib/typing.py index b26adc6..3e0fbdb 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -21,6 +21,7 @@ At large scale, the structure of the module is following: from abc import abstractmethod, ABCMeta import collections +from collections import defaultdict import collections.abc import contextlib import functools @@ -121,9 +122,11 @@ __all__ = [ 'assert_type', 'assert_never', 'cast', + 'clear_overloads', 'final', 'get_args', 'get_origin', + 'get_overloads', 'get_type_hints', 'is_typeddict', 'LiteralString', @@ -2450,6 +2453,10 @@ def _overload_dummy(*args, **kwds): "by an implementation that is not @overload-ed.") +# {module: {qualname: {firstlineno: func}}} +_overload_registry = defaultdict(functools.partial(defaultdict, dict)) + + def overload(func): """Decorator for overloaded functions/methods. @@ -2475,10 +2482,37 @@ def overload(func): def utf8(value: str) -> bytes: ... def utf8(value): # implementation goes here + + The overloads for a function can be retrieved at runtime using the + get_overloads() function. """ + # classmethod and staticmethod + f = getattr(func, "__func__", func) + try: + _overload_registry[f.__module__][f.__qualname__][f.__code__.co_firstlineno] = func + except AttributeError: + # Not a normal function; ignore. + pass return _overload_dummy +def get_overloads(func): + """Return all defined overloads for *func* as a sequence.""" + # classmethod and staticmethod + f = getattr(func, "__func__", func) + if f.__module__ not in _overload_registry: + return [] + mod_dict = _overload_registry[f.__module__] + if f.__qualname__ not in mod_dict: + return [] + return list(mod_dict[f.__qualname__].values()) + + +def clear_overloads(): + """Clear all overloads in the registry.""" + _overload_registry.clear() + + def final(f): """A decorator to indicate final methods and final classes. |