diff options
author | Arie Bovenberg <a.c.bovenberg@gmail.com> | 2022-03-19 21:01:17 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-03-19 21:01:17 (GMT) |
commit | 82e9b0bb0ac44d4942b9e01b2cdd2ca85c17e563 (patch) | |
tree | c1cb2d3397dc3b907f8d19c7682a4703c4494d75 /Lib/dataclasses.py | |
parent | 383a3bec74f0bf0c1b1bef9e0048db389c618452 (diff) | |
download | cpython-82e9b0bb0ac44d4942b9e01b2cdd2ca85c17e563.zip cpython-82e9b0bb0ac44d4942b9e01b2cdd2ca85c17e563.tar.gz cpython-82e9b0bb0ac44d4942b9e01b2cdd2ca85c17e563.tar.bz2 |
bpo-46382 dataclass(slots=True) now takes inherited slots into account (GH-31980)
Do not include any members in __slots__ that are already in a base class's __slots__.
Diffstat (limited to 'Lib/dataclasses.py')
-rw-r--r-- | Lib/dataclasses.py | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index b327462..6be7c7b 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -6,6 +6,7 @@ import inspect import keyword import builtins import functools +import itertools import abc import _thread from types import FunctionType, GenericAlias @@ -1122,6 +1123,20 @@ def _dataclass_setstate(self, state): object.__setattr__(self, field.name, value) +def _get_slots(cls): + match cls.__dict__.get('__slots__'): + case None: + return + case str(slot): + yield slot + # Slots may be any iterable, but we cannot handle an iterator + # because it will already be (partially) consumed. + case iterable if not hasattr(iterable, '__next__'): + yield from iterable + case _: + raise TypeError(f"Slots of '{cls.__name__}' cannot be determined") + + def _add_slots(cls, is_frozen): # Need to create a new class, since we can't set __slots__ # after a class has been created. @@ -1133,7 +1148,13 @@ def _add_slots(cls, is_frozen): # Create a new dict for our new class. cls_dict = dict(cls.__dict__) field_names = tuple(f.name for f in fields(cls)) - cls_dict['__slots__'] = field_names + # Make sure slots don't overlap with those in base classes. + inherited_slots = set( + itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1])) + ) + cls_dict["__slots__"] = tuple( + itertools.filterfalse(inherited_slots.__contains__, field_names) + ) for field_name in field_names: # Remove our attributes, if present. They'll still be # available in _MARKER. |