summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_asyncio/test_sslproto.py
diff options
context:
space:
mode:
authorYury Selivanov <yury@magic.io>2017-12-30 05:35:36 (GMT)
committerGitHub <noreply@github.com>2017-12-30 05:35:36 (GMT)
commitf111b3dcb414093a4efb9d74b69925e535ddc470 (patch)
tree9905a970a809f7f14cb378b5b90f1f9d06aebbeb /Lib/test/test_asyncio/test_sslproto.py
parentbbdb17d19bb1d5443ca4417254e014ad64c04540 (diff)
downloadcpython-f111b3dcb414093a4efb9d74b69925e535ddc470.zip
cpython-f111b3dcb414093a4efb9d74b69925e535ddc470.tar.gz
cpython-f111b3dcb414093a4efb9d74b69925e535ddc470.tar.bz2
bpo-23749: Implement loop.start_tls() (#5039)
Diffstat (limited to 'Lib/test/test_asyncio/test_sslproto.py')
-rw-r--r--Lib/test/test_asyncio/test_sslproto.py152
1 files changed, 152 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py
index a7498e8..886c5cf 100644
--- a/Lib/test/test_asyncio/test_sslproto.py
+++ b/Lib/test/test_asyncio/test_sslproto.py
@@ -13,6 +13,7 @@ from asyncio import log
from asyncio import sslproto
from asyncio import tasks
from test.test_asyncio import utils as test_utils
+from test.test_asyncio import functional as func_tests
@unittest.skipIf(ssl is None, 'No ssl module')
@@ -158,5 +159,156 @@ class SslProtoHandshakeTests(test_utils.TestCase):
self.assertIs(ssl_proto._app_protocol, new_app_proto)
+##############################################################################
+# Start TLS Tests
+##############################################################################
+
+
+class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
+
+ def new_loop(self):
+ raise NotImplementedError
+
+ def test_start_tls_client_1(self):
+ HELLO_MSG = b'1' * 1024 * 1024 * 5
+
+ server_context = test_utils.simple_server_sslcontext()
+ client_context = test_utils.simple_client_sslcontext()
+
+ def serve(sock):
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.start_tls(server_context, server_side=True)
+
+ sock.sendall(b'O')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+ sock.close()
+
+ class ClientProto(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(proto, tr):
+ proto.con_made_cnt += 1
+ # Ensure connection_made gets called only once.
+ self.assertEqual(proto.con_made_cnt, 1)
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
+ def eof_received(self):
+ self.on_eof.set_result(True)
+
+ async def client(addr):
+ on_data = self.loop.create_future()
+ on_eof = self.loop.create_future()
+
+ tr, proto = await self.loop.create_connection(
+ lambda: ClientProto(on_data, on_eof), *addr)
+
+ tr.write(HELLO_MSG)
+ new_tr = await self.loop.start_tls(tr, proto, client_context)
+
+ self.assertEqual(await on_data, b'O')
+ new_tr.write(HELLO_MSG)
+ await on_eof
+
+ new_tr.close()
+
+ with self.tcp_server(serve) as srv:
+ self.loop.run_until_complete(
+ asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
+
+ def test_start_tls_server_1(self):
+ HELLO_MSG = b'1' * 1024 * 1024 * 5
+
+ server_context = test_utils.simple_server_sslcontext()
+ client_context = test_utils.simple_client_sslcontext()
+
+ def client(sock, addr):
+ sock.connect(addr)
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
+ sock.start_tls(client_context)
+ sock.sendall(HELLO_MSG)
+ sock.close()
+
+ class ServerProto(asyncio.Protocol):
+ def __init__(self, on_con, on_eof):
+ self.on_con = on_con
+ self.on_eof = on_eof
+ self.data = b''
+
+ def connection_made(self, tr):
+ self.on_con.set_result(tr)
+
+ def data_received(self, data):
+ self.data += data
+
+ def eof_received(self):
+ self.on_eof.set_result(1)
+
+ async def main():
+ tr = await on_con
+ tr.write(HELLO_MSG)
+
+ self.assertEqual(proto.data, b'')
+
+ new_tr = await self.loop.start_tls(
+ tr, proto, server_context,
+ server_side=True)
+
+ await on_eof
+ self.assertEqual(proto.data, HELLO_MSG)
+ new_tr.close()
+
+ server.close()
+ await server.wait_closed()
+
+ on_con = self.loop.create_future()
+ on_eof = self.loop.create_future()
+ proto = ServerProto(on_con, on_eof)
+
+ server = self.loop.run_until_complete(
+ self.loop.create_server(
+ lambda: proto, '127.0.0.1', 0))
+ addr = server.sockets[0].getsockname()
+
+ with self.tcp_client(lambda sock: client(sock, addr)):
+ self.loop.run_until_complete(
+ asyncio.wait_for(main(), loop=self.loop, timeout=10))
+
+ def test_start_tls_wrong_args(self):
+ async def main():
+ with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
+ await self.loop.start_tls(None, None, None)
+
+ sslctx = test_utils.simple_server_sslcontext()
+ with self.assertRaisesRegex(TypeError, 'is not supported'):
+ await self.loop.start_tls(None, None, sslctx)
+
+ self.loop.run_until_complete(main())
+
+
+@unittest.skipIf(ssl is None, 'No ssl module')
+class SelectorStartTLS(BaseStartTLS, unittest.TestCase):
+
+ def new_loop(self):
+ return asyncio.SelectorEventLoop()
+
+
+@unittest.skipIf(ssl is None, 'No ssl module')
+@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
+class ProactorStartTLS(BaseStartTLS, unittest.TestCase):
+
+ def new_loop(self):
+ return asyncio.ProactorEventLoop()
+
+
if __name__ == '__main__':
unittest.main()