diff options
-rw-r--r-- | Lib/dataclasses.py | 15 | ||||
-rwxr-xr-x | Lib/test/test_dataclasses.py | 273 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2018-05-16-10-07-40.bpo-33536._s0TE8.rst | 2 |
3 files changed, 182 insertions, 108 deletions
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index bb77d3b..2c5593b 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -3,6 +3,7 @@ import sys import copy import types import inspect +import keyword __all__ = ['dataclass', 'field', @@ -1100,6 +1101,9 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, # Copy namespace since we're going to mutate it. namespace = namespace.copy() + # While we're looking through the field names, validate that they + # are identifiers, are not keywords, and not duplicates. + seen = set() anns = {} for item in fields: if isinstance(item, str): @@ -1110,6 +1114,17 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, elif len(item) == 3: name, tp, spec = item namespace[name] = spec + else: + raise TypeError(f'Invalid field: {item!r}') + + if not isinstance(name, str) or not name.isidentifier(): + raise TypeError(f'Field names must be valid identifers: {name!r}') + if keyword.iskeyword(name): + raise TypeError(f'Field names must not be keywords: {name!r}') + if name in seen: + raise TypeError(f'Field name duplicated: {name!r}') + + seen.add(name) anns[name] = tp namespace['__annotations__'] = anns diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index b251c04..7c39b79 100755 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -1826,114 +1826,6 @@ class TestCase(unittest.TestCase): self.assertEqual(new_sample.x, another_new_sample.x) self.assertEqual(sample.y, another_new_sample.y) - def test_helper_make_dataclass(self): - C = make_dataclass('C', - [('x', int), - ('y', int, field(default=5))], - namespace={'add_one': lambda self: self.x + 1}) - c = C(10) - self.assertEqual((c.x, c.y), (10, 5)) - self.assertEqual(c.add_one(), 11) - - - def test_helper_make_dataclass_no_mutate_namespace(self): - # Make sure a provided namespace isn't mutated. - ns = {} - C = make_dataclass('C', - [('x', int), - ('y', int, field(default=5))], - namespace=ns) - self.assertEqual(ns, {}) - - def test_helper_make_dataclass_base(self): - class Base1: - pass - class Base2: - pass - C = make_dataclass('C', - [('x', int)], - bases=(Base1, Base2)) - c = C(2) - self.assertIsInstance(c, C) - self.assertIsInstance(c, Base1) - self.assertIsInstance(c, Base2) - - def test_helper_make_dataclass_base_dataclass(self): - @dataclass - class Base1: - x: int - class Base2: - pass - C = make_dataclass('C', - [('y', int)], - bases=(Base1, Base2)) - with self.assertRaisesRegex(TypeError, 'required positional'): - c = C(2) - c = C(1, 2) - self.assertIsInstance(c, C) - self.assertIsInstance(c, Base1) - self.assertIsInstance(c, Base2) - - self.assertEqual((c.x, c.y), (1, 2)) - - def test_helper_make_dataclass_init_var(self): - def post_init(self, y): - self.x *= y - - C = make_dataclass('C', - [('x', int), - ('y', InitVar[int]), - ], - namespace={'__post_init__': post_init}, - ) - c = C(2, 3) - self.assertEqual(vars(c), {'x': 6}) - self.assertEqual(len(fields(c)), 1) - - def test_helper_make_dataclass_class_var(self): - C = make_dataclass('C', - [('x', int), - ('y', ClassVar[int], 10), - ('z', ClassVar[int], field(default=20)), - ]) - c = C(1) - self.assertEqual(vars(c), {'x': 1}) - self.assertEqual(len(fields(c)), 1) - self.assertEqual(C.y, 10) - self.assertEqual(C.z, 20) - - def test_helper_make_dataclass_other_params(self): - C = make_dataclass('C', - [('x', int), - ('y', ClassVar[int], 10), - ('z', ClassVar[int], field(default=20)), - ], - init=False) - # Make sure we have a repr, but no init. - self.assertNotIn('__init__', vars(C)) - self.assertIn('__repr__', vars(C)) - - # Make sure random other params don't work. - with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'): - C = make_dataclass('C', - [], - xxinit=False) - - def test_helper_make_dataclass_no_types(self): - C = make_dataclass('Point', ['x', 'y', 'z']) - c = C(1, 2, 3) - self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) - self.assertEqual(C.__annotations__, {'x': 'typing.Any', - 'y': 'typing.Any', - 'z': 'typing.Any'}) - - C = make_dataclass('Point', ['x', ('y', int), 'z']) - c = C(1, 2, 3) - self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) - self.assertEqual(C.__annotations__, {'x': 'typing.Any', - 'y': int, - 'z': 'typing.Any'}) - class TestFieldNoAnnotation(unittest.TestCase): def test_field_without_annotation(self): @@ -2947,5 +2839,170 @@ class TestStringAnnotations(unittest.TestCase): self.assertNotIn('not_iv4', c.__dict__) +class TestMakeDataclass(unittest.TestCase): + def test_simple(self): + C = make_dataclass('C', + [('x', int), + ('y', int, field(default=5))], + namespace={'add_one': lambda self: self.x + 1}) + c = C(10) + self.assertEqual((c.x, c.y), (10, 5)) + self.assertEqual(c.add_one(), 11) + + + def test_no_mutate_namespace(self): + # Make sure a provided namespace isn't mutated. + ns = {} + C = make_dataclass('C', + [('x', int), + ('y', int, field(default=5))], + namespace=ns) + self.assertEqual(ns, {}) + + def test_base(self): + class Base1: + pass + class Base2: + pass + C = make_dataclass('C', + [('x', int)], + bases=(Base1, Base2)) + c = C(2) + self.assertIsInstance(c, C) + self.assertIsInstance(c, Base1) + self.assertIsInstance(c, Base2) + + def test_base_dataclass(self): + @dataclass + class Base1: + x: int + class Base2: + pass + C = make_dataclass('C', + [('y', int)], + bases=(Base1, Base2)) + with self.assertRaisesRegex(TypeError, 'required positional'): + c = C(2) + c = C(1, 2) + self.assertIsInstance(c, C) + self.assertIsInstance(c, Base1) + self.assertIsInstance(c, Base2) + + self.assertEqual((c.x, c.y), (1, 2)) + + def test_init_var(self): + def post_init(self, y): + self.x *= y + + C = make_dataclass('C', + [('x', int), + ('y', InitVar[int]), + ], + namespace={'__post_init__': post_init}, + ) + c = C(2, 3) + self.assertEqual(vars(c), {'x': 6}) + self.assertEqual(len(fields(c)), 1) + + def test_class_var(self): + C = make_dataclass('C', + [('x', int), + ('y', ClassVar[int], 10), + ('z', ClassVar[int], field(default=20)), + ]) + c = C(1) + self.assertEqual(vars(c), {'x': 1}) + self.assertEqual(len(fields(c)), 1) + self.assertEqual(C.y, 10) + self.assertEqual(C.z, 20) + + def test_other_params(self): + C = make_dataclass('C', + [('x', int), + ('y', ClassVar[int], 10), + ('z', ClassVar[int], field(default=20)), + ], + init=False) + # Make sure we have a repr, but no init. + self.assertNotIn('__init__', vars(C)) + self.assertIn('__repr__', vars(C)) + + # Make sure random other params don't work. + with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'): + C = make_dataclass('C', + [], + xxinit=False) + + def test_no_types(self): + C = make_dataclass('Point', ['x', 'y', 'z']) + c = C(1, 2, 3) + self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) + self.assertEqual(C.__annotations__, {'x': 'typing.Any', + 'y': 'typing.Any', + 'z': 'typing.Any'}) + + C = make_dataclass('Point', ['x', ('y', int), 'z']) + c = C(1, 2, 3) + self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) + self.assertEqual(C.__annotations__, {'x': 'typing.Any', + 'y': int, + 'z': 'typing.Any'}) + + def test_invalid_type_specification(self): + for bad_field in [(), + (1, 2, 3, 4), + ]: + with self.subTest(bad_field=bad_field): + with self.assertRaisesRegex(TypeError, r'Invalid field: '): + make_dataclass('C', ['a', bad_field]) + + # And test for things with no len(). + for bad_field in [float, + lambda x:x, + ]: + with self.subTest(bad_field=bad_field): + with self.assertRaisesRegex(TypeError, r'has no len\(\)'): + make_dataclass('C', ['a', bad_field]) + + def test_duplicate_field_names(self): + for field in ['a', 'ab']: + with self.subTest(field=field): + with self.assertRaisesRegex(TypeError, 'Field name duplicated'): + make_dataclass('C', [field, 'a', field]) + + def test_keyword_field_names(self): + for field in ['for', 'async', 'await', 'as']: + with self.subTest(field=field): + with self.assertRaisesRegex(TypeError, 'must not be keywords'): + make_dataclass('C', ['a', field]) + with self.assertRaisesRegex(TypeError, 'must not be keywords'): + make_dataclass('C', [field]) + with self.assertRaisesRegex(TypeError, 'must not be keywords'): + make_dataclass('C', [field, 'a']) + + def test_non_identifier_field_names(self): + for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']: + with self.subTest(field=field): + with self.assertRaisesRegex(TypeError, 'must be valid identifers'): + make_dataclass('C', ['a', field]) + with self.assertRaisesRegex(TypeError, 'must be valid identifers'): + make_dataclass('C', [field]) + with self.assertRaisesRegex(TypeError, 'must be valid identifers'): + make_dataclass('C', [field, 'a']) + + def test_underscore_field_names(self): + # Unlike namedtuple, it's okay if dataclass field names have + # an underscore. + make_dataclass('C', ['_', '_a', 'a_a', 'a_']) + + def test_funny_class_names_names(self): + # No reason to prevent weird class names, since + # types.new_class allows them. + for classname in ['()', 'x,y', '*', '2@3', '']: + with self.subTest(classname=classname): + C = make_dataclass(classname, ['a', 'b']) + self.assertEqual(C.__name__, classname) + + if __name__ == '__main__': unittest.main() diff --git a/Misc/NEWS.d/next/Library/2018-05-16-10-07-40.bpo-33536._s0TE8.rst b/Misc/NEWS.d/next/Library/2018-05-16-10-07-40.bpo-33536._s0TE8.rst new file mode 100644 index 0000000..2c10241 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-05-16-10-07-40.bpo-33536._s0TE8.rst @@ -0,0 +1,2 @@ +dataclasses.make_dataclass now checks for invalid field names and duplicate +fields. Also, added a check for invalid field specifications. |