summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio
diff options
context:
space:
mode:
authorGuido van Rossum <guido@dropbox.com>2013-10-18 22:17:11 (GMT)
committerGuido van Rossum <guido@dropbox.com>2013-10-18 22:17:11 (GMT)
commit355491dc47ea4a2574ee8f9ea60a0d25fe3fba43 (patch)
tree2b9661e6f8c6d24704fae0c82467674802749d6d /Lib/asyncio
parent051a33148813d045c33745ccd0e9e20e96b1bb6f (diff)
downloadcpython-355491dc47ea4a2574ee8f9ea60a0d25fe3fba43.zip
cpython-355491dc47ea4a2574ee8f9ea60a0d25fe3fba43.tar.gz
cpython-355491dc47ea4a2574ee8f9ea60a0d25fe3fba43.tar.bz2
Write flow control for asyncio (includes asyncio.streams overhaul).
Diffstat (limited to 'Lib/asyncio')
-rw-r--r--Lib/asyncio/protocols.py28
-rw-r--r--Lib/asyncio/selector_events.py78
-rw-r--r--Lib/asyncio/streams.py208
-rw-r--r--Lib/asyncio/transports.py25
4 files changed, 267 insertions, 72 deletions
diff --git a/Lib/asyncio/protocols.py b/Lib/asyncio/protocols.py
index a94abbe..d3a8685 100644
--- a/Lib/asyncio/protocols.py
+++ b/Lib/asyncio/protocols.py
@@ -29,6 +29,34 @@ class BaseProtocol:
aborted or closed).
"""
+ def pause_writing(self):
+ """Called when the transport's buffer goes over the high-water mark.
+
+ Pause and resume calls are paired -- pause_writing() is called
+ once when the buffer goes strictly over the high-water mark
+ (even if subsequent writes increases the buffer size even
+ more), and eventually resume_writing() is called once when the
+ buffer size reaches the low-water mark.
+
+ Note that if the buffer size equals the high-water mark,
+ pause_writing() is not called -- it must go strictly over.
+ Conversely, resume_writing() is called when the buffer size is
+ equal or lower than the low-water mark. These end conditions
+ are important to ensure that things go as expected when either
+ mark is zero.
+
+ NOTE: This is the only Protocol callback that is not called
+ through EventLoop.call_soon() -- if it were, it would have no
+ effect when it's most needed (when the app keeps writing
+ without yielding until pause_writing() is called).
+ """
+
+ def resume_writing(self):
+ """Called when the transport's buffer drains below the low-water mark.
+
+ See pause_writing() for details.
+ """
+
class Protocol(BaseProtocol):
"""ABC representing a protocol.
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index 084d9be..63164f0 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -346,8 +346,10 @@ class _SelectorTransport(transports.Transport):
self._buffer = collections.deque()
self._conn_lost = 0 # Set when call to connection_lost scheduled.
self._closing = False # Set when close() called.
- if server is not None:
- server.attach(self)
+ self._protocol_paused = False
+ self.set_write_buffer_limits()
+ if self._server is not None:
+ self._server.attach(self)
def abort(self):
self._force_close(None)
@@ -392,6 +394,40 @@ class _SelectorTransport(transports.Transport):
server.detach(self)
self._server = None
+ def _maybe_pause_protocol(self):
+ size = self.get_write_buffer_size()
+ if size <= self._high_water:
+ return
+ if not self._protocol_paused:
+ self._protocol_paused = True
+ try:
+ self._protocol.pause_writing()
+ except Exception:
+ tulip_log.exception('pause_writing() failed')
+
+ def _maybe_resume_protocol(self):
+ if self._protocol_paused and self.get_write_buffer_size() <= self._low_water:
+ self._protocol_paused = False
+ try:
+ self._protocol.resume_writing()
+ except Exception:
+ tulip_log.exception('resume_writing() failed')
+
+ def set_write_buffer_limits(self, high=None, low=None):
+ if high is None:
+ if low is None:
+ high = 64*1024
+ else:
+ high = 4*low
+ if low is None:
+ low = high // 4
+ assert 0 <= low <= high, repr((low, high))
+ self._high_water = high
+ self._low_water = low
+
+ def get_write_buffer_size(self):
+ return sum(len(data) for data in self._buffer)
+
class _SelectorSocketTransport(_SelectorTransport):
@@ -447,7 +483,7 @@ class _SelectorSocketTransport(_SelectorTransport):
return
if not self._buffer:
- # Attempt to send it right away first.
+ # Optimization: try to send now.
try:
n = self._sock.send(data)
except (BlockingIOError, InterruptedError):
@@ -459,34 +495,36 @@ class _SelectorSocketTransport(_SelectorTransport):
data = data[n:]
if not data:
return
-
- # Start async I/O.
+ # Not all was written; register write handler.
self._loop.add_writer(self._sock_fd, self._write_ready)
+ # Add it to the buffer.
self._buffer.append(data)
+ self._maybe_pause_protocol()
def _write_ready(self):
data = b''.join(self._buffer)
assert data, 'Data should not be empty'
- self._buffer.clear()
+ self._buffer.clear() # Optimistically; may have to put it back later.
try:
n = self._sock.send(data)
except (BlockingIOError, InterruptedError):
- self._buffer.append(data)
+ self._buffer.append(data) # Still need to write this.
except Exception as exc:
self._loop.remove_writer(self._sock_fd)
self._fatal_error(exc)
else:
data = data[n:]
- if not data:
+ if data:
+ self._buffer.append(data) # Still need to write this.
+ self._maybe_resume_protocol() # May append to buffer.
+ if not self._buffer:
self._loop.remove_writer(self._sock_fd)
if self._closing:
self._call_connection_lost(None)
elif self._eof:
self._sock.shutdown(socket.SHUT_WR)
- return
- self._buffer.append(data) # Try again later.
def write_eof(self):
if self._eof:
@@ -546,16 +584,23 @@ class _SelectorSslTransport(_SelectorTransport):
self._loop.add_writer(self._sock_fd, self._on_handshake)
return
except Exception as exc:
+ self._loop.remove_reader(self._sock_fd)
+ self._loop.remove_writer(self._sock_fd)
self._sock.close()
if self._waiter is not None:
self._waiter.set_exception(exc)
return
except BaseException as exc:
+ self._loop.remove_reader(self._sock_fd)
+ self._loop.remove_writer(self._sock_fd)
self._sock.close()
if self._waiter is not None:
self._waiter.set_exception(exc)
raise
+ self._loop.remove_reader(self._sock_fd)
+ self._loop.remove_writer(self._sock_fd)
+
# Verify hostname if requested.
peercert = self._sock.getpeercert()
if (self._server_hostname is not None and
@@ -574,8 +619,6 @@ class _SelectorSslTransport(_SelectorTransport):
compression=self._sock.compression(),
)
- self._loop.remove_reader(self._sock_fd)
- self._loop.remove_writer(self._sock_fd)
self._loop.add_reader(self._sock_fd, self._on_ready)
self._loop.add_writer(self._sock_fd, self._on_ready)
self._loop.call_soon(self._protocol.connection_made, self)
@@ -642,6 +685,8 @@ class _SelectorSslTransport(_SelectorTransport):
if n < len(data):
self._buffer.append(data[n:])
+ self._maybe_resume_protocol() # May append to buffer.
+
if self._closing and not self._buffer:
self._loop.remove_writer(self._sock_fd)
self._call_connection_lost(None)
@@ -657,8 +702,9 @@ class _SelectorSslTransport(_SelectorTransport):
self._conn_lost += 1
return
- self._buffer.append(data)
# We could optimize, but the callback can do this for now.
+ self._buffer.append(data)
+ self._maybe_pause_protocol()
def can_write_eof(self):
return False
@@ -675,11 +721,13 @@ class _SelectorDatagramTransport(_SelectorTransport):
def __init__(self, loop, sock, protocol, address=None, extra=None):
super().__init__(loop, sock, protocol, extra)
-
self._address = address
self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self)
+ def get_write_buffer_size(self):
+ return sum(len(data) for data, _ in self._buffer)
+
def _read_ready(self):
try:
data, addr = self._sock.recvfrom(self.max_size)
@@ -723,6 +771,7 @@ class _SelectorDatagramTransport(_SelectorTransport):
return
self._buffer.append((data, addr))
+ self._maybe_pause_protocol()
def _sendto_ready(self):
while self._buffer:
@@ -743,6 +792,7 @@ class _SelectorDatagramTransport(_SelectorTransport):
self._fatal_error(exc)
return
+ self._maybe_resume_protocol() # May append to buffer.
if not self._buffer:
self._loop.remove_writer(self._sock_fd)
if self._closing:
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index 9915aa5..e995368 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -39,7 +39,8 @@ def open_connection(host=None, port=None, *,
protocol = StreamReaderProtocol(reader)
transport, _ = yield from loop.create_connection(
lambda: protocol, host, port, **kwds)
- return reader, transport # (reader, writer)
+ writer = StreamWriter(transport, protocol, reader, loop)
+ return reader, writer
class StreamReaderProtocol(protocols.Protocol):
@@ -52,22 +53,113 @@ class StreamReaderProtocol(protocols.Protocol):
"""
def __init__(self, stream_reader):
- self.stream_reader = stream_reader
+ self._stream_reader = stream_reader
+ self._drain_waiter = None
+ self._paused = False
def connection_made(self, transport):
- self.stream_reader.set_transport(transport)
+ self._stream_reader.set_transport(transport)
def connection_lost(self, exc):
if exc is None:
- self.stream_reader.feed_eof()
+ self._stream_reader.feed_eof()
else:
- self.stream_reader.set_exception(exc)
+ self._stream_reader.set_exception(exc)
+ # Also wake up the writing side.
+ if self._paused:
+ waiter = self._drain_waiter
+ if waiter is not None:
+ self._drain_waiter = None
+ if not waiter.done():
+ if exc is None:
+ waiter.set_result(None)
+ else:
+ waiter.set_exception(exc)
def data_received(self, data):
- self.stream_reader.feed_data(data)
+ self._stream_reader.feed_data(data)
def eof_received(self):
- self.stream_reader.feed_eof()
+ self._stream_reader.feed_eof()
+
+ def pause_writing(self):
+ assert not self._paused
+ self._paused = True
+
+ def resume_writing(self):
+ assert self._paused
+ self._paused = False
+ waiter = self._drain_waiter
+ if waiter is not None:
+ self._drain_waiter = None
+ if not waiter.done():
+ waiter.set_result(None)
+
+
+class StreamWriter:
+ """Wraps a Transport.
+
+ This exposes write(), writelines(), [can_]write_eof(),
+ get_extra_info() and close(). It adds drain() which returns an
+ optional Future on which you can wait for flow control. It also
+ adds a transport attribute which references the Transport
+ directly.
+ """
+
+ def __init__(self, transport, protocol, reader, loop):
+ self._transport = transport
+ self._protocol = protocol
+ self._reader = reader
+ self._loop = loop
+
+ @property
+ def transport(self):
+ return self._transport
+
+ def write(self, data):
+ self._transport.write(data)
+
+ def writelines(self, data):
+ self._transport.writelines(data)
+
+ def write_eof(self):
+ return self._transport.write_eof()
+
+ def can_write_eof(self):
+ return self._transport.can_write_eof()
+
+ def close(self):
+ return self._transport.close()
+
+ def get_extra_info(self, name, default=None):
+ return self._transport.get_extra_info(name, default)
+
+ def drain(self):
+ """This method has an unusual return value.
+
+ The intended use is to write
+
+ w.write(data)
+ yield from w.drain()
+
+ When there's nothing to wait for, drain() returns (), and the
+ yield-from continues immediately. When the transport buffer
+ is full (the protocol is paused), drain() creates and returns
+ a Future and the yield-from will block until that Future is
+ completed, which will happen when the buffer is (partially)
+ drained and the protocol is resumed.
+ """
+ if self._reader._exception is not None:
+ raise self._writer._exception
+ if self._transport._conn_lost: # Uses private variable.
+ raise ConnectionResetError('Connection lost')
+ if not self._protocol._paused:
+ return ()
+ waiter = self._protocol._drain_waiter
+ assert waiter is None or waiter.cancelled()
+ waiter = futures.Future(loop=self._loop)
+ self._protocol._drain_waiter = waiter
+ return waiter
class StreamReader:
@@ -75,14 +167,14 @@ class StreamReader:
def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
# The line length limit is a security feature;
# it also doubles as half the buffer limit.
- self.limit = limit
+ self._limit = limit
if loop is None:
loop = events.get_event_loop()
- self.loop = loop
- self.buffer = collections.deque() # Deque of bytes objects.
- self.byte_count = 0 # Bytes in buffer.
- self.eof = False # Whether we're done.
- self.waiter = None # A future.
+ self._loop = loop
+ self._buffer = collections.deque() # Deque of bytes objects.
+ self._byte_count = 0 # Bytes in buffer.
+ self._eof = False # Whether we're done.
+ self._waiter = None # A future.
self._exception = None
self._transport = None
self._paused = False
@@ -93,9 +185,9 @@ class StreamReader:
def set_exception(self, exc):
self._exception = exc
- waiter = self.waiter
+ waiter = self._waiter
if waiter is not None:
- self.waiter = None
+ self._waiter = None
if not waiter.cancelled():
waiter.set_exception(exc)
@@ -104,15 +196,15 @@ class StreamReader:
self._transport = transport
def _maybe_resume_transport(self):
- if self._paused and self.byte_count <= self.limit:
+ if self._paused and self._byte_count <= self._limit:
self._paused = False
self._transport.resume_reading()
def feed_eof(self):
- self.eof = True
- waiter = self.waiter
+ self._eof = True
+ waiter = self._waiter
if waiter is not None:
- self.waiter = None
+ self._waiter = None
if not waiter.cancelled():
waiter.set_result(True)
@@ -120,18 +212,18 @@ class StreamReader:
if not data:
return
- self.buffer.append(data)
- self.byte_count += len(data)
+ self._buffer.append(data)
+ self._byte_count += len(data)
- waiter = self.waiter
+ waiter = self._waiter
if waiter is not None:
- self.waiter = None
+ self._waiter = None
if not waiter.cancelled():
waiter.set_result(False)
if (self._transport is not None and
not self._paused and
- self.byte_count > 2*self.limit):
+ self._byte_count > 2*self._limit):
try:
self._transport.pause_reading()
except NotImplementedError:
@@ -152,8 +244,8 @@ class StreamReader:
not_enough = True
while not_enough:
- while self.buffer and not_enough:
- data = self.buffer.popleft()
+ while self._buffer and not_enough:
+ data = self._buffer.popleft()
ichar = data.find(b'\n')
if ichar < 0:
parts.append(data)
@@ -162,29 +254,29 @@ class StreamReader:
ichar += 1
head, tail = data[:ichar], data[ichar:]
if tail:
- self.buffer.appendleft(tail)
+ self._buffer.appendleft(tail)
not_enough = False
parts.append(head)
parts_size += len(head)
- if parts_size > self.limit:
- self.byte_count -= parts_size
+ if parts_size > self._limit:
+ self._byte_count -= parts_size
self._maybe_resume_transport()
raise ValueError('Line is too long')
- if self.eof:
+ if self._eof:
break
if not_enough:
- assert self.waiter is None
- self.waiter = futures.Future(loop=self.loop)
+ assert self._waiter is None
+ self._waiter = futures.Future(loop=self._loop)
try:
- yield from self.waiter
+ yield from self._waiter
finally:
- self.waiter = None
+ self._waiter = None
line = b''.join(parts)
- self.byte_count -= parts_size
+ self._byte_count -= parts_size
self._maybe_resume_transport()
return line
@@ -198,42 +290,42 @@ class StreamReader:
return b''
if n < 0:
- while not self.eof:
- assert not self.waiter
- self.waiter = futures.Future(loop=self.loop)
+ while not self._eof:
+ assert not self._waiter
+ self._waiter = futures.Future(loop=self._loop)
try:
- yield from self.waiter
+ yield from self._waiter
finally:
- self.waiter = None
+ self._waiter = None
else:
- if not self.byte_count and not self.eof:
- assert not self.waiter
- self.waiter = futures.Future(loop=self.loop)
+ if not self._byte_count and not self._eof:
+ assert not self._waiter
+ self._waiter = futures.Future(loop=self._loop)
try:
- yield from self.waiter
+ yield from self._waiter
finally:
- self.waiter = None
+ self._waiter = None
- if n < 0 or self.byte_count <= n:
- data = b''.join(self.buffer)
- self.buffer.clear()
- self.byte_count = 0
+ if n < 0 or self._byte_count <= n:
+ data = b''.join(self._buffer)
+ self._buffer.clear()
+ self._byte_count = 0
self._maybe_resume_transport()
return data
parts = []
parts_bytes = 0
- while self.buffer and parts_bytes < n:
- data = self.buffer.popleft()
+ while self._buffer and parts_bytes < n:
+ data = self._buffer.popleft()
data_bytes = len(data)
if n < parts_bytes + data_bytes:
data_bytes = n - parts_bytes
data, rest = data[:data_bytes], data[data_bytes:]
- self.buffer.appendleft(rest)
+ self._buffer.appendleft(rest)
parts.append(data)
parts_bytes += data_bytes
- self.byte_count -= data_bytes
+ self._byte_count -= data_bytes
self._maybe_resume_transport()
return b''.join(parts)
@@ -246,12 +338,12 @@ class StreamReader:
if n <= 0:
return b''
- while self.byte_count < n and not self.eof:
- assert not self.waiter
- self.waiter = futures.Future(loop=self.loop)
+ while self._byte_count < n and not self._eof:
+ assert not self._waiter
+ self._waiter = futures.Future(loop=self._loop)
try:
- yield from self.waiter
+ yield from self._waiter
finally:
- self.waiter = None
+ self._waiter = None
return (yield from self.read(n))
diff --git a/Lib/asyncio/transports.py b/Lib/asyncio/transports.py
index f1a7180..8c6b189 100644
--- a/Lib/asyncio/transports.py
+++ b/Lib/asyncio/transports.py
@@ -49,6 +49,31 @@ class ReadTransport(BaseTransport):
class WriteTransport(BaseTransport):
"""ABC for write-only transports."""
+ def set_write_buffer_limits(self, high=None, low=None):
+ """Set the high- and low-water limits for write flow control.
+
+ These two values control when to call the protocol's
+ pause_writing() and resume_writing() methods. If specified,
+ the low-water limit must be less than or equal to the
+ high-water limit. Neither value can be negative.
+
+ The defaults are implementation-specific. If only the
+ high-water limit is given, the low-water limit defaults to a
+ implementation-specific value less than or equal to the
+ high-water limit. Setting high to zero forces low to zero as
+ well, and causes pause_writing() to be called whenever the
+ buffer becomes non-empty. Setting low to zero causes
+ resume_writing() to be called only once the buffer is empty.
+ Use of zero for either limit is generally sub-optimal as it
+ reduces opportunities for doing I/O and computation
+ concurrently.
+ """
+ raise NotImplementedError
+
+ def get_write_buffer_size(self):
+ """Return the current size of the write buffer."""
+ raise NotImplementedError
+
def write(self, data):
"""Write some data bytes to the transport.