summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_csv.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_csv.py')
-rw-r--r--Lib/test/test_csv.py56
1 files changed, 34 insertions, 22 deletions
diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py
index 83f8cb3..65449ae 100644
--- a/Lib/test/test_csv.py
+++ b/Lib/test/test_csv.py
@@ -124,24 +124,31 @@ class Test_Csv(unittest.TestCase):
self.assertEqual(fileobj.read(),
expect + writer.dialect.lineterminator)
+ def _write_error_test(self, exc, fields, **kwargs):
+ with TemporaryFile("w+", newline='') as fileobj:
+ writer = csv.writer(fileobj, **kwargs)
+ with self.assertRaises(exc):
+ writer.writerow(fields)
+ fileobj.seek(0)
+ self.assertEqual(fileobj.read(), '')
+
def test_write_arg_valid(self):
- self.assertRaises(csv.Error, self._write_test, None, '')
+ self._write_error_test(csv.Error, None)
self._write_test((), '')
self._write_test([None], '""')
- self.assertRaises(csv.Error, self._write_test,
- [None], None, quoting = csv.QUOTE_NONE)
+ self._write_error_test(csv.Error, [None], quoting = csv.QUOTE_NONE)
# Check that exceptions are passed up the chain
class BadList:
def __len__(self):
return 10;
def __getitem__(self, i):
if i > 2:
- raise IOError
- self.assertRaises(IOError, self._write_test, BadList(), '')
+ raise OSError
+ self._write_error_test(OSError, BadList())
class BadItem:
def __str__(self):
- raise IOError
- self.assertRaises(IOError, self._write_test, [BadItem()], '')
+ raise OSError
+ self._write_error_test(OSError, [BadItem()])
def test_write_bigfield(self):
# This exercises the buffer realloc functionality
@@ -151,10 +158,8 @@ class Test_Csv(unittest.TestCase):
def test_write_quoting(self):
self._write_test(['a',1,'p,q'], 'a,1,"p,q"')
- self.assertRaises(csv.Error,
- self._write_test,
- ['a',1,'p,q'], 'a,1,p,q',
- quoting = csv.QUOTE_NONE)
+ self._write_error_test(csv.Error, ['a',1,'p,q'],
+ quoting = csv.QUOTE_NONE)
self._write_test(['a',1,'p,q'], 'a,1,"p,q"',
quoting = csv.QUOTE_MINIMAL)
self._write_test(['a',1,'p,q'], '"a",1,"p,q"',
@@ -167,10 +172,8 @@ class Test_Csv(unittest.TestCase):
def test_write_escape(self):
self._write_test(['a',1,'p,q'], 'a,1,"p,q"',
escapechar='\\')
- self.assertRaises(csv.Error,
- self._write_test,
- ['a',1,'p,"q"'], 'a,1,"p,\\"q\\""',
- escapechar=None, doublequote=False)
+ self._write_error_test(csv.Error, ['a',1,'p,"q"'],
+ escapechar=None, doublequote=False)
self._write_test(['a',1,'p,"q"'], 'a,1,"p,\\"q\\""',
escapechar='\\', doublequote = False)
self._write_test(['"'], '""""',
@@ -186,9 +189,9 @@ class Test_Csv(unittest.TestCase):
def test_writerows(self):
class BrokenFile:
def write(self, buf):
- raise IOError
+ raise OSError
writer = csv.writer(BrokenFile())
- self.assertRaises(IOError, writer.writerows, [['a']])
+ self.assertRaises(OSError, writer.writerows, [['a']])
with TemporaryFile("w+", newline='') as fileobj:
writer = csv.writer(fileobj)
@@ -308,6 +311,15 @@ class Test_Csv(unittest.TestCase):
for i, row in enumerate(csv.reader(fileobj)):
self.assertEqual(row, rows[i])
+ def test_roundtrip_escaped_unquoted_newlines(self):
+ with TemporaryFile("w+", newline='') as fileobj:
+ writer = csv.writer(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\")
+ rows = [['a\nb','b'],['c','x\r\nd']]
+ writer.writerows(rows)
+ fileobj.seek(0)
+ for i, row in enumerate(csv.reader(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\")):
+ self.assertEqual(row,rows[i])
+
class TestDialectRegistry(unittest.TestCase):
def test_registry_badargs(self):
self.assertRaises(TypeError, csv.list_dialects, None)
@@ -836,10 +848,11 @@ class TestDialectValidity(unittest.TestCase):
d = mydialect()
for field_name in ("delimiter", "escapechar", "quotechar"):
- self.assertRaises(csv.Error, create_invalid, field_name, "")
- self.assertRaises(csv.Error, create_invalid, field_name, "abc")
- self.assertRaises(csv.Error, create_invalid, field_name, b'x')
- self.assertRaises(csv.Error, create_invalid, field_name, 5)
+ with self.subTest(field_name=field_name):
+ self.assertRaises(csv.Error, create_invalid, field_name, "")
+ self.assertRaises(csv.Error, create_invalid, field_name, "abc")
+ self.assertRaises(csv.Error, create_invalid, field_name, b'x')
+ self.assertRaises(csv.Error, create_invalid, field_name, 5)
class TestSniffer(unittest.TestCase):
@@ -1053,7 +1066,6 @@ class TestUnicode(unittest.TestCase):
self.assertEqual(fileobj.read(), expected)
-
def test_main():
mod = sys.modules[__name__]
support.run_unittest(