summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_asyncio
diff options
context:
space:
mode:
authorPablo Galindo <Pablogsal@gmail.com>2021-05-03 15:21:59 (GMT)
committerGitHub <noreply@github.com>2021-05-03 15:21:59 (GMT)
commit7719953b30430b351ba0f153c2b51b16cc68ee36 (patch)
tree8014086b85a13ed79d45e29ab74a9a9f5c9c68eb /Lib/test/test_asyncio
parent39494285e15dc2d291ec13de5045b930eaf0a3db (diff)
downloadcpython-7719953b30430b351ba0f153c2b51b16cc68ee36.zip
cpython-7719953b30430b351ba0f153c2b51b16cc68ee36.tar.gz
cpython-7719953b30430b351ba0f153c2b51b16cc68ee36.tar.bz2
bpo-44011: Revert "New asyncio ssl implementation (GH-17975)" (GH-25848)
This reverts commit 5fb06edbbb769561e245d0fe13002bab50e2ae60 and all subsequent dependent commits.
Diffstat (limited to 'Lib/test/test_asyncio')
-rw-r--r--Lib/test/test_asyncio/test_base_events.py21
-rw-r--r--Lib/test/test_asyncio/test_selector_events.py38
-rw-r--r--Lib/test/test_asyncio/test_ssl.py1723
-rw-r--r--Lib/test/test_asyncio/test_sslproto.py35
4 files changed, 60 insertions, 1757 deletions
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
index be5ea1e..5691d42 100644
--- a/Lib/test/test_asyncio/test_base_events.py
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -1437,51 +1437,44 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
ANY = mock.ANY
handshake_timeout = object()
- shutdown_timeout = object()
# First try the default server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
- ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='python.org',
- ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_handshake_timeout=handshake_timeout)
# Next try an explicit server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='perl.com',
- ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='perl.com',
- ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_handshake_timeout=handshake_timeout)
# Finally try an explicit empty server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='',
- ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='',
- ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_handshake_timeout=handshake_timeout)
def test_create_connection_no_ssl_server_hostname_errors(self):
# When not using ssl, server_hostname must be None.
@@ -1888,7 +1881,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
constants.ACCEPT_RETRY_DELAY,
# self.loop._start_serving
mock.ANY,
- MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY)
+ MyProto, sock, None, None, mock.ANY, mock.ANY)
def test_call_coroutine(self):
with self.assertWarns(DeprecationWarning):
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
index 349e4f2..1613c75 100644
--- a/Lib/test/test_asyncio/test_selector_events.py
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -70,6 +70,44 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
close_transport(transport)
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_make_ssl_transport(self):
+ m = mock.Mock()
+ self.loop._add_reader = mock.Mock()
+ self.loop._add_reader._is_coroutine = False
+ self.loop._add_writer = mock.Mock()
+ self.loop._remove_reader = mock.Mock()
+ self.loop._remove_writer = mock.Mock()
+ waiter = self.loop.create_future()
+ with test_utils.disable_logger():
+ transport = self.loop._make_ssl_transport(
+ m, asyncio.Protocol(), m, waiter)
+
+ with self.assertRaisesRegex(RuntimeError,
+ r'SSL transport.*not.*initialized'):
+ transport.is_reading()
+
+ # execute the handshake while the logger is disabled
+ # to ignore SSL handshake failure
+ test_utils.run_briefly(self.loop)
+
+ self.assertTrue(transport.is_reading())
+ transport.pause_reading()
+ transport.pause_reading()
+ self.assertFalse(transport.is_reading())
+ transport.resume_reading()
+ transport.resume_reading()
+ self.assertTrue(transport.is_reading())
+
+ # Sanity check
+ class_name = transport.__class__.__name__
+ self.assertIn("ssl", class_name.lower())
+ self.assertIn("transport", class_name.lower())
+
+ transport.close()
+ # execute pending callbacks to close the socket transport
+ test_utils.run_briefly(self.loop)
+
@mock.patch('asyncio.selector_events.ssl', None)
@mock.patch('asyncio.sslproto.ssl', None)
def test_make_ssl_transport_without_ssl_error(self):
diff --git a/Lib/test/test_asyncio/test_ssl.py b/Lib/test/test_asyncio/test_ssl.py
deleted file mode 100644
index 9cdd281..0000000
--- a/Lib/test/test_asyncio/test_ssl.py
+++ /dev/null
@@ -1,1723 +0,0 @@
-import asyncio
-import asyncio.sslproto
-import contextlib
-import gc
-import logging
-import select
-import socket
-import tempfile
-import threading
-import time
-import weakref
-import unittest
-
-try:
- import ssl
-except ImportError:
- ssl = None
-
-from test import support
-from test.test_asyncio import utils as test_utils
-
-
-def tearDownModule():
- asyncio.set_event_loop_policy(None)
-
-
-class MyBaseProto(asyncio.Protocol):
- connected = None
- done = None
-
- def __init__(self, loop=None):
- self.transport = None
- self.state = 'INITIAL'
- self.nbytes = 0
- if loop is not None:
- self.connected = asyncio.Future(loop=loop)
- self.done = asyncio.Future(loop=loop)
-
- def connection_made(self, transport):
- self.transport = transport
- assert self.state == 'INITIAL', self.state
- self.state = 'CONNECTED'
- if self.connected:
- self.connected.set_result(None)
-
- def data_received(self, data):
- assert self.state == 'CONNECTED', self.state
- self.nbytes += len(data)
-
- def eof_received(self):
- assert self.state == 'CONNECTED', self.state
- self.state = 'EOF'
-
- def connection_lost(self, exc):
- assert self.state in ('CONNECTED', 'EOF'), self.state
- self.state = 'CLOSED'
- if self.done:
- self.done.set_result(None)
-
-
-@unittest.skipIf(ssl is None, 'No ssl module')
-class TestSSL(test_utils.TestCase):
-
- PAYLOAD_SIZE = 1024 * 100
- TIMEOUT = 60
-
- def setUp(self):
- super().setUp()
- self.loop = asyncio.new_event_loop()
- self.set_event_loop(self.loop)
- self.addCleanup(self.loop.close)
-
- def tearDown(self):
- # just in case if we have transport close callbacks
- if not self.loop.is_closed():
- test_utils.run_briefly(self.loop)
-
- self.doCleanups()
- support.gc_collect()
- super().tearDown()
-
- def tcp_server(self, server_prog, *,
- family=socket.AF_INET,
- addr=None,
- timeout=5,
- backlog=1,
- max_clients=10):
-
- if addr is None:
- if family == getattr(socket, "AF_UNIX", None):
- with tempfile.NamedTemporaryFile() as tmp:
- addr = tmp.name
- else:
- addr = ('127.0.0.1', 0)
-
- sock = socket.socket(family, socket.SOCK_STREAM)
-
- if timeout is None:
- raise RuntimeError('timeout is required')
- if timeout <= 0:
- raise RuntimeError('only blocking sockets are supported')
- sock.settimeout(timeout)
-
- try:
- sock.bind(addr)
- sock.listen(backlog)
- except OSError as ex:
- sock.close()
- raise ex
-
- return TestThreadedServer(
- self, sock, server_prog, timeout, max_clients)
-
- def tcp_client(self, client_prog,
- family=socket.AF_INET,
- timeout=10):
-
- sock = socket.socket(family, socket.SOCK_STREAM)
-
- if timeout is None:
- raise RuntimeError('timeout is required')
- if timeout <= 0:
- raise RuntimeError('only blocking sockets are supported')
- sock.settimeout(timeout)
-
- return TestThreadedClient(
- self, sock, client_prog, timeout)
-
- def unix_server(self, *args, **kwargs):
- return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
-
- def unix_client(self, *args, **kwargs):
- return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
-
- def _create_server_ssl_context(self, certfile, keyfile=None):
- sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
- sslcontext.options |= ssl.OP_NO_SSLv2
- sslcontext.load_cert_chain(certfile, keyfile)
- return sslcontext
-
- def _create_client_ssl_context(self, *, disable_verify=True):
- sslcontext = ssl.create_default_context()
- sslcontext.check_hostname = False
- if disable_verify:
- sslcontext.verify_mode = ssl.CERT_NONE
- return sslcontext
-
- @contextlib.contextmanager
- def _silence_eof_received_warning(self):
- # TODO This warning has to be fixed in asyncio.
- logger = logging.getLogger('asyncio')
- filter = logging.Filter('has no effect when using ssl')
- logger.addFilter(filter)
- try:
- yield
- finally:
- logger.removeFilter(filter)
-
- def _abort_socket_test(self, ex):
- try:
- self.loop.stop()
- finally:
- self.fail(ex)
-
- def new_loop(self):
- return asyncio.new_event_loop()
-
- def new_policy(self):
- return asyncio.DefaultEventLoopPolicy()
-
- async def wait_closed(self, obj):
- if not isinstance(obj, asyncio.StreamWriter):
- return
- try:
- await obj.wait_closed()
- except (BrokenPipeError, ConnectionError):
- pass
-
- def test_create_server_ssl_1(self):
- CNT = 0 # number of clients that were successful
- TOTAL_CNT = 25 # total number of clients that test will create
- TIMEOUT = 60.0 # timeout for this test
-
- A_DATA = b'A' * 1024 * 1024
- B_DATA = b'B' * 1024 * 1024
-
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT, test_utils.ONLYKEY
- )
- client_sslctx = self._create_client_ssl_context()
-
- clients = []
-
- async def handle_client(reader, writer):
- nonlocal CNT
-
- data = await reader.readexactly(len(A_DATA))
- self.assertEqual(data, A_DATA)
- writer.write(b'OK')
-
- data = await reader.readexactly(len(B_DATA))
- self.assertEqual(data, B_DATA)
- writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
-
- await writer.drain()
- writer.close()
-
- CNT += 1
-
- async def test_client(addr):
- fut = asyncio.Future()
-
- def prog(sock):
- try:
- sock.starttls(client_sslctx)
- sock.connect(addr)
- sock.send(A_DATA)
-
- data = sock.recv_all(2)
- self.assertEqual(data, b'OK')
-
- sock.send(B_DATA)
- data = sock.recv_all(4)
- self.assertEqual(data, b'SPAM')
-
- sock.close()
-
- except Exception as ex:
- self.loop.call_soon_threadsafe(fut.set_exception, ex)
- else:
- self.loop.call_soon_threadsafe(fut.set_result, None)
-
- client = self.tcp_client(prog)
- client.start()
- clients.append(client)
-
- await fut
-
- async def start_server():
- extras = {}
- extras = dict(ssl_handshake_timeout=40.0)
-
- srv = await asyncio.start_server(
- handle_client,
- '127.0.0.1', 0,
- family=socket.AF_INET,
- ssl=sslctx,
- **extras)
-
- try:
- srv_socks = srv.sockets
- self.assertTrue(srv_socks)
-
- addr = srv_socks[0].getsockname()
-
- tasks = []
- for _ in range(TOTAL_CNT):
- tasks.append(test_client(addr))
-
- await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
-
- finally:
- self.loop.call_soon(srv.close)
- await srv.wait_closed()
-
- with self._silence_eof_received_warning():
- self.loop.run_until_complete(start_server())
-
- self.assertEqual(CNT, TOTAL_CNT)
-
- for client in clients:
- client.stop()
-
- def test_create_connection_ssl_1(self):
- self.loop.set_exception_handler(None)
-
- CNT = 0
- TOTAL_CNT = 25
-
- A_DATA = b'A' * 1024 * 1024
- B_DATA = b'B' * 1024 * 1024
-
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT,
- test_utils.ONLYKEY
- )
- client_sslctx = self._create_client_ssl_context()
-
- def server(sock):
- sock.starttls(
- sslctx,
- server_side=True)
-
- data = sock.recv_all(len(A_DATA))
- self.assertEqual(data, A_DATA)
- sock.send(b'OK')
-
- data = sock.recv_all(len(B_DATA))
- self.assertEqual(data, B_DATA)
- sock.send(b'SPAM')
-
- sock.close()
-
- async def client(addr):
- extras = {}
- extras = dict(ssl_handshake_timeout=40.0)
-
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- **extras)
-
- writer.write(A_DATA)
- self.assertEqual(await reader.readexactly(2), b'OK')
-
- writer.write(B_DATA)
- self.assertEqual(await reader.readexactly(4), b'SPAM')
-
- nonlocal CNT
- CNT += 1
-
- writer.close()
- await self.wait_closed(writer)
-
- async def client_sock(addr):
- sock = socket.socket()
- sock.connect(addr)
- reader, writer = await asyncio.open_connection(
- sock=sock,
- ssl=client_sslctx,
- server_hostname='')
-
- writer.write(A_DATA)
- self.assertEqual(await reader.readexactly(2), b'OK')
-
- writer.write(B_DATA)
- self.assertEqual(await reader.readexactly(4), b'SPAM')
-
- nonlocal CNT
- CNT += 1
-
- writer.close()
- await self.wait_closed(writer)
- sock.close()
-
- def run(coro):
- nonlocal CNT
- CNT = 0
-
- async def _gather(*tasks):
- # trampoline
- return await asyncio.gather(*tasks)
-
- with self.tcp_server(server,
- max_clients=TOTAL_CNT,
- backlog=TOTAL_CNT) as srv:
- tasks = []
- for _ in range(TOTAL_CNT):
- tasks.append(coro(srv.addr))
-
- self.loop.run_until_complete(_gather(*tasks))
-
- self.assertEqual(CNT, TOTAL_CNT)
-
- with self._silence_eof_received_warning():
- run(client)
-
- with self._silence_eof_received_warning():
- run(client_sock)
-
- def test_create_connection_ssl_slow_handshake(self):
- client_sslctx = self._create_client_ssl_context()
-
- # silence error logger
- self.loop.set_exception_handler(lambda *args: None)
-
- def server(sock):
- try:
- sock.recv_all(1024 * 1024)
- except ConnectionAbortedError:
- pass
- finally:
- sock.close()
-
- async def client(addr):
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- ssl_handshake_timeout=1.0)
- writer.close()
- await self.wait_closed(writer)
-
- with self.tcp_server(server,
- max_clients=1,
- backlog=1) as srv:
-
- with self.assertRaisesRegex(
- ConnectionAbortedError,
- r'SSL handshake.*is taking longer'):
-
- self.loop.run_until_complete(client(srv.addr))
-
- def test_create_connection_ssl_failed_certificate(self):
- # silence error logger
- self.loop.set_exception_handler(lambda *args: None)
-
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT,
- test_utils.ONLYKEY
- )
- client_sslctx = self._create_client_ssl_context(disable_verify=False)
-
- def server(sock):
- try:
- sock.starttls(
- sslctx,
- server_side=True)
- sock.connect()
- except (ssl.SSLError, OSError):
- pass
- finally:
- sock.close()
-
- async def client(addr):
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- ssl_handshake_timeout=1.0)
- writer.close()
- await self.wait_closed(writer)
-
- with self.tcp_server(server,
- max_clients=1,
- backlog=1) as srv:
-
- with self.assertRaises(ssl.SSLCertVerificationError):
- self.loop.run_until_complete(client(srv.addr))
-
- def test_ssl_handshake_timeout(self):
- # bpo-29970: Check that a connection is aborted if handshake is not
- # completed in timeout period, instead of remaining open indefinitely
- client_sslctx = test_utils.simple_client_sslcontext()
-
- # silence error logger
- messages = []
- self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
-
- server_side_aborted = False
-
- def server(sock):
- nonlocal server_side_aborted
- try:
- sock.recv_all(1024 * 1024)
- except ConnectionAbortedError:
- server_side_aborted = True
- finally:
- sock.close()
-
- async def client(addr):
- await asyncio.wait_for(
- self.loop.create_connection(
- asyncio.Protocol,
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- ssl_handshake_timeout=10.0),
- 0.5)
-
- with self.tcp_server(server,
- max_clients=1,
- backlog=1) as srv:
-
- with self.assertRaises(asyncio.TimeoutError):
- self.loop.run_until_complete(client(srv.addr))
-
- self.assertTrue(server_side_aborted)
-
- # Python issue #23197: cancelling a handshake must not raise an
- # exception or log an error, even if the handshake failed
- self.assertEqual(messages, [])
-
- def test_ssl_handshake_connection_lost(self):
- # #246: make sure that no connection_lost() is called before
- # connection_made() is called first
-
- client_sslctx = test_utils.simple_client_sslcontext()
-
- # silence error logger
- self.loop.set_exception_handler(lambda loop, ctx: None)
-
- connection_made_called = False
- connection_lost_called = False
-
- def server(sock):
- sock.recv(1024)
- # break the connection during handshake
- sock.close()
-
- class ClientProto(asyncio.Protocol):
- def connection_made(self, transport):
- nonlocal connection_made_called
- connection_made_called = True
-
- def connection_lost(self, exc):
- nonlocal connection_lost_called
- connection_lost_called = True
-
- async def client(addr):
- await self.loop.create_connection(
- ClientProto,
- *addr,
- ssl=client_sslctx,
- server_hostname=''),
-
- with self.tcp_server(server,
- max_clients=1,
- backlog=1) as srv:
-
- with self.assertRaises(ConnectionResetError):
- self.loop.run_until_complete(client(srv.addr))
-
- if connection_lost_called:
- if connection_made_called:
- self.fail("unexpected call to connection_lost()")
- else:
- self.fail("unexpected call to connection_lost() without"
- "calling connection_made()")
- elif connection_made_called:
- self.fail("unexpected call to connection_made()")
-
- def test_ssl_connect_accepted_socket(self):
- proto = ssl.PROTOCOL_TLS_SERVER
- server_context = ssl.SSLContext(proto)
- server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY)
- if hasattr(server_context, 'check_hostname'):
- server_context.check_hostname = False
- server_context.verify_mode = ssl.CERT_NONE
-
- client_context = ssl.SSLContext(proto)
- if hasattr(server_context, 'check_hostname'):
- client_context.check_hostname = False
- client_context.verify_mode = ssl.CERT_NONE
-
- def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
- loop = self.loop
-
- class MyProto(MyBaseProto):
-
- def connection_lost(self, exc):
- super().connection_lost(exc)
- loop.call_soon(loop.stop)
-
- def data_received(self, data):
- super().data_received(data)
- self.transport.write(expected_response)
-
- lsock = socket.socket(socket.AF_INET)
- lsock.bind(('127.0.0.1', 0))
- lsock.listen(1)
- addr = lsock.getsockname()
-
- message = b'test data'
- response = None
- expected_response = b'roger'
-
- def client():
- nonlocal response
- try:
- csock = socket.socket(socket.AF_INET)
- if client_ssl is not None:
- csock = client_ssl.wrap_socket(csock)
- csock.connect(addr)
- csock.sendall(message)
- response = csock.recv(99)
- csock.close()
- except Exception as exc:
- print(
- "Failure in client thread in test_connect_accepted_socket",
- exc)
-
- thread = threading.Thread(target=client, daemon=True)
- thread.start()
-
- conn, _ = lsock.accept()
- proto = MyProto(loop=loop)
- proto.loop = loop
-
- extras = {}
- if server_ssl:
- extras = dict(ssl_handshake_timeout=10.0)
-
- f = loop.create_task(
- loop.connect_accepted_socket(
- (lambda: proto), conn, ssl=server_ssl,
- **extras))
- loop.run_forever()
- conn.close()
- lsock.close()
-
- thread.join(1)
- self.assertFalse(thread.is_alive())
- self.assertEqual(proto.state, 'CLOSED')
- self.assertEqual(proto.nbytes, len(message))
- self.assertEqual(response, expected_response)
- tr, _ = f.result()
-
- if server_ssl:
- self.assertIn('SSL', tr.__class__.__name__)
-
- tr.close()
- # let it close
- self.loop.run_until_complete(asyncio.sleep(0.1))
-
- def test_start_tls_client_corrupted_ssl(self):
- self.loop.set_exception_handler(lambda loop, ctx: None)
-
- sslctx = test_utils.simple_server_sslcontext()
- client_sslctx = test_utils.simple_client_sslcontext()
-
- def server(sock):
- orig_sock = sock.dup()
- try:
- sock.starttls(
- sslctx,
- server_side=True)
- sock.sendall(b'A\n')
- sock.recv_all(1)
- orig_sock.send(b'please corrupt the SSL connection')
- except ssl.SSLError:
- pass
- finally:
- sock.close()
- orig_sock.close()
-
- async def client(addr):
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='')
-
- self.assertEqual(await reader.readline(), b'A\n')
- writer.write(b'B')
- with self.assertRaises(ssl.SSLError):
- await reader.readline()
- writer.close()
- try:
- await self.wait_closed(writer)
- except ssl.SSLError:
- pass
- return 'OK'
-
- with self.tcp_server(server,
- max_clients=1,
- backlog=1) as srv:
-
- res = self.loop.run_until_complete(client(srv.addr))
-
- self.assertEqual(res, 'OK')
-
- def test_start_tls_client_reg_proto_1(self):
- HELLO_MSG = b'1' * self.PAYLOAD_SIZE
-
- server_context = test_utils.simple_server_sslcontext()
- client_context = test_utils.simple_client_sslcontext()
-
- def serve(sock):
- sock.settimeout(self.TIMEOUT)
-
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.starttls(server_context, server_side=True)
-
- sock.sendall(b'O')
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.unwrap()
- sock.close()
-
- class ClientProto(asyncio.Protocol):
- def __init__(self, on_data, on_eof):
- self.on_data = on_data
- self.on_eof = on_eof
- self.con_made_cnt = 0
-
- def connection_made(proto, tr):
- proto.con_made_cnt += 1
- # Ensure connection_made gets called only once.
- self.assertEqual(proto.con_made_cnt, 1)
-
- def data_received(self, data):
- self.on_data.set_result(data)
-
- def eof_received(self):
- self.on_eof.set_result(True)
-
- async def client(addr):
- await asyncio.sleep(0.5)
-
- on_data = self.loop.create_future()
- on_eof = self.loop.create_future()
-
- tr, proto = await self.loop.create_connection(
- lambda: ClientProto(on_data, on_eof), *addr)
-
- tr.write(HELLO_MSG)
- new_tr = await self.loop.start_tls(tr, proto, client_context)
-
- self.assertEqual(await on_data, b'O')
- new_tr.write(HELLO_MSG)
- await on_eof
-
- new_tr.close()
-
- with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
- self.loop.run_until_complete(
- asyncio.wait_for(client(srv.addr), timeout=10))
-
- def test_create_connection_memory_leak(self):
- HELLO_MSG = b'1' * self.PAYLOAD_SIZE
-
- server_context = self._create_server_ssl_context(
- test_utils.ONLYCERT, test_utils.ONLYKEY)
- client_context = self._create_client_ssl_context()
-
- def serve(sock):
- sock.settimeout(self.TIMEOUT)
-
- sock.starttls(server_context, server_side=True)
-
- sock.sendall(b'O')
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.unwrap()
- sock.close()
-
- class ClientProto(asyncio.Protocol):
- def __init__(self, on_data, on_eof):
- self.on_data = on_data
- self.on_eof = on_eof
- self.con_made_cnt = 0
-
- def connection_made(proto, tr):
- # XXX: We assume user stores the transport in protocol
- proto.tr = tr
- proto.con_made_cnt += 1
- # Ensure connection_made gets called only once.
- self.assertEqual(proto.con_made_cnt, 1)
-
- def data_received(self, data):
- self.on_data.set_result(data)
-
- def eof_received(self):
- self.on_eof.set_result(True)
-
- async def client(addr):
- await asyncio.sleep(0.5)
-
- on_data = self.loop.create_future()
- on_eof = self.loop.create_future()
-
- tr, proto = await self.loop.create_connection(
- lambda: ClientProto(on_data, on_eof), *addr,
- ssl=client_context)
-
- self.assertEqual(await on_data, b'O')
- tr.write(HELLO_MSG)
- await on_eof
-
- tr.close()
-
- with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
- self.loop.run_until_complete(
- asyncio.wait_for(client(srv.addr), timeout=10))
-
- # No garbage is left for SSL client from loop.create_connection, even
- # if user stores the SSLTransport in corresponding protocol instance
- client_context = weakref.ref(client_context)
- self.assertIsNone(client_context())
-
- def test_start_tls_client_buf_proto_1(self):
- HELLO_MSG = b'1' * self.PAYLOAD_SIZE
-
- server_context = test_utils.simple_server_sslcontext()
- client_context = test_utils.simple_client_sslcontext()
-
- client_con_made_calls = 0
-
- def serve(sock):
- sock.settimeout(self.TIMEOUT)
-
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.starttls(server_context, server_side=True)
-
- sock.sendall(b'O')
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.sendall(b'2')
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.unwrap()
- sock.close()
-
- class ClientProtoFirst(asyncio.BufferedProtocol):
- def __init__(self, on_data):
- self.on_data = on_data
- self.buf = bytearray(1)
-
- def connection_made(self, tr):
- nonlocal client_con_made_calls
- client_con_made_calls += 1
-
- def get_buffer(self, sizehint):
- return self.buf
-
- def buffer_updated(self, nsize):
- assert nsize == 1
- self.on_data.set_result(bytes(self.buf[:nsize]))
-
- def eof_received(self):
- pass
-
- class ClientProtoSecond(asyncio.Protocol):
- def __init__(self, on_data, on_eof):
- self.on_data = on_data
- self.on_eof = on_eof
- self.con_made_cnt = 0
-
- def connection_made(self, tr):
- nonlocal client_con_made_calls
- client_con_made_calls += 1
-
- def data_received(self, data):
- self.on_data.set_result(data)
-
- def eof_received(self):
- self.on_eof.set_result(True)
-
- async def client(addr):
- await asyncio.sleep(0.5)
-
- on_data1 = self.loop.create_future()
- on_data2 = self.loop.create_future()
- on_eof = self.loop.create_future()
-
- tr, proto = await self.loop.create_connection(
- lambda: ClientProtoFirst(on_data1), *addr)
-
- tr.write(HELLO_MSG)
- new_tr = await self.loop.start_tls(tr, proto, client_context)
-
- self.assertEqual(await on_data1, b'O')
- new_tr.write(HELLO_MSG)
-
- new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
- self.assertEqual(await on_data2, b'2')
- new_tr.write(HELLO_MSG)
- await on_eof
-
- new_tr.close()
-
- # connection_made() should be called only once -- when
- # we establish connection for the first time. Start TLS
- # doesn't call connection_made() on application protocols.
- self.assertEqual(client_con_made_calls, 1)
-
- with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
- self.loop.run_until_complete(
- asyncio.wait_for(client(srv.addr),
- timeout=self.TIMEOUT))
-
- def test_start_tls_slow_client_cancel(self):
- HELLO_MSG = b'1' * self.PAYLOAD_SIZE
-
- client_context = test_utils.simple_client_sslcontext()
- server_waits_on_handshake = self.loop.create_future()
-
- def serve(sock):
- sock.settimeout(self.TIMEOUT)
-
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- try:
- self.loop.call_soon_threadsafe(
- server_waits_on_handshake.set_result, None)
- data = sock.recv_all(1024 * 1024)
- except ConnectionAbortedError:
- pass
- finally:
- sock.close()
-
- class ClientProto(asyncio.Protocol):
- def __init__(self, on_data, on_eof):
- self.on_data = on_data
- self.on_eof = on_eof
- self.con_made_cnt = 0
-
- def connection_made(proto, tr):
- proto.con_made_cnt += 1
- # Ensure connection_made gets called only once.
- self.assertEqual(proto.con_made_cnt, 1)
-
- def data_received(self, data):
- self.on_data.set_result(data)
-
- def eof_received(self):
- self.on_eof.set_result(True)
-
- async def client(addr):
- await asyncio.sleep(0.5)
-
- on_data = self.loop.create_future()
- on_eof = self.loop.create_future()
-
- tr, proto = await self.loop.create_connection(
- lambda: ClientProto(on_data, on_eof), *addr)
-
- tr.write(HELLO_MSG)
-
- await server_waits_on_handshake
-
- with self.assertRaises(asyncio.TimeoutError):
- await asyncio.wait_for(
- self.loop.start_tls(tr, proto, client_context),
- 0.5)
-
- with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
- self.loop.run_until_complete(
- asyncio.wait_for(client(srv.addr), timeout=10))
-
- def test_start_tls_server_1(self):
- HELLO_MSG = b'1' * self.PAYLOAD_SIZE
-
- server_context = test_utils.simple_server_sslcontext()
- client_context = test_utils.simple_client_sslcontext()
-
- def client(sock, addr):
- sock.settimeout(self.TIMEOUT)
-
- sock.connect(addr)
- data = sock.recv_all(len(HELLO_MSG))
- self.assertEqual(len(data), len(HELLO_MSG))
-
- sock.starttls(client_context)
- sock.sendall(HELLO_MSG)
-
- sock.unwrap()
- sock.close()
-
- class ServerProto(asyncio.Protocol):
- def __init__(self, on_con, on_eof, on_con_lost):
- self.on_con = on_con
- self.on_eof = on_eof
- self.on_con_lost = on_con_lost
- self.data = b''
-
- def connection_made(self, tr):
- self.on_con.set_result(tr)
-
- def data_received(self, data):
- self.data += data
-
- def eof_received(self):
- self.on_eof.set_result(1)
-
- def connection_lost(self, exc):
- if exc is None:
- self.on_con_lost.set_result(None)
- else:
- self.on_con_lost.set_exception(exc)
-
- async def main(proto, on_con, on_eof, on_con_lost):
- tr = await on_con
- tr.write(HELLO_MSG)
-
- self.assertEqual(proto.data, b'')
-
- new_tr = await self.loop.start_tls(
- tr, proto, server_context,
- server_side=True,
- ssl_handshake_timeout=self.TIMEOUT)
-
- await on_eof
- await on_con_lost
- self.assertEqual(proto.data, HELLO_MSG)
- new_tr.close()
-
- async def run_main():
- on_con = self.loop.create_future()
- on_eof = self.loop.create_future()
- on_con_lost = self.loop.create_future()
- proto = ServerProto(on_con, on_eof, on_con_lost)
-
- server = await self.loop.create_server(
- lambda: proto, '127.0.0.1', 0)
- addr = server.sockets[0].getsockname()
-
- with self.tcp_client(lambda sock: client(sock, addr),
- timeout=self.TIMEOUT):
- await asyncio.wait_for(
- main(proto, on_con, on_eof, on_con_lost),
- timeout=self.TIMEOUT)
-
- server.close()
- await server.wait_closed()
-
- self.loop.run_until_complete(run_main())
-
- def test_create_server_ssl_over_ssl(self):
- CNT = 0 # number of clients that were successful
- TOTAL_CNT = 25 # total number of clients that test will create
- TIMEOUT = 10.0 # timeout for this test
-
- A_DATA = b'A' * 1024 * 1024
- B_DATA = b'B' * 1024 * 1024
-
- sslctx_1 = self._create_server_ssl_context(
- test_utils.ONLYCERT, test_utils.ONLYKEY)
- client_sslctx_1 = self._create_client_ssl_context()
- sslctx_2 = self._create_server_ssl_context(
- test_utils.ONLYCERT, test_utils.ONLYKEY)
- client_sslctx_2 = self._create_client_ssl_context()
-
- clients = []
-
- async def handle_client(reader, writer):
- nonlocal CNT
-
- data = await reader.readexactly(len(A_DATA))
- self.assertEqual(data, A_DATA)
- writer.write(b'OK')
-
- data = await reader.readexactly(len(B_DATA))
- self.assertEqual(data, B_DATA)
- writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
-
- await writer.drain()
- writer.close()
-
- CNT += 1
-
- class ServerProtocol(asyncio.StreamReaderProtocol):
- def connection_made(self, transport):
- super_ = super()
- transport.pause_reading()
- fut = self._loop.create_task(self._loop.start_tls(
- transport, self, sslctx_2, server_side=True))
-
- def cb(_):
- try:
- tr = fut.result()
- except Exception as ex:
- super_.connection_lost(ex)
- else:
- super_.connection_made(tr)
- fut.add_done_callback(cb)
-
- def server_protocol_factory():
- reader = asyncio.StreamReader()
- protocol = ServerProtocol(reader, handle_client)
- return protocol
-
- async def test_client(addr):
- fut = asyncio.Future()
-
- def prog(sock):
- try:
- sock.connect(addr)
- sock.starttls(client_sslctx_1)
-
- # because wrap_socket() doesn't work correctly on
- # SSLSocket, we have to do the 2nd level SSL manually
- incoming = ssl.MemoryBIO()
- outgoing = ssl.MemoryBIO()
- sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)
-
- def do(func, *args):
- while True:
- try:
- rv = func(*args)
- break
- except ssl.SSLWantReadError:
- if outgoing.pending:
- sock.send(outgoing.read())
- incoming.write(sock.recv(65536))
- if outgoing.pending:
- sock.send(outgoing.read())
- return rv
-
- do(sslobj.do_handshake)
-
- do(sslobj.write, A_DATA)
- data = do(sslobj.read, 2)
- self.assertEqual(data, b'OK')
-
- do(sslobj.write, B_DATA)
- data = b''
- while True:
- chunk = do(sslobj.read, 4)
- if not chunk:
- break
- data += chunk
- self.assertEqual(data, b'SPAM')
-
- do(sslobj.unwrap)
- sock.close()
-
- except Exception as ex:
- self.loop.call_soon_threadsafe(fut.set_exception, ex)
- sock.close()
- else:
- self.loop.call_soon_threadsafe(fut.set_result, None)
-
- client = self.tcp_client(prog)
- client.start()
- clients.append(client)
-
- await fut
-
- async def start_server():
- extras = {}
-
- srv = await self.loop.create_server(
- server_protocol_factory,
- '127.0.0.1', 0,
- family=socket.AF_INET,
- ssl=sslctx_1,
- **extras)
-
- try:
- srv_socks = srv.sockets
- self.assertTrue(srv_socks)
-
- addr = srv_socks[0].getsockname()
-
- tasks = []
- for _ in range(TOTAL_CNT):
- tasks.append(test_client(addr))
-
- await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
-
- finally:
- self.loop.call_soon(srv.close)
- await srv.wait_closed()
-
- with self._silence_eof_received_warning():
- self.loop.run_until_complete(start_server())
-
- self.assertEqual(CNT, TOTAL_CNT)
-
- for client in clients:
- client.stop()
-
- def test_shutdown_cleanly(self):
- CNT = 0
- TOTAL_CNT = 25
-
- A_DATA = b'A' * 1024 * 1024
-
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT, test_utils.ONLYKEY)
- client_sslctx = self._create_client_ssl_context()
-
- def server(sock):
- sock.starttls(
- sslctx,
- server_side=True)
-
- data = sock.recv_all(len(A_DATA))
- self.assertEqual(data, A_DATA)
- sock.send(b'OK')
-
- sock.unwrap()
-
- sock.close()
-
- async def client(addr):
- extras = {}
- extras = dict(ssl_handshake_timeout=10.0)
-
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='',
- **extras)
-
- writer.write(A_DATA)
- self.assertEqual(await reader.readexactly(2), b'OK')
-
- self.assertEqual(await reader.read(), b'')
-
- nonlocal CNT
- CNT += 1
-
- writer.close()
- await self.wait_closed(writer)
-
- def run(coro):
- nonlocal CNT
- CNT = 0
-
- async def _gather(*tasks):
- return await asyncio.gather(*tasks)
-
- with self.tcp_server(server,
- max_clients=TOTAL_CNT,
- backlog=TOTAL_CNT) as srv:
- tasks = []
- for _ in range(TOTAL_CNT):
- tasks.append(coro(srv.addr))
-
- self.loop.run_until_complete(
- _gather(*tasks))
-
- self.assertEqual(CNT, TOTAL_CNT)
-
- with self._silence_eof_received_warning():
- run(client)
-
- def test_flush_before_shutdown(self):
- CHUNK = 1024 * 128
- SIZE = 32
-
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT, test_utils.ONLYKEY)
- client_sslctx = self._create_client_ssl_context()
- if hasattr(ssl, 'OP_NO_TLSv1_3'):
- client_sslctx.options |= ssl.OP_NO_TLSv1_3
-
- future = None
-
- def server(sock):
- sock.starttls(sslctx, server_side=True)
- self.assertEqual(sock.recv_all(4), b'ping')
- sock.send(b'pong')
- time.sleep(0.5) # hopefully stuck the TCP buffer
- data = sock.recv_all(CHUNK * SIZE)
- self.assertEqual(len(data), CHUNK * SIZE)
- sock.close()
-
- def run(meth):
- def wrapper(sock):
- try:
- meth(sock)
- except Exception as ex:
- self.loop.call_soon_threadsafe(future.set_exception, ex)
- else:
- self.loop.call_soon_threadsafe(future.set_result, None)
- return wrapper
-
- async def client(addr):
- nonlocal future
- future = self.loop.create_future()
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='')
- sslprotocol = writer.transport._ssl_protocol
- writer.write(b'ping')
- data = await reader.readexactly(4)
- self.assertEqual(data, b'pong')
-
- sslprotocol.pause_writing()
- for _ in range(SIZE):
- writer.write(b'x' * CHUNK)
-
- writer.close()
- sslprotocol.resume_writing()
-
- await self.wait_closed(writer)
- try:
- data = await reader.read()
- self.assertEqual(data, b'')
- except ConnectionResetError:
- pass
- await future
-
- with self.tcp_server(run(server)) as srv:
- self.loop.run_until_complete(client(srv.addr))
-
- def test_remote_shutdown_receives_trailing_data(self):
- CHUNK = 1024 * 128
- SIZE = 32
-
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT,
- test_utils.ONLYKEY
- )
- client_sslctx = self._create_client_ssl_context()
- future = None
-
- def server(sock):
- incoming = ssl.MemoryBIO()
- outgoing = ssl.MemoryBIO()
- sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
-
- while True:
- try:
- sslobj.do_handshake()
- except ssl.SSLWantReadError:
- if outgoing.pending:
- sock.send(outgoing.read())
- incoming.write(sock.recv(16384))
- else:
- if outgoing.pending:
- sock.send(outgoing.read())
- break
-
- while True:
- try:
- data = sslobj.read(4)
- except ssl.SSLWantReadError:
- incoming.write(sock.recv(16384))
- else:
- break
-
- self.assertEqual(data, b'ping')
- sslobj.write(b'pong')
- sock.send(outgoing.read())
-
- time.sleep(0.2) # wait for the peer to fill its backlog
-
- # send close_notify but don't wait for response
- with self.assertRaises(ssl.SSLWantReadError):
- sslobj.unwrap()
- sock.send(outgoing.read())
-
- # should receive all data
- data_len = 0
- while True:
- try:
- chunk = len(sslobj.read(16384))
- data_len += chunk
- except ssl.SSLWantReadError:
- incoming.write(sock.recv(16384))
- except ssl.SSLZeroReturnError:
- break
-
- self.assertEqual(data_len, CHUNK * SIZE)
-
- # verify that close_notify is received
- sslobj.unwrap()
-
- sock.close()
-
- def eof_server(sock):
- sock.starttls(sslctx, server_side=True)
- self.assertEqual(sock.recv_all(4), b'ping')
- sock.send(b'pong')
-
- time.sleep(0.2) # wait for the peer to fill its backlog
-
- # send EOF
- sock.shutdown(socket.SHUT_WR)
-
- # should receive all data
- data = sock.recv_all(CHUNK * SIZE)
- self.assertEqual(len(data), CHUNK * SIZE)
-
- sock.close()
-
- async def client(addr):
- nonlocal future
- future = self.loop.create_future()
-
- reader, writer = await asyncio.open_connection(
- *addr,
- ssl=client_sslctx,
- server_hostname='')
- writer.write(b'ping')
- data = await reader.readexactly(4)
- self.assertEqual(data, b'pong')
-
- # fill write backlog in a hacky way - renegotiation won't help
- for _ in range(SIZE):
- writer.transport._test__append_write_backlog(b'x' * CHUNK)
-
- try:
- data = await reader.read()
- self.assertEqual(data, b'')
- except (BrokenPipeError, ConnectionResetError):
- pass
-
- await future
-
- writer.close()
- await self.wait_closed(writer)
-
- def run(meth):
- def wrapper(sock):
- try:
- meth(sock)
- except Exception as ex:
- self.loop.call_soon_threadsafe(future.set_exception, ex)
- else:
- self.loop.call_soon_threadsafe(future.set_result, None)
- return wrapper
-
- with self.tcp_server(run(server)) as srv:
- self.loop.run_until_complete(client(srv.addr))
-
- with self.tcp_server(run(eof_server)) as srv:
- self.loop.run_until_complete(client(srv.addr))
-
- def test_connect_timeout_warning(self):
- s = socket.socket(socket.AF_INET)
- s.bind(('127.0.0.1', 0))
- addr = s.getsockname()
-
- async def test():
- try:
- await asyncio.wait_for(
- self.loop.create_connection(asyncio.Protocol,
- *addr, ssl=True),
- 0.1)
- except (ConnectionRefusedError, asyncio.TimeoutError):
- pass
- else:
- self.fail('TimeoutError is not raised')
-
- with s:
- try:
- with self.assertWarns(ResourceWarning) as cm:
- self.loop.run_until_complete(test())
- gc.collect()
- gc.collect()
- gc.collect()
- except AssertionError as e:
- self.assertEqual(str(e), 'ResourceWarning not triggered')
- else:
- self.fail('Unexpected ResourceWarning: {}'.format(cm.warning))
-
- def test_handshake_timeout_handler_leak(self):
- s = socket.socket(socket.AF_INET)
- s.bind(('127.0.0.1', 0))
- s.listen(1)
- addr = s.getsockname()
-
- async def test(ctx):
- try:
- await asyncio.wait_for(
- self.loop.create_connection(asyncio.Protocol, *addr,
- ssl=ctx),
- 0.1)
- except (ConnectionRefusedError, asyncio.TimeoutError):
- pass
- else:
- self.fail('TimeoutError is not raised')
-
- with s:
- ctx = ssl.create_default_context()
- self.loop.run_until_complete(test(ctx))
- ctx = weakref.ref(ctx)
-
- # SSLProtocol should be DECREF to 0
- self.assertIsNone(ctx())
-
- def test_shutdown_timeout_handler_leak(self):
- loop = self.loop
-
- def server(sock):
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT,
- test_utils.ONLYKEY
- )
- sock = sslctx.wrap_socket(sock, server_side=True)
- sock.recv(32)
- sock.close()
-
- class Protocol(asyncio.Protocol):
- def __init__(self):
- self.fut = asyncio.Future(loop=loop)
-
- def connection_lost(self, exc):
- self.fut.set_result(None)
-
- async def client(addr, ctx):
- tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
- tr.close()
- await pr.fut
-
- with self.tcp_server(server) as srv:
- ctx = self._create_client_ssl_context()
- loop.run_until_complete(client(srv.addr, ctx))
- ctx = weakref.ref(ctx)
-
- # asyncio has no shutdown timeout, but it ends up with a circular
- # reference loop - not ideal (introduces gc glitches), but at least
- # not leaking
- gc.collect()
- gc.collect()
- gc.collect()
-
- # SSLProtocol should be DECREF to 0
- self.assertIsNone(ctx())
-
- def test_shutdown_timeout_handler_not_set(self):
- loop = self.loop
- eof = asyncio.Event()
- extra = None
-
- def server(sock):
- sslctx = self._create_server_ssl_context(
- test_utils.ONLYCERT,
- test_utils.ONLYKEY
- )
- sock = sslctx.wrap_socket(sock, server_side=True)
- sock.send(b'hello')
- assert sock.recv(1024) == b'world'
- sock.send(b'extra bytes')
- # sending EOF here
- sock.shutdown(socket.SHUT_WR)
- loop.call_soon_threadsafe(eof.set)
- # make sure we have enough time to reproduce the issue
- assert sock.recv(1024) == b''
- sock.close()
-
- class Protocol(asyncio.Protocol):
- def __init__(self):
- self.fut = asyncio.Future(loop=loop)
- self.transport = None
-
- def connection_made(self, transport):
- self.transport = transport
-
- def data_received(self, data):
- if data == b'hello':
- self.transport.write(b'world')
- # pause reading would make incoming data stay in the sslobj
- self.transport.pause_reading()
- else:
- nonlocal extra
- extra = data
-
- def connection_lost(self, exc):
- if exc is None:
- self.fut.set_result(None)
- else:
- self.fut.set_exception(exc)
-
- async def client(addr):
- ctx = self._create_client_ssl_context()
- tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
- await eof.wait()
- tr.resume_reading()
- await pr.fut
- tr.close()
- assert extra == b'extra bytes'
-
- with self.tcp_server(server) as srv:
- loop.run_until_complete(client(srv.addr))
-
-
-###############################################################################
-# Socket Testing Utilities
-###############################################################################
-
-
-class TestSocketWrapper:
-
- def __init__(self, sock):
- self.__sock = sock
-
- def recv_all(self, n):
- buf = b''
- while len(buf) < n:
- data = self.recv(n - len(buf))
- if data == b'':
- raise ConnectionAbortedError
- buf += data
- return buf
-
- def starttls(self, ssl_context, *,
- server_side=False,
- server_hostname=None,
- do_handshake_on_connect=True):
-
- assert isinstance(ssl_context, ssl.SSLContext)
-
- ssl_sock = ssl_context.wrap_socket(
- self.__sock, server_side=server_side,
- server_hostname=server_hostname,
- do_handshake_on_connect=do_handshake_on_connect)
-
- if server_side:
- ssl_sock.do_handshake()
-
- self.__sock.close()
- self.__sock = ssl_sock
-
- def __getattr__(self, name):
- return getattr(self.__sock, name)
-
- def __repr__(self):
- return '<{} {!r}>'.format(type(self).__name__, self.__sock)
-
-
-class SocketThread(threading.Thread):
-
- def stop(self):
- self._active = False
- self.join()
-
- def __enter__(self):
- self.start()
- return self
-
- def __exit__(self, *exc):
- self.stop()
-
-
-class TestThreadedClient(SocketThread):
-
- def __init__(self, test, sock, prog, timeout):
- threading.Thread.__init__(self, None, None, 'test-client')
- self.daemon = True
-
- self._timeout = timeout
- self._sock = sock
- self._active = True
- self._prog = prog
- self._test = test
-
- def run(self):
- try:
- self._prog(TestSocketWrapper(self._sock))
- except (KeyboardInterrupt, SystemExit):
- raise
- except BaseException as ex:
- self._test._abort_socket_test(ex)
-
-
-class TestThreadedServer(SocketThread):
-
- def __init__(self, test, sock, prog, timeout, max_clients):
- threading.Thread.__init__(self, None, None, 'test-server')
- self.daemon = True
-
- self._clients = 0
- self._finished_clients = 0
- self._max_clients = max_clients
- self._timeout = timeout
- self._sock = sock
- self._active = True
-
- self._prog = prog
-
- self._s1, self._s2 = socket.socketpair()
- self._s1.setblocking(False)
-
- self._test = test
-
- def stop(self):
- try:
- if self._s2 and self._s2.fileno() != -1:
- try:
- self._s2.send(b'stop')
- except OSError:
- pass
- finally:
- super().stop()
-
- def run(self):
- try:
- with self._sock:
- self._sock.setblocking(0)
- self._run()
- finally:
- self._s1.close()
- self._s2.close()
-
- def _run(self):
- while self._active:
- if self._clients >= self._max_clients:
- return
-
- r, w, x = select.select(
- [self._sock, self._s1], [], [], self._timeout)
-
- if self._s1 in r:
- return
-
- if self._sock in r:
- try:
- conn, addr = self._sock.accept()
- except BlockingIOError:
- continue
- except socket.timeout:
- if not self._active:
- return
- else:
- raise
- else:
- self._clients += 1
- conn.settimeout(self._timeout)
- try:
- with conn:
- self._handle_client(conn)
- except (KeyboardInterrupt, SystemExit):
- raise
- except BaseException as ex:
- self._active = False
- try:
- raise
- finally:
- self._test._abort_socket_test(ex)
-
- def _handle_client(self, sock):
- self._prog(TestSocketWrapper(sock))
-
- @property
- def addr(self):
- return self._sock.getsockname()
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py
index 79a81bd..e87863e 100644
--- a/Lib/test/test_asyncio/test_sslproto.py
+++ b/Lib/test/test_asyncio/test_sslproto.py
@@ -15,6 +15,7 @@ import asyncio
from asyncio import log
from asyncio import protocols
from asyncio import sslproto
+from test import support
from test.test_asyncio import utils as test_utils
from test.test_asyncio import functional as func_tests
@@ -43,13 +44,16 @@ class SslProtoHandshakeTests(test_utils.TestCase):
def connection_made(self, ssl_proto, *, do_handshake=None):
transport = mock.Mock()
- sslobj = mock.Mock()
- # emulate reading decompressed data
- sslobj.read.side_effect = ssl.SSLWantReadError
- if do_handshake is not None:
- sslobj.do_handshake = do_handshake
- ssl_proto._sslobj = sslobj
- ssl_proto.connection_made(transport)
+ sslpipe = mock.Mock()
+ sslpipe.shutdown.return_value = b''
+ if do_handshake:
+ sslpipe.do_handshake.side_effect = do_handshake
+ else:
+ def mock_handshake(callback):
+ return []
+ sslpipe.do_handshake.side_effect = mock_handshake
+ with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
+ ssl_proto.connection_made(transport)
return transport
def test_handshake_timeout_zero(self):
@@ -71,10 +75,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
def test_eof_received_waiter(self):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
- self.connection_made(
- ssl_proto,
- do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
- )
+ self.connection_made(ssl_proto)
ssl_proto.eof_received()
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionResetError)
@@ -99,10 +100,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
# yield from waiter hang if lost_connection was called.
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
- self.connection_made(
- ssl_proto,
- do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
- )
+ self.connection_made(ssl_proto)
ssl_proto.connection_lost(ConnectionAbortedError)
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
@@ -112,10 +110,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
- transport = self.connection_made(
- ssl_proto,
- do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
- )
+ transport = self.connection_made(ssl_proto)
test_utils.run_briefly(self.loop)
ssl_proto._app_transport.close()
@@ -148,7 +143,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
transp.close()
# should not raise
- self.assertIsNone(ssl_proto.buffer_updated(5))
+ self.assertIsNone(ssl_proto.data_received(b'data'))
def test_write_after_closing(self):
ssl_proto = self.ssl_protocol()