diff options
-rw-r--r-- | Lib/asyncio/sslproto.py | 4 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_sslproto.py | 72 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2019-03-17-16-43-29.bpo-34745.nOfm7_.rst | 1 |
3 files changed, 77 insertions, 0 deletions
diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 4278560..97a6fc6 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -498,7 +498,11 @@ class SSLProtocol(protocols.Protocol): self._app_transport._closed = True self._transport = None self._app_transport = None + if getattr(self, '_handshake_timeout_handle', None): + self._handshake_timeout_handle.cancel() self._wakeup_waiter(exc) + self._app_protocol = None + self._sslpipe = None def pause_writing(self): """Called when the low-level transport's buffer goes over diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index 19b7a43..7bc2ccf 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -4,6 +4,7 @@ import logging import socket import sys import unittest +import weakref from unittest import mock try: import ssl @@ -274,6 +275,72 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): self.loop.run_until_complete( asyncio.wait_for(client(srv.addr), timeout=10)) + # No garbage is left if SSL is closed uncleanly + client_context = weakref.ref(client_context) + self.assertIsNone(client_context()) + + def test_create_connection_memory_leak(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + sock.start_tls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.shutdown(socket.SHUT_RDWR) + sock.close() + + class ClientProto(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(proto, tr): + # XXX: We assume user stores the transport in protocol + proto.tr = tr + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5) + + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr, + ssl=client_context) + + self.assertEqual(await on_data, b'O') + tr.write(HELLO_MSG) + await on_eof + + tr.close() + + with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), timeout=10)) + + # No garbage is left for SSL client from loop.create_connection, even + # if user stores the SSLTransport in corresponding protocol instance + client_context = weakref.ref(client_context) + self.assertIsNone(client_context()) + def test_start_tls_client_buf_proto_1(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE @@ -562,6 +629,11 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): # exception or log an error, even if the handshake failed self.assertEqual(messages, []) + # The 10s handshake timeout should be cancelled to free related + # objects without really waiting for 10s + client_sslctx = weakref.ref(client_sslctx) + self.assertIsNone(client_sslctx()) + def test_create_connection_ssl_slow_handshake(self): client_sslctx = test_utils.simple_client_sslcontext() diff --git a/Misc/NEWS.d/next/Library/2019-03-17-16-43-29.bpo-34745.nOfm7_.rst b/Misc/NEWS.d/next/Library/2019-03-17-16-43-29.bpo-34745.nOfm7_.rst new file mode 100644 index 0000000..d88f36a --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-03-17-16-43-29.bpo-34745.nOfm7_.rst @@ -0,0 +1 @@ +Fix :mod:`asyncio` ssl memory issues caused by circular references |