summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/dataclasses.py120
-rwxr-xr-xLib/test/test_dataclasses.py99
-rw-r--r--Misc/NEWS.d/next/Library/2018-03-18-17-38-48.bpo-32953.t8WAWN.rst4
3 files changed, 168 insertions, 55 deletions
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index b55a497..8ab04dd 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -171,7 +171,11 @@ _FIELD_INITVAR = object() # Not a field, but an InitVar.
# The name of an attribute on the class where we store the Field
# objects. Also used to check if a class is a Data Class.
-_MARKER = '__dataclass_fields__'
+_FIELDS = '__dataclass_fields__'
+
+# The name of an attribute on the class that stores the parameters to
+# @dataclass.
+_PARAMS = '__dataclass_params__'
# The name of the function, that if it exists, is called at the end of
# __init__.
@@ -192,7 +196,7 @@ class InitVar(metaclass=_InitVarMeta):
# name and type are filled in after the fact, not in __init__. They're
# not known at the time this class is instantiated, but it's
# convenient if they're available later.
-# When cls._MARKER is filled in with a list of Field objects, the name
+# When cls._FIELDS is filled in with a list of Field objects, the name
# and type fields will have been populated.
class Field:
__slots__ = ('name',
@@ -236,6 +240,32 @@ class Field:
')')
+class _DataclassParams:
+ __slots__ = ('init',
+ 'repr',
+ 'eq',
+ 'order',
+ 'unsafe_hash',
+ 'frozen',
+ )
+ def __init__(self, init, repr, eq, order, unsafe_hash, frozen):
+ self.init = init
+ self.repr = repr
+ self.eq = eq
+ self.order = order
+ self.unsafe_hash = unsafe_hash
+ self.frozen = frozen
+
+ def __repr__(self):
+ return ('_DataclassParams('
+ f'init={self.init},'
+ f'repr={self.repr},'
+ f'eq={self.eq},'
+ f'order={self.order},'
+ f'unsafe_hash={self.unsafe_hash},'
+ f'frozen={self.frozen}'
+ ')')
+
# This function is used instead of exposing Field creation directly,
# so that a type checker can be told (via overloads) that this is a
# function whose type depends on its parameters.
@@ -285,6 +315,7 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
args = ','.join(args)
body = '\n'.join(f' {b}' for b in body)
+ # Compute the text of the entire function.
txt = f'def {name}({args}){return_annotation}:\n{body}'
exec(txt, globals, locals)
@@ -432,12 +463,29 @@ def _repr_fn(fields):
')"'])
-def _frozen_setattr(self, name, value):
- raise FrozenInstanceError(f'cannot assign to field {name!r}')
-
-
-def _frozen_delattr(self, name):
- raise FrozenInstanceError(f'cannot delete field {name!r}')
+def _frozen_get_del_attr(cls, fields):
+ # XXX: globals is modified on the first call to _create_fn, then the
+ # modified version is used in the second call. Is this okay?
+ globals = {'cls': cls,
+ 'FrozenInstanceError': FrozenInstanceError}
+ if fields:
+ fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
+ else:
+ # Special case for the zero-length tuple.
+ fields_str = '()'
+ return (_create_fn('__setattr__',
+ ('self', 'name', 'value'),
+ (f'if type(self) is cls or name in {fields_str}:',
+ ' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
+ f'super(cls, self).__setattr__(name, value)'),
+ globals=globals),
+ _create_fn('__delattr__',
+ ('self', 'name'),
+ (f'if type(self) is cls or name in {fields_str}:',
+ ' raise FrozenInstanceError(f"cannot delete field {name!r}")',
+ f'super(cls, self).__delattr__(name)'),
+ globals=globals),
+ )
def _cmp_fn(name, op, self_tuple, other_tuple):
@@ -583,23 +631,32 @@ _hash_action = {(False, False, False, False): (''),
# version of this table.
-def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
+def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# Now that dicts retain insertion order, there's no reason to use
# an ordered dict. I am leveraging that ordering here, because
# derived class fields overwrite base class fields, but the order
# is defined by the base class, which is found first.
fields = {}
+ setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
+ unsafe_hash, frozen))
+
# Find our base classes in reverse MRO order, and exclude
# ourselves. In reversed order so that more derived classes
# override earlier field definitions in base classes.
+ # As long as we're iterating over them, see if any are frozen.
+ any_frozen_base = False
+ has_dataclass_bases = False
for b in cls.__mro__[-1:0:-1]:
# Only process classes that have been processed by our
- # decorator. That is, they have a _MARKER attribute.
- base_fields = getattr(b, _MARKER, None)
+ # decorator. That is, they have a _FIELDS attribute.
+ base_fields = getattr(b, _FIELDS, None)
if base_fields:
+ has_dataclass_bases = True
for f in base_fields.values():
fields[f.name] = f
+ if getattr(b, _PARAMS).frozen:
+ any_frozen_base = True
# Now find fields in our class. While doing so, validate some
# things, and set the default values (as class attributes)
@@ -623,20 +680,21 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
else:
setattr(cls, f.name, f.default)
- # We're inheriting from a frozen dataclass, but we're not frozen.
- if cls.__setattr__ is _frozen_setattr and not frozen:
- raise TypeError('cannot inherit non-frozen dataclass from a '
- 'frozen one')
+ # Check rules that apply if we are derived from any dataclasses.
+ if has_dataclass_bases:
+ # Raise an exception if any of our bases are frozen, but we're not.
+ if any_frozen_base and not frozen:
+ raise TypeError('cannot inherit non-frozen dataclass from a '
+ 'frozen one')
- # We're inheriting from a non-frozen dataclass, but we're frozen.
- if (hasattr(cls, _MARKER) and cls.__setattr__ is not _frozen_setattr
- and frozen):
- raise TypeError('cannot inherit frozen dataclass from a '
- 'non-frozen one')
+ # Raise an exception if we're frozen, but none of our bases are.
+ if not any_frozen_base and frozen:
+ raise TypeError('cannot inherit frozen dataclass from a '
+ 'non-frozen one')
- # Remember all of the fields on our class (including bases). This
+ # Remember all of the fields on our class (including bases). This also
# marks this class as being a dataclass.
- setattr(cls, _MARKER, fields)
+ setattr(cls, _FIELDS, fields)
# Was this class defined with an explicit __hash__? Note that if
# __eq__ is defined in this class, then python will automatically
@@ -704,10 +762,10 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
'functools.total_ordering')
if frozen:
- for name, fn in [('__setattr__', _frozen_setattr),
- ('__delattr__', _frozen_delattr)]:
- if _set_new_attribute(cls, name, fn):
- raise TypeError(f'Cannot overwrite attribute {name} '
+ # XXX: Which fields are frozen? InitVar? ClassVar? hashed-only?
+ for fn in _frozen_get_del_attr(cls, field_list):
+ if _set_new_attribute(cls, fn.__name__, fn):
+ raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
f'in class {cls.__name__}')
# Decide if/how we're going to create a hash function.
@@ -759,7 +817,7 @@ def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
"""
def wrap(cls):
- return _process_class(cls, repr, eq, order, unsafe_hash, init, frozen)
+ return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)
# See if we're being called as @dataclass or @dataclass().
if _cls is None:
@@ -779,7 +837,7 @@ def fields(class_or_instance):
# Might it be worth caching this, per class?
try:
- fields = getattr(class_or_instance, _MARKER)
+ fields = getattr(class_or_instance, _FIELDS)
except AttributeError:
raise TypeError('must be called with a dataclass type or instance')
@@ -790,13 +848,13 @@ def fields(class_or_instance):
def _is_dataclass_instance(obj):
"""Returns True if obj is an instance of a dataclass."""
- return not isinstance(obj, type) and hasattr(obj, _MARKER)
+ return not isinstance(obj, type) and hasattr(obj, _FIELDS)
def is_dataclass(obj):
"""Returns True if obj is a dataclass or an instance of a
dataclass."""
- return hasattr(obj, _MARKER)
+ return hasattr(obj, _FIELDS)
def asdict(obj, *, dict_factory=dict):
@@ -953,7 +1011,7 @@ def replace(obj, **changes):
# It's an error to have init=False fields in 'changes'.
# If a field is not in 'changes', read its value from the provided obj.
- for f in getattr(obj, _MARKER).values():
+ for f in getattr(obj, _FIELDS).values():
if not f.init:
# Error if this field is specified in changes.
if f.name in changes:
diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py
index 46d485c..3e67263 100755
--- a/Lib/test/test_dataclasses.py
+++ b/Lib/test/test_dataclasses.py
@@ -2476,41 +2476,92 @@ class TestFrozen(unittest.TestCase):
d = D(0, 10)
with self.assertRaises(FrozenInstanceError):
d.i = 5
+ with self.assertRaises(FrozenInstanceError):
+ d.j = 6
self.assertEqual(d.i, 0)
+ self.assertEqual(d.j, 10)
+
+ # Test both ways: with an intermediate normal (non-dataclass)
+ # class and without an intermediate class.
+ def test_inherit_nonfrozen_from_frozen(self):
+ for intermediate_class in [True, False]:
+ with self.subTest(intermediate_class=intermediate_class):
+ @dataclass(frozen=True)
+ class C:
+ i: int
- def test_inherit_from_nonfrozen_from_frozen(self):
- @dataclass(frozen=True)
- class C:
- i: int
+ if intermediate_class:
+ class I(C): pass
+ else:
+ I = C
- with self.assertRaisesRegex(TypeError,
- 'cannot inherit non-frozen dataclass from a frozen one'):
- @dataclass
- class D(C):
- pass
+ with self.assertRaisesRegex(TypeError,
+ 'cannot inherit non-frozen dataclass from a frozen one'):
+ @dataclass
+ class D(I):
+ pass
- def test_inherit_from_frozen_from_nonfrozen(self):
- @dataclass
- class C:
- i: int
+ def test_inherit_frozen_from_nonfrozen(self):
+ for intermediate_class in [True, False]:
+ with self.subTest(intermediate_class=intermediate_class):
+ @dataclass
+ class C:
+ i: int
- with self.assertRaisesRegex(TypeError,
- 'cannot inherit frozen dataclass from a non-frozen one'):
- @dataclass(frozen=True)
- class D(C):
- pass
+ if intermediate_class:
+ class I(C): pass
+ else:
+ I = C
+
+ with self.assertRaisesRegex(TypeError,
+ 'cannot inherit frozen dataclass from a non-frozen one'):
+ @dataclass(frozen=True)
+ class D(I):
+ pass
def test_inherit_from_normal_class(self):
- class C:
- pass
+ for intermediate_class in [True, False]:
+ with self.subTest(intermediate_class=intermediate_class):
+ class C:
+ pass
+
+ if intermediate_class:
+ class I(C): pass
+ else:
+ I = C
+
+ @dataclass(frozen=True)
+ class D(I):
+ i: int
+
+ d = D(10)
+ with self.assertRaises(FrozenInstanceError):
+ d.i = 5
+
+ def test_non_frozen_normal_derived(self):
+ # See bpo-32953.
@dataclass(frozen=True)
- class D(C):
- i: int
+ class D:
+ x: int
+ y: int = 10
- d = D(10)
+ class S(D):
+ pass
+
+ s = S(3)
+ self.assertEqual(s.x, 3)
+ self.assertEqual(s.y, 10)
+ s.cached = True
+
+ # But can't change the frozen attributes.
with self.assertRaises(FrozenInstanceError):
- d.i = 5
+ s.x = 5
+ with self.assertRaises(FrozenInstanceError):
+ s.y = 5
+ self.assertEqual(s.x, 3)
+ self.assertEqual(s.y, 10)
+ self.assertEqual(s.cached, True)
if __name__ == '__main__':
diff --git a/Misc/NEWS.d/next/Library/2018-03-18-17-38-48.bpo-32953.t8WAWN.rst b/Misc/NEWS.d/next/Library/2018-03-18-17-38-48.bpo-32953.t8WAWN.rst
new file mode 100644
index 0000000..fbea34a
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-03-18-17-38-48.bpo-32953.t8WAWN.rst
@@ -0,0 +1,4 @@
+If a non-dataclass inherits from a frozen dataclass, allow attributes to be
+added to the derived class. Only attributes from from the frozen dataclass
+cannot be assigned to. Require all dataclasses in a hierarchy to be either
+all frozen or all non-frozen.