summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio/streams.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/asyncio/streams.py')
-rw-r--r--Lib/asyncio/streams.py208
1 files changed, 150 insertions, 58 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index 9915aa5..e995368 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -39,7 +39,8 @@ def open_connection(host=None, port=None, *,
protocol = StreamReaderProtocol(reader)
transport, _ = yield from loop.create_connection(
lambda: protocol, host, port, **kwds)
- return reader, transport # (reader, writer)
+ writer = StreamWriter(transport, protocol, reader, loop)
+ return reader, writer
class StreamReaderProtocol(protocols.Protocol):
@@ -52,22 +53,113 @@ class StreamReaderProtocol(protocols.Protocol):
"""
def __init__(self, stream_reader):
- self.stream_reader = stream_reader
+ self._stream_reader = stream_reader
+ self._drain_waiter = None
+ self._paused = False
def connection_made(self, transport):
- self.stream_reader.set_transport(transport)
+ self._stream_reader.set_transport(transport)
def connection_lost(self, exc):
if exc is None:
- self.stream_reader.feed_eof()
+ self._stream_reader.feed_eof()
else:
- self.stream_reader.set_exception(exc)
+ self._stream_reader.set_exception(exc)
+ # Also wake up the writing side.
+ if self._paused:
+ waiter = self._drain_waiter
+ if waiter is not None:
+ self._drain_waiter = None
+ if not waiter.done():
+ if exc is None:
+ waiter.set_result(None)
+ else:
+ waiter.set_exception(exc)
def data_received(self, data):
- self.stream_reader.feed_data(data)
+ self._stream_reader.feed_data(data)
def eof_received(self):
- self.stream_reader.feed_eof()
+ self._stream_reader.feed_eof()
+
+ def pause_writing(self):
+ assert not self._paused
+ self._paused = True
+
+ def resume_writing(self):
+ assert self._paused
+ self._paused = False
+ waiter = self._drain_waiter
+ if waiter is not None:
+ self._drain_waiter = None
+ if not waiter.done():
+ waiter.set_result(None)
+
+
+class StreamWriter:
+ """Wraps a Transport.
+
+ This exposes write(), writelines(), [can_]write_eof(),
+ get_extra_info() and close(). It adds drain() which returns an
+ optional Future on which you can wait for flow control. It also
+ adds a transport attribute which references the Transport
+ directly.
+ """
+
+ def __init__(self, transport, protocol, reader, loop):
+ self._transport = transport
+ self._protocol = protocol
+ self._reader = reader
+ self._loop = loop
+
+ @property
+ def transport(self):
+ return self._transport
+
+ def write(self, data):
+ self._transport.write(data)
+
+ def writelines(self, data):
+ self._transport.writelines(data)
+
+ def write_eof(self):
+ return self._transport.write_eof()
+
+ def can_write_eof(self):
+ return self._transport.can_write_eof()
+
+ def close(self):
+ return self._transport.close()
+
+ def get_extra_info(self, name, default=None):
+ return self._transport.get_extra_info(name, default)
+
+ def drain(self):
+ """This method has an unusual return value.
+
+ The intended use is to write
+
+ w.write(data)
+ yield from w.drain()
+
+ When there's nothing to wait for, drain() returns (), and the
+ yield-from continues immediately. When the transport buffer
+ is full (the protocol is paused), drain() creates and returns
+ a Future and the yield-from will block until that Future is
+ completed, which will happen when the buffer is (partially)
+ drained and the protocol is resumed.
+ """
+ if self._reader._exception is not None:
+ raise self._writer._exception
+ if self._transport._conn_lost: # Uses private variable.
+ raise ConnectionResetError('Connection lost')
+ if not self._protocol._paused:
+ return ()
+ waiter = self._protocol._drain_waiter
+ assert waiter is None or waiter.cancelled()
+ waiter = futures.Future(loop=self._loop)
+ self._protocol._drain_waiter = waiter
+ return waiter
class StreamReader:
@@ -75,14 +167,14 @@ class StreamReader:
def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
# The line length limit is a security feature;
# it also doubles as half the buffer limit.
- self.limit = limit
+ self._limit = limit
if loop is None:
loop = events.get_event_loop()
- self.loop = loop
- self.buffer = collections.deque() # Deque of bytes objects.
- self.byte_count = 0 # Bytes in buffer.
- self.eof = False # Whether we're done.
- self.waiter = None # A future.
+ self._loop = loop
+ self._buffer = collections.deque() # Deque of bytes objects.
+ self._byte_count = 0 # Bytes in buffer.
+ self._eof = False # Whether we're done.
+ self._waiter = None # A future.
self._exception = None
self._transport = None
self._paused = False
@@ -93,9 +185,9 @@ class StreamReader:
def set_exception(self, exc):
self._exception = exc
- waiter = self.waiter
+ waiter = self._waiter
if waiter is not None:
- self.waiter = None
+ self._waiter = None
if not waiter.cancelled():
waiter.set_exception(exc)
@@ -104,15 +196,15 @@ class StreamReader:
self._transport = transport
def _maybe_resume_transport(self):
- if self._paused and self.byte_count <= self.limit:
+ if self._paused and self._byte_count <= self._limit:
self._paused = False
self._transport.resume_reading()
def feed_eof(self):
- self.eof = True
- waiter = self.waiter
+ self._eof = True
+ waiter = self._waiter
if waiter is not None:
- self.waiter = None
+ self._waiter = None
if not waiter.cancelled():
waiter.set_result(True)
@@ -120,18 +212,18 @@ class StreamReader:
if not data:
return
- self.buffer.append(data)
- self.byte_count += len(data)
+ self._buffer.append(data)
+ self._byte_count += len(data)
- waiter = self.waiter
+ waiter = self._waiter
if waiter is not None:
- self.waiter = None
+ self._waiter = None
if not waiter.cancelled():
waiter.set_result(False)
if (self._transport is not None and
not self._paused and
- self.byte_count > 2*self.limit):
+ self._byte_count > 2*self._limit):
try:
self._transport.pause_reading()
except NotImplementedError:
@@ -152,8 +244,8 @@ class StreamReader:
not_enough = True
while not_enough:
- while self.buffer and not_enough:
- data = self.buffer.popleft()
+ while self._buffer and not_enough:
+ data = self._buffer.popleft()
ichar = data.find(b'\n')
if ichar < 0:
parts.append(data)
@@ -162,29 +254,29 @@ class StreamReader:
ichar += 1
head, tail = data[:ichar], data[ichar:]
if tail:
- self.buffer.appendleft(tail)
+ self._buffer.appendleft(tail)
not_enough = False
parts.append(head)
parts_size += len(head)
- if parts_size > self.limit:
- self.byte_count -= parts_size
+ if parts_size > self._limit:
+ self._byte_count -= parts_size
self._maybe_resume_transport()
raise ValueError('Line is too long')
- if self.eof:
+ if self._eof:
break
if not_enough:
- assert self.waiter is None
- self.waiter = futures.Future(loop=self.loop)
+ assert self._waiter is None
+ self._waiter = futures.Future(loop=self._loop)
try:
- yield from self.waiter
+ yield from self._waiter
finally:
- self.waiter = None
+ self._waiter = None
line = b''.join(parts)
- self.byte_count -= parts_size
+ self._byte_count -= parts_size
self._maybe_resume_transport()
return line
@@ -198,42 +290,42 @@ class StreamReader:
return b''
if n < 0:
- while not self.eof:
- assert not self.waiter
- self.waiter = futures.Future(loop=self.loop)
+ while not self._eof:
+ assert not self._waiter
+ self._waiter = futures.Future(loop=self._loop)
try:
- yield from self.waiter
+ yield from self._waiter
finally:
- self.waiter = None
+ self._waiter = None
else:
- if not self.byte_count and not self.eof:
- assert not self.waiter
- self.waiter = futures.Future(loop=self.loop)
+ if not self._byte_count and not self._eof:
+ assert not self._waiter
+ self._waiter = futures.Future(loop=self._loop)
try:
- yield from self.waiter
+ yield from self._waiter
finally:
- self.waiter = None
+ self._waiter = None
- if n < 0 or self.byte_count <= n:
- data = b''.join(self.buffer)
- self.buffer.clear()
- self.byte_count = 0
+ if n < 0 or self._byte_count <= n:
+ data = b''.join(self._buffer)
+ self._buffer.clear()
+ self._byte_count = 0
self._maybe_resume_transport()
return data
parts = []
parts_bytes = 0
- while self.buffer and parts_bytes < n:
- data = self.buffer.popleft()
+ while self._buffer and parts_bytes < n:
+ data = self._buffer.popleft()
data_bytes = len(data)
if n < parts_bytes + data_bytes:
data_bytes = n - parts_bytes
data, rest = data[:data_bytes], data[data_bytes:]
- self.buffer.appendleft(rest)
+ self._buffer.appendleft(rest)
parts.append(data)
parts_bytes += data_bytes
- self.byte_count -= data_bytes
+ self._byte_count -= data_bytes
self._maybe_resume_transport()
return b''.join(parts)
@@ -246,12 +338,12 @@ class StreamReader:
if n <= 0:
return b''
- while self.byte_count < n and not self.eof:
- assert not self.waiter
- self.waiter = futures.Future(loop=self.loop)
+ while self._byte_count < n and not self._eof:
+ assert not self._waiter
+ self._waiter = futures.Future(loop=self._loop)
try:
- yield from self.waiter
+ yield from self._waiter
finally:
- self.waiter = None
+ self._waiter = None
return (yield from self.read(n))