diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/csv.py | 22 | ||||
-rw-r--r-- | Lib/test/test_csv.py | 24 |
2 files changed, 42 insertions, 4 deletions
@@ -68,7 +68,7 @@ register_dialect("excel-tab", excel_tab) class DictReader: def __init__(self, f, fieldnames=None, restkey=None, restval=None, dialect="excel", *args, **kwds): - self.fieldnames = fieldnames # list of keys for the dict + self._fieldnames = fieldnames # list of keys for the dict self.restkey = restkey # key to catch long rows self.restval = restval # default value for short rows self.reader = reader(f, dialect, *args, **kwds) @@ -78,11 +78,25 @@ class DictReader: def __iter__(self): return self + @property + def fieldnames(self): + if self._fieldnames is None: + try: + self._fieldnames = next(self.reader) + except StopIteration: + pass + self.line_num = self.reader.line_num + return self._fieldnames + + @fieldnames.setter + def fieldnames(self, value): + self._fieldnames = value + def __next__(self): + if self.line_num == 0: + # Used only for its side effect. + self.fieldnames row = next(self.reader) - if self.fieldnames is None: - self.fieldnames = row - row = next(self.reader) self.line_num = self.reader.line_num # unlike the basic reader, we prefer not to return blanks, diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py index 1dbb71a..9c9840b 100644 --- a/Lib/test/test_csv.py +++ b/Lib/test/test_csv.py @@ -544,6 +544,29 @@ class TestDictFields(unittest.TestCase): fileobj.seek(0) reader = csv.DictReader(fileobj) self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'}) + self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) + + # Two test cases to make sure existing ways of implicitly setting + # fieldnames continue to work. Both arise from discussion in issue3436. + def test_read_dict_fieldnames_from_file(self): + with TemporaryFile("w+") as fileobj: + fileobj.write("f1,f2,f3\r\n1,2,abc\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, + fieldnames=next(csv.reader(fileobj))) + self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) + self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'}) + + def test_read_dict_fieldnames_chain(self): + import itertools + with TemporaryFile("w+") as fileobj: + fileobj.write("f1,f2,f3\r\n1,2,abc\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj) + first = next(reader) + for row in itertools.chain([first], reader): + self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) + self.assertEqual(row, {"f1": '1', "f2": '2', "f3": 'abc'}) def test_read_long(self): with TemporaryFile("w+") as fileobj: @@ -568,6 +591,7 @@ class TestDictFields(unittest.TestCase): fileobj.write("f1,f2\r\n1,2,abc,4,5,6\r\n") fileobj.seek(0) reader = csv.DictReader(fileobj, restkey="_rest") + self.assertEqual(reader.fieldnames, ["f1", "f2"]) self.assertEqual(next(reader), {"f1": '1', "f2": '2', "_rest": ["abc", "4", "5", "6"]}) |