summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_asyncio/test_events.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_asyncio/test_events.py')
-rw-r--r--Lib/test/test_asyncio/test_events.py303
1 files changed, 300 insertions, 3 deletions
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
index cf21753..0981bd6 100644
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -26,6 +26,7 @@ if sys.platform != 'win32':
import tty
import asyncio
+from asyncio import base_events
from asyncio import coroutines
from asyncio import events
from asyncio import proactor_events
@@ -2090,14 +2091,308 @@ class SubprocessTestsMixin:
self.loop.run_until_complete(connect(shell=False))
+class MySendfileProto(MyBaseProto):
+
+ def __init__(self, loop=None, close_after=0):
+ super().__init__(loop)
+ self.data = bytearray()
+ self.close_after = close_after
+
+ def data_received(self, data):
+ self.data.extend(data)
+ super().data_received(data)
+ if self.close_after and self.nbytes >= self.close_after:
+ self.transport.close()
+
+
+class SendfileMixin:
+ # Note: sendfile via SSL transport is equal to sendfile fallback
+
+ DATA = b"12345abcde" * 160 * 1024 # 160 KiB
+
+ @classmethod
+ def setUpClass(cls):
+ with open(support.TESTFN, 'wb') as fp:
+ fp.write(cls.DATA)
+ super().setUpClass()
+
+ @classmethod
+ def tearDownClass(cls):
+ support.unlink(support.TESTFN)
+ super().tearDownClass()
+
+ def setUp(self):
+ self.file = open(support.TESTFN, 'rb')
+ self.addCleanup(self.file.close)
+ super().setUp()
+
+ def run_loop(self, coro):
+ return self.loop.run_until_complete(coro)
+
+ def prepare(self, *, is_ssl=False, close_after=0):
+ port = support.find_unused_port()
+ srv_proto = MySendfileProto(loop=self.loop, close_after=close_after)
+ if is_ssl:
+ srv_ctx = test_utils.simple_server_sslcontext()
+ cli_ctx = test_utils.simple_client_sslcontext()
+ else:
+ srv_ctx = None
+ cli_ctx = None
+ srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # reduce recv socket buffer size to test on relative small data sets
+ srv_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
+ srv_sock.bind((support.HOST, port))
+ server = self.run_loop(self.loop.create_server(
+ lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
+
+ if is_ssl:
+ server_hostname = support.HOST
+ else:
+ server_hostname = None
+ cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # reduce send socket buffer size to test on relative small data sets
+ cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
+ cli_sock.connect((support.HOST, port))
+ cli_proto = MySendfileProto(loop=self.loop)
+ tr, pr = self.run_loop(self.loop.create_connection(
+ lambda: cli_proto, sock=cli_sock,
+ ssl=cli_ctx, server_hostname=server_hostname))
+
+ def cleanup():
+ srv_proto.transport.close()
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.run_loop(cli_proto.done)
+
+ server.close()
+ self.run_loop(server.wait_closed())
+
+ self.addCleanup(cleanup)
+ return srv_proto, cli_proto
+
+ @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
+ def test_sendfile_not_supported(self):
+ tr, pr = self.run_loop(
+ self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(loop=self.loop),
+ family=socket.AF_INET))
+ try:
+ with self.assertRaisesRegex(RuntimeError, "not supported"):
+ self.run_loop(
+ self.loop.sendfile(tr, self.file))
+ self.assertEqual(0, self.file.tell())
+ finally:
+ # don't use self.addCleanup because it produces resource warning
+ tr.close()
+
+ def test_sendfile(self):
+ srv_proto, cli_proto = self.prepare()
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_force_fallback(self):
+ srv_proto, cli_proto = self.prepare()
+
+ def sendfile_native(transp, file, offset, count):
+ # to raise SendfileNotAvailableError
+ return base_events.BaseEventLoop._sendfile_native(
+ self.loop, transp, file, offset, count)
+
+ self.loop._sendfile_native = sendfile_native
+
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_force_unsupported_native(self):
+ if sys.platform == 'win32':
+ if isinstance(self.loop, asyncio.ProactorEventLoop):
+ self.skipTest("Fails on proactor event loop")
+ srv_proto, cli_proto = self.prepare()
+
+ def sendfile_native(transp, file, offset, count):
+ # to raise SendfileNotAvailableError
+ return base_events.BaseEventLoop._sendfile_native(
+ self.loop, transp, file, offset, count)
+
+ self.loop._sendfile_native = sendfile_native
+
+ with self.assertRaisesRegex(events.SendfileNotAvailableError,
+ "not supported"):
+ self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file,
+ fallback=False))
+
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(srv_proto.nbytes, 0)
+ self.assertEqual(self.file.tell(), 0)
+
+ def test_sendfile_ssl(self):
+ srv_proto, cli_proto = self.prepare(is_ssl=True)
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_for_closing_transp(self):
+ srv_proto, cli_proto = self.prepare()
+ cli_proto.transport.close()
+ with self.assertRaisesRegex(RuntimeError, "is closing"):
+ self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
+ self.run_loop(srv_proto.done)
+ self.assertEqual(srv_proto.nbytes, 0)
+ self.assertEqual(self.file.tell(), 0)
+
+ def test_sendfile_pre_and_post_data(self):
+ srv_proto, cli_proto = self.prepare()
+ PREFIX = b'zxcvbnm' * 1024
+ SUFFIX = b'0987654321' * 1024
+ cli_proto.transport.write(PREFIX)
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.write(SUFFIX)
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_ssl_pre_and_post_data(self):
+ srv_proto, cli_proto = self.prepare(is_ssl=True)
+ PREFIX = b'zxcvbnm' * 1024
+ SUFFIX = b'0987654321' * 1024
+ cli_proto.transport.write(PREFIX)
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.write(SUFFIX)
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_partial(self):
+ srv_proto, cli_proto = self.prepare()
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, 100)
+ self.assertEqual(srv_proto.nbytes, 100)
+ self.assertEqual(srv_proto.data, self.DATA[1000:1100])
+ self.assertEqual(self.file.tell(), 1100)
+
+ def test_sendfile_ssl_partial(self):
+ srv_proto, cli_proto = self.prepare(is_ssl=True)
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, 100)
+ self.assertEqual(srv_proto.nbytes, 100)
+ self.assertEqual(srv_proto.data, self.DATA[1000:1100])
+ self.assertEqual(self.file.tell(), 1100)
+
+ def test_sendfile_close_peer_after_receiving(self):
+ srv_proto, cli_proto = self.prepare(close_after=len(self.DATA))
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_ssl_close_peer_after_receiving(self):
+ srv_proto, cli_proto = self.prepare(is_ssl=True,
+ close_after=len(self.DATA))
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_close_peer_in_middle_of_receiving(self):
+ srv_proto, cli_proto = self.prepare(close_after=1024)
+ with self.assertRaises(ConnectionError):
+ self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ self.run_loop(srv_proto.done)
+
+ self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
+ srv_proto.nbytes)
+ self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
+ self.file.tell())
+
+ def test_sendfile_fallback_close_peer_in_middle_of_receiving(self):
+
+ def sendfile_native(transp, file, offset, count):
+ # to raise SendfileNotAvailableError
+ return base_events.BaseEventLoop._sendfile_native(
+ self.loop, transp, file, offset, count)
+
+ self.loop._sendfile_native = sendfile_native
+
+ srv_proto, cli_proto = self.prepare(close_after=1024)
+ with self.assertRaises(ConnectionError):
+ self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ self.run_loop(srv_proto.done)
+
+ self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
+ srv_proto.nbytes)
+ self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
+ self.file.tell())
+
+ @unittest.skipIf(not hasattr(os, 'sendfile'),
+ "Don't have native sendfile support")
+ def test_sendfile_prevents_bare_write(self):
+ srv_proto, cli_proto = self.prepare()
+ fut = self.loop.create_future()
+
+ async def coro():
+ fut.set_result(None)
+ return await self.loop.sendfile(cli_proto.transport, self.file)
+
+ t = self.loop.create_task(coro())
+ self.run_loop(fut)
+ with self.assertRaisesRegex(RuntimeError,
+ "sendfile is in progress"):
+ cli_proto.transport.write(b'data')
+ ret = self.run_loop(t)
+ self.assertEqual(ret, len(self.DATA))
+
+
if sys.platform == 'win32':
- class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase):
+ class SelectEventLoopTests(EventLoopTestsMixin,
+ SendfileMixin,
+ test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop()
class ProactorEventLoopTests(EventLoopTestsMixin,
+ SendfileMixin,
SubprocessTestsMixin,
test_utils.TestCase):
@@ -2125,7 +2420,7 @@ if sys.platform == 'win32':
else:
import selectors
- class UnixEventLoopTestsMixin(EventLoopTestsMixin):
+ class UnixEventLoopTestsMixin(EventLoopTestsMixin, SendfileMixin):
def setUp(self):
super().setUp()
watcher = asyncio.SafeChildWatcher()
@@ -2556,7 +2851,9 @@ class AbstractEventLoopTests(unittest.TestCase):
with self.assertRaises(NotImplementedError):
await loop.sock_accept(f)
with self.assertRaises(NotImplementedError):
- await loop.sock_sendfile(f, mock.Mock())
+ await loop.sock_sendfile(f, f)
+ with self.assertRaises(NotImplementedError):
+ await loop.sendfile(f, f)
with self.assertRaises(NotImplementedError):
await loop.connect_read_pipe(f, mock.sentinel.pipe)
with self.assertRaises(NotImplementedError):