import asyncio import asyncio.sslproto import contextlib import gc import logging import select import socket import tempfile import threading import time import weakref import unittest try: import ssl except ImportError: ssl = None from test import support from test.test_asyncio import utils as test_utils def tearDownModule(): asyncio.set_event_loop_policy(None) class MyBaseProto(asyncio.Protocol): connected = None done = None def __init__(self, loop=None): self.transport = None self.state = 'INITIAL' self.nbytes = 0 if loop is not None: self.connected = asyncio.Future(loop=loop) self.done = asyncio.Future(loop=loop) def connection_made(self, transport): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' if self.connected: self.connected.set_result(None) def data_received(self, data): assert self.state == 'CONNECTED', self.state self.nbytes += len(data) def eof_received(self): assert self.state == 'CONNECTED', self.state self.state = 'EOF' def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state self.state = 'CLOSED' if self.done: self.done.set_result(None) @unittest.skipIf(ssl is None, 'No ssl module') class TestSSL(test_utils.TestCase): PAYLOAD_SIZE = 1024 * 100 TIMEOUT = 60 def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() self.set_event_loop(self.loop) self.addCleanup(self.loop.close) def tearDown(self): # just in case if we have transport close callbacks if not self.loop.is_closed(): test_utils.run_briefly(self.loop) self.doCleanups() support.gc_collect() super().tearDown() def tcp_server(self, server_prog, *, family=socket.AF_INET, addr=None, timeout=5, backlog=1, max_clients=10): if addr is None: if family == getattr(socket, "AF_UNIX", None): with tempfile.NamedTemporaryFile() as tmp: addr = tmp.name else: addr = ('127.0.0.1', 0) sock = socket.socket(family, socket.SOCK_STREAM) if timeout is None: raise RuntimeError('timeout is required') if timeout <= 0: raise RuntimeError('only blocking sockets are supported') sock.settimeout(timeout) try: sock.bind(addr) sock.listen(backlog) except OSError as ex: sock.close() raise ex return TestThreadedServer( self, sock, server_prog, timeout, max_clients) def tcp_client(self, client_prog, family=socket.AF_INET, timeout=10): sock = socket.socket(family, socket.SOCK_STREAM) if timeout is None: raise RuntimeError('timeout is required') if timeout <= 0: raise RuntimeError('only blocking sockets are supported') sock.settimeout(timeout) return TestThreadedClient( self, sock, client_prog, timeout) def unix_server(self, *args, **kwargs): return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) def unix_client(self, *args, **kwargs): return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) def _create_server_ssl_context(self, certfile, keyfile=None): sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslcontext.options |= ssl.OP_NO_SSLv2 sslcontext.load_cert_chain(certfile, keyfile) return sslcontext def _create_client_ssl_context(self, *, disable_verify=True): sslcontext = ssl.create_default_context() sslcontext.check_hostname = False if disable_verify: sslcontext.verify_mode = ssl.CERT_NONE return sslcontext @contextlib.contextmanager def _silence_eof_received_warning(self): # TODO This warning has to be fixed in asyncio. logger = logging.getLogger('asyncio') filter = logging.Filter('has no effect when using ssl') logger.addFilter(filter) try: yield finally: logger.removeFilter(filter) def _abort_socket_test(self, ex): try: self.loop.stop() finally: self.fail(ex) def new_loop(self): return asyncio.new_event_loop() def new_policy(self): return asyncio.DefaultEventLoopPolicy() async def wait_closed(self, obj): if not isinstance(obj, asyncio.StreamWriter): return try: await obj.wait_closed() except (BrokenPipeError, ConnectionError): pass def test_create_server_ssl_1(self): CNT = 0 # number of clients that were successful TOTAL_CNT = 25 # total number of clients that test will create TIMEOUT = 60.0 # timeout for this test A_DATA = b'A' * 1024 * 1024 B_DATA = b'B' * 1024 * 1024 sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY ) client_sslctx = self._create_client_ssl_context() clients = [] async def handle_client(reader, writer): nonlocal CNT data = await reader.readexactly(len(A_DATA)) self.assertEqual(data, A_DATA) writer.write(b'OK') data = await reader.readexactly(len(B_DATA)) self.assertEqual(data, B_DATA) writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) await writer.drain() writer.close() CNT += 1 async def test_client(addr): fut = asyncio.Future() def prog(sock): try: sock.starttls(client_sslctx) sock.connect(addr) sock.send(A_DATA) data = sock.recv_all(2) self.assertEqual(data, b'OK') sock.send(B_DATA) data = sock.recv_all(4) self.assertEqual(data, b'SPAM') sock.close() except Exception as ex: self.loop.call_soon_threadsafe(fut.set_exception, ex) else: self.loop.call_soon_threadsafe(fut.set_result, None) client = self.tcp_client(prog) client.start() clients.append(client) await fut async def start_server(): extras = {} extras = dict(ssl_handshake_timeout=40.0) srv = await asyncio.start_server( handle_client, '127.0.0.1', 0, family=socket.AF_INET, ssl=sslctx, **extras) try: srv_socks = srv.sockets self.assertTrue(srv_socks) addr = srv_socks[0].getsockname() tasks = [] for _ in range(TOTAL_CNT): tasks.append(test_client(addr)) await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) finally: self.loop.call_soon(srv.close) await srv.wait_closed() with self._silence_eof_received_warning(): self.loop.run_until_complete(start_server()) self.assertEqual(CNT, TOTAL_CNT) for client in clients: client.stop() def test_create_connection_ssl_1(self): self.loop.set_exception_handler(None) CNT = 0 TOTAL_CNT = 25 A_DATA = b'A' * 1024 * 1024 B_DATA = b'B' * 1024 * 1024 sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY ) client_sslctx = self._create_client_ssl_context() def server(sock): sock.starttls( sslctx, server_side=True) data = sock.recv_all(len(A_DATA)) self.assertEqual(data, A_DATA) sock.send(b'OK') data = sock.recv_all(len(B_DATA)) self.assertEqual(data, B_DATA) sock.send(b'SPAM') sock.close() async def client(addr): extras = {} extras = dict(ssl_handshake_timeout=40.0) reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='', **extras) writer.write(A_DATA) self.assertEqual(await reader.readexactly(2), b'OK') writer.write(B_DATA) self.assertEqual(await reader.readexactly(4), b'SPAM') nonlocal CNT CNT += 1 writer.close() await self.wait_closed(writer) async def client_sock(addr): sock = socket.socket() sock.connect(addr) reader, writer = await asyncio.open_connection( sock=sock, ssl=client_sslctx, server_hostname='') writer.write(A_DATA) self.assertEqual(await reader.readexactly(2), b'OK') writer.write(B_DATA) self.assertEqual(await reader.readexactly(4), b'SPAM') nonlocal CNT CNT += 1 writer.close() await self.wait_closed(writer) sock.close() def run(coro): nonlocal CNT CNT = 0 async def _gather(*tasks): # trampoline return await asyncio.gather(*tasks) with self.tcp_server(server, max_clients=TOTAL_CNT, backlog=TOTAL_CNT) as srv: tasks = [] for _ in range(TOTAL_CNT): tasks.append(coro(srv.addr)) self.loop.run_until_complete(_gather(*tasks)) self.assertEqual(CNT, TOTAL_CNT) with self._silence_eof_received_warning(): run(client) with self._silence_eof_received_warning(): run(client_sock) def test_create_connection_ssl_slow_handshake(self): client_sslctx = self._create_client_ssl_context() # silence error logger self.loop.set_exception_handler(lambda *args: None) def server(sock): try: sock.recv_all(1024 * 1024) except ConnectionAbortedError: pass finally: sock.close() async def client(addr): reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='', ssl_handshake_timeout=1.0) writer.close() await self.wait_closed(writer) with self.tcp_server(server, max_clients=1, backlog=1) as srv: with self.assertRaisesRegex( ConnectionAbortedError, r'SSL handshake.*is taking longer'): self.loop.run_until_complete(client(srv.addr)) def test_create_connection_ssl_failed_certificate(self): # silence error logger self.loop.set_exception_handler(lambda *args: None) sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY ) client_sslctx = self._create_client_ssl_context(disable_verify=False) def server(sock): try: sock.starttls( sslctx, server_side=True) sock.connect() except (ssl.SSLError, OSError): pass finally: sock.close() async def client(addr): reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='', ssl_handshake_timeout=1.0) writer.close() await self.wait_closed(writer) with self.tcp_server(server, max_clients=1, backlog=1) as srv: with self.assertRaises(ssl.SSLCertVerificationError): self.loop.run_until_complete(client(srv.addr)) def test_ssl_handshake_timeout(self): # bpo-29970: Check that a connection is aborted if handshake is not # completed in timeout period, instead of remaining open indefinitely client_sslctx = test_utils.simple_client_sslcontext() # silence error logger messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) server_side_aborted = False def server(sock): nonlocal server_side_aborted try: sock.recv_all(1024 * 1024) except ConnectionAbortedError: server_side_aborted = True finally: sock.close() async def client(addr): await asyncio.wait_for( self.loop.create_connection( asyncio.Protocol, *addr, ssl=client_sslctx, server_hostname='', ssl_handshake_timeout=10.0), 0.5) with self.tcp_server(server, max_clients=1, backlog=1) as srv: with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete(client(srv.addr)) self.assertTrue(server_side_aborted) # Python issue #23197: cancelling a handshake must not raise an # exception or log an error, even if the handshake failed self.assertEqual(messages, []) def test_ssl_handshake_connection_lost(self): # #246: make sure that no connection_lost() is called before # connection_made() is called first client_sslctx = test_utils.simple_client_sslcontext() # silence error logger self.loop.set_exception_handler(lambda loop, ctx: None) connection_made_called = False connection_lost_called = False def server(sock): sock.recv(1024) # break the connection during handshake sock.close() class ClientProto(asyncio.Protocol): def connection_made(self, transport): nonlocal connection_made_called connection_made_called = True def connection_lost(self, exc): nonlocal connection_lost_called connection_lost_called = True async def client(addr): await self.loop.create_connection( ClientProto, *addr, ssl=client_sslctx, server_hostname=''), with self.tcp_server(server, max_clients=1, backlog=1) as srv: with self.assertRaises(ConnectionResetError): self.loop.run_until_complete(client(srv.addr)) if connection_lost_called: if connection_made_called: self.fail("unexpected call to connection_lost()") else: self.fail("unexpected call to connection_lost() without" "calling connection_made()") elif connection_made_called: self.fail("unexpected call to connection_made()") def test_ssl_connect_accepted_socket(self): proto = ssl.PROTOCOL_TLS_SERVER server_context = ssl.SSLContext(proto) server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY) if hasattr(server_context, 'check_hostname'): server_context.check_hostname = False server_context.verify_mode = ssl.CERT_NONE client_context = ssl.SSLContext(proto) if hasattr(server_context, 'check_hostname'): client_context.check_hostname = False client_context.verify_mode = ssl.CERT_NONE def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None): loop = self.loop class MyProto(MyBaseProto): def connection_lost(self, exc): super().connection_lost(exc) loop.call_soon(loop.stop) def data_received(self, data): super().data_received(data) self.transport.write(expected_response) lsock = socket.socket(socket.AF_INET) lsock.bind(('127.0.0.1', 0)) lsock.listen(1) addr = lsock.getsockname() message = b'test data' response = None expected_response = b'roger' def client(): nonlocal response try: csock = socket.socket(socket.AF_INET) if client_ssl is not None: csock = client_ssl.wrap_socket(csock) csock.connect(addr) csock.sendall(message) response = csock.recv(99) csock.close() except Exception as exc: print( "Failure in client thread in test_connect_accepted_socket", exc) thread = threading.Thread(target=client, daemon=True) thread.start() conn, _ = lsock.accept() proto = MyProto(loop=loop) proto.loop = loop extras = {} if server_ssl: extras = dict(ssl_handshake_timeout=10.0) f = loop.create_task( loop.connect_accepted_socket( (lambda: proto), conn, ssl=server_ssl, **extras)) loop.run_forever() conn.close() lsock.close() thread.join(1) self.assertFalse(thread.is_alive()) self.assertEqual(proto.state, 'CLOSED') self.assertEqual(proto.nbytes, len(message)) self.assertEqual(response, expected_response) tr, _ = f.result() if server_ssl: self.assertIn('SSL', tr.__class__.__name__) tr.close() # let it close self.loop.run_until_complete(asyncio.sleep(0.1)) def test_start_tls_client_corrupted_ssl(self): self.loop.set_exception_handler(lambda loop, ctx: None) sslctx = test_utils.simple_server_sslcontext() client_sslctx = test_utils.simple_client_sslcontext() def server(sock): orig_sock = sock.dup() try: sock.starttls( sslctx, server_side=True) sock.sendall(b'A\n') sock.recv_all(1) orig_sock.send(b'please corrupt the SSL connection') except ssl.SSLError: pass finally: sock.close() orig_sock.close() async def client(addr): reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='') self.assertEqual(await reader.readline(), b'A\n') writer.write(b'B') with self.assertRaises(ssl.SSLError): await reader.readline() writer.close() try: await self.wait_closed(writer) except ssl.SSLError: pass return 'OK' with self.tcp_server(server, max_clients=1, backlog=1) as srv: res = self.loop.run_until_complete(client(srv.addr)) self.assertEqual(res, 'OK') def test_start_tls_client_reg_proto_1(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE server_context = test_utils.simple_server_sslcontext() client_context = test_utils.simple_client_sslcontext() def serve(sock): sock.settimeout(self.TIMEOUT) data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.starttls(server_context, server_side=True) sock.sendall(b'O') data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.unwrap() sock.close() class ClientProto(asyncio.Protocol): def __init__(self, on_data, on_eof): self.on_data = on_data self.on_eof = on_eof self.con_made_cnt = 0 def connection_made(proto, tr): proto.con_made_cnt += 1 # Ensure connection_made gets called only once. self.assertEqual(proto.con_made_cnt, 1) def data_received(self, data): self.on_data.set_result(data) def eof_received(self): self.on_eof.set_result(True) async def client(addr): await asyncio.sleep(0.5) on_data = self.loop.create_future() on_eof = self.loop.create_future() tr, proto = await self.loop.create_connection( lambda: ClientProto(on_data, on_eof), *addr) tr.write(HELLO_MSG) new_tr = await self.loop.start_tls(tr, proto, client_context) self.assertEqual(await on_data, b'O') new_tr.write(HELLO_MSG) await on_eof new_tr.close() with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: self.loop.run_until_complete( asyncio.wait_for(client(srv.addr), timeout=10)) def test_create_connection_memory_leak(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE server_context = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY) client_context = self._create_client_ssl_context() def serve(sock): sock.settimeout(self.TIMEOUT) sock.starttls(server_context, server_side=True) sock.sendall(b'O') data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.unwrap() sock.close() class ClientProto(asyncio.Protocol): def __init__(self, on_data, on_eof): self.on_data = on_data self.on_eof = on_eof self.con_made_cnt = 0 def connection_made(proto, tr): # XXX: We assume user stores the transport in protocol proto.tr = tr proto.con_made_cnt += 1 # Ensure connection_made gets called only once. self.assertEqual(proto.con_made_cnt, 1) def data_received(self, data): self.on_data.set_result(data) def eof_received(self): self.on_eof.set_result(True) async def client(addr): await asyncio.sleep(0.5) on_data = self.loop.create_future() on_eof = self.loop.create_future() tr, proto = await self.loop.create_connection( lambda: ClientProto(on_data, on_eof), *addr, ssl=client_context) self.assertEqual(await on_data, b'O') tr.write(HELLO_MSG) await on_eof tr.close() with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: self.loop.run_until_complete( asyncio.wait_for(client(srv.addr), timeout=10)) # No garbage is left for SSL client from loop.create_connection, even # if user stores the SSLTransport in corresponding protocol instance client_context = weakref.ref(client_context) self.assertIsNone(client_context()) def test_start_tls_client_buf_proto_1(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE server_context = test_utils.simple_server_sslcontext() client_context = test_utils.simple_client_sslcontext() client_con_made_calls = 0 def serve(sock): sock.settimeout(self.TIMEOUT) data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.starttls(server_context, server_side=True) sock.sendall(b'O') data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.sendall(b'2') data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.unwrap() sock.close() class ClientProtoFirst(asyncio.BufferedProtocol): def __init__(self, on_data): self.on_data = on_data self.buf = bytearray(1) def connection_made(self, tr): nonlocal client_con_made_calls client_con_made_calls += 1 def get_buffer(self, sizehint): return self.buf def buffer_updated(self, nsize): assert nsize == 1 self.on_data.set_result(bytes(self.buf[:nsize])) def eof_received(self): pass class ClientProtoSecond(asyncio.Protocol): def __init__(self, on_data, on_eof): self.on_data = on_data self.on_eof = on_eof self.con_made_cnt = 0 def connection_made(self, tr): nonlocal client_con_made_calls client_con_made_calls += 1 def data_received(self, data): self.on_data.set_result(data) def eof_received(self): self.on_eof.set_result(True) async def client(addr): await asyncio.sleep(0.5) on_data1 = self.loop.create_future() on_data2 = self.loop.create_future() on_eof = self.loop.create_future() tr, proto = await self.loop.create_connection( lambda: ClientProtoFirst(on_data1), *addr) tr.write(HELLO_MSG) new_tr = await self.loop.start_tls(tr, proto, client_context) self.assertEqual(await on_data1, b'O') new_tr.write(HELLO_MSG) new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) self.assertEqual(await on_data2, b'2') new_tr.write(HELLO_MSG) await on_eof new_tr.close() # connection_made() should be called only once -- when # we establish connection for the first time. Start TLS # doesn't call connection_made() on application protocols. self.assertEqual(client_con_made_calls, 1) with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: self.loop.run_until_complete( asyncio.wait_for(client(srv.addr), timeout=self.TIMEOUT)) def test_start_tls_slow_client_cancel(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE client_context = test_utils.simple_client_sslcontext() server_waits_on_handshake = self.loop.create_future() def serve(sock): sock.settimeout(self.TIMEOUT) data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) try: self.loop.call_soon_threadsafe( server_waits_on_handshake.set_result, None) data = sock.recv_all(1024 * 1024) except ConnectionAbortedError: pass finally: sock.close() class ClientProto(asyncio.Protocol): def __init__(self, on_data, on_eof): self.on_data = on_data self.on_eof = on_eof self.con_made_cnt = 0 def connection_made(proto, tr): proto.con_made_cnt += 1 # Ensure connection_made gets called only once. self.assertEqual(proto.con_made_cnt, 1) def data_received(self, data): self.on_data.set_result(data) def eof_received(self): self.on_eof.set_result(True) async def client(addr): await asyncio.sleep(0.5) on_data = self.loop.create_future() on_eof = self.loop.create_future() tr, proto = await self.loop.create_connection( lambda: ClientProto(on_data, on_eof), *addr) tr.write(HELLO_MSG) await server_waits_on_handshake with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for( self.loop.start_tls(tr, proto, client_context), 0.5) with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: self.loop.run_until_complete( asyncio.wait_for(client(srv.addr), timeout=10)) def test_start_tls_server_1(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE server_context = test_utils.simple_server_sslcontext() client_context = test_utils.simple_client_sslcontext() def client(sock, addr): sock.settimeout(self.TIMEOUT) sock.connect(addr) data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) sock.starttls(client_context) sock.sendall(HELLO_MSG) sock.unwrap() sock.close() class ServerProto(asyncio.Protocol): def __init__(self, on_con, on_eof, on_con_lost): self.on_con = on_con self.on_eof = on_eof self.on_con_lost = on_con_lost self.data = b'' def connection_made(self, tr): self.on_con.set_result(tr) def data_received(self, data): self.data += data def eof_received(self): self.on_eof.set_result(1) def connection_lost(self, exc): if exc is None: self.on_con_lost.set_result(None) else: self.on_con_lost.set_exception(exc) async def main(proto, on_con, on_eof, on_con_lost): tr = await on_con tr.write(HELLO_MSG) self.assertEqual(proto.data, b'') new_tr = await self.loop.start_tls( tr, proto, server_context, server_side=True, ssl_handshake_timeout=self.TIMEOUT) await on_eof await on_con_lost self.assertEqual(proto.data, HELLO_MSG) new_tr.close() async def run_main(): on_con = self.loop.create_future() on_eof = self.loop.create_future() on_con_lost = self.loop.create_future() proto = ServerProto(on_con, on_eof, on_con_lost) server = await self.loop.create_server( lambda: proto, '127.0.0.1', 0) addr = server.sockets[0].getsockname() with self.tcp_client(lambda sock: client(sock, addr), timeout=self.TIMEOUT): await asyncio.wait_for( main(proto, on_con, on_eof, on_con_lost), timeout=self.TIMEOUT) server.close() await server.wait_closed() self.loop.run_until_complete(run_main()) def test_create_server_ssl_over_ssl(self): CNT = 0 # number of clients that were successful TOTAL_CNT = 25 # total number of clients that test will create TIMEOUT = 10.0 # timeout for this test A_DATA = b'A' * 1024 * 1024 B_DATA = b'B' * 1024 * 1024 sslctx_1 = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY) client_sslctx_1 = self._create_client_ssl_context() sslctx_2 = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY) client_sslctx_2 = self._create_client_ssl_context() clients = [] async def handle_client(reader, writer): nonlocal CNT data = await reader.readexactly(len(A_DATA)) self.assertEqual(data, A_DATA) writer.write(b'OK') data = await reader.readexactly(len(B_DATA)) self.assertEqual(data, B_DATA) writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) await writer.drain() writer.close() CNT += 1 class ServerProtocol(asyncio.StreamReaderProtocol): def connection_made(self, transport): super_ = super() transport.pause_reading() fut = self._loop.create_task(self._loop.start_tls( transport, self, sslctx_2, server_side=True)) def cb(_): try: tr = fut.result() except Exception as ex: super_.connection_lost(ex) else: super_.connection_made(tr) fut.add_done_callback(cb) def server_protocol_factory(): reader = asyncio.StreamReader() protocol = ServerProtocol(reader, handle_client) return protocol async def test_client(addr): fut = asyncio.Future() def prog(sock): try: sock.connect(addr) sock.starttls(client_sslctx_1) # because wrap_socket() doesn't work correctly on # SSLSocket, we have to do the 2nd level SSL manually incoming = ssl.MemoryBIO() outgoing = ssl.MemoryBIO() sslobj = client_sslctx_2.wrap_bio(incoming, outgoing) def do(func, *args): while True: try: rv = func(*args) break except ssl.SSLWantReadError: if outgoing.pending: sock.send(outgoing.read()) incoming.write(sock.recv(65536)) if outgoing.pending: sock.send(outgoing.read()) return rv do(sslobj.do_handshake) do(sslobj.write, A_DATA) data = do(sslobj.read, 2) self.assertEqual(data, b'OK') do(sslobj.write, B_DATA) data = b'' while True: chunk = do(sslobj.read, 4) if not chunk: break data += chunk self.assertEqual(data, b'SPAM') do(sslobj.unwrap) sock.close() except Exception as ex: self.loop.call_soon_threadsafe(fut.set_exception, ex) sock.close() else: self.loop.call_soon_threadsafe(fut.set_result, None) client = self.tcp_client(prog) client.start() clients.append(client) await fut async def start_server(): extras = {} srv = await self.loop.create_server( server_protocol_factory, '127.0.0.1', 0, family=socket.AF_INET, ssl=sslctx_1, **extras) try: srv_socks = srv.sockets self.assertTrue(srv_socks) addr = srv_socks[0].getsockname() tasks = [] for _ in range(TOTAL_CNT): tasks.append(test_client(addr)) await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) finally: self.loop.call_soon(srv.close) await srv.wait_closed() with self._silence_eof_received_warning(): self.loop.run_until_complete(start_server()) self.assertEqual(CNT, TOTAL_CNT) for client in clients: client.stop() def test_shutdown_cleanly(self): CNT = 0 TOTAL_CNT = 25 A_DATA = b'A' * 1024 * 1024 sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY) client_sslctx = self._create_client_ssl_context() def server(sock): sock.starttls( sslctx, server_side=True) data = sock.recv_all(len(A_DATA)) self.assertEqual(data, A_DATA) sock.send(b'OK') sock.unwrap() sock.close() async def client(addr): extras = {} extras = dict(ssl_handshake_timeout=10.0) reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='', **extras) writer.write(A_DATA) self.assertEqual(await reader.readexactly(2), b'OK') self.assertEqual(await reader.read(), b'') nonlocal CNT CNT += 1 writer.close() await self.wait_closed(writer) def run(coro): nonlocal CNT CNT = 0 async def _gather(*tasks): return await asyncio.gather(*tasks) with self.tcp_server(server, max_clients=TOTAL_CNT, backlog=TOTAL_CNT) as srv: tasks = [] for _ in range(TOTAL_CNT): tasks.append(coro(srv.addr)) self.loop.run_until_complete( _gather(*tasks)) self.assertEqual(CNT, TOTAL_CNT) with self._silence_eof_received_warning(): run(client) def test_flush_before_shutdown(self): CHUNK = 1024 * 128 SIZE = 32 sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY) client_sslctx = self._create_client_ssl_context() if hasattr(ssl, 'OP_NO_TLSv1_3'): client_sslctx.options |= ssl.OP_NO_TLSv1_3 future = None def server(sock): sock.starttls(sslctx, server_side=True) self.assertEqual(sock.recv_all(4), b'ping') sock.send(b'pong') time.sleep(0.5) # hopefully stuck the TCP buffer data = sock.recv_all(CHUNK * SIZE) self.assertEqual(len(data), CHUNK * SIZE) sock.close() def run(meth): def wrapper(sock): try: meth(sock) except Exception as ex: self.loop.call_soon_threadsafe(future.set_exception, ex) else: self.loop.call_soon_threadsafe(future.set_result, None) return wrapper async def client(addr): nonlocal future future = self.loop.create_future() reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='') sslprotocol = writer.transport._ssl_protocol writer.write(b'ping') data = await reader.readexactly(4) self.assertEqual(data, b'pong') sslprotocol.pause_writing() for _ in range(SIZE): writer.write(b'x' * CHUNK) writer.close() sslprotocol.resume_writing() await self.wait_closed(writer) try: data = await reader.read() self.assertEqual(data, b'') except ConnectionResetError: pass await future with self.tcp_server(run(server)) as srv: self.loop.run_until_complete(client(srv.addr)) def test_remote_shutdown_receives_trailing_data(self): CHUNK = 1024 * 128 SIZE = 32 sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY ) client_sslctx = self._create_client_ssl_context() future = None def server(sock): incoming = ssl.MemoryBIO() outgoing = ssl.MemoryBIO() sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) while True: try: sslobj.do_handshake() except ssl.SSLWantReadError: if outgoing.pending: sock.send(outgoing.read()) incoming.write(sock.recv(16384)) else: if outgoing.pending: sock.send(outgoing.read()) break while True: try: data = sslobj.read(4) except ssl.SSLWantReadError: incoming.write(sock.recv(16384)) else: break self.assertEqual(data, b'ping') sslobj.write(b'pong') sock.send(outgoing.read()) time.sleep(0.2) # wait for the peer to fill its backlog # send close_notify but don't wait for response with self.assertRaises(ssl.SSLWantReadError): sslobj.unwrap() sock.send(outgoing.read()) # should receive all data data_len = 0 while True: try: chunk = len(sslobj.read(16384)) data_len += chunk except ssl.SSLWantReadError: incoming.write(sock.recv(16384)) except ssl.SSLZeroReturnError: break self.assertEqual(data_len, CHUNK * SIZE) # verify that close_notify is received sslobj.unwrap() sock.close() def eof_server(sock): sock.starttls(sslctx, server_side=True) self.assertEqual(sock.recv_all(4), b'ping') sock.send(b'pong') time.sleep(0.2) # wait for the peer to fill its backlog # send EOF sock.shutdown(socket.SHUT_WR) # should receive all data data = sock.recv_all(CHUNK * SIZE) self.assertEqual(len(data), CHUNK * SIZE) sock.close() async def client(addr): nonlocal future future = self.loop.create_future() reader, writer = await asyncio.open_connection( *addr, ssl=client_sslctx, server_hostname='') writer.write(b'ping') data = await reader.readexactly(4) self.assertEqual(data, b'pong') # fill write backlog in a hacky way - renegotiation won't help for _ in range(SIZE): writer.transport._test__append_write_backlog(b'x' * CHUNK) try: data = await reader.read() self.assertEqual(data, b'') except (BrokenPipeError, ConnectionResetError): pass await future writer.close() await self.wait_closed(writer) def run(meth): def wrapper(sock): try: meth(sock) except Exception as ex: self.loop.call_soon_threadsafe(future.set_exception, ex) else: self.loop.call_soon_threadsafe(future.set_result, None) return wrapper with self.tcp_server(run(server)) as srv: self.loop.run_until_complete(client(srv.addr)) with self.tcp_server(run(eof_server)) as srv: self.loop.run_until_complete(client(srv.addr)) def test_connect_timeout_warning(self): s = socket.socket(socket.AF_INET) s.bind(('127.0.0.1', 0)) addr = s.getsockname() async def test(): try: await asyncio.wait_for( self.loop.create_connection(asyncio.Protocol, *addr, ssl=True), 0.1) except (ConnectionRefusedError, asyncio.TimeoutError): pass else: self.fail('TimeoutError is not raised') with s: try: with self.assertWarns(ResourceWarning) as cm: self.loop.run_until_complete(test()) gc.collect() gc.collect() gc.collect() except AssertionError as e: self.assertEqual(str(e), 'ResourceWarning not triggered') else: self.fail('Unexpected ResourceWarning: {}'.format(cm.warning)) def test_handshake_timeout_handler_leak(self): s = socket.socket(socket.AF_INET) s.bind(('127.0.0.1', 0)) s.listen(1) addr = s.getsockname() async def test(ctx): try: await asyncio.wait_for( self.loop.create_connection(asyncio.Protocol, *addr, ssl=ctx), 0.1) except (ConnectionRefusedError, asyncio.TimeoutError): pass else: self.fail('TimeoutError is not raised') with s: ctx = ssl.create_default_context() self.loop.run_until_complete(test(ctx)) ctx = weakref.ref(ctx) # SSLProtocol should be DECREF to 0 self.assertIsNone(ctx()) def test_shutdown_timeout_handler_leak(self): loop = self.loop def server(sock): sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY ) sock = sslctx.wrap_socket(sock, server_side=True) sock.recv(32) sock.close() class Protocol(asyncio.Protocol): def __init__(self): self.fut = asyncio.Future(loop=loop) def connection_lost(self, exc): self.fut.set_result(None) async def client(addr, ctx): tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) tr.close() await pr.fut with self.tcp_server(server) as srv: ctx = self._create_client_ssl_context() loop.run_until_complete(client(srv.addr, ctx)) ctx = weakref.ref(ctx) # asyncio has no shutdown timeout, but it ends up with a circular # reference loop - not ideal (introduces gc glitches), but at least # not leaking gc.collect() gc.collect() gc.collect() # SSLProtocol should be DECREF to 0 self.assertIsNone(ctx()) def test_shutdown_timeout_handler_not_set(self): loop = self.loop eof = asyncio.Event() extra = None def server(sock): sslctx = self._create_server_ssl_context( test_utils.ONLYCERT, test_utils.ONLYKEY ) sock = sslctx.wrap_socket(sock, server_side=True) sock.send(b'hello') assert sock.recv(1024) == b'world' sock.send(b'extra bytes') # sending EOF here sock.shutdown(socket.SHUT_WR) loop.call_soon_threadsafe(eof.set) # make sure we have enough time to reproduce the issue assert sock.recv(1024) == b'' sock.close() class Protocol(asyncio.Protocol): def __init__(self): self.fut = asyncio.Future(loop=loop) self.transport = None def connection_made(self, transport): self.transport = transport def data_received(self, data): if data == b'hello': self.transport.write(b'world') # pause reading would make incoming data stay in the sslobj self.transport.pause_reading() else: nonlocal extra extra = data def connection_lost(self, exc): if exc is None: self.fut.set_result(None) else: self.fut.set_exception(exc) async def client(addr): ctx = self._create_client_ssl_context() tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) await eof.wait() tr.resume_reading() await pr.fut tr.close() assert extra == b'extra bytes' with self.tcp_server(server) as srv: loop.run_until_complete(client(srv.addr)) ############################################################################### # Socket Testing Utilities ############################################################################### class TestSocketWrapper: def __init__(self, sock): self.__sock = sock def recv_all(self, n): buf = b'' while len(buf) < n: data = self.recv(n - len(buf)) if data == b'': raise ConnectionAbortedError buf += data return buf def starttls(self, ssl_context, *, server_side=False, server_hostname=None, do_handshake_on_connect=True): assert isinstance(ssl_context, ssl.SSLContext) ssl_sock = ssl_context.wrap_socket( self.__sock, server_side=server_side, server_hostname=server_hostname, do_handshake_on_connect=do_handshake_on_connect) if server_side: ssl_sock.do_handshake() self.__sock.close() self.__sock = ssl_sock def __getattr__(self, name): return getattr(self.__sock, name) def __repr__(self): return '<{} {!r}>'.format(type(self).__name__, self.__sock) class SocketThread(threading.Thread): def stop(self): self._active = False self.join() def __enter__(self): self.start() return self def __exit__(self, *exc): self.stop() class TestThreadedClient(SocketThread): def __init__(self, test, sock, prog, timeout): threading.Thread.__init__(self, None, None, 'test-client') self.daemon = True self._timeout = timeout self._sock = sock self._active = True self._prog = prog self._test = test def run(self): try: self._prog(TestSocketWrapper(self._sock)) except (KeyboardInterrupt, SystemExit): raise except BaseException as ex: self._test._abort_socket_test(ex) class TestThreadedServer(SocketThread): def __init__(self, test, sock, prog, timeout, max_clients): threading.Thread.__init__(self, None, None, 'test-server') self.daemon = True self._clients = 0 self._finished_clients = 0 self._max_clients = max_clients self._timeout = timeout self._sock = sock self._active = True self._prog = prog self._s1, self._s2 = socket.socketpair() self._s1.setblocking(False) self._test = test def stop(self): try: if self._s2 and self._s2.fileno() != -1: try: self._s2.send(b'stop') except OSError: pass finally: super().stop() def run(self): try: with self._sock: self._sock.setblocking(0) self._run() finally: self._s1.close() self._s2.close() def _run(self): while self._active: if self._clients >= self._max_clients: return r, w, x = select.select( [self._sock, self._s1], [], [], self._timeout) if self._s1 in r: return if self._sock in r: try: conn, addr = self._sock.accept() except BlockingIOError: continue except socket.timeout: if not self._active: return else: raise else: self._clients += 1 conn.settimeout(self._timeout) try: with conn: self._handle_client(conn) except (KeyboardInterrupt, SystemExit): raise except BaseException as ex: self._active = False try: raise finally: self._test._abort_socket_test(ex) def _handle_client(self, sock): self._prog(TestSocketWrapper(sock)) @property def addr(self): return self._sock.getsockname()