summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/csv.py22
-rw-r--r--Lib/test/test_csv.py24
2 files changed, 42 insertions, 4 deletions
diff --git a/Lib/csv.py b/Lib/csv.py
index 09f4cf4..e0558c7 100644
--- a/Lib/csv.py
+++ b/Lib/csv.py
@@ -68,7 +68,7 @@ register_dialect("excel-tab", excel_tab)
class DictReader:
def __init__(self, f, fieldnames=None, restkey=None, restval=None,
dialect="excel", *args, **kwds):
- self.fieldnames = fieldnames # list of keys for the dict
+ self._fieldnames = fieldnames # list of keys for the dict
self.restkey = restkey # key to catch long rows
self.restval = restval # default value for short rows
self.reader = reader(f, dialect, *args, **kwds)
@@ -78,11 +78,25 @@ class DictReader:
def __iter__(self):
return self
+ @property
+ def fieldnames(self):
+ if self._fieldnames is None:
+ try:
+ self._fieldnames = next(self.reader)
+ except StopIteration:
+ pass
+ self.line_num = self.reader.line_num
+ return self._fieldnames
+
+ @fieldnames.setter
+ def fieldnames(self, value):
+ self._fieldnames = value
+
def __next__(self):
+ if self.line_num == 0:
+ # Used only for its side effect.
+ self.fieldnames
row = next(self.reader)
- if self.fieldnames is None:
- self.fieldnames = row
- row = next(self.reader)
self.line_num = self.reader.line_num
# unlike the basic reader, we prefer not to return blanks,
diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py
index 1dbb71a..9c9840b 100644
--- a/Lib/test/test_csv.py
+++ b/Lib/test/test_csv.py
@@ -544,6 +544,29 @@ class TestDictFields(unittest.TestCase):
fileobj.seek(0)
reader = csv.DictReader(fileobj)
self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'})
+ self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"])
+
+ # Two test cases to make sure existing ways of implicitly setting
+ # fieldnames continue to work. Both arise from discussion in issue3436.
+ def test_read_dict_fieldnames_from_file(self):
+ with TemporaryFile("w+") as fileobj:
+ fileobj.write("f1,f2,f3\r\n1,2,abc\r\n")
+ fileobj.seek(0)
+ reader = csv.DictReader(fileobj,
+ fieldnames=next(csv.reader(fileobj)))
+ self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"])
+ self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'})
+
+ def test_read_dict_fieldnames_chain(self):
+ import itertools
+ with TemporaryFile("w+") as fileobj:
+ fileobj.write("f1,f2,f3\r\n1,2,abc\r\n")
+ fileobj.seek(0)
+ reader = csv.DictReader(fileobj)
+ first = next(reader)
+ for row in itertools.chain([first], reader):
+ self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"])
+ self.assertEqual(row, {"f1": '1', "f2": '2', "f3": 'abc'})
def test_read_long(self):
with TemporaryFile("w+") as fileobj:
@@ -568,6 +591,7 @@ class TestDictFields(unittest.TestCase):
fileobj.write("f1,f2\r\n1,2,abc,4,5,6\r\n")
fileobj.seek(0)
reader = csv.DictReader(fileobj, restkey="_rest")
+ self.assertEqual(reader.fieldnames, ["f1", "f2"])
self.assertEqual(next(reader), {"f1": '1', "f2": '2',
"_rest": ["abc", "4", "5", "6"]})