diff options
author | Pablo Galindo <Pablogsal@gmail.com> | 2021-05-03 15:21:59 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-03 15:21:59 (GMT) |
commit | 7719953b30430b351ba0f153c2b51b16cc68ee36 (patch) | |
tree | 8014086b85a13ed79d45e29ab74a9a9f5c9c68eb /Lib/test/test_asyncio | |
parent | 39494285e15dc2d291ec13de5045b930eaf0a3db (diff) | |
download | cpython-7719953b30430b351ba0f153c2b51b16cc68ee36.zip cpython-7719953b30430b351ba0f153c2b51b16cc68ee36.tar.gz cpython-7719953b30430b351ba0f153c2b51b16cc68ee36.tar.bz2 |
bpo-44011: Revert "New asyncio ssl implementation (GH-17975)" (GH-25848)
This reverts commit 5fb06edbbb769561e245d0fe13002bab50e2ae60 and all
subsequent dependent commits.
Diffstat (limited to 'Lib/test/test_asyncio')
-rw-r--r-- | Lib/test/test_asyncio/test_base_events.py | 21 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_selector_events.py | 38 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_ssl.py | 1723 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_sslproto.py | 35 |
4 files changed, 60 insertions, 1757 deletions
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index be5ea1e..5691d42 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -1437,51 +1437,44 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport ANY = mock.ANY handshake_timeout = object() - shutdown_timeout = object() # First try the default server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( MyProto, 'python.org', 80, ssl=True, - ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_handshake_timeout=handshake_timeout) transport, _ = self.loop.run_until_complete(coro) transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, server_hostname='python.org', - ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_handshake_timeout=handshake_timeout) # Next try an explicit server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( MyProto, 'python.org', 80, ssl=True, server_hostname='perl.com', - ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_handshake_timeout=handshake_timeout) transport, _ = self.loop.run_until_complete(coro) transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, server_hostname='perl.com', - ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_handshake_timeout=handshake_timeout) # Finally try an explicit empty server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( MyProto, 'python.org', 80, ssl=True, server_hostname='', - ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_handshake_timeout=handshake_timeout) transport, _ = self.loop.run_until_complete(coro) transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, server_hostname='', - ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_handshake_timeout=handshake_timeout) def test_create_connection_no_ssl_server_hostname_errors(self): # When not using ssl, server_hostname must be None. @@ -1888,7 +1881,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): constants.ACCEPT_RETRY_DELAY, # self.loop._start_serving mock.ANY, - MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY) + MyProto, sock, None, None, mock.ANY, mock.ANY) def test_call_coroutine(self): with self.assertWarns(DeprecationWarning): diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index 349e4f2..1613c75 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -70,6 +70,44 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): close_transport(transport) + @unittest.skipIf(ssl is None, 'No ssl module') + def test_make_ssl_transport(self): + m = mock.Mock() + self.loop._add_reader = mock.Mock() + self.loop._add_reader._is_coroutine = False + self.loop._add_writer = mock.Mock() + self.loop._remove_reader = mock.Mock() + self.loop._remove_writer = mock.Mock() + waiter = self.loop.create_future() + with test_utils.disable_logger(): + transport = self.loop._make_ssl_transport( + m, asyncio.Protocol(), m, waiter) + + with self.assertRaisesRegex(RuntimeError, + r'SSL transport.*not.*initialized'): + transport.is_reading() + + # execute the handshake while the logger is disabled + # to ignore SSL handshake failure + test_utils.run_briefly(self.loop) + + self.assertTrue(transport.is_reading()) + transport.pause_reading() + transport.pause_reading() + self.assertFalse(transport.is_reading()) + transport.resume_reading() + transport.resume_reading() + self.assertTrue(transport.is_reading()) + + # Sanity check + class_name = transport.__class__.__name__ + self.assertIn("ssl", class_name.lower()) + self.assertIn("transport", class_name.lower()) + + transport.close() + # execute pending callbacks to close the socket transport + test_utils.run_briefly(self.loop) + @mock.patch('asyncio.selector_events.ssl', None) @mock.patch('asyncio.sslproto.ssl', None) def test_make_ssl_transport_without_ssl_error(self): diff --git a/Lib/test/test_asyncio/test_ssl.py b/Lib/test/test_asyncio/test_ssl.py deleted file mode 100644 index 9cdd281..0000000 --- a/Lib/test/test_asyncio/test_ssl.py +++ /dev/null @@ -1,1723 +0,0 @@ -import asyncio -import asyncio.sslproto -import contextlib -import gc -import logging -import select -import socket -import tempfile -import threading -import time -import weakref -import unittest - -try: - import ssl -except ImportError: - ssl = None - -from test import support -from test.test_asyncio import utils as test_utils - - -def tearDownModule(): - asyncio.set_event_loop_policy(None) - - -class MyBaseProto(asyncio.Protocol): - connected = None - done = None - - def __init__(self, loop=None): - self.transport = None - self.state = 'INITIAL' - self.nbytes = 0 - if loop is not None: - self.connected = asyncio.Future(loop=loop) - self.done = asyncio.Future(loop=loop) - - def connection_made(self, transport): - self.transport = transport - assert self.state == 'INITIAL', self.state - self.state = 'CONNECTED' - if self.connected: - self.connected.set_result(None) - - def data_received(self, data): - assert self.state == 'CONNECTED', self.state - self.nbytes += len(data) - - def eof_received(self): - assert self.state == 'CONNECTED', self.state - self.state = 'EOF' - - def connection_lost(self, exc): - assert self.state in ('CONNECTED', 'EOF'), self.state - self.state = 'CLOSED' - if self.done: - self.done.set_result(None) - - -@unittest.skipIf(ssl is None, 'No ssl module') -class TestSSL(test_utils.TestCase): - - PAYLOAD_SIZE = 1024 * 100 - TIMEOUT = 60 - - def setUp(self): - super().setUp() - self.loop = asyncio.new_event_loop() - self.set_event_loop(self.loop) - self.addCleanup(self.loop.close) - - def tearDown(self): - # just in case if we have transport close callbacks - if not self.loop.is_closed(): - test_utils.run_briefly(self.loop) - - self.doCleanups() - support.gc_collect() - super().tearDown() - - def tcp_server(self, server_prog, *, - family=socket.AF_INET, - addr=None, - timeout=5, - backlog=1, - max_clients=10): - - if addr is None: - if family == getattr(socket, "AF_UNIX", None): - with tempfile.NamedTemporaryFile() as tmp: - addr = tmp.name - else: - addr = ('127.0.0.1', 0) - - sock = socket.socket(family, socket.SOCK_STREAM) - - if timeout is None: - raise RuntimeError('timeout is required') - if timeout <= 0: - raise RuntimeError('only blocking sockets are supported') - sock.settimeout(timeout) - - try: - sock.bind(addr) - sock.listen(backlog) - except OSError as ex: - sock.close() - raise ex - - return TestThreadedServer( - self, sock, server_prog, timeout, max_clients) - - def tcp_client(self, client_prog, - family=socket.AF_INET, - timeout=10): - - sock = socket.socket(family, socket.SOCK_STREAM) - - if timeout is None: - raise RuntimeError('timeout is required') - if timeout <= 0: - raise RuntimeError('only blocking sockets are supported') - sock.settimeout(timeout) - - return TestThreadedClient( - self, sock, client_prog, timeout) - - def unix_server(self, *args, **kwargs): - return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) - - def unix_client(self, *args, **kwargs): - return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) - - def _create_server_ssl_context(self, certfile, keyfile=None): - sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.load_cert_chain(certfile, keyfile) - return sslcontext - - def _create_client_ssl_context(self, *, disable_verify=True): - sslcontext = ssl.create_default_context() - sslcontext.check_hostname = False - if disable_verify: - sslcontext.verify_mode = ssl.CERT_NONE - return sslcontext - - @contextlib.contextmanager - def _silence_eof_received_warning(self): - # TODO This warning has to be fixed in asyncio. - logger = logging.getLogger('asyncio') - filter = logging.Filter('has no effect when using ssl') - logger.addFilter(filter) - try: - yield - finally: - logger.removeFilter(filter) - - def _abort_socket_test(self, ex): - try: - self.loop.stop() - finally: - self.fail(ex) - - def new_loop(self): - return asyncio.new_event_loop() - - def new_policy(self): - return asyncio.DefaultEventLoopPolicy() - - async def wait_closed(self, obj): - if not isinstance(obj, asyncio.StreamWriter): - return - try: - await obj.wait_closed() - except (BrokenPipeError, ConnectionError): - pass - - def test_create_server_ssl_1(self): - CNT = 0 # number of clients that were successful - TOTAL_CNT = 25 # total number of clients that test will create - TIMEOUT = 60.0 # timeout for this test - - A_DATA = b'A' * 1024 * 1024 - B_DATA = b'B' * 1024 * 1024 - - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, test_utils.ONLYKEY - ) - client_sslctx = self._create_client_ssl_context() - - clients = [] - - async def handle_client(reader, writer): - nonlocal CNT - - data = await reader.readexactly(len(A_DATA)) - self.assertEqual(data, A_DATA) - writer.write(b'OK') - - data = await reader.readexactly(len(B_DATA)) - self.assertEqual(data, B_DATA) - writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) - - await writer.drain() - writer.close() - - CNT += 1 - - async def test_client(addr): - fut = asyncio.Future() - - def prog(sock): - try: - sock.starttls(client_sslctx) - sock.connect(addr) - sock.send(A_DATA) - - data = sock.recv_all(2) - self.assertEqual(data, b'OK') - - sock.send(B_DATA) - data = sock.recv_all(4) - self.assertEqual(data, b'SPAM') - - sock.close() - - except Exception as ex: - self.loop.call_soon_threadsafe(fut.set_exception, ex) - else: - self.loop.call_soon_threadsafe(fut.set_result, None) - - client = self.tcp_client(prog) - client.start() - clients.append(client) - - await fut - - async def start_server(): - extras = {} - extras = dict(ssl_handshake_timeout=40.0) - - srv = await asyncio.start_server( - handle_client, - '127.0.0.1', 0, - family=socket.AF_INET, - ssl=sslctx, - **extras) - - try: - srv_socks = srv.sockets - self.assertTrue(srv_socks) - - addr = srv_socks[0].getsockname() - - tasks = [] - for _ in range(TOTAL_CNT): - tasks.append(test_client(addr)) - - await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) - - finally: - self.loop.call_soon(srv.close) - await srv.wait_closed() - - with self._silence_eof_received_warning(): - self.loop.run_until_complete(start_server()) - - self.assertEqual(CNT, TOTAL_CNT) - - for client in clients: - client.stop() - - def test_create_connection_ssl_1(self): - self.loop.set_exception_handler(None) - - CNT = 0 - TOTAL_CNT = 25 - - A_DATA = b'A' * 1024 * 1024 - B_DATA = b'B' * 1024 * 1024 - - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, - test_utils.ONLYKEY - ) - client_sslctx = self._create_client_ssl_context() - - def server(sock): - sock.starttls( - sslctx, - server_side=True) - - data = sock.recv_all(len(A_DATA)) - self.assertEqual(data, A_DATA) - sock.send(b'OK') - - data = sock.recv_all(len(B_DATA)) - self.assertEqual(data, B_DATA) - sock.send(b'SPAM') - - sock.close() - - async def client(addr): - extras = {} - extras = dict(ssl_handshake_timeout=40.0) - - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - **extras) - - writer.write(A_DATA) - self.assertEqual(await reader.readexactly(2), b'OK') - - writer.write(B_DATA) - self.assertEqual(await reader.readexactly(4), b'SPAM') - - nonlocal CNT - CNT += 1 - - writer.close() - await self.wait_closed(writer) - - async def client_sock(addr): - sock = socket.socket() - sock.connect(addr) - reader, writer = await asyncio.open_connection( - sock=sock, - ssl=client_sslctx, - server_hostname='') - - writer.write(A_DATA) - self.assertEqual(await reader.readexactly(2), b'OK') - - writer.write(B_DATA) - self.assertEqual(await reader.readexactly(4), b'SPAM') - - nonlocal CNT - CNT += 1 - - writer.close() - await self.wait_closed(writer) - sock.close() - - def run(coro): - nonlocal CNT - CNT = 0 - - async def _gather(*tasks): - # trampoline - return await asyncio.gather(*tasks) - - with self.tcp_server(server, - max_clients=TOTAL_CNT, - backlog=TOTAL_CNT) as srv: - tasks = [] - for _ in range(TOTAL_CNT): - tasks.append(coro(srv.addr)) - - self.loop.run_until_complete(_gather(*tasks)) - - self.assertEqual(CNT, TOTAL_CNT) - - with self._silence_eof_received_warning(): - run(client) - - with self._silence_eof_received_warning(): - run(client_sock) - - def test_create_connection_ssl_slow_handshake(self): - client_sslctx = self._create_client_ssl_context() - - # silence error logger - self.loop.set_exception_handler(lambda *args: None) - - def server(sock): - try: - sock.recv_all(1024 * 1024) - except ConnectionAbortedError: - pass - finally: - sock.close() - - async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - ssl_handshake_timeout=1.0) - writer.close() - await self.wait_closed(writer) - - with self.tcp_server(server, - max_clients=1, - backlog=1) as srv: - - with self.assertRaisesRegex( - ConnectionAbortedError, - r'SSL handshake.*is taking longer'): - - self.loop.run_until_complete(client(srv.addr)) - - def test_create_connection_ssl_failed_certificate(self): - # silence error logger - self.loop.set_exception_handler(lambda *args: None) - - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, - test_utils.ONLYKEY - ) - client_sslctx = self._create_client_ssl_context(disable_verify=False) - - def server(sock): - try: - sock.starttls( - sslctx, - server_side=True) - sock.connect() - except (ssl.SSLError, OSError): - pass - finally: - sock.close() - - async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - ssl_handshake_timeout=1.0) - writer.close() - await self.wait_closed(writer) - - with self.tcp_server(server, - max_clients=1, - backlog=1) as srv: - - with self.assertRaises(ssl.SSLCertVerificationError): - self.loop.run_until_complete(client(srv.addr)) - - def test_ssl_handshake_timeout(self): - # bpo-29970: Check that a connection is aborted if handshake is not - # completed in timeout period, instead of remaining open indefinitely - client_sslctx = test_utils.simple_client_sslcontext() - - # silence error logger - messages = [] - self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - - server_side_aborted = False - - def server(sock): - nonlocal server_side_aborted - try: - sock.recv_all(1024 * 1024) - except ConnectionAbortedError: - server_side_aborted = True - finally: - sock.close() - - async def client(addr): - await asyncio.wait_for( - self.loop.create_connection( - asyncio.Protocol, - *addr, - ssl=client_sslctx, - server_hostname='', - ssl_handshake_timeout=10.0), - 0.5) - - with self.tcp_server(server, - max_clients=1, - backlog=1) as srv: - - with self.assertRaises(asyncio.TimeoutError): - self.loop.run_until_complete(client(srv.addr)) - - self.assertTrue(server_side_aborted) - - # Python issue #23197: cancelling a handshake must not raise an - # exception or log an error, even if the handshake failed - self.assertEqual(messages, []) - - def test_ssl_handshake_connection_lost(self): - # #246: make sure that no connection_lost() is called before - # connection_made() is called first - - client_sslctx = test_utils.simple_client_sslcontext() - - # silence error logger - self.loop.set_exception_handler(lambda loop, ctx: None) - - connection_made_called = False - connection_lost_called = False - - def server(sock): - sock.recv(1024) - # break the connection during handshake - sock.close() - - class ClientProto(asyncio.Protocol): - def connection_made(self, transport): - nonlocal connection_made_called - connection_made_called = True - - def connection_lost(self, exc): - nonlocal connection_lost_called - connection_lost_called = True - - async def client(addr): - await self.loop.create_connection( - ClientProto, - *addr, - ssl=client_sslctx, - server_hostname=''), - - with self.tcp_server(server, - max_clients=1, - backlog=1) as srv: - - with self.assertRaises(ConnectionResetError): - self.loop.run_until_complete(client(srv.addr)) - - if connection_lost_called: - if connection_made_called: - self.fail("unexpected call to connection_lost()") - else: - self.fail("unexpected call to connection_lost() without" - "calling connection_made()") - elif connection_made_called: - self.fail("unexpected call to connection_made()") - - def test_ssl_connect_accepted_socket(self): - proto = ssl.PROTOCOL_TLS_SERVER - server_context = ssl.SSLContext(proto) - server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY) - if hasattr(server_context, 'check_hostname'): - server_context.check_hostname = False - server_context.verify_mode = ssl.CERT_NONE - - client_context = ssl.SSLContext(proto) - if hasattr(server_context, 'check_hostname'): - client_context.check_hostname = False - client_context.verify_mode = ssl.CERT_NONE - - def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None): - loop = self.loop - - class MyProto(MyBaseProto): - - def connection_lost(self, exc): - super().connection_lost(exc) - loop.call_soon(loop.stop) - - def data_received(self, data): - super().data_received(data) - self.transport.write(expected_response) - - lsock = socket.socket(socket.AF_INET) - lsock.bind(('127.0.0.1', 0)) - lsock.listen(1) - addr = lsock.getsockname() - - message = b'test data' - response = None - expected_response = b'roger' - - def client(): - nonlocal response - try: - csock = socket.socket(socket.AF_INET) - if client_ssl is not None: - csock = client_ssl.wrap_socket(csock) - csock.connect(addr) - csock.sendall(message) - response = csock.recv(99) - csock.close() - except Exception as exc: - print( - "Failure in client thread in test_connect_accepted_socket", - exc) - - thread = threading.Thread(target=client, daemon=True) - thread.start() - - conn, _ = lsock.accept() - proto = MyProto(loop=loop) - proto.loop = loop - - extras = {} - if server_ssl: - extras = dict(ssl_handshake_timeout=10.0) - - f = loop.create_task( - loop.connect_accepted_socket( - (lambda: proto), conn, ssl=server_ssl, - **extras)) - loop.run_forever() - conn.close() - lsock.close() - - thread.join(1) - self.assertFalse(thread.is_alive()) - self.assertEqual(proto.state, 'CLOSED') - self.assertEqual(proto.nbytes, len(message)) - self.assertEqual(response, expected_response) - tr, _ = f.result() - - if server_ssl: - self.assertIn('SSL', tr.__class__.__name__) - - tr.close() - # let it close - self.loop.run_until_complete(asyncio.sleep(0.1)) - - def test_start_tls_client_corrupted_ssl(self): - self.loop.set_exception_handler(lambda loop, ctx: None) - - sslctx = test_utils.simple_server_sslcontext() - client_sslctx = test_utils.simple_client_sslcontext() - - def server(sock): - orig_sock = sock.dup() - try: - sock.starttls( - sslctx, - server_side=True) - sock.sendall(b'A\n') - sock.recv_all(1) - orig_sock.send(b'please corrupt the SSL connection') - except ssl.SSLError: - pass - finally: - sock.close() - orig_sock.close() - - async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='') - - self.assertEqual(await reader.readline(), b'A\n') - writer.write(b'B') - with self.assertRaises(ssl.SSLError): - await reader.readline() - writer.close() - try: - await self.wait_closed(writer) - except ssl.SSLError: - pass - return 'OK' - - with self.tcp_server(server, - max_clients=1, - backlog=1) as srv: - - res = self.loop.run_until_complete(client(srv.addr)) - - self.assertEqual(res, 'OK') - - def test_start_tls_client_reg_proto_1(self): - HELLO_MSG = b'1' * self.PAYLOAD_SIZE - - server_context = test_utils.simple_server_sslcontext() - client_context = test_utils.simple_client_sslcontext() - - def serve(sock): - sock.settimeout(self.TIMEOUT) - - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.starttls(server_context, server_side=True) - - sock.sendall(b'O') - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.unwrap() - sock.close() - - class ClientProto(asyncio.Protocol): - def __init__(self, on_data, on_eof): - self.on_data = on_data - self.on_eof = on_eof - self.con_made_cnt = 0 - - def connection_made(proto, tr): - proto.con_made_cnt += 1 - # Ensure connection_made gets called only once. - self.assertEqual(proto.con_made_cnt, 1) - - def data_received(self, data): - self.on_data.set_result(data) - - def eof_received(self): - self.on_eof.set_result(True) - - async def client(addr): - await asyncio.sleep(0.5) - - on_data = self.loop.create_future() - on_eof = self.loop.create_future() - - tr, proto = await self.loop.create_connection( - lambda: ClientProto(on_data, on_eof), *addr) - - tr.write(HELLO_MSG) - new_tr = await self.loop.start_tls(tr, proto, client_context) - - self.assertEqual(await on_data, b'O') - new_tr.write(HELLO_MSG) - await on_eof - - new_tr.close() - - with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: - self.loop.run_until_complete( - asyncio.wait_for(client(srv.addr), timeout=10)) - - def test_create_connection_memory_leak(self): - HELLO_MSG = b'1' * self.PAYLOAD_SIZE - - server_context = self._create_server_ssl_context( - test_utils.ONLYCERT, test_utils.ONLYKEY) - client_context = self._create_client_ssl_context() - - def serve(sock): - sock.settimeout(self.TIMEOUT) - - sock.starttls(server_context, server_side=True) - - sock.sendall(b'O') - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.unwrap() - sock.close() - - class ClientProto(asyncio.Protocol): - def __init__(self, on_data, on_eof): - self.on_data = on_data - self.on_eof = on_eof - self.con_made_cnt = 0 - - def connection_made(proto, tr): - # XXX: We assume user stores the transport in protocol - proto.tr = tr - proto.con_made_cnt += 1 - # Ensure connection_made gets called only once. - self.assertEqual(proto.con_made_cnt, 1) - - def data_received(self, data): - self.on_data.set_result(data) - - def eof_received(self): - self.on_eof.set_result(True) - - async def client(addr): - await asyncio.sleep(0.5) - - on_data = self.loop.create_future() - on_eof = self.loop.create_future() - - tr, proto = await self.loop.create_connection( - lambda: ClientProto(on_data, on_eof), *addr, - ssl=client_context) - - self.assertEqual(await on_data, b'O') - tr.write(HELLO_MSG) - await on_eof - - tr.close() - - with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: - self.loop.run_until_complete( - asyncio.wait_for(client(srv.addr), timeout=10)) - - # No garbage is left for SSL client from loop.create_connection, even - # if user stores the SSLTransport in corresponding protocol instance - client_context = weakref.ref(client_context) - self.assertIsNone(client_context()) - - def test_start_tls_client_buf_proto_1(self): - HELLO_MSG = b'1' * self.PAYLOAD_SIZE - - server_context = test_utils.simple_server_sslcontext() - client_context = test_utils.simple_client_sslcontext() - - client_con_made_calls = 0 - - def serve(sock): - sock.settimeout(self.TIMEOUT) - - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.starttls(server_context, server_side=True) - - sock.sendall(b'O') - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.sendall(b'2') - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.unwrap() - sock.close() - - class ClientProtoFirst(asyncio.BufferedProtocol): - def __init__(self, on_data): - self.on_data = on_data - self.buf = bytearray(1) - - def connection_made(self, tr): - nonlocal client_con_made_calls - client_con_made_calls += 1 - - def get_buffer(self, sizehint): - return self.buf - - def buffer_updated(self, nsize): - assert nsize == 1 - self.on_data.set_result(bytes(self.buf[:nsize])) - - def eof_received(self): - pass - - class ClientProtoSecond(asyncio.Protocol): - def __init__(self, on_data, on_eof): - self.on_data = on_data - self.on_eof = on_eof - self.con_made_cnt = 0 - - def connection_made(self, tr): - nonlocal client_con_made_calls - client_con_made_calls += 1 - - def data_received(self, data): - self.on_data.set_result(data) - - def eof_received(self): - self.on_eof.set_result(True) - - async def client(addr): - await asyncio.sleep(0.5) - - on_data1 = self.loop.create_future() - on_data2 = self.loop.create_future() - on_eof = self.loop.create_future() - - tr, proto = await self.loop.create_connection( - lambda: ClientProtoFirst(on_data1), *addr) - - tr.write(HELLO_MSG) - new_tr = await self.loop.start_tls(tr, proto, client_context) - - self.assertEqual(await on_data1, b'O') - new_tr.write(HELLO_MSG) - - new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) - self.assertEqual(await on_data2, b'2') - new_tr.write(HELLO_MSG) - await on_eof - - new_tr.close() - - # connection_made() should be called only once -- when - # we establish connection for the first time. Start TLS - # doesn't call connection_made() on application protocols. - self.assertEqual(client_con_made_calls, 1) - - with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: - self.loop.run_until_complete( - asyncio.wait_for(client(srv.addr), - timeout=self.TIMEOUT)) - - def test_start_tls_slow_client_cancel(self): - HELLO_MSG = b'1' * self.PAYLOAD_SIZE - - client_context = test_utils.simple_client_sslcontext() - server_waits_on_handshake = self.loop.create_future() - - def serve(sock): - sock.settimeout(self.TIMEOUT) - - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - try: - self.loop.call_soon_threadsafe( - server_waits_on_handshake.set_result, None) - data = sock.recv_all(1024 * 1024) - except ConnectionAbortedError: - pass - finally: - sock.close() - - class ClientProto(asyncio.Protocol): - def __init__(self, on_data, on_eof): - self.on_data = on_data - self.on_eof = on_eof - self.con_made_cnt = 0 - - def connection_made(proto, tr): - proto.con_made_cnt += 1 - # Ensure connection_made gets called only once. - self.assertEqual(proto.con_made_cnt, 1) - - def data_received(self, data): - self.on_data.set_result(data) - - def eof_received(self): - self.on_eof.set_result(True) - - async def client(addr): - await asyncio.sleep(0.5) - - on_data = self.loop.create_future() - on_eof = self.loop.create_future() - - tr, proto = await self.loop.create_connection( - lambda: ClientProto(on_data, on_eof), *addr) - - tr.write(HELLO_MSG) - - await server_waits_on_handshake - - with self.assertRaises(asyncio.TimeoutError): - await asyncio.wait_for( - self.loop.start_tls(tr, proto, client_context), - 0.5) - - with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: - self.loop.run_until_complete( - asyncio.wait_for(client(srv.addr), timeout=10)) - - def test_start_tls_server_1(self): - HELLO_MSG = b'1' * self.PAYLOAD_SIZE - - server_context = test_utils.simple_server_sslcontext() - client_context = test_utils.simple_client_sslcontext() - - def client(sock, addr): - sock.settimeout(self.TIMEOUT) - - sock.connect(addr) - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.starttls(client_context) - sock.sendall(HELLO_MSG) - - sock.unwrap() - sock.close() - - class ServerProto(asyncio.Protocol): - def __init__(self, on_con, on_eof, on_con_lost): - self.on_con = on_con - self.on_eof = on_eof - self.on_con_lost = on_con_lost - self.data = b'' - - def connection_made(self, tr): - self.on_con.set_result(tr) - - def data_received(self, data): - self.data += data - - def eof_received(self): - self.on_eof.set_result(1) - - def connection_lost(self, exc): - if exc is None: - self.on_con_lost.set_result(None) - else: - self.on_con_lost.set_exception(exc) - - async def main(proto, on_con, on_eof, on_con_lost): - tr = await on_con - tr.write(HELLO_MSG) - - self.assertEqual(proto.data, b'') - - new_tr = await self.loop.start_tls( - tr, proto, server_context, - server_side=True, - ssl_handshake_timeout=self.TIMEOUT) - - await on_eof - await on_con_lost - self.assertEqual(proto.data, HELLO_MSG) - new_tr.close() - - async def run_main(): - on_con = self.loop.create_future() - on_eof = self.loop.create_future() - on_con_lost = self.loop.create_future() - proto = ServerProto(on_con, on_eof, on_con_lost) - - server = await self.loop.create_server( - lambda: proto, '127.0.0.1', 0) - addr = server.sockets[0].getsockname() - - with self.tcp_client(lambda sock: client(sock, addr), - timeout=self.TIMEOUT): - await asyncio.wait_for( - main(proto, on_con, on_eof, on_con_lost), - timeout=self.TIMEOUT) - - server.close() - await server.wait_closed() - - self.loop.run_until_complete(run_main()) - - def test_create_server_ssl_over_ssl(self): - CNT = 0 # number of clients that were successful - TOTAL_CNT = 25 # total number of clients that test will create - TIMEOUT = 10.0 # timeout for this test - - A_DATA = b'A' * 1024 * 1024 - B_DATA = b'B' * 1024 * 1024 - - sslctx_1 = self._create_server_ssl_context( - test_utils.ONLYCERT, test_utils.ONLYKEY) - client_sslctx_1 = self._create_client_ssl_context() - sslctx_2 = self._create_server_ssl_context( - test_utils.ONLYCERT, test_utils.ONLYKEY) - client_sslctx_2 = self._create_client_ssl_context() - - clients = [] - - async def handle_client(reader, writer): - nonlocal CNT - - data = await reader.readexactly(len(A_DATA)) - self.assertEqual(data, A_DATA) - writer.write(b'OK') - - data = await reader.readexactly(len(B_DATA)) - self.assertEqual(data, B_DATA) - writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) - - await writer.drain() - writer.close() - - CNT += 1 - - class ServerProtocol(asyncio.StreamReaderProtocol): - def connection_made(self, transport): - super_ = super() - transport.pause_reading() - fut = self._loop.create_task(self._loop.start_tls( - transport, self, sslctx_2, server_side=True)) - - def cb(_): - try: - tr = fut.result() - except Exception as ex: - super_.connection_lost(ex) - else: - super_.connection_made(tr) - fut.add_done_callback(cb) - - def server_protocol_factory(): - reader = asyncio.StreamReader() - protocol = ServerProtocol(reader, handle_client) - return protocol - - async def test_client(addr): - fut = asyncio.Future() - - def prog(sock): - try: - sock.connect(addr) - sock.starttls(client_sslctx_1) - - # because wrap_socket() doesn't work correctly on - # SSLSocket, we have to do the 2nd level SSL manually - incoming = ssl.MemoryBIO() - outgoing = ssl.MemoryBIO() - sslobj = client_sslctx_2.wrap_bio(incoming, outgoing) - - def do(func, *args): - while True: - try: - rv = func(*args) - break - except ssl.SSLWantReadError: - if outgoing.pending: - sock.send(outgoing.read()) - incoming.write(sock.recv(65536)) - if outgoing.pending: - sock.send(outgoing.read()) - return rv - - do(sslobj.do_handshake) - - do(sslobj.write, A_DATA) - data = do(sslobj.read, 2) - self.assertEqual(data, b'OK') - - do(sslobj.write, B_DATA) - data = b'' - while True: - chunk = do(sslobj.read, 4) - if not chunk: - break - data += chunk - self.assertEqual(data, b'SPAM') - - do(sslobj.unwrap) - sock.close() - - except Exception as ex: - self.loop.call_soon_threadsafe(fut.set_exception, ex) - sock.close() - else: - self.loop.call_soon_threadsafe(fut.set_result, None) - - client = self.tcp_client(prog) - client.start() - clients.append(client) - - await fut - - async def start_server(): - extras = {} - - srv = await self.loop.create_server( - server_protocol_factory, - '127.0.0.1', 0, - family=socket.AF_INET, - ssl=sslctx_1, - **extras) - - try: - srv_socks = srv.sockets - self.assertTrue(srv_socks) - - addr = srv_socks[0].getsockname() - - tasks = [] - for _ in range(TOTAL_CNT): - tasks.append(test_client(addr)) - - await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) - - finally: - self.loop.call_soon(srv.close) - await srv.wait_closed() - - with self._silence_eof_received_warning(): - self.loop.run_until_complete(start_server()) - - self.assertEqual(CNT, TOTAL_CNT) - - for client in clients: - client.stop() - - def test_shutdown_cleanly(self): - CNT = 0 - TOTAL_CNT = 25 - - A_DATA = b'A' * 1024 * 1024 - - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, test_utils.ONLYKEY) - client_sslctx = self._create_client_ssl_context() - - def server(sock): - sock.starttls( - sslctx, - server_side=True) - - data = sock.recv_all(len(A_DATA)) - self.assertEqual(data, A_DATA) - sock.send(b'OK') - - sock.unwrap() - - sock.close() - - async def client(addr): - extras = {} - extras = dict(ssl_handshake_timeout=10.0) - - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - **extras) - - writer.write(A_DATA) - self.assertEqual(await reader.readexactly(2), b'OK') - - self.assertEqual(await reader.read(), b'') - - nonlocal CNT - CNT += 1 - - writer.close() - await self.wait_closed(writer) - - def run(coro): - nonlocal CNT - CNT = 0 - - async def _gather(*tasks): - return await asyncio.gather(*tasks) - - with self.tcp_server(server, - max_clients=TOTAL_CNT, - backlog=TOTAL_CNT) as srv: - tasks = [] - for _ in range(TOTAL_CNT): - tasks.append(coro(srv.addr)) - - self.loop.run_until_complete( - _gather(*tasks)) - - self.assertEqual(CNT, TOTAL_CNT) - - with self._silence_eof_received_warning(): - run(client) - - def test_flush_before_shutdown(self): - CHUNK = 1024 * 128 - SIZE = 32 - - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, test_utils.ONLYKEY) - client_sslctx = self._create_client_ssl_context() - if hasattr(ssl, 'OP_NO_TLSv1_3'): - client_sslctx.options |= ssl.OP_NO_TLSv1_3 - - future = None - - def server(sock): - sock.starttls(sslctx, server_side=True) - self.assertEqual(sock.recv_all(4), b'ping') - sock.send(b'pong') - time.sleep(0.5) # hopefully stuck the TCP buffer - data = sock.recv_all(CHUNK * SIZE) - self.assertEqual(len(data), CHUNK * SIZE) - sock.close() - - def run(meth): - def wrapper(sock): - try: - meth(sock) - except Exception as ex: - self.loop.call_soon_threadsafe(future.set_exception, ex) - else: - self.loop.call_soon_threadsafe(future.set_result, None) - return wrapper - - async def client(addr): - nonlocal future - future = self.loop.create_future() - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='') - sslprotocol = writer.transport._ssl_protocol - writer.write(b'ping') - data = await reader.readexactly(4) - self.assertEqual(data, b'pong') - - sslprotocol.pause_writing() - for _ in range(SIZE): - writer.write(b'x' * CHUNK) - - writer.close() - sslprotocol.resume_writing() - - await self.wait_closed(writer) - try: - data = await reader.read() - self.assertEqual(data, b'') - except ConnectionResetError: - pass - await future - - with self.tcp_server(run(server)) as srv: - self.loop.run_until_complete(client(srv.addr)) - - def test_remote_shutdown_receives_trailing_data(self): - CHUNK = 1024 * 128 - SIZE = 32 - - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, - test_utils.ONLYKEY - ) - client_sslctx = self._create_client_ssl_context() - future = None - - def server(sock): - incoming = ssl.MemoryBIO() - outgoing = ssl.MemoryBIO() - sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) - - while True: - try: - sslobj.do_handshake() - except ssl.SSLWantReadError: - if outgoing.pending: - sock.send(outgoing.read()) - incoming.write(sock.recv(16384)) - else: - if outgoing.pending: - sock.send(outgoing.read()) - break - - while True: - try: - data = sslobj.read(4) - except ssl.SSLWantReadError: - incoming.write(sock.recv(16384)) - else: - break - - self.assertEqual(data, b'ping') - sslobj.write(b'pong') - sock.send(outgoing.read()) - - time.sleep(0.2) # wait for the peer to fill its backlog - - # send close_notify but don't wait for response - with self.assertRaises(ssl.SSLWantReadError): - sslobj.unwrap() - sock.send(outgoing.read()) - - # should receive all data - data_len = 0 - while True: - try: - chunk = len(sslobj.read(16384)) - data_len += chunk - except ssl.SSLWantReadError: - incoming.write(sock.recv(16384)) - except ssl.SSLZeroReturnError: - break - - self.assertEqual(data_len, CHUNK * SIZE) - - # verify that close_notify is received - sslobj.unwrap() - - sock.close() - - def eof_server(sock): - sock.starttls(sslctx, server_side=True) - self.assertEqual(sock.recv_all(4), b'ping') - sock.send(b'pong') - - time.sleep(0.2) # wait for the peer to fill its backlog - - # send EOF - sock.shutdown(socket.SHUT_WR) - - # should receive all data - data = sock.recv_all(CHUNK * SIZE) - self.assertEqual(len(data), CHUNK * SIZE) - - sock.close() - - async def client(addr): - nonlocal future - future = self.loop.create_future() - - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='') - writer.write(b'ping') - data = await reader.readexactly(4) - self.assertEqual(data, b'pong') - - # fill write backlog in a hacky way - renegotiation won't help - for _ in range(SIZE): - writer.transport._test__append_write_backlog(b'x' * CHUNK) - - try: - data = await reader.read() - self.assertEqual(data, b'') - except (BrokenPipeError, ConnectionResetError): - pass - - await future - - writer.close() - await self.wait_closed(writer) - - def run(meth): - def wrapper(sock): - try: - meth(sock) - except Exception as ex: - self.loop.call_soon_threadsafe(future.set_exception, ex) - else: - self.loop.call_soon_threadsafe(future.set_result, None) - return wrapper - - with self.tcp_server(run(server)) as srv: - self.loop.run_until_complete(client(srv.addr)) - - with self.tcp_server(run(eof_server)) as srv: - self.loop.run_until_complete(client(srv.addr)) - - def test_connect_timeout_warning(self): - s = socket.socket(socket.AF_INET) - s.bind(('127.0.0.1', 0)) - addr = s.getsockname() - - async def test(): - try: - await asyncio.wait_for( - self.loop.create_connection(asyncio.Protocol, - *addr, ssl=True), - 0.1) - except (ConnectionRefusedError, asyncio.TimeoutError): - pass - else: - self.fail('TimeoutError is not raised') - - with s: - try: - with self.assertWarns(ResourceWarning) as cm: - self.loop.run_until_complete(test()) - gc.collect() - gc.collect() - gc.collect() - except AssertionError as e: - self.assertEqual(str(e), 'ResourceWarning not triggered') - else: - self.fail('Unexpected ResourceWarning: {}'.format(cm.warning)) - - def test_handshake_timeout_handler_leak(self): - s = socket.socket(socket.AF_INET) - s.bind(('127.0.0.1', 0)) - s.listen(1) - addr = s.getsockname() - - async def test(ctx): - try: - await asyncio.wait_for( - self.loop.create_connection(asyncio.Protocol, *addr, - ssl=ctx), - 0.1) - except (ConnectionRefusedError, asyncio.TimeoutError): - pass - else: - self.fail('TimeoutError is not raised') - - with s: - ctx = ssl.create_default_context() - self.loop.run_until_complete(test(ctx)) - ctx = weakref.ref(ctx) - - # SSLProtocol should be DECREF to 0 - self.assertIsNone(ctx()) - - def test_shutdown_timeout_handler_leak(self): - loop = self.loop - - def server(sock): - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, - test_utils.ONLYKEY - ) - sock = sslctx.wrap_socket(sock, server_side=True) - sock.recv(32) - sock.close() - - class Protocol(asyncio.Protocol): - def __init__(self): - self.fut = asyncio.Future(loop=loop) - - def connection_lost(self, exc): - self.fut.set_result(None) - - async def client(addr, ctx): - tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) - tr.close() - await pr.fut - - with self.tcp_server(server) as srv: - ctx = self._create_client_ssl_context() - loop.run_until_complete(client(srv.addr, ctx)) - ctx = weakref.ref(ctx) - - # asyncio has no shutdown timeout, but it ends up with a circular - # reference loop - not ideal (introduces gc glitches), but at least - # not leaking - gc.collect() - gc.collect() - gc.collect() - - # SSLProtocol should be DECREF to 0 - self.assertIsNone(ctx()) - - def test_shutdown_timeout_handler_not_set(self): - loop = self.loop - eof = asyncio.Event() - extra = None - - def server(sock): - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, - test_utils.ONLYKEY - ) - sock = sslctx.wrap_socket(sock, server_side=True) - sock.send(b'hello') - assert sock.recv(1024) == b'world' - sock.send(b'extra bytes') - # sending EOF here - sock.shutdown(socket.SHUT_WR) - loop.call_soon_threadsafe(eof.set) - # make sure we have enough time to reproduce the issue - assert sock.recv(1024) == b'' - sock.close() - - class Protocol(asyncio.Protocol): - def __init__(self): - self.fut = asyncio.Future(loop=loop) - self.transport = None - - def connection_made(self, transport): - self.transport = transport - - def data_received(self, data): - if data == b'hello': - self.transport.write(b'world') - # pause reading would make incoming data stay in the sslobj - self.transport.pause_reading() - else: - nonlocal extra - extra = data - - def connection_lost(self, exc): - if exc is None: - self.fut.set_result(None) - else: - self.fut.set_exception(exc) - - async def client(addr): - ctx = self._create_client_ssl_context() - tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) - await eof.wait() - tr.resume_reading() - await pr.fut - tr.close() - assert extra == b'extra bytes' - - with self.tcp_server(server) as srv: - loop.run_until_complete(client(srv.addr)) - - -############################################################################### -# Socket Testing Utilities -############################################################################### - - -class TestSocketWrapper: - - def __init__(self, sock): - self.__sock = sock - - def recv_all(self, n): - buf = b'' - while len(buf) < n: - data = self.recv(n - len(buf)) - if data == b'': - raise ConnectionAbortedError - buf += data - return buf - - def starttls(self, ssl_context, *, - server_side=False, - server_hostname=None, - do_handshake_on_connect=True): - - assert isinstance(ssl_context, ssl.SSLContext) - - ssl_sock = ssl_context.wrap_socket( - self.__sock, server_side=server_side, - server_hostname=server_hostname, - do_handshake_on_connect=do_handshake_on_connect) - - if server_side: - ssl_sock.do_handshake() - - self.__sock.close() - self.__sock = ssl_sock - - def __getattr__(self, name): - return getattr(self.__sock, name) - - def __repr__(self): - return '<{} {!r}>'.format(type(self).__name__, self.__sock) - - -class SocketThread(threading.Thread): - - def stop(self): - self._active = False - self.join() - - def __enter__(self): - self.start() - return self - - def __exit__(self, *exc): - self.stop() - - -class TestThreadedClient(SocketThread): - - def __init__(self, test, sock, prog, timeout): - threading.Thread.__init__(self, None, None, 'test-client') - self.daemon = True - - self._timeout = timeout - self._sock = sock - self._active = True - self._prog = prog - self._test = test - - def run(self): - try: - self._prog(TestSocketWrapper(self._sock)) - except (KeyboardInterrupt, SystemExit): - raise - except BaseException as ex: - self._test._abort_socket_test(ex) - - -class TestThreadedServer(SocketThread): - - def __init__(self, test, sock, prog, timeout, max_clients): - threading.Thread.__init__(self, None, None, 'test-server') - self.daemon = True - - self._clients = 0 - self._finished_clients = 0 - self._max_clients = max_clients - self._timeout = timeout - self._sock = sock - self._active = True - - self._prog = prog - - self._s1, self._s2 = socket.socketpair() - self._s1.setblocking(False) - - self._test = test - - def stop(self): - try: - if self._s2 and self._s2.fileno() != -1: - try: - self._s2.send(b'stop') - except OSError: - pass - finally: - super().stop() - - def run(self): - try: - with self._sock: - self._sock.setblocking(0) - self._run() - finally: - self._s1.close() - self._s2.close() - - def _run(self): - while self._active: - if self._clients >= self._max_clients: - return - - r, w, x = select.select( - [self._sock, self._s1], [], [], self._timeout) - - if self._s1 in r: - return - - if self._sock in r: - try: - conn, addr = self._sock.accept() - except BlockingIOError: - continue - except socket.timeout: - if not self._active: - return - else: - raise - else: - self._clients += 1 - conn.settimeout(self._timeout) - try: - with conn: - self._handle_client(conn) - except (KeyboardInterrupt, SystemExit): - raise - except BaseException as ex: - self._active = False - try: - raise - finally: - self._test._abort_socket_test(ex) - - def _handle_client(self, sock): - self._prog(TestSocketWrapper(sock)) - - @property - def addr(self): - return self._sock.getsockname() diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index 79a81bd..e87863e 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -15,6 +15,7 @@ import asyncio from asyncio import log from asyncio import protocols from asyncio import sslproto +from test import support from test.test_asyncio import utils as test_utils from test.test_asyncio import functional as func_tests @@ -43,13 +44,16 @@ class SslProtoHandshakeTests(test_utils.TestCase): def connection_made(self, ssl_proto, *, do_handshake=None): transport = mock.Mock() - sslobj = mock.Mock() - # emulate reading decompressed data - sslobj.read.side_effect = ssl.SSLWantReadError - if do_handshake is not None: - sslobj.do_handshake = do_handshake - ssl_proto._sslobj = sslobj - ssl_proto.connection_made(transport) + sslpipe = mock.Mock() + sslpipe.shutdown.return_value = b'' + if do_handshake: + sslpipe.do_handshake.side_effect = do_handshake + else: + def mock_handshake(callback): + return [] + sslpipe.do_handshake.side_effect = mock_handshake + with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): + ssl_proto.connection_made(transport) return transport def test_handshake_timeout_zero(self): @@ -71,10 +75,7 @@ class SslProtoHandshakeTests(test_utils.TestCase): def test_eof_received_waiter(self): waiter = self.loop.create_future() ssl_proto = self.ssl_protocol(waiter=waiter) - self.connection_made( - ssl_proto, - do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) - ) + self.connection_made(ssl_proto) ssl_proto.eof_received() test_utils.run_briefly(self.loop) self.assertIsInstance(waiter.exception(), ConnectionResetError) @@ -99,10 +100,7 @@ class SslProtoHandshakeTests(test_utils.TestCase): # yield from waiter hang if lost_connection was called. waiter = self.loop.create_future() ssl_proto = self.ssl_protocol(waiter=waiter) - self.connection_made( - ssl_proto, - do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) - ) + self.connection_made(ssl_proto) ssl_proto.connection_lost(ConnectionAbortedError) test_utils.run_briefly(self.loop) self.assertIsInstance(waiter.exception(), ConnectionAbortedError) @@ -112,10 +110,7 @@ class SslProtoHandshakeTests(test_utils.TestCase): waiter = self.loop.create_future() ssl_proto = self.ssl_protocol(waiter=waiter) - transport = self.connection_made( - ssl_proto, - do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) - ) + transport = self.connection_made(ssl_proto) test_utils.run_briefly(self.loop) ssl_proto._app_transport.close() @@ -148,7 +143,7 @@ class SslProtoHandshakeTests(test_utils.TestCase): transp.close() # should not raise - self.assertIsNone(ssl_proto.buffer_updated(5)) + self.assertIsNone(ssl_proto.data_received(b'data')) def test_write_after_closing(self): ssl_proto = self.ssl_protocol() |