summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio/streams.py
diff options
context:
space:
mode:
authorAndrew Svetlov <andrew.svetlov@gmail.com>2019-05-27 19:56:22 (GMT)
committerMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>2019-05-27 19:56:22 (GMT)
commit23b4b697e5b6cc897696f9c0288c187d2d24bff2 (patch)
tree2f70e14fe527878cd69ccbefca007a1e987943ed /Lib/asyncio/streams.py
parent6f6ff8a56518a80da406aad6ac8364c046cc7f18 (diff)
downloadcpython-23b4b697e5b6cc897696f9c0288c187d2d24bff2.zip
cpython-23b4b697e5b6cc897696f9c0288c187d2d24bff2.tar.gz
cpython-23b4b697e5b6cc897696f9c0288c187d2d24bff2.tar.bz2
bpo-36889: Merge asyncio streams (GH-13251)
https://bugs.python.org/issue36889
Diffstat (limited to 'Lib/asyncio/streams.py')
-rw-r--r--Lib/asyncio/streams.py1236
1 files changed, 1094 insertions, 142 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index 2f0cbfd..480f1a3 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -1,14 +1,19 @@
__all__ = (
- 'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
- 'open_connection', 'start_server')
+ 'Stream', 'StreamMode',
+ 'open_connection', 'start_server',
+ 'connect', 'connect_read_pipe', 'connect_write_pipe',
+ 'StreamServer')
+import enum
import socket
import sys
import warnings
import weakref
if hasattr(socket, 'AF_UNIX'):
- __all__ += ('open_unix_connection', 'start_unix_server')
+ __all__ += ('open_unix_connection', 'start_unix_server',
+ 'connect_unix',
+ 'UnixStreamServer')
from . import coroutines
from . import events
@@ -16,12 +21,134 @@ from . import exceptions
from . import format_helpers
from . import protocols
from .log import logger
-from .tasks import sleep
+from . import tasks
_DEFAULT_LIMIT = 2 ** 16 # 64 KiB
+class StreamMode(enum.Flag):
+ READ = enum.auto()
+ WRITE = enum.auto()
+ READWRITE = READ | WRITE
+
+
+def _ensure_can_read(mode):
+ if not mode & StreamMode.READ:
+ raise RuntimeError("The stream is write-only")
+
+
+def _ensure_can_write(mode):
+ if not mode & StreamMode.WRITE:
+ raise RuntimeError("The stream is read-only")
+
+
+class _ContextManagerHelper:
+ __slots__ = ('_awaitable', '_result')
+
+ def __init__(self, awaitable):
+ self._awaitable = awaitable
+ self._result = None
+
+ def __await__(self):
+ return self._awaitable.__await__()
+
+ async def __aenter__(self):
+ ret = await self._awaitable
+ result = await ret.__aenter__()
+ self._result = result
+ return result
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ return await self._result.__aexit__(exc_type, exc_val, exc_tb)
+
+
+def connect(host=None, port=None, *,
+ limit=_DEFAULT_LIMIT,
+ ssl=None, family=0, proto=0,
+ flags=0, sock=None, local_addr=None,
+ server_hostname=None,
+ ssl_handshake_timeout=None,
+ happy_eyeballs_delay=None, interleave=None):
+ # Design note:
+ # Don't use decorator approach but exilicit non-async
+ # function to fail fast and explicitly
+ # if passed arguments don't match the function signature
+ return _ContextManagerHelper(_connect(host, port, limit,
+ ssl, family, proto,
+ flags, sock, local_addr,
+ server_hostname,
+ ssl_handshake_timeout,
+ happy_eyeballs_delay,
+ interleave))
+
+
+async def _connect(host, port,
+ limit,
+ ssl, family, proto,
+ flags, sock, local_addr,
+ server_hostname,
+ ssl_handshake_timeout,
+ happy_eyeballs_delay, interleave):
+ loop = events.get_running_loop()
+ stream = Stream(mode=StreamMode.READWRITE,
+ limit=limit,
+ loop=loop,
+ _asyncio_internal=True)
+ await loop.create_connection(
+ lambda: _StreamProtocol(stream, loop=loop,
+ _asyncio_internal=True),
+ host, port,
+ ssl=ssl, family=family, proto=proto,
+ flags=flags, sock=sock, local_addr=local_addr,
+ server_hostname=server_hostname,
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ happy_eyeballs_delay=happy_eyeballs_delay, interleave=interleave)
+ return stream
+
+
+def connect_read_pipe(pipe, *, limit=_DEFAULT_LIMIT):
+ # Design note:
+ # Don't use decorator approach but explicit non-async
+ # function to fail fast and explicitly
+ # if passed arguments don't match the function signature
+ return _ContextManagerHelper(_connect_read_pipe(pipe, limit))
+
+
+async def _connect_read_pipe(pipe, limit):
+ loop = events.get_running_loop()
+ stream = Stream(mode=StreamMode.READ,
+ limit=limit,
+ loop=loop,
+ _asyncio_internal=True)
+ await loop.connect_read_pipe(
+ lambda: _StreamProtocol(stream, loop=loop,
+ _asyncio_internal=True),
+ pipe)
+ return stream
+
+
+def connect_write_pipe(pipe, *, limit=_DEFAULT_LIMIT):
+ # Design note:
+ # Don't use decorator approach but explicit non-async
+ # function to fail fast and explicitly
+ # if passed arguments don't match the function signature
+ return _ContextManagerHelper(_connect_write_pipe(pipe, limit))
+
+
+async def _connect_write_pipe(pipe, limit):
+ loop = events.get_running_loop()
+ stream = Stream(mode=StreamMode.WRITE,
+ limit=limit,
+ loop=loop,
+ _asyncio_internal=True)
+ await loop.connect_write_pipe(
+ lambda: _StreamProtocol(stream, loop=loop,
+ _asyncio_internal=True),
+ pipe)
+ return stream
+
+
async def open_connection(host=None, port=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""A wrapper for create_connection() returning a (reader, writer) pair.
@@ -41,16 +168,18 @@ async def open_connection(host=None, port=None, *,
StreamReaderProtocol classes, just copy the code -- there's
really nothing special here except some convenience.)
"""
+ warnings.warn("open_connection() is deprecated since Python 3.8 "
+ "in favor of connect(), and scheduled for removal "
+ "in Python 3.10",
+ DeprecationWarning,
+ stacklevel=2)
if loop is None:
loop = events.get_event_loop()
- reader = StreamReader(limit=limit, loop=loop,
- _asyncio_internal=True)
- protocol = StreamReaderProtocol(reader, loop=loop,
- _asyncio_internal=True)
+ reader = StreamReader(limit=limit, loop=loop)
+ protocol = StreamReaderProtocol(reader, loop=loop, _asyncio_internal=True)
transport, _ = await loop.create_connection(
lambda: protocol, host, port, **kwds)
- writer = StreamWriter(transport, protocol, reader, loop,
- _asyncio_internal=True)
+ writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
@@ -77,12 +206,16 @@ async def start_server(client_connected_cb, host=None, port=None, *,
The return value is the same as loop.create_server(), i.e. a
Server object which can be used to stop the service.
"""
+ warnings.warn("start_server() is deprecated since Python 3.8 "
+ "in favor of StreamServer(), and scheduled for removal "
+ "in Python 3.10",
+ DeprecationWarning,
+ stacklevel=2)
if loop is None:
loop = events.get_event_loop()
def factory():
- reader = StreamReader(limit=limit, loop=loop,
- _asyncio_internal=True)
+ reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop,
_asyncio_internal=True)
@@ -91,33 +224,258 @@ async def start_server(client_connected_cb, host=None, port=None, *,
return await loop.create_server(factory, host, port, **kwds)
+class _BaseStreamServer:
+ # Design notes.
+ # StreamServer and UnixStreamServer are exposed as FINAL classes,
+ # not function factories.
+ # async with serve(host, port) as server:
+ # server.start_serving()
+ # looks ugly.
+ # The class doesn't provide API for enumerating connected streams
+ # It can be a subject for improvements in Python 3.9
+
+ _server_impl = None
+
+ def __init__(self, client_connected_cb,
+ /,
+ limit=_DEFAULT_LIMIT,
+ shutdown_timeout=60,
+ _asyncio_internal=False):
+ if not _asyncio_internal:
+ raise RuntimeError("_ServerStream is a private asyncio class")
+ self._client_connected_cb = client_connected_cb
+ self._limit = limit
+ self._loop = events.get_running_loop()
+ self._streams = {}
+ self._shutdown_timeout = shutdown_timeout
+
+ def __init_subclass__(cls):
+ if not cls.__module__.startswith('asyncio.'):
+ raise TypeError(f"asyncio.{cls.__name__} "
+ "class cannot be inherited from")
+
+ async def bind(self):
+ if self._server_impl is not None:
+ return
+ self._server_impl = await self._bind()
+
+ def is_bound(self):
+ return self._server_impl is not None
+
+ @property
+ def sockets(self):
+ # multiple value for socket bound to both IPv4 and IPv6 families
+ if self._server_impl is None:
+ return ()
+ return self._server_impl.sockets
+
+ def is_serving(self):
+ if self._server_impl is None:
+ return False
+ return self._server_impl.is_serving()
+
+ async def start_serving(self):
+ await self.bind()
+ await self._server_impl.start_serving()
+
+ async def serve_forever(self):
+ await self.start_serving()
+ await self._server_impl.serve_forever()
+
+ async def close(self):
+ if self._server_impl is None:
+ return
+ self._server_impl.close()
+ streams = list(self._streams.keys())
+ active_tasks = list(self._streams.values())
+ if streams:
+ await tasks.wait([stream.close() for stream in streams])
+ await self._server_impl.wait_closed()
+ self._server_impl = None
+ await self._shutdown_active_tasks(active_tasks)
+
+ async def abort(self):
+ if self._server_impl is None:
+ return
+ self._server_impl.close()
+ streams = list(self._streams.keys())
+ active_tasks = list(self._streams.values())
+ if streams:
+ await tasks.wait([stream.abort() for stream in streams])
+ await self._server_impl.wait_closed()
+ self._server_impl = None
+ await self._shutdown_active_tasks(active_tasks)
+
+ async def __aenter__(self):
+ await self.bind()
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, exc_tb):
+ await self.close()
+
+ def _attach(self, stream, task):
+ self._streams[stream] = task
+
+ def _detach(self, stream, task):
+ del self._streams[stream]
+
+ async def _shutdown_active_tasks(self, active_tasks):
+ if not active_tasks:
+ return
+ # NOTE: tasks finished with exception are reported
+ # by the Task.__del__() method.
+ done, pending = await tasks.wait(active_tasks,
+ timeout=self._shutdown_timeout)
+ if not pending:
+ return
+ for task in pending:
+ task.cancel()
+ done, pending = await tasks.wait(pending,
+ timeout=self._shutdown_timeout)
+ for task in pending:
+ self._loop.call_exception_handler({
+ "message": (f'{task!r} ignored cancellation request '
+ f'from a closing {self!r}'),
+ "stream_server": self
+ })
+
+ def __repr__(self):
+ ret = [f'{self.__class__.__name__}']
+ if self.is_serving():
+ ret.append('serving')
+ if self.sockets:
+ ret.append(f'sockets={self.sockets!r}')
+ return '<' + ' '.join(ret) + '>'
+
+ def __del__(self, _warn=warnings.warn):
+ if self._server_impl is not None:
+ _warn(f"unclosed stream server {self!r}",
+ ResourceWarning, source=self)
+ self._server_impl.close()
+
+
+class StreamServer(_BaseStreamServer):
+
+ def __init__(self, client_connected_cb, /, host=None, port=None, *,
+ limit=_DEFAULT_LIMIT,
+ family=socket.AF_UNSPEC,
+ flags=socket.AI_PASSIVE, sock=None, backlog=100,
+ ssl=None, reuse_address=None, reuse_port=None,
+ ssl_handshake_timeout=None,
+ shutdown_timeout=60):
+ super().__init__(client_connected_cb,
+ limit=limit,
+ shutdown_timeout=shutdown_timeout,
+ _asyncio_internal=True)
+ self._host = host
+ self._port = port
+ self._family = family
+ self._flags = flags
+ self._sock = sock
+ self._backlog = backlog
+ self._ssl = ssl
+ self._reuse_address = reuse_address
+ self._reuse_port = reuse_port
+ self._ssl_handshake_timeout = ssl_handshake_timeout
+
+ async def _bind(self):
+ def factory():
+ protocol = _ServerStreamProtocol(self,
+ self._limit,
+ self._client_connected_cb,
+ loop=self._loop,
+ _asyncio_internal=True)
+ return protocol
+ return await self._loop.create_server(
+ factory,
+ self._host,
+ self._port,
+ start_serving=False,
+ family=self._family,
+ flags=self._flags,
+ sock=self._sock,
+ backlog=self._backlog,
+ ssl=self._ssl,
+ reuse_address=self._reuse_address,
+ reuse_port=self._reuse_port,
+ ssl_handshake_timeout=self._ssl_handshake_timeout)
+
+
if hasattr(socket, 'AF_UNIX'):
# UNIX Domain Sockets are supported on this platform
async def open_unix_connection(path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
+ warnings.warn("open_unix_connection() is deprecated since Python 3.8 "
+ "in favor of connect_unix(), and scheduled for removal "
+ "in Python 3.10",
+ DeprecationWarning,
+ stacklevel=2)
if loop is None:
loop = events.get_event_loop()
- reader = StreamReader(limit=limit, loop=loop,
- _asyncio_internal=True)
+ reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, loop=loop,
_asyncio_internal=True)
transport, _ = await loop.create_unix_connection(
lambda: protocol, path, **kwds)
- writer = StreamWriter(transport, protocol, reader, loop,
- _asyncio_internal=True)
+ writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
+
+ def connect_unix(path=None, *,
+ limit=_DEFAULT_LIMIT,
+ ssl=None, sock=None,
+ server_hostname=None,
+ ssl_handshake_timeout=None):
+ """Similar to `connect()` but works with UNIX Domain Sockets."""
+ # Design note:
+ # Don't use decorator approach but exilicit non-async
+ # function to fail fast and explicitly
+ # if passed arguments don't match the function signature
+ return _ContextManagerHelper(_connect_unix(path,
+ limit,
+ ssl, sock,
+ server_hostname,
+ ssl_handshake_timeout))
+
+
+ async def _connect_unix(path,
+ limit,
+ ssl, sock,
+ server_hostname,
+ ssl_handshake_timeout):
+ """Similar to `connect()` but works with UNIX Domain Sockets."""
+ loop = events.get_running_loop()
+ stream = Stream(mode=StreamMode.READWRITE,
+ limit=limit,
+ loop=loop,
+ _asyncio_internal=True)
+ await loop.create_unix_connection(
+ lambda: _StreamProtocol(stream,
+ loop=loop,
+ _asyncio_internal=True),
+ path,
+ ssl=ssl,
+ sock=sock,
+ server_hostname=server_hostname,
+ ssl_handshake_timeout=ssl_handshake_timeout)
+ return stream
+
+
async def start_unix_server(client_connected_cb, path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `start_server` but works with UNIX Domain Sockets."""
+ warnings.warn("start_unix_server() is deprecated since Python 3.8 "
+ "in favor of UnixStreamServer(), and scheduled "
+ "for removal in Python 3.10",
+ DeprecationWarning,
+ stacklevel=2)
if loop is None:
loop = events.get_event_loop()
def factory():
- reader = StreamReader(limit=limit, loop=loop,
- _asyncio_internal=True)
+ reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop,
_asyncio_internal=True)
@@ -125,6 +483,42 @@ if hasattr(socket, 'AF_UNIX'):
return await loop.create_unix_server(factory, path, **kwds)
+ class UnixStreamServer(_BaseStreamServer):
+
+ def __init__(self, client_connected_cb, /, path=None, *,
+ limit=_DEFAULT_LIMIT,
+ sock=None,
+ backlog=100,
+ ssl=None,
+ ssl_handshake_timeout=None,
+ shutdown_timeout=60):
+ super().__init__(client_connected_cb,
+ limit=limit,
+ shutdown_timeout=shutdown_timeout,
+ _asyncio_internal=True)
+ self._path = path
+ self._sock = sock
+ self._backlog = backlog
+ self._ssl = ssl
+ self._ssl_handshake_timeout = ssl_handshake_timeout
+
+ async def _bind(self):
+ def factory():
+ protocol = _ServerStreamProtocol(self,
+ self._limit,
+ self._client_connected_cb,
+ loop=self._loop,
+ _asyncio_internal=True)
+ return protocol
+ return await self._loop.create_unix_server(
+ factory,
+ self._path,
+ start_serving=False,
+ sock=self._sock,
+ backlog=self._backlog,
+ ssl=self._ssl,
+ ssl_handshake_timeout=self._ssl_handshake_timeout)
+
class FlowControlMixin(protocols.Protocol):
"""Reusable flow control logic for StreamWriter.drain().
@@ -203,6 +597,8 @@ class FlowControlMixin(protocols.Protocol):
raise NotImplementedError
+# begin legacy stream APIs
+
class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
"""Helper class to adapt between Protocol and StreamReader.
@@ -212,105 +608,47 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
call inappropriate methods of the protocol.)
"""
- _source_traceback = None
-
def __init__(self, stream_reader, client_connected_cb=None, loop=None,
*, _asyncio_internal=False):
super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
- if stream_reader is not None:
- self._stream_reader_wr = weakref.ref(stream_reader,
- self._on_reader_gc)
- self._source_traceback = stream_reader._source_traceback
- else:
- self._stream_reader_wr = None
- if client_connected_cb is not None:
- # This is a stream created by the `create_server()` function.
- # Keep a strong reference to the reader until a connection
- # is established.
- self._strong_reader = stream_reader
- self._reject_connection = False
+ self._stream_reader = stream_reader
self._stream_writer = None
- self._transport = None
self._client_connected_cb = client_connected_cb
self._over_ssl = False
self._closed = self._loop.create_future()
- def _on_reader_gc(self, wr):
- transport = self._transport
- if transport is not None:
- # connection_made was called
- context = {
- 'message': ('An open stream object is being garbage '
- 'collected; call "stream.close()" explicitly.')
- }
- if self._source_traceback:
- context['source_traceback'] = self._source_traceback
- self._loop.call_exception_handler(context)
- transport.abort()
- else:
- self._reject_connection = True
- self._stream_reader_wr = None
-
- @property
- def _stream_reader(self):
- if self._stream_reader_wr is None:
- return None
- return self._stream_reader_wr()
-
def connection_made(self, transport):
- if self._reject_connection:
- context = {
- 'message': ('An open stream was garbage collected prior to '
- 'establishing network connection; '
- 'call "stream.close()" explicitly.')
- }
- if self._source_traceback:
- context['source_traceback'] = self._source_traceback
- self._loop.call_exception_handler(context)
- transport.abort()
- return
- self._transport = transport
- reader = self._stream_reader
- if reader is not None:
- reader.set_transport(transport)
+ self._stream_reader.set_transport(transport)
self._over_ssl = transport.get_extra_info('sslcontext') is not None
if self._client_connected_cb is not None:
self._stream_writer = StreamWriter(transport, self,
- reader,
- self._loop,
- _asyncio_internal=True)
- res = self._client_connected_cb(reader,
+ self._stream_reader,
+ self._loop)
+ res = self._client_connected_cb(self._stream_reader,
self._stream_writer)
if coroutines.iscoroutine(res):
self._loop.create_task(res)
- self._strong_reader = None
def connection_lost(self, exc):
- reader = self._stream_reader
- if reader is not None:
+ if self._stream_reader is not None:
if exc is None:
- reader.feed_eof()
+ self._stream_reader.feed_eof()
else:
- reader.set_exception(exc)
+ self._stream_reader.set_exception(exc)
if not self._closed.done():
if exc is None:
self._closed.set_result(None)
else:
self._closed.set_exception(exc)
super().connection_lost(exc)
- self._stream_reader_wr = None
+ self._stream_reader = None
self._stream_writer = None
- self._transport = None
def data_received(self, data):
- reader = self._stream_reader
- if reader is not None:
- reader.feed_data(data)
+ self._stream_reader.feed_data(data)
def eof_received(self):
- reader = self._stream_reader
- if reader is not None:
- reader.feed_eof()
+ self._stream_reader.feed_eof()
if self._over_ssl:
# Prevent a warning in SSLProtocol.eof_received:
# "returning true from eof_received()
@@ -318,9 +656,6 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
return False
return True
- def _get_close_waiter(self, stream):
- return self._closed
-
def __del__(self):
# Prevent reports about unhandled exceptions.
# Better than self._closed._log_traceback = False hack
@@ -329,13 +664,6 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
closed.exception()
-def _swallow_unhandled_exception(task):
- # Do a trick to suppress unhandled exception
- # if stream.write() was used without await and
- # stream.drain() was paused and resumed with an exception
- task.exception()
-
-
class StreamWriter:
"""Wraps a Transport.
@@ -346,21 +674,13 @@ class StreamWriter:
directly.
"""
- def __init__(self, transport, protocol, reader, loop,
- *, _asyncio_internal=False):
- if not _asyncio_internal:
- warnings.warn(f"{self.__class__} should be instaniated "
- "by asyncio internals only, "
- "please avoid its creation from user code",
- DeprecationWarning)
+ def __init__(self, transport, protocol, reader, loop):
self._transport = transport
self._protocol = protocol
# drain() expects that the reader has an exception() method
assert reader is None or isinstance(reader, StreamReader)
self._reader = reader
self._loop = loop
- self._complete_fut = self._loop.create_future()
- self._complete_fut.set_result(None)
def __repr__(self):
info = [self.__class__.__name__, f'transport={self._transport!r}']
@@ -374,35 +694,9 @@ class StreamWriter:
def write(self, data):
self._transport.write(data)
- return self._fast_drain()
def writelines(self, data):
self._transport.writelines(data)
- return self._fast_drain()
-
- def _fast_drain(self):
- # The helper tries to use fast-path to return already existing complete future
- # object if underlying transport is not paused and actual waiting for writing
- # resume is not needed
- if self._reader is not None:
- # this branch will be simplified after merging reader with writer
- exc = self._reader.exception()
- if exc is not None:
- fut = self._loop.create_future()
- fut.set_exception(exc)
- return fut
- if not self._transport.is_closing():
- if self._protocol._connection_lost:
- fut = self._loop.create_future()
- fut.set_exception(ConnectionResetError('Connection lost'))
- return fut
- if not self._protocol._paused:
- # fast path, the stream is not paused
- # no need to wait for resume signal
- return self._complete_fut
- ret = self._loop.create_task(self.drain())
- ret.add_done_callback(_swallow_unhandled_exception)
- return ret
def write_eof(self):
return self._transport.write_eof()
@@ -411,14 +705,13 @@ class StreamWriter:
return self._transport.can_write_eof()
def close(self):
- self._transport.close()
- return self._protocol._get_close_waiter(self)
+ return self._transport.close()
def is_closing(self):
return self._transport.is_closing()
async def wait_closed(self):
- await self._protocol._get_close_waiter(self)
+ await self._protocol._closed
def get_extra_info(self, name, default=None):
return self._transport.get_extra_info(name, default)
@@ -436,24 +729,562 @@ class StreamWriter:
if exc is not None:
raise exc
if self._transport.is_closing():
- # Wait for protocol.connection_lost() call
- # Raise connection closing error if any,
- # ConnectionResetError otherwise
- await sleep(0)
+ # Yield to the event loop so connection_lost() may be
+ # called. Without this, _drain_helper() would return
+ # immediately, and code that calls
+ # write(...); await drain()
+ # in a loop would never call connection_lost(), so it
+ # would not see an error when the socket is closed.
+ await tasks.sleep(0, loop=self._loop)
await self._protocol._drain_helper()
class StreamReader:
+ def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
+ # The line length limit is a security feature;
+ # it also doubles as half the buffer limit.
+
+ if limit <= 0:
+ raise ValueError('Limit cannot be <= 0')
+
+ self._limit = limit
+ if loop is None:
+ self._loop = events.get_event_loop()
+ else:
+ self._loop = loop
+ self._buffer = bytearray()
+ self._eof = False # Whether we're done.
+ self._waiter = None # A future used by _wait_for_data()
+ self._exception = None
+ self._transport = None
+ self._paused = False
+
+ def __repr__(self):
+ info = ['StreamReader']
+ if self._buffer:
+ info.append(f'{len(self._buffer)} bytes')
+ if self._eof:
+ info.append('eof')
+ if self._limit != _DEFAULT_LIMIT:
+ info.append(f'limit={self._limit}')
+ if self._waiter:
+ info.append(f'waiter={self._waiter!r}')
+ if self._exception:
+ info.append(f'exception={self._exception!r}')
+ if self._transport:
+ info.append(f'transport={self._transport!r}')
+ if self._paused:
+ info.append('paused')
+ return '<{}>'.format(' '.join(info))
+
+ def exception(self):
+ return self._exception
+
+ def set_exception(self, exc):
+ self._exception = exc
+
+ waiter = self._waiter
+ if waiter is not None:
+ self._waiter = None
+ if not waiter.cancelled():
+ waiter.set_exception(exc)
+
+ def _wakeup_waiter(self):
+ """Wakeup read*() functions waiting for data or EOF."""
+ waiter = self._waiter
+ if waiter is not None:
+ self._waiter = None
+ if not waiter.cancelled():
+ waiter.set_result(None)
+
+ def set_transport(self, transport):
+ assert self._transport is None, 'Transport already set'
+ self._transport = transport
+
+ def _maybe_resume_transport(self):
+ if self._paused and len(self._buffer) <= self._limit:
+ self._paused = False
+ self._transport.resume_reading()
+
+ def feed_eof(self):
+ self._eof = True
+ self._wakeup_waiter()
+
+ def at_eof(self):
+ """Return True if the buffer is empty and 'feed_eof' was called."""
+ return self._eof and not self._buffer
+
+ def feed_data(self, data):
+ assert not self._eof, 'feed_data after feed_eof'
+
+ if not data:
+ return
+
+ self._buffer.extend(data)
+ self._wakeup_waiter()
+
+ if (self._transport is not None and
+ not self._paused and
+ len(self._buffer) > 2 * self._limit):
+ try:
+ self._transport.pause_reading()
+ except NotImplementedError:
+ # The transport can't be paused.
+ # We'll just have to buffer all data.
+ # Forget the transport so we don't keep trying.
+ self._transport = None
+ else:
+ self._paused = True
+
+ async def _wait_for_data(self, func_name):
+ """Wait until feed_data() or feed_eof() is called.
+
+ If stream was paused, automatically resume it.
+ """
+ # StreamReader uses a future to link the protocol feed_data() method
+ # to a read coroutine. Running two read coroutines at the same time
+ # would have an unexpected behaviour. It would not possible to know
+ # which coroutine would get the next data.
+ if self._waiter is not None:
+ raise RuntimeError(
+ f'{func_name}() called while another coroutine is '
+ f'already waiting for incoming data')
+
+ assert not self._eof, '_wait_for_data after EOF'
+
+ # Waiting for data while paused will make deadlock, so prevent it.
+ # This is essential for readexactly(n) for case when n > self._limit.
+ if self._paused:
+ self._paused = False
+ self._transport.resume_reading()
+
+ self._waiter = self._loop.create_future()
+ try:
+ await self._waiter
+ finally:
+ self._waiter = None
+
+ async def readline(self):
+ """Read chunk of data from the stream until newline (b'\n') is found.
+
+ On success, return chunk that ends with newline. If only partial
+ line can be read due to EOF, return incomplete line without
+ terminating newline. When EOF was reached while no bytes read, empty
+ bytes object is returned.
+
+ If limit is reached, ValueError will be raised. In that case, if
+ newline was found, complete line including newline will be removed
+ from internal buffer. Else, internal buffer will be cleared. Limit is
+ compared against part of the line without newline.
+
+ If stream was paused, this function will automatically resume it if
+ needed.
+ """
+ sep = b'\n'
+ seplen = len(sep)
+ try:
+ line = await self.readuntil(sep)
+ except exceptions.IncompleteReadError as e:
+ return e.partial
+ except exceptions.LimitOverrunError as e:
+ if self._buffer.startswith(sep, e.consumed):
+ del self._buffer[:e.consumed + seplen]
+ else:
+ self._buffer.clear()
+ self._maybe_resume_transport()
+ raise ValueError(e.args[0])
+ return line
+
+ async def readuntil(self, separator=b'\n'):
+ """Read data from the stream until ``separator`` is found.
+
+ On success, the data and separator will be removed from the
+ internal buffer (consumed). Returned data will include the
+ separator at the end.
+
+ Configured stream limit is used to check result. Limit sets the
+ maximal length of data that can be returned, not counting the
+ separator.
+
+ If an EOF occurs and the complete separator is still not found,
+ an IncompleteReadError exception will be raised, and the internal
+ buffer will be reset. The IncompleteReadError.partial attribute
+ may contain the separator partially.
+
+ If the data cannot be read because of over limit, a
+ LimitOverrunError exception will be raised, and the data
+ will be left in the internal buffer, so it can be read again.
+ """
+ seplen = len(separator)
+ if seplen == 0:
+ raise ValueError('Separator should be at least one-byte string')
+
+ if self._exception is not None:
+ raise self._exception
+
+ # Consume whole buffer except last bytes, which length is
+ # one less than seplen. Let's check corner cases with
+ # separator='SEPARATOR':
+ # * we have received almost complete separator (without last
+ # byte). i.e buffer='some textSEPARATO'. In this case we
+ # can safely consume len(separator) - 1 bytes.
+ # * last byte of buffer is first byte of separator, i.e.
+ # buffer='abcdefghijklmnopqrS'. We may safely consume
+ # everything except that last byte, but this require to
+ # analyze bytes of buffer that match partial separator.
+ # This is slow and/or require FSM. For this case our
+ # implementation is not optimal, since require rescanning
+ # of data that is known to not belong to separator. In
+ # real world, separator will not be so long to notice
+ # performance problems. Even when reading MIME-encoded
+ # messages :)
+
+ # `offset` is the number of bytes from the beginning of the buffer
+ # where there is no occurrence of `separator`.
+ offset = 0
+
+ # Loop until we find `separator` in the buffer, exceed the buffer size,
+ # or an EOF has happened.
+ while True:
+ buflen = len(self._buffer)
+
+ # Check if we now have enough data in the buffer for `separator` to
+ # fit.
+ if buflen - offset >= seplen:
+ isep = self._buffer.find(separator, offset)
+
+ if isep != -1:
+ # `separator` is in the buffer. `isep` will be used later
+ # to retrieve the data.
+ break
+
+ # see upper comment for explanation.
+ offset = buflen + 1 - seplen
+ if offset > self._limit:
+ raise exceptions.LimitOverrunError(
+ 'Separator is not found, and chunk exceed the limit',
+ offset)
+
+ # Complete message (with full separator) may be present in buffer
+ # even when EOF flag is set. This may happen when the last chunk
+ # adds data which makes separator be found. That's why we check for
+ # EOF *ater* inspecting the buffer.
+ if self._eof:
+ chunk = bytes(self._buffer)
+ self._buffer.clear()
+ raise exceptions.IncompleteReadError(chunk, None)
+
+ # _wait_for_data() will resume reading if stream was paused.
+ await self._wait_for_data('readuntil')
+
+ if isep > self._limit:
+ raise exceptions.LimitOverrunError(
+ 'Separator is found, but chunk is longer than limit', isep)
+
+ chunk = self._buffer[:isep + seplen]
+ del self._buffer[:isep + seplen]
+ self._maybe_resume_transport()
+ return bytes(chunk)
+
+ async def read(self, n=-1):
+ """Read up to `n` bytes from the stream.
+
+ If n is not provided, or set to -1, read until EOF and return all read
+ bytes. If the EOF was received and the internal buffer is empty, return
+ an empty bytes object.
+
+ If n is zero, return empty bytes object immediately.
+
+ If n is positive, this function try to read `n` bytes, and may return
+ less or equal bytes than requested, but at least one byte. If EOF was
+ received before any byte is read, this function returns empty byte
+ object.
+
+ Returned value is not limited with limit, configured at stream
+ creation.
+
+ If stream was paused, this function will automatically resume it if
+ needed.
+ """
+
+ if self._exception is not None:
+ raise self._exception
+
+ if n == 0:
+ return b''
+
+ if n < 0:
+ # This used to just loop creating a new waiter hoping to
+ # collect everything in self._buffer, but that would
+ # deadlock if the subprocess sends more than self.limit
+ # bytes. So just call self.read(self._limit) until EOF.
+ blocks = []
+ while True:
+ block = await self.read(self._limit)
+ if not block:
+ break
+ blocks.append(block)
+ return b''.join(blocks)
+
+ if not self._buffer and not self._eof:
+ await self._wait_for_data('read')
+
+ # This will work right even if buffer is less than n bytes
+ data = bytes(self._buffer[:n])
+ del self._buffer[:n]
+
+ self._maybe_resume_transport()
+ return data
+
+ async def readexactly(self, n):
+ """Read exactly `n` bytes.
+
+ Raise an IncompleteReadError if EOF is reached before `n` bytes can be
+ read. The IncompleteReadError.partial attribute of the exception will
+ contain the partial read bytes.
+
+ if n is zero, return empty bytes object.
+
+ Returned value is not limited with limit, configured at stream
+ creation.
+
+ If stream was paused, this function will automatically resume it if
+ needed.
+ """
+ if n < 0:
+ raise ValueError('readexactly size can not be less than zero')
+
+ if self._exception is not None:
+ raise self._exception
+
+ if n == 0:
+ return b''
+
+ while len(self._buffer) < n:
+ if self._eof:
+ incomplete = bytes(self._buffer)
+ self._buffer.clear()
+ raise exceptions.IncompleteReadError(incomplete, n)
+
+ await self._wait_for_data('readexactly')
+
+ if len(self._buffer) == n:
+ data = bytes(self._buffer)
+ self._buffer.clear()
+ else:
+ data = bytes(self._buffer[:n])
+ del self._buffer[:n]
+ self._maybe_resume_transport()
+ return data
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ val = await self.readline()
+ if val == b'':
+ raise StopAsyncIteration
+ return val
+
+
+# end legacy stream APIs
+
+
+class _BaseStreamProtocol(FlowControlMixin, protocols.Protocol):
+ """Helper class to adapt between Protocol and StreamReader.
+
+ (This is a helper class instead of making StreamReader itself a
+ Protocol subclass, because the StreamReader has other potential
+ uses, and to prevent the user of the StreamReader to accidentally
+ call inappropriate methods of the protocol.)
+ """
+
+ _stream = None # initialized in derived classes
+
+ def __init__(self, loop=None,
+ *, _asyncio_internal=False):
+ super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
+ self._transport = None
+ self._over_ssl = False
+ self._closed = self._loop.create_future()
+
+ def connection_made(self, transport):
+ self._transport = transport
+ self._over_ssl = transport.get_extra_info('sslcontext') is not None
+
+ def connection_lost(self, exc):
+ stream = self._stream
+ if stream is not None:
+ if exc is None:
+ stream.feed_eof()
+ else:
+ stream.set_exception(exc)
+ if not self._closed.done():
+ if exc is None:
+ self._closed.set_result(None)
+ else:
+ self._closed.set_exception(exc)
+ super().connection_lost(exc)
+ self._transport = None
+
+ def data_received(self, data):
+ stream = self._stream
+ if stream is not None:
+ stream.feed_data(data)
+
+ def eof_received(self):
+ stream = self._stream
+ if stream is not None:
+ stream.feed_eof()
+ if self._over_ssl:
+ # Prevent a warning in SSLProtocol.eof_received:
+ # "returning true from eof_received()
+ # has no effect when using ssl"
+ return False
+ return True
+
+ def _get_close_waiter(self, stream):
+ return self._closed
+
+ def __del__(self):
+ # Prevent reports about unhandled exceptions.
+ # Better than self._closed._log_traceback = False hack
+ closed = self._get_close_waiter(self._stream)
+ if closed.done() and not closed.cancelled():
+ closed.exception()
+
+
+class _StreamProtocol(_BaseStreamProtocol):
_source_traceback = None
- def __init__(self, limit=_DEFAULT_LIMIT, loop=None,
+ def __init__(self, stream, loop=None,
*, _asyncio_internal=False):
+ super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
+ self._source_traceback = stream._source_traceback
+ self._stream_wr = weakref.ref(stream, self._on_gc)
+ self._reject_connection = False
+
+ def _on_gc(self, wr):
+ transport = self._transport
+ if transport is not None:
+ # connection_made was called
+ context = {
+ 'message': ('An open stream object is being garbage '
+ 'collected; call "stream.close()" explicitly.')
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ transport.abort()
+ else:
+ self._reject_connection = True
+ self._stream_wr = None
+
+ @property
+ def _stream(self):
+ if self._stream_wr is None:
+ return None
+ return self._stream_wr()
+
+ def connection_made(self, transport):
+ if self._reject_connection:
+ context = {
+ 'message': ('An open stream was garbage collected prior to '
+ 'establishing network connection; '
+ 'call "stream.close()" explicitly.')
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ transport.abort()
+ return
+ super().connection_made(transport)
+ stream = self._stream
+ if stream is None:
+ return
+ stream.set_transport(transport)
+ stream._protocol = self
+
+ def connection_lost(self, exc):
+ super().connection_lost(exc)
+ self._stream_wr = None
+
+
+class _ServerStreamProtocol(_BaseStreamProtocol):
+ def __init__(self, server, limit, client_connected_cb, loop=None,
+ *, _asyncio_internal=False):
+ super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
+ assert self._closed
+ self._client_connected_cb = client_connected_cb
+ self._limit = limit
+ self._server = server
+ self._task = None
+
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ stream = Stream(mode=StreamMode.READWRITE,
+ transport=transport,
+ protocol=self,
+ limit=self._limit,
+ loop=self._loop,
+ is_server_side=True,
+ _asyncio_internal=True)
+ self._stream = stream
+ # If self._client_connected_cb(self._stream) fails
+ # the exception is logged by transport
+ self._task = self._loop.create_task(
+ self._client_connected_cb(self._stream))
+ self._server._attach(stream, self._task)
+
+ def connection_lost(self, exc):
+ super().connection_lost(exc)
+ self._server._detach(self._stream, self._task)
+ self._stream = None
+
+
+class _OptionalAwait:
+ # The class doesn't create a coroutine
+ # if not awaited
+ # It prevents "coroutine is never awaited" message
+
+ __slots___ = ('_method',)
+
+ def __init__(self, method):
+ self._method = method
+
+ def __await__(self):
+ return self._method().__await__()
+
+
+class Stream:
+ """Wraps a Transport.
+
+ This exposes write(), writelines(), [can_]write_eof(),
+ get_extra_info() and close(). It adds drain() which returns an
+ optional Future on which you can wait for flow control. It also
+ adds a transport property which references the Transport
+ directly.
+ """
+
+ _source_traceback = None
+
+ def __init__(self, mode, *,
+ transport=None,
+ protocol=None,
+ loop=None,
+ limit=_DEFAULT_LIMIT,
+ is_server_side=False,
+ _asyncio_internal=False):
if not _asyncio_internal:
warnings.warn(f"{self.__class__} should be instaniated "
"by asyncio internals only, "
"please avoid its creation from user code",
DeprecationWarning)
+ self._mode = mode
+ self._transport = transport
+ self._protocol = protocol
+ self._is_server_side = is_server_side
# The line length limit is a security feature;
# it also doubles as half the buffer limit.
@@ -470,14 +1301,17 @@ class StreamReader:
self._eof = False # Whether we're done.
self._waiter = None # A future used by _wait_for_data()
self._exception = None
- self._transport = None
self._paused = False
+ self._complete_fut = self._loop.create_future()
+ self._complete_fut.set_result(None)
+
if self._loop.get_debug():
self._source_traceback = format_helpers.extract_stack(
sys._getframe(1))
def __repr__(self):
- info = ['StreamReader']
+ info = [self.__class__.__name__]
+ info.append(f'mode={self._mode}')
if self._buffer:
info.append(f'{len(self._buffer)} bytes')
if self._eof:
@@ -494,6 +1328,110 @@ class StreamReader:
info.append('paused')
return '<{}>'.format(' '.join(info))
+ @property
+ def mode(self):
+ return self._mode
+
+ def is_server_side(self):
+ return self._is_server_side
+
+ @property
+ def transport(self):
+ return self._transport
+
+ def write(self, data):
+ _ensure_can_write(self._mode)
+ self._transport.write(data)
+ return self._fast_drain()
+
+ def writelines(self, data):
+ _ensure_can_write(self._mode)
+ self._transport.writelines(data)
+ return self._fast_drain()
+
+ def _fast_drain(self):
+ # The helper tries to use fast-path to return already existing
+ # complete future object if underlying transport is not paused
+ #and actual waiting for writing resume is not needed
+ exc = self.exception()
+ if exc is not None:
+ fut = self._loop.create_future()
+ fut.set_exception(exc)
+ return fut
+ if not self._transport.is_closing():
+ if self._protocol._connection_lost:
+ fut = self._loop.create_future()
+ fut.set_exception(ConnectionResetError('Connection lost'))
+ return fut
+ if not self._protocol._paused:
+ # fast path, the stream is not paused
+ # no need to wait for resume signal
+ return self._complete_fut
+ return _OptionalAwait(self.drain)
+
+ def write_eof(self):
+ _ensure_can_write(self._mode)
+ return self._transport.write_eof()
+
+ def can_write_eof(self):
+ if not self._mode.is_write():
+ return False
+ return self._transport.can_write_eof()
+
+ def close(self):
+ self._transport.close()
+ return _OptionalAwait(self.wait_closed)
+
+ def is_closing(self):
+ return self._transport.is_closing()
+
+ async def abort(self):
+ self._transport.abort()
+ await self.wait_closed()
+
+ async def wait_closed(self):
+ await self._protocol._get_close_waiter(self)
+
+ def get_extra_info(self, name, default=None):
+ return self._transport.get_extra_info(name, default)
+
+ async def drain(self):
+ """Flush the write buffer.
+
+ The intended use is to write
+
+ w.write(data)
+ await w.drain()
+ """
+ _ensure_can_write(self._mode)
+ exc = self.exception()
+ if exc is not None:
+ raise exc
+ if self._transport.is_closing():
+ # Wait for protocol.connection_lost() call
+ # Raise connection closing error if any,
+ # ConnectionResetError otherwise
+ await tasks.sleep(0)
+ await self._protocol._drain_helper()
+
+ async def sendfile(self, file, offset=0, count=None, *, fallback=True):
+ await self.drain() # check for stream mode and exceptions
+ return await self._loop.sendfile(self._transport, file,
+ offset, count, fallback=fallback)
+
+ async def start_tls(self, sslcontext, *,
+ server_hostname=None,
+ ssl_handshake_timeout=None):
+ await self.drain() # check for stream mode and exceptions
+ transport = await self._loop.start_tls(
+ self._transport, self._protocol, sslcontext,
+ server_side=self._is_server_side,
+ server_hostname=server_hostname,
+ ssl_handshake_timeout=ssl_handshake_timeout)
+ self._transport = transport
+ self._protocol._transport = transport
+ self._protocol._over_ssl = True
+
def exception(self):
return self._exception
@@ -515,6 +1453,8 @@ class StreamReader:
waiter.set_result(None)
def set_transport(self, transport):
+ if transport is self._transport:
+ return
assert self._transport is None, 'Transport already set'
self._transport = transport
@@ -532,6 +1472,7 @@ class StreamReader:
return self._eof and not self._buffer
def feed_data(self, data):
+ _ensure_can_read(self._mode)
assert not self._eof, 'feed_data after feed_eof'
if not data:
@@ -597,6 +1538,7 @@ class StreamReader:
If stream was paused, this function will automatically resume it if
needed.
"""
+ _ensure_can_read(self._mode)
sep = b'\n'
seplen = len(sep)
try:
@@ -632,6 +1574,7 @@ class StreamReader:
LimitOverrunError exception will be raised, and the data
will be left in the internal buffer, so it can be read again.
"""
+ _ensure_can_read(self._mode)
seplen = len(separator)
if seplen == 0:
raise ValueError('Separator should be at least one-byte string')
@@ -723,6 +1666,7 @@ class StreamReader:
If stream was paused, this function will automatically resume it if
needed.
"""
+ _ensure_can_read(self._mode)
if self._exception is not None:
raise self._exception
@@ -768,6 +1712,7 @@ class StreamReader:
If stream was paused, this function will automatically resume it if
needed.
"""
+ _ensure_can_read(self._mode)
if n < 0:
raise ValueError('readexactly size can not be less than zero')
@@ -795,6 +1740,7 @@ class StreamReader:
return data
def __aiter__(self):
+ _ensure_can_read(self._mode)
return self
async def __anext__(self):
@@ -802,3 +1748,9 @@ class StreamReader:
if val == b'':
raise StopAsyncIteration
return val
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()