diff options
author | Raymond Hettinger <rhettinger@users.noreply.github.com> | 2018-01-11 05:45:19 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-11 05:45:19 (GMT) |
commit | 3948207c610e931831828d33aaef258185df31db (patch) | |
tree | 784b0122b54543b540559d23385a060f9d924d44 /Lib | |
parent | d55209d5b1e097cde55fa3f83149d614c8ccaf09 (diff) | |
download | cpython-3948207c610e931831828d33aaef258185df31db.zip cpython-3948207c610e931831828d33aaef258185df31db.tar.gz cpython-3948207c610e931831828d33aaef258185df31db.tar.bz2 |
bpo-32320: Add default value support to collections.namedtuple() (#4859)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/collections/__init__.py | 20 | ||||
-rw-r--r-- | Lib/test/test_collections.py | 51 |
2 files changed, 68 insertions, 3 deletions
diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py index 50cf814..7088b88 100644 --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -303,7 +303,7 @@ except ImportError: _nt_itemgetters = {} -def namedtuple(typename, field_names, *, rename=False, module=None): +def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None): """Returns a new subclass of tuple with named fields. >>> Point = namedtuple('Point', ['x', 'y']) @@ -332,7 +332,8 @@ def namedtuple(typename, field_names, *, rename=False, module=None): if isinstance(field_names, str): field_names = field_names.replace(',', ' ').split() field_names = list(map(str, field_names)) - typename = str(typename) + typename = _sys.intern(str(typename)) + if rename: seen = set() for index, name in enumerate(field_names): @@ -342,6 +343,7 @@ def namedtuple(typename, field_names, *, rename=False, module=None): or name in seen): field_names[index] = f'_{index}' seen.add(name) + for name in [typename] + field_names: if type(name) is not str: raise TypeError('Type names and field names must be strings') @@ -351,6 +353,7 @@ def namedtuple(typename, field_names, *, rename=False, module=None): if _iskeyword(name): raise ValueError('Type names and field names cannot be a ' f'keyword: {name!r}') + seen = set() for name in field_names: if name.startswith('_') and not rename: @@ -360,6 +363,14 @@ def namedtuple(typename, field_names, *, rename=False, module=None): raise ValueError(f'Encountered duplicate field name: {name!r}') seen.add(name) + field_defaults = {} + if defaults is not None: + defaults = tuple(defaults) + if len(defaults) > len(field_names): + raise TypeError('Got more default values than field names') + field_defaults = dict(reversed(list(zip(reversed(field_names), + reversed(defaults))))) + # Variables used in the methods and docstrings field_names = tuple(map(_sys.intern, field_names)) num_fields = len(field_names) @@ -372,10 +383,12 @@ def namedtuple(typename, field_names, *, rename=False, module=None): s = f'def __new__(_cls, {arg_list}): return _tuple_new(_cls, ({arg_list}))' namespace = {'_tuple_new': tuple_new, '__name__': f'namedtuple_{typename}'} - # Note: exec() has the side-effect of interning the typename and field names + # Note: exec() has the side-effect of interning the field names exec(s, namespace) __new__ = namespace['__new__'] __new__.__doc__ = f'Create new instance of {typename}({arg_list})' + if defaults is not None: + __new__.__defaults__ = defaults @classmethod def _make(cls, iterable): @@ -420,6 +433,7 @@ def namedtuple(typename, field_names, *, rename=False, module=None): '__doc__': f'{typename}({arg_list})', '__slots__': (), '_fields': field_names, + '_fields_defaults': field_defaults, '__new__': __new__, '_make': _make, '_replace': _replace, diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 6c466f4..cb66235 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -216,6 +216,57 @@ class TestNamedTuple(unittest.TestCase): self.assertRaises(TypeError, Point._make, [11]) # catch too few args self.assertRaises(TypeError, Point._make, [11, 22, 33]) # catch too many args + def test_defaults(self): + Point = namedtuple('Point', 'x y', defaults=(10, 20)) # 2 defaults + self.assertEqual(Point._fields_defaults, {'x': 10, 'y': 20}) + self.assertEqual(Point(1, 2), (1, 2)) + self.assertEqual(Point(1), (1, 20)) + self.assertEqual(Point(), (10, 20)) + + Point = namedtuple('Point', 'x y', defaults=(20,)) # 1 default + self.assertEqual(Point._fields_defaults, {'y': 20}) + self.assertEqual(Point(1, 2), (1, 2)) + self.assertEqual(Point(1), (1, 20)) + + Point = namedtuple('Point', 'x y', defaults=()) # 0 defaults + self.assertEqual(Point._fields_defaults, {}) + self.assertEqual(Point(1, 2), (1, 2)) + with self.assertRaises(TypeError): + Point(1) + + with self.assertRaises(TypeError): # catch too few args + Point() + with self.assertRaises(TypeError): # catch too many args + Point(1, 2, 3) + with self.assertRaises(TypeError): # too many defaults + Point = namedtuple('Point', 'x y', defaults=(10, 20, 30)) + with self.assertRaises(TypeError): # non-iterable defaults + Point = namedtuple('Point', 'x y', defaults=10) + with self.assertRaises(TypeError): # another non-iterable default + Point = namedtuple('Point', 'x y', defaults=False) + + Point = namedtuple('Point', 'x y', defaults=None) # default is None + self.assertEqual(Point._fields_defaults, {}) + self.assertIsNone(Point.__new__.__defaults__, None) + self.assertEqual(Point(10, 20), (10, 20)) + with self.assertRaises(TypeError): # catch too few args + Point(10) + + Point = namedtuple('Point', 'x y', defaults=[10, 20]) # allow non-tuple iterable + self.assertEqual(Point._fields_defaults, {'x': 10, 'y': 20}) + self.assertEqual(Point.__new__.__defaults__, (10, 20)) + self.assertEqual(Point(1, 2), (1, 2)) + self.assertEqual(Point(1), (1, 20)) + self.assertEqual(Point(), (10, 20)) + + Point = namedtuple('Point', 'x y', defaults=iter([10, 20])) # allow plain iterator + self.assertEqual(Point._fields_defaults, {'x': 10, 'y': 20}) + self.assertEqual(Point.__new__.__defaults__, (10, 20)) + self.assertEqual(Point(1, 2), (1, 2)) + self.assertEqual(Point(1), (1, 20)) + self.assertEqual(Point(), (10, 20)) + + @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") def test_factory_doc_attr(self): |