diff options
author | Yury Selivanov <yselivanov@sprymix.com> | 2016-01-11 17:28:19 (GMT) |
---|---|---|
committer | Yury Selivanov <yselivanov@sprymix.com> | 2016-01-11 17:28:19 (GMT) |
commit | d9d0e864b901792c412ece6e9583e4ad74f9031d (patch) | |
tree | 42c7ee9321ffb7c3fb12b1e2d437e824dfc49643 /Lib | |
parent | f1240169b351288d11b0a8125ca8dc1f3e840e63 (diff) | |
download | cpython-d9d0e864b901792c412ece6e9583e4ad74f9031d.zip cpython-d9d0e864b901792c412ece6e9583e4ad74f9031d.tar.gz cpython-d9d0e864b901792c412ece6e9583e4ad74f9031d.tar.bz2 |
Issue #26050: Add asyncio.StreamReader.readuntil() method.
Patch by Марк Коренберг.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/streams.py | 225 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_streams.py | 128 |
2 files changed, 314 insertions, 39 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 9097e38..0008d51 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -3,6 +3,7 @@ __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', 'open_connection', 'start_server', 'IncompleteReadError', + 'LimitOverrunError', ] import socket @@ -27,15 +28,28 @@ class IncompleteReadError(EOFError): Incomplete read error. Attributes: - partial: read bytes string before the end of stream was reached - - expected: total number of expected bytes + - expected: total number of expected bytes (or None if unknown) """ def __init__(self, partial, expected): - EOFError.__init__(self, "%s bytes read on a total of %s expected bytes" - % (len(partial), expected)) + super().__init__("%d bytes read on a total of %r expected bytes" + % (len(partial), expected)) self.partial = partial self.expected = expected +class LimitOverrunError(Exception): + """Reached buffer limit while looking for the separator. + + Attributes: + - message: error message + - consumed: total number of bytes that should be consumed + """ + def __init__(self, message, consumed): + super().__init__(message) + self.message = message + self.consumed = consumed + + @coroutine def open_connection(host=None, port=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): @@ -318,6 +332,10 @@ class StreamReader: def __init__(self, limit=_DEFAULT_LIMIT, loop=None): # The line length limit is a security feature; # it also doubles as half the buffer limit. + + if limit <= 0: + raise ValueError('Limit cannot be <= 0') + self._limit = limit if loop is None: self._loop = events.get_event_loop() @@ -361,7 +379,7 @@ class StreamReader: waiter.set_exception(exc) def _wakeup_waiter(self): - """Wakeup read() or readline() function waiting for data or EOF.""" + """Wakeup read*() functions waiting for data or EOF.""" waiter = self._waiter if waiter is not None: self._waiter = None @@ -409,7 +427,10 @@ class StreamReader: @coroutine def _wait_for_data(self, func_name): - """Wait until feed_data() or feed_eof() is called.""" + """Wait until feed_data() or feed_eof() is called. + + If stream was paused, automatically resume it. + """ # StreamReader uses a future to link the protocol feed_data() method # to a read coroutine. Running two read coroutines at the same time # would have an unexpected behaviour. It would not possible to know @@ -418,6 +439,13 @@ class StreamReader: raise RuntimeError('%s() called while another coroutine is ' 'already waiting for incoming data' % func_name) + assert not self._eof, '_wait_for_data after EOF' + + # Waiting for data while paused will make deadlock, so prevent it. + if self._paused: + self._paused = False + self._transport.resume_reading() + self._waiter = futures.Future(loop=self._loop) try: yield from self._waiter @@ -426,43 +454,150 @@ class StreamReader: @coroutine def readline(self): + """Read chunk of data from the stream until newline (b'\n') is found. + + On success, return chunk that ends with newline. If only partial + line can be read due to EOF, return incomplete line without + terminating newline. When EOF was reached while no bytes read, empty + bytes object is returned. + + If limit is reached, ValueError will be raised. In that case, if + newline was found, complete line including newline will be removed + from internal buffer. Else, internal buffer will be cleared. Limit is + compared against part of the line without newline. + + If stream was paused, this function will automatically resume it if + needed. + """ + sep = b'\n' + seplen = len(sep) + try: + line = yield from self.readuntil(sep) + except IncompleteReadError as e: + return e.partial + except LimitOverrunError as e: + if self._buffer.startswith(sep, e.consumed): + del self._buffer[:e.consumed + seplen] + else: + self._buffer.clear() + self._maybe_resume_transport() + raise ValueError(e.args[0]) + return line + + @coroutine + def readuntil(self, separator=b'\n'): + """Read chunk of data from the stream until `separator` is found. + + On success, chunk and its separator will be removed from internal buffer + (i.e. consumed). Returned chunk will include separator at the end. + + Configured stream limit is used to check result. Limit means maximal + length of chunk that can be returned, not counting the separator. + + If EOF occurs and complete separator still not found, + IncompleteReadError(<partial data>, None) will be raised and internal + buffer becomes empty. This partial data may contain a partial separator. + + If chunk cannot be read due to overlimit, LimitOverrunError will be raised + and data will be left in internal buffer, so it can be read again, in + some different way. + + If stream was paused, this function will automatically resume it if + needed. + """ + seplen = len(separator) + if seplen == 0: + raise ValueError('Separator should be at least one-byte string') + if self._exception is not None: raise self._exception - line = bytearray() - not_enough = True - - while not_enough: - while self._buffer and not_enough: - ichar = self._buffer.find(b'\n') - if ichar < 0: - line.extend(self._buffer) - self._buffer.clear() - else: - ichar += 1 - line.extend(self._buffer[:ichar]) - del self._buffer[:ichar] - not_enough = False - - if len(line) > self._limit: - self._maybe_resume_transport() - raise ValueError('Line is too long') + # Consume whole buffer except last bytes, which length is + # one less than seplen. Let's check corner cases with + # separator='SEPARATOR': + # * we have received almost complete separator (without last + # byte). i.e buffer='some textSEPARATO'. In this case we + # can safely consume len(separator) - 1 bytes. + # * last byte of buffer is first byte of separator, i.e. + # buffer='abcdefghijklmnopqrS'. We may safely consume + # everything except that last byte, but this require to + # analyze bytes of buffer that match partial separator. + # This is slow and/or require FSM. For this case our + # implementation is not optimal, since require rescanning + # of data that is known to not belong to separator. In + # real world, separator will not be so long to notice + # performance problems. Even when reading MIME-encoded + # messages :) + + # `offset` is the number of bytes from the beginning of the buffer where + # is no occurrence of `separator`. + offset = 0 + + # Loop until we find `separator` in the buffer, exceed the buffer size, + # or an EOF has happened. + while True: + buflen = len(self._buffer) + + # Check if we now have enough data in the buffer for `separator` to + # fit. + if buflen - offset >= seplen: + isep = self._buffer.find(separator, offset) + + if isep != -1: + # `separator` is in the buffer. `isep` will be used later to + # retrieve the data. + break + + # see upper comment for explanation. + offset = buflen + 1 - seplen + if offset > self._limit: + raise LimitOverrunError('Separator is not found, and chunk exceed the limit', offset) + # Complete message (with full separator) may be present in buffer + # even when EOF flag is set. This may happen when the last chunk + # adds data which makes separator be found. That's why we check for + # EOF *ater* inspecting the buffer. if self._eof: - break + chunk = bytes(self._buffer) + self._buffer.clear() + raise IncompleteReadError(chunk, None) + + # _wait_for_data() will resume reading if stream was paused. + yield from self._wait_for_data('readuntil') - if not_enough: - yield from self._wait_for_data('readline') + if isep > self._limit: + raise LimitOverrunError('Separator is found, but chunk is longer than limit', isep) + chunk = self._buffer[:isep + seplen] + del self._buffer[:isep + seplen] self._maybe_resume_transport() - return bytes(line) + return bytes(chunk) @coroutine def read(self, n=-1): + """Read up to `n` bytes from the stream. + + If n is not provided, or set to -1, read until EOF and return all read + bytes. If the EOF was received and the internal buffer is empty, return + an empty bytes object. + + If n is zero, return empty bytes object immediatelly. + + If n is positive, this function try to read `n` bytes, and may return + less or equal bytes than requested, but at least one byte. If EOF was + received before any byte is read, this function returns empty byte + object. + + Returned value is not limited with limit, configured at stream creation. + + If stream was paused, this function will automatically resume it if + needed. + """ + if self._exception is not None: raise self._exception - if not n: + if n == 0: return b'' if n < 0: @@ -477,29 +612,41 @@ class StreamReader: break blocks.append(block) return b''.join(blocks) - else: - if not self._buffer and not self._eof: - yield from self._wait_for_data('read') - if n < 0 or len(self._buffer) <= n: - data = bytes(self._buffer) - self._buffer.clear() - else: - # n > 0 and len(self._buffer) > n - data = bytes(self._buffer[:n]) - del self._buffer[:n] + if not self._buffer and not self._eof: + yield from self._wait_for_data('read') + + # This will work right even if buffer is less than n bytes + data = bytes(self._buffer[:n]) + del self._buffer[:n] self._maybe_resume_transport() return data @coroutine def readexactly(self, n): + """Read exactly `n` bytes. + + Raise an `IncompleteReadError` if EOF is reached before `n` bytes can be + read. The `IncompleteReadError.partial` attribute of the exception will + contain the partial read bytes. + + if n is zero, return empty bytes object. + + Returned value is not limited with limit, configured at stream creation. + + If stream was paused, this function will automatically resume it if + needed. + """ if n < 0: raise ValueError('readexactly size can not be less than zero') if self._exception is not None: raise self._exception + if n == 0: + return b'' + # There used to be "optimized" code here. It created its own # Future and waited until self._buffer had at least the n # bytes, then called read(n). Unfortunately, this could pause @@ -516,6 +663,8 @@ class StreamReader: blocks.append(block) n -= len(block) + assert n == 0 + return b''.join(blocks) if compat.PY35: diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 3b115b1..1783d5f 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -203,6 +203,20 @@ class StreamReaderTests(test_utils.TestCase): self.assertRaises( ValueError, self.loop.run_until_complete, stream.read(2)) + def test_invalid_limit(self): + with self.assertRaisesRegex(ValueError, 'imit'): + asyncio.StreamReader(limit=0, loop=self.loop) + + with self.assertRaisesRegex(ValueError, 'imit'): + asyncio.StreamReader(limit=-1, loop=self.loop) + + def test_read_limit(self): + stream = asyncio.StreamReader(limit=3, loop=self.loop) + stream.feed_data(b'chunk') + data = self.loop.run_until_complete(stream.read(5)) + self.assertEqual(b'chunk', data) + self.assertEqual(b'', stream._buffer) + def test_readline(self): # Read one line. 'readline' will need to wait for the data # to come from 'cb' @@ -292,6 +306,23 @@ class StreamReaderTests(test_utils.TestCase): ValueError, self.loop.run_until_complete, stream.readline()) self.assertEqual(b'chunk3\n', stream._buffer) + # check strictness of the limit + stream = asyncio.StreamReader(limit=7, loop=self.loop) + stream.feed_data(b'1234567\n') + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'1234567\n', line) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'12345678\n') + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'12345678') + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'', stream._buffer) + def test_readline_nolimit_nowait(self): # All needed data for the first 'readline' call will be # in the buffer. @@ -342,6 +373,92 @@ class StreamReaderTests(test_utils.TestCase): ValueError, self.loop.run_until_complete, stream.readline()) self.assertEqual(b'', stream._buffer) + def test_readuntil_separator(self): + stream = asyncio.StreamReader(loop=self.loop) + with self.assertRaisesRegex(ValueError, 'Separator should be'): + self.loop.run_until_complete(stream.readuntil(separator=b'')) + + def test_readuntil_multi_chunks(self): + stream = asyncio.StreamReader(loop=self.loop) + + stream.feed_data(b'lineAAA') + data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA')) + self.assertEqual(b'lineAAA', data) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'lineAAA') + data = self.loop.run_until_complete(stream.readuntil(b'AAA')) + self.assertEqual(b'lineAAA', data) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'lineAAAxxx') + data = self.loop.run_until_complete(stream.readuntil(b'AAA')) + self.assertEqual(b'lineAAA', data) + self.assertEqual(b'xxx', stream._buffer) + + def test_readuntil_multi_chunks_1(self): + stream = asyncio.StreamReader(loop=self.loop) + + stream.feed_data(b'QWEaa') + stream.feed_data(b'XYaa') + stream.feed_data(b'a') + data = self.loop.run_until_complete(stream.readuntil(b'aaa')) + self.assertEqual(b'QWEaaXYaaa', data) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'QWEaa') + stream.feed_data(b'XYa') + stream.feed_data(b'aa') + data = self.loop.run_until_complete(stream.readuntil(b'aaa')) + self.assertEqual(b'QWEaaXYaaa', data) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'aaa') + data = self.loop.run_until_complete(stream.readuntil(b'aaa')) + self.assertEqual(b'aaa', data) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'Xaaa') + data = self.loop.run_until_complete(stream.readuntil(b'aaa')) + self.assertEqual(b'Xaaa', data) + self.assertEqual(b'', stream._buffer) + + stream.feed_data(b'XXX') + stream.feed_data(b'a') + stream.feed_data(b'a') + stream.feed_data(b'a') + data = self.loop.run_until_complete(stream.readuntil(b'aaa')) + self.assertEqual(b'XXXaaa', data) + self.assertEqual(b'', stream._buffer) + + def test_readuntil_eof(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'some dataAA') + stream.feed_eof() + + with self.assertRaises(asyncio.IncompleteReadError) as cm: + self.loop.run_until_complete(stream.readuntil(b'AAA')) + self.assertEqual(cm.exception.partial, b'some dataAA') + self.assertIsNone(cm.exception.expected) + self.assertEqual(b'', stream._buffer) + + def test_readuntil_limit_found_sep(self): + stream = asyncio.StreamReader(loop=self.loop, limit=3) + stream.feed_data(b'some dataAA') + + with self.assertRaisesRegex(asyncio.LimitOverrunError, + 'not found') as cm: + self.loop.run_until_complete(stream.readuntil(b'AAA')) + + self.assertEqual(b'some dataAA', stream._buffer) + + stream.feed_data(b'A') + with self.assertRaisesRegex(asyncio.LimitOverrunError, + 'is found') as cm: + self.loop.run_until_complete(stream.readuntil(b'AAA')) + + self.assertEqual(b'some dataAAA', stream._buffer) + def test_readexactly_zero_or_less(self): # Read exact number of bytes (zero or less). stream = asyncio.StreamReader(loop=self.loop) @@ -372,6 +489,13 @@ class StreamReaderTests(test_utils.TestCase): self.assertEqual(self.DATA + self.DATA, data) self.assertEqual(self.DATA, stream._buffer) + def test_readexactly_limit(self): + stream = asyncio.StreamReader(limit=3, loop=self.loop) + stream.feed_data(b'chunk') + data = self.loop.run_until_complete(stream.readexactly(5)) + self.assertEqual(b'chunk', data) + self.assertEqual(b'', stream._buffer) + def test_readexactly_eof(self): # Read exact number of bytes (eof). stream = asyncio.StreamReader(loop=self.loop) @@ -657,7 +781,9 @@ os.close(fd) @asyncio.coroutine def client(host, port): - reader, writer = yield from asyncio.open_connection(host, port, loop=self.loop) + reader, writer = yield from asyncio.open_connection( + host, port, loop=self.loop) + while True: writer.write(b"foo\n") yield from writer.drain() |