diff options
Diffstat (limited to 'Lib/asyncio/streams.py')
-rw-r--r-- | Lib/asyncio/streams.py | 208 |
1 files changed, 150 insertions, 58 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 9915aa5..e995368 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -39,7 +39,8 @@ def open_connection(host=None, port=None, *, protocol = StreamReaderProtocol(reader) transport, _ = yield from loop.create_connection( lambda: protocol, host, port, **kwds) - return reader, transport # (reader, writer) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer class StreamReaderProtocol(protocols.Protocol): @@ -52,22 +53,113 @@ class StreamReaderProtocol(protocols.Protocol): """ def __init__(self, stream_reader): - self.stream_reader = stream_reader + self._stream_reader = stream_reader + self._drain_waiter = None + self._paused = False def connection_made(self, transport): - self.stream_reader.set_transport(transport) + self._stream_reader.set_transport(transport) def connection_lost(self, exc): if exc is None: - self.stream_reader.feed_eof() + self._stream_reader.feed_eof() else: - self.stream_reader.set_exception(exc) + self._stream_reader.set_exception(exc) + # Also wake up the writing side. + if self._paused: + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) def data_received(self, data): - self.stream_reader.feed_data(data) + self._stream_reader.feed_data(data) def eof_received(self): - self.stream_reader.feed_eof() + self._stream_reader.feed_eof() + + def pause_writing(self): + assert not self._paused + self._paused = True + + def resume_writing(self): + assert self._paused + self._paused = False + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + +class StreamWriter: + """Wraps a Transport. + + This exposes write(), writelines(), [can_]write_eof(), + get_extra_info() and close(). It adds drain() which returns an + optional Future on which you can wait for flow control. It also + adds a transport attribute which references the Transport + directly. + """ + + def __init__(self, transport, protocol, reader, loop): + self._transport = transport + self._protocol = protocol + self._reader = reader + self._loop = loop + + @property + def transport(self): + return self._transport + + def write(self, data): + self._transport.write(data) + + def writelines(self, data): + self._transport.writelines(data) + + def write_eof(self): + return self._transport.write_eof() + + def can_write_eof(self): + return self._transport.can_write_eof() + + def close(self): + return self._transport.close() + + def get_extra_info(self, name, default=None): + return self._transport.get_extra_info(name, default) + + def drain(self): + """This method has an unusual return value. + + The intended use is to write + + w.write(data) + yield from w.drain() + + When there's nothing to wait for, drain() returns (), and the + yield-from continues immediately. When the transport buffer + is full (the protocol is paused), drain() creates and returns + a Future and the yield-from will block until that Future is + completed, which will happen when the buffer is (partially) + drained and the protocol is resumed. + """ + if self._reader._exception is not None: + raise self._writer._exception + if self._transport._conn_lost: # Uses private variable. + raise ConnectionResetError('Connection lost') + if not self._protocol._paused: + return () + waiter = self._protocol._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = futures.Future(loop=self._loop) + self._protocol._drain_waiter = waiter + return waiter class StreamReader: @@ -75,14 +167,14 @@ class StreamReader: def __init__(self, limit=_DEFAULT_LIMIT, loop=None): # The line length limit is a security feature; # it also doubles as half the buffer limit. - self.limit = limit + self._limit = limit if loop is None: loop = events.get_event_loop() - self.loop = loop - self.buffer = collections.deque() # Deque of bytes objects. - self.byte_count = 0 # Bytes in buffer. - self.eof = False # Whether we're done. - self.waiter = None # A future. + self._loop = loop + self._buffer = collections.deque() # Deque of bytes objects. + self._byte_count = 0 # Bytes in buffer. + self._eof = False # Whether we're done. + self._waiter = None # A future. self._exception = None self._transport = None self._paused = False @@ -93,9 +185,9 @@ class StreamReader: def set_exception(self, exc): self._exception = exc - waiter = self.waiter + waiter = self._waiter if waiter is not None: - self.waiter = None + self._waiter = None if not waiter.cancelled(): waiter.set_exception(exc) @@ -104,15 +196,15 @@ class StreamReader: self._transport = transport def _maybe_resume_transport(self): - if self._paused and self.byte_count <= self.limit: + if self._paused and self._byte_count <= self._limit: self._paused = False self._transport.resume_reading() def feed_eof(self): - self.eof = True - waiter = self.waiter + self._eof = True + waiter = self._waiter if waiter is not None: - self.waiter = None + self._waiter = None if not waiter.cancelled(): waiter.set_result(True) @@ -120,18 +212,18 @@ class StreamReader: if not data: return - self.buffer.append(data) - self.byte_count += len(data) + self._buffer.append(data) + self._byte_count += len(data) - waiter = self.waiter + waiter = self._waiter if waiter is not None: - self.waiter = None + self._waiter = None if not waiter.cancelled(): waiter.set_result(False) if (self._transport is not None and not self._paused and - self.byte_count > 2*self.limit): + self._byte_count > 2*self._limit): try: self._transport.pause_reading() except NotImplementedError: @@ -152,8 +244,8 @@ class StreamReader: not_enough = True while not_enough: - while self.buffer and not_enough: - data = self.buffer.popleft() + while self._buffer and not_enough: + data = self._buffer.popleft() ichar = data.find(b'\n') if ichar < 0: parts.append(data) @@ -162,29 +254,29 @@ class StreamReader: ichar += 1 head, tail = data[:ichar], data[ichar:] if tail: - self.buffer.appendleft(tail) + self._buffer.appendleft(tail) not_enough = False parts.append(head) parts_size += len(head) - if parts_size > self.limit: - self.byte_count -= parts_size + if parts_size > self._limit: + self._byte_count -= parts_size self._maybe_resume_transport() raise ValueError('Line is too long') - if self.eof: + if self._eof: break if not_enough: - assert self.waiter is None - self.waiter = futures.Future(loop=self.loop) + assert self._waiter is None + self._waiter = futures.Future(loop=self._loop) try: - yield from self.waiter + yield from self._waiter finally: - self.waiter = None + self._waiter = None line = b''.join(parts) - self.byte_count -= parts_size + self._byte_count -= parts_size self._maybe_resume_transport() return line @@ -198,42 +290,42 @@ class StreamReader: return b'' if n < 0: - while not self.eof: - assert not self.waiter - self.waiter = futures.Future(loop=self.loop) + while not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) try: - yield from self.waiter + yield from self._waiter finally: - self.waiter = None + self._waiter = None else: - if not self.byte_count and not self.eof: - assert not self.waiter - self.waiter = futures.Future(loop=self.loop) + if not self._byte_count and not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) try: - yield from self.waiter + yield from self._waiter finally: - self.waiter = None + self._waiter = None - if n < 0 or self.byte_count <= n: - data = b''.join(self.buffer) - self.buffer.clear() - self.byte_count = 0 + if n < 0 or self._byte_count <= n: + data = b''.join(self._buffer) + self._buffer.clear() + self._byte_count = 0 self._maybe_resume_transport() return data parts = [] parts_bytes = 0 - while self.buffer and parts_bytes < n: - data = self.buffer.popleft() + while self._buffer and parts_bytes < n: + data = self._buffer.popleft() data_bytes = len(data) if n < parts_bytes + data_bytes: data_bytes = n - parts_bytes data, rest = data[:data_bytes], data[data_bytes:] - self.buffer.appendleft(rest) + self._buffer.appendleft(rest) parts.append(data) parts_bytes += data_bytes - self.byte_count -= data_bytes + self._byte_count -= data_bytes self._maybe_resume_transport() return b''.join(parts) @@ -246,12 +338,12 @@ class StreamReader: if n <= 0: return b'' - while self.byte_count < n and not self.eof: - assert not self.waiter - self.waiter = futures.Future(loop=self.loop) + while self._byte_count < n and not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) try: - yield from self.waiter + yield from self._waiter finally: - self.waiter = None + self._waiter = None return (yield from self.read(n)) |