diff options
author | Eric V. Smith <ericvsmith@users.noreply.github.com> | 2018-01-28 00:07:40 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-28 00:07:40 (GMT) |
commit | ea8fc52e75363276db23c6a8d7a689f79efce4f9 (patch) | |
tree | ca662ba631df1f6e6e32b5b0d95a6b5458d5699c /Lib/dataclasses.py | |
parent | 2a2247ce5e1984eb2f2c41b269b38dbb795a60cf (diff) | |
download | cpython-ea8fc52e75363276db23c6a8d7a689f79efce4f9.zip cpython-ea8fc52e75363276db23c6a8d7a689f79efce4f9.tar.gz cpython-ea8fc52e75363276db23c6a8d7a689f79efce4f9.tar.bz2 |
bpo-32513: Make it easier to override dunders in dataclasses. (GH-5366)
Class authors no longer need to specify repr=False if they want to provide a custom __repr__ for dataclasses. The same thing applies for the other dunder methods that the dataclass decorator adds. If dataclass finds that a dunder methods is defined in the class, it will not overwrite it.
Diffstat (limited to 'Lib/dataclasses.py')
-rw-r--r-- | Lib/dataclasses.py | 306 |
1 files changed, 224 insertions, 82 deletions
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 7d30da1..fb279cd 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -18,6 +18,142 @@ __all__ = ['dataclass', 'is_dataclass', ] +# Conditions for adding methods. The boxes indicate what action the +# dataclass decorator takes. For all of these tables, when I talk +# about init=, repr=, eq=, order=, hash=, or frozen=, I'm referring +# to the arguments to the @dataclass decorator. When checking if a +# dunder method already exists, I mean check for an entry in the +# class's __dict__. I never check to see if an attribute is defined +# in a base class. + +# Key: +# +=========+=========================================+ +# + Value | Meaning | +# +=========+=========================================+ +# | <blank> | No action: no method is added. | +# +---------+-----------------------------------------+ +# | add | Generated method is added. | +# +---------+-----------------------------------------+ +# | add* | Generated method is added only if the | +# | | existing attribute is None and if the | +# | | user supplied a __eq__ method in the | +# | | class definition. | +# +---------+-----------------------------------------+ +# | raise | TypeError is raised. | +# +---------+-----------------------------------------+ +# | None | Attribute is set to None. | +# +=========+=========================================+ + +# __init__ +# +# +--- init= parameter +# | +# v | | | +# | no | yes | <--- class has __init__ in __dict__? +# +=======+=======+=======+ +# | False | | | +# +-------+-------+-------+ +# | True | add | | <- the default +# +=======+=======+=======+ + +# __repr__ +# +# +--- repr= parameter +# | +# v | | | +# | no | yes | <--- class has __repr__ in __dict__? +# +=======+=======+=======+ +# | False | | | +# +-------+-------+-------+ +# | True | add | | <- the default +# +=======+=======+=======+ + + +# __setattr__ +# __delattr__ +# +# +--- frozen= parameter +# | +# v | | | +# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__? +# +=======+=======+=======+ +# | False | | | <- the default +# +-------+-------+-------+ +# | True | add | raise | +# +=======+=======+=======+ +# Raise because not adding these methods would break the "frozen-ness" +# of the class. + +# __eq__ +# +# +--- eq= parameter +# | +# v | | | +# | no | yes | <--- class has __eq__ in __dict__? +# +=======+=======+=======+ +# | False | | | +# +-------+-------+-------+ +# | True | add | | <- the default +# +=======+=======+=======+ + +# __lt__ +# __le__ +# __gt__ +# __ge__ +# +# +--- order= parameter +# | +# v | | | +# | no | yes | <--- class has any comparison method in __dict__? +# +=======+=======+=======+ +# | False | | | <- the default +# +-------+-------+-------+ +# | True | add | raise | +# +=======+=======+=======+ +# Raise because to allow this case would interfere with using +# functools.total_ordering. + +# __hash__ + +# +------------------- hash= parameter +# | +----------- eq= parameter +# | | +--- frozen= parameter +# | | | +# v v v | | | +# | no | yes | <--- class has __hash__ in __dict__? +# +=========+=======+=======+========+========+ +# | 1 None | False | False | | | No __eq__, use the base class __hash__ +# +---------+-------+-------+--------+--------+ +# | 2 None | False | True | | | No __eq__, use the base class __hash__ +# +---------+-------+-------+--------+--------+ +# | 3 None | True | False | None | | <-- the default, not hashable +# +---------+-------+-------+--------+--------+ +# | 4 None | True | True | add | add* | Frozen, so hashable +# +---------+-------+-------+--------+--------+ +# | 5 False | False | False | | | +# +---------+-------+-------+--------+--------+ +# | 6 False | False | True | | | +# +---------+-------+-------+--------+--------+ +# | 7 False | True | False | | | +# +---------+-------+-------+--------+--------+ +# | 8 False | True | True | | | +# +---------+-------+-------+--------+--------+ +# | 9 True | False | False | add | add* | Has no __eq__, but hashable +# +---------+-------+-------+--------+--------+ +# |10 True | False | True | add | add* | Has no __eq__, but hashable +# +---------+-------+-------+--------+--------+ +# |11 True | True | False | add | add* | Not frozen, but hashable +# +---------+-------+-------+--------+--------+ +# |12 True | True | True | add | add* | Frozen, so hashable +# +=========+=======+=======+========+========+ +# For boxes that are blank, __hash__ is untouched and therefore +# inherited from the base class. If the base is object, then +# id-based hashing is used. +# Note that a class may have already __hash__=None if it specified an +# __eq__ method in the class body (not one that was created by +# @dataclass). + + # Raised when an attempt is made to modify a frozen class. class FrozenInstanceError(AttributeError): pass @@ -143,13 +279,13 @@ def _tuple_str(obj_name, fields): # return "(self.x,self.y)". # Special case for the 0-tuple. - if len(fields) == 0: + if not fields: return '()' # Note the trailing comma, needed if this turns out to be a 1-tuple. return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' -def _create_fn(name, args, body, globals=None, locals=None, +def _create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING): # Note that we mutate locals when exec() is called. Caller beware! if locals is None: @@ -287,7 +423,7 @@ def _init_fn(fields, frozen, has_post_init, self_name): body_lines += [f'{self_name}.{_POST_INIT_NAME}({params_str})'] # If no body lines, use 'pass'. - if len(body_lines) == 0: + if not body_lines: body_lines = ['pass'] locals = {f'_type_{f.name}': f.type for f in fields} @@ -329,32 +465,6 @@ def _cmp_fn(name, op, self_tuple, other_tuple): 'return NotImplemented']) -def _set_eq_fns(cls, fields): - # Create and set the equality comparison methods on cls. - # Pre-compute self_tuple and other_tuple, then re-use them for - # each function. - self_tuple = _tuple_str('self', fields) - other_tuple = _tuple_str('other', fields) - for name, op in [('__eq__', '=='), - ('__ne__', '!='), - ]: - _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple)) - - -def _set_order_fns(cls, fields): - # Create and set the ordering methods on cls. - # Pre-compute self_tuple and other_tuple, then re-use them for - # each function. - self_tuple = _tuple_str('self', fields) - other_tuple = _tuple_str('other', fields) - for name, op in [('__lt__', '<'), - ('__le__', '<='), - ('__gt__', '>'), - ('__ge__', '>='), - ]: - _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple)) - - def _hash_fn(fields): self_tuple = _tuple_str('self', fields) return _create_fn('__hash__', @@ -431,20 +541,20 @@ def _find_fields(cls): # a Field(), then it contains additional info beyond (and # possibly including) the actual default value. Pseudo-fields # ClassVars and InitVars are included, despite the fact that - # they're not real fields. That's deal with later. + # they're not real fields. That's dealt with later. annotations = getattr(cls, '__annotations__', {}) - return [_get_field(cls, a_name, a_type) for a_name, a_type in annotations.items()] -def _set_attribute(cls, name, value): - # Raise TypeError if an attribute by this name already exists. +def _set_new_attribute(cls, name, value): + # Never overwrites an existing attribute. Returns True if the + # attribute already exists. if name in cls.__dict__: - raise TypeError(f'Cannot overwrite attribute {name} ' - f'in {cls.__name__}') + return True setattr(cls, name, value) + return False def _process_class(cls, repr, eq, order, hash, init, frozen): @@ -495,6 +605,9 @@ def _process_class(cls, repr, eq, order, hash, init, frozen): # be inherited down. is_frozen = frozen or cls.__setattr__ is _frozen_setattr + # Was this class defined with an __eq__? Used in __hash__ logic. + auto_hash_test= '__eq__' in cls.__dict__ and getattr(cls.__dict__, '__hash__', MISSING) is None + # If we're generating ordering methods, we must be generating # the eq methods. if order and not eq: @@ -505,62 +618,91 @@ def _process_class(cls, repr, eq, order, hash, init, frozen): has_post_init = hasattr(cls, _POST_INIT_NAME) # Include InitVars and regular fields (so, not ClassVars). - _set_attribute(cls, '__init__', - _init_fn(list(filter(lambda f: f._field_type - in (_FIELD, _FIELD_INITVAR), - fields.values())), - is_frozen, - has_post_init, - # The name to use for the "self" param - # in __init__. Use "self" if possible. - '__dataclass_self__' if 'self' in fields - else 'self', - )) + flds = [f for f in fields.values() + if f._field_type in (_FIELD, _FIELD_INITVAR)] + _set_new_attribute(cls, '__init__', + _init_fn(flds, + is_frozen, + has_post_init, + # The name to use for the "self" param + # in __init__. Use "self" if possible. + '__dataclass_self__' if 'self' in fields + else 'self', + )) # Get the fields as a list, and include only real fields. This is # used in all of the following methods. - field_list = list(filter(lambda f: f._field_type is _FIELD, - fields.values())) + field_list = [f for f in fields.values() if f._field_type is _FIELD] if repr: - _set_attribute(cls, '__repr__', - _repr_fn(list(filter(lambda f: f.repr, field_list)))) - - if is_frozen: - _set_attribute(cls, '__setattr__', _frozen_setattr) - _set_attribute(cls, '__delattr__', _frozen_delattr) - - generate_hash = False - if hash is None: - if eq and frozen: - # Generate a hash function. - generate_hash = True - elif eq and not frozen: - # Not hashable. - _set_attribute(cls, '__hash__', None) - elif not eq: - # Otherwise, use the base class definition of hash(). That is, - # don't set anything on this class. - pass - else: - assert "can't get here" - else: - generate_hash = hash - if generate_hash: - _set_attribute(cls, '__hash__', - _hash_fn(list(filter(lambda f: f.compare - if f.hash is None - else f.hash, - field_list)))) + flds = [f for f in field_list if f.repr] + _set_new_attribute(cls, '__repr__', _repr_fn(flds)) if eq: - # Create and __eq__ and __ne__ methods. - _set_eq_fns(cls, list(filter(lambda f: f.compare, field_list))) + # Create _eq__ method. There's no need for a __ne__ method, + # since python will call __eq__ and negate it. + flds = [f for f in field_list if f.compare] + self_tuple = _tuple_str('self', flds) + other_tuple = _tuple_str('other', flds) + _set_new_attribute(cls, '__eq__', + _cmp_fn('__eq__', '==', + self_tuple, other_tuple)) if order: - # Create and __lt__, __le__, __gt__, and __ge__ methods. - # Create and set the comparison functions. - _set_order_fns(cls, list(filter(lambda f: f.compare, field_list))) + # Create and set the ordering methods. + flds = [f for f in field_list if f.compare] + self_tuple = _tuple_str('self', flds) + other_tuple = _tuple_str('other', flds) + for name, op in [('__lt__', '<'), + ('__le__', '<='), + ('__gt__', '>'), + ('__ge__', '>='), + ]: + if _set_new_attribute(cls, name, + _cmp_fn(name, op, self_tuple, other_tuple)): + raise TypeError(f'Cannot overwrite attribute {name} ' + f'in {cls.__name__}. Consider using ' + 'functools.total_ordering') + + if is_frozen: + for name, fn in [('__setattr__', _frozen_setattr), + ('__delattr__', _frozen_delattr)]: + if _set_new_attribute(cls, name, fn): + raise TypeError(f'Cannot overwrite attribute {name} ' + f'in {cls.__name__}') + + # Decide if/how we're going to create a hash function. + # TODO: Move this table to module scope, so it's not recreated + # all the time. + generate_hash = {(None, False, False): ('', ''), + (None, False, True): ('', ''), + (None, True, False): ('none', ''), + (None, True, True): ('fn', 'fn-x'), + (False, False, False): ('', ''), + (False, False, True): ('', ''), + (False, True, False): ('', ''), + (False, True, True): ('', ''), + (True, False, False): ('fn', 'fn-x'), + (True, False, True): ('fn', 'fn-x'), + (True, True, False): ('fn', 'fn-x'), + (True, True, True): ('fn', 'fn-x'), + }[None if hash is None else bool(hash), # Force bool() if not None. + bool(eq), + bool(frozen)]['__hash__' in cls.__dict__] + # No need to call _set_new_attribute here, since we already know if + # we're overwriting a __hash__ or not. + if generate_hash == '': + # Do nothing. + pass + elif generate_hash == 'none': + cls.__hash__ = None + elif generate_hash in ('fn', 'fn-x'): + if generate_hash == 'fn' or auto_hash_test: + flds = [f for f in field_list + if (f.compare if f.hash is None else f.hash)] + cls.__hash__ = _hash_fn(flds) + else: + assert False, f"can't get here: {generate_hash}" if not getattr(cls, '__doc__'): # Create a class doc-string. |