summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/dataclasses.py306
-rwxr-xr-xLib/test/test_dataclasses.py664
-rw-r--r--Misc/NEWS.d/next/Library/2018-01-27-11-20-16.bpo-32513.ak-iD2.rst2
3 files changed, 678 insertions, 294 deletions
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index 7d30da1..fb279cd 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -18,6 +18,142 @@ __all__ = ['dataclass',
'is_dataclass',
]
+# Conditions for adding methods. The boxes indicate what action the
+# dataclass decorator takes. For all of these tables, when I talk
+# about init=, repr=, eq=, order=, hash=, or frozen=, I'm referring
+# to the arguments to the @dataclass decorator. When checking if a
+# dunder method already exists, I mean check for an entry in the
+# class's __dict__. I never check to see if an attribute is defined
+# in a base class.
+
+# Key:
+# +=========+=========================================+
+# + Value | Meaning |
+# +=========+=========================================+
+# | <blank> | No action: no method is added. |
+# +---------+-----------------------------------------+
+# | add | Generated method is added. |
+# +---------+-----------------------------------------+
+# | add* | Generated method is added only if the |
+# | | existing attribute is None and if the |
+# | | user supplied a __eq__ method in the |
+# | | class definition. |
+# +---------+-----------------------------------------+
+# | raise | TypeError is raised. |
+# +---------+-----------------------------------------+
+# | None | Attribute is set to None. |
+# +=========+=========================================+
+
+# __init__
+#
+# +--- init= parameter
+# |
+# v | | |
+# | no | yes | <--- class has __init__ in __dict__?
+# +=======+=======+=======+
+# | False | | |
+# +-------+-------+-------+
+# | True | add | | <- the default
+# +=======+=======+=======+
+
+# __repr__
+#
+# +--- repr= parameter
+# |
+# v | | |
+# | no | yes | <--- class has __repr__ in __dict__?
+# +=======+=======+=======+
+# | False | | |
+# +-------+-------+-------+
+# | True | add | | <- the default
+# +=======+=======+=======+
+
+
+# __setattr__
+# __delattr__
+#
+# +--- frozen= parameter
+# |
+# v | | |
+# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__?
+# +=======+=======+=======+
+# | False | | | <- the default
+# +-------+-------+-------+
+# | True | add | raise |
+# +=======+=======+=======+
+# Raise because not adding these methods would break the "frozen-ness"
+# of the class.
+
+# __eq__
+#
+# +--- eq= parameter
+# |
+# v | | |
+# | no | yes | <--- class has __eq__ in __dict__?
+# +=======+=======+=======+
+# | False | | |
+# +-------+-------+-------+
+# | True | add | | <- the default
+# +=======+=======+=======+
+
+# __lt__
+# __le__
+# __gt__
+# __ge__
+#
+# +--- order= parameter
+# |
+# v | | |
+# | no | yes | <--- class has any comparison method in __dict__?
+# +=======+=======+=======+
+# | False | | | <- the default
+# +-------+-------+-------+
+# | True | add | raise |
+# +=======+=======+=======+
+# Raise because to allow this case would interfere with using
+# functools.total_ordering.
+
+# __hash__
+
+# +------------------- hash= parameter
+# | +----------- eq= parameter
+# | | +--- frozen= parameter
+# | | |
+# v v v | | |
+# | no | yes | <--- class has __hash__ in __dict__?
+# +=========+=======+=======+========+========+
+# | 1 None | False | False | | | No __eq__, use the base class __hash__
+# +---------+-------+-------+--------+--------+
+# | 2 None | False | True | | | No __eq__, use the base class __hash__
+# +---------+-------+-------+--------+--------+
+# | 3 None | True | False | None | | <-- the default, not hashable
+# +---------+-------+-------+--------+--------+
+# | 4 None | True | True | add | add* | Frozen, so hashable
+# +---------+-------+-------+--------+--------+
+# | 5 False | False | False | | |
+# +---------+-------+-------+--------+--------+
+# | 6 False | False | True | | |
+# +---------+-------+-------+--------+--------+
+# | 7 False | True | False | | |
+# +---------+-------+-------+--------+--------+
+# | 8 False | True | True | | |
+# +---------+-------+-------+--------+--------+
+# | 9 True | False | False | add | add* | Has no __eq__, but hashable
+# +---------+-------+-------+--------+--------+
+# |10 True | False | True | add | add* | Has no __eq__, but hashable
+# +---------+-------+-------+--------+--------+
+# |11 True | True | False | add | add* | Not frozen, but hashable
+# +---------+-------+-------+--------+--------+
+# |12 True | True | True | add | add* | Frozen, so hashable
+# +=========+=======+=======+========+========+
+# For boxes that are blank, __hash__ is untouched and therefore
+# inherited from the base class. If the base is object, then
+# id-based hashing is used.
+# Note that a class may have already __hash__=None if it specified an
+# __eq__ method in the class body (not one that was created by
+# @dataclass).
+
+
# Raised when an attempt is made to modify a frozen class.
class FrozenInstanceError(AttributeError): pass
@@ -143,13 +279,13 @@ def _tuple_str(obj_name, fields):
# return "(self.x,self.y)".
# Special case for the 0-tuple.
- if len(fields) == 0:
+ if not fields:
return '()'
# Note the trailing comma, needed if this turns out to be a 1-tuple.
return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
-def _create_fn(name, args, body, globals=None, locals=None,
+def _create_fn(name, args, body, *, globals=None, locals=None,
return_type=MISSING):
# Note that we mutate locals when exec() is called. Caller beware!
if locals is None:
@@ -287,7 +423,7 @@ def _init_fn(fields, frozen, has_post_init, self_name):
body_lines += [f'{self_name}.{_POST_INIT_NAME}({params_str})']
# If no body lines, use 'pass'.
- if len(body_lines) == 0:
+ if not body_lines:
body_lines = ['pass']
locals = {f'_type_{f.name}': f.type for f in fields}
@@ -329,32 +465,6 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
'return NotImplemented'])
-def _set_eq_fns(cls, fields):
- # Create and set the equality comparison methods on cls.
- # Pre-compute self_tuple and other_tuple, then re-use them for
- # each function.
- self_tuple = _tuple_str('self', fields)
- other_tuple = _tuple_str('other', fields)
- for name, op in [('__eq__', '=='),
- ('__ne__', '!='),
- ]:
- _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
-
-
-def _set_order_fns(cls, fields):
- # Create and set the ordering methods on cls.
- # Pre-compute self_tuple and other_tuple, then re-use them for
- # each function.
- self_tuple = _tuple_str('self', fields)
- other_tuple = _tuple_str('other', fields)
- for name, op in [('__lt__', '<'),
- ('__le__', '<='),
- ('__gt__', '>'),
- ('__ge__', '>='),
- ]:
- _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
-
-
def _hash_fn(fields):
self_tuple = _tuple_str('self', fields)
return _create_fn('__hash__',
@@ -431,20 +541,20 @@ def _find_fields(cls):
# a Field(), then it contains additional info beyond (and
# possibly including) the actual default value. Pseudo-fields
# ClassVars and InitVars are included, despite the fact that
- # they're not real fields. That's deal with later.
+ # they're not real fields. That's dealt with later.
annotations = getattr(cls, '__annotations__', {})
-
return [_get_field(cls, a_name, a_type)
for a_name, a_type in annotations.items()]
-def _set_attribute(cls, name, value):
- # Raise TypeError if an attribute by this name already exists.
+def _set_new_attribute(cls, name, value):
+ # Never overwrites an existing attribute. Returns True if the
+ # attribute already exists.
if name in cls.__dict__:
- raise TypeError(f'Cannot overwrite attribute {name} '
- f'in {cls.__name__}')
+ return True
setattr(cls, name, value)
+ return False
def _process_class(cls, repr, eq, order, hash, init, frozen):
@@ -495,6 +605,9 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
# be inherited down.
is_frozen = frozen or cls.__setattr__ is _frozen_setattr
+ # Was this class defined with an __eq__? Used in __hash__ logic.
+ auto_hash_test= '__eq__' in cls.__dict__ and getattr(cls.__dict__, '__hash__', MISSING) is None
+
# If we're generating ordering methods, we must be generating
# the eq methods.
if order and not eq:
@@ -505,62 +618,91 @@ def _process_class(cls, repr, eq, order, hash, init, frozen):
has_post_init = hasattr(cls, _POST_INIT_NAME)
# Include InitVars and regular fields (so, not ClassVars).
- _set_attribute(cls, '__init__',
- _init_fn(list(filter(lambda f: f._field_type
- in (_FIELD, _FIELD_INITVAR),
- fields.values())),
- is_frozen,
- has_post_init,
- # The name to use for the "self" param
- # in __init__. Use "self" if possible.
- '__dataclass_self__' if 'self' in fields
- else 'self',
- ))
+ flds = [f for f in fields.values()
+ if f._field_type in (_FIELD, _FIELD_INITVAR)]
+ _set_new_attribute(cls, '__init__',
+ _init_fn(flds,
+ is_frozen,
+ has_post_init,
+ # The name to use for the "self" param
+ # in __init__. Use "self" if possible.
+ '__dataclass_self__' if 'self' in fields
+ else 'self',
+ ))
# Get the fields as a list, and include only real fields. This is
# used in all of the following methods.
- field_list = list(filter(lambda f: f._field_type is _FIELD,
- fields.values()))
+ field_list = [f for f in fields.values() if f._field_type is _FIELD]
if repr:
- _set_attribute(cls, '__repr__',
- _repr_fn(list(filter(lambda f: f.repr, field_list))))
-
- if is_frozen:
- _set_attribute(cls, '__setattr__', _frozen_setattr)
- _set_attribute(cls, '__delattr__', _frozen_delattr)
-
- generate_hash = False
- if hash is None:
- if eq and frozen:
- # Generate a hash function.
- generate_hash = True
- elif eq and not frozen:
- # Not hashable.
- _set_attribute(cls, '__hash__', None)
- elif not eq:
- # Otherwise, use the base class definition of hash(). That is,
- # don't set anything on this class.
- pass
- else:
- assert "can't get here"
- else:
- generate_hash = hash
- if generate_hash:
- _set_attribute(cls, '__hash__',
- _hash_fn(list(filter(lambda f: f.compare
- if f.hash is None
- else f.hash,
- field_list))))
+ flds = [f for f in field_list if f.repr]
+ _set_new_attribute(cls, '__repr__', _repr_fn(flds))
if eq:
- # Create and __eq__ and __ne__ methods.
- _set_eq_fns(cls, list(filter(lambda f: f.compare, field_list)))
+ # Create _eq__ method. There's no need for a __ne__ method,
+ # since python will call __eq__ and negate it.
+ flds = [f for f in field_list if f.compare]
+ self_tuple = _tuple_str('self', flds)
+ other_tuple = _tuple_str('other', flds)
+ _set_new_attribute(cls, '__eq__',
+ _cmp_fn('__eq__', '==',
+ self_tuple, other_tuple))
if order:
- # Create and __lt__, __le__, __gt__, and __ge__ methods.
- # Create and set the comparison functions.
- _set_order_fns(cls, list(filter(lambda f: f.compare, field_list)))
+ # Create and set the ordering methods.
+ flds = [f for f in field_list if f.compare]
+ self_tuple = _tuple_str('self', flds)
+ other_tuple = _tuple_str('other', flds)
+ for name, op in [('__lt__', '<'),
+ ('__le__', '<='),
+ ('__gt__', '>'),
+ ('__ge__', '>='),
+ ]:
+ if _set_new_attribute(cls, name,
+ _cmp_fn(name, op, self_tuple, other_tuple)):
+ raise TypeError(f'Cannot overwrite attribute {name} '
+ f'in {cls.__name__}. Consider using '
+ 'functools.total_ordering')
+
+ if is_frozen:
+ for name, fn in [('__setattr__', _frozen_setattr),
+ ('__delattr__', _frozen_delattr)]:
+ if _set_new_attribute(cls, name, fn):
+ raise TypeError(f'Cannot overwrite attribute {name} '
+ f'in {cls.__name__}')
+
+ # Decide if/how we're going to create a hash function.
+ # TODO: Move this table to module scope, so it's not recreated
+ # all the time.
+ generate_hash = {(None, False, False): ('', ''),
+ (None, False, True): ('', ''),
+ (None, True, False): ('none', ''),
+ (None, True, True): ('fn', 'fn-x'),
+ (False, False, False): ('', ''),
+ (False, False, True): ('', ''),
+ (False, True, False): ('', ''),
+ (False, True, True): ('', ''),
+ (True, False, False): ('fn', 'fn-x'),
+ (True, False, True): ('fn', 'fn-x'),
+ (True, True, False): ('fn', 'fn-x'),
+ (True, True, True): ('fn', 'fn-x'),
+ }[None if hash is None else bool(hash), # Force bool() if not None.
+ bool(eq),
+ bool(frozen)]['__hash__' in cls.__dict__]
+ # No need to call _set_new_attribute here, since we already know if
+ # we're overwriting a __hash__ or not.
+ if generate_hash == '':
+ # Do nothing.
+ pass
+ elif generate_hash == 'none':
+ cls.__hash__ = None
+ elif generate_hash in ('fn', 'fn-x'):
+ if generate_hash == 'fn' or auto_hash_test:
+ flds = [f for f in field_list
+ if (f.compare if f.hash is None else f.hash)]
+ cls.__hash__ = _hash_fn(flds)
+ else:
+ assert False, f"can't get here: {generate_hash}"
if not getattr(cls, '__doc__'):
# Create a class doc-string.
diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py
index 69819ea..53281f9 100755
--- a/Lib/test/test_dataclasses.py
+++ b/Lib/test/test_dataclasses.py
@@ -9,6 +9,7 @@ import unittest
from unittest.mock import Mock
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
from collections import deque, OrderedDict, namedtuple
+from functools import total_ordering
# Just any custom exception we can catch.
class CustomError(Exception): pass
@@ -82,68 +83,12 @@ class TestCase(unittest.TestCase):
class C(B):
x: int = 0
- def test_overwriting_init(self):
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __init__ '
- 'in C'):
- @dataclass
- class C:
- x: int
- def __init__(self, x):
- self.x = 2 * x
-
- @dataclass(init=False)
- class C:
- x: int
- def __init__(self, x):
- self.x = 2 * x
- self.assertEqual(C(5).x, 10)
-
- def test_overwriting_repr(self):
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __repr__ '
- 'in C'):
- @dataclass
- class C:
- x: int
- def __repr__(self):
- pass
-
- @dataclass(repr=False)
- class C:
- x: int
- def __repr__(self):
- return 'x'
- self.assertEqual(repr(C(0)), 'x')
-
- def test_overwriting_cmp(self):
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __eq__ '
- 'in C'):
- # This will generate the comparison functions, make sure we can't
- # overwrite them.
- @dataclass(hash=False, frozen=False)
- class C:
- x: int
- def __eq__(self):
- pass
-
- @dataclass(order=False, eq=False)
+ def test_overwriting_hash(self):
+ @dataclass(frozen=True)
class C:
x: int
- def __eq__(self, other):
- return True
- self.assertEqual(C(0), 'x')
-
- def test_overwriting_hash(self):
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __hash__ '
- 'in C'):
- @dataclass(frozen=True)
- class C:
- x: int
- def __hash__(self):
- pass
+ def __hash__(self):
+ pass
@dataclass(frozen=True,hash=False)
class C:
@@ -152,14 +97,11 @@ class TestCase(unittest.TestCase):
return 600
self.assertEqual(hash(C(0)), 600)
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __hash__ '
- 'in C'):
- @dataclass(frozen=True)
- class C:
- x: int
- def __hash__(self):
- pass
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __hash__(self):
+ pass
@dataclass(frozen=True, hash=False)
class C:
@@ -168,33 +110,6 @@ class TestCase(unittest.TestCase):
return 600
self.assertEqual(hash(C(0)), 600)
- def test_overwriting_frozen(self):
- # frozen uses __setattr__ and __delattr__
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __setattr__ '
- 'in C'):
- @dataclass(frozen=True)
- class C:
- x: int
- def __setattr__(self):
- pass
-
- with self.assertRaisesRegex(TypeError,
- 'Cannot overwrite attribute __delattr__ '
- 'in C'):
- @dataclass(frozen=True)
- class C:
- x: int
- def __delattr__(self):
- pass
-
- @dataclass(frozen=False)
- class C:
- x: int
- def __setattr__(self, name, value):
- self.__dict__['x'] = value * 2
- self.assertEqual(C(10).x, 20)
-
def test_overwrite_fields_in_derived_class(self):
# Note that x from C1 replaces x in Base, but the order remains
# the same as defined in Base.
@@ -239,34 +154,6 @@ class TestCase(unittest.TestCase):
first = next(iter(sig.parameters))
self.assertEqual('self', first)
- def test_repr(self):
- @dataclass
- class B:
- x: int
-
- @dataclass
- class C(B):
- y: int = 10
-
- o = C(4)
- self.assertEqual(repr(o), 'TestCase.test_repr.<locals>.C(x=4, y=10)')
-
- @dataclass
- class D(C):
- x: int = 20
- self.assertEqual(repr(D()), 'TestCase.test_repr.<locals>.D(x=20, y=10)')
-
- @dataclass
- class C:
- @dataclass
- class D:
- i: int
- @dataclass
- class E:
- pass
- self.assertEqual(repr(C.D(0)), 'TestCase.test_repr.<locals>.C.D(i=0)')
- self.assertEqual(repr(C.E()), 'TestCase.test_repr.<locals>.C.E()')
-
def test_0_field_compare(self):
# Ensure that order=False is the default.
@dataclass
@@ -420,80 +307,8 @@ class TestCase(unittest.TestCase):
self.assertEqual(hash(C(4)), hash((4,)))
self.assertEqual(hash(C(42)), hash((42,)))
- def test_hash(self):
- @dataclass(hash=True)
- class C:
- x: int
- y: str
- self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
-
- def test_no_hash(self):
- @dataclass(hash=None)
- class C:
- x: int
- with self.assertRaisesRegex(TypeError,
- "unhashable type: 'C'"):
- hash(C(1))
-
- def test_hash_rules(self):
- # There are 24 cases of:
- # hash=True/False/None
- # eq=True/False
- # order=True/False
- # frozen=True/False
- for (hash, eq, order, frozen, result ) in [
- (False, False, False, False, 'absent'),
- (False, False, False, True, 'absent'),
- (False, False, True, False, 'exception'),
- (False, False, True, True, 'exception'),
- (False, True, False, False, 'absent'),
- (False, True, False, True, 'absent'),
- (False, True, True, False, 'absent'),
- (False, True, True, True, 'absent'),
- (True, False, False, False, 'fn'),
- (True, False, False, True, 'fn'),
- (True, False, True, False, 'exception'),
- (True, False, True, True, 'exception'),
- (True, True, False, False, 'fn'),
- (True, True, False, True, 'fn'),
- (True, True, True, False, 'fn'),
- (True, True, True, True, 'fn'),
- (None, False, False, False, 'absent'),
- (None, False, False, True, 'absent'),
- (None, False, True, False, 'exception'),
- (None, False, True, True, 'exception'),
- (None, True, False, False, 'none'),
- (None, True, False, True, 'fn'),
- (None, True, True, False, 'none'),
- (None, True, True, True, 'fn'),
- ]:
- with self.subTest(hash=hash, eq=eq, order=order, frozen=frozen):
- if result == 'exception':
- with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
- @dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
- class C:
- pass
- else:
- @dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
- class C:
- pass
-
- # See if the result matches what's expected.
- if result == 'fn':
- # __hash__ contains the function we generated.
- self.assertIn('__hash__', C.__dict__)
- self.assertIsNotNone(C.__dict__['__hash__'])
- elif result == 'absent':
- # __hash__ is not present in our class.
- self.assertNotIn('__hash__', C.__dict__)
- elif result == 'none':
- # __hash__ is set to None.
- self.assertIn('__hash__', C.__dict__)
- self.assertIsNone(C.__dict__['__hash__'])
- else:
- assert False, f'unknown result {result!r}'
-
def test_eq_order(self):
+ # Test combining eq and order.
for (eq, order, result ) in [
(False, False, 'neither'),
(False, True, 'exception'),
@@ -513,21 +328,18 @@ class TestCase(unittest.TestCase):
if result == 'neither':
self.assertNotIn('__eq__', C.__dict__)
- self.assertNotIn('__ne__', C.__dict__)
self.assertNotIn('__lt__', C.__dict__)
self.assertNotIn('__le__', C.__dict__)
self.assertNotIn('__gt__', C.__dict__)
self.assertNotIn('__ge__', C.__dict__)
elif result == 'both':
self.assertIn('__eq__', C.__dict__)
- self.assertIn('__ne__', C.__dict__)
self.assertIn('__lt__', C.__dict__)
self.assertIn('__le__', C.__dict__)
self.assertIn('__gt__', C.__dict__)
self.assertIn('__ge__', C.__dict__)
elif result == 'eq_only':
self.assertIn('__eq__', C.__dict__)
- self.assertIn('__ne__', C.__dict__)
self.assertNotIn('__lt__', C.__dict__)
self.assertNotIn('__le__', C.__dict__)
self.assertNotIn('__gt__', C.__dict__)
@@ -811,19 +623,6 @@ class TestCase(unittest.TestCase):
y: int
self.assertNotEqual(Point(1, 3), C(1, 3))
- def test_base_has_init(self):
- class B:
- def __init__(self):
- pass
-
- # Make sure that declaring this class doesn't raise an error.
- # The issue is that we can't override __init__ in our class,
- # but it should be okay to add __init__ to us if our base has
- # an __init__.
- @dataclass
- class C(B):
- x: int = 0
-
def test_frozen(self):
@dataclass(frozen=True)
class C:
@@ -2065,6 +1864,7 @@ class TestCase(unittest.TestCase):
'y': int,
'z': 'typing.Any'})
+
class TestDocString(unittest.TestCase):
def assertDocStrEqual(self, a, b):
# Because 3.6 and 3.7 differ in how inspect.signature work
@@ -2154,5 +1954,445 @@ class TestDocString(unittest.TestCase):
self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
+class TestInit(unittest.TestCase):
+ def test_base_has_init(self):
+ class B:
+ def __init__(self):
+ self.z = 100
+ pass
+
+ # Make sure that declaring this class doesn't raise an error.
+ # The issue is that we can't override __init__ in our class,
+ # but it should be okay to add __init__ to us if our base has
+ # an __init__.
+ @dataclass
+ class C(B):
+ x: int = 0
+ c = C(10)
+ self.assertEqual(c.x, 10)
+ self.assertNotIn('z', vars(c))
+
+ # Make sure that if we don't add an init, the base __init__
+ # gets called.
+ @dataclass(init=False)
+ class C(B):
+ x: int = 10
+ c = C()
+ self.assertEqual(c.x, 10)
+ self.assertEqual(c.z, 100)
+
+ def test_no_init(self):
+ dataclass(init=False)
+ class C:
+ i: int = 0
+ self.assertEqual(C().i, 0)
+
+ dataclass(init=False)
+ class C:
+ i: int = 2
+ def __init__(self):
+ self.i = 3
+ self.assertEqual(C().i, 3)
+
+ def test_overwriting_init(self):
+ # If the class has __init__, use it no matter the value of
+ # init=.
+
+ @dataclass
+ class C:
+ x: int
+ def __init__(self, x):
+ self.x = 2 * x
+ self.assertEqual(C(3).x, 6)
+
+ @dataclass(init=True)
+ class C:
+ x: int
+ def __init__(self, x):
+ self.x = 2 * x
+ self.assertEqual(C(4).x, 8)
+
+ @dataclass(init=False)
+ class C:
+ x: int
+ def __init__(self, x):
+ self.x = 2 * x
+ self.assertEqual(C(5).x, 10)
+
+
+class TestRepr(unittest.TestCase):
+ def test_repr(self):
+ @dataclass
+ class B:
+ x: int
+
+ @dataclass
+ class C(B):
+ y: int = 10
+
+ o = C(4)
+ self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
+
+ @dataclass
+ class D(C):
+ x: int = 20
+ self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
+
+ @dataclass
+ class C:
+ @dataclass
+ class D:
+ i: int
+ @dataclass
+ class E:
+ pass
+ self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
+ self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
+
+ def test_no_repr(self):
+ # Test a class with no __repr__ and repr=False.
+ @dataclass(repr=False)
+ class C:
+ x: int
+ self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
+ repr(C(3)))
+
+ # Test a class with a __repr__ and repr=False.
+ @dataclass(repr=False)
+ class C:
+ x: int
+ def __repr__(self):
+ return 'C-class'
+ self.assertEqual(repr(C(3)), 'C-class')
+
+ def test_overwriting_repr(self):
+ # If the class has __repr__, use it no matter the value of
+ # repr=.
+
+ @dataclass
+ class C:
+ x: int
+ def __repr__(self):
+ return 'x'
+ self.assertEqual(repr(C(0)), 'x')
+
+ @dataclass(repr=True)
+ class C:
+ x: int
+ def __repr__(self):
+ return 'x'
+ self.assertEqual(repr(C(0)), 'x')
+
+ @dataclass(repr=False)
+ class C:
+ x: int
+ def __repr__(self):
+ return 'x'
+ self.assertEqual(repr(C(0)), 'x')
+
+
+class TestFrozen(unittest.TestCase):
+ def test_overwriting_frozen(self):
+ # frozen uses __setattr__ and __delattr__
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __setattr__'):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __setattr__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __delattr__'):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __delattr__(self):
+ pass
+
+ @dataclass(frozen=False)
+ class C:
+ x: int
+ def __setattr__(self, name, value):
+ self.__dict__['x'] = value * 2
+ self.assertEqual(C(10).x, 20)
+
+
+class TestEq(unittest.TestCase):
+ def test_no_eq(self):
+ # Test a class with no __eq__ and eq=False.
+ @dataclass(eq=False)
+ class C:
+ x: int
+ self.assertNotEqual(C(0), C(0))
+ c = C(3)
+ self.assertEqual(c, c)
+
+ # Test a class with an __eq__ and eq=False.
+ @dataclass(eq=False)
+ class C:
+ x: int
+ def __eq__(self, other):
+ return other == 10
+ self.assertEqual(C(3), 10)
+
+ def test_overwriting_eq(self):
+ # If the class has __eq__, use it no matter the value of
+ # eq=.
+
+ @dataclass
+ class C:
+ x: int
+ def __eq__(self, other):
+ return other == 3
+ self.assertEqual(C(1), 3)
+ self.assertNotEqual(C(1), 1)
+
+ @dataclass(eq=True)
+ class C:
+ x: int
+ def __eq__(self, other):
+ return other == 4
+ self.assertEqual(C(1), 4)
+ self.assertNotEqual(C(1), 1)
+
+ @dataclass(eq=False)
+ class C:
+ x: int
+ def __eq__(self, other):
+ return other == 5
+ self.assertEqual(C(1), 5)
+ self.assertNotEqual(C(1), 1)
+
+
+class TestOrdering(unittest.TestCase):
+ def test_functools_total_ordering(self):
+ # Test that functools.total_ordering works with this class.
+ @total_ordering
+ @dataclass
+ class C:
+ x: int
+ def __lt__(self, other):
+ # Perform the test "backward", just to make
+ # sure this is being called.
+ return self.x >= other
+
+ self.assertLess(C(0), -1)
+ self.assertLessEqual(C(0), -1)
+ self.assertGreater(C(0), 1)
+ self.assertGreaterEqual(C(0), 1)
+
+ def test_no_order(self):
+ # Test that no ordering functions are added by default.
+ @dataclass(order=False)
+ class C:
+ x: int
+ # Make sure no order methods are added.
+ self.assertNotIn('__le__', C.__dict__)
+ self.assertNotIn('__lt__', C.__dict__)
+ self.assertNotIn('__ge__', C.__dict__)
+ self.assertNotIn('__gt__', C.__dict__)
+
+ # Test that __lt__ is still called
+ @dataclass(order=False)
+ class C:
+ x: int
+ def __lt__(self, other):
+ return False
+ # Make sure other methods aren't added.
+ self.assertNotIn('__le__', C.__dict__)
+ self.assertNotIn('__ge__', C.__dict__)
+ self.assertNotIn('__gt__', C.__dict__)
+
+ def test_overwriting_order(self):
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __lt__'
+ '.*using functools.total_ordering'):
+ @dataclass(order=True)
+ class C:
+ x: int
+ def __lt__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __le__'
+ '.*using functools.total_ordering'):
+ @dataclass(order=True)
+ class C:
+ x: int
+ def __le__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __gt__'
+ '.*using functools.total_ordering'):
+ @dataclass(order=True)
+ class C:
+ x: int
+ def __gt__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __ge__'
+ '.*using functools.total_ordering'):
+ @dataclass(order=True)
+ class C:
+ x: int
+ def __ge__(self):
+ pass
+
+class TestHash(unittest.TestCase):
+ def test_hash(self):
+ @dataclass(hash=True)
+ class C:
+ x: int
+ y: str
+ self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
+
+ def test_hash_false(self):
+ @dataclass(hash=False)
+ class C:
+ x: int
+ y: str
+ self.assertNotEqual(hash(C(1, 'foo')), hash((1, 'foo')))
+
+ def test_hash_none(self):
+ @dataclass(hash=None)
+ class C:
+ x: int
+ with self.assertRaisesRegex(TypeError,
+ "unhashable type: 'C'"):
+ hash(C(1))
+
+ def test_hash_rules(self):
+ def non_bool(value):
+ # Map to something else that's True, but not a bool.
+ if value is None:
+ return None
+ if value:
+ return (3,)
+ return 0
+
+ def test(case, hash, eq, frozen, with_hash, result):
+ with self.subTest(case=case, hash=hash, eq=eq, frozen=frozen):
+ if with_hash:
+ @dataclass(hash=hash, eq=eq, frozen=frozen)
+ class C:
+ def __hash__(self):
+ return 0
+ else:
+ @dataclass(hash=hash, eq=eq, frozen=frozen)
+ class C:
+ pass
+
+ # See if the result matches what's expected.
+ if result in ('fn', 'fn-x'):
+ # __hash__ contains the function we generated.
+ self.assertIn('__hash__', C.__dict__)
+ self.assertIsNotNone(C.__dict__['__hash__'])
+
+ if result == 'fn-x':
+ # This is the "auto-hash test" case. We
+ # should overwrite __hash__ iff there's an
+ # __eq__ and if __hash__=None.
+
+ # There are two ways of getting __hash__=None:
+ # explicitely, and by defining __eq__. If
+ # __eq__ is defined, python will add __hash__
+ # when the class is created.
+ @dataclass(hash=hash, eq=eq, frozen=frozen)
+ class C:
+ def __eq__(self, other): pass
+ __hash__ = None
+
+ # Hash should be overwritten (non-None).
+ self.assertIsNotNone(C.__dict__['__hash__'])
+
+ # Same test as above, but we don't provide
+ # __hash__, it will implicitely set to None.
+ @dataclass(hash=hash, eq=eq, frozen=frozen)
+ class C:
+ def __eq__(self, other): pass
+
+ # Hash should be overwritten (non-None).
+ self.assertIsNotNone(C.__dict__['__hash__'])
+
+ elif result == '':
+ # __hash__ is not present in our class.
+ if not with_hash:
+ self.assertNotIn('__hash__', C.__dict__)
+ elif result == 'none':
+ # __hash__ is set to None.
+ self.assertIn('__hash__', C.__dict__)
+ self.assertIsNone(C.__dict__['__hash__'])
+ else:
+ assert False, f'unknown result {result!r}'
+
+ # There are 12 cases of:
+ # hash=True/False/None
+ # eq=True/False
+ # frozen=True/False
+ # And for each of these, a different result if
+ # __hash__ is defined or not.
+ for case, (hash, eq, frozen, result_no, result_yes) in enumerate([
+ (None, False, False, '', ''),
+ (None, False, True, '', ''),
+ (None, True, False, 'none', ''),
+ (None, True, True, 'fn', 'fn-x'),
+ (False, False, False, '', ''),
+ (False, False, True, '', ''),
+ (False, True, False, '', ''),
+ (False, True, True, '', ''),
+ (True, False, False, 'fn', 'fn-x'),
+ (True, False, True, 'fn', 'fn-x'),
+ (True, True, False, 'fn', 'fn-x'),
+ (True, True, True, 'fn', 'fn-x'),
+ ], 1):
+ test(case, hash, eq, frozen, False, result_no)
+ test(case, hash, eq, frozen, True, result_yes)
+
+ # Test non-bool truth values, too. This is just to
+ # make sure the data-driven table in the decorator
+ # handles non-bool values.
+ test(case, non_bool(hash), non_bool(eq), non_bool(frozen), False, result_no)
+ test(case, non_bool(hash), non_bool(eq), non_bool(frozen), True, result_yes)
+
+
+ def test_eq_only(self):
+ # If a class defines __eq__, __hash__ is automatically added
+ # and set to None. This is normal Python behavior, not
+ # related to dataclasses. Make sure we don't interfere with
+ # that (see bpo=32546).
+
+ @dataclass
+ class C:
+ i: int
+ def __eq__(self, other):
+ return self.i == other.i
+ self.assertEqual(C(1), C(1))
+ self.assertNotEqual(C(1), C(4))
+
+ # And make sure things work in this case if we specify
+ # hash=True.
+ @dataclass(hash=True)
+ class C:
+ i: int
+ def __eq__(self, other):
+ return self.i == other.i
+ self.assertEqual(C(1), C(1.0))
+ self.assertEqual(hash(C(1)), hash(C(1.0)))
+
+ # And check that the classes __eq__ is being used, despite
+ # specifying eq=True.
+ @dataclass(hash=True, eq=True)
+ class C:
+ i: int
+ def __eq__(self, other):
+ return self.i == 3 and self.i == other.i
+ self.assertEqual(C(3), C(3))
+ self.assertNotEqual(C(1), C(1))
+ self.assertEqual(hash(C(1)), hash(C(1.0)))
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2018-01-27-11-20-16.bpo-32513.ak-iD2.rst b/Misc/NEWS.d/next/Library/2018-01-27-11-20-16.bpo-32513.ak-iD2.rst
new file mode 100644
index 0000000..4807241
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-01-27-11-20-16.bpo-32513.ak-iD2.rst
@@ -0,0 +1,2 @@
+In dataclasses, allow easier overriding of dunder methods without specifying
+decorator parameters.