diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/functools.py | 21 | ||||
-rw-r--r-- | Lib/test/test_functools.py | 68 |
2 files changed, 87 insertions, 2 deletions
diff --git a/Lib/functools.py b/Lib/functools.py index 8518450..2a8a69b 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -928,11 +928,14 @@ class singledispatchmethod: """ def __init__(self, func): + import weakref # see comment in singledispatch function if not callable(func) and not hasattr(func, "__get__"): raise TypeError(f"{func!r} is not callable or a descriptor") self.dispatcher = singledispatch(func) self.func = func + self._method_cache = weakref.WeakKeyDictionary() + self._all_weakrefable_instances = True def register(self, cls, method=None): """generic_method.register(cls, func) -> func @@ -942,13 +945,27 @@ class singledispatchmethod: return self.dispatcher.register(cls, func=method) def __get__(self, obj, cls=None): + if self._all_weakrefable_instances: + try: + _method = self._method_cache[obj] + except TypeError: + self._all_weakrefable_instances = False + except KeyError: + pass + else: + return _method + + dispatch = self.dispatcher.dispatch def _method(*args, **kwargs): - method = self.dispatcher.dispatch(args[0].__class__) - return method.__get__(obj, cls)(*args, **kwargs) + return dispatch(args[0].__class__).__get__(obj, cls)(*args, **kwargs) _method.__isabstractmethod__ = self.__isabstractmethod__ _method.register = self.register update_wrapper(_method, self.func) + + if self._all_weakrefable_instances: + self._method_cache[obj] = _method + return _method @property diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index c4eca0f..50770f0 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -2474,6 +2474,74 @@ class TestSingleDispatch(unittest.TestCase): self.assertTrue(A.t('')) self.assertEqual(A.t(0.0), 0.0) + def test_slotted_class(self): + class Slot: + __slots__ = ('a', 'b') + @functools.singledispatchmethod + def go(self, item, arg): + pass + + @go.register + def _(self, item: int, arg): + return item + arg + + s = Slot() + self.assertEqual(s.go(1, 1), 2) + + def test_classmethod_slotted_class(self): + class Slot: + __slots__ = ('a', 'b') + @functools.singledispatchmethod + @classmethod + def go(cls, item, arg): + pass + + @go.register + @classmethod + def _(cls, item: int, arg): + return item + arg + + s = Slot() + self.assertEqual(s.go(1, 1), 2) + self.assertEqual(Slot.go(1, 1), 2) + + def test_staticmethod_slotted_class(self): + class A: + __slots__ = ['a'] + @functools.singledispatchmethod + @staticmethod + def t(arg): + return arg + @t.register(int) + @staticmethod + def _(arg): + return isinstance(arg, int) + @t.register(str) + @staticmethod + def _(arg): + return isinstance(arg, str) + a = A() + + self.assertTrue(A.t(0)) + self.assertTrue(A.t('')) + self.assertEqual(A.t(0.0), 0.0) + self.assertTrue(a.t(0)) + self.assertTrue(a.t('')) + self.assertEqual(a.t(0.0), 0.0) + + def test_assignment_behavior(self): + # see gh-106448 + class A: + @functools.singledispatchmethod + def t(arg): + return arg + + a = A() + a.t.foo = 'bar' + a2 = A() + with self.assertRaises(AttributeError): + a2.t.foo + def test_classmethod_register(self): class A: def __init__(self, arg): |