summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorEric V. Smith <ericvsmith@users.noreply.github.com>2018-05-16 15:31:29 (GMT)
committerGitHub <noreply@github.com>2018-05-16 15:31:29 (GMT)
commit4e81296b1874829912c687eba4d39361ab51e145 (patch)
tree46db8179f7b5f6eccd66a688ef4126cc34457b38 /Lib
parent5db5c0669e624767375593cc1a01f32092c91c58 (diff)
downloadcpython-4e81296b1874829912c687eba4d39361ab51e145.zip
cpython-4e81296b1874829912c687eba4d39361ab51e145.tar.gz
cpython-4e81296b1874829912c687eba4d39361ab51e145.tar.bz2
bpo-33536: Validate make_dataclass() field names. (GH-6906)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/dataclasses.py15
-rwxr-xr-xLib/test/test_dataclasses.py273
2 files changed, 180 insertions, 108 deletions
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index bb77d3b..2c5593b 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -3,6 +3,7 @@ import sys
import copy
import types
import inspect
+import keyword
__all__ = ['dataclass',
'field',
@@ -1100,6 +1101,9 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
# Copy namespace since we're going to mutate it.
namespace = namespace.copy()
+ # While we're looking through the field names, validate that they
+ # are identifiers, are not keywords, and not duplicates.
+ seen = set()
anns = {}
for item in fields:
if isinstance(item, str):
@@ -1110,6 +1114,17 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
elif len(item) == 3:
name, tp, spec = item
namespace[name] = spec
+ else:
+ raise TypeError(f'Invalid field: {item!r}')
+
+ if not isinstance(name, str) or not name.isidentifier():
+ raise TypeError(f'Field names must be valid identifers: {name!r}')
+ if keyword.iskeyword(name):
+ raise TypeError(f'Field names must not be keywords: {name!r}')
+ if name in seen:
+ raise TypeError(f'Field name duplicated: {name!r}')
+
+ seen.add(name)
anns[name] = tp
namespace['__annotations__'] = anns
diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py
index b251c04..7c39b79 100755
--- a/Lib/test/test_dataclasses.py
+++ b/Lib/test/test_dataclasses.py
@@ -1826,114 +1826,6 @@ class TestCase(unittest.TestCase):
self.assertEqual(new_sample.x, another_new_sample.x)
self.assertEqual(sample.y, another_new_sample.y)
- def test_helper_make_dataclass(self):
- C = make_dataclass('C',
- [('x', int),
- ('y', int, field(default=5))],
- namespace={'add_one': lambda self: self.x + 1})
- c = C(10)
- self.assertEqual((c.x, c.y), (10, 5))
- self.assertEqual(c.add_one(), 11)
-
-
- def test_helper_make_dataclass_no_mutate_namespace(self):
- # Make sure a provided namespace isn't mutated.
- ns = {}
- C = make_dataclass('C',
- [('x', int),
- ('y', int, field(default=5))],
- namespace=ns)
- self.assertEqual(ns, {})
-
- def test_helper_make_dataclass_base(self):
- class Base1:
- pass
- class Base2:
- pass
- C = make_dataclass('C',
- [('x', int)],
- bases=(Base1, Base2))
- c = C(2)
- self.assertIsInstance(c, C)
- self.assertIsInstance(c, Base1)
- self.assertIsInstance(c, Base2)
-
- def test_helper_make_dataclass_base_dataclass(self):
- @dataclass
- class Base1:
- x: int
- class Base2:
- pass
- C = make_dataclass('C',
- [('y', int)],
- bases=(Base1, Base2))
- with self.assertRaisesRegex(TypeError, 'required positional'):
- c = C(2)
- c = C(1, 2)
- self.assertIsInstance(c, C)
- self.assertIsInstance(c, Base1)
- self.assertIsInstance(c, Base2)
-
- self.assertEqual((c.x, c.y), (1, 2))
-
- def test_helper_make_dataclass_init_var(self):
- def post_init(self, y):
- self.x *= y
-
- C = make_dataclass('C',
- [('x', int),
- ('y', InitVar[int]),
- ],
- namespace={'__post_init__': post_init},
- )
- c = C(2, 3)
- self.assertEqual(vars(c), {'x': 6})
- self.assertEqual(len(fields(c)), 1)
-
- def test_helper_make_dataclass_class_var(self):
- C = make_dataclass('C',
- [('x', int),
- ('y', ClassVar[int], 10),
- ('z', ClassVar[int], field(default=20)),
- ])
- c = C(1)
- self.assertEqual(vars(c), {'x': 1})
- self.assertEqual(len(fields(c)), 1)
- self.assertEqual(C.y, 10)
- self.assertEqual(C.z, 20)
-
- def test_helper_make_dataclass_other_params(self):
- C = make_dataclass('C',
- [('x', int),
- ('y', ClassVar[int], 10),
- ('z', ClassVar[int], field(default=20)),
- ],
- init=False)
- # Make sure we have a repr, but no init.
- self.assertNotIn('__init__', vars(C))
- self.assertIn('__repr__', vars(C))
-
- # Make sure random other params don't work.
- with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
- C = make_dataclass('C',
- [],
- xxinit=False)
-
- def test_helper_make_dataclass_no_types(self):
- C = make_dataclass('Point', ['x', 'y', 'z'])
- c = C(1, 2, 3)
- self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
- self.assertEqual(C.__annotations__, {'x': 'typing.Any',
- 'y': 'typing.Any',
- 'z': 'typing.Any'})
-
- C = make_dataclass('Point', ['x', ('y', int), 'z'])
- c = C(1, 2, 3)
- self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
- self.assertEqual(C.__annotations__, {'x': 'typing.Any',
- 'y': int,
- 'z': 'typing.Any'})
-
class TestFieldNoAnnotation(unittest.TestCase):
def test_field_without_annotation(self):
@@ -2947,5 +2839,170 @@ class TestStringAnnotations(unittest.TestCase):
self.assertNotIn('not_iv4', c.__dict__)
+class TestMakeDataclass(unittest.TestCase):
+ def test_simple(self):
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', int, field(default=5))],
+ namespace={'add_one': lambda self: self.x + 1})
+ c = C(10)
+ self.assertEqual((c.x, c.y), (10, 5))
+ self.assertEqual(c.add_one(), 11)
+
+
+ def test_no_mutate_namespace(self):
+ # Make sure a provided namespace isn't mutated.
+ ns = {}
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', int, field(default=5))],
+ namespace=ns)
+ self.assertEqual(ns, {})
+
+ def test_base(self):
+ class Base1:
+ pass
+ class Base2:
+ pass
+ C = make_dataclass('C',
+ [('x', int)],
+ bases=(Base1, Base2))
+ c = C(2)
+ self.assertIsInstance(c, C)
+ self.assertIsInstance(c, Base1)
+ self.assertIsInstance(c, Base2)
+
+ def test_base_dataclass(self):
+ @dataclass
+ class Base1:
+ x: int
+ class Base2:
+ pass
+ C = make_dataclass('C',
+ [('y', int)],
+ bases=(Base1, Base2))
+ with self.assertRaisesRegex(TypeError, 'required positional'):
+ c = C(2)
+ c = C(1, 2)
+ self.assertIsInstance(c, C)
+ self.assertIsInstance(c, Base1)
+ self.assertIsInstance(c, Base2)
+
+ self.assertEqual((c.x, c.y), (1, 2))
+
+ def test_init_var(self):
+ def post_init(self, y):
+ self.x *= y
+
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', InitVar[int]),
+ ],
+ namespace={'__post_init__': post_init},
+ )
+ c = C(2, 3)
+ self.assertEqual(vars(c), {'x': 6})
+ self.assertEqual(len(fields(c)), 1)
+
+ def test_class_var(self):
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', ClassVar[int], 10),
+ ('z', ClassVar[int], field(default=20)),
+ ])
+ c = C(1)
+ self.assertEqual(vars(c), {'x': 1})
+ self.assertEqual(len(fields(c)), 1)
+ self.assertEqual(C.y, 10)
+ self.assertEqual(C.z, 20)
+
+ def test_other_params(self):
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', ClassVar[int], 10),
+ ('z', ClassVar[int], field(default=20)),
+ ],
+ init=False)
+ # Make sure we have a repr, but no init.
+ self.assertNotIn('__init__', vars(C))
+ self.assertIn('__repr__', vars(C))
+
+ # Make sure random other params don't work.
+ with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
+ C = make_dataclass('C',
+ [],
+ xxinit=False)
+
+ def test_no_types(self):
+ C = make_dataclass('Point', ['x', 'y', 'z'])
+ c = C(1, 2, 3)
+ self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
+ self.assertEqual(C.__annotations__, {'x': 'typing.Any',
+ 'y': 'typing.Any',
+ 'z': 'typing.Any'})
+
+ C = make_dataclass('Point', ['x', ('y', int), 'z'])
+ c = C(1, 2, 3)
+ self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
+ self.assertEqual(C.__annotations__, {'x': 'typing.Any',
+ 'y': int,
+ 'z': 'typing.Any'})
+
+ def test_invalid_type_specification(self):
+ for bad_field in [(),
+ (1, 2, 3, 4),
+ ]:
+ with self.subTest(bad_field=bad_field):
+ with self.assertRaisesRegex(TypeError, r'Invalid field: '):
+ make_dataclass('C', ['a', bad_field])
+
+ # And test for things with no len().
+ for bad_field in [float,
+ lambda x:x,
+ ]:
+ with self.subTest(bad_field=bad_field):
+ with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
+ make_dataclass('C', ['a', bad_field])
+
+ def test_duplicate_field_names(self):
+ for field in ['a', 'ab']:
+ with self.subTest(field=field):
+ with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
+ make_dataclass('C', [field, 'a', field])
+
+ def test_keyword_field_names(self):
+ for field in ['for', 'async', 'await', 'as']:
+ with self.subTest(field=field):
+ with self.assertRaisesRegex(TypeError, 'must not be keywords'):
+ make_dataclass('C', ['a', field])
+ with self.assertRaisesRegex(TypeError, 'must not be keywords'):
+ make_dataclass('C', [field])
+ with self.assertRaisesRegex(TypeError, 'must not be keywords'):
+ make_dataclass('C', [field, 'a'])
+
+ def test_non_identifier_field_names(self):
+ for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
+ with self.subTest(field=field):
+ with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
+ make_dataclass('C', ['a', field])
+ with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
+ make_dataclass('C', [field])
+ with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
+ make_dataclass('C', [field, 'a'])
+
+ def test_underscore_field_names(self):
+ # Unlike namedtuple, it's okay if dataclass field names have
+ # an underscore.
+ make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
+
+ def test_funny_class_names_names(self):
+ # No reason to prevent weird class names, since
+ # types.new_class allows them.
+ for classname in ['()', 'x,y', '*', '2@3', '']:
+ with self.subTest(classname=classname):
+ C = make_dataclass(classname, ['a', 'b'])
+ self.assertEqual(C.__name__, classname)
+
+
if __name__ == '__main__':
unittest.main()