From 355491dc47ea4a2574ee8f9ea60a0d25fe3fba43 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 15:17:11 -0700 Subject: Write flow control for asyncio (includes asyncio.streams overhaul). --- Lib/asyncio/protocols.py | 28 +++++ Lib/asyncio/selector_events.py | 78 ++++++++++--- Lib/asyncio/streams.py | 208 ++++++++++++++++++++++++---------- Lib/asyncio/transports.py | 25 ++++ Lib/test/test_asyncio/test_streams.py | 42 +++---- 5 files changed, 288 insertions(+), 93 deletions(-) diff --git a/Lib/asyncio/protocols.py b/Lib/asyncio/protocols.py index a94abbe..d3a8685 100644 --- a/Lib/asyncio/protocols.py +++ b/Lib/asyncio/protocols.py @@ -29,6 +29,34 @@ class BaseProtocol: aborted or closed). """ + def pause_writing(self): + """Called when the transport's buffer goes over the high-water mark. + + Pause and resume calls are paired -- pause_writing() is called + once when the buffer goes strictly over the high-water mark + (even if subsequent writes increases the buffer size even + more), and eventually resume_writing() is called once when the + buffer size reaches the low-water mark. + + Note that if the buffer size equals the high-water mark, + pause_writing() is not called -- it must go strictly over. + Conversely, resume_writing() is called when the buffer size is + equal or lower than the low-water mark. These end conditions + are important to ensure that things go as expected when either + mark is zero. + + NOTE: This is the only Protocol callback that is not called + through EventLoop.call_soon() -- if it were, it would have no + effect when it's most needed (when the app keeps writing + without yielding until pause_writing() is called). + """ + + def resume_writing(self): + """Called when the transport's buffer drains below the low-water mark. + + See pause_writing() for details. + """ + class Protocol(BaseProtocol): """ABC representing a protocol. diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 084d9be..63164f0 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -346,8 +346,10 @@ class _SelectorTransport(transports.Transport): self._buffer = collections.deque() self._conn_lost = 0 # Set when call to connection_lost scheduled. self._closing = False # Set when close() called. - if server is not None: - server.attach(self) + self._protocol_paused = False + self.set_write_buffer_limits() + if self._server is not None: + self._server.attach(self) def abort(self): self._force_close(None) @@ -392,6 +394,40 @@ class _SelectorTransport(transports.Transport): server.detach(self) self._server = None + def _maybe_pause_protocol(self): + size = self.get_write_buffer_size() + if size <= self._high_water: + return + if not self._protocol_paused: + self._protocol_paused = True + try: + self._protocol.pause_writing() + except Exception: + tulip_log.exception('pause_writing() failed') + + def _maybe_resume_protocol(self): + if self._protocol_paused and self.get_write_buffer_size() <= self._low_water: + self._protocol_paused = False + try: + self._protocol.resume_writing() + except Exception: + tulip_log.exception('resume_writing() failed') + + def set_write_buffer_limits(self, high=None, low=None): + if high is None: + if low is None: + high = 64*1024 + else: + high = 4*low + if low is None: + low = high // 4 + assert 0 <= low <= high, repr((low, high)) + self._high_water = high + self._low_water = low + + def get_write_buffer_size(self): + return sum(len(data) for data in self._buffer) + class _SelectorSocketTransport(_SelectorTransport): @@ -447,7 +483,7 @@ class _SelectorSocketTransport(_SelectorTransport): return if not self._buffer: - # Attempt to send it right away first. + # Optimization: try to send now. try: n = self._sock.send(data) except (BlockingIOError, InterruptedError): @@ -459,34 +495,36 @@ class _SelectorSocketTransport(_SelectorTransport): data = data[n:] if not data: return - - # Start async I/O. + # Not all was written; register write handler. self._loop.add_writer(self._sock_fd, self._write_ready) + # Add it to the buffer. self._buffer.append(data) + self._maybe_pause_protocol() def _write_ready(self): data = b''.join(self._buffer) assert data, 'Data should not be empty' - self._buffer.clear() + self._buffer.clear() # Optimistically; may have to put it back later. try: n = self._sock.send(data) except (BlockingIOError, InterruptedError): - self._buffer.append(data) + self._buffer.append(data) # Still need to write this. except Exception as exc: self._loop.remove_writer(self._sock_fd) self._fatal_error(exc) else: data = data[n:] - if not data: + if data: + self._buffer.append(data) # Still need to write this. + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer: self._loop.remove_writer(self._sock_fd) if self._closing: self._call_connection_lost(None) elif self._eof: self._sock.shutdown(socket.SHUT_WR) - return - self._buffer.append(data) # Try again later. def write_eof(self): if self._eof: @@ -546,16 +584,23 @@ class _SelectorSslTransport(_SelectorTransport): self._loop.add_writer(self._sock_fd, self._on_handshake) return except Exception as exc: + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) self._sock.close() if self._waiter is not None: self._waiter.set_exception(exc) return except BaseException as exc: + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) self._sock.close() if self._waiter is not None: self._waiter.set_exception(exc) raise + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + # Verify hostname if requested. peercert = self._sock.getpeercert() if (self._server_hostname is not None and @@ -574,8 +619,6 @@ class _SelectorSslTransport(_SelectorTransport): compression=self._sock.compression(), ) - self._loop.remove_reader(self._sock_fd) - self._loop.remove_writer(self._sock_fd) self._loop.add_reader(self._sock_fd, self._on_ready) self._loop.add_writer(self._sock_fd, self._on_ready) self._loop.call_soon(self._protocol.connection_made, self) @@ -642,6 +685,8 @@ class _SelectorSslTransport(_SelectorTransport): if n < len(data): self._buffer.append(data[n:]) + self._maybe_resume_protocol() # May append to buffer. + if self._closing and not self._buffer: self._loop.remove_writer(self._sock_fd) self._call_connection_lost(None) @@ -657,8 +702,9 @@ class _SelectorSslTransport(_SelectorTransport): self._conn_lost += 1 return - self._buffer.append(data) # We could optimize, but the callback can do this for now. + self._buffer.append(data) + self._maybe_pause_protocol() def can_write_eof(self): return False @@ -675,11 +721,13 @@ class _SelectorDatagramTransport(_SelectorTransport): def __init__(self, loop, sock, protocol, address=None, extra=None): super().__init__(loop, sock, protocol, extra) - self._address = address self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) + def get_write_buffer_size(self): + return sum(len(data) for data, _ in self._buffer) + def _read_ready(self): try: data, addr = self._sock.recvfrom(self.max_size) @@ -723,6 +771,7 @@ class _SelectorDatagramTransport(_SelectorTransport): return self._buffer.append((data, addr)) + self._maybe_pause_protocol() def _sendto_ready(self): while self._buffer: @@ -743,6 +792,7 @@ class _SelectorDatagramTransport(_SelectorTransport): self._fatal_error(exc) return + self._maybe_resume_protocol() # May append to buffer. if not self._buffer: self._loop.remove_writer(self._sock_fd) if self._closing: diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 9915aa5..e995368 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -39,7 +39,8 @@ def open_connection(host=None, port=None, *, protocol = StreamReaderProtocol(reader) transport, _ = yield from loop.create_connection( lambda: protocol, host, port, **kwds) - return reader, transport # (reader, writer) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer class StreamReaderProtocol(protocols.Protocol): @@ -52,22 +53,113 @@ class StreamReaderProtocol(protocols.Protocol): """ def __init__(self, stream_reader): - self.stream_reader = stream_reader + self._stream_reader = stream_reader + self._drain_waiter = None + self._paused = False def connection_made(self, transport): - self.stream_reader.set_transport(transport) + self._stream_reader.set_transport(transport) def connection_lost(self, exc): if exc is None: - self.stream_reader.feed_eof() + self._stream_reader.feed_eof() else: - self.stream_reader.set_exception(exc) + self._stream_reader.set_exception(exc) + # Also wake up the writing side. + if self._paused: + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) def data_received(self, data): - self.stream_reader.feed_data(data) + self._stream_reader.feed_data(data) def eof_received(self): - self.stream_reader.feed_eof() + self._stream_reader.feed_eof() + + def pause_writing(self): + assert not self._paused + self._paused = True + + def resume_writing(self): + assert self._paused + self._paused = False + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + +class StreamWriter: + """Wraps a Transport. + + This exposes write(), writelines(), [can_]write_eof(), + get_extra_info() and close(). It adds drain() which returns an + optional Future on which you can wait for flow control. It also + adds a transport attribute which references the Transport + directly. + """ + + def __init__(self, transport, protocol, reader, loop): + self._transport = transport + self._protocol = protocol + self._reader = reader + self._loop = loop + + @property + def transport(self): + return self._transport + + def write(self, data): + self._transport.write(data) + + def writelines(self, data): + self._transport.writelines(data) + + def write_eof(self): + return self._transport.write_eof() + + def can_write_eof(self): + return self._transport.can_write_eof() + + def close(self): + return self._transport.close() + + def get_extra_info(self, name, default=None): + return self._transport.get_extra_info(name, default) + + def drain(self): + """This method has an unusual return value. + + The intended use is to write + + w.write(data) + yield from w.drain() + + When there's nothing to wait for, drain() returns (), and the + yield-from continues immediately. When the transport buffer + is full (the protocol is paused), drain() creates and returns + a Future and the yield-from will block until that Future is + completed, which will happen when the buffer is (partially) + drained and the protocol is resumed. + """ + if self._reader._exception is not None: + raise self._writer._exception + if self._transport._conn_lost: # Uses private variable. + raise ConnectionResetError('Connection lost') + if not self._protocol._paused: + return () + waiter = self._protocol._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = futures.Future(loop=self._loop) + self._protocol._drain_waiter = waiter + return waiter class StreamReader: @@ -75,14 +167,14 @@ class StreamReader: def __init__(self, limit=_DEFAULT_LIMIT, loop=None): # The line length limit is a security feature; # it also doubles as half the buffer limit. - self.limit = limit + self._limit = limit if loop is None: loop = events.get_event_loop() - self.loop = loop - self.buffer = collections.deque() # Deque of bytes objects. - self.byte_count = 0 # Bytes in buffer. - self.eof = False # Whether we're done. - self.waiter = None # A future. + self._loop = loop + self._buffer = collections.deque() # Deque of bytes objects. + self._byte_count = 0 # Bytes in buffer. + self._eof = False # Whether we're done. + self._waiter = None # A future. self._exception = None self._transport = None self._paused = False @@ -93,9 +185,9 @@ class StreamReader: def set_exception(self, exc): self._exception = exc - waiter = self.waiter + waiter = self._waiter if waiter is not None: - self.waiter = None + self._waiter = None if not waiter.cancelled(): waiter.set_exception(exc) @@ -104,15 +196,15 @@ class StreamReader: self._transport = transport def _maybe_resume_transport(self): - if self._paused and self.byte_count <= self.limit: + if self._paused and self._byte_count <= self._limit: self._paused = False self._transport.resume_reading() def feed_eof(self): - self.eof = True - waiter = self.waiter + self._eof = True + waiter = self._waiter if waiter is not None: - self.waiter = None + self._waiter = None if not waiter.cancelled(): waiter.set_result(True) @@ -120,18 +212,18 @@ class StreamReader: if not data: return - self.buffer.append(data) - self.byte_count += len(data) + self._buffer.append(data) + self._byte_count += len(data) - waiter = self.waiter + waiter = self._waiter if waiter is not None: - self.waiter = None + self._waiter = None if not waiter.cancelled(): waiter.set_result(False) if (self._transport is not None and not self._paused and - self.byte_count > 2*self.limit): + self._byte_count > 2*self._limit): try: self._transport.pause_reading() except NotImplementedError: @@ -152,8 +244,8 @@ class StreamReader: not_enough = True while not_enough: - while self.buffer and not_enough: - data = self.buffer.popleft() + while self._buffer and not_enough: + data = self._buffer.popleft() ichar = data.find(b'\n') if ichar < 0: parts.append(data) @@ -162,29 +254,29 @@ class StreamReader: ichar += 1 head, tail = data[:ichar], data[ichar:] if tail: - self.buffer.appendleft(tail) + self._buffer.appendleft(tail) not_enough = False parts.append(head) parts_size += len(head) - if parts_size > self.limit: - self.byte_count -= parts_size + if parts_size > self._limit: + self._byte_count -= parts_size self._maybe_resume_transport() raise ValueError('Line is too long') - if self.eof: + if self._eof: break if not_enough: - assert self.waiter is None - self.waiter = futures.Future(loop=self.loop) + assert self._waiter is None + self._waiter = futures.Future(loop=self._loop) try: - yield from self.waiter + yield from self._waiter finally: - self.waiter = None + self._waiter = None line = b''.join(parts) - self.byte_count -= parts_size + self._byte_count -= parts_size self._maybe_resume_transport() return line @@ -198,42 +290,42 @@ class StreamReader: return b'' if n < 0: - while not self.eof: - assert not self.waiter - self.waiter = futures.Future(loop=self.loop) + while not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) try: - yield from self.waiter + yield from self._waiter finally: - self.waiter = None + self._waiter = None else: - if not self.byte_count and not self.eof: - assert not self.waiter - self.waiter = futures.Future(loop=self.loop) + if not self._byte_count and not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) try: - yield from self.waiter + yield from self._waiter finally: - self.waiter = None + self._waiter = None - if n < 0 or self.byte_count <= n: - data = b''.join(self.buffer) - self.buffer.clear() - self.byte_count = 0 + if n < 0 or self._byte_count <= n: + data = b''.join(self._buffer) + self._buffer.clear() + self._byte_count = 0 self._maybe_resume_transport() return data parts = [] parts_bytes = 0 - while self.buffer and parts_bytes < n: - data = self.buffer.popleft() + while self._buffer and parts_bytes < n: + data = self._buffer.popleft() data_bytes = len(data) if n < parts_bytes + data_bytes: data_bytes = n - parts_bytes data, rest = data[:data_bytes], data[data_bytes:] - self.buffer.appendleft(rest) + self._buffer.appendleft(rest) parts.append(data) parts_bytes += data_bytes - self.byte_count -= data_bytes + self._byte_count -= data_bytes self._maybe_resume_transport() return b''.join(parts) @@ -246,12 +338,12 @@ class StreamReader: if n <= 0: return b'' - while self.byte_count < n and not self.eof: - assert not self.waiter - self.waiter = futures.Future(loop=self.loop) + while self._byte_count < n and not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) try: - yield from self.waiter + yield from self._waiter finally: - self.waiter = None + self._waiter = None return (yield from self.read(n)) diff --git a/Lib/asyncio/transports.py b/Lib/asyncio/transports.py index f1a7180..8c6b189 100644 --- a/Lib/asyncio/transports.py +++ b/Lib/asyncio/transports.py @@ -49,6 +49,31 @@ class ReadTransport(BaseTransport): class WriteTransport(BaseTransport): """ABC for write-only transports.""" + def set_write_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for write flow control. + + These two values control when to call the protocol's + pause_writing() and resume_writing() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to a + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_writing() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_writing() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + raise NotImplementedError + + def get_write_buffer_size(self): + """Return the current size of the write buffer.""" + raise NotImplementedError + def write(self, data): """Write some data bytes to the transport. diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 31d8151..69e2246 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -32,7 +32,7 @@ class StreamReaderTests(unittest.TestCase): @unittest.mock.patch('asyncio.streams.events') def test_ctor_global_loop(self, m_events): stream = streams.StreamReader() - self.assertIs(stream.loop, m_events.get_event_loop.return_value) + self.assertIs(stream._loop, m_events.get_event_loop.return_value) def test_open_connection(self): with test_utils.run_test_server() as httpd: @@ -81,13 +81,13 @@ class StreamReaderTests(unittest.TestCase): stream = streams.StreamReader(loop=self.loop) stream.feed_data(b'') - self.assertEqual(0, stream.byte_count) + self.assertEqual(0, stream._byte_count) def test_feed_data_byte_count(self): stream = streams.StreamReader(loop=self.loop) stream.feed_data(self.DATA) - self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(len(self.DATA), stream._byte_count) def test_read_zero(self): # Read zero bytes. @@ -96,7 +96,7 @@ class StreamReaderTests(unittest.TestCase): data = self.loop.run_until_complete(stream.read(0)) self.assertEqual(b'', data) - self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(len(self.DATA), stream._byte_count) def test_read(self): # Read bytes. @@ -109,7 +109,7 @@ class StreamReaderTests(unittest.TestCase): data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) - self.assertFalse(stream.byte_count) + self.assertFalse(stream._byte_count) def test_read_line_breaks(self): # Read bytes without line breaks. @@ -120,7 +120,7 @@ class StreamReaderTests(unittest.TestCase): data = self.loop.run_until_complete(stream.read(5)) self.assertEqual(b'line1', data) - self.assertEqual(5, stream.byte_count) + self.assertEqual(5, stream._byte_count) def test_read_eof(self): # Read bytes, stop at eof. @@ -133,7 +133,7 @@ class StreamReaderTests(unittest.TestCase): data = self.loop.run_until_complete(read_task) self.assertEqual(b'', data) - self.assertFalse(stream.byte_count) + self.assertFalse(stream._byte_count) def test_read_until_eof(self): # Read all bytes until eof. @@ -149,7 +149,7 @@ class StreamReaderTests(unittest.TestCase): data = self.loop.run_until_complete(read_task) self.assertEqual(b'chunk1\nchunk2', data) - self.assertFalse(stream.byte_count) + self.assertFalse(stream._byte_count) def test_read_exception(self): stream = streams.StreamReader(loop=self.loop) @@ -176,7 +176,7 @@ class StreamReaderTests(unittest.TestCase): line = self.loop.run_until_complete(read_task) self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) - self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + self.assertEqual(len(b'\n chunk4')-1, stream._byte_count) def test_readline_limit_with_existing_data(self): stream = streams.StreamReader(3, loop=self.loop) @@ -185,7 +185,7 @@ class StreamReaderTests(unittest.TestCase): self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) - self.assertEqual([b'line2\n'], list(stream.buffer)) + self.assertEqual([b'line2\n'], list(stream._buffer)) stream = streams.StreamReader(3, loop=self.loop) stream.feed_data(b'li') @@ -194,8 +194,8 @@ class StreamReaderTests(unittest.TestCase): self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) - self.assertEqual([b'li'], list(stream.buffer)) - self.assertEqual(2, stream.byte_count) + self.assertEqual([b'li'], list(stream._buffer)) + self.assertEqual(2, stream._byte_count) def test_readline_limit(self): stream = streams.StreamReader(7, loop=self.loop) @@ -209,8 +209,8 @@ class StreamReaderTests(unittest.TestCase): self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) - self.assertEqual([b'chunk3\n'], list(stream.buffer)) - self.assertEqual(7, stream.byte_count) + self.assertEqual([b'chunk3\n'], list(stream._buffer)) + self.assertEqual(7, stream._byte_count) def test_readline_line_byte_count(self): stream = streams.StreamReader(loop=self.loop) @@ -220,7 +220,7 @@ class StreamReaderTests(unittest.TestCase): line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'line1\n', line) - self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count) def test_readline_eof(self): stream = streams.StreamReader(loop=self.loop) @@ -248,7 +248,7 @@ class StreamReaderTests(unittest.TestCase): self.assertEqual(b'line2\nl', data) self.assertEqual( len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), - stream.byte_count) + stream._byte_count) def test_readline_exception(self): stream = streams.StreamReader(loop=self.loop) @@ -268,11 +268,11 @@ class StreamReaderTests(unittest.TestCase): data = self.loop.run_until_complete(stream.readexactly(0)) self.assertEqual(b'', data) - self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(len(self.DATA), stream._byte_count) data = self.loop.run_until_complete(stream.readexactly(-1)) self.assertEqual(b'', data) - self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(len(self.DATA), stream._byte_count) def test_readexactly(self): # Read exact number of bytes. @@ -289,7 +289,7 @@ class StreamReaderTests(unittest.TestCase): data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA + self.DATA, data) - self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(len(self.DATA), stream._byte_count) def test_readexactly_eof(self): # Read exact number of bytes (eof). @@ -304,7 +304,7 @@ class StreamReaderTests(unittest.TestCase): data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) - self.assertFalse(stream.byte_count) + self.assertFalse(stream._byte_count) def test_readexactly_exception(self): stream = streams.StreamReader(loop=self.loop) @@ -357,7 +357,7 @@ class StreamReaderTests(unittest.TestCase): # The following line fails if set_exception() isn't careful. stream.set_exception(RuntimeError('message')) test_utils.run_briefly(self.loop) - self.assertIs(stream.waiter, None) + self.assertIs(stream._waiter, None) if __name__ == '__main__': -- cgit v0.12