diff options
author | Andrew Svetlov <andrew.svetlov@gmail.com> | 2019-05-27 19:56:22 (GMT) |
---|---|---|
committer | Miss Islington (bot) <31488909+miss-islington@users.noreply.github.com> | 2019-05-27 19:56:22 (GMT) |
commit | 23b4b697e5b6cc897696f9c0288c187d2d24bff2 (patch) | |
tree | 2f70e14fe527878cd69ccbefca007a1e987943ed /Lib/asyncio/streams.py | |
parent | 6f6ff8a56518a80da406aad6ac8364c046cc7f18 (diff) | |
download | cpython-23b4b697e5b6cc897696f9c0288c187d2d24bff2.zip cpython-23b4b697e5b6cc897696f9c0288c187d2d24bff2.tar.gz cpython-23b4b697e5b6cc897696f9c0288c187d2d24bff2.tar.bz2 |
bpo-36889: Merge asyncio streams (GH-13251)
https://bugs.python.org/issue36889
Diffstat (limited to 'Lib/asyncio/streams.py')
-rw-r--r-- | Lib/asyncio/streams.py | 1236 |
1 files changed, 1094 insertions, 142 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 2f0cbfd..480f1a3 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,14 +1,19 @@ __all__ = ( - 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', - 'open_connection', 'start_server') + 'Stream', 'StreamMode', + 'open_connection', 'start_server', + 'connect', 'connect_read_pipe', 'connect_write_pipe', + 'StreamServer') +import enum import socket import sys import warnings import weakref if hasattr(socket, 'AF_UNIX'): - __all__ += ('open_unix_connection', 'start_unix_server') + __all__ += ('open_unix_connection', 'start_unix_server', + 'connect_unix', + 'UnixStreamServer') from . import coroutines from . import events @@ -16,12 +21,134 @@ from . import exceptions from . import format_helpers from . import protocols from .log import logger -from .tasks import sleep +from . import tasks _DEFAULT_LIMIT = 2 ** 16 # 64 KiB +class StreamMode(enum.Flag): + READ = enum.auto() + WRITE = enum.auto() + READWRITE = READ | WRITE + + +def _ensure_can_read(mode): + if not mode & StreamMode.READ: + raise RuntimeError("The stream is write-only") + + +def _ensure_can_write(mode): + if not mode & StreamMode.WRITE: + raise RuntimeError("The stream is read-only") + + +class _ContextManagerHelper: + __slots__ = ('_awaitable', '_result') + + def __init__(self, awaitable): + self._awaitable = awaitable + self._result = None + + def __await__(self): + return self._awaitable.__await__() + + async def __aenter__(self): + ret = await self._awaitable + result = await ret.__aenter__() + self._result = result + return result + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self._result.__aexit__(exc_type, exc_val, exc_tb) + + +def connect(host=None, port=None, *, + limit=_DEFAULT_LIMIT, + ssl=None, family=0, proto=0, + flags=0, sock=None, local_addr=None, + server_hostname=None, + ssl_handshake_timeout=None, + happy_eyeballs_delay=None, interleave=None): + # Design note: + # Don't use decorator approach but exilicit non-async + # function to fail fast and explicitly + # if passed arguments don't match the function signature + return _ContextManagerHelper(_connect(host, port, limit, + ssl, family, proto, + flags, sock, local_addr, + server_hostname, + ssl_handshake_timeout, + happy_eyeballs_delay, + interleave)) + + +async def _connect(host, port, + limit, + ssl, family, proto, + flags, sock, local_addr, + server_hostname, + ssl_handshake_timeout, + happy_eyeballs_delay, interleave): + loop = events.get_running_loop() + stream = Stream(mode=StreamMode.READWRITE, + limit=limit, + loop=loop, + _asyncio_internal=True) + await loop.create_connection( + lambda: _StreamProtocol(stream, loop=loop, + _asyncio_internal=True), + host, port, + ssl=ssl, family=family, proto=proto, + flags=flags, sock=sock, local_addr=local_addr, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + happy_eyeballs_delay=happy_eyeballs_delay, interleave=interleave) + return stream + + +def connect_read_pipe(pipe, *, limit=_DEFAULT_LIMIT): + # Design note: + # Don't use decorator approach but explicit non-async + # function to fail fast and explicitly + # if passed arguments don't match the function signature + return _ContextManagerHelper(_connect_read_pipe(pipe, limit)) + + +async def _connect_read_pipe(pipe, limit): + loop = events.get_running_loop() + stream = Stream(mode=StreamMode.READ, + limit=limit, + loop=loop, + _asyncio_internal=True) + await loop.connect_read_pipe( + lambda: _StreamProtocol(stream, loop=loop, + _asyncio_internal=True), + pipe) + return stream + + +def connect_write_pipe(pipe, *, limit=_DEFAULT_LIMIT): + # Design note: + # Don't use decorator approach but explicit non-async + # function to fail fast and explicitly + # if passed arguments don't match the function signature + return _ContextManagerHelper(_connect_write_pipe(pipe, limit)) + + +async def _connect_write_pipe(pipe, limit): + loop = events.get_running_loop() + stream = Stream(mode=StreamMode.WRITE, + limit=limit, + loop=loop, + _asyncio_internal=True) + await loop.connect_write_pipe( + lambda: _StreamProtocol(stream, loop=loop, + _asyncio_internal=True), + pipe) + return stream + + async def open_connection(host=None, port=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """A wrapper for create_connection() returning a (reader, writer) pair. @@ -41,16 +168,18 @@ async def open_connection(host=None, port=None, *, StreamReaderProtocol classes, just copy the code -- there's really nothing special here except some convenience.) """ + warnings.warn("open_connection() is deprecated since Python 3.8 " + "in favor of connect(), and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) if loop is None: loop = events.get_event_loop() - reader = StreamReader(limit=limit, loop=loop, - _asyncio_internal=True) - protocol = StreamReaderProtocol(reader, loop=loop, - _asyncio_internal=True) + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop, _asyncio_internal=True) transport, _ = await loop.create_connection( lambda: protocol, host, port, **kwds) - writer = StreamWriter(transport, protocol, reader, loop, - _asyncio_internal=True) + writer = StreamWriter(transport, protocol, reader, loop) return reader, writer @@ -77,12 +206,16 @@ async def start_server(client_connected_cb, host=None, port=None, *, The return value is the same as loop.create_server(), i.e. a Server object which can be used to stop the service. """ + warnings.warn("start_server() is deprecated since Python 3.8 " + "in favor of StreamServer(), and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) if loop is None: loop = events.get_event_loop() def factory(): - reader = StreamReader(limit=limit, loop=loop, - _asyncio_internal=True) + reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, client_connected_cb, loop=loop, _asyncio_internal=True) @@ -91,33 +224,258 @@ async def start_server(client_connected_cb, host=None, port=None, *, return await loop.create_server(factory, host, port, **kwds) +class _BaseStreamServer: + # Design notes. + # StreamServer and UnixStreamServer are exposed as FINAL classes, + # not function factories. + # async with serve(host, port) as server: + # server.start_serving() + # looks ugly. + # The class doesn't provide API for enumerating connected streams + # It can be a subject for improvements in Python 3.9 + + _server_impl = None + + def __init__(self, client_connected_cb, + /, + limit=_DEFAULT_LIMIT, + shutdown_timeout=60, + _asyncio_internal=False): + if not _asyncio_internal: + raise RuntimeError("_ServerStream is a private asyncio class") + self._client_connected_cb = client_connected_cb + self._limit = limit + self._loop = events.get_running_loop() + self._streams = {} + self._shutdown_timeout = shutdown_timeout + + def __init_subclass__(cls): + if not cls.__module__.startswith('asyncio.'): + raise TypeError(f"asyncio.{cls.__name__} " + "class cannot be inherited from") + + async def bind(self): + if self._server_impl is not None: + return + self._server_impl = await self._bind() + + def is_bound(self): + return self._server_impl is not None + + @property + def sockets(self): + # multiple value for socket bound to both IPv4 and IPv6 families + if self._server_impl is None: + return () + return self._server_impl.sockets + + def is_serving(self): + if self._server_impl is None: + return False + return self._server_impl.is_serving() + + async def start_serving(self): + await self.bind() + await self._server_impl.start_serving() + + async def serve_forever(self): + await self.start_serving() + await self._server_impl.serve_forever() + + async def close(self): + if self._server_impl is None: + return + self._server_impl.close() + streams = list(self._streams.keys()) + active_tasks = list(self._streams.values()) + if streams: + await tasks.wait([stream.close() for stream in streams]) + await self._server_impl.wait_closed() + self._server_impl = None + await self._shutdown_active_tasks(active_tasks) + + async def abort(self): + if self._server_impl is None: + return + self._server_impl.close() + streams = list(self._streams.keys()) + active_tasks = list(self._streams.values()) + if streams: + await tasks.wait([stream.abort() for stream in streams]) + await self._server_impl.wait_closed() + self._server_impl = None + await self._shutdown_active_tasks(active_tasks) + + async def __aenter__(self): + await self.bind() + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + await self.close() + + def _attach(self, stream, task): + self._streams[stream] = task + + def _detach(self, stream, task): + del self._streams[stream] + + async def _shutdown_active_tasks(self, active_tasks): + if not active_tasks: + return + # NOTE: tasks finished with exception are reported + # by the Task.__del__() method. + done, pending = await tasks.wait(active_tasks, + timeout=self._shutdown_timeout) + if not pending: + return + for task in pending: + task.cancel() + done, pending = await tasks.wait(pending, + timeout=self._shutdown_timeout) + for task in pending: + self._loop.call_exception_handler({ + "message": (f'{task!r} ignored cancellation request ' + f'from a closing {self!r}'), + "stream_server": self + }) + + def __repr__(self): + ret = [f'{self.__class__.__name__}'] + if self.is_serving(): + ret.append('serving') + if self.sockets: + ret.append(f'sockets={self.sockets!r}') + return '<' + ' '.join(ret) + '>' + + def __del__(self, _warn=warnings.warn): + if self._server_impl is not None: + _warn(f"unclosed stream server {self!r}", + ResourceWarning, source=self) + self._server_impl.close() + + +class StreamServer(_BaseStreamServer): + + def __init__(self, client_connected_cb, /, host=None, port=None, *, + limit=_DEFAULT_LIMIT, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, sock=None, backlog=100, + ssl=None, reuse_address=None, reuse_port=None, + ssl_handshake_timeout=None, + shutdown_timeout=60): + super().__init__(client_connected_cb, + limit=limit, + shutdown_timeout=shutdown_timeout, + _asyncio_internal=True) + self._host = host + self._port = port + self._family = family + self._flags = flags + self._sock = sock + self._backlog = backlog + self._ssl = ssl + self._reuse_address = reuse_address + self._reuse_port = reuse_port + self._ssl_handshake_timeout = ssl_handshake_timeout + + async def _bind(self): + def factory(): + protocol = _ServerStreamProtocol(self, + self._limit, + self._client_connected_cb, + loop=self._loop, + _asyncio_internal=True) + return protocol + return await self._loop.create_server( + factory, + self._host, + self._port, + start_serving=False, + family=self._family, + flags=self._flags, + sock=self._sock, + backlog=self._backlog, + ssl=self._ssl, + reuse_address=self._reuse_address, + reuse_port=self._reuse_port, + ssl_handshake_timeout=self._ssl_handshake_timeout) + + if hasattr(socket, 'AF_UNIX'): # UNIX Domain Sockets are supported on this platform async def open_unix_connection(path=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `open_connection` but works with UNIX Domain Sockets.""" + warnings.warn("open_unix_connection() is deprecated since Python 3.8 " + "in favor of connect_unix(), and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) if loop is None: loop = events.get_event_loop() - reader = StreamReader(limit=limit, loop=loop, - _asyncio_internal=True) + reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, loop=loop, _asyncio_internal=True) transport, _ = await loop.create_unix_connection( lambda: protocol, path, **kwds) - writer = StreamWriter(transport, protocol, reader, loop, - _asyncio_internal=True) + writer = StreamWriter(transport, protocol, reader, loop) return reader, writer + + def connect_unix(path=None, *, + limit=_DEFAULT_LIMIT, + ssl=None, sock=None, + server_hostname=None, + ssl_handshake_timeout=None): + """Similar to `connect()` but works with UNIX Domain Sockets.""" + # Design note: + # Don't use decorator approach but exilicit non-async + # function to fail fast and explicitly + # if passed arguments don't match the function signature + return _ContextManagerHelper(_connect_unix(path, + limit, + ssl, sock, + server_hostname, + ssl_handshake_timeout)) + + + async def _connect_unix(path, + limit, + ssl, sock, + server_hostname, + ssl_handshake_timeout): + """Similar to `connect()` but works with UNIX Domain Sockets.""" + loop = events.get_running_loop() + stream = Stream(mode=StreamMode.READWRITE, + limit=limit, + loop=loop, + _asyncio_internal=True) + await loop.create_unix_connection( + lambda: _StreamProtocol(stream, + loop=loop, + _asyncio_internal=True), + path, + ssl=ssl, + sock=sock, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) + return stream + + async def start_unix_server(client_connected_cb, path=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `start_server` but works with UNIX Domain Sockets.""" + warnings.warn("start_unix_server() is deprecated since Python 3.8 " + "in favor of UnixStreamServer(), and scheduled " + "for removal in Python 3.10", + DeprecationWarning, + stacklevel=2) if loop is None: loop = events.get_event_loop() def factory(): - reader = StreamReader(limit=limit, loop=loop, - _asyncio_internal=True) + reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, client_connected_cb, loop=loop, _asyncio_internal=True) @@ -125,6 +483,42 @@ if hasattr(socket, 'AF_UNIX'): return await loop.create_unix_server(factory, path, **kwds) + class UnixStreamServer(_BaseStreamServer): + + def __init__(self, client_connected_cb, /, path=None, *, + limit=_DEFAULT_LIMIT, + sock=None, + backlog=100, + ssl=None, + ssl_handshake_timeout=None, + shutdown_timeout=60): + super().__init__(client_connected_cb, + limit=limit, + shutdown_timeout=shutdown_timeout, + _asyncio_internal=True) + self._path = path + self._sock = sock + self._backlog = backlog + self._ssl = ssl + self._ssl_handshake_timeout = ssl_handshake_timeout + + async def _bind(self): + def factory(): + protocol = _ServerStreamProtocol(self, + self._limit, + self._client_connected_cb, + loop=self._loop, + _asyncio_internal=True) + return protocol + return await self._loop.create_unix_server( + factory, + self._path, + start_serving=False, + sock=self._sock, + backlog=self._backlog, + ssl=self._ssl, + ssl_handshake_timeout=self._ssl_handshake_timeout) + class FlowControlMixin(protocols.Protocol): """Reusable flow control logic for StreamWriter.drain(). @@ -203,6 +597,8 @@ class FlowControlMixin(protocols.Protocol): raise NotImplementedError +# begin legacy stream APIs + class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): """Helper class to adapt between Protocol and StreamReader. @@ -212,105 +608,47 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): call inappropriate methods of the protocol.) """ - _source_traceback = None - def __init__(self, stream_reader, client_connected_cb=None, loop=None, *, _asyncio_internal=False): super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) - if stream_reader is not None: - self._stream_reader_wr = weakref.ref(stream_reader, - self._on_reader_gc) - self._source_traceback = stream_reader._source_traceback - else: - self._stream_reader_wr = None - if client_connected_cb is not None: - # This is a stream created by the `create_server()` function. - # Keep a strong reference to the reader until a connection - # is established. - self._strong_reader = stream_reader - self._reject_connection = False + self._stream_reader = stream_reader self._stream_writer = None - self._transport = None self._client_connected_cb = client_connected_cb self._over_ssl = False self._closed = self._loop.create_future() - def _on_reader_gc(self, wr): - transport = self._transport - if transport is not None: - # connection_made was called - context = { - 'message': ('An open stream object is being garbage ' - 'collected; call "stream.close()" explicitly.') - } - if self._source_traceback: - context['source_traceback'] = self._source_traceback - self._loop.call_exception_handler(context) - transport.abort() - else: - self._reject_connection = True - self._stream_reader_wr = None - - @property - def _stream_reader(self): - if self._stream_reader_wr is None: - return None - return self._stream_reader_wr() - def connection_made(self, transport): - if self._reject_connection: - context = { - 'message': ('An open stream was garbage collected prior to ' - 'establishing network connection; ' - 'call "stream.close()" explicitly.') - } - if self._source_traceback: - context['source_traceback'] = self._source_traceback - self._loop.call_exception_handler(context) - transport.abort() - return - self._transport = transport - reader = self._stream_reader - if reader is not None: - reader.set_transport(transport) + self._stream_reader.set_transport(transport) self._over_ssl = transport.get_extra_info('sslcontext') is not None if self._client_connected_cb is not None: self._stream_writer = StreamWriter(transport, self, - reader, - self._loop, - _asyncio_internal=True) - res = self._client_connected_cb(reader, + self._stream_reader, + self._loop) + res = self._client_connected_cb(self._stream_reader, self._stream_writer) if coroutines.iscoroutine(res): self._loop.create_task(res) - self._strong_reader = None def connection_lost(self, exc): - reader = self._stream_reader - if reader is not None: + if self._stream_reader is not None: if exc is None: - reader.feed_eof() + self._stream_reader.feed_eof() else: - reader.set_exception(exc) + self._stream_reader.set_exception(exc) if not self._closed.done(): if exc is None: self._closed.set_result(None) else: self._closed.set_exception(exc) super().connection_lost(exc) - self._stream_reader_wr = None + self._stream_reader = None self._stream_writer = None - self._transport = None def data_received(self, data): - reader = self._stream_reader - if reader is not None: - reader.feed_data(data) + self._stream_reader.feed_data(data) def eof_received(self): - reader = self._stream_reader - if reader is not None: - reader.feed_eof() + self._stream_reader.feed_eof() if self._over_ssl: # Prevent a warning in SSLProtocol.eof_received: # "returning true from eof_received() @@ -318,9 +656,6 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): return False return True - def _get_close_waiter(self, stream): - return self._closed - def __del__(self): # Prevent reports about unhandled exceptions. # Better than self._closed._log_traceback = False hack @@ -329,13 +664,6 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): closed.exception() -def _swallow_unhandled_exception(task): - # Do a trick to suppress unhandled exception - # if stream.write() was used without await and - # stream.drain() was paused and resumed with an exception - task.exception() - - class StreamWriter: """Wraps a Transport. @@ -346,21 +674,13 @@ class StreamWriter: directly. """ - def __init__(self, transport, protocol, reader, loop, - *, _asyncio_internal=False): - if not _asyncio_internal: - warnings.warn(f"{self.__class__} should be instaniated " - "by asyncio internals only, " - "please avoid its creation from user code", - DeprecationWarning) + def __init__(self, transport, protocol, reader, loop): self._transport = transport self._protocol = protocol # drain() expects that the reader has an exception() method assert reader is None or isinstance(reader, StreamReader) self._reader = reader self._loop = loop - self._complete_fut = self._loop.create_future() - self._complete_fut.set_result(None) def __repr__(self): info = [self.__class__.__name__, f'transport={self._transport!r}'] @@ -374,35 +694,9 @@ class StreamWriter: def write(self, data): self._transport.write(data) - return self._fast_drain() def writelines(self, data): self._transport.writelines(data) - return self._fast_drain() - - def _fast_drain(self): - # The helper tries to use fast-path to return already existing complete future - # object if underlying transport is not paused and actual waiting for writing - # resume is not needed - if self._reader is not None: - # this branch will be simplified after merging reader with writer - exc = self._reader.exception() - if exc is not None: - fut = self._loop.create_future() - fut.set_exception(exc) - return fut - if not self._transport.is_closing(): - if self._protocol._connection_lost: - fut = self._loop.create_future() - fut.set_exception(ConnectionResetError('Connection lost')) - return fut - if not self._protocol._paused: - # fast path, the stream is not paused - # no need to wait for resume signal - return self._complete_fut - ret = self._loop.create_task(self.drain()) - ret.add_done_callback(_swallow_unhandled_exception) - return ret def write_eof(self): return self._transport.write_eof() @@ -411,14 +705,13 @@ class StreamWriter: return self._transport.can_write_eof() def close(self): - self._transport.close() - return self._protocol._get_close_waiter(self) + return self._transport.close() def is_closing(self): return self._transport.is_closing() async def wait_closed(self): - await self._protocol._get_close_waiter(self) + await self._protocol._closed def get_extra_info(self, name, default=None): return self._transport.get_extra_info(name, default) @@ -436,24 +729,562 @@ class StreamWriter: if exc is not None: raise exc if self._transport.is_closing(): - # Wait for protocol.connection_lost() call - # Raise connection closing error if any, - # ConnectionResetError otherwise - await sleep(0) + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); await drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + await tasks.sleep(0, loop=self._loop) await self._protocol._drain_helper() class StreamReader: + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + # The line length limit is a security feature; + # it also doubles as half the buffer limit. + + if limit <= 0: + raise ValueError('Limit cannot be <= 0') + + self._limit = limit + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._buffer = bytearray() + self._eof = False # Whether we're done. + self._waiter = None # A future used by _wait_for_data() + self._exception = None + self._transport = None + self._paused = False + + def __repr__(self): + info = ['StreamReader'] + if self._buffer: + info.append(f'{len(self._buffer)} bytes') + if self._eof: + info.append('eof') + if self._limit != _DEFAULT_LIMIT: + info.append(f'limit={self._limit}') + if self._waiter: + info.append(f'waiter={self._waiter!r}') + if self._exception: + info.append(f'exception={self._exception!r}') + if self._transport: + info.append(f'transport={self._transport!r}') + if self._paused: + info.append('paused') + return '<{}>'.format(' '.join(info)) + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_exception(exc) + + def _wakeup_waiter(self): + """Wakeup read*() functions waiting for data or EOF.""" + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_result(None) + + def set_transport(self, transport): + assert self._transport is None, 'Transport already set' + self._transport = transport + + def _maybe_resume_transport(self): + if self._paused and len(self._buffer) <= self._limit: + self._paused = False + self._transport.resume_reading() + + def feed_eof(self): + self._eof = True + self._wakeup_waiter() + + def at_eof(self): + """Return True if the buffer is empty and 'feed_eof' was called.""" + return self._eof and not self._buffer + + def feed_data(self, data): + assert not self._eof, 'feed_data after feed_eof' + + if not data: + return + + self._buffer.extend(data) + self._wakeup_waiter() + + if (self._transport is not None and + not self._paused and + len(self._buffer) > 2 * self._limit): + try: + self._transport.pause_reading() + except NotImplementedError: + # The transport can't be paused. + # We'll just have to buffer all data. + # Forget the transport so we don't keep trying. + self._transport = None + else: + self._paused = True + + async def _wait_for_data(self, func_name): + """Wait until feed_data() or feed_eof() is called. + + If stream was paused, automatically resume it. + """ + # StreamReader uses a future to link the protocol feed_data() method + # to a read coroutine. Running two read coroutines at the same time + # would have an unexpected behaviour. It would not possible to know + # which coroutine would get the next data. + if self._waiter is not None: + raise RuntimeError( + f'{func_name}() called while another coroutine is ' + f'already waiting for incoming data') + + assert not self._eof, '_wait_for_data after EOF' + + # Waiting for data while paused will make deadlock, so prevent it. + # This is essential for readexactly(n) for case when n > self._limit. + if self._paused: + self._paused = False + self._transport.resume_reading() + + self._waiter = self._loop.create_future() + try: + await self._waiter + finally: + self._waiter = None + + async def readline(self): + """Read chunk of data from the stream until newline (b'\n') is found. + + On success, return chunk that ends with newline. If only partial + line can be read due to EOF, return incomplete line without + terminating newline. When EOF was reached while no bytes read, empty + bytes object is returned. + + If limit is reached, ValueError will be raised. In that case, if + newline was found, complete line including newline will be removed + from internal buffer. Else, internal buffer will be cleared. Limit is + compared against part of the line without newline. + + If stream was paused, this function will automatically resume it if + needed. + """ + sep = b'\n' + seplen = len(sep) + try: + line = await self.readuntil(sep) + except exceptions.IncompleteReadError as e: + return e.partial + except exceptions.LimitOverrunError as e: + if self._buffer.startswith(sep, e.consumed): + del self._buffer[:e.consumed + seplen] + else: + self._buffer.clear() + self._maybe_resume_transport() + raise ValueError(e.args[0]) + return line + + async def readuntil(self, separator=b'\n'): + """Read data from the stream until ``separator`` is found. + + On success, the data and separator will be removed from the + internal buffer (consumed). Returned data will include the + separator at the end. + + Configured stream limit is used to check result. Limit sets the + maximal length of data that can be returned, not counting the + separator. + + If an EOF occurs and the complete separator is still not found, + an IncompleteReadError exception will be raised, and the internal + buffer will be reset. The IncompleteReadError.partial attribute + may contain the separator partially. + + If the data cannot be read because of over limit, a + LimitOverrunError exception will be raised, and the data + will be left in the internal buffer, so it can be read again. + """ + seplen = len(separator) + if seplen == 0: + raise ValueError('Separator should be at least one-byte string') + + if self._exception is not None: + raise self._exception + + # Consume whole buffer except last bytes, which length is + # one less than seplen. Let's check corner cases with + # separator='SEPARATOR': + # * we have received almost complete separator (without last + # byte). i.e buffer='some textSEPARATO'. In this case we + # can safely consume len(separator) - 1 bytes. + # * last byte of buffer is first byte of separator, i.e. + # buffer='abcdefghijklmnopqrS'. We may safely consume + # everything except that last byte, but this require to + # analyze bytes of buffer that match partial separator. + # This is slow and/or require FSM. For this case our + # implementation is not optimal, since require rescanning + # of data that is known to not belong to separator. In + # real world, separator will not be so long to notice + # performance problems. Even when reading MIME-encoded + # messages :) + + # `offset` is the number of bytes from the beginning of the buffer + # where there is no occurrence of `separator`. + offset = 0 + + # Loop until we find `separator` in the buffer, exceed the buffer size, + # or an EOF has happened. + while True: + buflen = len(self._buffer) + + # Check if we now have enough data in the buffer for `separator` to + # fit. + if buflen - offset >= seplen: + isep = self._buffer.find(separator, offset) + + if isep != -1: + # `separator` is in the buffer. `isep` will be used later + # to retrieve the data. + break + + # see upper comment for explanation. + offset = buflen + 1 - seplen + if offset > self._limit: + raise exceptions.LimitOverrunError( + 'Separator is not found, and chunk exceed the limit', + offset) + + # Complete message (with full separator) may be present in buffer + # even when EOF flag is set. This may happen when the last chunk + # adds data which makes separator be found. That's why we check for + # EOF *ater* inspecting the buffer. + if self._eof: + chunk = bytes(self._buffer) + self._buffer.clear() + raise exceptions.IncompleteReadError(chunk, None) + + # _wait_for_data() will resume reading if stream was paused. + await self._wait_for_data('readuntil') + + if isep > self._limit: + raise exceptions.LimitOverrunError( + 'Separator is found, but chunk is longer than limit', isep) + + chunk = self._buffer[:isep + seplen] + del self._buffer[:isep + seplen] + self._maybe_resume_transport() + return bytes(chunk) + + async def read(self, n=-1): + """Read up to `n` bytes from the stream. + + If n is not provided, or set to -1, read until EOF and return all read + bytes. If the EOF was received and the internal buffer is empty, return + an empty bytes object. + + If n is zero, return empty bytes object immediately. + + If n is positive, this function try to read `n` bytes, and may return + less or equal bytes than requested, but at least one byte. If EOF was + received before any byte is read, this function returns empty byte + object. + + Returned value is not limited with limit, configured at stream + creation. + + If stream was paused, this function will automatically resume it if + needed. + """ + + if self._exception is not None: + raise self._exception + + if n == 0: + return b'' + + if n < 0: + # This used to just loop creating a new waiter hoping to + # collect everything in self._buffer, but that would + # deadlock if the subprocess sends more than self.limit + # bytes. So just call self.read(self._limit) until EOF. + blocks = [] + while True: + block = await self.read(self._limit) + if not block: + break + blocks.append(block) + return b''.join(blocks) + + if not self._buffer and not self._eof: + await self._wait_for_data('read') + + # This will work right even if buffer is less than n bytes + data = bytes(self._buffer[:n]) + del self._buffer[:n] + + self._maybe_resume_transport() + return data + + async def readexactly(self, n): + """Read exactly `n` bytes. + + Raise an IncompleteReadError if EOF is reached before `n` bytes can be + read. The IncompleteReadError.partial attribute of the exception will + contain the partial read bytes. + + if n is zero, return empty bytes object. + + Returned value is not limited with limit, configured at stream + creation. + + If stream was paused, this function will automatically resume it if + needed. + """ + if n < 0: + raise ValueError('readexactly size can not be less than zero') + + if self._exception is not None: + raise self._exception + + if n == 0: + return b'' + + while len(self._buffer) < n: + if self._eof: + incomplete = bytes(self._buffer) + self._buffer.clear() + raise exceptions.IncompleteReadError(incomplete, n) + + await self._wait_for_data('readexactly') + + if len(self._buffer) == n: + data = bytes(self._buffer) + self._buffer.clear() + else: + data = bytes(self._buffer[:n]) + del self._buffer[:n] + self._maybe_resume_transport() + return data + + def __aiter__(self): + return self + + async def __anext__(self): + val = await self.readline() + if val == b'': + raise StopAsyncIteration + return val + + +# end legacy stream APIs + + +class _BaseStreamProtocol(FlowControlMixin, protocols.Protocol): + """Helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + _stream = None # initialized in derived classes + + def __init__(self, loop=None, + *, _asyncio_internal=False): + super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) + self._transport = None + self._over_ssl = False + self._closed = self._loop.create_future() + + def connection_made(self, transport): + self._transport = transport + self._over_ssl = transport.get_extra_info('sslcontext') is not None + + def connection_lost(self, exc): + stream = self._stream + if stream is not None: + if exc is None: + stream.feed_eof() + else: + stream.set_exception(exc) + if not self._closed.done(): + if exc is None: + self._closed.set_result(None) + else: + self._closed.set_exception(exc) + super().connection_lost(exc) + self._transport = None + + def data_received(self, data): + stream = self._stream + if stream is not None: + stream.feed_data(data) + + def eof_received(self): + stream = self._stream + if stream is not None: + stream.feed_eof() + if self._over_ssl: + # Prevent a warning in SSLProtocol.eof_received: + # "returning true from eof_received() + # has no effect when using ssl" + return False + return True + + def _get_close_waiter(self, stream): + return self._closed + + def __del__(self): + # Prevent reports about unhandled exceptions. + # Better than self._closed._log_traceback = False hack + closed = self._get_close_waiter(self._stream) + if closed.done() and not closed.cancelled(): + closed.exception() + + +class _StreamProtocol(_BaseStreamProtocol): _source_traceback = None - def __init__(self, limit=_DEFAULT_LIMIT, loop=None, + def __init__(self, stream, loop=None, *, _asyncio_internal=False): + super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) + self._source_traceback = stream._source_traceback + self._stream_wr = weakref.ref(stream, self._on_gc) + self._reject_connection = False + + def _on_gc(self, wr): + transport = self._transport + if transport is not None: + # connection_made was called + context = { + 'message': ('An open stream object is being garbage ' + 'collected; call "stream.close()" explicitly.') + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + transport.abort() + else: + self._reject_connection = True + self._stream_wr = None + + @property + def _stream(self): + if self._stream_wr is None: + return None + return self._stream_wr() + + def connection_made(self, transport): + if self._reject_connection: + context = { + 'message': ('An open stream was garbage collected prior to ' + 'establishing network connection; ' + 'call "stream.close()" explicitly.') + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + transport.abort() + return + super().connection_made(transport) + stream = self._stream + if stream is None: + return + stream.set_transport(transport) + stream._protocol = self + + def connection_lost(self, exc): + super().connection_lost(exc) + self._stream_wr = None + + +class _ServerStreamProtocol(_BaseStreamProtocol): + def __init__(self, server, limit, client_connected_cb, loop=None, + *, _asyncio_internal=False): + super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) + assert self._closed + self._client_connected_cb = client_connected_cb + self._limit = limit + self._server = server + self._task = None + + def connection_made(self, transport): + super().connection_made(transport) + stream = Stream(mode=StreamMode.READWRITE, + transport=transport, + protocol=self, + limit=self._limit, + loop=self._loop, + is_server_side=True, + _asyncio_internal=True) + self._stream = stream + # If self._client_connected_cb(self._stream) fails + # the exception is logged by transport + self._task = self._loop.create_task( + self._client_connected_cb(self._stream)) + self._server._attach(stream, self._task) + + def connection_lost(self, exc): + super().connection_lost(exc) + self._server._detach(self._stream, self._task) + self._stream = None + + +class _OptionalAwait: + # The class doesn't create a coroutine + # if not awaited + # It prevents "coroutine is never awaited" message + + __slots___ = ('_method',) + + def __init__(self, method): + self._method = method + + def __await__(self): + return self._method().__await__() + + +class Stream: + """Wraps a Transport. + + This exposes write(), writelines(), [can_]write_eof(), + get_extra_info() and close(). It adds drain() which returns an + optional Future on which you can wait for flow control. It also + adds a transport property which references the Transport + directly. + """ + + _source_traceback = None + + def __init__(self, mode, *, + transport=None, + protocol=None, + loop=None, + limit=_DEFAULT_LIMIT, + is_server_side=False, + _asyncio_internal=False): if not _asyncio_internal: warnings.warn(f"{self.__class__} should be instaniated " "by asyncio internals only, " "please avoid its creation from user code", DeprecationWarning) + self._mode = mode + self._transport = transport + self._protocol = protocol + self._is_server_side = is_server_side # The line length limit is a security feature; # it also doubles as half the buffer limit. @@ -470,14 +1301,17 @@ class StreamReader: self._eof = False # Whether we're done. self._waiter = None # A future used by _wait_for_data() self._exception = None - self._transport = None self._paused = False + self._complete_fut = self._loop.create_future() + self._complete_fut.set_result(None) + if self._loop.get_debug(): self._source_traceback = format_helpers.extract_stack( sys._getframe(1)) def __repr__(self): - info = ['StreamReader'] + info = [self.__class__.__name__] + info.append(f'mode={self._mode}') if self._buffer: info.append(f'{len(self._buffer)} bytes') if self._eof: @@ -494,6 +1328,110 @@ class StreamReader: info.append('paused') return '<{}>'.format(' '.join(info)) + @property + def mode(self): + return self._mode + + def is_server_side(self): + return self._is_server_side + + @property + def transport(self): + return self._transport + + def write(self, data): + _ensure_can_write(self._mode) + self._transport.write(data) + return self._fast_drain() + + def writelines(self, data): + _ensure_can_write(self._mode) + self._transport.writelines(data) + return self._fast_drain() + + def _fast_drain(self): + # The helper tries to use fast-path to return already existing + # complete future object if underlying transport is not paused + #and actual waiting for writing resume is not needed + exc = self.exception() + if exc is not None: + fut = self._loop.create_future() + fut.set_exception(exc) + return fut + if not self._transport.is_closing(): + if self._protocol._connection_lost: + fut = self._loop.create_future() + fut.set_exception(ConnectionResetError('Connection lost')) + return fut + if not self._protocol._paused: + # fast path, the stream is not paused + # no need to wait for resume signal + return self._complete_fut + return _OptionalAwait(self.drain) + + def write_eof(self): + _ensure_can_write(self._mode) + return self._transport.write_eof() + + def can_write_eof(self): + if not self._mode.is_write(): + return False + return self._transport.can_write_eof() + + def close(self): + self._transport.close() + return _OptionalAwait(self.wait_closed) + + def is_closing(self): + return self._transport.is_closing() + + async def abort(self): + self._transport.abort() + await self.wait_closed() + + async def wait_closed(self): + await self._protocol._get_close_waiter(self) + + def get_extra_info(self, name, default=None): + return self._transport.get_extra_info(name, default) + + async def drain(self): + """Flush the write buffer. + + The intended use is to write + + w.write(data) + await w.drain() + """ + _ensure_can_write(self._mode) + exc = self.exception() + if exc is not None: + raise exc + if self._transport.is_closing(): + # Wait for protocol.connection_lost() call + # Raise connection closing error if any, + # ConnectionResetError otherwise + await tasks.sleep(0) + await self._protocol._drain_helper() + + async def sendfile(self, file, offset=0, count=None, *, fallback=True): + await self.drain() # check for stream mode and exceptions + return await self._loop.sendfile(self._transport, file, + offset, count, fallback=fallback) + + async def start_tls(self, sslcontext, *, + server_hostname=None, + ssl_handshake_timeout=None): + await self.drain() # check for stream mode and exceptions + transport = await self._loop.start_tls( + self._transport, self._protocol, sslcontext, + server_side=self._is_server_side, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) + self._transport = transport + self._protocol._transport = transport + self._protocol._over_ssl = True + def exception(self): return self._exception @@ -515,6 +1453,8 @@ class StreamReader: waiter.set_result(None) def set_transport(self, transport): + if transport is self._transport: + return assert self._transport is None, 'Transport already set' self._transport = transport @@ -532,6 +1472,7 @@ class StreamReader: return self._eof and not self._buffer def feed_data(self, data): + _ensure_can_read(self._mode) assert not self._eof, 'feed_data after feed_eof' if not data: @@ -597,6 +1538,7 @@ class StreamReader: If stream was paused, this function will automatically resume it if needed. """ + _ensure_can_read(self._mode) sep = b'\n' seplen = len(sep) try: @@ -632,6 +1574,7 @@ class StreamReader: LimitOverrunError exception will be raised, and the data will be left in the internal buffer, so it can be read again. """ + _ensure_can_read(self._mode) seplen = len(separator) if seplen == 0: raise ValueError('Separator should be at least one-byte string') @@ -723,6 +1666,7 @@ class StreamReader: If stream was paused, this function will automatically resume it if needed. """ + _ensure_can_read(self._mode) if self._exception is not None: raise self._exception @@ -768,6 +1712,7 @@ class StreamReader: If stream was paused, this function will automatically resume it if needed. """ + _ensure_can_read(self._mode) if n < 0: raise ValueError('readexactly size can not be less than zero') @@ -795,6 +1740,7 @@ class StreamReader: return data def __aiter__(self): + _ensure_can_read(self._mode) return self async def __anext__(self): @@ -802,3 +1748,9 @@ class StreamReader: if val == b'': raise StopAsyncIteration return val + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() |