summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2018-01-11 05:45:19 (GMT)
committerGitHub <noreply@github.com>2018-01-11 05:45:19 (GMT)
commit3948207c610e931831828d33aaef258185df31db (patch)
tree784b0122b54543b540559d23385a060f9d924d44 /Lib
parentd55209d5b1e097cde55fa3f83149d614c8ccaf09 (diff)
downloadcpython-3948207c610e931831828d33aaef258185df31db.zip
cpython-3948207c610e931831828d33aaef258185df31db.tar.gz
cpython-3948207c610e931831828d33aaef258185df31db.tar.bz2
bpo-32320: Add default value support to collections.namedtuple() (#4859)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/collections/__init__.py20
-rw-r--r--Lib/test/test_collections.py51
2 files changed, 68 insertions, 3 deletions
diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py
index 50cf814..7088b88 100644
--- a/Lib/collections/__init__.py
+++ b/Lib/collections/__init__.py
@@ -303,7 +303,7 @@ except ImportError:
_nt_itemgetters = {}
-def namedtuple(typename, field_names, *, rename=False, module=None):
+def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None):
"""Returns a new subclass of tuple with named fields.
>>> Point = namedtuple('Point', ['x', 'y'])
@@ -332,7 +332,8 @@ def namedtuple(typename, field_names, *, rename=False, module=None):
if isinstance(field_names, str):
field_names = field_names.replace(',', ' ').split()
field_names = list(map(str, field_names))
- typename = str(typename)
+ typename = _sys.intern(str(typename))
+
if rename:
seen = set()
for index, name in enumerate(field_names):
@@ -342,6 +343,7 @@ def namedtuple(typename, field_names, *, rename=False, module=None):
or name in seen):
field_names[index] = f'_{index}'
seen.add(name)
+
for name in [typename] + field_names:
if type(name) is not str:
raise TypeError('Type names and field names must be strings')
@@ -351,6 +353,7 @@ def namedtuple(typename, field_names, *, rename=False, module=None):
if _iskeyword(name):
raise ValueError('Type names and field names cannot be a '
f'keyword: {name!r}')
+
seen = set()
for name in field_names:
if name.startswith('_') and not rename:
@@ -360,6 +363,14 @@ def namedtuple(typename, field_names, *, rename=False, module=None):
raise ValueError(f'Encountered duplicate field name: {name!r}')
seen.add(name)
+ field_defaults = {}
+ if defaults is not None:
+ defaults = tuple(defaults)
+ if len(defaults) > len(field_names):
+ raise TypeError('Got more default values than field names')
+ field_defaults = dict(reversed(list(zip(reversed(field_names),
+ reversed(defaults)))))
+
# Variables used in the methods and docstrings
field_names = tuple(map(_sys.intern, field_names))
num_fields = len(field_names)
@@ -372,10 +383,12 @@ def namedtuple(typename, field_names, *, rename=False, module=None):
s = f'def __new__(_cls, {arg_list}): return _tuple_new(_cls, ({arg_list}))'
namespace = {'_tuple_new': tuple_new, '__name__': f'namedtuple_{typename}'}
- # Note: exec() has the side-effect of interning the typename and field names
+ # Note: exec() has the side-effect of interning the field names
exec(s, namespace)
__new__ = namespace['__new__']
__new__.__doc__ = f'Create new instance of {typename}({arg_list})'
+ if defaults is not None:
+ __new__.__defaults__ = defaults
@classmethod
def _make(cls, iterable):
@@ -420,6 +433,7 @@ def namedtuple(typename, field_names, *, rename=False, module=None):
'__doc__': f'{typename}({arg_list})',
'__slots__': (),
'_fields': field_names,
+ '_fields_defaults': field_defaults,
'__new__': __new__,
'_make': _make,
'_replace': _replace,
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
index 6c466f4..cb66235 100644
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -216,6 +216,57 @@ class TestNamedTuple(unittest.TestCase):
self.assertRaises(TypeError, Point._make, [11]) # catch too few args
self.assertRaises(TypeError, Point._make, [11, 22, 33]) # catch too many args
+ def test_defaults(self):
+ Point = namedtuple('Point', 'x y', defaults=(10, 20)) # 2 defaults
+ self.assertEqual(Point._fields_defaults, {'x': 10, 'y': 20})
+ self.assertEqual(Point(1, 2), (1, 2))
+ self.assertEqual(Point(1), (1, 20))
+ self.assertEqual(Point(), (10, 20))
+
+ Point = namedtuple('Point', 'x y', defaults=(20,)) # 1 default
+ self.assertEqual(Point._fields_defaults, {'y': 20})
+ self.assertEqual(Point(1, 2), (1, 2))
+ self.assertEqual(Point(1), (1, 20))
+
+ Point = namedtuple('Point', 'x y', defaults=()) # 0 defaults
+ self.assertEqual(Point._fields_defaults, {})
+ self.assertEqual(Point(1, 2), (1, 2))
+ with self.assertRaises(TypeError):
+ Point(1)
+
+ with self.assertRaises(TypeError): # catch too few args
+ Point()
+ with self.assertRaises(TypeError): # catch too many args
+ Point(1, 2, 3)
+ with self.assertRaises(TypeError): # too many defaults
+ Point = namedtuple('Point', 'x y', defaults=(10, 20, 30))
+ with self.assertRaises(TypeError): # non-iterable defaults
+ Point = namedtuple('Point', 'x y', defaults=10)
+ with self.assertRaises(TypeError): # another non-iterable default
+ Point = namedtuple('Point', 'x y', defaults=False)
+
+ Point = namedtuple('Point', 'x y', defaults=None) # default is None
+ self.assertEqual(Point._fields_defaults, {})
+ self.assertIsNone(Point.__new__.__defaults__, None)
+ self.assertEqual(Point(10, 20), (10, 20))
+ with self.assertRaises(TypeError): # catch too few args
+ Point(10)
+
+ Point = namedtuple('Point', 'x y', defaults=[10, 20]) # allow non-tuple iterable
+ self.assertEqual(Point._fields_defaults, {'x': 10, 'y': 20})
+ self.assertEqual(Point.__new__.__defaults__, (10, 20))
+ self.assertEqual(Point(1, 2), (1, 2))
+ self.assertEqual(Point(1), (1, 20))
+ self.assertEqual(Point(), (10, 20))
+
+ Point = namedtuple('Point', 'x y', defaults=iter([10, 20])) # allow plain iterator
+ self.assertEqual(Point._fields_defaults, {'x': 10, 'y': 20})
+ self.assertEqual(Point.__new__.__defaults__, (10, 20))
+ self.assertEqual(Point(1, 2), (1, 2))
+ self.assertEqual(Point(1), (1, 20))
+ self.assertEqual(Point(), (10, 20))
+
+
@unittest.skipIf(sys.flags.optimize >= 2,
"Docstrings are omitted with -O2 and above")
def test_factory_doc_attr(self):