diff options
author | Andrew Svetlov <andrew.svetlov@gmail.com> | 2018-10-09 04:52:57 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-10-09 04:52:57 (GMT) |
commit | 2b2758d0b30f4ed7d37319d6c18552eccbc8e7b7 (patch) | |
tree | c75330f29ba7fc380dc182bfd16657c51e2c84bb | |
parent | 199a280af540e3194405eb250ca1a8d487f6a4f7 (diff) | |
download | cpython-2b2758d0b30f4ed7d37319d6c18552eccbc8e7b7.zip cpython-2b2758d0b30f4ed7d37319d6c18552eccbc8e7b7.tar.gz cpython-2b2758d0b30f4ed7d37319d6c18552eccbc8e7b7.tar.bz2 |
Extract sendfile tests into a separate test file (#9757)
-rw-r--r-- | Lib/test/test_asyncio/test_events.py | 451 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_sendfile.py | 550 |
2 files changed, 551 insertions, 450 deletions
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 607c195..b76cfb7 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -15,7 +15,6 @@ except ImportError: ssl = None import subprocess import sys -import tempfile import threading import time import errno @@ -1987,461 +1986,15 @@ class SubprocessTestsMixin: self.loop.run_until_complete(connect(shell=False)) -class SendfileBase: - - DATA = b"SendfileBaseData" * (1024 * 8) # 128 KiB - - # Reduce socket buffer size to test on relative small data sets. - BUF_SIZE = 4 * 1024 # 4 KiB - - @classmethod - def setUpClass(cls): - with open(support.TESTFN, 'wb') as fp: - fp.write(cls.DATA) - super().setUpClass() - - @classmethod - def tearDownClass(cls): - support.unlink(support.TESTFN) - super().tearDownClass() - - def setUp(self): - self.file = open(support.TESTFN, 'rb') - self.addCleanup(self.file.close) - super().setUp() - - def run_loop(self, coro): - return self.loop.run_until_complete(coro) - - -class SockSendfileMixin(SendfileBase): - - class MyProto(asyncio.Protocol): - - def __init__(self, loop): - self.started = False - self.closed = False - self.data = bytearray() - self.fut = loop.create_future() - self.transport = None - - def connection_made(self, transport): - self.started = True - self.transport = transport - - def data_received(self, data): - self.data.extend(data) - - def connection_lost(self, exc): - self.closed = True - self.fut.set_result(None) - - async def wait_closed(self): - await self.fut - - @classmethod - def setUpClass(cls): - cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE - constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16 - super().setUpClass() - - @classmethod - def tearDownClass(cls): - constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize - super().tearDownClass() - - def make_socket(self, cleanup=True): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setblocking(False) - if cleanup: - self.addCleanup(sock.close) - return sock - - def reduce_receive_buffer_size(self, sock): - # Reduce receive socket buffer size to test on relative - # small data sets. - sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE) - - def reduce_send_buffer_size(self, sock, transport=None): - # Reduce send socket buffer size to test on relative small data sets. - - # On macOS, SO_SNDBUF is reset by connect(). So this method - # should be called after the socket is connected. - sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE) - - if transport is not None: - transport.set_write_buffer_limits(high=self.BUF_SIZE) - - def prepare_socksendfile(self): - proto = self.MyProto(self.loop) - port = support.find_unused_port() - srv_sock = self.make_socket(cleanup=False) - srv_sock.bind((support.HOST, port)) - server = self.run_loop(self.loop.create_server( - lambda: proto, sock=srv_sock)) - self.reduce_receive_buffer_size(srv_sock) - - sock = self.make_socket() - self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port))) - self.reduce_send_buffer_size(sock) - - def cleanup(): - if proto.transport is not None: - # can be None if the task was cancelled before - # connection_made callback - proto.transport.close() - self.run_loop(proto.wait_closed()) - - server.close() - self.run_loop(server.wait_closed()) - - self.addCleanup(cleanup) - - return sock, proto - - def test_sock_sendfile_success(self): - sock, proto = self.prepare_socksendfile() - ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) - sock.close() - self.run_loop(proto.wait_closed()) - - self.assertEqual(ret, len(self.DATA)) - self.assertEqual(proto.data, self.DATA) - self.assertEqual(self.file.tell(), len(self.DATA)) - - def test_sock_sendfile_with_offset_and_count(self): - sock, proto = self.prepare_socksendfile() - ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, - 1000, 2000)) - sock.close() - self.run_loop(proto.wait_closed()) - - self.assertEqual(proto.data, self.DATA[1000:3000]) - self.assertEqual(self.file.tell(), 3000) - self.assertEqual(ret, 2000) - - def test_sock_sendfile_zero_size(self): - sock, proto = self.prepare_socksendfile() - with tempfile.TemporaryFile() as f: - ret = self.run_loop(self.loop.sock_sendfile(sock, f, - 0, None)) - sock.close() - self.run_loop(proto.wait_closed()) - - self.assertEqual(ret, 0) - self.assertEqual(self.file.tell(), 0) - - def test_sock_sendfile_mix_with_regular_send(self): - buf = b"mix_regular_send" * (4 * 1024) # 64 KiB - sock, proto = self.prepare_socksendfile() - self.run_loop(self.loop.sock_sendall(sock, buf)) - ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) - self.run_loop(self.loop.sock_sendall(sock, buf)) - sock.close() - self.run_loop(proto.wait_closed()) - - self.assertEqual(ret, len(self.DATA)) - expected = buf + self.DATA + buf - self.assertEqual(proto.data, expected) - self.assertEqual(self.file.tell(), len(self.DATA)) - - -class SendfileMixin(SendfileBase): - - class MySendfileProto(MyBaseProto): - - def __init__(self, loop=None, close_after=0): - super().__init__(loop) - self.data = bytearray() - self.close_after = close_after - - def data_received(self, data): - self.data.extend(data) - super().data_received(data) - if self.close_after and self.nbytes >= self.close_after: - self.transport.close() - - - # Note: sendfile via SSL transport is equal to sendfile fallback - - def prepare_sendfile(self, *, is_ssl=False, close_after=0): - port = support.find_unused_port() - srv_proto = self.MySendfileProto(loop=self.loop, - close_after=close_after) - if is_ssl: - if not ssl: - self.skipTest("No ssl module") - srv_ctx = test_utils.simple_server_sslcontext() - cli_ctx = test_utils.simple_client_sslcontext() - else: - srv_ctx = None - cli_ctx = None - srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - srv_sock.bind((support.HOST, port)) - server = self.run_loop(self.loop.create_server( - lambda: srv_proto, sock=srv_sock, ssl=srv_ctx)) - self.reduce_receive_buffer_size(srv_sock) - - if is_ssl: - server_hostname = support.HOST - else: - server_hostname = None - cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - cli_sock.connect((support.HOST, port)) - - cli_proto = self.MySendfileProto(loop=self.loop) - tr, pr = self.run_loop(self.loop.create_connection( - lambda: cli_proto, sock=cli_sock, - ssl=cli_ctx, server_hostname=server_hostname)) - self.reduce_send_buffer_size(cli_sock, transport=tr) - - def cleanup(): - srv_proto.transport.close() - cli_proto.transport.close() - self.run_loop(srv_proto.done) - self.run_loop(cli_proto.done) - - server.close() - self.run_loop(server.wait_closed()) - - self.addCleanup(cleanup) - return srv_proto, cli_proto - - @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported") - def test_sendfile_not_supported(self): - tr, pr = self.run_loop( - self.loop.create_datagram_endpoint( - lambda: MyDatagramProto(loop=self.loop), - family=socket.AF_INET)) - try: - with self.assertRaisesRegex(RuntimeError, "not supported"): - self.run_loop( - self.loop.sendfile(tr, self.file)) - self.assertEqual(0, self.file.tell()) - finally: - # don't use self.addCleanup because it produces resource warning - tr.close() - - def test_sendfile(self): - srv_proto, cli_proto = self.prepare_sendfile() - ret = self.run_loop( - self.loop.sendfile(cli_proto.transport, self.file)) - cli_proto.transport.close() - self.run_loop(srv_proto.done) - self.assertEqual(ret, len(self.DATA)) - self.assertEqual(srv_proto.nbytes, len(self.DATA)) - self.assertEqual(srv_proto.data, self.DATA) - self.assertEqual(self.file.tell(), len(self.DATA)) - - def test_sendfile_force_fallback(self): - srv_proto, cli_proto = self.prepare_sendfile() - - def sendfile_native(transp, file, offset, count): - # to raise SendfileNotAvailableError - return base_events.BaseEventLoop._sendfile_native( - self.loop, transp, file, offset, count) - - self.loop._sendfile_native = sendfile_native - - ret = self.run_loop( - self.loop.sendfile(cli_proto.transport, self.file)) - cli_proto.transport.close() - self.run_loop(srv_proto.done) - self.assertEqual(ret, len(self.DATA)) - self.assertEqual(srv_proto.nbytes, len(self.DATA)) - self.assertEqual(srv_proto.data, self.DATA) - self.assertEqual(self.file.tell(), len(self.DATA)) - - def test_sendfile_force_unsupported_native(self): - if sys.platform == 'win32': - if isinstance(self.loop, asyncio.ProactorEventLoop): - self.skipTest("Fails on proactor event loop") - srv_proto, cli_proto = self.prepare_sendfile() - - def sendfile_native(transp, file, offset, count): - # to raise SendfileNotAvailableError - return base_events.BaseEventLoop._sendfile_native( - self.loop, transp, file, offset, count) - - self.loop._sendfile_native = sendfile_native - - with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, - "not supported"): - self.run_loop( - self.loop.sendfile(cli_proto.transport, self.file, - fallback=False)) - - cli_proto.transport.close() - self.run_loop(srv_proto.done) - self.assertEqual(srv_proto.nbytes, 0) - self.assertEqual(self.file.tell(), 0) - - def test_sendfile_ssl(self): - srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) - ret = self.run_loop( - self.loop.sendfile(cli_proto.transport, self.file)) - cli_proto.transport.close() - self.run_loop(srv_proto.done) - self.assertEqual(ret, len(self.DATA)) - self.assertEqual(srv_proto.nbytes, len(self.DATA)) - self.assertEqual(srv_proto.data, self.DATA) - self.assertEqual(self.file.tell(), len(self.DATA)) - - def test_sendfile_for_closing_transp(self): - srv_proto, cli_proto = self.prepare_sendfile() - cli_proto.transport.close() - with self.assertRaisesRegex(RuntimeError, "is closing"): - self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) - self.run_loop(srv_proto.done) - self.assertEqual(srv_proto.nbytes, 0) - self.assertEqual(self.file.tell(), 0) - - def test_sendfile_pre_and_post_data(self): - srv_proto, cli_proto = self.prepare_sendfile() - PREFIX = b'PREFIX__' * 1024 # 8 KiB - SUFFIX = b'--SUFFIX' * 1024 # 8 KiB - cli_proto.transport.write(PREFIX) - ret = self.run_loop( - self.loop.sendfile(cli_proto.transport, self.file)) - cli_proto.transport.write(SUFFIX) - cli_proto.transport.close() - self.run_loop(srv_proto.done) - self.assertEqual(ret, len(self.DATA)) - self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) - self.assertEqual(self.file.tell(), len(self.DATA)) - - def test_sendfile_ssl_pre_and_post_data(self): - srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) - PREFIX = b'zxcvbnm' * 1024 - SUFFIX = b'0987654321' * 1024 - cli_proto.transport.write(PREFIX) - ret = self.run_loop( - self.loop.sendfile(cli_proto.transport, self.file)) - cli_proto.transport.write(SUFFIX) - cli_proto.transport.close() - self.run_loop(srv_proto.done) - self.assertEqual(ret, len(self.DATA)) - self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) - self.assertEqual(self.file.tell(), len(self.DATA)) - - def test_sendfile_partial(self): - srv_proto, cli_proto = self.prepare_sendfile() - ret = self.run_loop( - self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) - cli_proto.transport.close() - self.run_loop(srv_proto.done) - self.assertEqual(ret, 100) - self.assertEqual(srv_proto.nbytes, 100) - self.assertEqual(srv_proto.data, self.DATA[1000:1100]) - self.assertEqual(self.file.tell(), 1100) - - def test_sendfile_ssl_partial(self): - srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) - ret = self.run_loop( - self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) - cli_proto.transport.close() - self.run_loop(srv_proto.done) - self.assertEqual(ret, 100) - self.assertEqual(srv_proto.nbytes, 100) - self.assertEqual(srv_proto.data, self.DATA[1000:1100]) - self.assertEqual(self.file.tell(), 1100) - - def test_sendfile_close_peer_after_receiving(self): - srv_proto, cli_proto = self.prepare_sendfile( - close_after=len(self.DATA)) - ret = self.run_loop( - self.loop.sendfile(cli_proto.transport, self.file)) - cli_proto.transport.close() - self.run_loop(srv_proto.done) - self.assertEqual(ret, len(self.DATA)) - self.assertEqual(srv_proto.nbytes, len(self.DATA)) - self.assertEqual(srv_proto.data, self.DATA) - self.assertEqual(self.file.tell(), len(self.DATA)) - - def test_sendfile_ssl_close_peer_after_receiving(self): - srv_proto, cli_proto = self.prepare_sendfile( - is_ssl=True, close_after=len(self.DATA)) - ret = self.run_loop( - self.loop.sendfile(cli_proto.transport, self.file)) - self.run_loop(srv_proto.done) - self.assertEqual(ret, len(self.DATA)) - self.assertEqual(srv_proto.nbytes, len(self.DATA)) - self.assertEqual(srv_proto.data, self.DATA) - self.assertEqual(self.file.tell(), len(self.DATA)) - - def test_sendfile_close_peer_in_the_middle_of_receiving(self): - srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) - with self.assertRaises(ConnectionError): - self.run_loop( - self.loop.sendfile(cli_proto.transport, self.file)) - self.run_loop(srv_proto.done) - - self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), - srv_proto.nbytes) - self.assertTrue(1024 <= self.file.tell() < len(self.DATA), - self.file.tell()) - self.assertTrue(cli_proto.transport.is_closing()) - - def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self): - - def sendfile_native(transp, file, offset, count): - # to raise SendfileNotAvailableError - return base_events.BaseEventLoop._sendfile_native( - self.loop, transp, file, offset, count) - - self.loop._sendfile_native = sendfile_native - - srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) - with self.assertRaises(ConnectionError): - self.run_loop( - self.loop.sendfile(cli_proto.transport, self.file)) - self.run_loop(srv_proto.done) - - self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), - srv_proto.nbytes) - self.assertTrue(1024 <= self.file.tell() < len(self.DATA), - self.file.tell()) - - @unittest.skipIf(not hasattr(os, 'sendfile'), - "Don't have native sendfile support") - def test_sendfile_prevents_bare_write(self): - srv_proto, cli_proto = self.prepare_sendfile() - fut = self.loop.create_future() - - async def coro(): - fut.set_result(None) - return await self.loop.sendfile(cli_proto.transport, self.file) - - t = self.loop.create_task(coro()) - self.run_loop(fut) - with self.assertRaisesRegex(RuntimeError, - "sendfile is in progress"): - cli_proto.transport.write(b'data') - ret = self.run_loop(t) - self.assertEqual(ret, len(self.DATA)) - - def test_sendfile_no_fallback_for_fallback_transport(self): - transport = mock.Mock() - transport.is_closing.side_effect = lambda: False - transport._sendfile_compatible = constants._SendfileMode.FALLBACK - with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'): - self.loop.run_until_complete( - self.loop.sendfile(transport, None, fallback=False)) - - if sys.platform == 'win32': class SelectEventLoopTests(EventLoopTestsMixin, - SendfileMixin, - SockSendfileMixin, test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop() class ProactorEventLoopTests(EventLoopTestsMixin, - SendfileMixin, - SockSendfileMixin, SubprocessTestsMixin, test_utils.TestCase): @@ -2469,9 +2022,7 @@ if sys.platform == 'win32': else: import selectors - class UnixEventLoopTestsMixin(EventLoopTestsMixin, - SendfileMixin, - SockSendfileMixin): + class UnixEventLoopTestsMixin(EventLoopTestsMixin): def setUp(self): super().setUp() watcher = asyncio.SafeChildWatcher() diff --git a/Lib/test/test_asyncio/test_sendfile.py b/Lib/test/test_asyncio/test_sendfile.py new file mode 100644 index 0000000..26e44a3 --- /dev/null +++ b/Lib/test/test_asyncio/test_sendfile.py @@ -0,0 +1,550 @@ +"""Tests for sendfile functionality.""" + +import asyncio +import os +import socket +import sys +import tempfile +import unittest +from asyncio import base_events +from asyncio import constants +from unittest import mock +from test import support +from test.test_asyncio import utils as test_utils + +try: + import ssl +except ImportError: + ssl = None + + +class MySendfileProto(asyncio.Protocol): + + def __init__(self, loop=None, close_after=0): + self.transport = None + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.connected = loop.create_future() + self.done = loop.create_future() + self.data = bytearray() + self.close_after = close_after + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + if self.connected: + self.connected.set_result(None) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + self.data.extend(data) + super().data_received(data) + if self.close_after and self.nbytes >= self.close_after: + self.transport.close() + + +class MyProto(asyncio.Protocol): + + def __init__(self, loop): + self.started = False + self.closed = False + self.data = bytearray() + self.fut = loop.create_future() + self.transport = None + + def connection_made(self, transport): + self.started = True + self.transport = transport + + def data_received(self, data): + self.data.extend(data) + + def connection_lost(self, exc): + self.closed = True + self.fut.set_result(None) + + async def wait_closed(self): + await self.fut + + +class SendfileBase: + + DATA = b"SendfileBaseData" * (1024 * 8) # 128 KiB + + # Reduce socket buffer size to test on relative small data sets. + BUF_SIZE = 4 * 1024 # 4 KiB + + def create_event_loop(self): + raise NotImplementedError + + @classmethod + def setUpClass(cls): + with open(support.TESTFN, 'wb') as fp: + fp.write(cls.DATA) + super().setUpClass() + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + super().tearDownClass() + + def setUp(self): + self.file = open(support.TESTFN, 'rb') + self.addCleanup(self.file.close) + self.loop = self.create_event_loop() + self.set_event_loop(self.loop) + super().setUp() + + def tearDown(self): + # just in case if we have transport close callbacks + if not self.loop.is_closed(): + test_utils.run_briefly(self.loop) + + self.doCleanups() + support.gc_collect() + super().tearDown() + + def run_loop(self, coro): + return self.loop.run_until_complete(coro) + + +class SockSendfileMixin(SendfileBase): + + @classmethod + def setUpClass(cls): + cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE + constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16 + super().setUpClass() + + @classmethod + def tearDownClass(cls): + constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize + super().tearDownClass() + + def make_socket(self, cleanup=True): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + if cleanup: + self.addCleanup(sock.close) + return sock + + def reduce_receive_buffer_size(self, sock): + # Reduce receive socket buffer size to test on relative + # small data sets. + sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE) + + def reduce_send_buffer_size(self, sock, transport=None): + # Reduce send socket buffer size to test on relative small data sets. + + # On macOS, SO_SNDBUF is reset by connect(). So this method + # should be called after the socket is connected. + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE) + + if transport is not None: + transport.set_write_buffer_limits(high=self.BUF_SIZE) + + def prepare_socksendfile(self): + proto = MyProto(self.loop) + port = support.find_unused_port() + srv_sock = self.make_socket(cleanup=False) + srv_sock.bind((support.HOST, port)) + server = self.run_loop(self.loop.create_server( + lambda: proto, sock=srv_sock)) + self.reduce_receive_buffer_size(srv_sock) + + sock = self.make_socket() + self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port))) + self.reduce_send_buffer_size(sock) + + def cleanup(): + if proto.transport is not None: + # can be None if the task was cancelled before + # connection_made callback + proto.transport.close() + self.run_loop(proto.wait_closed()) + + server.close() + self.run_loop(server.wait_closed()) + + self.addCleanup(cleanup) + + return sock, proto + + def test_sock_sendfile_success(self): + sock, proto = self.prepare_socksendfile() + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sock_sendfile_with_offset_and_count(self): + sock, proto = self.prepare_socksendfile() + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, + 1000, 2000)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(proto.data, self.DATA[1000:3000]) + self.assertEqual(self.file.tell(), 3000) + self.assertEqual(ret, 2000) + + def test_sock_sendfile_zero_size(self): + sock, proto = self.prepare_socksendfile() + with tempfile.TemporaryFile() as f: + ret = self.run_loop(self.loop.sock_sendfile(sock, f, + 0, None)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, 0) + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_mix_with_regular_send(self): + buf = b"mix_regular_send" * (4 * 1024) # 64 KiB + sock, proto = self.prepare_socksendfile() + self.run_loop(self.loop.sock_sendall(sock, buf)) + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) + self.run_loop(self.loop.sock_sendall(sock, buf)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, len(self.DATA)) + expected = buf + self.DATA + buf + self.assertEqual(proto.data, expected) + self.assertEqual(self.file.tell(), len(self.DATA)) + + +class SendfileMixin(SendfileBase): + + # Note: sendfile via SSL transport is equal to sendfile fallback + + def prepare_sendfile(self, *, is_ssl=False, close_after=0): + port = support.find_unused_port() + srv_proto = MySendfileProto(loop=self.loop, + close_after=close_after) + if is_ssl: + if not ssl: + self.skipTest("No ssl module") + srv_ctx = test_utils.simple_server_sslcontext() + cli_ctx = test_utils.simple_client_sslcontext() + else: + srv_ctx = None + cli_ctx = None + srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + srv_sock.bind((support.HOST, port)) + server = self.run_loop(self.loop.create_server( + lambda: srv_proto, sock=srv_sock, ssl=srv_ctx)) + self.reduce_receive_buffer_size(srv_sock) + + if is_ssl: + server_hostname = support.HOST + else: + server_hostname = None + cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + cli_sock.connect((support.HOST, port)) + + cli_proto = MySendfileProto(loop=self.loop) + tr, pr = self.run_loop(self.loop.create_connection( + lambda: cli_proto, sock=cli_sock, + ssl=cli_ctx, server_hostname=server_hostname)) + self.reduce_send_buffer_size(cli_sock, transport=tr) + + def cleanup(): + srv_proto.transport.close() + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.run_loop(cli_proto.done) + + server.close() + self.run_loop(server.wait_closed()) + + self.addCleanup(cleanup) + return srv_proto, cli_proto + + @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported") + def test_sendfile_not_supported(self): + tr, pr = self.run_loop( + self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, + family=socket.AF_INET)) + try: + with self.assertRaisesRegex(RuntimeError, "not supported"): + self.run_loop( + self.loop.sendfile(tr, self.file)) + self.assertEqual(0, self.file.tell()) + finally: + # don't use self.addCleanup because it produces resource warning + tr.close() + + def test_sendfile(self): + srv_proto, cli_proto = self.prepare_sendfile() + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_force_fallback(self): + srv_proto, cli_proto = self.prepare_sendfile() + + def sendfile_native(transp, file, offset, count): + # to raise SendfileNotAvailableError + return base_events.BaseEventLoop._sendfile_native( + self.loop, transp, file, offset, count) + + self.loop._sendfile_native = sendfile_native + + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_force_unsupported_native(self): + if sys.platform == 'win32': + if isinstance(self.loop, asyncio.ProactorEventLoop): + self.skipTest("Fails on proactor event loop") + srv_proto, cli_proto = self.prepare_sendfile() + + def sendfile_native(transp, file, offset, count): + # to raise SendfileNotAvailableError + return base_events.BaseEventLoop._sendfile_native( + self.loop, transp, file, offset, count) + + self.loop._sendfile_native = sendfile_native + + with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, + "not supported"): + self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file, + fallback=False)) + + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(srv_proto.nbytes, 0) + self.assertEqual(self.file.tell(), 0) + + def test_sendfile_ssl(self): + srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_for_closing_transp(self): + srv_proto, cli_proto = self.prepare_sendfile() + cli_proto.transport.close() + with self.assertRaisesRegex(RuntimeError, "is closing"): + self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + self.assertEqual(srv_proto.nbytes, 0) + self.assertEqual(self.file.tell(), 0) + + def test_sendfile_pre_and_post_data(self): + srv_proto, cli_proto = self.prepare_sendfile() + PREFIX = b'PREFIX__' * 1024 # 8 KiB + SUFFIX = b'--SUFFIX' * 1024 # 8 KiB + cli_proto.transport.write(PREFIX) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.write(SUFFIX) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_ssl_pre_and_post_data(self): + srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) + PREFIX = b'zxcvbnm' * 1024 + SUFFIX = b'0987654321' * 1024 + cli_proto.transport.write(PREFIX) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.write(SUFFIX) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_partial(self): + srv_proto, cli_proto = self.prepare_sendfile() + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, 100) + self.assertEqual(srv_proto.nbytes, 100) + self.assertEqual(srv_proto.data, self.DATA[1000:1100]) + self.assertEqual(self.file.tell(), 1100) + + def test_sendfile_ssl_partial(self): + srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, 100) + self.assertEqual(srv_proto.nbytes, 100) + self.assertEqual(srv_proto.data, self.DATA[1000:1100]) + self.assertEqual(self.file.tell(), 1100) + + def test_sendfile_close_peer_after_receiving(self): + srv_proto, cli_proto = self.prepare_sendfile( + close_after=len(self.DATA)) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_ssl_close_peer_after_receiving(self): + srv_proto, cli_proto = self.prepare_sendfile( + is_ssl=True, close_after=len(self.DATA)) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_close_peer_in_the_middle_of_receiving(self): + srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) + with self.assertRaises(ConnectionError): + self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + + self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), + srv_proto.nbytes) + self.assertTrue(1024 <= self.file.tell() < len(self.DATA), + self.file.tell()) + self.assertTrue(cli_proto.transport.is_closing()) + + def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self): + + def sendfile_native(transp, file, offset, count): + # to raise SendfileNotAvailableError + return base_events.BaseEventLoop._sendfile_native( + self.loop, transp, file, offset, count) + + self.loop._sendfile_native = sendfile_native + + srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) + with self.assertRaises(ConnectionError): + self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + + self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), + srv_proto.nbytes) + self.assertTrue(1024 <= self.file.tell() < len(self.DATA), + self.file.tell()) + + @unittest.skipIf(not hasattr(os, 'sendfile'), + "Don't have native sendfile support") + def test_sendfile_prevents_bare_write(self): + srv_proto, cli_proto = self.prepare_sendfile() + fut = self.loop.create_future() + + async def coro(): + fut.set_result(None) + return await self.loop.sendfile(cli_proto.transport, self.file) + + t = self.loop.create_task(coro()) + self.run_loop(fut) + with self.assertRaisesRegex(RuntimeError, + "sendfile is in progress"): + cli_proto.transport.write(b'data') + ret = self.run_loop(t) + self.assertEqual(ret, len(self.DATA)) + + def test_sendfile_no_fallback_for_fallback_transport(self): + transport = mock.Mock() + transport.is_closing.side_effect = lambda: False + transport._sendfile_compatible = constants._SendfileMode.FALLBACK + with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'): + self.loop.run_until_complete( + self.loop.sendfile(transport, None, fallback=False)) + + +class SendfileTestsBase(SendfileMixin, SockSendfileMixin): + pass + + +if sys.platform == 'win32': + + class SelectEventLoopTests(SendfileTestsBase, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop() + + class ProactorEventLoopTests(SendfileTestsBase, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.ProactorEventLoop() + +else: + import selectors + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(SendfileTestsBase, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(SendfileTestsBase, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(SendfileTestsBase, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(SendfileTestsBase, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.SelectSelector()) |