summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/tempfile.py8
-rw-r--r--Lib/test/test_tempfile.py19
2 files changed, 22 insertions, 5 deletions
diff --git a/Lib/tempfile.py b/Lib/tempfile.py
index 109dc59..68cc540 100644
--- a/Lib/tempfile.py
+++ b/Lib/tempfile.py
@@ -29,6 +29,7 @@ __all__ = [
# Imports.
+import io as _io
import os as _os
import errno as _errno
from random import Random as _Random
@@ -37,8 +38,6 @@ if _os.name == 'mac':
import Carbon.Folder as _Folder
import Carbon.Folders as _Folders
-from io import StringIO as _StringIO
-
try:
import fcntl as _fcntl
except ImportError:
@@ -486,7 +485,10 @@ class SpooledTemporaryFile:
def __init__(self, max_size=0, mode='w+b', bufsize=-1,
suffix="", prefix=template, dir=None):
- self._file = _StringIO()
+ if 'b' in mode:
+ self._file = _io.BytesIO()
+ else:
+ self._file = _io.StringIO()
self._max_size = max_size
self._rolled = False
self._TemporaryFileArgs = (mode, bufsize, suffix, prefix, dir)
diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py
index caa8f4e..c925f68 100644
--- a/Lib/test/test_tempfile.py
+++ b/Lib/test/test_tempfile.py
@@ -664,7 +664,7 @@ class test_SpooledTemporaryFile(TC):
self.failUnless(f._rolled)
filename = f.name
f.close()
- self.failIf(os.path.exists(filename),
+ self.failIf(isinstance(filename, str) and os.path.exists(filename),
"SpooledTemporaryFile %s exists after close" % filename)
finally:
os.rmdir(dir)
@@ -730,7 +730,22 @@ class test_SpooledTemporaryFile(TC):
write("a" * 35)
write("b" * 35)
seek(0, 0)
- self.assertEqual(read(70), 'a'*35 + 'b'*35)
+ self.assertEqual(read(70), b'a'*35 + b'b'*35)
+
+ def test_text_mode(self):
+ # Creating a SpooledTemporaryFile with a text mode should produce
+ # a file object reading and writing (Unicode) text strings.
+ f = tempfile.SpooledTemporaryFile(mode='w+', max_size=10)
+ f.write("abc\n")
+ f.seek(0)
+ self.assertEqual(f.read(), "abc\n")
+ f.write("def\n")
+ f.seek(0)
+ self.assertEqual(f.read(), "abc\ndef\n")
+ f.write("xyzzy\n")
+ f.seek(0)
+ self.assertEqual(f.read(), "abc\ndef\nxyzzy\n")
+
test_classes.append(test_SpooledTemporaryFile)