summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorAndrew Svetlov <andrew.svetlov@gmail.com>2019-05-27 19:56:22 (GMT)
committerMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>2019-05-27 19:56:22 (GMT)
commit23b4b697e5b6cc897696f9c0288c187d2d24bff2 (patch)
tree2f70e14fe527878cd69ccbefca007a1e987943ed /Lib
parent6f6ff8a56518a80da406aad6ac8364c046cc7f18 (diff)
downloadcpython-23b4b697e5b6cc897696f9c0288c187d2d24bff2.zip
cpython-23b4b697e5b6cc897696f9c0288c187d2d24bff2.tar.gz
cpython-23b4b697e5b6cc897696f9c0288c187d2d24bff2.tar.bz2
bpo-36889: Merge asyncio streams (GH-13251)
https://bugs.python.org/issue36889
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/__init__.py38
-rw-r--r--Lib/asyncio/streams.py1236
-rw-r--r--Lib/asyncio/subprocess.py35
-rw-r--r--Lib/asyncio/windows_events.py2
-rw-r--r--Lib/test/test___all__.py36
-rw-r--r--Lib/test/test_asyncio/test_base_events.py5
-rw-r--r--Lib/test/test_asyncio/test_buffered_proto.py7
-rw-r--r--Lib/test/test_asyncio/test_pep492.py4
-rw-r--r--Lib/test/test_asyncio/test_server.py10
-rw-r--r--Lib/test/test_asyncio/test_sslproto.py37
-rw-r--r--Lib/test/test_asyncio/test_streams.py1008
-rw-r--r--Lib/test/test_asyncio/test_windows_events.py14
12 files changed, 2049 insertions, 383 deletions
diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py
index 28c2e2c..a6a29db 100644
--- a/Lib/asyncio/__init__.py
+++ b/Lib/asyncio/__init__.py
@@ -3,6 +3,7 @@
# flake8: noqa
import sys
+import warnings
# This relies on each of the submodules having an __all__ variable.
from .base_events import *
@@ -43,3 +44,40 @@ if sys.platform == 'win32': # pragma: no cover
else:
from .unix_events import * # pragma: no cover
__all__ += unix_events.__all__
+
+
+__all__ += ('StreamReader', 'StreamWriter', 'StreamReaderProtocol') # deprecated
+
+
+def __getattr__(name):
+ global StreamReader, StreamWriter, StreamReaderProtocol
+ if name == 'StreamReader':
+ warnings.warn("StreamReader is deprecated since Python 3.8 "
+ "in favor of Stream, and scheduled for removal "
+ "in Python 3.10",
+ DeprecationWarning,
+ stacklevel=2)
+ from .streams import StreamReader as sr
+ StreamReader = sr
+ return StreamReader
+ if name == 'StreamWriter':
+ warnings.warn("StreamWriter is deprecated since Python 3.8 "
+ "in favor of Stream, and scheduled for removal "
+ "in Python 3.10",
+ DeprecationWarning,
+ stacklevel=2)
+ from .streams import StreamWriter as sw
+ StreamWriter = sw
+ return StreamWriter
+ if name == 'StreamReaderProtocol':
+ warnings.warn("Using asyncio internal class StreamReaderProtocol "
+ "is deprecated since Python 3.8 "
+ " and scheduled for removal "
+ "in Python 3.10",
+ DeprecationWarning,
+ stacklevel=2)
+ from .streams import StreamReaderProtocol as srp
+ StreamReaderProtocol = srp
+ return StreamReaderProtocol
+
+ raise AttributeError(f"module {__name__} has no attribute {name}")
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index 2f0cbfd..480f1a3 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -1,14 +1,19 @@
__all__ = (
- 'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
- 'open_connection', 'start_server')
+ 'Stream', 'StreamMode',
+ 'open_connection', 'start_server',
+ 'connect', 'connect_read_pipe', 'connect_write_pipe',
+ 'StreamServer')
+import enum
import socket
import sys
import warnings
import weakref
if hasattr(socket, 'AF_UNIX'):
- __all__ += ('open_unix_connection', 'start_unix_server')
+ __all__ += ('open_unix_connection', 'start_unix_server',
+ 'connect_unix',
+ 'UnixStreamServer')
from . import coroutines
from . import events
@@ -16,12 +21,134 @@ from . import exceptions
from . import format_helpers
from . import protocols
from .log import logger
-from .tasks import sleep
+from . import tasks
_DEFAULT_LIMIT = 2 ** 16 # 64 KiB
+class StreamMode(enum.Flag):
+ READ = enum.auto()
+ WRITE = enum.auto()
+ READWRITE = READ | WRITE
+
+
+def _ensure_can_read(mode):
+ if not mode & StreamMode.READ:
+ raise RuntimeError("The stream is write-only")
+
+
+def _ensure_can_write(mode):
+ if not mode & StreamMode.WRITE:
+ raise RuntimeError("The stream is read-only")
+
+
+class _ContextManagerHelper:
+ __slots__ = ('_awaitable', '_result')
+
+ def __init__(self, awaitable):
+ self._awaitable = awaitable
+ self._result = None
+
+ def __await__(self):
+ return self._awaitable.__await__()
+
+ async def __aenter__(self):
+ ret = await self._awaitable
+ result = await ret.__aenter__()
+ self._result = result
+ return result
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ return await self._result.__aexit__(exc_type, exc_val, exc_tb)
+
+
+def connect(host=None, port=None, *,
+ limit=_DEFAULT_LIMIT,
+ ssl=None, family=0, proto=0,
+ flags=0, sock=None, local_addr=None,
+ server_hostname=None,
+ ssl_handshake_timeout=None,
+ happy_eyeballs_delay=None, interleave=None):
+ # Design note:
+ # Don't use decorator approach but exilicit non-async
+ # function to fail fast and explicitly
+ # if passed arguments don't match the function signature
+ return _ContextManagerHelper(_connect(host, port, limit,
+ ssl, family, proto,
+ flags, sock, local_addr,
+ server_hostname,
+ ssl_handshake_timeout,
+ happy_eyeballs_delay,
+ interleave))
+
+
+async def _connect(host, port,
+ limit,
+ ssl, family, proto,
+ flags, sock, local_addr,
+ server_hostname,
+ ssl_handshake_timeout,
+ happy_eyeballs_delay, interleave):
+ loop = events.get_running_loop()
+ stream = Stream(mode=StreamMode.READWRITE,
+ limit=limit,
+ loop=loop,
+ _asyncio_internal=True)
+ await loop.create_connection(
+ lambda: _StreamProtocol(stream, loop=loop,
+ _asyncio_internal=True),
+ host, port,
+ ssl=ssl, family=family, proto=proto,
+ flags=flags, sock=sock, local_addr=local_addr,
+ server_hostname=server_hostname,
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ happy_eyeballs_delay=happy_eyeballs_delay, interleave=interleave)
+ return stream
+
+
+def connect_read_pipe(pipe, *, limit=_DEFAULT_LIMIT):
+ # Design note:
+ # Don't use decorator approach but explicit non-async
+ # function to fail fast and explicitly
+ # if passed arguments don't match the function signature
+ return _ContextManagerHelper(_connect_read_pipe(pipe, limit))
+
+
+async def _connect_read_pipe(pipe, limit):
+ loop = events.get_running_loop()
+ stream = Stream(mode=StreamMode.READ,
+ limit=limit,
+ loop=loop,
+ _asyncio_internal=True)
+ await loop.connect_read_pipe(
+ lambda: _StreamProtocol(stream, loop=loop,
+ _asyncio_internal=True),
+ pipe)
+ return stream
+
+
+def connect_write_pipe(pipe, *, limit=_DEFAULT_LIMIT):
+ # Design note:
+ # Don't use decorator approach but explicit non-async
+ # function to fail fast and explicitly
+ # if passed arguments don't match the function signature
+ return _ContextManagerHelper(_connect_write_pipe(pipe, limit))
+
+
+async def _connect_write_pipe(pipe, limit):
+ loop = events.get_running_loop()
+ stream = Stream(mode=StreamMode.WRITE,
+ limit=limit,
+ loop=loop,
+ _asyncio_internal=True)
+ await loop.connect_write_pipe(
+ lambda: _StreamProtocol(stream, loop=loop,
+ _asyncio_internal=True),
+ pipe)
+ return stream
+
+
async def open_connection(host=None, port=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""A wrapper for create_connection() returning a (reader, writer) pair.
@@ -41,16 +168,18 @@ async def open_connection(host=None, port=None, *,
StreamReaderProtocol classes, just copy the code -- there's
really nothing special here except some convenience.)
"""
+ warnings.warn("open_connection() is deprecated since Python 3.8 "
+ "in favor of connect(), and scheduled for removal "
+ "in Python 3.10",
+ DeprecationWarning,
+ stacklevel=2)
if loop is None:
loop = events.get_event_loop()
- reader = StreamReader(limit=limit, loop=loop,
- _asyncio_internal=True)
- protocol = StreamReaderProtocol(reader, loop=loop,
- _asyncio_internal=True)
+ reader = StreamReader(limit=limit, loop=loop)
+ protocol = StreamReaderProtocol(reader, loop=loop, _asyncio_internal=True)
transport, _ = await loop.create_connection(
lambda: protocol, host, port, **kwds)
- writer = StreamWriter(transport, protocol, reader, loop,
- _asyncio_internal=True)
+ writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
@@ -77,12 +206,16 @@ async def start_server(client_connected_cb, host=None, port=None, *,
The return value is the same as loop.create_server(), i.e. a
Server object which can be used to stop the service.
"""
+ warnings.warn("start_server() is deprecated since Python 3.8 "
+ "in favor of StreamServer(), and scheduled for removal "
+ "in Python 3.10",
+ DeprecationWarning,
+ stacklevel=2)
if loop is None:
loop = events.get_event_loop()
def factory():
- reader = StreamReader(limit=limit, loop=loop,
- _asyncio_internal=True)
+ reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop,
_asyncio_internal=True)
@@ -91,33 +224,258 @@ async def start_server(client_connected_cb, host=None, port=None, *,
return await loop.create_server(factory, host, port, **kwds)
+class _BaseStreamServer:
+ # Design notes.
+ # StreamServer and UnixStreamServer are exposed as FINAL classes,
+ # not function factories.
+ # async with serve(host, port) as server:
+ # server.start_serving()
+ # looks ugly.
+ # The class doesn't provide API for enumerating connected streams
+ # It can be a subject for improvements in Python 3.9
+
+ _server_impl = None
+
+ def __init__(self, client_connected_cb,
+ /,
+ limit=_DEFAULT_LIMIT,
+ shutdown_timeout=60,
+ _asyncio_internal=False):
+ if not _asyncio_internal:
+ raise RuntimeError("_ServerStream is a private asyncio class")
+ self._client_connected_cb = client_connected_cb
+ self._limit = limit
+ self._loop = events.get_running_loop()
+ self._streams = {}
+ self._shutdown_timeout = shutdown_timeout
+
+ def __init_subclass__(cls):
+ if not cls.__module__.startswith('asyncio.'):
+ raise TypeError(f"asyncio.{cls.__name__} "
+ "class cannot be inherited from")
+
+ async def bind(self):
+ if self._server_impl is not None:
+ return
+ self._server_impl = await self._bind()
+
+ def is_bound(self):
+ return self._server_impl is not None
+
+ @property
+ def sockets(self):
+ # multiple value for socket bound to both IPv4 and IPv6 families
+ if self._server_impl is None:
+ return ()
+ return self._server_impl.sockets
+
+ def is_serving(self):
+ if self._server_impl is None:
+ return False
+ return self._server_impl.is_serving()
+
+ async def start_serving(self):
+ await self.bind()
+ await self._server_impl.start_serving()
+
+ async def serve_forever(self):
+ await self.start_serving()
+ await self._server_impl.serve_forever()
+
+ async def close(self):
+ if self._server_impl is None:
+ return
+ self._server_impl.close()
+ streams = list(self._streams.keys())
+ active_tasks = list(self._streams.values())
+ if streams:
+ await tasks.wait([stream.close() for stream in streams])
+ await self._server_impl.wait_closed()
+ self._server_impl = None
+ await self._shutdown_active_tasks(active_tasks)
+
+ async def abort(self):
+ if self._server_impl is None:
+ return
+ self._server_impl.close()
+ streams = list(self._streams.keys())
+ active_tasks = list(self._streams.values())
+ if streams:
+ await tasks.wait([stream.abort() for stream in streams])
+ await self._server_impl.wait_closed()
+ self._server_impl = None
+ await self._shutdown_active_tasks(active_tasks)
+
+ async def __aenter__(self):
+ await self.bind()
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, exc_tb):
+ await self.close()
+
+ def _attach(self, stream, task):
+ self._streams[stream] = task
+
+ def _detach(self, stream, task):
+ del self._streams[stream]
+
+ async def _shutdown_active_tasks(self, active_tasks):
+ if not active_tasks:
+ return
+ # NOTE: tasks finished with exception are reported
+ # by the Task.__del__() method.
+ done, pending = await tasks.wait(active_tasks,
+ timeout=self._shutdown_timeout)
+ if not pending:
+ return
+ for task in pending:
+ task.cancel()
+ done, pending = await tasks.wait(pending,
+ timeout=self._shutdown_timeout)
+ for task in pending:
+ self._loop.call_exception_handler({
+ "message": (f'{task!r} ignored cancellation request '
+ f'from a closing {self!r}'),
+ "stream_server": self
+ })
+
+ def __repr__(self):
+ ret = [f'{self.__class__.__name__}']
+ if self.is_serving():
+ ret.append('serving')
+ if self.sockets:
+ ret.append(f'sockets={self.sockets!r}')
+ return '<' + ' '.join(ret) + '>'
+
+ def __del__(self, _warn=warnings.warn):
+ if self._server_impl is not None:
+ _warn(f"unclosed stream server {self!r}",
+ ResourceWarning, source=self)
+ self._server_impl.close()
+
+
+class StreamServer(_BaseStreamServer):
+
+ def __init__(self, client_connected_cb, /, host=None, port=None, *,
+ limit=_DEFAULT_LIMIT,
+ family=socket.AF_UNSPEC,
+ flags=socket.AI_PASSIVE, sock=None, backlog=100,
+ ssl=None, reuse_address=None, reuse_port=None,
+ ssl_handshake_timeout=None,
+ shutdown_timeout=60):
+ super().__init__(client_connected_cb,
+ limit=limit,
+ shutdown_timeout=shutdown_timeout,
+ _asyncio_internal=True)
+ self._host = host
+ self._port = port
+ self._family = family
+ self._flags = flags
+ self._sock = sock
+ self._backlog = backlog
+ self._ssl = ssl
+ self._reuse_address = reuse_address
+ self._reuse_port = reuse_port
+ self._ssl_handshake_timeout = ssl_handshake_timeout
+
+ async def _bind(self):
+ def factory():
+ protocol = _ServerStreamProtocol(self,
+ self._limit,
+ self._client_connected_cb,
+ loop=self._loop,
+ _asyncio_internal=True)
+ return protocol
+ return await self._loop.create_server(
+ factory,
+ self._host,
+ self._port,
+ start_serving=False,
+ family=self._family,
+ flags=self._flags,
+ sock=self._sock,
+ backlog=self._backlog,
+ ssl=self._ssl,
+ reuse_address=self._reuse_address,
+ reuse_port=self._reuse_port,
+ ssl_handshake_timeout=self._ssl_handshake_timeout)
+
+
if hasattr(socket, 'AF_UNIX'):
# UNIX Domain Sockets are supported on this platform
async def open_unix_connection(path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
+ warnings.warn("open_unix_connection() is deprecated since Python 3.8 "
+ "in favor of connect_unix(), and scheduled for removal "
+ "in Python 3.10",
+ DeprecationWarning,
+ stacklevel=2)
if loop is None:
loop = events.get_event_loop()
- reader = StreamReader(limit=limit, loop=loop,
- _asyncio_internal=True)
+ reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, loop=loop,
_asyncio_internal=True)
transport, _ = await loop.create_unix_connection(
lambda: protocol, path, **kwds)
- writer = StreamWriter(transport, protocol, reader, loop,
- _asyncio_internal=True)
+ writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
+
+ def connect_unix(path=None, *,
+ limit=_DEFAULT_LIMIT,
+ ssl=None, sock=None,
+ server_hostname=None,
+ ssl_handshake_timeout=None):
+ """Similar to `connect()` but works with UNIX Domain Sockets."""
+ # Design note:
+ # Don't use decorator approach but exilicit non-async
+ # function to fail fast and explicitly
+ # if passed arguments don't match the function signature
+ return _ContextManagerHelper(_connect_unix(path,
+ limit,
+ ssl, sock,
+ server_hostname,
+ ssl_handshake_timeout))
+
+
+ async def _connect_unix(path,
+ limit,
+ ssl, sock,
+ server_hostname,
+ ssl_handshake_timeout):
+ """Similar to `connect()` but works with UNIX Domain Sockets."""
+ loop = events.get_running_loop()
+ stream = Stream(mode=StreamMode.READWRITE,
+ limit=limit,
+ loop=loop,
+ _asyncio_internal=True)
+ await loop.create_unix_connection(
+ lambda: _StreamProtocol(stream,
+ loop=loop,
+ _asyncio_internal=True),
+ path,
+ ssl=ssl,
+ sock=sock,
+ server_hostname=server_hostname,
+ ssl_handshake_timeout=ssl_handshake_timeout)
+ return stream
+
+
async def start_unix_server(client_connected_cb, path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `start_server` but works with UNIX Domain Sockets."""
+ warnings.warn("start_unix_server() is deprecated since Python 3.8 "
+ "in favor of UnixStreamServer(), and scheduled "
+ "for removal in Python 3.10",
+ DeprecationWarning,
+ stacklevel=2)
if loop is None:
loop = events.get_event_loop()
def factory():
- reader = StreamReader(limit=limit, loop=loop,
- _asyncio_internal=True)
+ reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop,
_asyncio_internal=True)
@@ -125,6 +483,42 @@ if hasattr(socket, 'AF_UNIX'):
return await loop.create_unix_server(factory, path, **kwds)
+ class UnixStreamServer(_BaseStreamServer):
+
+ def __init__(self, client_connected_cb, /, path=None, *,
+ limit=_DEFAULT_LIMIT,
+ sock=None,
+ backlog=100,
+ ssl=None,
+ ssl_handshake_timeout=None,
+ shutdown_timeout=60):
+ super().__init__(client_connected_cb,
+ limit=limit,
+ shutdown_timeout=shutdown_timeout,
+ _asyncio_internal=True)
+ self._path = path
+ self._sock = sock
+ self._backlog = backlog
+ self._ssl = ssl
+ self._ssl_handshake_timeout = ssl_handshake_timeout
+
+ async def _bind(self):
+ def factory():
+ protocol = _ServerStreamProtocol(self,
+ self._limit,
+ self._client_connected_cb,
+ loop=self._loop,
+ _asyncio_internal=True)
+ return protocol
+ return await self._loop.create_unix_server(
+ factory,
+ self._path,
+ start_serving=False,
+ sock=self._sock,
+ backlog=self._backlog,
+ ssl=self._ssl,
+ ssl_handshake_timeout=self._ssl_handshake_timeout)
+
class FlowControlMixin(protocols.Protocol):
"""Reusable flow control logic for StreamWriter.drain().
@@ -203,6 +597,8 @@ class FlowControlMixin(protocols.Protocol):
raise NotImplementedError
+# begin legacy stream APIs
+
class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
"""Helper class to adapt between Protocol and StreamReader.
@@ -212,105 +608,47 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
call inappropriate methods of the protocol.)
"""
- _source_traceback = None
-
def __init__(self, stream_reader, client_connected_cb=None, loop=None,
*, _asyncio_internal=False):
super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
- if stream_reader is not None:
- self._stream_reader_wr = weakref.ref(stream_reader,
- self._on_reader_gc)
- self._source_traceback = stream_reader._source_traceback
- else:
- self._stream_reader_wr = None
- if client_connected_cb is not None:
- # This is a stream created by the `create_server()` function.
- # Keep a strong reference to the reader until a connection
- # is established.
- self._strong_reader = stream_reader
- self._reject_connection = False
+ self._stream_reader = stream_reader
self._stream_writer = None
- self._transport = None
self._client_connected_cb = client_connected_cb
self._over_ssl = False
self._closed = self._loop.create_future()
- def _on_reader_gc(self, wr):
- transport = self._transport
- if transport is not None:
- # connection_made was called
- context = {
- 'message': ('An open stream object is being garbage '
- 'collected; call "stream.close()" explicitly.')
- }
- if self._source_traceback:
- context['source_traceback'] = self._source_traceback
- self._loop.call_exception_handler(context)
- transport.abort()
- else:
- self._reject_connection = True
- self._stream_reader_wr = None
-
- @property
- def _stream_reader(self):
- if self._stream_reader_wr is None:
- return None
- return self._stream_reader_wr()
-
def connection_made(self, transport):
- if self._reject_connection:
- context = {
- 'message': ('An open stream was garbage collected prior to '
- 'establishing network connection; '
- 'call "stream.close()" explicitly.')
- }
- if self._source_traceback:
- context['source_traceback'] = self._source_traceback
- self._loop.call_exception_handler(context)
- transport.abort()
- return
- self._transport = transport
- reader = self._stream_reader
- if reader is not None:
- reader.set_transport(transport)
+ self._stream_reader.set_transport(transport)
self._over_ssl = transport.get_extra_info('sslcontext') is not None
if self._client_connected_cb is not None:
self._stream_writer = StreamWriter(transport, self,
- reader,
- self._loop,
- _asyncio_internal=True)
- res = self._client_connected_cb(reader,
+ self._stream_reader,
+ self._loop)
+ res = self._client_connected_cb(self._stream_reader,
self._stream_writer)
if coroutines.iscoroutine(res):
self._loop.create_task(res)
- self._strong_reader = None
def connection_lost(self, exc):
- reader = self._stream_reader
- if reader is not None:
+ if self._stream_reader is not None:
if exc is None:
- reader.feed_eof()
+ self._stream_reader.feed_eof()
else:
- reader.set_exception(exc)
+ self._stream_reader.set_exception(exc)
if not self._closed.done():
if exc is None:
self._closed.set_result(None)
else:
self._closed.set_exception(exc)
super().connection_lost(exc)
- self._stream_reader_wr = None
+ self._stream_reader = None
self._stream_writer = None
- self._transport = None
def data_received(self, data):
- reader = self._stream_reader
- if reader is not None:
- reader.feed_data(data)
+ self._stream_reader.feed_data(data)
def eof_received(self):
- reader = self._stream_reader
- if reader is not None:
- reader.feed_eof()
+ self._stream_reader.feed_eof()
if self._over_ssl:
# Prevent a warning in SSLProtocol.eof_received:
# "returning true from eof_received()
@@ -318,9 +656,6 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
return False
return True
- def _get_close_waiter(self, stream):
- return self._closed
-
def __del__(self):
# Prevent reports about unhandled exceptions.
# Better than self._closed._log_traceback = False hack
@@ -329,13 +664,6 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
closed.exception()
-def _swallow_unhandled_exception(task):
- # Do a trick to suppress unhandled exception
- # if stream.write() was used without await and
- # stream.drain() was paused and resumed with an exception
- task.exception()
-
-
class StreamWriter:
"""Wraps a Transport.
@@ -346,21 +674,13 @@ class StreamWriter:
directly.
"""
- def __init__(self, transport, protocol, reader, loop,
- *, _asyncio_internal=False):
- if not _asyncio_internal:
- warnings.warn(f"{self.__class__} should be instaniated "
- "by asyncio internals only, "
- "please avoid its creation from user code",
- DeprecationWarning)
+ def __init__(self, transport, protocol, reader, loop):
self._transport = transport
self._protocol = protocol
# drain() expects that the reader has an exception() method
assert reader is None or isinstance(reader, StreamReader)
self._reader = reader
self._loop = loop
- self._complete_fut = self._loop.create_future()
- self._complete_fut.set_result(None)
def __repr__(self):
info = [self.__class__.__name__, f'transport={self._transport!r}']
@@ -374,35 +694,9 @@ class StreamWriter:
def write(self, data):
self._transport.write(data)
- return self._fast_drain()
def writelines(self, data):
self._transport.writelines(data)
- return self._fast_drain()
-
- def _fast_drain(self):
- # The helper tries to use fast-path to return already existing complete future
- # object if underlying transport is not paused and actual waiting for writing
- # resume is not needed
- if self._reader is not None:
- # this branch will be simplified after merging reader with writer
- exc = self._reader.exception()
- if exc is not None:
- fut = self._loop.create_future()
- fut.set_exception(exc)
- return fut
- if not self._transport.is_closing():
- if self._protocol._connection_lost:
- fut = self._loop.create_future()
- fut.set_exception(ConnectionResetError('Connection lost'))
- return fut
- if not self._protocol._paused:
- # fast path, the stream is not paused
- # no need to wait for resume signal
- return self._complete_fut
- ret = self._loop.create_task(self.drain())
- ret.add_done_callback(_swallow_unhandled_exception)
- return ret
def write_eof(self):
return self._transport.write_eof()
@@ -411,14 +705,13 @@ class StreamWriter:
return self._transport.can_write_eof()
def close(self):
- self._transport.close()
- return self._protocol._get_close_waiter(self)
+ return self._transport.close()
def is_closing(self):
return self._transport.is_closing()
async def wait_closed(self):
- await self._protocol._get_close_waiter(self)
+ await self._protocol._closed
def get_extra_info(self, name, default=None):
return self._transport.get_extra_info(name, default)
@@ -436,24 +729,562 @@ class StreamWriter:
if exc is not None:
raise exc
if self._transport.is_closing():
- # Wait for protocol.connection_lost() call
- # Raise connection closing error if any,
- # ConnectionResetError otherwise
- await sleep(0)
+ # Yield to the event loop so connection_lost() may be
+ # called. Without this, _drain_helper() would return
+ # immediately, and code that calls
+ # write(...); await drain()
+ # in a loop would never call connection_lost(), so it
+ # would not see an error when the socket is closed.
+ await tasks.sleep(0, loop=self._loop)
await self._protocol._drain_helper()
class StreamReader:
+ def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
+ # The line length limit is a security feature;
+ # it also doubles as half the buffer limit.
+
+ if limit <= 0:
+ raise ValueError('Limit cannot be <= 0')
+
+ self._limit = limit
+ if loop is None:
+ self._loop = events.get_event_loop()
+ else:
+ self._loop = loop
+ self._buffer = bytearray()
+ self._eof = False # Whether we're done.
+ self._waiter = None # A future used by _wait_for_data()
+ self._exception = None
+ self._transport = None
+ self._paused = False
+
+ def __repr__(self):
+ info = ['StreamReader']
+ if self._buffer:
+ info.append(f'{len(self._buffer)} bytes')
+ if self._eof:
+ info.append('eof')
+ if self._limit != _DEFAULT_LIMIT:
+ info.append(f'limit={self._limit}')
+ if self._waiter:
+ info.append(f'waiter={self._waiter!r}')
+ if self._exception:
+ info.append(f'exception={self._exception!r}')
+ if self._transport:
+ info.append(f'transport={self._transport!r}')
+ if self._paused:
+ info.append('paused')
+ return '<{}>'.format(' '.join(info))
+
+ def exception(self):
+ return self._exception
+
+ def set_exception(self, exc):
+ self._exception = exc
+
+ waiter = self._waiter
+ if waiter is not None:
+ self._waiter = None
+ if not waiter.cancelled():
+ waiter.set_exception(exc)
+
+ def _wakeup_waiter(self):
+ """Wakeup read*() functions waiting for data or EOF."""
+ waiter = self._waiter
+ if waiter is not None:
+ self._waiter = None
+ if not waiter.cancelled():
+ waiter.set_result(None)
+
+ def set_transport(self, transport):
+ assert self._transport is None, 'Transport already set'
+ self._transport = transport
+
+ def _maybe_resume_transport(self):
+ if self._paused and len(self._buffer) <= self._limit:
+ self._paused = False
+ self._transport.resume_reading()
+
+ def feed_eof(self):
+ self._eof = True
+ self._wakeup_waiter()
+
+ def at_eof(self):
+ """Return True if the buffer is empty and 'feed_eof' was called."""
+ return self._eof and not self._buffer
+
+ def feed_data(self, data):
+ assert not self._eof, 'feed_data after feed_eof'
+
+ if not data:
+ return
+
+ self._buffer.extend(data)
+ self._wakeup_waiter()
+
+ if (self._transport is not None and
+ not self._paused and
+ len(self._buffer) > 2 * self._limit):
+ try:
+ self._transport.pause_reading()
+ except NotImplementedError:
+ # The transport can't be paused.
+ # We'll just have to buffer all data.
+ # Forget the transport so we don't keep trying.
+ self._transport = None
+ else:
+ self._paused = True
+
+ async def _wait_for_data(self, func_name):
+ """Wait until feed_data() or feed_eof() is called.
+
+ If stream was paused, automatically resume it.
+ """
+ # StreamReader uses a future to link the protocol feed_data() method
+ # to a read coroutine. Running two read coroutines at the same time
+ # would have an unexpected behaviour. It would not possible to know
+ # which coroutine would get the next data.
+ if self._waiter is not None:
+ raise RuntimeError(
+ f'{func_name}() called while another coroutine is '
+ f'already waiting for incoming data')
+
+ assert not self._eof, '_wait_for_data after EOF'
+
+ # Waiting for data while paused will make deadlock, so prevent it.
+ # This is essential for readexactly(n) for case when n > self._limit.
+ if self._paused:
+ self._paused = False
+ self._transport.resume_reading()
+
+ self._waiter = self._loop.create_future()
+ try:
+ await self._waiter
+ finally:
+ self._waiter = None
+
+ async def readline(self):
+ """Read chunk of data from the stream until newline (b'\n') is found.
+
+ On success, return chunk that ends with newline. If only partial
+ line can be read due to EOF, return incomplete line without
+ terminating newline. When EOF was reached while no bytes read, empty
+ bytes object is returned.
+
+ If limit is reached, ValueError will be raised. In that case, if
+ newline was found, complete line including newline will be removed
+ from internal buffer. Else, internal buffer will be cleared. Limit is
+ compared against part of the line without newline.
+
+ If stream was paused, this function will automatically resume it if
+ needed.
+ """
+ sep = b'\n'
+ seplen = len(sep)
+ try:
+ line = await self.readuntil(sep)
+ except exceptions.IncompleteReadError as e:
+ return e.partial
+ except exceptions.LimitOverrunError as e:
+ if self._buffer.startswith(sep, e.consumed):
+ del self._buffer[:e.consumed + seplen]
+ else:
+ self._buffer.clear()
+ self._maybe_resume_transport()
+ raise ValueError(e.args[0])
+ return line
+
+ async def readuntil(self, separator=b'\n'):
+ """Read data from the stream until ``separator`` is found.
+
+ On success, the data and separator will be removed from the
+ internal buffer (consumed). Returned data will include the
+ separator at the end.
+
+ Configured stream limit is used to check result. Limit sets the
+ maximal length of data that can be returned, not counting the
+ separator.
+
+ If an EOF occurs and the complete separator is still not found,
+ an IncompleteReadError exception will be raised, and the internal
+ buffer will be reset. The IncompleteReadError.partial attribute
+ may contain the separator partially.
+
+ If the data cannot be read because of over limit, a
+ LimitOverrunError exception will be raised, and the data
+ will be left in the internal buffer, so it can be read again.
+ """
+ seplen = len(separator)
+ if seplen == 0:
+ raise ValueError('Separator should be at least one-byte string')
+
+ if self._exception is not None:
+ raise self._exception
+
+ # Consume whole buffer except last bytes, which length is
+ # one less than seplen. Let's check corner cases with
+ # separator='SEPARATOR':
+ # * we have received almost complete separator (without last
+ # byte). i.e buffer='some textSEPARATO'. In this case we
+ # can safely consume len(separator) - 1 bytes.
+ # * last byte of buffer is first byte of separator, i.e.
+ # buffer='abcdefghijklmnopqrS'. We may safely consume
+ # everything except that last byte, but this require to
+ # analyze bytes of buffer that match partial separator.
+ # This is slow and/or require FSM. For this case our
+ # implementation is not optimal, since require rescanning
+ # of data that is known to not belong to separator. In
+ # real world, separator will not be so long to notice
+ # performance problems. Even when reading MIME-encoded
+ # messages :)
+
+ # `offset` is the number of bytes from the beginning of the buffer
+ # where there is no occurrence of `separator`.
+ offset = 0
+
+ # Loop until we find `separator` in the buffer, exceed the buffer size,
+ # or an EOF has happened.
+ while True:
+ buflen = len(self._buffer)
+
+ # Check if we now have enough data in the buffer for `separator` to
+ # fit.
+ if buflen - offset >= seplen:
+ isep = self._buffer.find(separator, offset)
+
+ if isep != -1:
+ # `separator` is in the buffer. `isep` will be used later
+ # to retrieve the data.
+ break
+
+ # see upper comment for explanation.
+ offset = buflen + 1 - seplen
+ if offset > self._limit:
+ raise exceptions.LimitOverrunError(
+ 'Separator is not found, and chunk exceed the limit',
+ offset)
+
+ # Complete message (with full separator) may be present in buffer
+ # even when EOF flag is set. This may happen when the last chunk
+ # adds data which makes separator be found. That's why we check for
+ # EOF *ater* inspecting the buffer.
+ if self._eof:
+ chunk = bytes(self._buffer)
+ self._buffer.clear()
+ raise exceptions.IncompleteReadError(chunk, None)
+
+ # _wait_for_data() will resume reading if stream was paused.
+ await self._wait_for_data('readuntil')
+
+ if isep > self._limit:
+ raise exceptions.LimitOverrunError(
+ 'Separator is found, but chunk is longer than limit', isep)
+
+ chunk = self._buffer[:isep + seplen]
+ del self._buffer[:isep + seplen]
+ self._maybe_resume_transport()
+ return bytes(chunk)
+
+ async def read(self, n=-1):
+ """Read up to `n` bytes from the stream.
+
+ If n is not provided, or set to -1, read until EOF and return all read
+ bytes. If the EOF was received and the internal buffer is empty, return
+ an empty bytes object.
+
+ If n is zero, return empty bytes object immediately.
+
+ If n is positive, this function try to read `n` bytes, and may return
+ less or equal bytes than requested, but at least one byte. If EOF was
+ received before any byte is read, this function returns empty byte
+ object.
+
+ Returned value is not limited with limit, configured at stream
+ creation.
+
+ If stream was paused, this function will automatically resume it if
+ needed.
+ """
+
+ if self._exception is not None:
+ raise self._exception
+
+ if n == 0:
+ return b''
+
+ if n < 0:
+ # This used to just loop creating a new waiter hoping to
+ # collect everything in self._buffer, but that would
+ # deadlock if the subprocess sends more than self.limit
+ # bytes. So just call self.read(self._limit) until EOF.
+ blocks = []
+ while True:
+ block = await self.read(self._limit)
+ if not block:
+ break
+ blocks.append(block)
+ return b''.join(blocks)
+
+ if not self._buffer and not self._eof:
+ await self._wait_for_data('read')
+
+ # This will work right even if buffer is less than n bytes
+ data = bytes(self._buffer[:n])
+ del self._buffer[:n]
+
+ self._maybe_resume_transport()
+ return data
+
+ async def readexactly(self, n):
+ """Read exactly `n` bytes.
+
+ Raise an IncompleteReadError if EOF is reached before `n` bytes can be
+ read. The IncompleteReadError.partial attribute of the exception will
+ contain the partial read bytes.
+
+ if n is zero, return empty bytes object.
+
+ Returned value is not limited with limit, configured at stream
+ creation.
+
+ If stream was paused, this function will automatically resume it if
+ needed.
+ """
+ if n < 0:
+ raise ValueError('readexactly size can not be less than zero')
+
+ if self._exception is not None:
+ raise self._exception
+
+ if n == 0:
+ return b''
+
+ while len(self._buffer) < n:
+ if self._eof:
+ incomplete = bytes(self._buffer)
+ self._buffer.clear()
+ raise exceptions.IncompleteReadError(incomplete, n)
+
+ await self._wait_for_data('readexactly')
+
+ if len(self._buffer) == n:
+ data = bytes(self._buffer)
+ self._buffer.clear()
+ else:
+ data = bytes(self._buffer[:n])
+ del self._buffer[:n]
+ self._maybe_resume_transport()
+ return data
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ val = await self.readline()
+ if val == b'':
+ raise StopAsyncIteration
+ return val
+
+
+# end legacy stream APIs
+
+
+class _BaseStreamProtocol(FlowControlMixin, protocols.Protocol):
+ """Helper class to adapt between Protocol and StreamReader.
+
+ (This is a helper class instead of making StreamReader itself a
+ Protocol subclass, because the StreamReader has other potential
+ uses, and to prevent the user of the StreamReader to accidentally
+ call inappropriate methods of the protocol.)
+ """
+
+ _stream = None # initialized in derived classes
+
+ def __init__(self, loop=None,
+ *, _asyncio_internal=False):
+ super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
+ self._transport = None
+ self._over_ssl = False
+ self._closed = self._loop.create_future()
+
+ def connection_made(self, transport):
+ self._transport = transport
+ self._over_ssl = transport.get_extra_info('sslcontext') is not None
+
+ def connection_lost(self, exc):
+ stream = self._stream
+ if stream is not None:
+ if exc is None:
+ stream.feed_eof()
+ else:
+ stream.set_exception(exc)
+ if not self._closed.done():
+ if exc is None:
+ self._closed.set_result(None)
+ else:
+ self._closed.set_exception(exc)
+ super().connection_lost(exc)
+ self._transport = None
+
+ def data_received(self, data):
+ stream = self._stream
+ if stream is not None:
+ stream.feed_data(data)
+
+ def eof_received(self):
+ stream = self._stream
+ if stream is not None:
+ stream.feed_eof()
+ if self._over_ssl:
+ # Prevent a warning in SSLProtocol.eof_received:
+ # "returning true from eof_received()
+ # has no effect when using ssl"
+ return False
+ return True
+
+ def _get_close_waiter(self, stream):
+ return self._closed
+
+ def __del__(self):
+ # Prevent reports about unhandled exceptions.
+ # Better than self._closed._log_traceback = False hack
+ closed = self._get_close_waiter(self._stream)
+ if closed.done() and not closed.cancelled():
+ closed.exception()
+
+
+class _StreamProtocol(_BaseStreamProtocol):
_source_traceback = None
- def __init__(self, limit=_DEFAULT_LIMIT, loop=None,
+ def __init__(self, stream, loop=None,
*, _asyncio_internal=False):
+ super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
+ self._source_traceback = stream._source_traceback
+ self._stream_wr = weakref.ref(stream, self._on_gc)
+ self._reject_connection = False
+
+ def _on_gc(self, wr):
+ transport = self._transport
+ if transport is not None:
+ # connection_made was called
+ context = {
+ 'message': ('An open stream object is being garbage '
+ 'collected; call "stream.close()" explicitly.')
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ transport.abort()
+ else:
+ self._reject_connection = True
+ self._stream_wr = None
+
+ @property
+ def _stream(self):
+ if self._stream_wr is None:
+ return None
+ return self._stream_wr()
+
+ def connection_made(self, transport):
+ if self._reject_connection:
+ context = {
+ 'message': ('An open stream was garbage collected prior to '
+ 'establishing network connection; '
+ 'call "stream.close()" explicitly.')
+ }
+ if self._source_traceback:
+ context['source_traceback'] = self._source_traceback
+ self._loop.call_exception_handler(context)
+ transport.abort()
+ return
+ super().connection_made(transport)
+ stream = self._stream
+ if stream is None:
+ return
+ stream.set_transport(transport)
+ stream._protocol = self
+
+ def connection_lost(self, exc):
+ super().connection_lost(exc)
+ self._stream_wr = None
+
+
+class _ServerStreamProtocol(_BaseStreamProtocol):
+ def __init__(self, server, limit, client_connected_cb, loop=None,
+ *, _asyncio_internal=False):
+ super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
+ assert self._closed
+ self._client_connected_cb = client_connected_cb
+ self._limit = limit
+ self._server = server
+ self._task = None
+
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ stream = Stream(mode=StreamMode.READWRITE,
+ transport=transport,
+ protocol=self,
+ limit=self._limit,
+ loop=self._loop,
+ is_server_side=True,
+ _asyncio_internal=True)
+ self._stream = stream
+ # If self._client_connected_cb(self._stream) fails
+ # the exception is logged by transport
+ self._task = self._loop.create_task(
+ self._client_connected_cb(self._stream))
+ self._server._attach(stream, self._task)
+
+ def connection_lost(self, exc):
+ super().connection_lost(exc)
+ self._server._detach(self._stream, self._task)
+ self._stream = None
+
+
+class _OptionalAwait:
+ # The class doesn't create a coroutine
+ # if not awaited
+ # It prevents "coroutine is never awaited" message
+
+ __slots___ = ('_method',)
+
+ def __init__(self, method):
+ self._method = method
+
+ def __await__(self):
+ return self._method().__await__()
+
+
+class Stream:
+ """Wraps a Transport.
+
+ This exposes write(), writelines(), [can_]write_eof(),
+ get_extra_info() and close(). It adds drain() which returns an
+ optional Future on which you can wait for flow control. It also
+ adds a transport property which references the Transport
+ directly.
+ """
+
+ _source_traceback = None
+
+ def __init__(self, mode, *,
+ transport=None,
+ protocol=None,
+ loop=None,
+ limit=_DEFAULT_LIMIT,
+ is_server_side=False,
+ _asyncio_internal=False):
if not _asyncio_internal:
warnings.warn(f"{self.__class__} should be instaniated "
"by asyncio internals only, "
"please avoid its creation from user code",
DeprecationWarning)
+ self._mode = mode
+ self._transport = transport
+ self._protocol = protocol
+ self._is_server_side = is_server_side
# The line length limit is a security feature;
# it also doubles as half the buffer limit.
@@ -470,14 +1301,17 @@ class StreamReader:
self._eof = False # Whether we're done.
self._waiter = None # A future used by _wait_for_data()
self._exception = None
- self._transport = None
self._paused = False
+ self._complete_fut = self._loop.create_future()
+ self._complete_fut.set_result(None)
+
if self._loop.get_debug():
self._source_traceback = format_helpers.extract_stack(
sys._getframe(1))
def __repr__(self):
- info = ['StreamReader']
+ info = [self.__class__.__name__]
+ info.append(f'mode={self._mode}')
if self._buffer:
info.append(f'{len(self._buffer)} bytes')
if self._eof:
@@ -494,6 +1328,110 @@ class StreamReader:
info.append('paused')
return '<{}>'.format(' '.join(info))
+ @property
+ def mode(self):
+ return self._mode
+
+ def is_server_side(self):
+ return self._is_server_side
+
+ @property
+ def transport(self):
+ return self._transport
+
+ def write(self, data):
+ _ensure_can_write(self._mode)
+ self._transport.write(data)
+ return self._fast_drain()
+
+ def writelines(self, data):
+ _ensure_can_write(self._mode)
+ self._transport.writelines(data)
+ return self._fast_drain()
+
+ def _fast_drain(self):
+ # The helper tries to use fast-path to return already existing
+ # complete future object if underlying transport is not paused
+ #and actual waiting for writing resume is not needed
+ exc = self.exception()
+ if exc is not None:
+ fut = self._loop.create_future()
+ fut.set_exception(exc)
+ return fut
+ if not self._transport.is_closing():
+ if self._protocol._connection_lost:
+ fut = self._loop.create_future()
+ fut.set_exception(ConnectionResetError('Connection lost'))
+ return fut
+ if not self._protocol._paused:
+ # fast path, the stream is not paused
+ # no need to wait for resume signal
+ return self._complete_fut
+ return _OptionalAwait(self.drain)
+
+ def write_eof(self):
+ _ensure_can_write(self._mode)
+ return self._transport.write_eof()
+
+ def can_write_eof(self):
+ if not self._mode.is_write():
+ return False
+ return self._transport.can_write_eof()
+
+ def close(self):
+ self._transport.close()
+ return _OptionalAwait(self.wait_closed)
+
+ def is_closing(self):
+ return self._transport.is_closing()
+
+ async def abort(self):
+ self._transport.abort()
+ await self.wait_closed()
+
+ async def wait_closed(self):
+ await self._protocol._get_close_waiter(self)
+
+ def get_extra_info(self, name, default=None):
+ return self._transport.get_extra_info(name, default)
+
+ async def drain(self):
+ """Flush the write buffer.
+
+ The intended use is to write
+
+ w.write(data)
+ await w.drain()
+ """
+ _ensure_can_write(self._mode)
+ exc = self.exception()
+ if exc is not None:
+ raise exc
+ if self._transport.is_closing():
+ # Wait for protocol.connection_lost() call
+ # Raise connection closing error if any,
+ # ConnectionResetError otherwise
+ await tasks.sleep(0)
+ await self._protocol._drain_helper()
+
+ async def sendfile(self, file, offset=0, count=None, *, fallback=True):
+ await self.drain() # check for stream mode and exceptions
+ return await self._loop.sendfile(self._transport, file,
+ offset, count, fallback=fallback)
+
+ async def start_tls(self, sslcontext, *,
+ server_hostname=None,
+ ssl_handshake_timeout=None):
+ await self.drain() # check for stream mode and exceptions
+ transport = await self._loop.start_tls(
+ self._transport, self._protocol, sslcontext,
+ server_side=self._is_server_side,
+ server_hostname=server_hostname,
+ ssl_handshake_timeout=ssl_handshake_timeout)
+ self._transport = transport
+ self._protocol._transport = transport
+ self._protocol._over_ssl = True
+
def exception(self):
return self._exception
@@ -515,6 +1453,8 @@ class StreamReader:
waiter.set_result(None)
def set_transport(self, transport):
+ if transport is self._transport:
+ return
assert self._transport is None, 'Transport already set'
self._transport = transport
@@ -532,6 +1472,7 @@ class StreamReader:
return self._eof and not self._buffer
def feed_data(self, data):
+ _ensure_can_read(self._mode)
assert not self._eof, 'feed_data after feed_eof'
if not data:
@@ -597,6 +1538,7 @@ class StreamReader:
If stream was paused, this function will automatically resume it if
needed.
"""
+ _ensure_can_read(self._mode)
sep = b'\n'
seplen = len(sep)
try:
@@ -632,6 +1574,7 @@ class StreamReader:
LimitOverrunError exception will be raised, and the data
will be left in the internal buffer, so it can be read again.
"""
+ _ensure_can_read(self._mode)
seplen = len(separator)
if seplen == 0:
raise ValueError('Separator should be at least one-byte string')
@@ -723,6 +1666,7 @@ class StreamReader:
If stream was paused, this function will automatically resume it if
needed.
"""
+ _ensure_can_read(self._mode)
if self._exception is not None:
raise self._exception
@@ -768,6 +1712,7 @@ class StreamReader:
If stream was paused, this function will automatically resume it if
needed.
"""
+ _ensure_can_read(self._mode)
if n < 0:
raise ValueError('readexactly size can not be less than zero')
@@ -795,6 +1740,7 @@ class StreamReader:
return data
def __aiter__(self):
+ _ensure_can_read(self._mode)
return self
async def __anext__(self):
@@ -802,3 +1748,9 @@ class StreamReader:
if val == b'':
raise StopAsyncIteration
return val
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()
diff --git a/Lib/asyncio/subprocess.py b/Lib/asyncio/subprocess.py
index d34b611..e6bec71 100644
--- a/Lib/asyncio/subprocess.py
+++ b/Lib/asyncio/subprocess.py
@@ -27,6 +27,8 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
self._process_exited = False
self._pipe_fds = []
self._stdin_closed = self._loop.create_future()
+ self._stdout_closed = self._loop.create_future()
+ self._stderr_closed = self._loop.create_future()
def __repr__(self):
info = [self.__class__.__name__]
@@ -40,30 +42,35 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
def connection_made(self, transport):
self._transport = transport
-
stdout_transport = transport.get_pipe_transport(1)
if stdout_transport is not None:
- self.stdout = streams.StreamReader(limit=self._limit,
- loop=self._loop,
- _asyncio_internal=True)
+ self.stdout = streams.Stream(mode=streams.StreamMode.READ,
+ transport=stdout_transport,
+ protocol=self,
+ limit=self._limit,
+ loop=self._loop,
+ _asyncio_internal=True)
self.stdout.set_transport(stdout_transport)
self._pipe_fds.append(1)
stderr_transport = transport.get_pipe_transport(2)
if stderr_transport is not None:
- self.stderr = streams.StreamReader(limit=self._limit,
- loop=self._loop,
- _asyncio_internal=True)
+ self.stderr = streams.Stream(mode=streams.StreamMode.READ,
+ transport=stderr_transport,
+ protocol=self,
+ limit=self._limit,
+ loop=self._loop,
+ _asyncio_internal=True)
self.stderr.set_transport(stderr_transport)
self._pipe_fds.append(2)
stdin_transport = transport.get_pipe_transport(0)
if stdin_transport is not None:
- self.stdin = streams.StreamWriter(stdin_transport,
- protocol=self,
- reader=None,
- loop=self._loop,
- _asyncio_internal=True)
+ self.stdin = streams.Stream(mode=streams.StreamMode.WRITE,
+ transport=stdin_transport,
+ protocol=self,
+ loop=self._loop,
+ _asyncio_internal=True)
def pipe_data_received(self, fd, data):
if fd == 1:
@@ -114,6 +121,10 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
def _get_close_waiter(self, stream):
if stream is self.stdin:
return self._stdin_closed
+ elif stream is self.stdout:
+ return self._stdout_closed
+ elif stream is self.stderr:
+ return self._stderr_closed
class Process:
diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py
index b5b2e24..61b40ba 100644
--- a/Lib/asyncio/windows_events.py
+++ b/Lib/asyncio/windows_events.py
@@ -607,7 +607,7 @@ class IocpProactor:
# ConnectPipe() failed with ERROR_PIPE_BUSY: retry later
delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY)
- await tasks.sleep(delay, loop=self._loop)
+ await tasks.sleep(delay)
return windows_utils.PipeHandle(handle)
diff --git a/Lib/test/test___all__.py b/Lib/test/test___all__.py
index f6e82eb..c077881 100644
--- a/Lib/test/test___all__.py
+++ b/Lib/test/test___all__.py
@@ -30,21 +30,27 @@ class AllTest(unittest.TestCase):
raise NoAll(modname)
names = {}
with self.subTest(module=modname):
- try:
- exec("from %s import *" % modname, names)
- except Exception as e:
- # Include the module name in the exception string
- self.fail("__all__ failure in {}: {}: {}".format(
- modname, e.__class__.__name__, e))
- if "__builtins__" in names:
- del names["__builtins__"]
- if '__annotations__' in names:
- del names['__annotations__']
- keys = set(names)
- all_list = sys.modules[modname].__all__
- all_set = set(all_list)
- self.assertCountEqual(all_set, all_list, "in module {}".format(modname))
- self.assertEqual(keys, all_set, "in module {}".format(modname))
+ with support.check_warnings(
+ ("", DeprecationWarning),
+ ("", ResourceWarning),
+ quiet=True):
+ try:
+ exec("from %s import *" % modname, names)
+ except Exception as e:
+ # Include the module name in the exception string
+ self.fail("__all__ failure in {}: {}: {}".format(
+ modname, e.__class__.__name__, e))
+ if "__builtins__" in names:
+ del names["__builtins__"]
+ if '__annotations__' in names:
+ del names['__annotations__']
+ if "__warningregistry__" in names:
+ del names["__warningregistry__"]
+ keys = set(names)
+ all_list = sys.modules[modname].__all__
+ all_set = set(all_list)
+ self.assertCountEqual(all_set, all_list, "in module {}".format(modname))
+ self.assertEqual(keys, all_set, "in module {}".format(modname))
def walk_modules(self, basedir, modpath):
for fn in sorted(os.listdir(basedir)):
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
index 31018c5..02a97c6 100644
--- a/Lib/test/test_asyncio/test_base_events.py
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -1152,8 +1152,9 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
@unittest.skipUnless(hasattr(socket, 'AF_INET6'), 'no IPv6 support')
def test_create_server_ipv6(self):
async def main():
- srv = await asyncio.start_server(
- lambda: None, '::1', 0, loop=self.loop)
+ with self.assertWarns(DeprecationWarning):
+ srv = await asyncio.start_server(
+ lambda: None, '::1', 0, loop=self.loop)
try:
self.assertGreater(len(srv.sockets), 0)
finally:
diff --git a/Lib/test/test_asyncio/test_buffered_proto.py b/Lib/test/test_asyncio/test_buffered_proto.py
index f24e363..b1531fb 100644
--- a/Lib/test/test_asyncio/test_buffered_proto.py
+++ b/Lib/test/test_asyncio/test_buffered_proto.py
@@ -58,9 +58,10 @@ class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin):
writer.close()
await writer.wait_closed()
- srv = self.loop.run_until_complete(
- asyncio.start_server(
- on_server_client, '127.0.0.1', 0))
+ with self.assertWarns(DeprecationWarning):
+ srv = self.loop.run_until_complete(
+ asyncio.start_server(
+ on_server_client, '127.0.0.1', 0))
addr = srv.sockets[0].getsockname()
self.loop.run_until_complete(
diff --git a/Lib/test/test_asyncio/test_pep492.py b/Lib/test/test_asyncio/test_pep492.py
index 297a3b3..11c0ce4 100644
--- a/Lib/test/test_asyncio/test_pep492.py
+++ b/Lib/test/test_asyncio/test_pep492.py
@@ -94,7 +94,9 @@ class StreamReaderTests(BaseTest):
def test_readline(self):
DATA = b'line1\nline2\nline3'
- stream = asyncio.StreamReader(loop=self.loop, _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(DATA)
stream.feed_eof()
diff --git a/Lib/test/test_asyncio/test_server.py b/Lib/test/test_asyncio/test_server.py
index 4e758ad..0e38e6c 100644
--- a/Lib/test/test_asyncio/test_server.py
+++ b/Lib/test/test_asyncio/test_server.py
@@ -46,8 +46,9 @@ class BaseStartServer(func_tests.FunctionalTestCaseMixin):
async with srv:
await srv.serve_forever()
- srv = self.loop.run_until_complete(asyncio.start_server(
- serve, support.HOSTv4, 0, loop=self.loop, start_serving=False))
+ with self.assertWarns(DeprecationWarning):
+ srv = self.loop.run_until_complete(asyncio.start_server(
+ serve, support.HOSTv4, 0, loop=self.loop, start_serving=False))
self.assertFalse(srv.is_serving())
@@ -102,8 +103,9 @@ class SelectorStartServerTests(BaseStartServer, unittest.TestCase):
await srv.serve_forever()
with test_utils.unix_socket_path() as addr:
- srv = self.loop.run_until_complete(asyncio.start_unix_server(
- serve, addr, loop=self.loop, start_serving=False))
+ with self.assertWarns(DeprecationWarning):
+ srv = self.loop.run_until_complete(asyncio.start_unix_server(
+ serve, addr, loop=self.loop, start_serving=False))
main_task = self.loop.create_task(main(srv))
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py
index 079b255..4215abf 100644
--- a/Lib/test/test_asyncio/test_sslproto.py
+++ b/Lib/test/test_asyncio/test_sslproto.py
@@ -649,12 +649,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
sock.close()
async def client(addr):
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- loop=self.loop,
- ssl_handshake_timeout=1.0)
+ with self.assertWarns(DeprecationWarning):
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ loop=self.loop,
+ ssl_handshake_timeout=1.0)
with self.tcp_server(server,
max_clients=1,
@@ -688,12 +689,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
sock.close()
async def client(addr):
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- loop=self.loop,
- ssl_handshake_timeout=1.0)
+ with self.assertWarns(DeprecationWarning):
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ loop=self.loop,
+ ssl_handshake_timeout=1.0)
with self.tcp_server(server,
max_clients=1,
@@ -724,11 +726,12 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
sock.close()
async def client(addr):
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- loop=self.loop)
+ with self.assertWarns(DeprecationWarning):
+ reader, writer = await asyncio.open_connection(
+ *addr,
+ ssl=client_sslctx,
+ server_hostname='',
+ loop=self.loop)
self.assertEqual(await reader.readline(), b'A\n')
writer.write(b'B')
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py
index fed6098..df3d7e7 100644
--- a/Lib/test/test_asyncio/test_streams.py
+++ b/Lib/test/test_asyncio/test_streams.py
@@ -1,6 +1,8 @@
"""Tests for streams.py."""
+import contextlib
import gc
+import io
import os
import queue
import pickle
@@ -16,6 +18,7 @@ except ImportError:
ssl = None
import asyncio
+from asyncio.streams import _StreamProtocol, _ensure_can_read, _ensure_can_write
from test.test_asyncio import utils as test_utils
@@ -23,6 +26,24 @@ def tearDownModule():
asyncio.set_event_loop_policy(None)
+class StreamModeTests(unittest.TestCase):
+ def test__ensure_can_read_ok(self):
+ self.assertIsNone(_ensure_can_read(asyncio.StreamMode.READ))
+ self.assertIsNone(_ensure_can_read(asyncio.StreamMode.READWRITE))
+
+ def test__ensure_can_read_fail(self):
+ with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+ _ensure_can_read(asyncio.StreamMode.WRITE)
+
+ def test__ensure_can_write_ok(self):
+ self.assertIsNone(_ensure_can_write(asyncio.StreamMode.WRITE))
+ self.assertIsNone(_ensure_can_write(asyncio.StreamMode.READWRITE))
+
+ def test__ensure_can_write_fail(self):
+ with self.assertRaisesRegex(RuntimeError, "The stream is read-only"):
+ _ensure_can_write(asyncio.StreamMode.READ)
+
+
class StreamTests(test_utils.TestCase):
DATA = b'line1\nline2\nline3\n'
@@ -42,13 +63,15 @@ class StreamTests(test_utils.TestCase):
@mock.patch('asyncio.streams.events')
def test_ctor_global_loop(self, m_events):
- stream = asyncio.StreamReader(_asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ _asyncio_internal=True)
self.assertIs(stream._loop, m_events.get_event_loop.return_value)
def _basetest_open_connection(self, open_connection_fut):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
- reader, writer = self.loop.run_until_complete(open_connection_fut)
+ with self.assertWarns(DeprecationWarning):
+ reader, writer = self.loop.run_until_complete(open_connection_fut)
writer.write(b'GET / HTTP/1.0\r\n\r\n')
f = reader.readline()
data = self.loop.run_until_complete(f)
@@ -76,7 +99,9 @@ class StreamTests(test_utils.TestCase):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
try:
- reader, writer = self.loop.run_until_complete(open_connection_fut)
+ with self.assertWarns(DeprecationWarning):
+ reader, writer = self.loop.run_until_complete(
+ open_connection_fut)
finally:
asyncio.set_event_loop(None)
writer.write(b'GET / HTTP/1.0\r\n\r\n')
@@ -112,7 +137,8 @@ class StreamTests(test_utils.TestCase):
def _basetest_open_connection_error(self, open_connection_fut):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
- reader, writer = self.loop.run_until_complete(open_connection_fut)
+ with self.assertWarns(DeprecationWarning):
+ reader, writer = self.loop.run_until_complete(open_connection_fut)
writer._protocol.connection_lost(ZeroDivisionError())
f = reader.read()
with self.assertRaises(ZeroDivisionError):
@@ -135,23 +161,26 @@ class StreamTests(test_utils.TestCase):
self._basetest_open_connection_error(conn_fut)
def test_feed_empty_data(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'')
self.assertEqual(b'', stream._buffer)
def test_feed_nonempty_data(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(self.DATA)
self.assertEqual(self.DATA, stream._buffer)
def test_read_zero(self):
# Read zero bytes.
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(self.DATA)
data = self.loop.run_until_complete(stream.read(0))
@@ -160,8 +189,9 @@ class StreamTests(test_utils.TestCase):
def test_read(self):
# Read bytes.
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
read_task = asyncio.Task(stream.read(30), loop=self.loop)
def cb():
@@ -174,8 +204,9 @@ class StreamTests(test_utils.TestCase):
def test_read_line_breaks(self):
# Read bytes without line breaks.
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'line1')
stream.feed_data(b'line2')
@@ -186,8 +217,9 @@ class StreamTests(test_utils.TestCase):
def test_read_eof(self):
# Read bytes, stop at eof.
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
read_task = asyncio.Task(stream.read(1024), loop=self.loop)
def cb():
@@ -200,8 +232,9 @@ class StreamTests(test_utils.TestCase):
def test_read_until_eof(self):
# Read all bytes until eof.
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
read_task = asyncio.Task(stream.read(-1), loop=self.loop)
def cb():
@@ -216,8 +249,9 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b'', stream._buffer)
def test_read_exception(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'line\n')
data = self.loop.run_until_complete(stream.read(2))
@@ -229,16 +263,19 @@ class StreamTests(test_utils.TestCase):
def test_invalid_limit(self):
with self.assertRaisesRegex(ValueError, 'imit'):
- asyncio.StreamReader(limit=0, loop=self.loop,
- _asyncio_internal=True)
+ asyncio.Stream(mode=asyncio.StreamMode.READ,
+ limit=0, loop=self.loop,
+ _asyncio_internal=True)
with self.assertRaisesRegex(ValueError, 'imit'):
- asyncio.StreamReader(limit=-1, loop=self.loop,
- _asyncio_internal=True)
+ asyncio.Stream(mode=asyncio.StreamMode.READ,
+ limit=-1, loop=self.loop,
+ _asyncio_internal=True)
def test_read_limit(self):
- stream = asyncio.StreamReader(limit=3, loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ limit=3, loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'chunk')
data = self.loop.run_until_complete(stream.read(5))
self.assertEqual(b'chunk', data)
@@ -247,8 +284,9 @@ class StreamTests(test_utils.TestCase):
def test_readline(self):
# Read one line. 'readline' will need to wait for the data
# to come from 'cb'
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'chunk1 ')
read_task = asyncio.Task(stream.readline(), loop=self.loop)
@@ -263,11 +301,12 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b' chunk4', stream._buffer)
def test_readline_limit_with_existing_data(self):
- # Read one line. The data is in StreamReader's buffer
+ # Read one line. The data is in Stream's buffer
# before the event loop is run.
- stream = asyncio.StreamReader(limit=3, loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ limit=3, loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'li')
stream.feed_data(b'ne1\nline2\n')
@@ -276,8 +315,9 @@ class StreamTests(test_utils.TestCase):
# The buffer should contain the remaining data after exception
self.assertEqual(b'line2\n', stream._buffer)
- stream = asyncio.StreamReader(limit=3, loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ limit=3, loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'li')
stream.feed_data(b'ne1')
stream.feed_data(b'li')
@@ -292,8 +332,9 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b'', stream._buffer)
def test_at_eof(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
self.assertFalse(stream.at_eof())
stream.feed_data(b'some data\n')
@@ -308,11 +349,12 @@ class StreamTests(test_utils.TestCase):
self.assertTrue(stream.at_eof())
def test_readline_limit(self):
- # Read one line. StreamReaders are fed with data after
+ # Read one line. Streams are fed with data after
# their 'readline' methods are called.
- stream = asyncio.StreamReader(limit=7, loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ limit=7, loop=self.loop,
+ _asyncio_internal=True)
def cb():
stream.feed_data(b'chunk1')
stream.feed_data(b'chunk2')
@@ -326,8 +368,9 @@ class StreamTests(test_utils.TestCase):
# a ValueError it should be empty.
self.assertEqual(b'', stream._buffer)
- stream = asyncio.StreamReader(limit=7, loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ limit=7, loop=self.loop,
+ _asyncio_internal=True)
def cb():
stream.feed_data(b'chunk1')
stream.feed_data(b'chunk2\n')
@@ -340,8 +383,9 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b'chunk3\n', stream._buffer)
# check strictness of the limit
- stream = asyncio.StreamReader(limit=7, loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ limit=7, loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'1234567\n')
line = self.loop.run_until_complete(stream.readline())
self.assertEqual(b'1234567\n', line)
@@ -360,8 +404,9 @@ class StreamTests(test_utils.TestCase):
def test_readline_nolimit_nowait(self):
# All needed data for the first 'readline' call will be
# in the buffer.
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(self.DATA[:6])
stream.feed_data(self.DATA[6:])
@@ -371,8 +416,9 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b'line2\nline3\n', stream._buffer)
def test_readline_eof(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'some data')
stream.feed_eof()
@@ -380,16 +426,18 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b'some data', line)
def test_readline_empty_eof(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_eof()
line = self.loop.run_until_complete(stream.readline())
self.assertEqual(b'', line)
def test_readline_read_byte_count(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(self.DATA)
self.loop.run_until_complete(stream.readline())
@@ -400,8 +448,9 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b'ine3\n', stream._buffer)
def test_readline_exception(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'line\n')
data = self.loop.run_until_complete(stream.readline())
@@ -413,14 +462,16 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b'', stream._buffer)
def test_readuntil_separator(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
with self.assertRaisesRegex(ValueError, 'Separator should be'):
self.loop.run_until_complete(stream.readuntil(separator=b''))
def test_readuntil_multi_chunks(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'lineAAA')
data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA'))
@@ -438,8 +489,9 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b'xxx', stream._buffer)
def test_readuntil_multi_chunks_1(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'QWEaa')
stream.feed_data(b'XYaa')
@@ -474,8 +526,9 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b'', stream._buffer)
def test_readuntil_eof(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'some dataAA')
stream.feed_eof()
@@ -486,8 +539,9 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b'', stream._buffer)
def test_readuntil_limit_found_sep(self):
- stream = asyncio.StreamReader(loop=self.loop, limit=3,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop, limit=3,
+ _asyncio_internal=True)
stream.feed_data(b'some dataAA')
with self.assertRaisesRegex(asyncio.LimitOverrunError,
@@ -505,8 +559,9 @@ class StreamTests(test_utils.TestCase):
def test_readexactly_zero_or_less(self):
# Read exact number of bytes (zero or less).
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(self.DATA)
data = self.loop.run_until_complete(stream.readexactly(0))
@@ -519,8 +574,9 @@ class StreamTests(test_utils.TestCase):
def test_readexactly(self):
# Read exact number of bytes.
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
n = 2 * len(self.DATA)
read_task = asyncio.Task(stream.readexactly(n), loop=self.loop)
@@ -536,8 +592,9 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(self.DATA, stream._buffer)
def test_readexactly_limit(self):
- stream = asyncio.StreamReader(limit=3, loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ limit=3, loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'chunk')
data = self.loop.run_until_complete(stream.readexactly(5))
self.assertEqual(b'chunk', data)
@@ -545,8 +602,9 @@ class StreamTests(test_utils.TestCase):
def test_readexactly_eof(self):
# Read exact number of bytes (eof).
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
n = 2 * len(self.DATA)
read_task = asyncio.Task(stream.readexactly(n), loop=self.loop)
@@ -564,8 +622,9 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b'', stream._buffer)
def test_readexactly_exception(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'line\n')
data = self.loop.run_until_complete(stream.readexactly(2))
@@ -576,8 +635,9 @@ class StreamTests(test_utils.TestCase):
ValueError, self.loop.run_until_complete, stream.readexactly(2))
def test_exception(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
self.assertIsNone(stream.exception())
exc = ValueError()
@@ -585,8 +645,9 @@ class StreamTests(test_utils.TestCase):
self.assertIs(stream.exception(), exc)
def test_exception_waiter(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
async def set_err():
stream.set_exception(ValueError())
@@ -599,8 +660,9 @@ class StreamTests(test_utils.TestCase):
self.assertRaises(ValueError, t1.result)
def test_exception_cancel(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
t = asyncio.Task(stream.readline(), loop=self.loop)
test_utils.run_briefly(self.loop)
@@ -655,8 +717,9 @@ class StreamTests(test_utils.TestCase):
self.server = None
async def client(addr):
- reader, writer = await asyncio.open_connection(
- *addr, loop=self.loop)
+ with self.assertWarns(DeprecationWarning):
+ reader, writer = await asyncio.open_connection(
+ *addr, loop=self.loop)
# send a line
writer.write(b"hello world!\n")
# read it back
@@ -670,7 +733,8 @@ class StreamTests(test_utils.TestCase):
# test the server variant with a coroutine as client handler
server = MyServer(self.loop)
- addr = server.start()
+ with self.assertWarns(DeprecationWarning):
+ addr = server.start()
msg = self.loop.run_until_complete(asyncio.Task(client(addr),
loop=self.loop))
server.stop()
@@ -678,7 +742,8 @@ class StreamTests(test_utils.TestCase):
# test the server variant with a callback as client handler
server = MyServer(self.loop)
- addr = server.start_callback()
+ with self.assertWarns(DeprecationWarning):
+ addr = server.start_callback()
msg = self.loop.run_until_complete(asyncio.Task(client(addr),
loop=self.loop))
server.stop()
@@ -726,8 +791,9 @@ class StreamTests(test_utils.TestCase):
self.server = None
async def client(path):
- reader, writer = await asyncio.open_unix_connection(
- path, loop=self.loop)
+ with self.assertWarns(DeprecationWarning):
+ reader, writer = await asyncio.open_unix_connection(
+ path, loop=self.loop)
# send a line
writer.write(b"hello world!\n")
# read it back
@@ -742,7 +808,8 @@ class StreamTests(test_utils.TestCase):
# test the server variant with a coroutine as client handler
with test_utils.unix_socket_path() as path:
server = MyServer(self.loop, path)
- server.start()
+ with self.assertWarns(DeprecationWarning):
+ server.start()
msg = self.loop.run_until_complete(asyncio.Task(client(path),
loop=self.loop))
server.stop()
@@ -751,7 +818,8 @@ class StreamTests(test_utils.TestCase):
# test the server variant with a callback as client handler
with test_utils.unix_socket_path() as path:
server = MyServer(self.loop, path)
- server.start_callback()
+ with self.assertWarns(DeprecationWarning):
+ server.start_callback()
msg = self.loop.run_until_complete(asyncio.Task(client(path),
loop=self.loop))
server.stop()
@@ -763,7 +831,7 @@ class StreamTests(test_utils.TestCase):
def test_read_all_from_pipe_reader(self):
# See asyncio issue 168. This test is derived from the example
# subprocess_attach_read_pipe.py, but we configure the
- # StreamReader's limit so that twice it is less than the size
+ # Stream's limit so that twice it is less than the size
# of the data writter. Also we must explicitly attach a child
# watcher to the event loop.
@@ -777,10 +845,11 @@ os.close(fd)
args = [sys.executable, '-c', code, str(wfd)]
pipe = open(rfd, 'rb', 0)
- reader = asyncio.StreamReader(loop=self.loop, limit=1,
- _asyncio_internal=True)
- protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop, limit=1,
+ _asyncio_internal=True)
+ protocol = _StreamProtocol(stream, loop=self.loop,
+ _asyncio_internal=True)
transport, _ = self.loop.run_until_complete(
self.loop.connect_read_pipe(lambda: protocol, pipe))
@@ -797,29 +866,30 @@ os.close(fd)
asyncio.set_child_watcher(None)
os.close(wfd)
- data = self.loop.run_until_complete(reader.read(-1))
+ data = self.loop.run_until_complete(stream.read(-1))
self.assertEqual(data, b'data')
def test_streamreader_constructor(self):
self.addCleanup(asyncio.set_event_loop, None)
asyncio.set_event_loop(self.loop)
- # asyncio issue #184: Ensure that StreamReaderProtocol constructor
+ # asyncio issue #184: Ensure that _StreamProtocol constructor
# retrieves the current loop if the loop parameter is not set
- reader = asyncio.StreamReader(_asyncio_internal=True)
+ reader = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ _asyncio_internal=True)
self.assertIs(reader._loop, self.loop)
def test_streamreaderprotocol_constructor(self):
self.addCleanup(asyncio.set_event_loop, None)
asyncio.set_event_loop(self.loop)
- # asyncio issue #184: Ensure that StreamReaderProtocol constructor
+ # asyncio issue #184: Ensure that _StreamProtocol constructor
# retrieves the current loop if the loop parameter is not set
- reader = mock.Mock()
- protocol = asyncio.StreamReaderProtocol(reader, _asyncio_internal=True)
+ stream = mock.Mock()
+ protocol = _StreamProtocol(stream, _asyncio_internal=True)
self.assertIs(protocol._loop, self.loop)
- def test_drain_raises(self):
+ def test_drain_raises_deprecated(self):
# See http://bugs.python.org/issue25441
# This test should not use asyncio for the mock server; the
@@ -833,15 +903,16 @@ os.close(fd)
def server():
# Runs in a separate thread.
- with socket.create_server(('localhost', 0)) as sock:
+ with socket.create_server(('127.0.0.1', 0)) as sock:
addr = sock.getsockname()
q.put(addr)
clt, _ = sock.accept()
clt.close()
async def client(host, port):
- reader, writer = await asyncio.open_connection(
- host, port, loop=self.loop)
+ with self.assertWarns(DeprecationWarning):
+ reader, writer = await asyncio.open_connection(
+ host, port, loop=self.loop)
while True:
writer.write(b"foo\n")
@@ -863,55 +934,106 @@ os.close(fd)
thread.join()
self.assertEqual([], messages)
+ def test_drain_raises(self):
+ # See http://bugs.python.org/issue25441
+
+ # This test should not use asyncio for the mock server; the
+ # whole point of the test is to test for a bug in drain()
+ # where it never gives up the event loop but the socket is
+ # closed on the server side.
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+ q = queue.Queue()
+
+ def server():
+ # Runs in a separate thread.
+ with socket.create_server(('localhost', 0)) as sock:
+ addr = sock.getsockname()
+ q.put(addr)
+ clt, _ = sock.accept()
+ clt.close()
+
+ async def client(host, port):
+ stream = await asyncio.connect(host, port)
+
+ while True:
+ stream.write(b"foo\n")
+ await stream.drain()
+
+ # Start the server thread and wait for it to be listening.
+ thread = threading.Thread(target=server)
+ thread.setDaemon(True)
+ thread.start()
+ addr = q.get()
+
+ # Should not be stuck in an infinite loop.
+ with self.assertRaises((ConnectionResetError, ConnectionAbortedError,
+ BrokenPipeError)):
+ self.loop.run_until_complete(client(*addr))
+
+ # Clean up the thread. (Only on success; on failure, it may
+ # be stuck in accept().)
+ thread.join()
+ self.assertEqual([], messages)
+
def test___repr__(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
- self.assertEqual("<StreamReader>", repr(stream))
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
+ self.assertEqual("<Stream mode=StreamMode.READ>", repr(stream))
def test___repr__nondefault_limit(self):
- stream = asyncio.StreamReader(loop=self.loop, limit=123,
- _asyncio_internal=True)
- self.assertEqual("<StreamReader limit=123>", repr(stream))
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop, limit=123,
+ _asyncio_internal=True)
+ self.assertEqual("<Stream mode=StreamMode.READ limit=123>", repr(stream))
def test___repr__eof(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_eof()
- self.assertEqual("<StreamReader eof>", repr(stream))
+ self.assertEqual("<Stream mode=StreamMode.READ eof>", repr(stream))
def test___repr__data(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream.feed_data(b'data')
- self.assertEqual("<StreamReader 4 bytes>", repr(stream))
+ self.assertEqual("<Stream mode=StreamMode.READ 4 bytes>", repr(stream))
def test___repr__exception(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
exc = RuntimeError()
stream.set_exception(exc)
- self.assertEqual("<StreamReader exception=RuntimeError()>",
+ self.assertEqual("<Stream mode=StreamMode.READ exception=RuntimeError()>",
repr(stream))
def test___repr__waiter(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream._waiter = asyncio.Future(loop=self.loop)
self.assertRegex(
repr(stream),
- r"<StreamReader waiter=<Future pending[\S ]*>>")
+ r"<Stream .+ waiter=<Future pending[\S ]*>>")
stream._waiter.set_result(None)
self.loop.run_until_complete(stream._waiter)
stream._waiter = None
- self.assertEqual("<StreamReader>", repr(stream))
+ self.assertEqual("<Stream mode=StreamMode.READ>", repr(stream))
def test___repr__transport(self):
- stream = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
stream._transport = mock.Mock()
stream._transport.__repr__ = mock.Mock()
stream._transport.__repr__.return_value = "<Transport>"
- self.assertEqual("<StreamReader transport=<Transport>>", repr(stream))
+ self.assertEqual("<Stream mode=StreamMode.READ transport=<Transport>>",
+ repr(stream))
def test_IncompleteReadError_pickleable(self):
e = asyncio.IncompleteReadError(b'abc', 10)
@@ -930,10 +1052,11 @@ os.close(fd)
self.assertEqual(str(e), str(e2))
self.assertEqual(e.consumed, e2.consumed)
- def test_wait_closed_on_close(self):
+ def test_wait_closed_on_close_deprecated(self):
with test_utils.run_test_server() as httpd:
- rd, wr = self.loop.run_until_complete(
- asyncio.open_connection(*httpd.address, loop=self.loop))
+ with self.assertWarns(DeprecationWarning):
+ rd, wr = self.loop.run_until_complete(
+ asyncio.open_connection(*httpd.address, loop=self.loop))
wr.write(b'GET / HTTP/1.0\r\n\r\n')
f = rd.readline()
@@ -947,10 +1070,28 @@ os.close(fd)
self.assertTrue(wr.is_closing())
self.loop.run_until_complete(wr.wait_closed())
- def test_wait_closed_on_close_with_unread_data(self):
+ def test_wait_closed_on_close(self):
with test_utils.run_test_server() as httpd:
- rd, wr = self.loop.run_until_complete(
- asyncio.open_connection(*httpd.address, loop=self.loop))
+ stream = self.loop.run_until_complete(
+ asyncio.connect(*httpd.address))
+
+ stream.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = stream.readline()
+ data = self.loop.run_until_complete(f)
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ f = stream.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+ self.assertFalse(stream.is_closing())
+ stream.close()
+ self.assertTrue(stream.is_closing())
+ self.loop.run_until_complete(stream.wait_closed())
+
+ def test_wait_closed_on_close_with_unread_data_deprecated(self):
+ with test_utils.run_test_server() as httpd:
+ with self.assertWarns(DeprecationWarning):
+ rd, wr = self.loop.run_until_complete(
+ asyncio.open_connection(*httpd.address, loop=self.loop))
wr.write(b'GET / HTTP/1.0\r\n\r\n')
f = rd.readline()
@@ -959,32 +1100,44 @@ os.close(fd)
wr.close()
self.loop.run_until_complete(wr.wait_closed())
+ def test_wait_closed_on_close_with_unread_data(self):
+ with test_utils.run_test_server() as httpd:
+ stream = self.loop.run_until_complete(
+ asyncio.connect(*httpd.address))
+
+ stream.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = stream.readline()
+ data = self.loop.run_until_complete(f)
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ stream.close()
+ self.loop.run_until_complete(stream.wait_closed())
+
def test_del_stream_before_sock_closing(self):
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
- with test_utils.run_test_server() as httpd:
- rd, wr = self.loop.run_until_complete(
- asyncio.open_connection(*httpd.address, loop=self.loop))
- sock = wr.get_extra_info('socket')
- self.assertNotEqual(sock.fileno(), -1)
+ async def test():
- wr.write(b'GET / HTTP/1.0\r\n\r\n')
- f = rd.readline()
- data = self.loop.run_until_complete(f)
- self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ with test_utils.run_test_server() as httpd:
+ stream = await asyncio.connect(*httpd.address)
+ sock = stream.get_extra_info('socket')
+ self.assertNotEqual(sock.fileno(), -1)
- # drop refs to reader/writer
- del rd
- del wr
- gc.collect()
- # make a chance to close the socket
- test_utils.run_briefly(self.loop)
+ await stream.write(b'GET / HTTP/1.0\r\n\r\n')
+ data = await stream.readline()
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
- self.assertEqual(1, len(messages))
- self.assertEqual(sock.fileno(), -1)
+ # drop refs to reader/writer
+ del stream
+ gc.collect()
+ # make a chance to close the socket
+ await asyncio.sleep(0)
- self.assertEqual(1, len(messages))
+ self.assertEqual(1, len(messages), messages)
+ self.assertEqual(sock.fileno(), -1)
+
+ self.loop.run_until_complete(test())
+ self.assertEqual(1, len(messages), messages)
self.assertEqual('An open stream object is being garbage '
'collected; call "stream.close()" explicitly.',
messages[0]['message'])
@@ -994,11 +1147,12 @@ os.close(fd)
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
with test_utils.run_test_server() as httpd:
- rd = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
- pr = asyncio.StreamReaderProtocol(rd, loop=self.loop,
- _asyncio_internal=True)
- del rd
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop,
+ _asyncio_internal=True)
+ pr = _StreamProtocol(stream, loop=self.loop,
+ _asyncio_internal=True)
+ del stream
gc.collect()
tr, _ = self.loop.run_until_complete(
self.loop.create_connection(
@@ -1015,14 +1169,14 @@ os.close(fd)
def test_async_writer_api(self):
async def inner(httpd):
- rd, wr = await asyncio.open_connection(*httpd.address)
+ stream = await asyncio.connect(*httpd.address)
- await wr.write(b'GET / HTTP/1.0\r\n\r\n')
- data = await rd.readline()
+ await stream.write(b'GET / HTTP/1.0\r\n\r\n')
+ data = await stream.readline()
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
- data = await rd.read()
+ data = await stream.read()
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
- await wr.close()
+ await stream.close()
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
@@ -1032,18 +1186,18 @@ os.close(fd)
self.assertEqual(messages, [])
- def test_async_writer_api(self):
+ def test_async_writer_api_exception_after_close(self):
async def inner(httpd):
- rd, wr = await asyncio.open_connection(*httpd.address)
+ stream = await asyncio.connect(*httpd.address)
- await wr.write(b'GET / HTTP/1.0\r\n\r\n')
- data = await rd.readline()
+ await stream.write(b'GET / HTTP/1.0\r\n\r\n')
+ data = await stream.readline()
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
- data = await rd.read()
+ data = await stream.read()
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
- wr.close()
+ stream.close()
with self.assertRaises(ConnectionResetError):
- await wr.write(b'data')
+ await stream.write(b'data')
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
@@ -1059,11 +1213,13 @@ os.close(fd)
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
with test_utils.run_test_server() as httpd:
- rd, wr = self.loop.run_until_complete(
- asyncio.open_connection(*httpd.address,
- loop=self.loop))
+ with self.assertWarns(DeprecationWarning):
+ rd, wr = self.loop.run_until_complete(
+ asyncio.open_connection(*httpd.address,
+ loop=self.loop))
- f = wr.close()
+ wr.close()
+ f = wr.wait_closed()
self.loop.run_until_complete(f)
assert rd.at_eof()
f = rd.read()
@@ -1074,22 +1230,514 @@ os.close(fd)
def test_stream_reader_create_warning(self):
with self.assertWarns(DeprecationWarning):
- asyncio.StreamReader(loop=self.loop)
-
- def test_stream_reader_protocol_create_warning(self):
- reader = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
- with self.assertWarns(DeprecationWarning):
- asyncio.StreamReaderProtocol(reader, loop=self.loop)
+ asyncio.StreamReader
def test_stream_writer_create_warning(self):
- reader = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
- proto = asyncio.StreamReaderProtocol(reader, loop=self.loop,
- _asyncio_internal=True)
with self.assertWarns(DeprecationWarning):
- asyncio.StreamWriter('transport', proto, reader, self.loop)
+ asyncio.StreamWriter
+
+ def test_stream_reader_forbidden_ops(self):
+ async def inner():
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ _asyncio_internal=True)
+ with self.assertRaisesRegex(RuntimeError, "The stream is read-only"):
+ await stream.write(b'data')
+ with self.assertRaisesRegex(RuntimeError, "The stream is read-only"):
+ await stream.writelines([b'data', b'other'])
+ with self.assertRaisesRegex(RuntimeError, "The stream is read-only"):
+ stream.write_eof()
+ with self.assertRaisesRegex(RuntimeError, "The stream is read-only"):
+ await stream.drain()
+
+ self.loop.run_until_complete(inner())
+
+ def test_stream_writer_forbidden_ops(self):
+ async def inner():
+ stream = asyncio.Stream(mode=asyncio.StreamMode.WRITE,
+ _asyncio_internal=True)
+ with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+ stream.feed_data(b'data')
+ with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+ await stream.readline()
+ with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+ await stream.readuntil()
+ with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+ await stream.read()
+ with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+ await stream.readexactly(10)
+ with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+ async for chunk in stream:
+ pass
+
+ self.loop.run_until_complete(inner())
+
+ def _basetest_connect(self, stream):
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+ stream.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = stream.readline()
+ data = self.loop.run_until_complete(f)
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ f = stream.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+ stream.close()
+ self.loop.run_until_complete(stream.wait_closed())
+
+ self.assertEqual([], messages)
+
+ def test_connect(self):
+ with test_utils.run_test_server() as httpd:
+ stream = self.loop.run_until_complete(
+ asyncio.connect(*httpd.address))
+ self.assertFalse(stream.is_server_side())
+ self._basetest_connect(stream)
+
+ @support.skip_unless_bind_unix_socket
+ def test_connect_unix(self):
+ with test_utils.run_test_unix_server() as httpd:
+ stream = self.loop.run_until_complete(
+ asyncio.connect_unix(httpd.address))
+ self._basetest_connect(stream)
+
+ def test_stream_async_context_manager(self):
+ async def test(httpd):
+ stream = await asyncio.connect(*httpd.address)
+ async with stream:
+ await stream.write(b'GET / HTTP/1.0\r\n\r\n')
+ data = await stream.readline()
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ data = await stream.read()
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+ self.assertTrue(stream.is_closing())
+
+ with test_utils.run_test_server() as httpd:
+ self.loop.run_until_complete(test(httpd))
+
+ def test_connect_async_context_manager(self):
+ async def test(httpd):
+ async with asyncio.connect(*httpd.address) as stream:
+ await stream.write(b'GET / HTTP/1.0\r\n\r\n')
+ data = await stream.readline()
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ data = await stream.read()
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+ self.assertTrue(stream.is_closing())
+
+ with test_utils.run_test_server() as httpd:
+ self.loop.run_until_complete(test(httpd))
+
+ @support.skip_unless_bind_unix_socket
+ def test_connect_unix_async_context_manager(self):
+ async def test(httpd):
+ async with asyncio.connect_unix(httpd.address) as stream:
+ await stream.write(b'GET / HTTP/1.0\r\n\r\n')
+ data = await stream.readline()
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ data = await stream.read()
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+ self.assertTrue(stream.is_closing())
+
+ with test_utils.run_test_unix_server() as httpd:
+ self.loop.run_until_complete(test(httpd))
+
+ def test_stream_server(self):
+
+ async def handle_client(stream):
+ self.assertTrue(stream.is_server_side())
+ data = await stream.readline()
+ await stream.write(data)
+ await stream.close()
+
+ async def client(srv):
+ addr = srv.sockets[0].getsockname()
+ stream = await asyncio.connect(*addr)
+ # send a line
+ await stream.write(b"hello world!\n")
+ # read it back
+ msgback = await stream.readline()
+ await stream.close()
+ self.assertEqual(msgback, b"hello world!\n")
+ await srv.close()
+
+ async def test():
+ async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server:
+ await server.start_serving()
+ task = asyncio.create_task(client(server))
+ with contextlib.suppress(asyncio.CancelledError):
+ await server.serve_forever()
+ await task
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+ self.loop.run_until_complete(test())
+ self.assertEqual(messages, [])
+
+ @support.skip_unless_bind_unix_socket
+ def test_unix_stream_server(self):
+
+ async def handle_client(stream):
+ data = await stream.readline()
+ await stream.write(data)
+ await stream.close()
+
+ async def client(srv):
+ addr = srv.sockets[0].getsockname()
+ stream = await asyncio.connect_unix(addr)
+ # send a line
+ await stream.write(b"hello world!\n")
+ # read it back
+ msgback = await stream.readline()
+ await stream.close()
+ self.assertEqual(msgback, b"hello world!\n")
+ await srv.close()
+
+ async def test():
+ with test_utils.unix_socket_path() as path:
+ async with asyncio.UnixStreamServer(handle_client, path) as server:
+ await server.start_serving()
+ task = asyncio.create_task(client(server))
+ with contextlib.suppress(asyncio.CancelledError):
+ await server.serve_forever()
+ await task
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+ self.loop.run_until_complete(test())
+ self.assertEqual(messages, [])
+
+ def test_stream_server_inheritance_forbidden(self):
+ with self.assertRaises(TypeError):
+ class MyServer(asyncio.StreamServer):
+ pass
+
+ @support.skip_unless_bind_unix_socket
+ def test_unix_stream_server_inheritance_forbidden(self):
+ with self.assertRaises(TypeError):
+ class MyServer(asyncio.UnixStreamServer):
+ pass
+
+ def test_stream_server_bind(self):
+ async def handle_client(stream):
+ await stream.close()
+
+ async def test():
+ srv = asyncio.StreamServer(handle_client, '127.0.0.1', 0)
+ self.assertFalse(srv.is_bound())
+ self.assertEqual(0, len(srv.sockets))
+ await srv.bind()
+ self.assertTrue(srv.is_bound())
+ self.assertEqual(1, len(srv.sockets))
+ await srv.close()
+ self.assertFalse(srv.is_bound())
+ self.assertEqual(0, len(srv.sockets))
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+ self.loop.run_until_complete(test())
+ self.assertEqual(messages, [])
+
+ def test_stream_server_bind_async_with(self):
+ async def handle_client(stream):
+ await stream.close()
+
+ async def test():
+ async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv:
+ self.assertTrue(srv.is_bound())
+ self.assertEqual(1, len(srv.sockets))
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+ self.loop.run_until_complete(test())
+ self.assertEqual(messages, [])
+
+ def test_stream_server_start_serving(self):
+ async def handle_client(stream):
+ await stream.close()
+
+ async def test():
+ async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv:
+ self.assertFalse(srv.is_serving())
+ await srv.start_serving()
+ self.assertTrue(srv.is_serving())
+ await srv.close()
+ self.assertFalse(srv.is_serving())
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+ self.loop.run_until_complete(test())
+ self.assertEqual(messages, [])
+
+ def test_stream_server_close(self):
+ server_stream_aborted = False
+ fut = self.loop.create_future()
+
+ async def handle_client(stream):
+ await fut
+ self.assertEqual(b'', await stream.readline())
+ nonlocal server_stream_aborted
+ server_stream_aborted = True
+
+ async def client(srv):
+ addr = srv.sockets[0].getsockname()
+ stream = await asyncio.connect(*addr)
+ fut.set_result(None)
+ self.assertEqual(b'', await stream.readline())
+ await stream.close()
+
+ async def test():
+ async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server:
+ await server.start_serving()
+ task = asyncio.create_task(client(server))
+ await fut
+ await server.close()
+ await task
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+ self.loop.run_until_complete(test())
+ self.assertEqual(messages, [])
+ self.assertTrue(fut.done())
+ self.assertTrue(server_stream_aborted)
+
+ def test_stream_server_abort(self):
+ server_stream_aborted = False
+ fut = self.loop.create_future()
+
+ async def handle_client(stream):
+ await fut
+ self.assertEqual(b'', await stream.readline())
+ nonlocal server_stream_aborted
+ server_stream_aborted = True
+
+ async def client(srv):
+ addr = srv.sockets[0].getsockname()
+ stream = await asyncio.connect(*addr)
+ fut.set_result(None)
+ self.assertEqual(b'', await stream.readline())
+ await stream.close()
+
+ async def test():
+ async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server:
+ await server.start_serving()
+ task = asyncio.create_task(client(server))
+ await fut
+ await server.abort()
+ await task
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+ self.loop.run_until_complete(test())
+ self.assertEqual(messages, [])
+ self.assertTrue(fut.done())
+ self.assertTrue(server_stream_aborted)
+
+ def test_stream_shutdown_hung_task(self):
+ fut1 = self.loop.create_future()
+ fut2 = self.loop.create_future()
+
+ async def handle_client(stream):
+ while True:
+ await asyncio.sleep(0.01)
+
+ async def client(srv):
+ addr = srv.sockets[0].getsockname()
+ stream = await asyncio.connect(*addr)
+ fut1.set_result(None)
+ await fut2
+ self.assertEqual(b'', await stream.readline())
+ await stream.close()
+
+ async def test():
+ async with asyncio.StreamServer(handle_client,
+ '127.0.0.1',
+ 0,
+ shutdown_timeout=0.3) as server:
+ await server.start_serving()
+ task = asyncio.create_task(client(server))
+ await fut1
+ await server.close()
+ fut2.set_result(None)
+ await task
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+ self.loop.run_until_complete(test())
+ self.assertEqual(messages, [])
+ self.assertTrue(fut1.done())
+ self.assertTrue(fut2.done())
+
+ def test_stream_shutdown_hung_task_prevents_cancellation(self):
+ fut1 = self.loop.create_future()
+ fut2 = self.loop.create_future()
+ do_handle_client = True
+
+ async def handle_client(stream):
+ while do_handle_client:
+ with contextlib.suppress(asyncio.CancelledError):
+ await asyncio.sleep(0.01)
+
+ async def client(srv):
+ addr = srv.sockets[0].getsockname()
+ stream = await asyncio.connect(*addr)
+ fut1.set_result(None)
+ await fut2
+ self.assertEqual(b'', await stream.readline())
+ await stream.close()
+
+ async def test():
+ async with asyncio.StreamServer(handle_client,
+ '127.0.0.1',
+ 0,
+ shutdown_timeout=0.3) as server:
+ await server.start_serving()
+ task = asyncio.create_task(client(server))
+ await fut1
+ await server.close()
+ nonlocal do_handle_client
+ do_handle_client = False
+ fut2.set_result(None)
+ await task
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+ self.loop.run_until_complete(test())
+ self.assertEqual(1, len(messages))
+ self.assertRegex(messages[0]['message'],
+ "<Task pending .+ ignored cancellation request")
+ self.assertTrue(fut1.done())
+ self.assertTrue(fut2.done())
+
+ def test_sendfile(self):
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+ with open(support.TESTFN, 'wb') as fp:
+ fp.write(b'data\n')
+ self.addCleanup(support.unlink, support.TESTFN)
+
+ async def serve_callback(stream):
+ data = await stream.readline()
+ self.assertEqual(data, b'begin\n')
+ data = await stream.readline()
+ self.assertEqual(data, b'data\n')
+ data = await stream.readline()
+ self.assertEqual(data, b'end\n')
+ await stream.write(b'done\n')
+ await stream.close()
+
+ async def do_connect(host, port):
+ stream = await asyncio.connect(host, port)
+ await stream.write(b'begin\n')
+ with open(support.TESTFN, 'rb') as fp:
+ await stream.sendfile(fp)
+ await stream.write(b'end\n')
+ data = await stream.readline()
+ self.assertEqual(data, b'done\n')
+ await stream.close()
+
+ async def test():
+ async with asyncio.StreamServer(serve_callback, '127.0.0.1', 0) as srv:
+ await srv.start_serving()
+ await do_connect(*srv.sockets[0].getsockname())
+
+ self.loop.run_until_complete(test())
+
+ self.assertEqual([], messages)
+
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_connect_start_tls(self):
+ with test_utils.run_test_server(use_ssl=True) as httpd:
+ # connect without SSL but upgrade to TLS just after
+ # connection is established
+ stream = self.loop.run_until_complete(
+ asyncio.connect(*httpd.address))
+
+ self.loop.run_until_complete(
+ stream.start_tls(
+ sslcontext=test_utils.dummy_ssl_context()))
+ self._basetest_connect(stream)
+
+ def test_repr_unbound(self):
+ async def serve(stream):
+ pass
+
+ async def test():
+ srv = asyncio.StreamServer(serve)
+ self.assertEqual('<StreamServer>', repr(srv))
+ await srv.close()
+
+ self.loop.run_until_complete(test())
+
+ def test_repr_bound(self):
+ async def serve(stream):
+ pass
+
+ async def test():
+ srv = asyncio.StreamServer(serve, '127.0.0.1', 0)
+ await srv.bind()
+ self.assertRegex(repr(srv), r'<StreamServer sockets=\(.+\)>')
+ await srv.close()
+
+ self.loop.run_until_complete(test())
+
+ def test_repr_serving(self):
+ async def serve(stream):
+ pass
+
+ async def test():
+ srv = asyncio.StreamServer(serve, '127.0.0.1', 0)
+ await srv.start_serving()
+ self.assertRegex(repr(srv), r'<StreamServer serving sockets=\(.+\)>')
+ await srv.close()
+
+ self.loop.run_until_complete(test())
+
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_read_pipe(self):
+ async def test():
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(rpipe, 'rb', 1024)
+
+ async with asyncio.connect_read_pipe(pipeobj) as stream:
+ self.assertEqual(stream.mode, asyncio.StreamMode.READ)
+
+ os.write(wpipe, b'1')
+ data = await stream.readexactly(1)
+ self.assertEqual(data, b'1')
+
+ os.write(wpipe, b'2345')
+ data = await stream.readexactly(4)
+ self.assertEqual(data, b'2345')
+ os.close(wpipe)
+
+ self.loop.run_until_complete(test())
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_write_pipe(self):
+ async def test():
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(wpipe, 'wb', 1024)
+
+ async with asyncio.connect_write_pipe(pipeobj) as stream:
+ self.assertEqual(stream.mode, asyncio.StreamMode.WRITE)
+
+ await stream.write(b'1')
+ data = os.read(rpipe, 1024)
+ self.assertEqual(data, b'1')
+
+ await stream.write(b'2345')
+ data = os.read(rpipe, 1024)
+ self.assertEqual(data, b'2345')
+
+ os.close(rpipe)
+ self.loop.run_until_complete(test())
if __name__ == '__main__':
diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py
index e201a06..13aef7c 100644
--- a/Lib/test/test_asyncio/test_windows_events.py
+++ b/Lib/test/test_asyncio/test_windows_events.py
@@ -17,6 +17,7 @@ import _winapi
import asyncio
from asyncio import windows_events
+from asyncio.streams import _StreamProtocol
from test.test_asyncio import utils as test_utils
from test.support.script_helper import spawn_python
@@ -100,16 +101,16 @@ class ProactorTests(test_utils.TestCase):
clients = []
for i in range(5):
- stream_reader = asyncio.StreamReader(loop=self.loop,
- _asyncio_internal=True)
- protocol = asyncio.StreamReaderProtocol(stream_reader,
- loop=self.loop,
- _asyncio_internal=True)
+ stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+ loop=self.loop, _asyncio_internal=True)
+ protocol = _StreamProtocol(stream,
+ loop=self.loop,
+ _asyncio_internal=True)
trans, proto = await self.loop.create_pipe_connection(
lambda: protocol, ADDRESS)
self.assertIsInstance(trans, asyncio.Transport)
self.assertEqual(protocol, proto)
- clients.append((stream_reader, trans))
+ clients.append((stream, trans))
for i, (r, w) in enumerate(clients):
w.write('lower-{}\n'.format(i).encode())
@@ -118,6 +119,7 @@ class ProactorTests(test_utils.TestCase):
response = await r.readline()
self.assertEqual(response, 'LOWER-{}\n'.format(i).encode())
w.close()
+ await r.close()
server.close()