diff options
Diffstat (limited to 'Lib/test/test_asyncio/test_sslproto.py')
-rw-r--r-- | Lib/test/test_asyncio/test_sslproto.py | 177 |
1 files changed, 152 insertions, 25 deletions
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index c534a34..932487a 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -1,8 +1,7 @@ """Tests for asyncio/sslproto.py.""" -import os import logging -import time +import socket import unittest from unittest import mock try: @@ -185,17 +184,67 @@ class SslProtoHandshakeTests(test_utils.TestCase): class BaseStartTLS(func_tests.FunctionalTestCaseMixin): + PAYLOAD_SIZE = 1024 * 100 + TIMEOUT = 60 + def new_loop(self): raise NotImplementedError - def test_start_tls_client_1(self): - HELLO_MSG = b'1' * 1024 * 1024 + def test_buf_feed_data(self): + + class Proto(asyncio.BufferedProtocol): + + def __init__(self, bufsize, usemv): + self.buf = bytearray(bufsize) + self.mv = memoryview(self.buf) + self.data = b'' + self.usemv = usemv + + def get_buffer(self, sizehint): + if self.usemv: + return self.mv + else: + return self.buf + + def buffer_updated(self, nsize): + if self.usemv: + self.data += self.mv[:nsize] + else: + self.data += self.buf[:nsize] + + for usemv in [False, True]: + proto = Proto(1, usemv) + sslproto._feed_data_to_bufferred_proto(proto, b'12345') + self.assertEqual(proto.data, b'12345') + + proto = Proto(2, usemv) + sslproto._feed_data_to_bufferred_proto(proto, b'12345') + self.assertEqual(proto.data, b'12345') + + proto = Proto(2, usemv) + sslproto._feed_data_to_bufferred_proto(proto, b'1234') + self.assertEqual(proto.data, b'1234') + + proto = Proto(4, usemv) + sslproto._feed_data_to_bufferred_proto(proto, b'1234') + self.assertEqual(proto.data, b'1234') + + proto = Proto(100, usemv) + sslproto._feed_data_to_bufferred_proto(proto, b'12345') + self.assertEqual(proto.data, b'12345') + + proto = Proto(0, usemv) + with self.assertRaisesRegex(RuntimeError, 'empty buffer'): + sslproto._feed_data_to_bufferred_proto(proto, b'12345') + + def test_start_tls_client_reg_proto_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE server_context = test_utils.simple_server_sslcontext() client_context = test_utils.simple_client_sslcontext() def serve(sock): - sock.settimeout(5) + sock.settimeout(self.TIMEOUT) data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) @@ -205,6 +254,8 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): sock.sendall(b'O') data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) + + sock.shutdown(socket.SHUT_RDWR) sock.close() class ClientProto(asyncio.Protocol): @@ -246,17 +297,80 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): self.loop.run_until_complete( asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10)) + def test_start_tls_client_buf_proto_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.start_tls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.shutdown(socket.SHUT_RDWR) + sock.close() + + class ClientProto(asyncio.BufferedProtocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + self.buf = bytearray(1) + + def connection_made(proto, tr): + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def get_buffer(self, sizehint): + return self.buf + + def buffer_updated(self, nsize): + assert nsize == 1 + self.on_data.set_result(bytes(self.buf[:nsize])) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5, loop=self.loop) + + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr) + + tr.write(HELLO_MSG) + new_tr = await self.loop.start_tls(tr, proto, client_context) + + self.assertEqual(await on_data, b'O') + new_tr.write(HELLO_MSG) + await on_eof + + new_tr.close() + + with self.tcp_server(serve) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), + loop=self.loop, timeout=self.TIMEOUT)) + def test_start_tls_server_1(self): - HELLO_MSG = b'1' * 1024 * 1024 + HELLO_MSG = b'1' * self.PAYLOAD_SIZE server_context = test_utils.simple_server_sslcontext() client_context = test_utils.simple_client_sslcontext() - # TODO: fix TLSv1.3 support - client_context.options |= ssl.OP_NO_TLSv1_3 def client(sock, addr): - time.sleep(0.5) - sock.settimeout(5) + sock.settimeout(self.TIMEOUT) sock.connect(addr) data = sock.recv_all(len(HELLO_MSG)) @@ -264,12 +378,15 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): sock.start_tls(client_context) sock.sendall(HELLO_MSG) + + sock.shutdown(socket.SHUT_RDWR) sock.close() class ServerProto(asyncio.Protocol): - def __init__(self, on_con, on_eof): + def __init__(self, on_con, on_eof, on_con_lost): self.on_con = on_con self.on_eof = on_eof + self.on_con_lost = on_con_lost self.data = b'' def connection_made(self, tr): @@ -281,7 +398,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): def eof_received(self): self.on_eof.set_result(1) - async def main(): + def connection_lost(self, exc): + if exc is None: + self.on_con_lost.set_result(None) + else: + self.on_con_lost.set_exception(exc) + + async def main(proto, on_con, on_eof, on_con_lost): tr = await on_con tr.write(HELLO_MSG) @@ -292,24 +415,29 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin): server_side=True) await on_eof + await on_con_lost self.assertEqual(proto.data, HELLO_MSG) new_tr.close() - server.close() - await server.wait_closed() + async def run_main(): + on_con = self.loop.create_future() + on_eof = self.loop.create_future() + on_con_lost = self.loop.create_future() + proto = ServerProto(on_con, on_eof, on_con_lost) - on_con = self.loop.create_future() - on_eof = self.loop.create_future() - proto = ServerProto(on_con, on_eof) + server = await self.loop.create_server( + lambda: proto, '127.0.0.1', 0) + addr = server.sockets[0].getsockname() - server = self.loop.run_until_complete( - self.loop.create_server( - lambda: proto, '127.0.0.1', 0)) - addr = server.sockets[0].getsockname() + with self.tcp_client(lambda sock: client(sock, addr)): + await asyncio.wait_for( + main(proto, on_con, on_eof, on_con_lost), + loop=self.loop, timeout=self.TIMEOUT) - with self.tcp_client(lambda sock: client(sock, addr)): - self.loop.run_until_complete( - asyncio.wait_for(main(), loop=self.loop, timeout=10)) + server.close() + await server.wait_closed() + + self.loop.run_until_complete(run_main()) def test_start_tls_wrong_args(self): async def main(): @@ -332,7 +460,6 @@ class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase): @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') -@unittest.skipIf(os.environ.get('APPVEYOR'), 'XXX: issue 32458') class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase): def new_loop(self): |