diff options
author | Andrew Svetlov <andrew.svetlov@gmail.com> | 2019-05-27 19:56:22 (GMT) |
---|---|---|
committer | Miss Islington (bot) <31488909+miss-islington@users.noreply.github.com> | 2019-05-27 19:56:22 (GMT) |
commit | 23b4b697e5b6cc897696f9c0288c187d2d24bff2 (patch) | |
tree | 2f70e14fe527878cd69ccbefca007a1e987943ed /Lib | |
parent | 6f6ff8a56518a80da406aad6ac8364c046cc7f18 (diff) | |
download | cpython-23b4b697e5b6cc897696f9c0288c187d2d24bff2.zip cpython-23b4b697e5b6cc897696f9c0288c187d2d24bff2.tar.gz cpython-23b4b697e5b6cc897696f9c0288c187d2d24bff2.tar.bz2 |
bpo-36889: Merge asyncio streams (GH-13251)
https://bugs.python.org/issue36889
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/__init__.py | 38 | ||||
-rw-r--r-- | Lib/asyncio/streams.py | 1236 | ||||
-rw-r--r-- | Lib/asyncio/subprocess.py | 35 | ||||
-rw-r--r-- | Lib/asyncio/windows_events.py | 2 | ||||
-rw-r--r-- | Lib/test/test___all__.py | 36 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_base_events.py | 5 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_buffered_proto.py | 7 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_pep492.py | 4 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_server.py | 10 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_sslproto.py | 37 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_streams.py | 1008 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_windows_events.py | 14 |
12 files changed, 2049 insertions, 383 deletions
diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py index 28c2e2c..a6a29db 100644 --- a/Lib/asyncio/__init__.py +++ b/Lib/asyncio/__init__.py @@ -3,6 +3,7 @@ # flake8: noqa import sys +import warnings # This relies on each of the submodules having an __all__ variable. from .base_events import * @@ -43,3 +44,40 @@ if sys.platform == 'win32': # pragma: no cover else: from .unix_events import * # pragma: no cover __all__ += unix_events.__all__ + + +__all__ += ('StreamReader', 'StreamWriter', 'StreamReaderProtocol') # deprecated + + +def __getattr__(name): + global StreamReader, StreamWriter, StreamReaderProtocol + if name == 'StreamReader': + warnings.warn("StreamReader is deprecated since Python 3.8 " + "in favor of Stream, and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) + from .streams import StreamReader as sr + StreamReader = sr + return StreamReader + if name == 'StreamWriter': + warnings.warn("StreamWriter is deprecated since Python 3.8 " + "in favor of Stream, and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) + from .streams import StreamWriter as sw + StreamWriter = sw + return StreamWriter + if name == 'StreamReaderProtocol': + warnings.warn("Using asyncio internal class StreamReaderProtocol " + "is deprecated since Python 3.8 " + " and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) + from .streams import StreamReaderProtocol as srp + StreamReaderProtocol = srp + return StreamReaderProtocol + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 2f0cbfd..480f1a3 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,14 +1,19 @@ __all__ = ( - 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', - 'open_connection', 'start_server') + 'Stream', 'StreamMode', + 'open_connection', 'start_server', + 'connect', 'connect_read_pipe', 'connect_write_pipe', + 'StreamServer') +import enum import socket import sys import warnings import weakref if hasattr(socket, 'AF_UNIX'): - __all__ += ('open_unix_connection', 'start_unix_server') + __all__ += ('open_unix_connection', 'start_unix_server', + 'connect_unix', + 'UnixStreamServer') from . import coroutines from . import events @@ -16,12 +21,134 @@ from . import exceptions from . import format_helpers from . import protocols from .log import logger -from .tasks import sleep +from . import tasks _DEFAULT_LIMIT = 2 ** 16 # 64 KiB +class StreamMode(enum.Flag): + READ = enum.auto() + WRITE = enum.auto() + READWRITE = READ | WRITE + + +def _ensure_can_read(mode): + if not mode & StreamMode.READ: + raise RuntimeError("The stream is write-only") + + +def _ensure_can_write(mode): + if not mode & StreamMode.WRITE: + raise RuntimeError("The stream is read-only") + + +class _ContextManagerHelper: + __slots__ = ('_awaitable', '_result') + + def __init__(self, awaitable): + self._awaitable = awaitable + self._result = None + + def __await__(self): + return self._awaitable.__await__() + + async def __aenter__(self): + ret = await self._awaitable + result = await ret.__aenter__() + self._result = result + return result + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self._result.__aexit__(exc_type, exc_val, exc_tb) + + +def connect(host=None, port=None, *, + limit=_DEFAULT_LIMIT, + ssl=None, family=0, proto=0, + flags=0, sock=None, local_addr=None, + server_hostname=None, + ssl_handshake_timeout=None, + happy_eyeballs_delay=None, interleave=None): + # Design note: + # Don't use decorator approach but exilicit non-async + # function to fail fast and explicitly + # if passed arguments don't match the function signature + return _ContextManagerHelper(_connect(host, port, limit, + ssl, family, proto, + flags, sock, local_addr, + server_hostname, + ssl_handshake_timeout, + happy_eyeballs_delay, + interleave)) + + +async def _connect(host, port, + limit, + ssl, family, proto, + flags, sock, local_addr, + server_hostname, + ssl_handshake_timeout, + happy_eyeballs_delay, interleave): + loop = events.get_running_loop() + stream = Stream(mode=StreamMode.READWRITE, + limit=limit, + loop=loop, + _asyncio_internal=True) + await loop.create_connection( + lambda: _StreamProtocol(stream, loop=loop, + _asyncio_internal=True), + host, port, + ssl=ssl, family=family, proto=proto, + flags=flags, sock=sock, local_addr=local_addr, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + happy_eyeballs_delay=happy_eyeballs_delay, interleave=interleave) + return stream + + +def connect_read_pipe(pipe, *, limit=_DEFAULT_LIMIT): + # Design note: + # Don't use decorator approach but explicit non-async + # function to fail fast and explicitly + # if passed arguments don't match the function signature + return _ContextManagerHelper(_connect_read_pipe(pipe, limit)) + + +async def _connect_read_pipe(pipe, limit): + loop = events.get_running_loop() + stream = Stream(mode=StreamMode.READ, + limit=limit, + loop=loop, + _asyncio_internal=True) + await loop.connect_read_pipe( + lambda: _StreamProtocol(stream, loop=loop, + _asyncio_internal=True), + pipe) + return stream + + +def connect_write_pipe(pipe, *, limit=_DEFAULT_LIMIT): + # Design note: + # Don't use decorator approach but explicit non-async + # function to fail fast and explicitly + # if passed arguments don't match the function signature + return _ContextManagerHelper(_connect_write_pipe(pipe, limit)) + + +async def _connect_write_pipe(pipe, limit): + loop = events.get_running_loop() + stream = Stream(mode=StreamMode.WRITE, + limit=limit, + loop=loop, + _asyncio_internal=True) + await loop.connect_write_pipe( + lambda: _StreamProtocol(stream, loop=loop, + _asyncio_internal=True), + pipe) + return stream + + async def open_connection(host=None, port=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """A wrapper for create_connection() returning a (reader, writer) pair. @@ -41,16 +168,18 @@ async def open_connection(host=None, port=None, *, StreamReaderProtocol classes, just copy the code -- there's really nothing special here except some convenience.) """ + warnings.warn("open_connection() is deprecated since Python 3.8 " + "in favor of connect(), and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) if loop is None: loop = events.get_event_loop() - reader = StreamReader(limit=limit, loop=loop, - _asyncio_internal=True) - protocol = StreamReaderProtocol(reader, loop=loop, - _asyncio_internal=True) + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop, _asyncio_internal=True) transport, _ = await loop.create_connection( lambda: protocol, host, port, **kwds) - writer = StreamWriter(transport, protocol, reader, loop, - _asyncio_internal=True) + writer = StreamWriter(transport, protocol, reader, loop) return reader, writer @@ -77,12 +206,16 @@ async def start_server(client_connected_cb, host=None, port=None, *, The return value is the same as loop.create_server(), i.e. a Server object which can be used to stop the service. """ + warnings.warn("start_server() is deprecated since Python 3.8 " + "in favor of StreamServer(), and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) if loop is None: loop = events.get_event_loop() def factory(): - reader = StreamReader(limit=limit, loop=loop, - _asyncio_internal=True) + reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, client_connected_cb, loop=loop, _asyncio_internal=True) @@ -91,33 +224,258 @@ async def start_server(client_connected_cb, host=None, port=None, *, return await loop.create_server(factory, host, port, **kwds) +class _BaseStreamServer: + # Design notes. + # StreamServer and UnixStreamServer are exposed as FINAL classes, + # not function factories. + # async with serve(host, port) as server: + # server.start_serving() + # looks ugly. + # The class doesn't provide API for enumerating connected streams + # It can be a subject for improvements in Python 3.9 + + _server_impl = None + + def __init__(self, client_connected_cb, + /, + limit=_DEFAULT_LIMIT, + shutdown_timeout=60, + _asyncio_internal=False): + if not _asyncio_internal: + raise RuntimeError("_ServerStream is a private asyncio class") + self._client_connected_cb = client_connected_cb + self._limit = limit + self._loop = events.get_running_loop() + self._streams = {} + self._shutdown_timeout = shutdown_timeout + + def __init_subclass__(cls): + if not cls.__module__.startswith('asyncio.'): + raise TypeError(f"asyncio.{cls.__name__} " + "class cannot be inherited from") + + async def bind(self): + if self._server_impl is not None: + return + self._server_impl = await self._bind() + + def is_bound(self): + return self._server_impl is not None + + @property + def sockets(self): + # multiple value for socket bound to both IPv4 and IPv6 families + if self._server_impl is None: + return () + return self._server_impl.sockets + + def is_serving(self): + if self._server_impl is None: + return False + return self._server_impl.is_serving() + + async def start_serving(self): + await self.bind() + await self._server_impl.start_serving() + + async def serve_forever(self): + await self.start_serving() + await self._server_impl.serve_forever() + + async def close(self): + if self._server_impl is None: + return + self._server_impl.close() + streams = list(self._streams.keys()) + active_tasks = list(self._streams.values()) + if streams: + await tasks.wait([stream.close() for stream in streams]) + await self._server_impl.wait_closed() + self._server_impl = None + await self._shutdown_active_tasks(active_tasks) + + async def abort(self): + if self._server_impl is None: + return + self._server_impl.close() + streams = list(self._streams.keys()) + active_tasks = list(self._streams.values()) + if streams: + await tasks.wait([stream.abort() for stream in streams]) + await self._server_impl.wait_closed() + self._server_impl = None + await self._shutdown_active_tasks(active_tasks) + + async def __aenter__(self): + await self.bind() + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + await self.close() + + def _attach(self, stream, task): + self._streams[stream] = task + + def _detach(self, stream, task): + del self._streams[stream] + + async def _shutdown_active_tasks(self, active_tasks): + if not active_tasks: + return + # NOTE: tasks finished with exception are reported + # by the Task.__del__() method. + done, pending = await tasks.wait(active_tasks, + timeout=self._shutdown_timeout) + if not pending: + return + for task in pending: + task.cancel() + done, pending = await tasks.wait(pending, + timeout=self._shutdown_timeout) + for task in pending: + self._loop.call_exception_handler({ + "message": (f'{task!r} ignored cancellation request ' + f'from a closing {self!r}'), + "stream_server": self + }) + + def __repr__(self): + ret = [f'{self.__class__.__name__}'] + if self.is_serving(): + ret.append('serving') + if self.sockets: + ret.append(f'sockets={self.sockets!r}') + return '<' + ' '.join(ret) + '>' + + def __del__(self, _warn=warnings.warn): + if self._server_impl is not None: + _warn(f"unclosed stream server {self!r}", + ResourceWarning, source=self) + self._server_impl.close() + + +class StreamServer(_BaseStreamServer): + + def __init__(self, client_connected_cb, /, host=None, port=None, *, + limit=_DEFAULT_LIMIT, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, sock=None, backlog=100, + ssl=None, reuse_address=None, reuse_port=None, + ssl_handshake_timeout=None, + shutdown_timeout=60): + super().__init__(client_connected_cb, + limit=limit, + shutdown_timeout=shutdown_timeout, + _asyncio_internal=True) + self._host = host + self._port = port + self._family = family + self._flags = flags + self._sock = sock + self._backlog = backlog + self._ssl = ssl + self._reuse_address = reuse_address + self._reuse_port = reuse_port + self._ssl_handshake_timeout = ssl_handshake_timeout + + async def _bind(self): + def factory(): + protocol = _ServerStreamProtocol(self, + self._limit, + self._client_connected_cb, + loop=self._loop, + _asyncio_internal=True) + return protocol + return await self._loop.create_server( + factory, + self._host, + self._port, + start_serving=False, + family=self._family, + flags=self._flags, + sock=self._sock, + backlog=self._backlog, + ssl=self._ssl, + reuse_address=self._reuse_address, + reuse_port=self._reuse_port, + ssl_handshake_timeout=self._ssl_handshake_timeout) + + if hasattr(socket, 'AF_UNIX'): # UNIX Domain Sockets are supported on this platform async def open_unix_connection(path=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `open_connection` but works with UNIX Domain Sockets.""" + warnings.warn("open_unix_connection() is deprecated since Python 3.8 " + "in favor of connect_unix(), and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) if loop is None: loop = events.get_event_loop() - reader = StreamReader(limit=limit, loop=loop, - _asyncio_internal=True) + reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, loop=loop, _asyncio_internal=True) transport, _ = await loop.create_unix_connection( lambda: protocol, path, **kwds) - writer = StreamWriter(transport, protocol, reader, loop, - _asyncio_internal=True) + writer = StreamWriter(transport, protocol, reader, loop) return reader, writer + + def connect_unix(path=None, *, + limit=_DEFAULT_LIMIT, + ssl=None, sock=None, + server_hostname=None, + ssl_handshake_timeout=None): + """Similar to `connect()` but works with UNIX Domain Sockets.""" + # Design note: + # Don't use decorator approach but exilicit non-async + # function to fail fast and explicitly + # if passed arguments don't match the function signature + return _ContextManagerHelper(_connect_unix(path, + limit, + ssl, sock, + server_hostname, + ssl_handshake_timeout)) + + + async def _connect_unix(path, + limit, + ssl, sock, + server_hostname, + ssl_handshake_timeout): + """Similar to `connect()` but works with UNIX Domain Sockets.""" + loop = events.get_running_loop() + stream = Stream(mode=StreamMode.READWRITE, + limit=limit, + loop=loop, + _asyncio_internal=True) + await loop.create_unix_connection( + lambda: _StreamProtocol(stream, + loop=loop, + _asyncio_internal=True), + path, + ssl=ssl, + sock=sock, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) + return stream + + async def start_unix_server(client_connected_cb, path=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `start_server` but works with UNIX Domain Sockets.""" + warnings.warn("start_unix_server() is deprecated since Python 3.8 " + "in favor of UnixStreamServer(), and scheduled " + "for removal in Python 3.10", + DeprecationWarning, + stacklevel=2) if loop is None: loop = events.get_event_loop() def factory(): - reader = StreamReader(limit=limit, loop=loop, - _asyncio_internal=True) + reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, client_connected_cb, loop=loop, _asyncio_internal=True) @@ -125,6 +483,42 @@ if hasattr(socket, 'AF_UNIX'): return await loop.create_unix_server(factory, path, **kwds) + class UnixStreamServer(_BaseStreamServer): + + def __init__(self, client_connected_cb, /, path=None, *, + limit=_DEFAULT_LIMIT, + sock=None, + backlog=100, + ssl=None, + ssl_handshake_timeout=None, + shutdown_timeout=60): + super().__init__(client_connected_cb, + limit=limit, + shutdown_timeout=shutdown_timeout, + _asyncio_internal=True) + self._path = path + self._sock = sock + self._backlog = backlog + self._ssl = ssl + self._ssl_handshake_timeout = ssl_handshake_timeout + + async def _bind(self): + def factory(): + protocol = _ServerStreamProtocol(self, + self._limit, + self._client_connected_cb, + loop=self._loop, + _asyncio_internal=True) + return protocol + return await self._loop.create_unix_server( + factory, + self._path, + start_serving=False, + sock=self._sock, + backlog=self._backlog, + ssl=self._ssl, + ssl_handshake_timeout=self._ssl_handshake_timeout) + class FlowControlMixin(protocols.Protocol): """Reusable flow control logic for StreamWriter.drain(). @@ -203,6 +597,8 @@ class FlowControlMixin(protocols.Protocol): raise NotImplementedError +# begin legacy stream APIs + class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): """Helper class to adapt between Protocol and StreamReader. @@ -212,105 +608,47 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): call inappropriate methods of the protocol.) """ - _source_traceback = None - def __init__(self, stream_reader, client_connected_cb=None, loop=None, *, _asyncio_internal=False): super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) - if stream_reader is not None: - self._stream_reader_wr = weakref.ref(stream_reader, - self._on_reader_gc) - self._source_traceback = stream_reader._source_traceback - else: - self._stream_reader_wr = None - if client_connected_cb is not None: - # This is a stream created by the `create_server()` function. - # Keep a strong reference to the reader until a connection - # is established. - self._strong_reader = stream_reader - self._reject_connection = False + self._stream_reader = stream_reader self._stream_writer = None - self._transport = None self._client_connected_cb = client_connected_cb self._over_ssl = False self._closed = self._loop.create_future() - def _on_reader_gc(self, wr): - transport = self._transport - if transport is not None: - # connection_made was called - context = { - 'message': ('An open stream object is being garbage ' - 'collected; call "stream.close()" explicitly.') - } - if self._source_traceback: - context['source_traceback'] = self._source_traceback - self._loop.call_exception_handler(context) - transport.abort() - else: - self._reject_connection = True - self._stream_reader_wr = None - - @property - def _stream_reader(self): - if self._stream_reader_wr is None: - return None - return self._stream_reader_wr() - def connection_made(self, transport): - if self._reject_connection: - context = { - 'message': ('An open stream was garbage collected prior to ' - 'establishing network connection; ' - 'call "stream.close()" explicitly.') - } - if self._source_traceback: - context['source_traceback'] = self._source_traceback - self._loop.call_exception_handler(context) - transport.abort() - return - self._transport = transport - reader = self._stream_reader - if reader is not None: - reader.set_transport(transport) + self._stream_reader.set_transport(transport) self._over_ssl = transport.get_extra_info('sslcontext') is not None if self._client_connected_cb is not None: self._stream_writer = StreamWriter(transport, self, - reader, - self._loop, - _asyncio_internal=True) - res = self._client_connected_cb(reader, + self._stream_reader, + self._loop) + res = self._client_connected_cb(self._stream_reader, self._stream_writer) if coroutines.iscoroutine(res): self._loop.create_task(res) - self._strong_reader = None def connection_lost(self, exc): - reader = self._stream_reader - if reader is not None: + if self._stream_reader is not None: if exc is None: - reader.feed_eof() + self._stream_reader.feed_eof() else: - reader.set_exception(exc) + self._stream_reader.set_exception(exc) if not self._closed.done(): if exc is None: self._closed.set_result(None) else: self._closed.set_exception(exc) super().connection_lost(exc) - self._stream_reader_wr = None + self._stream_reader = None self._stream_writer = None - self._transport = None def data_received(self, data): - reader = self._stream_reader - if reader is not None: - reader.feed_data(data) + self._stream_reader.feed_data(data) def eof_received(self): - reader = self._stream_reader - if reader is not None: - reader.feed_eof() + self._stream_reader.feed_eof() if self._over_ssl: # Prevent a warning in SSLProtocol.eof_received: # "returning true from eof_received() @@ -318,9 +656,6 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): return False return True - def _get_close_waiter(self, stream): - return self._closed - def __del__(self): # Prevent reports about unhandled exceptions. # Better than self._closed._log_traceback = False hack @@ -329,13 +664,6 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): closed.exception() -def _swallow_unhandled_exception(task): - # Do a trick to suppress unhandled exception - # if stream.write() was used without await and - # stream.drain() was paused and resumed with an exception - task.exception() - - class StreamWriter: """Wraps a Transport. @@ -346,21 +674,13 @@ class StreamWriter: directly. """ - def __init__(self, transport, protocol, reader, loop, - *, _asyncio_internal=False): - if not _asyncio_internal: - warnings.warn(f"{self.__class__} should be instaniated " - "by asyncio internals only, " - "please avoid its creation from user code", - DeprecationWarning) + def __init__(self, transport, protocol, reader, loop): self._transport = transport self._protocol = protocol # drain() expects that the reader has an exception() method assert reader is None or isinstance(reader, StreamReader) self._reader = reader self._loop = loop - self._complete_fut = self._loop.create_future() - self._complete_fut.set_result(None) def __repr__(self): info = [self.__class__.__name__, f'transport={self._transport!r}'] @@ -374,35 +694,9 @@ class StreamWriter: def write(self, data): self._transport.write(data) - return self._fast_drain() def writelines(self, data): self._transport.writelines(data) - return self._fast_drain() - - def _fast_drain(self): - # The helper tries to use fast-path to return already existing complete future - # object if underlying transport is not paused and actual waiting for writing - # resume is not needed - if self._reader is not None: - # this branch will be simplified after merging reader with writer - exc = self._reader.exception() - if exc is not None: - fut = self._loop.create_future() - fut.set_exception(exc) - return fut - if not self._transport.is_closing(): - if self._protocol._connection_lost: - fut = self._loop.create_future() - fut.set_exception(ConnectionResetError('Connection lost')) - return fut - if not self._protocol._paused: - # fast path, the stream is not paused - # no need to wait for resume signal - return self._complete_fut - ret = self._loop.create_task(self.drain()) - ret.add_done_callback(_swallow_unhandled_exception) - return ret def write_eof(self): return self._transport.write_eof() @@ -411,14 +705,13 @@ class StreamWriter: return self._transport.can_write_eof() def close(self): - self._transport.close() - return self._protocol._get_close_waiter(self) + return self._transport.close() def is_closing(self): return self._transport.is_closing() async def wait_closed(self): - await self._protocol._get_close_waiter(self) + await self._protocol._closed def get_extra_info(self, name, default=None): return self._transport.get_extra_info(name, default) @@ -436,24 +729,562 @@ class StreamWriter: if exc is not None: raise exc if self._transport.is_closing(): - # Wait for protocol.connection_lost() call - # Raise connection closing error if any, - # ConnectionResetError otherwise - await sleep(0) + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); await drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + await tasks.sleep(0, loop=self._loop) await self._protocol._drain_helper() class StreamReader: + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + # The line length limit is a security feature; + # it also doubles as half the buffer limit. + + if limit <= 0: + raise ValueError('Limit cannot be <= 0') + + self._limit = limit + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._buffer = bytearray() + self._eof = False # Whether we're done. + self._waiter = None # A future used by _wait_for_data() + self._exception = None + self._transport = None + self._paused = False + + def __repr__(self): + info = ['StreamReader'] + if self._buffer: + info.append(f'{len(self._buffer)} bytes') + if self._eof: + info.append('eof') + if self._limit != _DEFAULT_LIMIT: + info.append(f'limit={self._limit}') + if self._waiter: + info.append(f'waiter={self._waiter!r}') + if self._exception: + info.append(f'exception={self._exception!r}') + if self._transport: + info.append(f'transport={self._transport!r}') + if self._paused: + info.append('paused') + return '<{}>'.format(' '.join(info)) + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_exception(exc) + + def _wakeup_waiter(self): + """Wakeup read*() functions waiting for data or EOF.""" + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_result(None) + + def set_transport(self, transport): + assert self._transport is None, 'Transport already set' + self._transport = transport + + def _maybe_resume_transport(self): + if self._paused and len(self._buffer) <= self._limit: + self._paused = False + self._transport.resume_reading() + + def feed_eof(self): + self._eof = True + self._wakeup_waiter() + + def at_eof(self): + """Return True if the buffer is empty and 'feed_eof' was called.""" + return self._eof and not self._buffer + + def feed_data(self, data): + assert not self._eof, 'feed_data after feed_eof' + + if not data: + return + + self._buffer.extend(data) + self._wakeup_waiter() + + if (self._transport is not None and + not self._paused and + len(self._buffer) > 2 * self._limit): + try: + self._transport.pause_reading() + except NotImplementedError: + # The transport can't be paused. + # We'll just have to buffer all data. + # Forget the transport so we don't keep trying. + self._transport = None + else: + self._paused = True + + async def _wait_for_data(self, func_name): + """Wait until feed_data() or feed_eof() is called. + + If stream was paused, automatically resume it. + """ + # StreamReader uses a future to link the protocol feed_data() method + # to a read coroutine. Running two read coroutines at the same time + # would have an unexpected behaviour. It would not possible to know + # which coroutine would get the next data. + if self._waiter is not None: + raise RuntimeError( + f'{func_name}() called while another coroutine is ' + f'already waiting for incoming data') + + assert not self._eof, '_wait_for_data after EOF' + + # Waiting for data while paused will make deadlock, so prevent it. + # This is essential for readexactly(n) for case when n > self._limit. + if self._paused: + self._paused = False + self._transport.resume_reading() + + self._waiter = self._loop.create_future() + try: + await self._waiter + finally: + self._waiter = None + + async def readline(self): + """Read chunk of data from the stream until newline (b'\n') is found. + + On success, return chunk that ends with newline. If only partial + line can be read due to EOF, return incomplete line without + terminating newline. When EOF was reached while no bytes read, empty + bytes object is returned. + + If limit is reached, ValueError will be raised. In that case, if + newline was found, complete line including newline will be removed + from internal buffer. Else, internal buffer will be cleared. Limit is + compared against part of the line without newline. + + If stream was paused, this function will automatically resume it if + needed. + """ + sep = b'\n' + seplen = len(sep) + try: + line = await self.readuntil(sep) + except exceptions.IncompleteReadError as e: + return e.partial + except exceptions.LimitOverrunError as e: + if self._buffer.startswith(sep, e.consumed): + del self._buffer[:e.consumed + seplen] + else: + self._buffer.clear() + self._maybe_resume_transport() + raise ValueError(e.args[0]) + return line + + async def readuntil(self, separator=b'\n'): + """Read data from the stream until ``separator`` is found. + + On success, the data and separator will be removed from the + internal buffer (consumed). Returned data will include the + separator at the end. + + Configured stream limit is used to check result. Limit sets the + maximal length of data that can be returned, not counting the + separator. + + If an EOF occurs and the complete separator is still not found, + an IncompleteReadError exception will be raised, and the internal + buffer will be reset. The IncompleteReadError.partial attribute + may contain the separator partially. + + If the data cannot be read because of over limit, a + LimitOverrunError exception will be raised, and the data + will be left in the internal buffer, so it can be read again. + """ + seplen = len(separator) + if seplen == 0: + raise ValueError('Separator should be at least one-byte string') + + if self._exception is not None: + raise self._exception + + # Consume whole buffer except last bytes, which length is + # one less than seplen. Let's check corner cases with + # separator='SEPARATOR': + # * we have received almost complete separator (without last + # byte). i.e buffer='some textSEPARATO'. In this case we + # can safely consume len(separator) - 1 bytes. + # * last byte of buffer is first byte of separator, i.e. + # buffer='abcdefghijklmnopqrS'. We may safely consume + # everything except that last byte, but this require to + # analyze bytes of buffer that match partial separator. + # This is slow and/or require FSM. For this case our + # implementation is not optimal, since require rescanning + # of data that is known to not belong to separator. In + # real world, separator will not be so long to notice + # performance problems. Even when reading MIME-encoded + # messages :) + + # `offset` is the number of bytes from the beginning of the buffer + # where there is no occurrence of `separator`. + offset = 0 + + # Loop until we find `separator` in the buffer, exceed the buffer size, + # or an EOF has happened. + while True: + buflen = len(self._buffer) + + # Check if we now have enough data in the buffer for `separator` to + # fit. + if buflen - offset >= seplen: + isep = self._buffer.find(separator, offset) + + if isep != -1: + # `separator` is in the buffer. `isep` will be used later + # to retrieve the data. + break + + # see upper comment for explanation. + offset = buflen + 1 - seplen + if offset > self._limit: + raise exceptions.LimitOverrunError( + 'Separator is not found, and chunk exceed the limit', + offset) + + # Complete message (with full separator) may be present in buffer + # even when EOF flag is set. This may happen when the last chunk + # adds data which makes separator be found. That's why we check for + # EOF *ater* inspecting the buffer. + if self._eof: + chunk = bytes(self._buffer) + self._buffer.clear() + raise exceptions.IncompleteReadError(chunk, None) + + # _wait_for_data() will resume reading if stream was paused. + await self._wait_for_data('readuntil') + + if isep > self._limit: + raise exceptions.LimitOverrunError( + 'Separator is found, but chunk is longer than limit', isep) + + chunk = self._buffer[:isep + seplen] + del self._buffer[:isep + seplen] + self._maybe_resume_transport() + return bytes(chunk) + + async def read(self, n=-1): + """Read up to `n` bytes from the stream. + + If n is not provided, or set to -1, read until EOF and return all read + bytes. If the EOF was received and the internal buffer is empty, return + an empty bytes object. + + If n is zero, return empty bytes object immediately. + + If n is positive, this function try to read `n` bytes, and may return + less or equal bytes than requested, but at least one byte. If EOF was + received before any byte is read, this function returns empty byte + object. + + Returned value is not limited with limit, configured at stream + creation. + + If stream was paused, this function will automatically resume it if + needed. + """ + + if self._exception is not None: + raise self._exception + + if n == 0: + return b'' + + if n < 0: + # This used to just loop creating a new waiter hoping to + # collect everything in self._buffer, but that would + # deadlock if the subprocess sends more than self.limit + # bytes. So just call self.read(self._limit) until EOF. + blocks = [] + while True: + block = await self.read(self._limit) + if not block: + break + blocks.append(block) + return b''.join(blocks) + + if not self._buffer and not self._eof: + await self._wait_for_data('read') + + # This will work right even if buffer is less than n bytes + data = bytes(self._buffer[:n]) + del self._buffer[:n] + + self._maybe_resume_transport() + return data + + async def readexactly(self, n): + """Read exactly `n` bytes. + + Raise an IncompleteReadError if EOF is reached before `n` bytes can be + read. The IncompleteReadError.partial attribute of the exception will + contain the partial read bytes. + + if n is zero, return empty bytes object. + + Returned value is not limited with limit, configured at stream + creation. + + If stream was paused, this function will automatically resume it if + needed. + """ + if n < 0: + raise ValueError('readexactly size can not be less than zero') + + if self._exception is not None: + raise self._exception + + if n == 0: + return b'' + + while len(self._buffer) < n: + if self._eof: + incomplete = bytes(self._buffer) + self._buffer.clear() + raise exceptions.IncompleteReadError(incomplete, n) + + await self._wait_for_data('readexactly') + + if len(self._buffer) == n: + data = bytes(self._buffer) + self._buffer.clear() + else: + data = bytes(self._buffer[:n]) + del self._buffer[:n] + self._maybe_resume_transport() + return data + + def __aiter__(self): + return self + + async def __anext__(self): + val = await self.readline() + if val == b'': + raise StopAsyncIteration + return val + + +# end legacy stream APIs + + +class _BaseStreamProtocol(FlowControlMixin, protocols.Protocol): + """Helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + _stream = None # initialized in derived classes + + def __init__(self, loop=None, + *, _asyncio_internal=False): + super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) + self._transport = None + self._over_ssl = False + self._closed = self._loop.create_future() + + def connection_made(self, transport): + self._transport = transport + self._over_ssl = transport.get_extra_info('sslcontext') is not None + + def connection_lost(self, exc): + stream = self._stream + if stream is not None: + if exc is None: + stream.feed_eof() + else: + stream.set_exception(exc) + if not self._closed.done(): + if exc is None: + self._closed.set_result(None) + else: + self._closed.set_exception(exc) + super().connection_lost(exc) + self._transport = None + + def data_received(self, data): + stream = self._stream + if stream is not None: + stream.feed_data(data) + + def eof_received(self): + stream = self._stream + if stream is not None: + stream.feed_eof() + if self._over_ssl: + # Prevent a warning in SSLProtocol.eof_received: + # "returning true from eof_received() + # has no effect when using ssl" + return False + return True + + def _get_close_waiter(self, stream): + return self._closed + + def __del__(self): + # Prevent reports about unhandled exceptions. + # Better than self._closed._log_traceback = False hack + closed = self._get_close_waiter(self._stream) + if closed.done() and not closed.cancelled(): + closed.exception() + + +class _StreamProtocol(_BaseStreamProtocol): _source_traceback = None - def __init__(self, limit=_DEFAULT_LIMIT, loop=None, + def __init__(self, stream, loop=None, *, _asyncio_internal=False): + super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) + self._source_traceback = stream._source_traceback + self._stream_wr = weakref.ref(stream, self._on_gc) + self._reject_connection = False + + def _on_gc(self, wr): + transport = self._transport + if transport is not None: + # connection_made was called + context = { + 'message': ('An open stream object is being garbage ' + 'collected; call "stream.close()" explicitly.') + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + transport.abort() + else: + self._reject_connection = True + self._stream_wr = None + + @property + def _stream(self): + if self._stream_wr is None: + return None + return self._stream_wr() + + def connection_made(self, transport): + if self._reject_connection: + context = { + 'message': ('An open stream was garbage collected prior to ' + 'establishing network connection; ' + 'call "stream.close()" explicitly.') + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + transport.abort() + return + super().connection_made(transport) + stream = self._stream + if stream is None: + return + stream.set_transport(transport) + stream._protocol = self + + def connection_lost(self, exc): + super().connection_lost(exc) + self._stream_wr = None + + +class _ServerStreamProtocol(_BaseStreamProtocol): + def __init__(self, server, limit, client_connected_cb, loop=None, + *, _asyncio_internal=False): + super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) + assert self._closed + self._client_connected_cb = client_connected_cb + self._limit = limit + self._server = server + self._task = None + + def connection_made(self, transport): + super().connection_made(transport) + stream = Stream(mode=StreamMode.READWRITE, + transport=transport, + protocol=self, + limit=self._limit, + loop=self._loop, + is_server_side=True, + _asyncio_internal=True) + self._stream = stream + # If self._client_connected_cb(self._stream) fails + # the exception is logged by transport + self._task = self._loop.create_task( + self._client_connected_cb(self._stream)) + self._server._attach(stream, self._task) + + def connection_lost(self, exc): + super().connection_lost(exc) + self._server._detach(self._stream, self._task) + self._stream = None + + +class _OptionalAwait: + # The class doesn't create a coroutine + # if not awaited + # It prevents "coroutine is never awaited" message + + __slots___ = ('_method',) + + def __init__(self, method): + self._method = method + + def __await__(self): + return self._method().__await__() + + +class Stream: + """Wraps a Transport. + + This exposes write(), writelines(), [can_]write_eof(), + get_extra_info() and close(). It adds drain() which returns an + optional Future on which you can wait for flow control. It also + adds a transport property which references the Transport + directly. + """ + + _source_traceback = None + + def __init__(self, mode, *, + transport=None, + protocol=None, + loop=None, + limit=_DEFAULT_LIMIT, + is_server_side=False, + _asyncio_internal=False): if not _asyncio_internal: warnings.warn(f"{self.__class__} should be instaniated " "by asyncio internals only, " "please avoid its creation from user code", DeprecationWarning) + self._mode = mode + self._transport = transport + self._protocol = protocol + self._is_server_side = is_server_side # The line length limit is a security feature; # it also doubles as half the buffer limit. @@ -470,14 +1301,17 @@ class StreamReader: self._eof = False # Whether we're done. self._waiter = None # A future used by _wait_for_data() self._exception = None - self._transport = None self._paused = False + self._complete_fut = self._loop.create_future() + self._complete_fut.set_result(None) + if self._loop.get_debug(): self._source_traceback = format_helpers.extract_stack( sys._getframe(1)) def __repr__(self): - info = ['StreamReader'] + info = [self.__class__.__name__] + info.append(f'mode={self._mode}') if self._buffer: info.append(f'{len(self._buffer)} bytes') if self._eof: @@ -494,6 +1328,110 @@ class StreamReader: info.append('paused') return '<{}>'.format(' '.join(info)) + @property + def mode(self): + return self._mode + + def is_server_side(self): + return self._is_server_side + + @property + def transport(self): + return self._transport + + def write(self, data): + _ensure_can_write(self._mode) + self._transport.write(data) + return self._fast_drain() + + def writelines(self, data): + _ensure_can_write(self._mode) + self._transport.writelines(data) + return self._fast_drain() + + def _fast_drain(self): + # The helper tries to use fast-path to return already existing + # complete future object if underlying transport is not paused + #and actual waiting for writing resume is not needed + exc = self.exception() + if exc is not None: + fut = self._loop.create_future() + fut.set_exception(exc) + return fut + if not self._transport.is_closing(): + if self._protocol._connection_lost: + fut = self._loop.create_future() + fut.set_exception(ConnectionResetError('Connection lost')) + return fut + if not self._protocol._paused: + # fast path, the stream is not paused + # no need to wait for resume signal + return self._complete_fut + return _OptionalAwait(self.drain) + + def write_eof(self): + _ensure_can_write(self._mode) + return self._transport.write_eof() + + def can_write_eof(self): + if not self._mode.is_write(): + return False + return self._transport.can_write_eof() + + def close(self): + self._transport.close() + return _OptionalAwait(self.wait_closed) + + def is_closing(self): + return self._transport.is_closing() + + async def abort(self): + self._transport.abort() + await self.wait_closed() + + async def wait_closed(self): + await self._protocol._get_close_waiter(self) + + def get_extra_info(self, name, default=None): + return self._transport.get_extra_info(name, default) + + async def drain(self): + """Flush the write buffer. + + The intended use is to write + + w.write(data) + await w.drain() + """ + _ensure_can_write(self._mode) + exc = self.exception() + if exc is not None: + raise exc + if self._transport.is_closing(): + # Wait for protocol.connection_lost() call + # Raise connection closing error if any, + # ConnectionResetError otherwise + await tasks.sleep(0) + await self._protocol._drain_helper() + + async def sendfile(self, file, offset=0, count=None, *, fallback=True): + await self.drain() # check for stream mode and exceptions + return await self._loop.sendfile(self._transport, file, + offset, count, fallback=fallback) + + async def start_tls(self, sslcontext, *, + server_hostname=None, + ssl_handshake_timeout=None): + await self.drain() # check for stream mode and exceptions + transport = await self._loop.start_tls( + self._transport, self._protocol, sslcontext, + server_side=self._is_server_side, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) + self._transport = transport + self._protocol._transport = transport + self._protocol._over_ssl = True + def exception(self): return self._exception @@ -515,6 +1453,8 @@ class StreamReader: waiter.set_result(None) def set_transport(self, transport): + if transport is self._transport: + return assert self._transport is None, 'Transport already set' self._transport = transport @@ -532,6 +1472,7 @@ class StreamReader: return self._eof and not self._buffer def feed_data(self, data): + _ensure_can_read(self._mode) assert not self._eof, 'feed_data after feed_eof' if not data: @@ -597,6 +1538,7 @@ class StreamReader: If stream was paused, this function will automatically resume it if needed. """ + _ensure_can_read(self._mode) sep = b'\n' seplen = len(sep) try: @@ -632,6 +1574,7 @@ class StreamReader: LimitOverrunError exception will be raised, and the data will be left in the internal buffer, so it can be read again. """ + _ensure_can_read(self._mode) seplen = len(separator) if seplen == 0: raise ValueError('Separator should be at least one-byte string') @@ -723,6 +1666,7 @@ class StreamReader: If stream was paused, this function will automatically resume it if needed. """ + _ensure_can_read(self._mode) if self._exception is not None: raise self._exception @@ -768,6 +1712,7 @@ class StreamReader: If stream was paused, this function will automatically resume it if needed. """ + _ensure_can_read(self._mode) if n < 0: raise ValueError('readexactly size can not be less than zero') @@ -795,6 +1740,7 @@ class StreamReader: return data def __aiter__(self): + _ensure_can_read(self._mode) return self async def __anext__(self): @@ -802,3 +1748,9 @@ class StreamReader: if val == b'': raise StopAsyncIteration return val + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() diff --git a/Lib/asyncio/subprocess.py b/Lib/asyncio/subprocess.py index d34b611..e6bec71 100644 --- a/Lib/asyncio/subprocess.py +++ b/Lib/asyncio/subprocess.py @@ -27,6 +27,8 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, self._process_exited = False self._pipe_fds = [] self._stdin_closed = self._loop.create_future() + self._stdout_closed = self._loop.create_future() + self._stderr_closed = self._loop.create_future() def __repr__(self): info = [self.__class__.__name__] @@ -40,30 +42,35 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, def connection_made(self, transport): self._transport = transport - stdout_transport = transport.get_pipe_transport(1) if stdout_transport is not None: - self.stdout = streams.StreamReader(limit=self._limit, - loop=self._loop, - _asyncio_internal=True) + self.stdout = streams.Stream(mode=streams.StreamMode.READ, + transport=stdout_transport, + protocol=self, + limit=self._limit, + loop=self._loop, + _asyncio_internal=True) self.stdout.set_transport(stdout_transport) self._pipe_fds.append(1) stderr_transport = transport.get_pipe_transport(2) if stderr_transport is not None: - self.stderr = streams.StreamReader(limit=self._limit, - loop=self._loop, - _asyncio_internal=True) + self.stderr = streams.Stream(mode=streams.StreamMode.READ, + transport=stderr_transport, + protocol=self, + limit=self._limit, + loop=self._loop, + _asyncio_internal=True) self.stderr.set_transport(stderr_transport) self._pipe_fds.append(2) stdin_transport = transport.get_pipe_transport(0) if stdin_transport is not None: - self.stdin = streams.StreamWriter(stdin_transport, - protocol=self, - reader=None, - loop=self._loop, - _asyncio_internal=True) + self.stdin = streams.Stream(mode=streams.StreamMode.WRITE, + transport=stdin_transport, + protocol=self, + loop=self._loop, + _asyncio_internal=True) def pipe_data_received(self, fd, data): if fd == 1: @@ -114,6 +121,10 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, def _get_close_waiter(self, stream): if stream is self.stdin: return self._stdin_closed + elif stream is self.stdout: + return self._stdout_closed + elif stream is self.stderr: + return self._stderr_closed class Process: diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py index b5b2e24..61b40ba 100644 --- a/Lib/asyncio/windows_events.py +++ b/Lib/asyncio/windows_events.py @@ -607,7 +607,7 @@ class IocpProactor: # ConnectPipe() failed with ERROR_PIPE_BUSY: retry later delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY) - await tasks.sleep(delay, loop=self._loop) + await tasks.sleep(delay) return windows_utils.PipeHandle(handle) diff --git a/Lib/test/test___all__.py b/Lib/test/test___all__.py index f6e82eb..c077881 100644 --- a/Lib/test/test___all__.py +++ b/Lib/test/test___all__.py @@ -30,21 +30,27 @@ class AllTest(unittest.TestCase): raise NoAll(modname) names = {} with self.subTest(module=modname): - try: - exec("from %s import *" % modname, names) - except Exception as e: - # Include the module name in the exception string - self.fail("__all__ failure in {}: {}: {}".format( - modname, e.__class__.__name__, e)) - if "__builtins__" in names: - del names["__builtins__"] - if '__annotations__' in names: - del names['__annotations__'] - keys = set(names) - all_list = sys.modules[modname].__all__ - all_set = set(all_list) - self.assertCountEqual(all_set, all_list, "in module {}".format(modname)) - self.assertEqual(keys, all_set, "in module {}".format(modname)) + with support.check_warnings( + ("", DeprecationWarning), + ("", ResourceWarning), + quiet=True): + try: + exec("from %s import *" % modname, names) + except Exception as e: + # Include the module name in the exception string + self.fail("__all__ failure in {}: {}: {}".format( + modname, e.__class__.__name__, e)) + if "__builtins__" in names: + del names["__builtins__"] + if '__annotations__' in names: + del names['__annotations__'] + if "__warningregistry__" in names: + del names["__warningregistry__"] + keys = set(names) + all_list = sys.modules[modname].__all__ + all_set = set(all_list) + self.assertCountEqual(all_set, all_list, "in module {}".format(modname)) + self.assertEqual(keys, all_set, "in module {}".format(modname)) def walk_modules(self, basedir, modpath): for fn in sorted(os.listdir(basedir)): diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 31018c5..02a97c6 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -1152,8 +1152,9 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): @unittest.skipUnless(hasattr(socket, 'AF_INET6'), 'no IPv6 support') def test_create_server_ipv6(self): async def main(): - srv = await asyncio.start_server( - lambda: None, '::1', 0, loop=self.loop) + with self.assertWarns(DeprecationWarning): + srv = await asyncio.start_server( + lambda: None, '::1', 0, loop=self.loop) try: self.assertGreater(len(srv.sockets), 0) finally: diff --git a/Lib/test/test_asyncio/test_buffered_proto.py b/Lib/test/test_asyncio/test_buffered_proto.py index f24e363..b1531fb 100644 --- a/Lib/test/test_asyncio/test_buffered_proto.py +++ b/Lib/test/test_asyncio/test_buffered_proto.py @@ -58,9 +58,10 @@ class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin): writer.close() await writer.wait_closed() - srv = self.loop.run_until_complete( - asyncio.start_server( - on_server_client, '127.0.0.1', 0)) + with self.assertWarns(DeprecationWarning): + srv = self.loop.run_until_complete( + asyncio.start_server( + on_server_client, '127.0.0.1', 0)) addr = srv.sockets[0].getsockname() self.loop.run_until_complete( diff --git a/Lib/test/test_asyncio/test_pep492.py b/Lib/test/test_asyncio/test_pep492.py index 297a3b3..11c0ce4 100644 --- a/Lib/test/test_asyncio/test_pep492.py +++ b/Lib/test/test_asyncio/test_pep492.py @@ -94,7 +94,9 @@ class StreamReaderTests(BaseTest): def test_readline(self): DATA = b'line1\nline2\nline3' - stream = asyncio.StreamReader(loop=self.loop, _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(DATA) stream.feed_eof() diff --git a/Lib/test/test_asyncio/test_server.py b/Lib/test/test_asyncio/test_server.py index 4e758ad..0e38e6c 100644 --- a/Lib/test/test_asyncio/test_server.py +++ b/Lib/test/test_asyncio/test_server.py @@ -46,8 +46,9 @@ class BaseStartServer(func_tests.FunctionalTestCaseMixin): async with srv: await srv.serve_forever() - srv = self.loop.run_until_complete(asyncio.start_server( - serve, support.HOSTv4, 0, loop=self.loop, start_serving=False)) + with self.assertWarns(DeprecationWarning): + srv = self.loop.run_until_complete(asyncio.start_server( + serve, support.HOSTv4, 0, loop=self.loop, start_serving=False)) self.assertFalse(srv.is_serving()) @@ -102,8 +103,9 @@ class SelectorStartServerTests(BaseStartServer, unittest.TestCase): await srv.serve_forever() with test_utils.unix_socket_path() as addr: - srv = self.loop.run_until_complete(asyncio.start_unix_server( - serve, addr, loop=self.loop, start_serving=False)) + with self.assertWarns(DeprecationWarning): + srv = self.loop.run_until_complete(asyncio.start_unix_server( + serve, addr, loop=self.loop, start_serving=False)) main_task = self.loop.create_task(main(srv)) diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index 079b255..4215abf 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -649,12 +649,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): sock.close() async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - loop=self.loop, - ssl_handshake_timeout=1.0) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + loop=self.loop, + ssl_handshake_timeout=1.0) with self.tcp_server(server, max_clients=1, @@ -688,12 +689,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): sock.close() async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - loop=self.loop, - ssl_handshake_timeout=1.0) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + loop=self.loop, + ssl_handshake_timeout=1.0) with self.tcp_server(server, max_clients=1, @@ -724,11 +726,12 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): sock.close() async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - loop=self.loop) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + loop=self.loop) self.assertEqual(await reader.readline(), b'A\n') writer.write(b'B') diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index fed6098..df3d7e7 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1,6 +1,8 @@ """Tests for streams.py.""" +import contextlib import gc +import io import os import queue import pickle @@ -16,6 +18,7 @@ except ImportError: ssl = None import asyncio +from asyncio.streams import _StreamProtocol, _ensure_can_read, _ensure_can_write from test.test_asyncio import utils as test_utils @@ -23,6 +26,24 @@ def tearDownModule(): asyncio.set_event_loop_policy(None) +class StreamModeTests(unittest.TestCase): + def test__ensure_can_read_ok(self): + self.assertIsNone(_ensure_can_read(asyncio.StreamMode.READ)) + self.assertIsNone(_ensure_can_read(asyncio.StreamMode.READWRITE)) + + def test__ensure_can_read_fail(self): + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + _ensure_can_read(asyncio.StreamMode.WRITE) + + def test__ensure_can_write_ok(self): + self.assertIsNone(_ensure_can_write(asyncio.StreamMode.WRITE)) + self.assertIsNone(_ensure_can_write(asyncio.StreamMode.READWRITE)) + + def test__ensure_can_write_fail(self): + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + _ensure_can_write(asyncio.StreamMode.READ) + + class StreamTests(test_utils.TestCase): DATA = b'line1\nline2\nline3\n' @@ -42,13 +63,15 @@ class StreamTests(test_utils.TestCase): @mock.patch('asyncio.streams.events') def test_ctor_global_loop(self, m_events): - stream = asyncio.StreamReader(_asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + _asyncio_internal=True) self.assertIs(stream._loop, m_events.get_event_loop.return_value) def _basetest_open_connection(self, open_connection_fut): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - reader, writer = self.loop.run_until_complete(open_connection_fut) + with self.assertWarns(DeprecationWarning): + reader, writer = self.loop.run_until_complete(open_connection_fut) writer.write(b'GET / HTTP/1.0\r\n\r\n') f = reader.readline() data = self.loop.run_until_complete(f) @@ -76,7 +99,9 @@ class StreamTests(test_utils.TestCase): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) try: - reader, writer = self.loop.run_until_complete(open_connection_fut) + with self.assertWarns(DeprecationWarning): + reader, writer = self.loop.run_until_complete( + open_connection_fut) finally: asyncio.set_event_loop(None) writer.write(b'GET / HTTP/1.0\r\n\r\n') @@ -112,7 +137,8 @@ class StreamTests(test_utils.TestCase): def _basetest_open_connection_error(self, open_connection_fut): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - reader, writer = self.loop.run_until_complete(open_connection_fut) + with self.assertWarns(DeprecationWarning): + reader, writer = self.loop.run_until_complete(open_connection_fut) writer._protocol.connection_lost(ZeroDivisionError()) f = reader.read() with self.assertRaises(ZeroDivisionError): @@ -135,23 +161,26 @@ class StreamTests(test_utils.TestCase): self._basetest_open_connection_error(conn_fut) def test_feed_empty_data(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'') self.assertEqual(b'', stream._buffer) def test_feed_nonempty_data(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA) self.assertEqual(self.DATA, stream._buffer) def test_read_zero(self): # Read zero bytes. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA) data = self.loop.run_until_complete(stream.read(0)) @@ -160,8 +189,9 @@ class StreamTests(test_utils.TestCase): def test_read(self): # Read bytes. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) read_task = asyncio.Task(stream.read(30), loop=self.loop) def cb(): @@ -174,8 +204,9 @@ class StreamTests(test_utils.TestCase): def test_read_line_breaks(self): # Read bytes without line breaks. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'line1') stream.feed_data(b'line2') @@ -186,8 +217,9 @@ class StreamTests(test_utils.TestCase): def test_read_eof(self): # Read bytes, stop at eof. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) read_task = asyncio.Task(stream.read(1024), loop=self.loop) def cb(): @@ -200,8 +232,9 @@ class StreamTests(test_utils.TestCase): def test_read_until_eof(self): # Read all bytes until eof. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) read_task = asyncio.Task(stream.read(-1), loop=self.loop) def cb(): @@ -216,8 +249,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'', stream._buffer) def test_read_exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.read(2)) @@ -229,16 +263,19 @@ class StreamTests(test_utils.TestCase): def test_invalid_limit(self): with self.assertRaisesRegex(ValueError, 'imit'): - asyncio.StreamReader(limit=0, loop=self.loop, - _asyncio_internal=True) + asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=0, loop=self.loop, + _asyncio_internal=True) with self.assertRaisesRegex(ValueError, 'imit'): - asyncio.StreamReader(limit=-1, loop=self.loop, - _asyncio_internal=True) + asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=-1, loop=self.loop, + _asyncio_internal=True) def test_read_limit(self): - stream = asyncio.StreamReader(limit=3, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=3, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'chunk') data = self.loop.run_until_complete(stream.read(5)) self.assertEqual(b'chunk', data) @@ -247,8 +284,9 @@ class StreamTests(test_utils.TestCase): def test_readline(self): # Read one line. 'readline' will need to wait for the data # to come from 'cb' - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'chunk1 ') read_task = asyncio.Task(stream.readline(), loop=self.loop) @@ -263,11 +301,12 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b' chunk4', stream._buffer) def test_readline_limit_with_existing_data(self): - # Read one line. The data is in StreamReader's buffer + # Read one line. The data is in Stream's buffer # before the event loop is run. - stream = asyncio.StreamReader(limit=3, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=3, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'li') stream.feed_data(b'ne1\nline2\n') @@ -276,8 +315,9 @@ class StreamTests(test_utils.TestCase): # The buffer should contain the remaining data after exception self.assertEqual(b'line2\n', stream._buffer) - stream = asyncio.StreamReader(limit=3, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=3, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'li') stream.feed_data(b'ne1') stream.feed_data(b'li') @@ -292,8 +332,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'', stream._buffer) def test_at_eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) self.assertFalse(stream.at_eof()) stream.feed_data(b'some data\n') @@ -308,11 +349,12 @@ class StreamTests(test_utils.TestCase): self.assertTrue(stream.at_eof()) def test_readline_limit(self): - # Read one line. StreamReaders are fed with data after + # Read one line. Streams are fed with data after # their 'readline' methods are called. - stream = asyncio.StreamReader(limit=7, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=7, loop=self.loop, + _asyncio_internal=True) def cb(): stream.feed_data(b'chunk1') stream.feed_data(b'chunk2') @@ -326,8 +368,9 @@ class StreamTests(test_utils.TestCase): # a ValueError it should be empty. self.assertEqual(b'', stream._buffer) - stream = asyncio.StreamReader(limit=7, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=7, loop=self.loop, + _asyncio_internal=True) def cb(): stream.feed_data(b'chunk1') stream.feed_data(b'chunk2\n') @@ -340,8 +383,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'chunk3\n', stream._buffer) # check strictness of the limit - stream = asyncio.StreamReader(limit=7, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=7, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'1234567\n') line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'1234567\n', line) @@ -360,8 +404,9 @@ class StreamTests(test_utils.TestCase): def test_readline_nolimit_nowait(self): # All needed data for the first 'readline' call will be # in the buffer. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[6:]) @@ -371,8 +416,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'line2\nline3\n', stream._buffer) def test_readline_eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'some data') stream.feed_eof() @@ -380,16 +426,18 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'some data', line) def test_readline_empty_eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_eof() line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'', line) def test_readline_read_byte_count(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA) self.loop.run_until_complete(stream.readline()) @@ -400,8 +448,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'ine3\n', stream._buffer) def test_readline_exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.readline()) @@ -413,14 +462,16 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'', stream._buffer) def test_readuntil_separator(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) with self.assertRaisesRegex(ValueError, 'Separator should be'): self.loop.run_until_complete(stream.readuntil(separator=b'')) def test_readuntil_multi_chunks(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'lineAAA') data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA')) @@ -438,8 +489,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'xxx', stream._buffer) def test_readuntil_multi_chunks_1(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'QWEaa') stream.feed_data(b'XYaa') @@ -474,8 +526,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'', stream._buffer) def test_readuntil_eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'some dataAA') stream.feed_eof() @@ -486,8 +539,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'', stream._buffer) def test_readuntil_limit_found_sep(self): - stream = asyncio.StreamReader(loop=self.loop, limit=3, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, limit=3, + _asyncio_internal=True) stream.feed_data(b'some dataAA') with self.assertRaisesRegex(asyncio.LimitOverrunError, @@ -505,8 +559,9 @@ class StreamTests(test_utils.TestCase): def test_readexactly_zero_or_less(self): # Read exact number of bytes (zero or less). - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA) data = self.loop.run_until_complete(stream.readexactly(0)) @@ -519,8 +574,9 @@ class StreamTests(test_utils.TestCase): def test_readexactly(self): # Read exact number of bytes. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) n = 2 * len(self.DATA) read_task = asyncio.Task(stream.readexactly(n), loop=self.loop) @@ -536,8 +592,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(self.DATA, stream._buffer) def test_readexactly_limit(self): - stream = asyncio.StreamReader(limit=3, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=3, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'chunk') data = self.loop.run_until_complete(stream.readexactly(5)) self.assertEqual(b'chunk', data) @@ -545,8 +602,9 @@ class StreamTests(test_utils.TestCase): def test_readexactly_eof(self): # Read exact number of bytes (eof). - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) n = 2 * len(self.DATA) read_task = asyncio.Task(stream.readexactly(n), loop=self.loop) @@ -564,8 +622,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'', stream._buffer) def test_readexactly_exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.readexactly(2)) @@ -576,8 +635,9 @@ class StreamTests(test_utils.TestCase): ValueError, self.loop.run_until_complete, stream.readexactly(2)) def test_exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) self.assertIsNone(stream.exception()) exc = ValueError() @@ -585,8 +645,9 @@ class StreamTests(test_utils.TestCase): self.assertIs(stream.exception(), exc) def test_exception_waiter(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) async def set_err(): stream.set_exception(ValueError()) @@ -599,8 +660,9 @@ class StreamTests(test_utils.TestCase): self.assertRaises(ValueError, t1.result) def test_exception_cancel(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) t = asyncio.Task(stream.readline(), loop=self.loop) test_utils.run_briefly(self.loop) @@ -655,8 +717,9 @@ class StreamTests(test_utils.TestCase): self.server = None async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, loop=self.loop) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_connection( + *addr, loop=self.loop) # send a line writer.write(b"hello world!\n") # read it back @@ -670,7 +733,8 @@ class StreamTests(test_utils.TestCase): # test the server variant with a coroutine as client handler server = MyServer(self.loop) - addr = server.start() + with self.assertWarns(DeprecationWarning): + addr = server.start() msg = self.loop.run_until_complete(asyncio.Task(client(addr), loop=self.loop)) server.stop() @@ -678,7 +742,8 @@ class StreamTests(test_utils.TestCase): # test the server variant with a callback as client handler server = MyServer(self.loop) - addr = server.start_callback() + with self.assertWarns(DeprecationWarning): + addr = server.start_callback() msg = self.loop.run_until_complete(asyncio.Task(client(addr), loop=self.loop)) server.stop() @@ -726,8 +791,9 @@ class StreamTests(test_utils.TestCase): self.server = None async def client(path): - reader, writer = await asyncio.open_unix_connection( - path, loop=self.loop) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_unix_connection( + path, loop=self.loop) # send a line writer.write(b"hello world!\n") # read it back @@ -742,7 +808,8 @@ class StreamTests(test_utils.TestCase): # test the server variant with a coroutine as client handler with test_utils.unix_socket_path() as path: server = MyServer(self.loop, path) - server.start() + with self.assertWarns(DeprecationWarning): + server.start() msg = self.loop.run_until_complete(asyncio.Task(client(path), loop=self.loop)) server.stop() @@ -751,7 +818,8 @@ class StreamTests(test_utils.TestCase): # test the server variant with a callback as client handler with test_utils.unix_socket_path() as path: server = MyServer(self.loop, path) - server.start_callback() + with self.assertWarns(DeprecationWarning): + server.start_callback() msg = self.loop.run_until_complete(asyncio.Task(client(path), loop=self.loop)) server.stop() @@ -763,7 +831,7 @@ class StreamTests(test_utils.TestCase): def test_read_all_from_pipe_reader(self): # See asyncio issue 168. This test is derived from the example # subprocess_attach_read_pipe.py, but we configure the - # StreamReader's limit so that twice it is less than the size + # Stream's limit so that twice it is less than the size # of the data writter. Also we must explicitly attach a child # watcher to the event loop. @@ -777,10 +845,11 @@ os.close(fd) args = [sys.executable, '-c', code, str(wfd)] pipe = open(rfd, 'rb', 0) - reader = asyncio.StreamReader(loop=self.loop, limit=1, - _asyncio_internal=True) - protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, limit=1, + _asyncio_internal=True) + protocol = _StreamProtocol(stream, loop=self.loop, + _asyncio_internal=True) transport, _ = self.loop.run_until_complete( self.loop.connect_read_pipe(lambda: protocol, pipe)) @@ -797,29 +866,30 @@ os.close(fd) asyncio.set_child_watcher(None) os.close(wfd) - data = self.loop.run_until_complete(reader.read(-1)) + data = self.loop.run_until_complete(stream.read(-1)) self.assertEqual(data, b'data') def test_streamreader_constructor(self): self.addCleanup(asyncio.set_event_loop, None) asyncio.set_event_loop(self.loop) - # asyncio issue #184: Ensure that StreamReaderProtocol constructor + # asyncio issue #184: Ensure that _StreamProtocol constructor # retrieves the current loop if the loop parameter is not set - reader = asyncio.StreamReader(_asyncio_internal=True) + reader = asyncio.Stream(mode=asyncio.StreamMode.READ, + _asyncio_internal=True) self.assertIs(reader._loop, self.loop) def test_streamreaderprotocol_constructor(self): self.addCleanup(asyncio.set_event_loop, None) asyncio.set_event_loop(self.loop) - # asyncio issue #184: Ensure that StreamReaderProtocol constructor + # asyncio issue #184: Ensure that _StreamProtocol constructor # retrieves the current loop if the loop parameter is not set - reader = mock.Mock() - protocol = asyncio.StreamReaderProtocol(reader, _asyncio_internal=True) + stream = mock.Mock() + protocol = _StreamProtocol(stream, _asyncio_internal=True) self.assertIs(protocol._loop, self.loop) - def test_drain_raises(self): + def test_drain_raises_deprecated(self): # See http://bugs.python.org/issue25441 # This test should not use asyncio for the mock server; the @@ -833,15 +903,16 @@ os.close(fd) def server(): # Runs in a separate thread. - with socket.create_server(('localhost', 0)) as sock: + with socket.create_server(('127.0.0.1', 0)) as sock: addr = sock.getsockname() q.put(addr) clt, _ = sock.accept() clt.close() async def client(host, port): - reader, writer = await asyncio.open_connection( - host, port, loop=self.loop) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_connection( + host, port, loop=self.loop) while True: writer.write(b"foo\n") @@ -863,55 +934,106 @@ os.close(fd) thread.join() self.assertEqual([], messages) + def test_drain_raises(self): + # See http://bugs.python.org/issue25441 + + # This test should not use asyncio for the mock server; the + # whole point of the test is to test for a bug in drain() + # where it never gives up the event loop but the socket is + # closed on the server side. + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + q = queue.Queue() + + def server(): + # Runs in a separate thread. + with socket.create_server(('localhost', 0)) as sock: + addr = sock.getsockname() + q.put(addr) + clt, _ = sock.accept() + clt.close() + + async def client(host, port): + stream = await asyncio.connect(host, port) + + while True: + stream.write(b"foo\n") + await stream.drain() + + # Start the server thread and wait for it to be listening. + thread = threading.Thread(target=server) + thread.setDaemon(True) + thread.start() + addr = q.get() + + # Should not be stuck in an infinite loop. + with self.assertRaises((ConnectionResetError, ConnectionAbortedError, + BrokenPipeError)): + self.loop.run_until_complete(client(*addr)) + + # Clean up the thread. (Only on success; on failure, it may + # be stuck in accept().) + thread.join() + self.assertEqual([], messages) + def test___repr__(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) - self.assertEqual("<StreamReader>", repr(stream)) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) + self.assertEqual("<Stream mode=StreamMode.READ>", repr(stream)) def test___repr__nondefault_limit(self): - stream = asyncio.StreamReader(loop=self.loop, limit=123, - _asyncio_internal=True) - self.assertEqual("<StreamReader limit=123>", repr(stream)) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, limit=123, + _asyncio_internal=True) + self.assertEqual("<Stream mode=StreamMode.READ limit=123>", repr(stream)) def test___repr__eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_eof() - self.assertEqual("<StreamReader eof>", repr(stream)) + self.assertEqual("<Stream mode=StreamMode.READ eof>", repr(stream)) def test___repr__data(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'data') - self.assertEqual("<StreamReader 4 bytes>", repr(stream)) + self.assertEqual("<Stream mode=StreamMode.READ 4 bytes>", repr(stream)) def test___repr__exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) exc = RuntimeError() stream.set_exception(exc) - self.assertEqual("<StreamReader exception=RuntimeError()>", + self.assertEqual("<Stream mode=StreamMode.READ exception=RuntimeError()>", repr(stream)) def test___repr__waiter(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream._waiter = asyncio.Future(loop=self.loop) self.assertRegex( repr(stream), - r"<StreamReader waiter=<Future pending[\S ]*>>") + r"<Stream .+ waiter=<Future pending[\S ]*>>") stream._waiter.set_result(None) self.loop.run_until_complete(stream._waiter) stream._waiter = None - self.assertEqual("<StreamReader>", repr(stream)) + self.assertEqual("<Stream mode=StreamMode.READ>", repr(stream)) def test___repr__transport(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream._transport = mock.Mock() stream._transport.__repr__ = mock.Mock() stream._transport.__repr__.return_value = "<Transport>" - self.assertEqual("<StreamReader transport=<Transport>>", repr(stream)) + self.assertEqual("<Stream mode=StreamMode.READ transport=<Transport>>", + repr(stream)) def test_IncompleteReadError_pickleable(self): e = asyncio.IncompleteReadError(b'abc', 10) @@ -930,10 +1052,11 @@ os.close(fd) self.assertEqual(str(e), str(e2)) self.assertEqual(e.consumed, e2.consumed) - def test_wait_closed_on_close(self): + def test_wait_closed_on_close_deprecated(self): with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address, loop=self.loop)) + with self.assertWarns(DeprecationWarning): + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, loop=self.loop)) wr.write(b'GET / HTTP/1.0\r\n\r\n') f = rd.readline() @@ -947,10 +1070,28 @@ os.close(fd) self.assertTrue(wr.is_closing()) self.loop.run_until_complete(wr.wait_closed()) - def test_wait_closed_on_close_with_unread_data(self): + def test_wait_closed_on_close(self): with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address, loop=self.loop)) + stream = self.loop.run_until_complete( + asyncio.connect(*httpd.address)) + + stream.write(b'GET / HTTP/1.0\r\n\r\n') + f = stream.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = stream.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertFalse(stream.is_closing()) + stream.close() + self.assertTrue(stream.is_closing()) + self.loop.run_until_complete(stream.wait_closed()) + + def test_wait_closed_on_close_with_unread_data_deprecated(self): + with test_utils.run_test_server() as httpd: + with self.assertWarns(DeprecationWarning): + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, loop=self.loop)) wr.write(b'GET / HTTP/1.0\r\n\r\n') f = rd.readline() @@ -959,32 +1100,44 @@ os.close(fd) wr.close() self.loop.run_until_complete(wr.wait_closed()) + def test_wait_closed_on_close_with_unread_data(self): + with test_utils.run_test_server() as httpd: + stream = self.loop.run_until_complete( + asyncio.connect(*httpd.address)) + + stream.write(b'GET / HTTP/1.0\r\n\r\n') + f = stream.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + stream.close() + self.loop.run_until_complete(stream.wait_closed()) + def test_del_stream_before_sock_closing(self): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address, loop=self.loop)) - sock = wr.get_extra_info('socket') - self.assertNotEqual(sock.fileno(), -1) + async def test(): - wr.write(b'GET / HTTP/1.0\r\n\r\n') - f = rd.readline() - data = self.loop.run_until_complete(f) - self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + with test_utils.run_test_server() as httpd: + stream = await asyncio.connect(*httpd.address) + sock = stream.get_extra_info('socket') + self.assertNotEqual(sock.fileno(), -1) - # drop refs to reader/writer - del rd - del wr - gc.collect() - # make a chance to close the socket - test_utils.run_briefly(self.loop) + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - self.assertEqual(1, len(messages)) - self.assertEqual(sock.fileno(), -1) + # drop refs to reader/writer + del stream + gc.collect() + # make a chance to close the socket + await asyncio.sleep(0) - self.assertEqual(1, len(messages)) + self.assertEqual(1, len(messages), messages) + self.assertEqual(sock.fileno(), -1) + + self.loop.run_until_complete(test()) + self.assertEqual(1, len(messages), messages) self.assertEqual('An open stream object is being garbage ' 'collected; call "stream.close()" explicitly.', messages[0]['message']) @@ -994,11 +1147,12 @@ os.close(fd) self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) with test_utils.run_test_server() as httpd: - rd = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) - pr = asyncio.StreamReaderProtocol(rd, loop=self.loop, - _asyncio_internal=True) - del rd + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) + pr = _StreamProtocol(stream, loop=self.loop, + _asyncio_internal=True) + del stream gc.collect() tr, _ = self.loop.run_until_complete( self.loop.create_connection( @@ -1015,14 +1169,14 @@ os.close(fd) def test_async_writer_api(self): async def inner(httpd): - rd, wr = await asyncio.open_connection(*httpd.address) + stream = await asyncio.connect(*httpd.address) - await wr.write(b'GET / HTTP/1.0\r\n\r\n') - data = await rd.readline() + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - data = await rd.read() + data = await stream.read() self.assertTrue(data.endswith(b'\r\n\r\nTest message')) - await wr.close() + await stream.close() messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1032,18 +1186,18 @@ os.close(fd) self.assertEqual(messages, []) - def test_async_writer_api(self): + def test_async_writer_api_exception_after_close(self): async def inner(httpd): - rd, wr = await asyncio.open_connection(*httpd.address) + stream = await asyncio.connect(*httpd.address) - await wr.write(b'GET / HTTP/1.0\r\n\r\n') - data = await rd.readline() + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - data = await rd.read() + data = await stream.read() self.assertTrue(data.endswith(b'\r\n\r\nTest message')) - wr.close() + stream.close() with self.assertRaises(ConnectionResetError): - await wr.write(b'data') + await stream.write(b'data') messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1059,11 +1213,13 @@ os.close(fd) self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address, - loop=self.loop)) + with self.assertWarns(DeprecationWarning): + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, + loop=self.loop)) - f = wr.close() + wr.close() + f = wr.wait_closed() self.loop.run_until_complete(f) assert rd.at_eof() f = rd.read() @@ -1074,22 +1230,514 @@ os.close(fd) def test_stream_reader_create_warning(self): with self.assertWarns(DeprecationWarning): - asyncio.StreamReader(loop=self.loop) - - def test_stream_reader_protocol_create_warning(self): - reader = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) - with self.assertWarns(DeprecationWarning): - asyncio.StreamReaderProtocol(reader, loop=self.loop) + asyncio.StreamReader def test_stream_writer_create_warning(self): - reader = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) - proto = asyncio.StreamReaderProtocol(reader, loop=self.loop, - _asyncio_internal=True) with self.assertWarns(DeprecationWarning): - asyncio.StreamWriter('transport', proto, reader, self.loop) + asyncio.StreamWriter + + def test_stream_reader_forbidden_ops(self): + async def inner(): + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + _asyncio_internal=True) + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + await stream.write(b'data') + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + await stream.writelines([b'data', b'other']) + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + stream.write_eof() + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + await stream.drain() + + self.loop.run_until_complete(inner()) + + def test_stream_writer_forbidden_ops(self): + async def inner(): + stream = asyncio.Stream(mode=asyncio.StreamMode.WRITE, + _asyncio_internal=True) + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + stream.feed_data(b'data') + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + await stream.readline() + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + await stream.readuntil() + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + await stream.read() + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + await stream.readexactly(10) + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + async for chunk in stream: + pass + + self.loop.run_until_complete(inner()) + + def _basetest_connect(self, stream): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + stream.write(b'GET / HTTP/1.0\r\n\r\n') + f = stream.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = stream.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + stream.close() + self.loop.run_until_complete(stream.wait_closed()) + + self.assertEqual([], messages) + + def test_connect(self): + with test_utils.run_test_server() as httpd: + stream = self.loop.run_until_complete( + asyncio.connect(*httpd.address)) + self.assertFalse(stream.is_server_side()) + self._basetest_connect(stream) + + @support.skip_unless_bind_unix_socket + def test_connect_unix(self): + with test_utils.run_test_unix_server() as httpd: + stream = self.loop.run_until_complete( + asyncio.connect_unix(httpd.address)) + self._basetest_connect(stream) + + def test_stream_async_context_manager(self): + async def test(httpd): + stream = await asyncio.connect(*httpd.address) + async with stream: + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + data = await stream.read() + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertTrue(stream.is_closing()) + + with test_utils.run_test_server() as httpd: + self.loop.run_until_complete(test(httpd)) + + def test_connect_async_context_manager(self): + async def test(httpd): + async with asyncio.connect(*httpd.address) as stream: + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + data = await stream.read() + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertTrue(stream.is_closing()) + + with test_utils.run_test_server() as httpd: + self.loop.run_until_complete(test(httpd)) + + @support.skip_unless_bind_unix_socket + def test_connect_unix_async_context_manager(self): + async def test(httpd): + async with asyncio.connect_unix(httpd.address) as stream: + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + data = await stream.read() + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertTrue(stream.is_closing()) + + with test_utils.run_test_unix_server() as httpd: + self.loop.run_until_complete(test(httpd)) + + def test_stream_server(self): + + async def handle_client(stream): + self.assertTrue(stream.is_server_side()) + data = await stream.readline() + await stream.write(data) + await stream.close() + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect(*addr) + # send a line + await stream.write(b"hello world!\n") + # read it back + msgback = await stream.readline() + await stream.close() + self.assertEqual(msgback, b"hello world!\n") + await srv.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + with contextlib.suppress(asyncio.CancelledError): + await server.serve_forever() + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + + @support.skip_unless_bind_unix_socket + def test_unix_stream_server(self): + + async def handle_client(stream): + data = await stream.readline() + await stream.write(data) + await stream.close() + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect_unix(addr) + # send a line + await stream.write(b"hello world!\n") + # read it back + msgback = await stream.readline() + await stream.close() + self.assertEqual(msgback, b"hello world!\n") + await srv.close() + + async def test(): + with test_utils.unix_socket_path() as path: + async with asyncio.UnixStreamServer(handle_client, path) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + with contextlib.suppress(asyncio.CancelledError): + await server.serve_forever() + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + + def test_stream_server_inheritance_forbidden(self): + with self.assertRaises(TypeError): + class MyServer(asyncio.StreamServer): + pass + + @support.skip_unless_bind_unix_socket + def test_unix_stream_server_inheritance_forbidden(self): + with self.assertRaises(TypeError): + class MyServer(asyncio.UnixStreamServer): + pass + + def test_stream_server_bind(self): + async def handle_client(stream): + await stream.close() + + async def test(): + srv = asyncio.StreamServer(handle_client, '127.0.0.1', 0) + self.assertFalse(srv.is_bound()) + self.assertEqual(0, len(srv.sockets)) + await srv.bind() + self.assertTrue(srv.is_bound()) + self.assertEqual(1, len(srv.sockets)) + await srv.close() + self.assertFalse(srv.is_bound()) + self.assertEqual(0, len(srv.sockets)) + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + + def test_stream_server_bind_async_with(self): + async def handle_client(stream): + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv: + self.assertTrue(srv.is_bound()) + self.assertEqual(1, len(srv.sockets)) + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + + def test_stream_server_start_serving(self): + async def handle_client(stream): + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv: + self.assertFalse(srv.is_serving()) + await srv.start_serving() + self.assertTrue(srv.is_serving()) + await srv.close() + self.assertFalse(srv.is_serving()) + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + + def test_stream_server_close(self): + server_stream_aborted = False + fut = self.loop.create_future() + + async def handle_client(stream): + await fut + self.assertEqual(b'', await stream.readline()) + nonlocal server_stream_aborted + server_stream_aborted = True + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect(*addr) + fut.set_result(None) + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut + await server.close() + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + self.assertTrue(fut.done()) + self.assertTrue(server_stream_aborted) + + def test_stream_server_abort(self): + server_stream_aborted = False + fut = self.loop.create_future() + + async def handle_client(stream): + await fut + self.assertEqual(b'', await stream.readline()) + nonlocal server_stream_aborted + server_stream_aborted = True + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect(*addr) + fut.set_result(None) + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut + await server.abort() + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + self.assertTrue(fut.done()) + self.assertTrue(server_stream_aborted) + + def test_stream_shutdown_hung_task(self): + fut1 = self.loop.create_future() + fut2 = self.loop.create_future() + + async def handle_client(stream): + while True: + await asyncio.sleep(0.01) + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect(*addr) + fut1.set_result(None) + await fut2 + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, + '127.0.0.1', + 0, + shutdown_timeout=0.3) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut1 + await server.close() + fut2.set_result(None) + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + self.assertTrue(fut1.done()) + self.assertTrue(fut2.done()) + + def test_stream_shutdown_hung_task_prevents_cancellation(self): + fut1 = self.loop.create_future() + fut2 = self.loop.create_future() + do_handle_client = True + + async def handle_client(stream): + while do_handle_client: + with contextlib.suppress(asyncio.CancelledError): + await asyncio.sleep(0.01) + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect(*addr) + fut1.set_result(None) + await fut2 + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, + '127.0.0.1', + 0, + shutdown_timeout=0.3) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut1 + await server.close() + nonlocal do_handle_client + do_handle_client = False + fut2.set_result(None) + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(1, len(messages)) + self.assertRegex(messages[0]['message'], + "<Task pending .+ ignored cancellation request") + self.assertTrue(fut1.done()) + self.assertTrue(fut2.done()) + + def test_sendfile(self): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + with open(support.TESTFN, 'wb') as fp: + fp.write(b'data\n') + self.addCleanup(support.unlink, support.TESTFN) + + async def serve_callback(stream): + data = await stream.readline() + self.assertEqual(data, b'begin\n') + data = await stream.readline() + self.assertEqual(data, b'data\n') + data = await stream.readline() + self.assertEqual(data, b'end\n') + await stream.write(b'done\n') + await stream.close() + + async def do_connect(host, port): + stream = await asyncio.connect(host, port) + await stream.write(b'begin\n') + with open(support.TESTFN, 'rb') as fp: + await stream.sendfile(fp) + await stream.write(b'end\n') + data = await stream.readline() + self.assertEqual(data, b'done\n') + await stream.close() + + async def test(): + async with asyncio.StreamServer(serve_callback, '127.0.0.1', 0) as srv: + await srv.start_serving() + await do_connect(*srv.sockets[0].getsockname()) + + self.loop.run_until_complete(test()) + + self.assertEqual([], messages) + + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_connect_start_tls(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + # connect without SSL but upgrade to TLS just after + # connection is established + stream = self.loop.run_until_complete( + asyncio.connect(*httpd.address)) + + self.loop.run_until_complete( + stream.start_tls( + sslcontext=test_utils.dummy_ssl_context())) + self._basetest_connect(stream) + + def test_repr_unbound(self): + async def serve(stream): + pass + + async def test(): + srv = asyncio.StreamServer(serve) + self.assertEqual('<StreamServer>', repr(srv)) + await srv.close() + + self.loop.run_until_complete(test()) + + def test_repr_bound(self): + async def serve(stream): + pass + + async def test(): + srv = asyncio.StreamServer(serve, '127.0.0.1', 0) + await srv.bind() + self.assertRegex(repr(srv), r'<StreamServer sockets=\(.+\)>') + await srv.close() + + self.loop.run_until_complete(test()) + + def test_repr_serving(self): + async def serve(stream): + pass + + async def test(): + srv = asyncio.StreamServer(serve, '127.0.0.1', 0) + await srv.start_serving() + self.assertRegex(repr(srv), r'<StreamServer serving sockets=\(.+\)>') + await srv.close() + + self.loop.run_until_complete(test()) + + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + async def test(): + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + async with asyncio.connect_read_pipe(pipeobj) as stream: + self.assertEqual(stream.mode, asyncio.StreamMode.READ) + + os.write(wpipe, b'1') + data = await stream.readexactly(1) + self.assertEqual(data, b'1') + + os.write(wpipe, b'2345') + data = await stream.readexactly(4) + self.assertEqual(data, b'2345') + os.close(wpipe) + + self.loop.run_until_complete(test()) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + async def test(): + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + async with asyncio.connect_write_pipe(pipeobj) as stream: + self.assertEqual(stream.mode, asyncio.StreamMode.WRITE) + + await stream.write(b'1') + data = os.read(rpipe, 1024) + self.assertEqual(data, b'1') + + await stream.write(b'2345') + data = os.read(rpipe, 1024) + self.assertEqual(data, b'2345') + + os.close(rpipe) + self.loop.run_until_complete(test()) if __name__ == '__main__': diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py index e201a06..13aef7c 100644 --- a/Lib/test/test_asyncio/test_windows_events.py +++ b/Lib/test/test_asyncio/test_windows_events.py @@ -17,6 +17,7 @@ import _winapi import asyncio from asyncio import windows_events +from asyncio.streams import _StreamProtocol from test.test_asyncio import utils as test_utils from test.support.script_helper import spawn_python @@ -100,16 +101,16 @@ class ProactorTests(test_utils.TestCase): clients = [] for i in range(5): - stream_reader = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) - protocol = asyncio.StreamReaderProtocol(stream_reader, - loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, _asyncio_internal=True) + protocol = _StreamProtocol(stream, + loop=self.loop, + _asyncio_internal=True) trans, proto = await self.loop.create_pipe_connection( lambda: protocol, ADDRESS) self.assertIsInstance(trans, asyncio.Transport) self.assertEqual(protocol, proto) - clients.append((stream_reader, trans)) + clients.append((stream, trans)) for i, (r, w) in enumerate(clients): w.write('lower-{}\n'.format(i).encode()) @@ -118,6 +119,7 @@ class ProactorTests(test_utils.TestCase): response = await r.readline() self.assertEqual(response, 'LOWER-{}\n'.format(i).encode()) w.close() + await r.close() server.close() |