summaryrefslogtreecommitdiffstats
path: root/Lib/asyncio/streams.py
diff options
context:
space:
mode:
authorYury Selivanov <yury@edgedb.com>2019-09-30 04:59:55 (GMT)
committerGitHub <noreply@github.com>2019-09-30 04:59:55 (GMT)
commit6758e6e12a71ef5530146161881f88df1fa43382 (patch)
treeda1f89f35e54ddcfffc3706b87bb13f54907f7ea /Lib/asyncio/streams.py
parent3667e1ee6c90e6d3b6a745cd590ece87118f81ad (diff)
downloadcpython-6758e6e12a71ef5530146161881f88df1fa43382.zip
cpython-6758e6e12a71ef5530146161881f88df1fa43382.tar.gz
cpython-6758e6e12a71ef5530146161881f88df1fa43382.tar.bz2
bpo-38242: Revert "bpo-36889: Merge asyncio streams (GH-13251)" (#16482)
See https://bugs.python.org/issue38242 for more details
Diffstat (limited to 'Lib/asyncio/streams.py')
-rw-r--r--Lib/asyncio/streams.py1252
1 files changed, 91 insertions, 1161 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index b709dc1..795530e 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -1,19 +1,14 @@
__all__ = (
- 'Stream', 'StreamMode',
- 'open_connection', 'start_server',
- 'connect', 'connect_read_pipe', 'connect_write_pipe',
- 'StreamServer')
+ 'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
+ 'open_connection', 'start_server')
-import enum
import socket
import sys
import warnings
import weakref
if hasattr(socket, 'AF_UNIX'):
- __all__ += ('open_unix_connection', 'start_unix_server',
- 'connect_unix',
- 'UnixStreamServer')
+ __all__ += ('open_unix_connection', 'start_unix_server')
from . import coroutines
from . import events
@@ -21,155 +16,12 @@ from . import exceptions
from . import format_helpers
from . import protocols
from .log import logger
-from . import tasks
+from .tasks import sleep
_DEFAULT_LIMIT = 2 ** 16 # 64 KiB
-class StreamMode(enum.Flag):
- READ = enum.auto()
- WRITE = enum.auto()
- READWRITE = READ | WRITE
-
-
-def _ensure_can_read(mode):
- if not mode & StreamMode.READ:
- raise RuntimeError("The stream is write-only")
-
-
-def _ensure_can_write(mode):
- if not mode & StreamMode.WRITE:
- raise RuntimeError("The stream is read-only")
-
-
-class _ContextManagerHelper:
- __slots__ = ('_awaitable', '_result')
-
- def __init__(self, awaitable):
- self._awaitable = awaitable
- self._result = None
-
- def __await__(self):
- return self._awaitable.__await__()
-
- async def __aenter__(self):
- ret = await self._awaitable
- result = await ret.__aenter__()
- self._result = result
- return result
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- return await self._result.__aexit__(exc_type, exc_val, exc_tb)
-
-
-def connect(host=None, port=None, *,
- limit=_DEFAULT_LIMIT,
- ssl=None, family=0, proto=0,
- flags=0, sock=None, local_addr=None,
- server_hostname=None,
- ssl_handshake_timeout=None,
- happy_eyeballs_delay=None, interleave=None):
- """Connect to TCP socket on *host* : *port* address to send and receive data.
-
- *limit* determines the buffer size limit used by the returned `Stream`
- instance. By default the *limit* is set to 64 KiB.
-
- The rest of the arguments are passed directly to `loop.create_connection()`.
- """
- # Design note:
- # Don't use decorator approach but explicit non-async
- # function to fail fast and explicitly
- # if passed arguments don't match the function signature
- return _ContextManagerHelper(_connect(host, port, limit,
- ssl, family, proto,
- flags, sock, local_addr,
- server_hostname,
- ssl_handshake_timeout,
- happy_eyeballs_delay,
- interleave))
-
-
-async def _connect(host, port,
- limit,
- ssl, family, proto,
- flags, sock, local_addr,
- server_hostname,
- ssl_handshake_timeout,
- happy_eyeballs_delay, interleave):
- loop = events.get_running_loop()
- stream = Stream(mode=StreamMode.READWRITE,
- limit=limit,
- loop=loop,
- _asyncio_internal=True)
- await loop.create_connection(
- lambda: _StreamProtocol(stream, loop=loop,
- _asyncio_internal=True),
- host, port,
- ssl=ssl, family=family, proto=proto,
- flags=flags, sock=sock, local_addr=local_addr,
- server_hostname=server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout,
- happy_eyeballs_delay=happy_eyeballs_delay, interleave=interleave)
- return stream
-
-
-def connect_read_pipe(pipe, *, limit=_DEFAULT_LIMIT):
- """Establish a connection to a file-like object *pipe* to receive data.
-
- Takes a file-like object *pipe* to return a Stream object of the mode
- StreamMode.READ that has similar API of StreamReader. It can also be used
- as an async context manager.
- """
-
- # Design note:
- # Don't use decorator approach but explicit non-async
- # function to fail fast and explicitly
- # if passed arguments don't match the function signature
- return _ContextManagerHelper(_connect_read_pipe(pipe, limit))
-
-
-async def _connect_read_pipe(pipe, limit):
- loop = events.get_running_loop()
- stream = Stream(mode=StreamMode.READ,
- limit=limit,
- loop=loop,
- _asyncio_internal=True)
- await loop.connect_read_pipe(
- lambda: _StreamProtocol(stream, loop=loop,
- _asyncio_internal=True),
- pipe)
- return stream
-
-
-def connect_write_pipe(pipe, *, limit=_DEFAULT_LIMIT):
- """Establish a connection to a file-like object *pipe* to send data.
-
- Takes a file-like object *pipe* to return a Stream object of the mode
- StreamMode.WRITE that has similar API of StreamWriter. It can also be used
- as an async context manager.
- """
-
- # Design note:
- # Don't use decorator approach but explicit non-async
- # function to fail fast and explicitly
- # if passed arguments don't match the function signature
- return _ContextManagerHelper(_connect_write_pipe(pipe, limit))
-
-
-async def _connect_write_pipe(pipe, limit):
- loop = events.get_running_loop()
- stream = Stream(mode=StreamMode.WRITE,
- limit=limit,
- loop=loop,
- _asyncio_internal=True)
- await loop.connect_write_pipe(
- lambda: _StreamProtocol(stream, loop=loop,
- _asyncio_internal=True),
- pipe)
- return stream
-
-
async def open_connection(host=None, port=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""A wrapper for create_connection() returning a (reader, writer) pair.
@@ -189,11 +41,6 @@ async def open_connection(host=None, port=None, *,
StreamReaderProtocol classes, just copy the code -- there's
really nothing special here except some convenience.)
"""
- warnings.warn("open_connection() is deprecated since Python 3.8 "
- "in favor of connect(), and scheduled for removal "
- "in Python 3.10",
- DeprecationWarning,
- stacklevel=2)
if loop is None:
loop = events.get_event_loop()
else:
@@ -201,7 +48,7 @@ async def open_connection(host=None, port=None, *,
"and scheduled for removal in Python 3.10.",
DeprecationWarning, stacklevel=2)
reader = StreamReader(limit=limit, loop=loop)
- protocol = StreamReaderProtocol(reader, loop=loop, _asyncio_internal=True)
+ protocol = StreamReaderProtocol(reader, loop=loop)
transport, _ = await loop.create_connection(
lambda: protocol, host, port, **kwds)
writer = StreamWriter(transport, protocol, reader, loop)
@@ -231,11 +78,6 @@ async def start_server(client_connected_cb, host=None, port=None, *,
The return value is the same as loop.create_server(), i.e. a
Server object which can be used to stop the service.
"""
- warnings.warn("start_server() is deprecated since Python 3.8 "
- "in favor of StreamServer(), and scheduled for removal "
- "in Python 3.10",
- DeprecationWarning,
- stacklevel=2)
if loop is None:
loop = events.get_event_loop()
else:
@@ -246,201 +88,18 @@ async def start_server(client_connected_cb, host=None, port=None, *,
def factory():
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, client_connected_cb,
- loop=loop,
- _asyncio_internal=True)
+ loop=loop)
return protocol
return await loop.create_server(factory, host, port, **kwds)
-class _BaseStreamServer:
- # Design notes.
- # StreamServer and UnixStreamServer are exposed as FINAL classes,
- # not function factories.
- # async with serve(host, port) as server:
- # server.start_serving()
- # looks ugly.
- # The class doesn't provide API for enumerating connected streams
- # It can be a subject for improvements in Python 3.9
-
- _server_impl = None
-
- def __init__(self, client_connected_cb,
- /,
- limit=_DEFAULT_LIMIT,
- shutdown_timeout=60,
- _asyncio_internal=False):
- if not _asyncio_internal:
- raise RuntimeError("_ServerStream is a private asyncio class")
- self._client_connected_cb = client_connected_cb
- self._limit = limit
- self._loop = events.get_running_loop()
- self._streams = {}
- self._shutdown_timeout = shutdown_timeout
-
- def __init_subclass__(cls):
- if not cls.__module__.startswith('asyncio.'):
- raise TypeError(f"asyncio.{cls.__name__} "
- "class cannot be inherited from")
-
- async def bind(self):
- if self._server_impl is not None:
- return
- self._server_impl = await self._bind()
-
- def is_bound(self):
- return self._server_impl is not None
-
- @property
- def sockets(self):
- # multiple value for socket bound to both IPv4 and IPv6 families
- if self._server_impl is None:
- return ()
- return self._server_impl.sockets
-
- def is_serving(self):
- if self._server_impl is None:
- return False
- return self._server_impl.is_serving()
-
- async def start_serving(self):
- await self.bind()
- await self._server_impl.start_serving()
-
- async def serve_forever(self):
- await self.start_serving()
- await self._server_impl.serve_forever()
-
- async def close(self):
- if self._server_impl is None:
- return
- self._server_impl.close()
- streams = list(self._streams.keys())
- active_tasks = list(self._streams.values())
- if streams:
- await tasks.wait([stream.close() for stream in streams])
- await self._server_impl.wait_closed()
- self._server_impl = None
- await self._shutdown_active_tasks(active_tasks)
-
- async def abort(self):
- if self._server_impl is None:
- return
- self._server_impl.close()
- streams = list(self._streams.keys())
- active_tasks = list(self._streams.values())
- if streams:
- await tasks.wait([stream.abort() for stream in streams])
- await self._server_impl.wait_closed()
- self._server_impl = None
- await self._shutdown_active_tasks(active_tasks)
-
- async def __aenter__(self):
- await self.bind()
- return self
-
- async def __aexit__(self, exc_type, exc_value, exc_tb):
- await self.close()
-
- def _attach(self, stream, task):
- self._streams[stream] = task
-
- def _detach(self, stream, task):
- del self._streams[stream]
-
- async def _shutdown_active_tasks(self, active_tasks):
- if not active_tasks:
- return
- # NOTE: tasks finished with exception are reported
- # by the Task.__del__() method.
- done, pending = await tasks.wait(active_tasks,
- timeout=self._shutdown_timeout)
- if not pending:
- return
- for task in pending:
- task.cancel()
- done, pending = await tasks.wait(pending,
- timeout=self._shutdown_timeout)
- for task in pending:
- self._loop.call_exception_handler({
- "message": (f'{task!r} ignored cancellation request '
- f'from a closing {self!r}'),
- "stream_server": self
- })
-
- def __repr__(self):
- ret = [f'{self.__class__.__name__}']
- if self.is_serving():
- ret.append('serving')
- if self.sockets:
- ret.append(f'sockets={self.sockets!r}')
- return '<' + ' '.join(ret) + '>'
-
- def __del__(self, _warn=warnings.warn):
- if self._server_impl is not None:
- _warn(f"unclosed stream server {self!r}",
- ResourceWarning, source=self)
- self._server_impl.close()
-
-
-class StreamServer(_BaseStreamServer):
-
- def __init__(self, client_connected_cb, /, host=None, port=None, *,
- limit=_DEFAULT_LIMIT,
- family=socket.AF_UNSPEC,
- flags=socket.AI_PASSIVE, sock=None, backlog=100,
- ssl=None, reuse_address=None, reuse_port=None,
- ssl_handshake_timeout=None,
- shutdown_timeout=60):
- super().__init__(client_connected_cb,
- limit=limit,
- shutdown_timeout=shutdown_timeout,
- _asyncio_internal=True)
- self._host = host
- self._port = port
- self._family = family
- self._flags = flags
- self._sock = sock
- self._backlog = backlog
- self._ssl = ssl
- self._reuse_address = reuse_address
- self._reuse_port = reuse_port
- self._ssl_handshake_timeout = ssl_handshake_timeout
-
- async def _bind(self):
- def factory():
- protocol = _ServerStreamProtocol(self,
- self._limit,
- self._client_connected_cb,
- loop=self._loop,
- _asyncio_internal=True)
- return protocol
- return await self._loop.create_server(
- factory,
- self._host,
- self._port,
- start_serving=False,
- family=self._family,
- flags=self._flags,
- sock=self._sock,
- backlog=self._backlog,
- ssl=self._ssl,
- reuse_address=self._reuse_address,
- reuse_port=self._reuse_port,
- ssl_handshake_timeout=self._ssl_handshake_timeout)
-
-
if hasattr(socket, 'AF_UNIX'):
# UNIX Domain Sockets are supported on this platform
async def open_unix_connection(path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
- warnings.warn("open_unix_connection() is deprecated since Python 3.8 "
- "in favor of connect_unix(), and scheduled for removal "
- "in Python 3.10",
- DeprecationWarning,
- stacklevel=2)
if loop is None:
loop = events.get_event_loop()
else:
@@ -448,62 +107,15 @@ if hasattr(socket, 'AF_UNIX'):
"and scheduled for removal in Python 3.10.",
DeprecationWarning, stacklevel=2)
reader = StreamReader(limit=limit, loop=loop)
- protocol = StreamReaderProtocol(reader, loop=loop,
- _asyncio_internal=True)
+ protocol = StreamReaderProtocol(reader, loop=loop)
transport, _ = await loop.create_unix_connection(
lambda: protocol, path, **kwds)
writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
-
- def connect_unix(path=None, *,
- limit=_DEFAULT_LIMIT,
- ssl=None, sock=None,
- server_hostname=None,
- ssl_handshake_timeout=None):
- """Similar to `connect()` but works with UNIX Domain Sockets."""
- # Design note:
- # Don't use decorator approach but explicit non-async
- # function to fail fast and explicitly
- # if passed arguments don't match the function signature
- return _ContextManagerHelper(_connect_unix(path,
- limit,
- ssl, sock,
- server_hostname,
- ssl_handshake_timeout))
-
-
- async def _connect_unix(path,
- limit,
- ssl, sock,
- server_hostname,
- ssl_handshake_timeout):
- """Similar to `connect()` but works with UNIX Domain Sockets."""
- loop = events.get_running_loop()
- stream = Stream(mode=StreamMode.READWRITE,
- limit=limit,
- loop=loop,
- _asyncio_internal=True)
- await loop.create_unix_connection(
- lambda: _StreamProtocol(stream,
- loop=loop,
- _asyncio_internal=True),
- path,
- ssl=ssl,
- sock=sock,
- server_hostname=server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout)
- return stream
-
-
async def start_unix_server(client_connected_cb, path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `start_server` but works with UNIX Domain Sockets."""
- warnings.warn("start_unix_server() is deprecated since Python 3.8 "
- "in favor of UnixStreamServer(), and scheduled "
- "for removal in Python 3.10",
- DeprecationWarning,
- stacklevel=2)
if loop is None:
loop = events.get_event_loop()
else:
@@ -514,48 +126,11 @@ if hasattr(socket, 'AF_UNIX'):
def factory():
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, client_connected_cb,
- loop=loop,
- _asyncio_internal=True)
+ loop=loop)
return protocol
return await loop.create_unix_server(factory, path, **kwds)
- class UnixStreamServer(_BaseStreamServer):
-
- def __init__(self, client_connected_cb, /, path=None, *,
- limit=_DEFAULT_LIMIT,
- sock=None,
- backlog=100,
- ssl=None,
- ssl_handshake_timeout=None,
- shutdown_timeout=60):
- super().__init__(client_connected_cb,
- limit=limit,
- shutdown_timeout=shutdown_timeout,
- _asyncio_internal=True)
- self._path = path
- self._sock = sock
- self._backlog = backlog
- self._ssl = ssl
- self._ssl_handshake_timeout = ssl_handshake_timeout
-
- async def _bind(self):
- def factory():
- protocol = _ServerStreamProtocol(self,
- self._limit,
- self._client_connected_cb,
- loop=self._loop,
- _asyncio_internal=True)
- return protocol
- return await self._loop.create_unix_server(
- factory,
- self._path,
- start_serving=False,
- sock=self._sock,
- backlog=self._backlog,
- ssl=self._ssl,
- ssl_handshake_timeout=self._ssl_handshake_timeout)
-
class FlowControlMixin(protocols.Protocol):
"""Reusable flow control logic for StreamWriter.drain().
@@ -567,20 +142,11 @@ class FlowControlMixin(protocols.Protocol):
StreamWriter.drain() must wait for _drain_helper() coroutine.
"""
- def __init__(self, loop=None, *, _asyncio_internal=False):
+ def __init__(self, loop=None):
if loop is None:
self._loop = events.get_event_loop()
else:
self._loop = loop
- if not _asyncio_internal:
- # NOTE:
- # Avoid inheritance from FlowControlMixin
- # Copy-paste the code to your project
- # if you need flow control helpers
- warnings.warn(f"{self.__class__} should be instantiated "
- "by asyncio internals only, "
- "please avoid its creation from user code",
- DeprecationWarning)
self._paused = False
self._drain_waiter = None
self._connection_lost = False
@@ -634,8 +200,6 @@ class FlowControlMixin(protocols.Protocol):
raise NotImplementedError
-# begin legacy stream APIs
-
class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
"""Helper class to adapt between Protocol and StreamReader.
@@ -645,47 +209,103 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
call inappropriate methods of the protocol.)
"""
- def __init__(self, stream_reader, client_connected_cb=None, loop=None,
- *, _asyncio_internal=False):
- super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
- self._stream_reader = stream_reader
+ _source_traceback = None
+
+ def __init__(self, stream_reader, client_connected_cb=None, loop=None):
+ super().__init__(loop=loop)
+ if stream_reader is not None:
+ self._stream_reader_wr = weakref.ref(stream_reader,
+ self._on_reader_gc)
+ self._source_traceback = stream_reader._source_traceback
+ else:
+ self._stream_reader_wr = None
+ if client_connected_cb is not None:
+ # This is a stream created by the `create_server()` function.
+ # Keep a strong reference to the reader until a connection
+ # is established.
+ self._strong_reader = stream_reader
+ self._reject_connection = False
self._stream_writer = None
+ self._transport = None
self._client_connected_cb = client_connected_cb
self._over_ssl = False
self._closed = self._loop.create_future()
+ def _on_reader_gc(self, wr):
+ transport = self._transport
+ if transport is not None:
+ # connection_made was called
+ context = {
+ 'message': ('An open stream object is being garbage '
+ 'collected; call "stream.close()" explicitly.')
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ transport.abort()
+ else:
+ self._reject_connection = True
+ self._stream_reader_wr = None
+
+ @property
+ def _stream_reader(self):
+ if self._stream_reader_wr is None:
+ return None
+ return self._stream_reader_wr()
+
def connection_made(self, transport):
- self._stream_reader.set_transport(transport)
+ if self._reject_connection:
+ context = {
+ 'message': ('An open stream was garbage collected prior to '
+ 'establishing network connection; '
+ 'call "stream.close()" explicitly.')
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ transport.abort()
+ return
+ self._transport = transport
+ reader = self._stream_reader
+ if reader is not None:
+ reader.set_transport(transport)
self._over_ssl = transport.get_extra_info('sslcontext') is not None
if self._client_connected_cb is not None:
self._stream_writer = StreamWriter(transport, self,
- self._stream_reader,
+ reader,
self._loop)
- res = self._client_connected_cb(self._stream_reader,
+ res = self._client_connected_cb(reader,
self._stream_writer)
if coroutines.iscoroutine(res):
self._loop.create_task(res)
+ self._strong_reader = None
def connection_lost(self, exc):
- if self._stream_reader is not None:
+ reader = self._stream_reader
+ if reader is not None:
if exc is None:
- self._stream_reader.feed_eof()
+ reader.feed_eof()
else:
- self._stream_reader.set_exception(exc)
+ reader.set_exception(exc)
if not self._closed.done():
if exc is None:
self._closed.set_result(None)
else:
self._closed.set_exception(exc)
super().connection_lost(exc)
- self._stream_reader = None
+ self._stream_reader_wr = None
self._stream_writer = None
+ self._transport = None
def data_received(self, data):
- self._stream_reader.feed_data(data)
+ reader = self._stream_reader
+ if reader is not None:
+ reader.feed_data(data)
def eof_received(self):
- self._stream_reader.feed_eof()
+ reader = self._stream_reader
+ if reader is not None:
+ reader.feed_eof()
if self._over_ssl:
# Prevent a warning in SSLProtocol.eof_received:
# "returning true from eof_received()
@@ -693,6 +313,9 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
return False
return True
+ def _get_close_waiter(self, stream):
+ return self._closed
+
def __del__(self):
# Prevent reports about unhandled exceptions.
# Better than self._closed._log_traceback = False hack
@@ -718,6 +341,8 @@ class StreamWriter:
assert reader is None or isinstance(reader, StreamReader)
self._reader = reader
self._loop = loop
+ self._complete_fut = self._loop.create_future()
+ self._complete_fut.set_result(None)
def __repr__(self):
info = [self.__class__.__name__, f'transport={self._transport!r}']
@@ -748,7 +373,7 @@ class StreamWriter:
return self._transport.is_closing()
async def wait_closed(self):
- await self._protocol._closed
+ await self._protocol._get_close_waiter(self)
def get_extra_info(self, name, default=None):
return self._transport.get_extra_info(name, default)
@@ -766,561 +391,24 @@ class StreamWriter:
if exc is not None:
raise exc
if self._transport.is_closing():
+ # Wait for protocol.connection_lost() call
+ # Raise connection closing error if any,
+ # ConnectionResetError otherwise
# Yield to the event loop so connection_lost() may be
# called. Without this, _drain_helper() would return
# immediately, and code that calls
# write(...); await drain()
# in a loop would never call connection_lost(), so it
# would not see an error when the socket is closed.
- await tasks.sleep(0, loop=self._loop)
+ await sleep(0)
await self._protocol._drain_helper()
class StreamReader:
- def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
- # The line length limit is a security feature;
- # it also doubles as half the buffer limit.
-
- if limit <= 0:
- raise ValueError('Limit cannot be <= 0')
-
- self._limit = limit
- if loop is None:
- self._loop = events.get_event_loop()
- else:
- self._loop = loop
- self._buffer = bytearray()
- self._eof = False # Whether we're done.
- self._waiter = None # A future used by _wait_for_data()
- self._exception = None
- self._transport = None
- self._paused = False
-
- def __repr__(self):
- info = ['StreamReader']
- if self._buffer:
- info.append(f'{len(self._buffer)} bytes')
- if self._eof:
- info.append('eof')
- if self._limit != _DEFAULT_LIMIT:
- info.append(f'limit={self._limit}')
- if self._waiter:
- info.append(f'waiter={self._waiter!r}')
- if self._exception:
- info.append(f'exception={self._exception!r}')
- if self._transport:
- info.append(f'transport={self._transport!r}')
- if self._paused:
- info.append('paused')
- return '<{}>'.format(' '.join(info))
-
- def exception(self):
- return self._exception
-
- def set_exception(self, exc):
- self._exception = exc
-
- waiter = self._waiter
- if waiter is not None:
- self._waiter = None
- if not waiter.cancelled():
- waiter.set_exception(exc)
-
- def _wakeup_waiter(self):
- """Wakeup read*() functions waiting for data or EOF."""
- waiter = self._waiter
- if waiter is not None:
- self._waiter = None
- if not waiter.cancelled():
- waiter.set_result(None)
-
- def set_transport(self, transport):
- assert self._transport is None, 'Transport already set'
- self._transport = transport
-
- def _maybe_resume_transport(self):
- if self._paused and len(self._buffer) <= self._limit:
- self._paused = False
- self._transport.resume_reading()
-
- def feed_eof(self):
- self._eof = True
- self._wakeup_waiter()
-
- def at_eof(self):
- """Return True if the buffer is empty and 'feed_eof' was called."""
- return self._eof and not self._buffer
-
- def feed_data(self, data):
- assert not self._eof, 'feed_data after feed_eof'
-
- if not data:
- return
-
- self._buffer.extend(data)
- self._wakeup_waiter()
-
- if (self._transport is not None and
- not self._paused and
- len(self._buffer) > 2 * self._limit):
- try:
- self._transport.pause_reading()
- except NotImplementedError:
- # The transport can't be paused.
- # We'll just have to buffer all data.
- # Forget the transport so we don't keep trying.
- self._transport = None
- else:
- self._paused = True
-
- async def _wait_for_data(self, func_name):
- """Wait until feed_data() or feed_eof() is called.
-
- If stream was paused, automatically resume it.
- """
- # StreamReader uses a future to link the protocol feed_data() method
- # to a read coroutine. Running two read coroutines at the same time
- # would have an unexpected behaviour. It would not possible to know
- # which coroutine would get the next data.
- if self._waiter is not None:
- raise RuntimeError(
- f'{func_name}() called while another coroutine is '
- f'already waiting for incoming data')
-
- assert not self._eof, '_wait_for_data after EOF'
-
- # Waiting for data while paused will make deadlock, so prevent it.
- # This is essential for readexactly(n) for case when n > self._limit.
- if self._paused:
- self._paused = False
- self._transport.resume_reading()
-
- self._waiter = self._loop.create_future()
- try:
- await self._waiter
- finally:
- self._waiter = None
-
- async def readline(self):
- """Read chunk of data from the stream until newline (b'\n') is found.
-
- On success, return chunk that ends with newline. If only partial
- line can be read due to EOF, return incomplete line without
- terminating newline. When EOF was reached while no bytes read, empty
- bytes object is returned.
-
- If limit is reached, ValueError will be raised. In that case, if
- newline was found, complete line including newline will be removed
- from internal buffer. Else, internal buffer will be cleared. Limit is
- compared against part of the line without newline.
-
- If stream was paused, this function will automatically resume it if
- needed.
- """
- sep = b'\n'
- seplen = len(sep)
- try:
- line = await self.readuntil(sep)
- except exceptions.IncompleteReadError as e:
- return e.partial
- except exceptions.LimitOverrunError as e:
- if self._buffer.startswith(sep, e.consumed):
- del self._buffer[:e.consumed + seplen]
- else:
- self._buffer.clear()
- self._maybe_resume_transport()
- raise ValueError(e.args[0])
- return line
-
- async def readuntil(self, separator=b'\n'):
- """Read data from the stream until ``separator`` is found.
-
- On success, the data and separator will be removed from the
- internal buffer (consumed). Returned data will include the
- separator at the end.
-
- Configured stream limit is used to check result. Limit sets the
- maximal length of data that can be returned, not counting the
- separator.
-
- If an EOF occurs and the complete separator is still not found,
- an IncompleteReadError exception will be raised, and the internal
- buffer will be reset. The IncompleteReadError.partial attribute
- may contain the separator partially.
-
- If the data cannot be read because of over limit, a
- LimitOverrunError exception will be raised, and the data
- will be left in the internal buffer, so it can be read again.
- """
- seplen = len(separator)
- if seplen == 0:
- raise ValueError('Separator should be at least one-byte string')
-
- if self._exception is not None:
- raise self._exception
-
- # Consume whole buffer except last bytes, which length is
- # one less than seplen. Let's check corner cases with
- # separator='SEPARATOR':
- # * we have received almost complete separator (without last
- # byte). i.e buffer='some textSEPARATO'. In this case we
- # can safely consume len(separator) - 1 bytes.
- # * last byte of buffer is first byte of separator, i.e.
- # buffer='abcdefghijklmnopqrS'. We may safely consume
- # everything except that last byte, but this require to
- # analyze bytes of buffer that match partial separator.
- # This is slow and/or require FSM. For this case our
- # implementation is not optimal, since require rescanning
- # of data that is known to not belong to separator. In
- # real world, separator will not be so long to notice
- # performance problems. Even when reading MIME-encoded
- # messages :)
-
- # `offset` is the number of bytes from the beginning of the buffer
- # where there is no occurrence of `separator`.
- offset = 0
-
- # Loop until we find `separator` in the buffer, exceed the buffer size,
- # or an EOF has happened.
- while True:
- buflen = len(self._buffer)
-
- # Check if we now have enough data in the buffer for `separator` to
- # fit.
- if buflen - offset >= seplen:
- isep = self._buffer.find(separator, offset)
-
- if isep != -1:
- # `separator` is in the buffer. `isep` will be used later
- # to retrieve the data.
- break
-
- # see upper comment for explanation.
- offset = buflen + 1 - seplen
- if offset > self._limit:
- raise exceptions.LimitOverrunError(
- 'Separator is not found, and chunk exceed the limit',
- offset)
-
- # Complete message (with full separator) may be present in buffer
- # even when EOF flag is set. This may happen when the last chunk
- # adds data which makes separator be found. That's why we check for
- # EOF *ater* inspecting the buffer.
- if self._eof:
- chunk = bytes(self._buffer)
- self._buffer.clear()
- raise exceptions.IncompleteReadError(chunk, None)
-
- # _wait_for_data() will resume reading if stream was paused.
- await self._wait_for_data('readuntil')
-
- if isep > self._limit:
- raise exceptions.LimitOverrunError(
- 'Separator is found, but chunk is longer than limit', isep)
-
- chunk = self._buffer[:isep + seplen]
- del self._buffer[:isep + seplen]
- self._maybe_resume_transport()
- return bytes(chunk)
-
- async def read(self, n=-1):
- """Read up to `n` bytes from the stream.
-
- If n is not provided, or set to -1, read until EOF and return all read
- bytes. If the EOF was received and the internal buffer is empty, return
- an empty bytes object.
-
- If n is zero, return empty bytes object immediately.
-
- If n is positive, this function try to read `n` bytes, and may return
- less or equal bytes than requested, but at least one byte. If EOF was
- received before any byte is read, this function returns empty byte
- object.
-
- Returned value is not limited with limit, configured at stream
- creation.
-
- If stream was paused, this function will automatically resume it if
- needed.
- """
-
- if self._exception is not None:
- raise self._exception
-
- if n == 0:
- return b''
-
- if n < 0:
- # This used to just loop creating a new waiter hoping to
- # collect everything in self._buffer, but that would
- # deadlock if the subprocess sends more than self.limit
- # bytes. So just call self.read(self._limit) until EOF.
- blocks = []
- while True:
- block = await self.read(self._limit)
- if not block:
- break
- blocks.append(block)
- return b''.join(blocks)
-
- if not self._buffer and not self._eof:
- await self._wait_for_data('read')
-
- # This will work right even if buffer is less than n bytes
- data = bytes(self._buffer[:n])
- del self._buffer[:n]
-
- self._maybe_resume_transport()
- return data
-
- async def readexactly(self, n):
- """Read exactly `n` bytes.
-
- Raise an IncompleteReadError if EOF is reached before `n` bytes can be
- read. The IncompleteReadError.partial attribute of the exception will
- contain the partial read bytes.
-
- if n is zero, return empty bytes object.
-
- Returned value is not limited with limit, configured at stream
- creation.
-
- If stream was paused, this function will automatically resume it if
- needed.
- """
- if n < 0:
- raise ValueError('readexactly size can not be less than zero')
-
- if self._exception is not None:
- raise self._exception
-
- if n == 0:
- return b''
-
- while len(self._buffer) < n:
- if self._eof:
- incomplete = bytes(self._buffer)
- self._buffer.clear()
- raise exceptions.IncompleteReadError(incomplete, n)
-
- await self._wait_for_data('readexactly')
-
- if len(self._buffer) == n:
- data = bytes(self._buffer)
- self._buffer.clear()
- else:
- data = bytes(self._buffer[:n])
- del self._buffer[:n]
- self._maybe_resume_transport()
- return data
-
- def __aiter__(self):
- return self
-
- async def __anext__(self):
- val = await self.readline()
- if val == b'':
- raise StopAsyncIteration
- return val
-
-
-# end legacy stream APIs
-
-
-class _BaseStreamProtocol(FlowControlMixin, protocols.Protocol):
- """Helper class to adapt between Protocol and StreamReader.
-
- (This is a helper class instead of making StreamReader itself a
- Protocol subclass, because the StreamReader has other potential
- uses, and to prevent the user of the StreamReader to accidentally
- call inappropriate methods of the protocol.)
- """
-
- _stream = None # initialized in derived classes
-
- def __init__(self, loop=None,
- *, _asyncio_internal=False):
- super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
- self._transport = None
- self._over_ssl = False
- self._closed = self._loop.create_future()
-
- def connection_made(self, transport):
- self._transport = transport
- self._over_ssl = transport.get_extra_info('sslcontext') is not None
-
- def connection_lost(self, exc):
- stream = self._stream
- if stream is not None:
- if exc is None:
- stream._feed_eof()
- else:
- stream._set_exception(exc)
- if not self._closed.done():
- if exc is None:
- self._closed.set_result(None)
- else:
- self._closed.set_exception(exc)
- super().connection_lost(exc)
- self._transport = None
-
- def data_received(self, data):
- stream = self._stream
- if stream is not None:
- stream._feed_data(data)
-
- def eof_received(self):
- stream = self._stream
- if stream is not None:
- stream._feed_eof()
- if self._over_ssl:
- # Prevent a warning in SSLProtocol.eof_received:
- # "returning true from eof_received()
- # has no effect when using ssl"
- return False
- return True
-
- def _get_close_waiter(self, stream):
- return self._closed
-
- def __del__(self):
- # Prevent reports about unhandled exceptions.
- # Better than self._closed._log_traceback = False hack
- closed = self._get_close_waiter(self._stream)
- if closed.done() and not closed.cancelled():
- closed.exception()
-
-
-class _StreamProtocol(_BaseStreamProtocol):
- _source_traceback = None
-
- def __init__(self, stream, loop=None,
- *, _asyncio_internal=False):
- super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
- self._source_traceback = stream._source_traceback
- self._stream_wr = weakref.ref(stream, self._on_gc)
- self._reject_connection = False
-
- def _on_gc(self, wr):
- transport = self._transport
- if transport is not None:
- # connection_made was called
- context = {
- 'message': ('An open stream object is being garbage '
- 'collected; call "stream.close()" explicitly.')
- }
- if self._source_traceback:
- context['source_traceback'] = self._source_traceback
- self._loop.call_exception_handler(context)
- transport.abort()
- else:
- self._reject_connection = True
- self._stream_wr = None
-
- @property
- def _stream(self):
- if self._stream_wr is None:
- return None
- return self._stream_wr()
-
- def connection_made(self, transport):
- if self._reject_connection:
- context = {
- 'message': ('An open stream was garbage collected prior to '
- 'establishing network connection; '
- 'call "stream.close()" explicitly.')
- }
- if self._source_traceback:
- context['source_traceback'] = self._source_traceback
- self._loop.call_exception_handler(context)
- transport.abort()
- return
- super().connection_made(transport)
- stream = self._stream
- if stream is None:
- return
- stream._set_transport(transport)
- stream._protocol = self
-
- def connection_lost(self, exc):
- super().connection_lost(exc)
- self._stream_wr = None
-
-
-class _ServerStreamProtocol(_BaseStreamProtocol):
- def __init__(self, server, limit, client_connected_cb, loop=None,
- *, _asyncio_internal=False):
- super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
- assert self._closed
- self._client_connected_cb = client_connected_cb
- self._limit = limit
- self._server = server
- self._task = None
-
- def connection_made(self, transport):
- super().connection_made(transport)
- stream = Stream(mode=StreamMode.READWRITE,
- transport=transport,
- protocol=self,
- limit=self._limit,
- loop=self._loop,
- is_server_side=True,
- _asyncio_internal=True)
- self._stream = stream
- # If self._client_connected_cb(self._stream) fails
- # the exception is logged by transport
- self._task = self._loop.create_task(
- self._client_connected_cb(self._stream))
- self._server._attach(stream, self._task)
-
- def connection_lost(self, exc):
- super().connection_lost(exc)
- self._server._detach(self._stream, self._task)
- self._stream = None
-
-
-class _OptionalAwait:
- # The class doesn't create a coroutine
- # if not awaited
- # It prevents "coroutine is never awaited" message
-
- __slots___ = ('_method',)
-
- def __init__(self, method):
- self._method = method
-
- def __await__(self):
- return self._method().__await__()
-
-
-class Stream:
- """Wraps a Transport.
-
- This exposes write(), writelines(), [can_]write_eof(),
- get_extra_info() and close(). It adds drain() which returns an
- optional Future on which you can wait for flow control. It also
- adds a transport property which references the Transport
- directly.
- """
-
_source_traceback = None
- def __init__(self, mode, *,
- transport=None,
- protocol=None,
- loop=None,
- limit=_DEFAULT_LIMIT,
- is_server_side=False,
- _asyncio_internal=False):
- if not _asyncio_internal:
- raise RuntimeError(f"{self.__class__} should be instantiated "
- "by asyncio internals only")
- self._mode = mode
- self._transport = transport
- self._protocol = protocol
- self._is_server_side = is_server_side
-
+ def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
# The line length limit is a security feature;
# it also doubles as half the buffer limit.
@@ -1336,17 +424,14 @@ class Stream:
self._eof = False # Whether we're done.
self._waiter = None # A future used by _wait_for_data()
self._exception = None
+ self._transport = None
self._paused = False
- self._complete_fut = self._loop.create_future()
- self._complete_fut.set_result(None)
-
if self._loop.get_debug():
self._source_traceback = format_helpers.extract_stack(
sys._getframe(1))
def __repr__(self):
- info = [self.__class__.__name__]
- info.append(f'mode={self._mode}')
+ info = ['StreamReader']
if self._buffer:
info.append(f'{len(self._buffer)} bytes')
if self._eof:
@@ -1363,127 +448,10 @@ class Stream:
info.append('paused')
return '<{}>'.format(' '.join(info))
- @property
- def mode(self):
- return self._mode
-
- def is_server_side(self):
- return self._is_server_side
-
- @property
- def transport(self):
- warnings.warn("Stream.transport attribute is deprecated "
- "since Python 3.8 and is scheduled for removal in 3.10; "
- "it is an internal API",
- DeprecationWarning,
- stacklevel=2)
- return self._transport
-
- def write(self, data):
- _ensure_can_write(self._mode)
- self._transport.write(data)
- return self._fast_drain()
-
- def writelines(self, data):
- _ensure_can_write(self._mode)
- self._transport.writelines(data)
- return self._fast_drain()
-
- def _fast_drain(self):
- # The helper tries to use fast-path to return already existing
- # complete future object if underlying transport is not paused
- # and actual waiting for writing resume is not needed
- exc = self.exception()
- if exc is not None:
- fut = self._loop.create_future()
- fut.set_exception(exc)
- return fut
- if not self._transport.is_closing():
- if self._protocol._connection_lost:
- fut = self._loop.create_future()
- fut.set_exception(ConnectionResetError('Connection lost'))
- return fut
- if not self._protocol._paused:
- # fast path, the stream is not paused
- # no need to wait for resume signal
- return self._complete_fut
- return _OptionalAwait(self.drain)
-
- def write_eof(self):
- _ensure_can_write(self._mode)
- return self._transport.write_eof()
-
- def can_write_eof(self):
- if not self._mode.is_write():
- return False
- return self._transport.can_write_eof()
-
- def close(self):
- self._transport.close()
- return _OptionalAwait(self.wait_closed)
-
- def is_closing(self):
- return self._transport.is_closing()
-
- async def abort(self):
- self._transport.abort()
- await self.wait_closed()
-
- async def wait_closed(self):
- await self._protocol._get_close_waiter(self)
-
- def get_extra_info(self, name, default=None):
- return self._transport.get_extra_info(name, default)
-
- async def drain(self):
- """Flush the write buffer.
-
- The intended use is to write
-
- w.write(data)
- await w.drain()
- """
- _ensure_can_write(self._mode)
- exc = self.exception()
- if exc is not None:
- raise exc
- if self._transport.is_closing():
- # Wait for protocol.connection_lost() call
- # Raise connection closing error if any,
- # ConnectionResetError otherwise
- await tasks.sleep(0)
- await self._protocol._drain_helper()
-
- async def sendfile(self, file, offset=0, count=None, *, fallback=True):
- await self.drain() # check for stream mode and exceptions
- return await self._loop.sendfile(self._transport, file,
- offset, count, fallback=fallback)
-
- async def start_tls(self, sslcontext, *,
- server_hostname=None,
- ssl_handshake_timeout=None):
- await self.drain() # check for stream mode and exceptions
- transport = await self._loop.start_tls(
- self._transport, self._protocol, sslcontext,
- server_side=self._is_server_side,
- server_hostname=server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout)
- self._transport = transport
- self._protocol._transport = transport
- self._protocol._over_ssl = True
-
def exception(self):
return self._exception
def set_exception(self, exc):
- warnings.warn("Stream.set_exception() is deprecated "
- "since Python 3.8 and is scheduled for removal in 3.10; "
- "it is an internal API",
- DeprecationWarning,
- stacklevel=2)
- self._set_exception(exc)
-
- def _set_exception(self, exc):
self._exception = exc
waiter = self._waiter
@@ -1501,16 +469,6 @@ class Stream:
waiter.set_result(None)
def set_transport(self, transport):
- warnings.warn("Stream.set_transport() is deprecated "
- "since Python 3.8 and is scheduled for removal in 3.10; "
- "it is an internal API",
- DeprecationWarning,
- stacklevel=2)
- self._set_transport(transport)
-
- def _set_transport(self, transport):
- if transport is self._transport:
- return
assert self._transport is None, 'Transport already set'
self._transport = transport
@@ -1520,14 +478,6 @@ class Stream:
self._transport.resume_reading()
def feed_eof(self):
- warnings.warn("Stream.feed_eof() is deprecated "
- "since Python 3.8 and is scheduled for removal in 3.10; "
- "it is an internal API",
- DeprecationWarning,
- stacklevel=2)
- self._feed_eof()
-
- def _feed_eof(self):
self._eof = True
self._wakeup_waiter()
@@ -1536,15 +486,6 @@ class Stream:
return self._eof and not self._buffer
def feed_data(self, data):
- warnings.warn("Stream.feed_data() is deprecated "
- "since Python 3.8 and is scheduled for removal in 3.10; "
- "it is an internal API",
- DeprecationWarning,
- stacklevel=2)
- self._feed_data(data)
-
- def _feed_data(self, data):
- _ensure_can_read(self._mode)
assert not self._eof, 'feed_data after feed_eof'
if not data:
@@ -1610,7 +551,6 @@ class Stream:
If stream was paused, this function will automatically resume it if
needed.
"""
- _ensure_can_read(self._mode)
sep = b'\n'
seplen = len(sep)
try:
@@ -1646,7 +586,6 @@ class Stream:
LimitOverrunError exception will be raised, and the data
will be left in the internal buffer, so it can be read again.
"""
- _ensure_can_read(self._mode)
seplen = len(separator)
if seplen == 0:
raise ValueError('Separator should be at least one-byte string')
@@ -1738,7 +677,6 @@ class Stream:
If stream was paused, this function will automatically resume it if
needed.
"""
- _ensure_can_read(self._mode)
if self._exception is not None:
raise self._exception
@@ -1784,7 +722,6 @@ class Stream:
If stream was paused, this function will automatically resume it if
needed.
"""
- _ensure_can_read(self._mode)
if n < 0:
raise ValueError('readexactly size can not be less than zero')
@@ -1812,7 +749,6 @@ class Stream:
return data
def __aiter__(self):
- _ensure_can_read(self._mode)
return self
async def __anext__(self):
@@ -1820,9 +756,3 @@ class Stream:
if val == b'':
raise StopAsyncIteration
return val
-
- async def __aenter__(self):
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await self.close()