diff options
Diffstat (limited to 'Lib/test/test_asyncio/test_events.py')
-rw-r--r-- | Lib/test/test_asyncio/test_events.py | 303 |
1 files changed, 300 insertions, 3 deletions
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index cf21753..0981bd6 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -26,6 +26,7 @@ if sys.platform != 'win32': import tty import asyncio +from asyncio import base_events from asyncio import coroutines from asyncio import events from asyncio import proactor_events @@ -2090,14 +2091,308 @@ class SubprocessTestsMixin: self.loop.run_until_complete(connect(shell=False)) +class MySendfileProto(MyBaseProto): + + def __init__(self, loop=None, close_after=0): + super().__init__(loop) + self.data = bytearray() + self.close_after = close_after + + def data_received(self, data): + self.data.extend(data) + super().data_received(data) + if self.close_after and self.nbytes >= self.close_after: + self.transport.close() + + +class SendfileMixin: + # Note: sendfile via SSL transport is equal to sendfile fallback + + DATA = b"12345abcde" * 160 * 1024 # 160 KiB + + @classmethod + def setUpClass(cls): + with open(support.TESTFN, 'wb') as fp: + fp.write(cls.DATA) + super().setUpClass() + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + super().tearDownClass() + + def setUp(self): + self.file = open(support.TESTFN, 'rb') + self.addCleanup(self.file.close) + super().setUp() + + def run_loop(self, coro): + return self.loop.run_until_complete(coro) + + def prepare(self, *, is_ssl=False, close_after=0): + port = support.find_unused_port() + srv_proto = MySendfileProto(loop=self.loop, close_after=close_after) + if is_ssl: + srv_ctx = test_utils.simple_server_sslcontext() + cli_ctx = test_utils.simple_client_sslcontext() + else: + srv_ctx = None + cli_ctx = None + srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # reduce recv socket buffer size to test on relative small data sets + srv_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024) + srv_sock.bind((support.HOST, port)) + server = self.run_loop(self.loop.create_server( + lambda: srv_proto, sock=srv_sock, ssl=srv_ctx)) + + if is_ssl: + server_hostname = support.HOST + else: + server_hostname = None + cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # reduce send socket buffer size to test on relative small data sets + cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + cli_sock.connect((support.HOST, port)) + cli_proto = MySendfileProto(loop=self.loop) + tr, pr = self.run_loop(self.loop.create_connection( + lambda: cli_proto, sock=cli_sock, + ssl=cli_ctx, server_hostname=server_hostname)) + + def cleanup(): + srv_proto.transport.close() + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.run_loop(cli_proto.done) + + server.close() + self.run_loop(server.wait_closed()) + + self.addCleanup(cleanup) + return srv_proto, cli_proto + + @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported") + def test_sendfile_not_supported(self): + tr, pr = self.run_loop( + self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + family=socket.AF_INET)) + try: + with self.assertRaisesRegex(RuntimeError, "not supported"): + self.run_loop( + self.loop.sendfile(tr, self.file)) + self.assertEqual(0, self.file.tell()) + finally: + # don't use self.addCleanup because it produces resource warning + tr.close() + + def test_sendfile(self): + srv_proto, cli_proto = self.prepare() + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_force_fallback(self): + srv_proto, cli_proto = self.prepare() + + def sendfile_native(transp, file, offset, count): + # to raise SendfileNotAvailableError + return base_events.BaseEventLoop._sendfile_native( + self.loop, transp, file, offset, count) + + self.loop._sendfile_native = sendfile_native + + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_force_unsupported_native(self): + if sys.platform == 'win32': + if isinstance(self.loop, asyncio.ProactorEventLoop): + self.skipTest("Fails on proactor event loop") + srv_proto, cli_proto = self.prepare() + + def sendfile_native(transp, file, offset, count): + # to raise SendfileNotAvailableError + return base_events.BaseEventLoop._sendfile_native( + self.loop, transp, file, offset, count) + + self.loop._sendfile_native = sendfile_native + + with self.assertRaisesRegex(events.SendfileNotAvailableError, + "not supported"): + self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file, + fallback=False)) + + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(srv_proto.nbytes, 0) + self.assertEqual(self.file.tell(), 0) + + def test_sendfile_ssl(self): + srv_proto, cli_proto = self.prepare(is_ssl=True) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_for_closing_transp(self): + srv_proto, cli_proto = self.prepare() + cli_proto.transport.close() + with self.assertRaisesRegex(RuntimeError, "is closing"): + self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + self.assertEqual(srv_proto.nbytes, 0) + self.assertEqual(self.file.tell(), 0) + + def test_sendfile_pre_and_post_data(self): + srv_proto, cli_proto = self.prepare() + PREFIX = b'zxcvbnm' * 1024 + SUFFIX = b'0987654321' * 1024 + cli_proto.transport.write(PREFIX) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.write(SUFFIX) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_ssl_pre_and_post_data(self): + srv_proto, cli_proto = self.prepare(is_ssl=True) + PREFIX = b'zxcvbnm' * 1024 + SUFFIX = b'0987654321' * 1024 + cli_proto.transport.write(PREFIX) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.write(SUFFIX) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_partial(self): + srv_proto, cli_proto = self.prepare() + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, 100) + self.assertEqual(srv_proto.nbytes, 100) + self.assertEqual(srv_proto.data, self.DATA[1000:1100]) + self.assertEqual(self.file.tell(), 1100) + + def test_sendfile_ssl_partial(self): + srv_proto, cli_proto = self.prepare(is_ssl=True) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, 100) + self.assertEqual(srv_proto.nbytes, 100) + self.assertEqual(srv_proto.data, self.DATA[1000:1100]) + self.assertEqual(self.file.tell(), 1100) + + def test_sendfile_close_peer_after_receiving(self): + srv_proto, cli_proto = self.prepare(close_after=len(self.DATA)) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_ssl_close_peer_after_receiving(self): + srv_proto, cli_proto = self.prepare(is_ssl=True, + close_after=len(self.DATA)) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_close_peer_in_middle_of_receiving(self): + srv_proto, cli_proto = self.prepare(close_after=1024) + with self.assertRaises(ConnectionError): + self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + + self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), + srv_proto.nbytes) + self.assertTrue(1024 <= self.file.tell() < len(self.DATA), + self.file.tell()) + + def test_sendfile_fallback_close_peer_in_middle_of_receiving(self): + + def sendfile_native(transp, file, offset, count): + # to raise SendfileNotAvailableError + return base_events.BaseEventLoop._sendfile_native( + self.loop, transp, file, offset, count) + + self.loop._sendfile_native = sendfile_native + + srv_proto, cli_proto = self.prepare(close_after=1024) + with self.assertRaises(ConnectionError): + self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + + self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), + srv_proto.nbytes) + self.assertTrue(1024 <= self.file.tell() < len(self.DATA), + self.file.tell()) + + @unittest.skipIf(not hasattr(os, 'sendfile'), + "Don't have native sendfile support") + def test_sendfile_prevents_bare_write(self): + srv_proto, cli_proto = self.prepare() + fut = self.loop.create_future() + + async def coro(): + fut.set_result(None) + return await self.loop.sendfile(cli_proto.transport, self.file) + + t = self.loop.create_task(coro()) + self.run_loop(fut) + with self.assertRaisesRegex(RuntimeError, + "sendfile is in progress"): + cli_proto.transport.write(b'data') + ret = self.run_loop(t) + self.assertEqual(ret, len(self.DATA)) + + if sys.platform == 'win32': - class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase): + class SelectEventLoopTests(EventLoopTestsMixin, + SendfileMixin, + test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop() class ProactorEventLoopTests(EventLoopTestsMixin, + SendfileMixin, SubprocessTestsMixin, test_utils.TestCase): @@ -2125,7 +2420,7 @@ if sys.platform == 'win32': else: import selectors - class UnixEventLoopTestsMixin(EventLoopTestsMixin): + class UnixEventLoopTestsMixin(EventLoopTestsMixin, SendfileMixin): def setUp(self): super().setUp() watcher = asyncio.SafeChildWatcher() @@ -2556,7 +2851,9 @@ class AbstractEventLoopTests(unittest.TestCase): with self.assertRaises(NotImplementedError): await loop.sock_accept(f) with self.assertRaises(NotImplementedError): - await loop.sock_sendfile(f, mock.Mock()) + await loop.sock_sendfile(f, f) + with self.assertRaises(NotImplementedError): + await loop.sendfile(f, f) with self.assertRaises(NotImplementedError): await loop.connect_read_pipe(f, mock.sentinel.pipe) with self.assertRaises(NotImplementedError): |