summaryrefslogtreecommitdiffstats
path: root/Lib/dataclasses.py
diff options
context:
space:
mode:
authorArie Bovenberg <a.c.bovenberg@gmail.com>2022-03-19 21:01:17 (GMT)
committerGitHub <noreply@github.com>2022-03-19 21:01:17 (GMT)
commit82e9b0bb0ac44d4942b9e01b2cdd2ca85c17e563 (patch)
treec1cb2d3397dc3b907f8d19c7682a4703c4494d75 /Lib/dataclasses.py
parent383a3bec74f0bf0c1b1bef9e0048db389c618452 (diff)
downloadcpython-82e9b0bb0ac44d4942b9e01b2cdd2ca85c17e563.zip
cpython-82e9b0bb0ac44d4942b9e01b2cdd2ca85c17e563.tar.gz
cpython-82e9b0bb0ac44d4942b9e01b2cdd2ca85c17e563.tar.bz2
bpo-46382 dataclass(slots=True) now takes inherited slots into account (GH-31980)
Do not include any members in __slots__ that are already in a base class's __slots__.
Diffstat (limited to 'Lib/dataclasses.py')
-rw-r--r--Lib/dataclasses.py23
1 files changed, 22 insertions, 1 deletions
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index b327462..6be7c7b 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -6,6 +6,7 @@ import inspect
import keyword
import builtins
import functools
+import itertools
import abc
import _thread
from types import FunctionType, GenericAlias
@@ -1122,6 +1123,20 @@ def _dataclass_setstate(self, state):
object.__setattr__(self, field.name, value)
+def _get_slots(cls):
+ match cls.__dict__.get('__slots__'):
+ case None:
+ return
+ case str(slot):
+ yield slot
+ # Slots may be any iterable, but we cannot handle an iterator
+ # because it will already be (partially) consumed.
+ case iterable if not hasattr(iterable, '__next__'):
+ yield from iterable
+ case _:
+ raise TypeError(f"Slots of '{cls.__name__}' cannot be determined")
+
+
def _add_slots(cls, is_frozen):
# Need to create a new class, since we can't set __slots__
# after a class has been created.
@@ -1133,7 +1148,13 @@ def _add_slots(cls, is_frozen):
# Create a new dict for our new class.
cls_dict = dict(cls.__dict__)
field_names = tuple(f.name for f in fields(cls))
- cls_dict['__slots__'] = field_names
+ # Make sure slots don't overlap with those in base classes.
+ inherited_slots = set(
+ itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1]))
+ )
+ cls_dict["__slots__"] = tuple(
+ itertools.filterfalse(inherited_slots.__contains__, field_names)
+ )
for field_name in field_names:
# Remove our attributes, if present. They'll still be
# available in _MARKER.