diff options
-rw-r--r-- | Lib/dataclasses.py | 306 | ||||
-rwxr-xr-x | Lib/test/test_dataclasses.py | 664 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2018-01-27-11-20-16.bpo-32513.ak-iD2.rst | 2 |
3 files changed, 678 insertions, 294 deletions
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 7d30da1..fb279cd 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -18,6 +18,142 @@ __all__ = ['dataclass', 'is_dataclass', ] +# Conditions for adding methods. The boxes indicate what action the +# dataclass decorator takes. For all of these tables, when I talk +# about init=, repr=, eq=, order=, hash=, or frozen=, I'm referring +# to the arguments to the @dataclass decorator. When checking if a +# dunder method already exists, I mean check for an entry in the +# class's __dict__. I never check to see if an attribute is defined +# in a base class. + +# Key: +# +=========+=========================================+ +# + Value | Meaning | +# +=========+=========================================+ +# | <blank> | No action: no method is added. | +# +---------+-----------------------------------------+ +# | add | Generated method is added. | +# +---------+-----------------------------------------+ +# | add* | Generated method is added only if the | +# | | existing attribute is None and if the | +# | | user supplied a __eq__ method in the | +# | | class definition. | +# +---------+-----------------------------------------+ +# | raise | TypeError is raised. | +# +---------+-----------------------------------------+ +# | None | Attribute is set to None. | +# +=========+=========================================+ + +# __init__ +# +# +--- init= parameter +# | +# v | | | +# | no | yes | <--- class has __init__ in __dict__? +# +=======+=======+=======+ +# | False | | | +# +-------+-------+-------+ +# | True | add | | <- the default +# +=======+=======+=======+ + +# __repr__ +# +# +--- repr= parameter +# | +# v | | | +# | no | yes | <--- class has __repr__ in __dict__? +# +=======+=======+=======+ +# | False | | | +# +-------+-------+-------+ +# | True | add | | <- the default +# +=======+=======+=======+ + + +# __setattr__ +# __delattr__ +# +# +--- frozen= parameter +# | +# v | | | +# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__? +# +=======+=======+=======+ +# | False | | | <- the default +# +-------+-------+-------+ +# | True | add | raise | +# +=======+=======+=======+ +# Raise because not adding these methods would break the "frozen-ness" +# of the class. + +# __eq__ +# +# +--- eq= parameter +# | +# v | | | +# | no | yes | <--- class has __eq__ in __dict__? +# +=======+=======+=======+ +# | False | | | +# +-------+-------+-------+ +# | True | add | | <- the default +# +=======+=======+=======+ + +# __lt__ +# __le__ +# __gt__ +# __ge__ +# +# +--- order= parameter +# | +# v | | | +# | no | yes | <--- class has any comparison method in __dict__? +# +=======+=======+=======+ +# | False | | | <- the default +# +-------+-------+-------+ +# | True | add | raise | +# +=======+=======+=======+ +# Raise because to allow this case would interfere with using +# functools.total_ordering. + +# __hash__ + +# +------------------- hash= parameter +# | +----------- eq= parameter +# | | +--- frozen= parameter +# | | | +# v v v | | | +# | no | yes | <--- class has __hash__ in __dict__? +# +=========+=======+=======+========+========+ +# | 1 None | False | False | | | No __eq__, use the base class __hash__ +# +---------+-------+-------+--------+--------+ +# | 2 None | False | True | | | No __eq__, use the base class __hash__ +# +---------+-------+-------+--------+--------+ +# | 3 None | True | False | None | | <-- the default, not hashable +# +---------+-------+-------+--------+--------+ +# | 4 None | True | True | add | add* | Frozen, so hashable +# +---------+-------+-------+--------+--------+ +# | 5 False | False | False | | | +# +---------+-------+-------+--------+--------+ +# | 6 False | False | True | | | +# +---------+-------+-------+--------+--------+ +# | 7 False | True | False | | | +# +---------+-------+-------+--------+--------+ +# | 8 False | True | True | | | +# +---------+-------+-------+--------+--------+ +# | 9 True | False | False | add | add* | Has no __eq__, but hashable +# +---------+-------+-------+--------+--------+ +# |10 True | False | True | add | add* | Has no __eq__, but hashable +# +---------+-------+-------+--------+--------+ +# |11 True | True | False | add | add* | Not frozen, but hashable +# +---------+-------+-------+--------+--------+ +# |12 True | True | True | add | add* | Frozen, so hashable +# +=========+=======+=======+========+========+ +# For boxes that are blank, __hash__ is untouched and therefore +# inherited from the base class. If the base is object, then +# id-based hashing is used. +# Note that a class may have already __hash__=None if it specified an +# __eq__ method in the class body (not one that was created by +# @dataclass). + + # Raised when an attempt is made to modify a frozen class. class FrozenInstanceError(AttributeError): pass @@ -143,13 +279,13 @@ def _tuple_str(obj_name, fields): # return "(self.x,self.y)". # Special case for the 0-tuple. - if len(fields) == 0: + if not fields: return '()' # Note the trailing comma, needed if this turns out to be a 1-tuple. return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' -def _create_fn(name, args, body, globals=None, locals=None, +def _create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING): # Note that we mutate locals when exec() is called. Caller beware! if locals is None: @@ -287,7 +423,7 @@ def _init_fn(fields, frozen, has_post_init, self_name): body_lines += [f'{self_name}.{_POST_INIT_NAME}({params_str})'] # If no body lines, use 'pass'. - if len(body_lines) == 0: + if not body_lines: body_lines = ['pass'] locals = {f'_type_{f.name}': f.type for f in fields} @@ -329,32 +465,6 @@ def _cmp_fn(name, op, self_tuple, other_tuple): 'return NotImplemented']) -def _set_eq_fns(cls, fields): - # Create and set the equality comparison methods on cls. - # Pre-compute self_tuple and other_tuple, then re-use them for - # each function. - self_tuple = _tuple_str('self', fields) - other_tuple = _tuple_str('other', fields) - for name, op in [('__eq__', '=='), - ('__ne__', '!='), - ]: - _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple)) - - -def _set_order_fns(cls, fields): - # Create and set the ordering methods on cls. - # Pre-compute self_tuple and other_tuple, then re-use them for - # each function. - self_tuple = _tuple_str('self', fields) - other_tuple = _tuple_str('other', fields) - for name, op in [('__lt__', '<'), - ('__le__', '<='), - ('__gt__', '>'), - ('__ge__', '>='), - ]: - _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple)) - - def _hash_fn(fields): self_tuple = _tuple_str('self', fields) return _create_fn('__hash__', @@ -431,20 +541,20 @@ def _find_fields(cls): # a Field(), then it contains additional info beyond (and # possibly including) the actual default value. Pseudo-fields # ClassVars and InitVars are included, despite the fact that - # they're not real fields. That's deal with later. + # they're not real fields. That's dealt with later. annotations = getattr(cls, '__annotations__', {}) - return [_get_field(cls, a_name, a_type) for a_name, a_type in annotations.items()] -def _set_attribute(cls, name, value): - # Raise TypeError if an attribute by this name already exists. +def _set_new_attribute(cls, name, value): + # Never overwrites an existing attribute. Returns True if the + # attribute already exists. if name in cls.__dict__: - raise TypeError(f'Cannot overwrite attribute {name} ' - f'in {cls.__name__}') + return True setattr(cls, name, value) + return False def _process_class(cls, repr, eq, order, hash, init, frozen): @@ -495,6 +605,9 @@ def _process_class(cls, repr, eq, order, hash, init, frozen): # be inherited down. is_frozen = frozen or cls.__setattr__ is _frozen_setattr + # Was this class defined with an __eq__? Used in __hash__ logic. + auto_hash_test= '__eq__' in cls.__dict__ and getattr(cls.__dict__, '__hash__', MISSING) is None + # If we're generating ordering methods, we must be generating # the eq methods. if order and not eq: @@ -505,62 +618,91 @@ def _process_class(cls, repr, eq, order, hash, init, frozen): has_post_init = hasattr(cls, _POST_INIT_NAME) # Include InitVars and regular fields (so, not ClassVars). - _set_attribute(cls, '__init__', - _init_fn(list(filter(lambda f: f._field_type - in (_FIELD, _FIELD_INITVAR), - fields.values())), - is_frozen, - has_post_init, - # The name to use for the "self" param - # in __init__. Use "self" if possible. - '__dataclass_self__' if 'self' in fields - else 'self', - )) + flds = [f for f in fields.values() + if f._field_type in (_FIELD, _FIELD_INITVAR)] + _set_new_attribute(cls, '__init__', + _init_fn(flds, + is_frozen, + has_post_init, + # The name to use for the "self" param + # in __init__. Use "self" if possible. + '__dataclass_self__' if 'self' in fields + else 'self', + )) # Get the fields as a list, and include only real fields. This is # used in all of the following methods. - field_list = list(filter(lambda f: f._field_type is _FIELD, - fields.values())) + field_list = [f for f in fields.values() if f._field_type is _FIELD] if repr: - _set_attribute(cls, '__repr__', - _repr_fn(list(filter(lambda f: f.repr, field_list)))) - - if is_frozen: - _set_attribute(cls, '__setattr__', _frozen_setattr) - _set_attribute(cls, '__delattr__', _frozen_delattr) - - generate_hash = False - if hash is None: - if eq and frozen: - # Generate a hash function. - generate_hash = True - elif eq and not frozen: - # Not hashable. - _set_attribute(cls, '__hash__', None) - elif not eq: - # Otherwise, use the base class definition of hash(). That is, - # don't set anything on this class. - pass - else: - assert "can't get here" - else: - generate_hash = hash - if generate_hash: - _set_attribute(cls, '__hash__', - _hash_fn(list(filter(lambda f: f.compare - if f.hash is None - else f.hash, - field_list)))) + flds = [f for f in field_list if f.repr] + _set_new_attribute(cls, '__repr__', _repr_fn(flds)) if eq: - # Create and __eq__ and __ne__ methods. - _set_eq_fns(cls, list(filter(lambda f: f.compare, field_list))) + # Create _eq__ method. There's no need for a __ne__ method, + # since python will call __eq__ and negate it. + flds = [f for f in field_list if f.compare] + self_tuple = _tuple_str('self', flds) + other_tuple = _tuple_str('other', flds) + _set_new_attribute(cls, '__eq__', + _cmp_fn('__eq__', '==', + self_tuple, other_tuple)) if order: - # Create and __lt__, __le__, __gt__, and __ge__ methods. - # Create and set the comparison functions. - _set_order_fns(cls, list(filter(lambda f: f.compare, field_list))) + # Create and set the ordering methods. + flds = [f for f in field_list if f.compare] + self_tuple = _tuple_str('self', flds) + other_tuple = _tuple_str('other', flds) + for name, op in [('__lt__', '<'), + ('__le__', '<='), + ('__gt__', '>'), + ('__ge__', '>='), + ]: + if _set_new_attribute(cls, name, + _cmp_fn(name, op, self_tuple, other_tuple)): + raise TypeError(f'Cannot overwrite attribute {name} ' + f'in {cls.__name__}. Consider using ' + 'functools.total_ordering') + + if is_frozen: + for name, fn in [('__setattr__', _frozen_setattr), + ('__delattr__', _frozen_delattr)]: + if _set_new_attribute(cls, name, fn): + raise TypeError(f'Cannot overwrite attribute {name} ' + f'in {cls.__name__}') + + # Decide if/how we're going to create a hash function. + # TODO: Move this table to module scope, so it's not recreated + # all the time. + generate_hash = {(None, False, False): ('', ''), + (None, False, True): ('', ''), + (None, True, False): ('none', ''), + (None, True, True): ('fn', 'fn-x'), + (False, False, False): ('', ''), + (False, False, True): ('', ''), + (False, True, False): ('', ''), + (False, True, True): ('', ''), + (True, False, False): ('fn', 'fn-x'), + (True, False, True): ('fn', 'fn-x'), + (True, True, False): ('fn', 'fn-x'), + (True, True, True): ('fn', 'fn-x'), + }[None if hash is None else bool(hash), # Force bool() if not None. + bool(eq), + bool(frozen)]['__hash__' in cls.__dict__] + # No need to call _set_new_attribute here, since we already know if + # we're overwriting a __hash__ or not. + if generate_hash == '': + # Do nothing. + pass + elif generate_hash == 'none': + cls.__hash__ = None + elif generate_hash in ('fn', 'fn-x'): + if generate_hash == 'fn' or auto_hash_test: + flds = [f for f in field_list + if (f.compare if f.hash is None else f.hash)] + cls.__hash__ = _hash_fn(flds) + else: + assert False, f"can't get here: {generate_hash}" if not getattr(cls, '__doc__'): # Create a class doc-string. diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index 69819ea..53281f9 100755 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -9,6 +9,7 @@ import unittest from unittest.mock import Mock from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar from collections import deque, OrderedDict, namedtuple +from functools import total_ordering # Just any custom exception we can catch. class CustomError(Exception): pass @@ -82,68 +83,12 @@ class TestCase(unittest.TestCase): class C(B): x: int = 0 - def test_overwriting_init(self): - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __init__ ' - 'in C'): - @dataclass - class C: - x: int - def __init__(self, x): - self.x = 2 * x - - @dataclass(init=False) - class C: - x: int - def __init__(self, x): - self.x = 2 * x - self.assertEqual(C(5).x, 10) - - def test_overwriting_repr(self): - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __repr__ ' - 'in C'): - @dataclass - class C: - x: int - def __repr__(self): - pass - - @dataclass(repr=False) - class C: - x: int - def __repr__(self): - return 'x' - self.assertEqual(repr(C(0)), 'x') - - def test_overwriting_cmp(self): - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __eq__ ' - 'in C'): - # This will generate the comparison functions, make sure we can't - # overwrite them. - @dataclass(hash=False, frozen=False) - class C: - x: int - def __eq__(self): - pass - - @dataclass(order=False, eq=False) + def test_overwriting_hash(self): + @dataclass(frozen=True) class C: x: int - def __eq__(self, other): - return True - self.assertEqual(C(0), 'x') - - def test_overwriting_hash(self): - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __hash__ ' - 'in C'): - @dataclass(frozen=True) - class C: - x: int - def __hash__(self): - pass + def __hash__(self): + pass @dataclass(frozen=True,hash=False) class C: @@ -152,14 +97,11 @@ class TestCase(unittest.TestCase): return 600 self.assertEqual(hash(C(0)), 600) - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __hash__ ' - 'in C'): - @dataclass(frozen=True) - class C: - x: int - def __hash__(self): - pass + @dataclass(frozen=True) + class C: + x: int + def __hash__(self): + pass @dataclass(frozen=True, hash=False) class C: @@ -168,33 +110,6 @@ class TestCase(unittest.TestCase): return 600 self.assertEqual(hash(C(0)), 600) - def test_overwriting_frozen(self): - # frozen uses __setattr__ and __delattr__ - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __setattr__ ' - 'in C'): - @dataclass(frozen=True) - class C: - x: int - def __setattr__(self): - pass - - with self.assertRaisesRegex(TypeError, - 'Cannot overwrite attribute __delattr__ ' - 'in C'): - @dataclass(frozen=True) - class C: - x: int - def __delattr__(self): - pass - - @dataclass(frozen=False) - class C: - x: int - def __setattr__(self, name, value): - self.__dict__['x'] = value * 2 - self.assertEqual(C(10).x, 20) - def test_overwrite_fields_in_derived_class(self): # Note that x from C1 replaces x in Base, but the order remains # the same as defined in Base. @@ -239,34 +154,6 @@ class TestCase(unittest.TestCase): first = next(iter(sig.parameters)) self.assertEqual('self', first) - def test_repr(self): - @dataclass - class B: - x: int - - @dataclass - class C(B): - y: int = 10 - - o = C(4) - self.assertEqual(repr(o), 'TestCase.test_repr.<locals>.C(x=4, y=10)') - - @dataclass - class D(C): - x: int = 20 - self.assertEqual(repr(D()), 'TestCase.test_repr.<locals>.D(x=20, y=10)') - - @dataclass - class C: - @dataclass - class D: - i: int - @dataclass - class E: - pass - self.assertEqual(repr(C.D(0)), 'TestCase.test_repr.<locals>.C.D(i=0)') - self.assertEqual(repr(C.E()), 'TestCase.test_repr.<locals>.C.E()') - def test_0_field_compare(self): # Ensure that order=False is the default. @dataclass @@ -420,80 +307,8 @@ class TestCase(unittest.TestCase): self.assertEqual(hash(C(4)), hash((4,))) self.assertEqual(hash(C(42)), hash((42,))) - def test_hash(self): - @dataclass(hash=True) - class C: - x: int - y: str - self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo'))) - - def test_no_hash(self): - @dataclass(hash=None) - class C: - x: int - with self.assertRaisesRegex(TypeError, - "unhashable type: 'C'"): - hash(C(1)) - - def test_hash_rules(self): - # There are 24 cases of: - # hash=True/False/None - # eq=True/False - # order=True/False - # frozen=True/False - for (hash, eq, order, frozen, result ) in [ - (False, False, False, False, 'absent'), - (False, False, False, True, 'absent'), - (False, False, True, False, 'exception'), - (False, False, True, True, 'exception'), - (False, True, False, False, 'absent'), - (False, True, False, True, 'absent'), - (False, True, True, False, 'absent'), - (False, True, True, True, 'absent'), - (True, False, False, False, 'fn'), - (True, False, False, True, 'fn'), - (True, False, True, False, 'exception'), - (True, False, True, True, 'exception'), - (True, True, False, False, 'fn'), - (True, True, False, True, 'fn'), - (True, True, True, False, 'fn'), - (True, True, True, True, 'fn'), - (None, False, False, False, 'absent'), - (None, False, False, True, 'absent'), - (None, False, True, False, 'exception'), - (None, False, True, True, 'exception'), - (None, True, False, False, 'none'), - (None, True, False, True, 'fn'), - (None, True, True, False, 'none'), - (None, True, True, True, 'fn'), - ]: - with self.subTest(hash=hash, eq=eq, order=order, frozen=frozen): - if result == 'exception': - with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'): - @dataclass(hash=hash, eq=eq, order=order, frozen=frozen) - class C: - pass - else: - @dataclass(hash=hash, eq=eq, order=order, frozen=frozen) - class C: - pass - - # See if the result matches what's expected. - if result == 'fn': - # __hash__ contains the function we generated. - self.assertIn('__hash__', C.__dict__) - self.assertIsNotNone(C.__dict__['__hash__']) - elif result == 'absent': - # __hash__ is not present in our class. - self.assertNotIn('__hash__', C.__dict__) - elif result == 'none': - # __hash__ is set to None. - self.assertIn('__hash__', C.__dict__) - self.assertIsNone(C.__dict__['__hash__']) - else: - assert False, f'unknown result {result!r}' - def test_eq_order(self): + # Test combining eq and order. for (eq, order, result ) in [ (False, False, 'neither'), (False, True, 'exception'), @@ -513,21 +328,18 @@ class TestCase(unittest.TestCase): if result == 'neither': self.assertNotIn('__eq__', C.__dict__) - self.assertNotIn('__ne__', C.__dict__) self.assertNotIn('__lt__', C.__dict__) self.assertNotIn('__le__', C.__dict__) self.assertNotIn('__gt__', C.__dict__) self.assertNotIn('__ge__', C.__dict__) elif result == 'both': self.assertIn('__eq__', C.__dict__) - self.assertIn('__ne__', C.__dict__) self.assertIn('__lt__', C.__dict__) self.assertIn('__le__', C.__dict__) self.assertIn('__gt__', C.__dict__) self.assertIn('__ge__', C.__dict__) elif result == 'eq_only': self.assertIn('__eq__', C.__dict__) - self.assertIn('__ne__', C.__dict__) self.assertNotIn('__lt__', C.__dict__) self.assertNotIn('__le__', C.__dict__) self.assertNotIn('__gt__', C.__dict__) @@ -811,19 +623,6 @@ class TestCase(unittest.TestCase): y: int self.assertNotEqual(Point(1, 3), C(1, 3)) - def test_base_has_init(self): - class B: - def __init__(self): - pass - - # Make sure that declaring this class doesn't raise an error. - # The issue is that we can't override __init__ in our class, - # but it should be okay to add __init__ to us if our base has - # an __init__. - @dataclass - class C(B): - x: int = 0 - def test_frozen(self): @dataclass(frozen=True) class C: @@ -2065,6 +1864,7 @@ class TestCase(unittest.TestCase): 'y': int, 'z': 'typing.Any'}) + class TestDocString(unittest.TestCase): def assertDocStrEqual(self, a, b): # Because 3.6 and 3.7 differ in how inspect.signature work @@ -2154,5 +1954,445 @@ class TestDocString(unittest.TestCase): self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)") +class TestInit(unittest.TestCase): + def test_base_has_init(self): + class B: + def __init__(self): + self.z = 100 + pass + + # Make sure that declaring this class doesn't raise an error. + # The issue is that we can't override __init__ in our class, + # but it should be okay to add __init__ to us if our base has + # an __init__. + @dataclass + class C(B): + x: int = 0 + c = C(10) + self.assertEqual(c.x, 10) + self.assertNotIn('z', vars(c)) + + # Make sure that if we don't add an init, the base __init__ + # gets called. + @dataclass(init=False) + class C(B): + x: int = 10 + c = C() + self.assertEqual(c.x, 10) + self.assertEqual(c.z, 100) + + def test_no_init(self): + dataclass(init=False) + class C: + i: int = 0 + self.assertEqual(C().i, 0) + + dataclass(init=False) + class C: + i: int = 2 + def __init__(self): + self.i = 3 + self.assertEqual(C().i, 3) + + def test_overwriting_init(self): + # If the class has __init__, use it no matter the value of + # init=. + + @dataclass + class C: + x: int + def __init__(self, x): + self.x = 2 * x + self.assertEqual(C(3).x, 6) + + @dataclass(init=True) + class C: + x: int + def __init__(self, x): + self.x = 2 * x + self.assertEqual(C(4).x, 8) + + @dataclass(init=False) + class C: + x: int + def __init__(self, x): + self.x = 2 * x + self.assertEqual(C(5).x, 10) + + +class TestRepr(unittest.TestCase): + def test_repr(self): + @dataclass + class B: + x: int + + @dataclass + class C(B): + y: int = 10 + + o = C(4) + self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)') + + @dataclass + class D(C): + x: int = 20 + self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)') + + @dataclass + class C: + @dataclass + class D: + i: int + @dataclass + class E: + pass + self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)') + self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()') + + def test_no_repr(self): + # Test a class with no __repr__ and repr=False. + @dataclass(repr=False) + class C: + x: int + self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at', + repr(C(3))) + + # Test a class with a __repr__ and repr=False. + @dataclass(repr=False) + class C: + x: int + def __repr__(self): + return 'C-class' + self.assertEqual(repr(C(3)), 'C-class') + + def test_overwriting_repr(self): + # If the class has __repr__, use it no matter the value of + # repr=. + + @dataclass + class C: + x: int + def __repr__(self): + return 'x' + self.assertEqual(repr(C(0)), 'x') + + @dataclass(repr=True) + class C: + x: int + def __repr__(self): + return 'x' + self.assertEqual(repr(C(0)), 'x') + + @dataclass(repr=False) + class C: + x: int + def __repr__(self): + return 'x' + self.assertEqual(repr(C(0)), 'x') + + +class TestFrozen(unittest.TestCase): + def test_overwriting_frozen(self): + # frozen uses __setattr__ and __delattr__ + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __setattr__'): + @dataclass(frozen=True) + class C: + x: int + def __setattr__(self): + pass + + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __delattr__'): + @dataclass(frozen=True) + class C: + x: int + def __delattr__(self): + pass + + @dataclass(frozen=False) + class C: + x: int + def __setattr__(self, name, value): + self.__dict__['x'] = value * 2 + self.assertEqual(C(10).x, 20) + + +class TestEq(unittest.TestCase): + def test_no_eq(self): + # Test a class with no __eq__ and eq=False. + @dataclass(eq=False) + class C: + x: int + self.assertNotEqual(C(0), C(0)) + c = C(3) + self.assertEqual(c, c) + + # Test a class with an __eq__ and eq=False. + @dataclass(eq=False) + class C: + x: int + def __eq__(self, other): + return other == 10 + self.assertEqual(C(3), 10) + + def test_overwriting_eq(self): + # If the class has __eq__, use it no matter the value of + # eq=. + + @dataclass + class C: + x: int + def __eq__(self, other): + return other == 3 + self.assertEqual(C(1), 3) + self.assertNotEqual(C(1), 1) + + @dataclass(eq=True) + class C: + x: int + def __eq__(self, other): + return other == 4 + self.assertEqual(C(1), 4) + self.assertNotEqual(C(1), 1) + + @dataclass(eq=False) + class C: + x: int + def __eq__(self, other): + return other == 5 + self.assertEqual(C(1), 5) + self.assertNotEqual(C(1), 1) + + +class TestOrdering(unittest.TestCase): + def test_functools_total_ordering(self): + # Test that functools.total_ordering works with this class. + @total_ordering + @dataclass + class C: + x: int + def __lt__(self, other): + # Perform the test "backward", just to make + # sure this is being called. + return self.x >= other + + self.assertLess(C(0), -1) + self.assertLessEqual(C(0), -1) + self.assertGreater(C(0), 1) + self.assertGreaterEqual(C(0), 1) + + def test_no_order(self): + # Test that no ordering functions are added by default. + @dataclass(order=False) + class C: + x: int + # Make sure no order methods are added. + self.assertNotIn('__le__', C.__dict__) + self.assertNotIn('__lt__', C.__dict__) + self.assertNotIn('__ge__', C.__dict__) + self.assertNotIn('__gt__', C.__dict__) + + # Test that __lt__ is still called + @dataclass(order=False) + class C: + x: int + def __lt__(self, other): + return False + # Make sure other methods aren't added. + self.assertNotIn('__le__', C.__dict__) + self.assertNotIn('__ge__', C.__dict__) + self.assertNotIn('__gt__', C.__dict__) + + def test_overwriting_order(self): + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __lt__' + '.*using functools.total_ordering'): + @dataclass(order=True) + class C: + x: int + def __lt__(self): + pass + + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __le__' + '.*using functools.total_ordering'): + @dataclass(order=True) + class C: + x: int + def __le__(self): + pass + + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __gt__' + '.*using functools.total_ordering'): + @dataclass(order=True) + class C: + x: int + def __gt__(self): + pass + + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __ge__' + '.*using functools.total_ordering'): + @dataclass(order=True) + class C: + x: int + def __ge__(self): + pass + +class TestHash(unittest.TestCase): + def test_hash(self): + @dataclass(hash=True) + class C: + x: int + y: str + self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo'))) + + def test_hash_false(self): + @dataclass(hash=False) + class C: + x: int + y: str + self.assertNotEqual(hash(C(1, 'foo')), hash((1, 'foo'))) + + def test_hash_none(self): + @dataclass(hash=None) + class C: + x: int + with self.assertRaisesRegex(TypeError, + "unhashable type: 'C'"): + hash(C(1)) + + def test_hash_rules(self): + def non_bool(value): + # Map to something else that's True, but not a bool. + if value is None: + return None + if value: + return (3,) + return 0 + + def test(case, hash, eq, frozen, with_hash, result): + with self.subTest(case=case, hash=hash, eq=eq, frozen=frozen): + if with_hash: + @dataclass(hash=hash, eq=eq, frozen=frozen) + class C: + def __hash__(self): + return 0 + else: + @dataclass(hash=hash, eq=eq, frozen=frozen) + class C: + pass + + # See if the result matches what's expected. + if result in ('fn', 'fn-x'): + # __hash__ contains the function we generated. + self.assertIn('__hash__', C.__dict__) + self.assertIsNotNone(C.__dict__['__hash__']) + + if result == 'fn-x': + # This is the "auto-hash test" case. We + # should overwrite __hash__ iff there's an + # __eq__ and if __hash__=None. + + # There are two ways of getting __hash__=None: + # explicitely, and by defining __eq__. If + # __eq__ is defined, python will add __hash__ + # when the class is created. + @dataclass(hash=hash, eq=eq, frozen=frozen) + class C: + def __eq__(self, other): pass + __hash__ = None + + # Hash should be overwritten (non-None). + self.assertIsNotNone(C.__dict__['__hash__']) + + # Same test as above, but we don't provide + # __hash__, it will implicitely set to None. + @dataclass(hash=hash, eq=eq, frozen=frozen) + class C: + def __eq__(self, other): pass + + # Hash should be overwritten (non-None). + self.assertIsNotNone(C.__dict__['__hash__']) + + elif result == '': + # __hash__ is not present in our class. + if not with_hash: + self.assertNotIn('__hash__', C.__dict__) + elif result == 'none': + # __hash__ is set to None. + self.assertIn('__hash__', C.__dict__) + self.assertIsNone(C.__dict__['__hash__']) + else: + assert False, f'unknown result {result!r}' + + # There are 12 cases of: + # hash=True/False/None + # eq=True/False + # frozen=True/False + # And for each of these, a different result if + # __hash__ is defined or not. + for case, (hash, eq, frozen, result_no, result_yes) in enumerate([ + (None, False, False, '', ''), + (None, False, True, '', ''), + (None, True, False, 'none', ''), + (None, True, True, 'fn', 'fn-x'), + (False, False, False, '', ''), + (False, False, True, '', ''), + (False, True, False, '', ''), + (False, True, True, '', ''), + (True, False, False, 'fn', 'fn-x'), + (True, False, True, 'fn', 'fn-x'), + (True, True, False, 'fn', 'fn-x'), + (True, True, True, 'fn', 'fn-x'), + ], 1): + test(case, hash, eq, frozen, False, result_no) + test(case, hash, eq, frozen, True, result_yes) + + # Test non-bool truth values, too. This is just to + # make sure the data-driven table in the decorator + # handles non-bool values. + test(case, non_bool(hash), non_bool(eq), non_bool(frozen), False, result_no) + test(case, non_bool(hash), non_bool(eq), non_bool(frozen), True, result_yes) + + + def test_eq_only(self): + # If a class defines __eq__, __hash__ is automatically added + # and set to None. This is normal Python behavior, not + # related to dataclasses. Make sure we don't interfere with + # that (see bpo=32546). + + @dataclass + class C: + i: int + def __eq__(self, other): + return self.i == other.i + self.assertEqual(C(1), C(1)) + self.assertNotEqual(C(1), C(4)) + + # And make sure things work in this case if we specify + # hash=True. + @dataclass(hash=True) + class C: + i: int + def __eq__(self, other): + return self.i == other.i + self.assertEqual(C(1), C(1.0)) + self.assertEqual(hash(C(1)), hash(C(1.0))) + + # And check that the classes __eq__ is being used, despite + # specifying eq=True. + @dataclass(hash=True, eq=True) + class C: + i: int + def __eq__(self, other): + return self.i == 3 and self.i == other.i + self.assertEqual(C(3), C(3)) + self.assertNotEqual(C(1), C(1)) + self.assertEqual(hash(C(1)), hash(C(1.0))) + + if __name__ == '__main__': unittest.main() diff --git a/Misc/NEWS.d/next/Library/2018-01-27-11-20-16.bpo-32513.ak-iD2.rst b/Misc/NEWS.d/next/Library/2018-01-27-11-20-16.bpo-32513.ak-iD2.rst new file mode 100644 index 0000000..4807241 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-01-27-11-20-16.bpo-32513.ak-iD2.rst @@ -0,0 +1,2 @@ +In dataclasses, allow easier overriding of dunder methods without specifying +decorator parameters. |