diff options
author | Yury Selivanov <yury@magic.io> | 2017-12-30 05:35:36 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-12-30 05:35:36 (GMT) |
commit | f111b3dcb414093a4efb9d74b69925e535ddc470 (patch) | |
tree | 9905a970a809f7f14cb378b5b90f1f9d06aebbeb /Lib/test/test_asyncio/test_sslproto.py | |
parent | bbdb17d19bb1d5443ca4417254e014ad64c04540 (diff) | |
download | cpython-f111b3dcb414093a4efb9d74b69925e535ddc470.zip cpython-f111b3dcb414093a4efb9d74b69925e535ddc470.tar.gz cpython-f111b3dcb414093a4efb9d74b69925e535ddc470.tar.bz2 |
bpo-23749: Implement loop.start_tls() (#5039)
Diffstat (limited to 'Lib/test/test_asyncio/test_sslproto.py')
-rw-r--r-- | Lib/test/test_asyncio/test_sslproto.py | 152 |
1 files changed, 152 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index a7498e8..886c5cf 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -13,6 +13,7 @@ from asyncio import log from asyncio import sslproto from asyncio import tasks from test.test_asyncio import utils as test_utils +from test.test_asyncio import functional as func_tests @unittest.skipIf(ssl is None, 'No ssl module') @@ -158,5 +159,156 @@ class SslProtoHandshakeTests(test_utils.TestCase): self.assertIs(ssl_proto._app_protocol, new_app_proto) +############################################################################## +# Start TLS Tests +############################################################################## + + +class BaseStartTLS(func_tests.FunctionalTestCaseMixin): + + def new_loop(self): + raise NotImplementedError + + def test_start_tls_client_1(self): + HELLO_MSG = b'1' * 1024 * 1024 * 5 + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def serve(sock): + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.start_tls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + sock.close() + + class ClientProto(asyncio.Protocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + + def connection_made(proto, tr): + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def data_received(self, data): + self.on_data.set_result(data) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr) + + tr.write(HELLO_MSG) + new_tr = await self.loop.start_tls(tr, proto, client_context) + + self.assertEqual(await on_data, b'O') + new_tr.write(HELLO_MSG) + await on_eof + + new_tr.close() + + with self.tcp_server(serve) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10)) + + def test_start_tls_server_1(self): + HELLO_MSG = b'1' * 1024 * 1024 * 5 + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def client(sock, addr): + sock.connect(addr) + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.start_tls(client_context) + sock.sendall(HELLO_MSG) + sock.close() + + class ServerProto(asyncio.Protocol): + def __init__(self, on_con, on_eof): + self.on_con = on_con + self.on_eof = on_eof + self.data = b'' + + def connection_made(self, tr): + self.on_con.set_result(tr) + + def data_received(self, data): + self.data += data + + def eof_received(self): + self.on_eof.set_result(1) + + async def main(): + tr = await on_con + tr.write(HELLO_MSG) + + self.assertEqual(proto.data, b'') + + new_tr = await self.loop.start_tls( + tr, proto, server_context, + server_side=True) + + await on_eof + self.assertEqual(proto.data, HELLO_MSG) + new_tr.close() + + server.close() + await server.wait_closed() + + on_con = self.loop.create_future() + on_eof = self.loop.create_future() + proto = ServerProto(on_con, on_eof) + + server = self.loop.run_until_complete( + self.loop.create_server( + lambda: proto, '127.0.0.1', 0)) + addr = server.sockets[0].getsockname() + + with self.tcp_client(lambda sock: client(sock, addr)): + self.loop.run_until_complete( + asyncio.wait_for(main(), loop=self.loop, timeout=10)) + + def test_start_tls_wrong_args(self): + async def main(): + with self.assertRaisesRegex(TypeError, 'SSLContext, got'): + await self.loop.start_tls(None, None, None) + + sslctx = test_utils.simple_server_sslcontext() + with self.assertRaisesRegex(TypeError, 'is not supported'): + await self.loop.start_tls(None, None, sslctx) + + self.loop.run_until_complete(main()) + + +@unittest.skipIf(ssl is None, 'No ssl module') +class SelectorStartTLS(BaseStartTLS, unittest.TestCase): + + def new_loop(self): + return asyncio.SelectorEventLoop() + + +@unittest.skipIf(ssl is None, 'No ssl module') +@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') +class ProactorStartTLS(BaseStartTLS, unittest.TestCase): + + def new_loop(self): + return asyncio.ProactorEventLoop() + + if __name__ == '__main__': unittest.main() |