diff options
Diffstat (limited to 'Lib/test/test_csv.py')
| -rw-r--r-- | Lib/test/test_csv.py | 56 |
1 files changed, 34 insertions, 22 deletions
diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py index 83f8cb3..65449ae 100644 --- a/Lib/test/test_csv.py +++ b/Lib/test/test_csv.py @@ -124,24 +124,31 @@ class Test_Csv(unittest.TestCase): self.assertEqual(fileobj.read(), expect + writer.dialect.lineterminator) + def _write_error_test(self, exc, fields, **kwargs): + with TemporaryFile("w+", newline='') as fileobj: + writer = csv.writer(fileobj, **kwargs) + with self.assertRaises(exc): + writer.writerow(fields) + fileobj.seek(0) + self.assertEqual(fileobj.read(), '') + def test_write_arg_valid(self): - self.assertRaises(csv.Error, self._write_test, None, '') + self._write_error_test(csv.Error, None) self._write_test((), '') self._write_test([None], '""') - self.assertRaises(csv.Error, self._write_test, - [None], None, quoting = csv.QUOTE_NONE) + self._write_error_test(csv.Error, [None], quoting = csv.QUOTE_NONE) # Check that exceptions are passed up the chain class BadList: def __len__(self): return 10; def __getitem__(self, i): if i > 2: - raise IOError - self.assertRaises(IOError, self._write_test, BadList(), '') + raise OSError + self._write_error_test(OSError, BadList()) class BadItem: def __str__(self): - raise IOError - self.assertRaises(IOError, self._write_test, [BadItem()], '') + raise OSError + self._write_error_test(OSError, [BadItem()]) def test_write_bigfield(self): # This exercises the buffer realloc functionality @@ -151,10 +158,8 @@ class Test_Csv(unittest.TestCase): def test_write_quoting(self): self._write_test(['a',1,'p,q'], 'a,1,"p,q"') - self.assertRaises(csv.Error, - self._write_test, - ['a',1,'p,q'], 'a,1,p,q', - quoting = csv.QUOTE_NONE) + self._write_error_test(csv.Error, ['a',1,'p,q'], + quoting = csv.QUOTE_NONE) self._write_test(['a',1,'p,q'], 'a,1,"p,q"', quoting = csv.QUOTE_MINIMAL) self._write_test(['a',1,'p,q'], '"a",1,"p,q"', @@ -167,10 +172,8 @@ class Test_Csv(unittest.TestCase): def test_write_escape(self): self._write_test(['a',1,'p,q'], 'a,1,"p,q"', escapechar='\\') - self.assertRaises(csv.Error, - self._write_test, - ['a',1,'p,"q"'], 'a,1,"p,\\"q\\""', - escapechar=None, doublequote=False) + self._write_error_test(csv.Error, ['a',1,'p,"q"'], + escapechar=None, doublequote=False) self._write_test(['a',1,'p,"q"'], 'a,1,"p,\\"q\\""', escapechar='\\', doublequote = False) self._write_test(['"'], '""""', @@ -186,9 +189,9 @@ class Test_Csv(unittest.TestCase): def test_writerows(self): class BrokenFile: def write(self, buf): - raise IOError + raise OSError writer = csv.writer(BrokenFile()) - self.assertRaises(IOError, writer.writerows, [['a']]) + self.assertRaises(OSError, writer.writerows, [['a']]) with TemporaryFile("w+", newline='') as fileobj: writer = csv.writer(fileobj) @@ -308,6 +311,15 @@ class Test_Csv(unittest.TestCase): for i, row in enumerate(csv.reader(fileobj)): self.assertEqual(row, rows[i]) + def test_roundtrip_escaped_unquoted_newlines(self): + with TemporaryFile("w+", newline='') as fileobj: + writer = csv.writer(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\") + rows = [['a\nb','b'],['c','x\r\nd']] + writer.writerows(rows) + fileobj.seek(0) + for i, row in enumerate(csv.reader(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\")): + self.assertEqual(row,rows[i]) + class TestDialectRegistry(unittest.TestCase): def test_registry_badargs(self): self.assertRaises(TypeError, csv.list_dialects, None) @@ -836,10 +848,11 @@ class TestDialectValidity(unittest.TestCase): d = mydialect() for field_name in ("delimiter", "escapechar", "quotechar"): - self.assertRaises(csv.Error, create_invalid, field_name, "") - self.assertRaises(csv.Error, create_invalid, field_name, "abc") - self.assertRaises(csv.Error, create_invalid, field_name, b'x') - self.assertRaises(csv.Error, create_invalid, field_name, 5) + with self.subTest(field_name=field_name): + self.assertRaises(csv.Error, create_invalid, field_name, "") + self.assertRaises(csv.Error, create_invalid, field_name, "abc") + self.assertRaises(csv.Error, create_invalid, field_name, b'x') + self.assertRaises(csv.Error, create_invalid, field_name, 5) class TestSniffer(unittest.TestCase): @@ -1053,7 +1066,6 @@ class TestUnicode(unittest.TestCase): self.assertEqual(fileobj.read(), expected) - def test_main(): mod = sys.modules[__name__] support.run_unittest( |
