diff options
author | Sam Ezeh <sam.z.ezeh@gmail.com> | 2022-08-25 10:13:24 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-25 10:13:24 (GMT) |
commit | cd492d43a2980faf0ef4a3f99c665023a506414c (patch) | |
tree | c5758ce600d818dfdad9b1cbacb5b62aa3adc7a1 /Lib | |
parent | c09fa7542c6d9b724e423b14c6fb5f4338eabd12 (diff) | |
download | cpython-cd492d43a2980faf0ef4a3f99c665023a506414c.zip cpython-cd492d43a2980faf0ef4a3f99c665023a506414c.tar.gz cpython-cd492d43a2980faf0ef4a3f99c665023a506414c.tar.bz2 |
gh-76728: Coerce DictReader and DictWriter fieldnames argument to a list (GH-32225)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/csv.py | 4 | ||||
-rw-r--r-- | Lib/test/test_csv.py | 28 |
2 files changed, 32 insertions, 0 deletions
@@ -81,6 +81,8 @@ register_dialect("unix", unix_dialect) class DictReader: def __init__(self, f, fieldnames=None, restkey=None, restval=None, dialect="excel", *args, **kwds): + if fieldnames is not None and iter(fieldnames) is fieldnames: + fieldnames = list(fieldnames) self._fieldnames = fieldnames # list of keys for the dict self.restkey = restkey # key to catch long rows self.restval = restval # default value for short rows @@ -133,6 +135,8 @@ class DictReader: class DictWriter: def __init__(self, f, fieldnames, restval="", extrasaction="raise", dialect="excel", *args, **kwds): + if fieldnames is not None and iter(fieldnames) is fieldnames: + fieldnames = list(fieldnames) self.fieldnames = fieldnames # list of keys for the dict self.restval = restval # for writing short dicts if extrasaction.lower() not in ("raise", "ignore"): diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py index 95a19dd..51ca1f2 100644 --- a/Lib/test/test_csv.py +++ b/Lib/test/test_csv.py @@ -736,6 +736,34 @@ class TestDictFields(unittest.TestCase): csv.DictWriter.writerow(writer, dictrow) self.assertEqual(fileobj.getvalue(), "1,2\r\n") + def test_dict_reader_fieldnames_accepts_iter(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + reader = csv.DictReader(f, iter(fieldnames)) + self.assertEqual(reader.fieldnames, fieldnames) + + def test_dict_reader_fieldnames_accepts_list(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + reader = csv.DictReader(f, fieldnames) + self.assertEqual(reader.fieldnames, fieldnames) + + def test_dict_writer_fieldnames_rejects_iter(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + writer = csv.DictWriter(f, iter(fieldnames)) + self.assertEqual(writer.fieldnames, fieldnames) + + def test_dict_writer_fieldnames_accepts_list(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + writer = csv.DictWriter(f, fieldnames) + self.assertEqual(writer.fieldnames, fieldnames) + + def test_dict_reader_fieldnames_is_optional(self): + f = StringIO() + reader = csv.DictReader(f, fieldnames=None) + def test_read_dict_fields(self): with TemporaryFile("w+", encoding="utf-8") as fileobj: fileobj.write("1,2,abc\r\n") |