From f0c7416157f471dbb8ac88aa72202cc984c02700 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 28 Aug 2007 03:29:45 +0000 Subject: Patch # 1033 by Adam Hupp: 1) tempfile.TemporaryFile, NamedTemporaryFile, and SpooledTemporaryFile now pass newline and encoding to the underlying io.open call. 2) test_tempfile is updated 3) test_csv is updated to use the new arguments. --- Lib/tempfile.py | 41 ++++++++++++++++++++++++------------- Lib/test/test_csv.py | 51 +++++++++++++++++++++++------------------------ Lib/test/test_tempfile.py | 24 ++++++++++++++++++++++ 3 files changed, 76 insertions(+), 40 deletions(-) diff --git a/Lib/tempfile.py b/Lib/tempfile.py index c71cebc..b6b3b96 100644 --- a/Lib/tempfile.py +++ b/Lib/tempfile.py @@ -406,13 +406,16 @@ class _TemporaryFileWrapper: def __del__(self): self.close() -def NamedTemporaryFile(mode='w+b', bufsize=-1, suffix="", - prefix=template, dir=None, delete=True): +def NamedTemporaryFile(mode='w+b', buffering=-1, encoding=None, + newline=None, suffix="", prefix=template, + dir=None, delete=True): """Create and return a temporary file. Arguments: 'prefix', 'suffix', 'dir' -- as for mkstemp. - 'mode' -- the mode argument to os.fdopen (default "w+b"). - 'bufsize' -- the buffer size argument to os.fdopen (default -1). + 'mode' -- the mode argument to io.open (default "w+b"). + 'buffering' -- the buffer size argument to io.open (default -1). + 'encoding' -- the encoding argument to io.open (default None) + 'newline' -- the newline argument to io.open (default None) 'delete' -- whether the file is deleted on close (default True). The file is created as mkstemp() would do it. @@ -435,7 +438,9 @@ def NamedTemporaryFile(mode='w+b', bufsize=-1, suffix="", flags |= _os.O_TEMPORARY (fd, name) = _mkstemp_inner(dir, prefix, suffix, flags) - file = _io.open(fd, mode, bufsize) + file = _io.open(fd, mode, buffering=buffering, + newline=newline, encoding=encoding) + return _TemporaryFileWrapper(file, name, delete) if _os.name != 'posix' or _os.sys.platform == 'cygwin': @@ -444,13 +449,16 @@ if _os.name != 'posix' or _os.sys.platform == 'cygwin': TemporaryFile = NamedTemporaryFile else: - def TemporaryFile(mode='w+b', bufsize=-1, suffix="", - prefix=template, dir=None): + def TemporaryFile(mode='w+b', buffering=-1, encoding=None, + newline=None, suffix="", prefix=template, + dir=None): """Create and return a temporary file. Arguments: 'prefix', 'suffix', 'dir' -- as for mkstemp. - 'mode' -- the mode argument to os.fdopen (default "w+b"). - 'bufsize' -- the buffer size argument to os.fdopen (default -1). + 'mode' -- the mode argument to io.open (default "w+b"). + 'buffering' -- the buffer size argument to io.open (default -1). + 'encoding' -- the encoding argument to io.open (default None) + 'newline' -- the newline argument to io.open (default None) The file is created as mkstemp() would do it. Returns an object with a file-like interface. The file has no @@ -468,7 +476,8 @@ else: (fd, name) = _mkstemp_inner(dir, prefix, suffix, flags) try: _os.unlink(name) - return _io.open(fd, mode, bufsize) + return _io.open(fd, mode, buffering=buffering, + newline=newline, encoding=encoding) except: _os.close(fd) raise @@ -480,15 +489,19 @@ class SpooledTemporaryFile: """ _rolled = False - def __init__(self, max_size=0, mode='w+b', bufsize=-1, + def __init__(self, max_size=0, mode='w+b', buffering=-1, + encoding=None, newline=None, suffix="", prefix=template, dir=None): if 'b' in mode: self._file = _io.BytesIO() else: - self._file = _io.StringIO() + self._file = _io.StringIO(encoding=encoding, newline=newline) self._max_size = max_size self._rolled = False - self._TemporaryFileArgs = (mode, bufsize, suffix, prefix, dir) + self._TemporaryFileArgs = {'mode': mode, 'buffering': buffering, + 'suffix': suffix, 'prefix': prefix, + 'encoding': encoding, 'newline': newline, + 'dir': dir} def _check(self, file): if self._rolled: return @@ -499,7 +512,7 @@ class SpooledTemporaryFile: def rollover(self): if self._rolled: return file = self._file - newfile = self._file = TemporaryFile(*self._TemporaryFileArgs) + newfile = self._file = TemporaryFile(**self._TemporaryFileArgs) del self._TemporaryFileArgs newfile.write(file.getvalue()) diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py index 7ac52e3..6c1c542 100644 --- a/Lib/test/test_csv.py +++ b/Lib/test/test_csv.py @@ -2,6 +2,7 @@ # Copyright (C) 2001,2002 Python Software Foundation # csv package unit tests +import io import sys import os import unittest @@ -117,11 +118,11 @@ class Test_Csv(unittest.TestCase): def _write_test(self, fields, expect, **kwargs): - with TemporaryFile("w+b") as fileobj: + with TemporaryFile("w+", newline='') as fileobj: writer = csv.writer(fileobj, **kwargs) writer.writerow(fields) fileobj.seek(0) - self.assertEqual(str(fileobj.read()), + self.assertEqual(fileobj.read(), expect + writer.dialect.lineterminator) def test_write_arg_valid(self): @@ -188,12 +189,12 @@ class Test_Csv(unittest.TestCase): writer = csv.writer(BrokenFile()) self.assertRaises(IOError, writer.writerows, [['a']]) - with TemporaryFile("w+b") as fileobj: + with TemporaryFile("w+", newline='') as fileobj: writer = csv.writer(fileobj) self.assertRaises(TypeError, writer.writerows, None) writer.writerows([['a','b'],['c','d']]) fileobj.seek(0) - self.assertEqual(fileobj.read(), b"a,b\r\nc,d\r\n") + self.assertEqual(fileobj.read(), "a,b\r\nc,d\r\n") def _read_test(self, input, expect, **kwargs): reader = csv.reader(input, **kwargs) @@ -332,11 +333,13 @@ class TestDialectRegistry(unittest.TestCase): self.assertEqual(next(reader), ["c1ccccc1", "benzene"]) def compare_dialect_123(self, expected, *writeargs, **kwwriteargs): - with TemporaryFile("w+b") as fileobj: + + with TemporaryFile("w+", newline='', encoding="utf-8") as fileobj: + writer = csv.writer(fileobj, *writeargs, **kwwriteargs) writer.writerow([1,2,3]) fileobj.seek(0) - self.assertEqual(str(fileobj.read()), expected) + self.assertEqual(fileobj.read(), expected) def test_dialect_apply(self): class testA(csv.excel): @@ -380,11 +383,11 @@ class TestCsvBase(unittest.TestCase): self.assertEqual(fields, expected_result) def writerAssertEqual(self, input, expected_result): - with TemporaryFile("w+b") as fileobj: + with TemporaryFile("w+", newline='') as fileobj: writer = csv.writer(fileobj, dialect = self.dialect) writer.writerows(input) fileobj.seek(0) - self.assertEqual(str(fileobj.read()), expected_result) + self.assertEqual(fileobj.read(), expected_result) class TestDialectExcel(TestCsvBase): dialect = 'excel' @@ -513,11 +516,11 @@ class TestDictFields(unittest.TestCase): ### "long" means the row is longer than the number of fieldnames ### "short" means there are fewer elements in the row than fieldnames def test_write_simple_dict(self): - with TemporaryFile("w+b") as fileobj: + with TemporaryFile("w+", newline='') as fileobj: writer = csv.DictWriter(fileobj, fieldnames = ["f1", "f2", "f3"]) writer.writerow({"f1": 10, "f3": "abc"}) fileobj.seek(0) - self.assertEqual(str(fileobj.read()), "10,,abc\r\n") + self.assertEqual(fileobj.read(), "10,,abc\r\n") def test_write_no_fields(self): fileobj = StringIO() @@ -614,45 +617,45 @@ class TestArrayWrites(unittest.TestCase): contents = [(20-i) for i in range(20)] a = array.array('i', contents) - with TemporaryFile("w+b") as fileobj: + with TemporaryFile("w+", newline='') as fileobj: writer = csv.writer(fileobj, dialect="excel") writer.writerow(a) expected = ",".join([str(i) for i in a])+"\r\n" fileobj.seek(0) - self.assertEqual(str(fileobj.read()), expected) + self.assertEqual(fileobj.read(), expected) def test_double_write(self): import array contents = [(20-i)*0.1 for i in range(20)] a = array.array('d', contents) - with TemporaryFile("w+b") as fileobj: + with TemporaryFile("w+", newline='') as fileobj: writer = csv.writer(fileobj, dialect="excel") writer.writerow(a) expected = ",".join([str(i) for i in a])+"\r\n" fileobj.seek(0) - self.assertEqual(str(fileobj.read()), expected) + self.assertEqual(fileobj.read(), expected) def test_float_write(self): import array contents = [(20-i)*0.1 for i in range(20)] a = array.array('f', contents) - with TemporaryFile("w+b") as fileobj: + with TemporaryFile("w+", newline='') as fileobj: writer = csv.writer(fileobj, dialect="excel") writer.writerow(a) expected = ",".join([str(i) for i in a])+"\r\n" fileobj.seek(0) - self.assertEqual(str(fileobj.read()), expected) + self.assertEqual(fileobj.read(), expected) def test_char_write(self): import array, string a = array.array('u', string.ascii_letters) - with TemporaryFile("w+b") as fileobj: + with TemporaryFile("w+", newline='') as fileobj: writer = csv.writer(fileobj, dialect="excel") writer.writerow(a) expected = ",".join(a)+"\r\n" fileobj.seek(0) - self.assertEqual(str(fileobj.read()), expected) + self.assertEqual(fileobj.read(), expected) class TestDialectValidity(unittest.TestCase): def test_quoting(self): @@ -864,10 +867,8 @@ class TestUnicode(unittest.TestCase): def test_unicode_read(self): import io - fileobj = io.TextIOWrapper(TemporaryFile("w+b"), encoding="utf-16") - with fileobj as fileobj: + with TemporaryFile("w+", newline='', encoding="utf-8") as fileobj: fileobj.write(",".join(self.names) + "\r\n") - fileobj.seek(0) reader = csv.reader(fileobj) self.assertEqual(list(reader), [self.names]) @@ -875,14 +876,12 @@ class TestUnicode(unittest.TestCase): def test_unicode_write(self): import io - with TemporaryFile("w+b") as fileobj: - encwriter = io.TextIOWrapper(fileobj, encoding="utf-8") - writer = csv.writer(encwriter) + with TemporaryFile("w+", newline='', encoding="utf-8") as fileobj: + writer = csv.writer(fileobj) writer.writerow(self.names) expected = ",".join(self.names)+"\r\n" fileobj.seek(0) - self.assertEqual(str(fileobj.read()), expected) - + self.assertEqual(fileobj.read(), expected) diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py index 07f32e8..3e4f803 100644 --- a/Lib/test/test_tempfile.py +++ b/Lib/test/test_tempfile.py @@ -746,6 +746,18 @@ class test_SpooledTemporaryFile(TC): f.seek(0) self.assertEqual(f.read(), "abc\ndef\nxyzzy\n") + def test_text_newline_and_encoding(self): + f = tempfile.SpooledTemporaryFile(mode='w+', max_size=10, + newline='', encoding='utf-8') + f.write("\u039B\r\n") + f.seek(0) + self.assertEqual(f.read(), "\u039B\r\n") + self.failIf(f._rolled) + + f.write("\u039B" * 20 + "\r\n") + f.seek(0) + self.assertEqual(f.read(), "\u039B\r\n" + ("\u039B" * 20) + "\r\n") + self.failUnless(f._rolled) test_classes.append(test_SpooledTemporaryFile) @@ -790,6 +802,18 @@ class test_TemporaryFile(TC): self.failOnException("close") # How to test the mode and bufsize parameters? + def test_mode_and_encoding(self): + + def roundtrip(input, *args, **kwargs): + with tempfile.TemporaryFile(*args, **kwargs) as fileobj: + fileobj.write(input) + fileobj.seek(0) + self.assertEquals(input, fileobj.read()) + + roundtrip(b"1234", "w+b") + roundtrip("abdc\n", "w+") + roundtrip("\u039B", "w+", encoding="utf-16") + roundtrip("foo\r\n", "w+", newline="") if tempfile.NamedTemporaryFile is not tempfile.TemporaryFile: -- cgit v0.12