From f111b3dcb414093a4efb9d74b69925e535ddc470 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Sat, 30 Dec 2017 00:35:36 -0500 Subject: bpo-23749: Implement loop.start_tls() (#5039) --- Doc/library/asyncio-eventloop.rst | 32 +++ Lib/asyncio/base_events.py | 45 +++- Lib/asyncio/events.py | 11 + Lib/asyncio/proactor_events.py | 2 + Lib/asyncio/selector_events.py | 2 + Lib/test/test_asyncio/functional.py | 279 +++++++++++++++++++++ Lib/test/test_asyncio/test_events.py | 67 ++--- Lib/test/test_asyncio/test_sslproto.py | 152 +++++++++++ Lib/test/test_asyncio/utils.py | 43 ++++ .../2017-12-29-00-44-42.bpo-23749.QL1Cxd.rst | 1 + 10 files changed, 580 insertions(+), 54 deletions(-) create mode 100644 Lib/test/test_asyncio/functional.py create mode 100644 Misc/NEWS.d/next/Library/2017-12-29-00-44-42.bpo-23749.QL1Cxd.rst diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst index 5dd258d..33b86d6 100644 --- a/Doc/library/asyncio-eventloop.rst +++ b/Doc/library/asyncio-eventloop.rst @@ -537,6 +537,38 @@ Creating listening connections .. versionadded:: 3.5.3 +TLS Upgrade +----------- + +.. coroutinemethod:: AbstractEventLoop.start_tls(transport, protocol, sslcontext, \*, server_side=False, server_hostname=None, ssl_handshake_timeout=None) + + Upgrades an existing connection to TLS. + + Returns a new transport instance, that the *protocol* must start using + immediately after the *await*. The *transport* instance passed to + the *start_tls* method should never be used again. + + Parameters: + + * *transport* and *protocol* instances that methods like + :meth:`~AbstractEventLoop.create_server` and + :meth:`~AbstractEventLoop.create_connection` return. + + * *sslcontext*: a configured instance of :class:`~ssl.SSLContext`. + + * *server_side* pass ``True`` when a server-side connection is being + upgraded (like the one created by :meth:`~AbstractEventLoop.create_server`). + + * *server_hostname*: sets or overrides the host name that the target + server's certificate will be matched against. + + * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to + wait for the SSL handshake to complete before aborting the connection. + ``10.0`` seconds if ``None`` (default). + + .. versionadded:: 3.7 + + Watch file descriptors ---------------------- diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 96cc4f0..00831b3 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -29,9 +29,15 @@ import sys import warnings import weakref +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + from . import coroutines from . import events from . import futures +from . import sslproto from . import tasks from .log import logger @@ -279,7 +285,8 @@ class BaseEventLoop(events.AbstractEventLoop): self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + call_connection_made=True): """Create SSL transport.""" raise NotImplementedError @@ -795,6 +802,42 @@ class BaseEventLoop(events.AbstractEventLoop): return transport, protocol + async def start_tls(self, transport, protocol, sslcontext, *, + server_side=False, + server_hostname=None, + ssl_handshake_timeout=None): + """Upgrade transport to TLS. + + Return a new transport that *protocol* should start using + immediately. + """ + if ssl is None: + raise RuntimeError('Python ssl module is not available') + + if not isinstance(sslcontext, ssl.SSLContext): + raise TypeError( + f'sslcontext is expected to be an instance of ssl.SSLContext, ' + f'got {sslcontext!r}') + + if not getattr(transport, '_start_tls_compatible', False): + raise TypeError( + f'transport {self!r} is not supported by start_tls()') + + waiter = self.create_future() + ssl_protocol = sslproto.SSLProtocol( + self, protocol, sslcontext, waiter, + server_side, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + call_connection_made=False) + + transport.set_protocol(ssl_protocol) + self.call_soon(ssl_protocol.connection_made, transport) + if not transport.is_reading(): + self.call_soon(transport.resume_reading) + + await waiter + return ssl_protocol._app_transport + async def create_datagram_endpoint(self, protocol_factory, local_addr=None, remote_addr=None, *, family=0, proto=0, flags=0, diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index 3a5dbad..9496d5c 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -305,6 +305,17 @@ class AbstractEventLoop: """ raise NotImplementedError + async def start_tls(self, transport, protocol, sslcontext, *, + server_side=False, + server_hostname=None, + ssl_handshake_timeout=None): + """Upgrade a transport to TLS. + + Return a new transport that *protocol* should start using + immediately. + """ + raise NotImplementedError + async def create_unix_connection( self, protocol_factory, path=None, *, ssl=None, sock=None, diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 2661cdd..ab1285b 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -223,6 +223,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, transports.WriteTransport): """Transport for write pipes.""" + _start_tls_compatible = True + def write(self, data): if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError( diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 1e4bd83..5692e38 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -694,6 +694,8 @@ class _SelectorTransport(transports._FlowControlMixin, class _SelectorSocketTransport(_SelectorTransport): + _start_tls_compatible = True + def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): super().__init__(loop, sock, protocol, extra, server) diff --git a/Lib/test/test_asyncio/functional.py b/Lib/test/test_asyncio/functional.py new file mode 100644 index 0000000..5fd174b --- /dev/null +++ b/Lib/test/test_asyncio/functional.py @@ -0,0 +1,279 @@ +import asyncio +import asyncio.events +import contextlib +import os +import pprint +import select +import socket +import ssl +import tempfile +import threading + + +class FunctionalTestCaseMixin: + + def new_loop(self): + return asyncio.new_event_loop() + + def run_loop_briefly(self, *, delay=0.01): + self.loop.run_until_complete(asyncio.sleep(delay, loop=self.loop)) + + def loop_exception_handler(self, loop, context): + self.__unhandled_exceptions.append(context) + self.loop.default_exception_handler(context) + + def setUp(self): + self.loop = self.new_loop() + asyncio.set_event_loop(None) + + self.loop.set_exception_handler(self.loop_exception_handler) + self.__unhandled_exceptions = [] + + # Disable `_get_running_loop`. + self._old_get_running_loop = asyncio.events._get_running_loop + asyncio.events._get_running_loop = lambda: None + + def tearDown(self): + try: + self.loop.close() + + if self.__unhandled_exceptions: + print('Unexpected calls to loop.call_exception_handler():') + pprint.pprint(self.__unhandled_exceptions) + self.fail('unexpected calls to loop.call_exception_handler()') + + finally: + asyncio.events._get_running_loop = self._old_get_running_loop + asyncio.set_event_loop(None) + self.loop = None + + def tcp_server(self, server_prog, *, + family=socket.AF_INET, + addr=None, + timeout=5, + backlog=1, + max_clients=10): + + if addr is None: + if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX: + with tempfile.NamedTemporaryFile() as tmp: + addr = tmp.name + else: + addr = ('127.0.0.1', 0) + + sock = socket.socket(family, socket.SOCK_STREAM) + + if timeout is None: + raise RuntimeError('timeout is required') + if timeout <= 0: + raise RuntimeError('only blocking sockets are supported') + sock.settimeout(timeout) + + try: + sock.bind(addr) + sock.listen(backlog) + except OSError as ex: + sock.close() + raise ex + + return TestThreadedServer( + self, sock, server_prog, timeout, max_clients) + + def tcp_client(self, client_prog, + family=socket.AF_INET, + timeout=10): + + sock = socket.socket(family, socket.SOCK_STREAM) + + if timeout is None: + raise RuntimeError('timeout is required') + if timeout <= 0: + raise RuntimeError('only blocking sockets are supported') + sock.settimeout(timeout) + + return TestThreadedClient( + self, sock, client_prog, timeout) + + def unix_server(self, *args, **kwargs): + if not hasattr(socket, 'AF_UNIX'): + raise NotImplementedError + return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) + + def unix_client(self, *args, **kwargs): + if not hasattr(socket, 'AF_UNIX'): + raise NotImplementedError + return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) + + @contextlib.contextmanager + def unix_sock_name(self): + with tempfile.TemporaryDirectory() as td: + fn = os.path.join(td, 'sock') + try: + yield fn + finally: + try: + os.unlink(fn) + except OSError: + pass + + def _abort_socket_test(self, ex): + try: + self.loop.stop() + finally: + self.fail(ex) + + +############################################################################## +# Socket Testing Utilities +############################################################################## + + +class TestSocketWrapper: + + def __init__(self, sock): + self.__sock = sock + + def recv_all(self, n): + buf = b'' + while len(buf) < n: + data = self.recv(n - len(buf)) + if data == b'': + raise ConnectionAbortedError + buf += data + return buf + + def start_tls(self, ssl_context, *, + server_side=False, + server_hostname=None): + + assert isinstance(ssl_context, ssl.SSLContext) + + ssl_sock = ssl_context.wrap_socket( + self.__sock, server_side=server_side, + server_hostname=server_hostname, + do_handshake_on_connect=False) + + ssl_sock.do_handshake() + + self.__sock.close() + self.__sock = ssl_sock + + def __getattr__(self, name): + return getattr(self.__sock, name) + + def __repr__(self): + return '<{} {!r}>'.format(type(self).__name__, self.__sock) + + +class SocketThread(threading.Thread): + + def stop(self): + self._active = False + self.join() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *exc): + self.stop() + + +class TestThreadedClient(SocketThread): + + def __init__(self, test, sock, prog, timeout): + threading.Thread.__init__(self, None, None, 'test-client') + self.daemon = True + + self._timeout = timeout + self._sock = sock + self._active = True + self._prog = prog + self._test = test + + def run(self): + try: + self._prog(TestSocketWrapper(self._sock)) + except Exception as ex: + self._test._abort_socket_test(ex) + + +class TestThreadedServer(SocketThread): + + def __init__(self, test, sock, prog, timeout, max_clients): + threading.Thread.__init__(self, None, None, 'test-server') + self.daemon = True + + self._clients = 0 + self._finished_clients = 0 + self._max_clients = max_clients + self._timeout = timeout + self._sock = sock + self._active = True + + self._prog = prog + + self._s1, self._s2 = socket.socketpair() + self._s1.setblocking(False) + + self._test = test + + def stop(self): + try: + if self._s2 and self._s2.fileno() != -1: + try: + self._s2.send(b'stop') + except OSError: + pass + finally: + super().stop() + + def run(self): + try: + with self._sock: + self._sock.setblocking(0) + self._run() + finally: + self._s1.close() + self._s2.close() + + def _run(self): + while self._active: + if self._clients >= self._max_clients: + return + + r, w, x = select.select( + [self._sock, self._s1], [], [], self._timeout) + + if self._s1 in r: + return + + if self._sock in r: + try: + conn, addr = self._sock.accept() + except BlockingIOError: + continue + except socket.timeout: + if not self._active: + return + else: + raise + else: + self._clients += 1 + conn.settimeout(self._timeout) + try: + with conn: + self._handle_client(conn) + except Exception as ex: + self._active = False + try: + raise + finally: + self._test._abort_socket_test(ex) + + def _handle_client(self, sock): + self._prog(TestSocketWrapper(sock)) + + @property + def addr(self): + return self._sock.getsockname() diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 79e8d79..da2e036 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -31,21 +31,7 @@ from asyncio import events from asyncio import proactor_events from asyncio import selector_events from test.test_asyncio import utils as test_utils -try: - from test import support -except ImportError: - from asyncio import test_support as support - - -def data_file(filename): - if hasattr(support, 'TEST_HOME_DIR'): - fullname = os.path.join(support.TEST_HOME_DIR, filename) - if os.path.isfile(fullname): - return fullname - fullname = os.path.join(os.path.dirname(__file__), filename) - if os.path.isfile(fullname): - return fullname - raise FileNotFoundError(filename) +from test import support def osx_tiger(): @@ -80,23 +66,6 @@ class CoroLike: pass -ONLYCERT = data_file('ssl_cert.pem') -ONLYKEY = data_file('ssl_key.pem') -SIGNED_CERTFILE = data_file('keycert3.pem') -SIGNING_CA = data_file('pycacert.pem') -PEERCERT = {'serialNumber': 'B09264B1F2DA21D1', - 'version': 1, - 'subject': ((('countryName', 'XY'),), - (('localityName', 'Castle Anthrax'),), - (('organizationName', 'Python Software Foundation'),), - (('commonName', 'localhost'),)), - 'issuer': ((('countryName', 'XY'),), - (('organizationName', 'Python Software Foundation CA'),), - (('commonName', 'our-ca-server'),)), - 'notAfter': 'Nov 13 19:47:07 2022 GMT', - 'notBefore': 'Jan 4 19:47:07 2013 GMT'} - - class MyBaseProto(asyncio.Protocol): connected = None done = None @@ -853,16 +822,8 @@ class EventLoopTestsMixin: 'SSL not supported with proactor event loops before Python 3.5' ) - server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - server_context.load_cert_chain(ONLYCERT, ONLYKEY) - if hasattr(server_context, 'check_hostname'): - server_context.check_hostname = False - server_context.verify_mode = ssl.CERT_NONE - - client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - if hasattr(server_context, 'check_hostname'): - client_context.check_hostname = False - client_context.verify_mode = ssl.CERT_NONE + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() self.test_connect_accepted_socket(server_context, client_context) @@ -1048,7 +1009,7 @@ class EventLoopTestsMixin: def test_create_server_ssl(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( - lambda: proto, ONLYCERT, ONLYKEY) + lambda: proto, test_utils.ONLYCERT, test_utils.ONLYKEY) f_c = self.loop.create_connection(MyBaseProto, host, port, ssl=test_utils.dummy_ssl_context()) @@ -1081,7 +1042,7 @@ class EventLoopTestsMixin: def test_create_unix_server_ssl(self): proto = MyProto(loop=self.loop) server, path = self._make_ssl_unix_server( - lambda: proto, ONLYCERT, ONLYKEY) + lambda: proto, test_utils.ONLYCERT, test_utils.ONLYKEY) f_c = self.loop.create_unix_connection( MyBaseProto, path, ssl=test_utils.dummy_ssl_context(), @@ -1111,7 +1072,7 @@ class EventLoopTestsMixin: def test_create_server_ssl_verify_failed(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( - lambda: proto, SIGNED_CERTFILE) + lambda: proto, test_utils.SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) sslcontext_client.options |= ssl.OP_NO_SSLv2 @@ -1141,7 +1102,7 @@ class EventLoopTestsMixin: def test_create_unix_server_ssl_verify_failed(self): proto = MyProto(loop=self.loop) server, path = self._make_ssl_unix_server( - lambda: proto, SIGNED_CERTFILE) + lambda: proto, test_utils.SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) sslcontext_client.options |= ssl.OP_NO_SSLv2 @@ -1170,13 +1131,13 @@ class EventLoopTestsMixin: def test_create_server_ssl_match_failed(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( - lambda: proto, SIGNED_CERTFILE) + lambda: proto, test_utils.SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.verify_mode = ssl.CERT_REQUIRED sslcontext_client.load_verify_locations( - cafile=SIGNING_CA) + cafile=test_utils.SIGNING_CA) if hasattr(sslcontext_client, 'check_hostname'): sslcontext_client.check_hostname = True @@ -1199,12 +1160,12 @@ class EventLoopTestsMixin: def test_create_unix_server_ssl_verified(self): proto = MyProto(loop=self.loop) server, path = self._make_ssl_unix_server( - lambda: proto, SIGNED_CERTFILE) + lambda: proto, test_utils.SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.verify_mode = ssl.CERT_REQUIRED - sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + sslcontext_client.load_verify_locations(cafile=test_utils.SIGNING_CA) if hasattr(sslcontext_client, 'check_hostname'): sslcontext_client.check_hostname = True @@ -1224,12 +1185,12 @@ class EventLoopTestsMixin: def test_create_server_ssl_verified(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( - lambda: proto, SIGNED_CERTFILE) + lambda: proto, test_utils.SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.verify_mode = ssl.CERT_REQUIRED - sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + sslcontext_client.load_verify_locations(cafile=test_utils.SIGNING_CA) if hasattr(sslcontext_client, 'check_hostname'): sslcontext_client.check_hostname = True @@ -1241,7 +1202,7 @@ class EventLoopTestsMixin: # extra info is available self.check_ssl_extra_info(client,peername=(host, port), - peercert=PEERCERT) + peercert=test_utils.PEERCERT) # close connection proto.transport.close() diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index a7498e8..886c5cf 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -13,6 +13,7 @@ from asyncio import log from asyncio import sslproto from asyncio import tasks from test.test_asyncio import utils as test_utils +from test.test_asyncio import functional as func_tests @unittest.skipIf(ssl is None, 'No ssl module') @@ -158,5 +159,156 @@ class SslProtoHandshakeTests(test_utils.TestCase): self.assertIs(ssl_proto._app_protocol, new_app_proto) +############################################################################## +# Start TLS Tests +############################################################################## + + +class BaseStartTLS(func_tests.FunctionalTestCaseMixin): + + def new_loop(self): + raise NotImplementedError + + def test_start_tls_client_1(self): + HELLO_MSG = b'1' * 1024 * 1024 * 5 + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def serve(sock): + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.start_tls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + sock.close() + + class ClientProto(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(proto, tr): + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr) + + tr.write(HELLO_MSG) + new_tr = await self.loop.start_tls(tr, proto, client_context) + + self.assertEqual(await on_data, b'O') + new_tr.write(HELLO_MSG) + await on_eof + + new_tr.close() + + with self.tcp_server(serve) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10)) + + def test_start_tls_server_1(self): + HELLO_MSG = b'1' * 1024 * 1024 * 5 + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def client(sock, addr): + sock.connect(addr) + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.start_tls(client_context) + sock.sendall(HELLO_MSG) + sock.close() + + class ServerProto(asyncio.Protocol): + def __init__(self, on_con, on_eof): + self.on_con = on_con + self.on_eof = on_eof + self.data = b'' + + def connection_made(self, tr): + self.on_con.set_result(tr) + + def data_received(self, data): + self.data += data + + def eof_received(self): + self.on_eof.set_result(1) + + async def main(): + tr = await on_con + tr.write(HELLO_MSG) + + self.assertEqual(proto.data, b'') + + new_tr = await self.loop.start_tls( + tr, proto, server_context, + server_side=True) + + await on_eof + self.assertEqual(proto.data, HELLO_MSG) + new_tr.close() + + server.close() + await server.wait_closed() + + on_con = self.loop.create_future() + on_eof = self.loop.create_future() + proto = ServerProto(on_con, on_eof) + + server = self.loop.run_until_complete( + self.loop.create_server( + lambda: proto, '127.0.0.1', 0)) + addr = server.sockets[0].getsockname() + + with self.tcp_client(lambda sock: client(sock, addr)): + self.loop.run_until_complete( + asyncio.wait_for(main(), loop=self.loop, timeout=10)) + + def test_start_tls_wrong_args(self): + async def main(): + with self.assertRaisesRegex(TypeError, 'SSLContext, got'): + await self.loop.start_tls(None, None, None) + + sslctx = test_utils.simple_server_sslcontext() + with self.assertRaisesRegex(TypeError, 'is not supported'): + await self.loop.start_tls(None, None, sslctx) + + self.loop.run_until_complete(main()) + + +@unittest.skipIf(ssl is None, 'No ssl module') +class SelectorStartTLS(BaseStartTLS, unittest.TestCase): + + def new_loop(self): + return asyncio.SelectorEventLoop() + + +@unittest.skipIf(ssl is None, 'No ssl module') +@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') +class ProactorStartTLS(BaseStartTLS, unittest.TestCase): + + def new_loop(self): + return asyncio.ProactorEventLoop() + + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py index eaafe3a..a78e019 100644 --- a/Lib/test/test_asyncio/utils.py +++ b/Lib/test/test_asyncio/utils.py @@ -35,6 +35,49 @@ from asyncio.log import logger from test import support +def data_file(filename): + if hasattr(support, 'TEST_HOME_DIR'): + fullname = os.path.join(support.TEST_HOME_DIR, filename) + if os.path.isfile(fullname): + return fullname + fullname = os.path.join(os.path.dirname(__file__), filename) + if os.path.isfile(fullname): + return fullname + raise FileNotFoundError(filename) + + +ONLYCERT = data_file('ssl_cert.pem') +ONLYKEY = data_file('ssl_key.pem') +SIGNED_CERTFILE = data_file('keycert3.pem') +SIGNING_CA = data_file('pycacert.pem') +PEERCERT = {'serialNumber': 'B09264B1F2DA21D1', + 'version': 1, + 'subject': ((('countryName', 'XY'),), + (('localityName', 'Castle Anthrax'),), + (('organizationName', 'Python Software Foundation'),), + (('commonName', 'localhost'),)), + 'issuer': ((('countryName', 'XY'),), + (('organizationName', 'Python Software Foundation CA'),), + (('commonName', 'our-ca-server'),)), + 'notAfter': 'Nov 13 19:47:07 2022 GMT', + 'notBefore': 'Jan 4 19:47:07 2013 GMT'} + + +def simple_server_sslcontext(): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(ONLYCERT, ONLYKEY) + server_context.check_hostname = False + server_context.verify_mode = ssl.CERT_NONE + return server_context + + +def simple_client_sslcontext(): + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_context.check_hostname = False + client_context.verify_mode = ssl.CERT_NONE + return client_context + + def dummy_ssl_context(): if ssl is None: return None diff --git a/Misc/NEWS.d/next/Library/2017-12-29-00-44-42.bpo-23749.QL1Cxd.rst b/Misc/NEWS.d/next/Library/2017-12-29-00-44-42.bpo-23749.QL1Cxd.rst new file mode 100644 index 0000000..d6de1fe --- /dev/null +++ b/Misc/NEWS.d/next/Library/2017-12-29-00-44-42.bpo-23749.QL1Cxd.rst @@ -0,0 +1 @@ +asyncio: Implement loop.start_tls() -- cgit v0.12