diff options
-rw-r--r-- | Lib/asyncio/streams.py | 91 | ||||
-rw-r--r-- | Lib/asyncio/subprocess.py | 5 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_streams.py | 71 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2018-09-12-10-33-44.bpo-34638.xaeZX5.rst | 3 |
4 files changed, 160 insertions, 10 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 9dab49b..e7fb22e 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -3,6 +3,8 @@ __all__ = ( 'open_connection', 'start_server') import socket +import sys +import weakref if hasattr(socket, 'AF_UNIX'): __all__ += ('open_unix_connection', 'start_unix_server') @@ -10,6 +12,7 @@ if hasattr(socket, 'AF_UNIX'): from . import coroutines from . import events from . import exceptions +from . import format_helpers from . import protocols from .log import logger from .tasks import sleep @@ -186,46 +189,106 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): call inappropriate methods of the protocol.) """ + _source_traceback = None + def __init__(self, stream_reader, client_connected_cb=None, loop=None): super().__init__(loop=loop) - self._stream_reader = stream_reader + if stream_reader is not None: + self._stream_reader_wr = weakref.ref(stream_reader, + self._on_reader_gc) + self._source_traceback = stream_reader._source_traceback + else: + self._stream_reader_wr = None + if client_connected_cb is not None: + # This is a stream created by the `create_server()` function. + # Keep a strong reference to the reader until a connection + # is established. + self._strong_reader = stream_reader + self._reject_connection = False self._stream_writer = None + self._transport = None self._client_connected_cb = client_connected_cb self._over_ssl = False self._closed = self._loop.create_future() + def _on_reader_gc(self, wr): + transport = self._transport + if transport is not None: + # connection_made was called + context = { + 'message': ('An open stream object is being garbage ' + 'collected; call "stream.close()" explicitly.') + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + transport.abort() + else: + self._reject_connection = True + self._stream_reader_wr = None + + def _untrack_reader(self): + self._stream_reader_wr = None + + @property + def _stream_reader(self): + if self._stream_reader_wr is None: + return None + return self._stream_reader_wr() + def connection_made(self, transport): - self._stream_reader.set_transport(transport) + if self._reject_connection: + context = { + 'message': ('An open stream was garbage collected prior to ' + 'establishing network connection; ' + 'call "stream.close()" explicitly.') + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + transport.abort() + return + self._transport = transport + reader = self._stream_reader + if reader is not None: + reader.set_transport(transport) self._over_ssl = transport.get_extra_info('sslcontext') is not None if self._client_connected_cb is not None: self._stream_writer = StreamWriter(transport, self, - self._stream_reader, + reader, self._loop) - res = self._client_connected_cb(self._stream_reader, + res = self._client_connected_cb(reader, self._stream_writer) if coroutines.iscoroutine(res): self._loop.create_task(res) + self._strong_reader = None def connection_lost(self, exc): - if self._stream_reader is not None: + reader = self._stream_reader + if reader is not None: if exc is None: - self._stream_reader.feed_eof() + reader.feed_eof() else: - self._stream_reader.set_exception(exc) + reader.set_exception(exc) if not self._closed.done(): if exc is None: self._closed.set_result(None) else: self._closed.set_exception(exc) super().connection_lost(exc) - self._stream_reader = None + self._stream_reader_wr = None self._stream_writer = None + self._transport = None def data_received(self, data): - self._stream_reader.feed_data(data) + reader = self._stream_reader + if reader is not None: + reader.feed_data(data) def eof_received(self): - self._stream_reader.feed_eof() + reader = self._stream_reader + if reader is not None: + reader.feed_eof() if self._over_ssl: # Prevent a warning in SSLProtocol.eof_received: # "returning true from eof_received() @@ -282,6 +345,9 @@ class StreamWriter: return self._transport.can_write_eof() def close(self): + # a reader can be garbage collected + # after connection closing + self._protocol._untrack_reader() return self._transport.close() def is_closing(self): @@ -318,6 +384,8 @@ class StreamWriter: class StreamReader: + _source_traceback = None + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): # The line length limit is a security feature; # it also doubles as half the buffer limit. @@ -336,6 +404,9 @@ class StreamReader: self._exception = None self._transport = None self._paused = False + if self._loop.get_debug(): + self._source_traceback = format_helpers.extract_stack( + sys._getframe(1)) def __repr__(self): info = ['StreamReader'] diff --git a/Lib/asyncio/subprocess.py b/Lib/asyncio/subprocess.py index 90fc00d..c86de3d 100644 --- a/Lib/asyncio/subprocess.py +++ b/Lib/asyncio/subprocess.py @@ -36,6 +36,11 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, info.append(f'stderr={self.stderr!r}') return '<{}>'.format(' '.join(info)) + def _untrack_reader(self): + # StreamWriter.close() expects the protocol + # to have this method defined. + pass + def connection_made(self, transport): self._transport = transport diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 66d1873..67ac9d9 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -46,6 +46,8 @@ class StreamTests(test_utils.TestCase): self.assertIs(stream._loop, m_events.get_event_loop.return_value) def _basetest_open_connection(self, open_connection_fut): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) reader, writer = self.loop.run_until_complete(open_connection_fut) writer.write(b'GET / HTTP/1.0\r\n\r\n') f = reader.readline() @@ -55,6 +57,7 @@ class StreamTests(test_utils.TestCase): data = self.loop.run_until_complete(f) self.assertTrue(data.endswith(b'\r\n\r\nTest message')) writer.close() + self.assertEqual(messages, []) def test_open_connection(self): with test_utils.run_test_server() as httpd: @@ -70,6 +73,8 @@ class StreamTests(test_utils.TestCase): self._basetest_open_connection(conn_fut) def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) try: reader, writer = self.loop.run_until_complete(open_connection_fut) finally: @@ -80,6 +85,7 @@ class StreamTests(test_utils.TestCase): self.assertTrue(data.endswith(b'\r\n\r\nTest message')) writer.close() + self.assertEqual(messages, []) @unittest.skipIf(ssl is None, 'No ssl module') def test_open_connection_no_loop_ssl(self): @@ -104,6 +110,8 @@ class StreamTests(test_utils.TestCase): self._basetest_open_connection_no_loop_ssl(conn_fut) def _basetest_open_connection_error(self, open_connection_fut): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) reader, writer = self.loop.run_until_complete(open_connection_fut) writer._protocol.connection_lost(ZeroDivisionError()) f = reader.read() @@ -111,6 +119,7 @@ class StreamTests(test_utils.TestCase): self.loop.run_until_complete(f) writer.close() test_utils.run_briefly(self.loop) + self.assertEqual(messages, []) def test_open_connection_error(self): with test_utils.run_test_server() as httpd: @@ -621,6 +630,9 @@ class StreamTests(test_utils.TestCase): writer.close() return msgback + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + # test the server variant with a coroutine as client handler server = MyServer(self.loop) addr = server.start() @@ -637,6 +649,8 @@ class StreamTests(test_utils.TestCase): server.stop() self.assertEqual(msg, b"hello world!\n") + self.assertEqual(messages, []) + @support.skip_unless_bind_unix_socket def test_start_unix_server(self): @@ -685,6 +699,9 @@ class StreamTests(test_utils.TestCase): writer.close() return msgback + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + # test the server variant with a coroutine as client handler with test_utils.unix_socket_path() as path: server = MyServer(self.loop, path) @@ -703,6 +720,8 @@ class StreamTests(test_utils.TestCase): server.stop() self.assertEqual(msg, b"hello world!\n") + self.assertEqual(messages, []) + @unittest.skipIf(sys.platform == 'win32', "Don't have pipes") def test_read_all_from_pipe_reader(self): # See asyncio issue 168. This test is derived from the example @@ -893,6 +912,58 @@ os.close(fd) wr.close() self.loop.run_until_complete(wr.wait_closed()) + def test_del_stream_before_sock_closing(self): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + with test_utils.run_test_server() as httpd: + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, loop=self.loop)) + sock = wr.get_extra_info('socket') + self.assertNotEqual(sock.fileno(), -1) + + wr.write(b'GET / HTTP/1.0\r\n\r\n') + f = rd.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + + # drop refs to reader/writer + del rd + del wr + gc.collect() + # make a chance to close the socket + test_utils.run_briefly(self.loop) + + self.assertEqual(1, len(messages)) + self.assertEqual(sock.fileno(), -1) + + self.assertEqual(1, len(messages)) + self.assertEqual('An open stream object is being garbage ' + 'collected; call "stream.close()" explicitly.', + messages[0]['message']) + + def test_del_stream_before_connection_made(self): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + with test_utils.run_test_server() as httpd: + rd = asyncio.StreamReader(loop=self.loop) + pr = asyncio.StreamReaderProtocol(rd, loop=self.loop) + del rd + gc.collect() + tr, _ = self.loop.run_until_complete( + self.loop.create_connection( + lambda: pr, *httpd.address)) + + sock = tr.get_extra_info('socket') + self.assertEqual(sock.fileno(), -1) + + self.assertEqual(1, len(messages)) + self.assertEqual('An open stream was garbage collected prior to ' + 'establishing network connection; ' + 'call "stream.close()" explicitly.', + messages[0]['message']) + if __name__ == '__main__': unittest.main() diff --git a/Misc/NEWS.d/next/Library/2018-09-12-10-33-44.bpo-34638.xaeZX5.rst b/Misc/NEWS.d/next/Library/2018-09-12-10-33-44.bpo-34638.xaeZX5.rst new file mode 100644 index 0000000..13b3952 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-09-12-10-33-44.bpo-34638.xaeZX5.rst @@ -0,0 +1,3 @@ +Store a weak reference to stream reader to break strong references loop +between reader and protocol. It allows to detect and close the socket if +the stream is deleted (garbage collected) without ``close()`` call. |