diff options
-rw-r--r-- | Lib/test/test_dataclasses.py | 22 | ||||
-rw-r--r-- | Lib/typing.py | 36 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst | 2 |
3 files changed, 47 insertions, 13 deletions
diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index 8e645ae..33c9fcd 100644 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -10,7 +10,7 @@ import inspect import builtins import unittest from unittest.mock import Mock -from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional +from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol from typing import get_type_hints from collections import deque, OrderedDict, namedtuple from functools import total_ordering @@ -2150,6 +2150,26 @@ class TestInit(unittest.TestCase): self.x = 2 * x self.assertEqual(C(5).x, 10) + def test_inherit_from_protocol(self): + # Dataclasses inheriting from protocol should preserve their own `__init__`. + # See bpo-45081. + + class P(Protocol): + a: int + + @dataclass + class C(P): + a: int + + self.assertEqual(C(5).a, 5) + + @dataclass + class D(P): + def __init__(self, a): + self.a = a * 2 + + self.assertEqual(D(5).a, 10) + class TestRepr(unittest.TestCase): def test_repr(self): diff --git a/Lib/typing.py b/Lib/typing.py index 35c57c2..892f1b3 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -1400,8 +1400,29 @@ def _is_callable_members_only(cls): return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls)) -def _no_init(self, *args, **kwargs): - raise TypeError('Protocols cannot be instantiated') +def _no_init_or_replace_init(self, *args, **kwargs): + cls = type(self) + + if cls._is_protocol: + raise TypeError('Protocols cannot be instantiated') + + # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. + # The first instantiation of the subclass will call `_no_init_or_replace_init` which + # searches for a proper new `__init__` in the MRO. The new `__init__` + # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent + # instantiation of the protocol subclass will thus use the new + # `__init__` and no longer call `_no_init_or_replace_init`. + for base in cls.__mro__: + init = base.__dict__.get('__init__', _no_init_or_replace_init) + if init is not _no_init_or_replace_init: + cls.__init__ = init + break + else: + # should not happen + cls.__init__ = object.__init__ + + cls.__init__(self, *args, **kwargs) + def _caller(depth=1, default='__main__'): try: @@ -1541,15 +1562,6 @@ class Protocol(Generic, metaclass=_ProtocolMeta): # We have nothing more to do for non-protocols... if not cls._is_protocol: - if cls.__init__ == _no_init: - for base in cls.__mro__: - init = base.__dict__.get('__init__', _no_init) - if init != _no_init: - cls.__init__ = init - break - else: - # should not happen - cls.__init__ = object.__init__ return # ... otherwise check consistency of bases, and prohibit instantiation. @@ -1560,7 +1572,7 @@ class Protocol(Generic, metaclass=_ProtocolMeta): issubclass(base, Generic) and base._is_protocol): raise TypeError('Protocols can only inherit from other' ' protocols, got %r' % base) - cls.__init__ = _no_init + cls.__init__ = _no_init_or_replace_init class _AnnotatedAlias(_GenericAlias, _root=True): diff --git a/Misc/NEWS.d/next/Library/2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst b/Misc/NEWS.d/next/Library/2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst new file mode 100644 index 0000000..86d7182 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-09-02-12-42-25.bpo-45081.tOjJ1k.rst @@ -0,0 +1,2 @@ +Fix issue when dataclasses that inherit from ``typing.Protocol`` subclasses +have wrong ``__init__``. Patch provided by Yurii Karabas. |