summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorAntoine Pitrou <solipsis@pitrou.net>2008-12-14 16:36:46 (GMT)
committerAntoine Pitrou <solipsis@pitrou.net>2008-12-14 16:36:46 (GMT)
commit180a336f1afdcef332189e6bcee314576cadc2bf (patch)
tree645f10b148e39a1a9a602d426dbe9846a48c62f7 /Lib
parentff94552763d5ceb33dd646a534b4d1b56e6162cb (diff)
downloadcpython-180a336f1afdcef332189e6bcee314576cadc2bf.zip
cpython-180a336f1afdcef332189e6bcee314576cadc2bf.tar.gz
cpython-180a336f1afdcef332189e6bcee314576cadc2bf.tar.bz2
Issue #4574: reading an UTF16-encoded text file crashes if \r on 64-char boundary.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/io.py27
-rw-r--r--Lib/test/test_io.py112
2 files changed, 84 insertions, 55 deletions
diff --git a/Lib/io.py b/Lib/io.py
index 041e8eb..af5a144 100644
--- a/Lib/io.py
+++ b/Lib/io.py
@@ -1282,25 +1282,23 @@ class IncrementalNewlineDecoder(codecs.IncrementalDecoder):
"""
def __init__(self, decoder, translate, errors='strict'):
codecs.IncrementalDecoder.__init__(self, errors=errors)
- self.buffer = b''
self.translate = translate
self.decoder = decoder
self.seennl = 0
+ self.pendingcr = False
def decode(self, input, final=False):
# decode input (with the eventual \r from a previous pass)
- if self.buffer:
- input = self.buffer + input
-
output = self.decoder.decode(input, final=final)
+ if self.pendingcr and (output or final):
+ output = "\r" + output
+ self.pendingcr = False
# retain last \r even when not translating data:
# then readline() is sure to get \r\n in one pass
if output.endswith("\r") and not final:
output = output[:-1]
- self.buffer = b'\r'
- else:
- self.buffer = b''
+ self.pendingcr = True
# Record which newlines are read
crlf = output.count('\r\n')
@@ -1319,20 +1317,19 @@ class IncrementalNewlineDecoder(codecs.IncrementalDecoder):
def getstate(self):
buf, flag = self.decoder.getstate()
- return buf + self.buffer, flag
+ flag <<= 1
+ if self.pendingcr:
+ flag |= 1
+ return buf, flag
def setstate(self, state):
buf, flag = state
- if buf.endswith(b'\r'):
- self.buffer = b'\r'
- buf = buf[:-1]
- else:
- self.buffer = b''
- self.decoder.setstate((buf, flag))
+ self.pendingcr = bool(flag & 1)
+ self.decoder.setstate((buf, flag >> 1))
def reset(self):
self.seennl = 0
- self.buffer = b''
+ self.pendingcr = False
self.decoder.reset()
_LF = 1
diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py
index 58203ed..1733440 100644
--- a/Lib/test/test_io.py
+++ b/Lib/test/test_io.py
@@ -679,8 +679,9 @@ class StatefulIncrementalDecoder(codecs.IncrementalDecoder):
@classmethod
def lookupTestDecoder(cls, name):
if cls.codecEnabled and name == 'test_decoder':
+ latin1 = codecs.lookup('latin-1')
return codecs.CodecInfo(
- name='test_decoder', encode=None, decode=None,
+ name='test_decoder', encode=latin1.encode, decode=None,
incrementalencoder=None,
streamreader=None, streamwriter=None,
incrementaldecoder=cls)
@@ -840,8 +841,11 @@ class TextIOWrapperTest(unittest.TestCase):
[ '\r\n', [ "unix\nwindows\r\n", "os9\rlast\nnonl" ] ],
[ '\r', [ "unix\nwindows\r", "\nos9\r", "last\nnonl" ] ],
]
-
- encodings = ('utf-8', 'latin-1')
+ encodings = (
+ 'utf-8', 'latin-1',
+ 'utf-16', 'utf-16-le', 'utf-16-be',
+ 'utf-32', 'utf-32-le', 'utf-32-be',
+ )
# Try a range of buffer sizes to test the case where \r is the last
# character in TextIOWrapper._pending_line.
@@ -1195,56 +1199,84 @@ class TextIOWrapperTest(unittest.TestCase):
self.assertEqual(buffer.seekable(), txt.seekable())
- def test_newline_decoder(self):
- import codecs
- decoder = codecs.getincrementaldecoder("utf-8")()
- decoder = io.IncrementalNewlineDecoder(decoder, translate=True)
+ def check_newline_decoder_utf8(self, decoder):
+ # UTF-8 specific tests for a newline decoder
+ def _check_decode(b, s, **kwargs):
+ # We exercise getstate() / setstate() as well as decode()
+ state = decoder.getstate()
+ self.assertEquals(decoder.decode(b, **kwargs), s)
+ decoder.setstate(state)
+ self.assertEquals(decoder.decode(b, **kwargs), s)
- self.assertEquals(decoder.decode(b'\xe8\xa2\x88'), "\u8888")
+ _check_decode(b'\xe8\xa2\x88', "\u8888")
- self.assertEquals(decoder.decode(b'\xe8'), "")
- self.assertEquals(decoder.decode(b'\xa2'), "")
- self.assertEquals(decoder.decode(b'\x88'), "\u8888")
+ _check_decode(b'\xe8', "")
+ _check_decode(b'\xa2', "")
+ _check_decode(b'\x88', "\u8888")
- self.assertEquals(decoder.decode(b'\xe8'), "")
- self.assertRaises(UnicodeDecodeError, decoder.decode, b'', final=True)
+ _check_decode(b'\xe8', "")
+ _check_decode(b'\xa2', "")
+ _check_decode(b'\x88', "\u8888")
- decoder.setstate((b'', 0))
- self.assertEquals(decoder.decode(b'\n'), "\n")
- self.assertEquals(decoder.decode(b'\r'), "")
- self.assertEquals(decoder.decode(b'', final=True), "\n")
- self.assertEquals(decoder.decode(b'\r', final=True), "\n")
-
- self.assertEquals(decoder.decode(b'\r'), "")
- self.assertEquals(decoder.decode(b'a'), "\na")
-
- self.assertEquals(decoder.decode(b'\r\r\n'), "\n\n")
- self.assertEquals(decoder.decode(b'\r'), "")
- self.assertEquals(decoder.decode(b'\r'), "\n")
- self.assertEquals(decoder.decode(b'\na'), "\na")
-
- self.assertEquals(decoder.decode(b'\xe8\xa2\x88\r\n'), "\u8888\n")
- self.assertEquals(decoder.decode(b'\xe8\xa2\x88'), "\u8888")
- self.assertEquals(decoder.decode(b'\n'), "\n")
- self.assertEquals(decoder.decode(b'\xe8\xa2\x88\r'), "\u8888")
- self.assertEquals(decoder.decode(b'\n'), "\n")
+ _check_decode(b'\xe8', "")
+ self.assertRaises(UnicodeDecodeError, decoder.decode, b'', final=True)
- decoder = codecs.getincrementaldecoder("utf-8")()
- decoder = io.IncrementalNewlineDecoder(decoder, translate=True)
+ decoder.reset()
+ _check_decode(b'\n', "\n")
+ _check_decode(b'\r', "")
+ _check_decode(b'', "\n", final=True)
+ _check_decode(b'\r', "\n", final=True)
+
+ _check_decode(b'\r', "")
+ _check_decode(b'a', "\na")
+
+ _check_decode(b'\r\r\n', "\n\n")
+ _check_decode(b'\r', "")
+ _check_decode(b'\r', "\n")
+ _check_decode(b'\na', "\na")
+
+ _check_decode(b'\xe8\xa2\x88\r\n', "\u8888\n")
+ _check_decode(b'\xe8\xa2\x88', "\u8888")
+ _check_decode(b'\n', "\n")
+ _check_decode(b'\xe8\xa2\x88\r', "\u8888")
+ _check_decode(b'\n', "\n")
+
+ def check_newline_decoder(self, decoder, encoding):
+ result = []
+ encoder = codecs.getincrementalencoder(encoding)()
+ def _decode_bytewise(s):
+ for b in encoder.encode(s):
+ result.append(decoder.decode(bytes([b])))
self.assertEquals(decoder.newlines, None)
- decoder.decode(b"abc\n\r")
+ _decode_bytewise("abc\n\r")
self.assertEquals(decoder.newlines, '\n')
- decoder.decode(b"\nabc")
+ _decode_bytewise("\nabc")
self.assertEquals(decoder.newlines, ('\n', '\r\n'))
- decoder.decode(b"abc\r")
+ _decode_bytewise("abc\r")
self.assertEquals(decoder.newlines, ('\n', '\r\n'))
- decoder.decode(b"abc")
+ _decode_bytewise("abc")
self.assertEquals(decoder.newlines, ('\r', '\n', '\r\n'))
- decoder.decode(b"abc\r")
+ _decode_bytewise("abc\r")
+ self.assertEquals("".join(result), "abc\n\nabcabc\nabcabc")
decoder.reset()
- self.assertEquals(decoder.decode(b"abc"), "abc")
+ self.assertEquals(decoder.decode("abc".encode(encoding)), "abc")
self.assertEquals(decoder.newlines, None)
+ def test_newline_decoder(self):
+ encodings = (
+ 'utf-8', 'latin-1',
+ 'utf-16', 'utf-16-le', 'utf-16-be',
+ 'utf-32', 'utf-32-le', 'utf-32-be',
+ )
+ for enc in encodings:
+ decoder = codecs.getincrementaldecoder(enc)()
+ decoder = io.IncrementalNewlineDecoder(decoder, translate=True)
+ self.check_newline_decoder(decoder, enc)
+ decoder = codecs.getincrementaldecoder("utf-8")()
+ decoder = io.IncrementalNewlineDecoder(decoder, translate=True)
+ self.check_newline_decoder_utf8(decoder)
+
+
# XXX Tests for open()
class MiscIOTest(unittest.TestCase):