diff options
author | Andrew Svetlov <andrew.svetlov@gmail.com> | 2018-02-25 16:32:14 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-25 16:32:14 (GMT) |
commit | a19fb3c6aaa7632410d1d9dcb395d7101d124da4 (patch) | |
tree | 476902dc75526cc8bb22c41cf213416384c36805 /Lib/test/test_asyncio/test_events.py | |
parent | 5fb632e83136399bad9427ee23ec8b771695290a (diff) | |
download | cpython-a19fb3c6aaa7632410d1d9dcb395d7101d124da4.zip cpython-a19fb3c6aaa7632410d1d9dcb395d7101d124da4.tar.gz cpython-a19fb3c6aaa7632410d1d9dcb395d7101d124da4.tar.bz2 |
bpo-32622: Native sendfile on windows (#5565)
* Support sendfile on Windows Proactor event loop naively.
Diffstat (limited to 'Lib/test/test_asyncio/test_events.py')
-rw-r--r-- | Lib/test/test_asyncio/test_events.py | 187 |
1 files changed, 152 insertions, 35 deletions
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index f599597..6accbda 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -15,6 +15,7 @@ except ImportError: ssl = None import subprocess import sys +import tempfile import threading import time import errno @@ -2092,22 +2093,7 @@ class SubprocessTestsMixin: self.loop.run_until_complete(connect(shell=False)) -class MySendfileProto(MyBaseProto): - - def __init__(self, loop=None, close_after=0): - super().__init__(loop) - self.data = bytearray() - self.close_after = close_after - - def data_received(self, data): - self.data.extend(data) - super().data_received(data) - if self.close_after and self.nbytes >= self.close_after: - self.transport.close() - - -class SendfileMixin: - # Note: sendfile via SSL transport is equal to sendfile fallback +class SendfileBase: DATA = b"12345abcde" * 160 * 1024 # 160 KiB @@ -2130,9 +2116,134 @@ class SendfileMixin: def run_loop(self, coro): return self.loop.run_until_complete(coro) - def prepare(self, *, is_ssl=False, close_after=0): + +class SockSendfileMixin(SendfileBase): + + class MyProto(asyncio.Protocol): + + def __init__(self, loop): + self.started = False + self.closed = False + self.data = bytearray() + self.fut = loop.create_future() + self.transport = None + + def connection_made(self, transport): + self.started = True + self.transport = transport + + def data_received(self, data): + self.data.extend(data) + + def connection_lost(self, exc): + self.closed = True + self.fut.set_result(None) + + async def wait_closed(self): + await self.fut + + def make_socket(self, cleanup=True): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024) + if cleanup: + self.addCleanup(sock.close) + return sock + + def prepare_socksendfile(self): + sock = self.make_socket() + proto = self.MyProto(self.loop) + port = support.find_unused_port() + srv_sock = self.make_socket(cleanup=False) + srv_sock.bind((support.HOST, port)) + server = self.run_loop(self.loop.create_server( + lambda: proto, sock=srv_sock)) + self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port))) + + def cleanup(): + if proto.transport is not None: + # can be None if the task was cancelled before + # connection_made callback + proto.transport.close() + self.run_loop(proto.wait_closed()) + + server.close() + self.run_loop(server.wait_closed()) + + self.addCleanup(cleanup) + + return sock, proto + + def test_sock_sendfile_success(self): + sock, proto = self.prepare_socksendfile() + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sock_sendfile_with_offset_and_count(self): + sock, proto = self.prepare_socksendfile() + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, + 1000, 2000)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(proto.data, self.DATA[1000:3000]) + self.assertEqual(self.file.tell(), 3000) + self.assertEqual(ret, 2000) + + def test_sock_sendfile_zero_size(self): + sock, proto = self.prepare_socksendfile() + with tempfile.TemporaryFile() as f: + ret = self.run_loop(self.loop.sock_sendfile(sock, f, + 0, None)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, 0) + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_mix_with_regular_send(self): + buf = b'1234567890' * 1024 * 1024 # 10 MB + sock, proto = self.prepare_socksendfile() + self.run_loop(self.loop.sock_sendall(sock, buf)) + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) + self.run_loop(self.loop.sock_sendall(sock, buf)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, len(self.DATA)) + expected = buf + self.DATA + buf + self.assertEqual(proto.data, expected) + self.assertEqual(self.file.tell(), len(self.DATA)) + + +class SendfileMixin(SendfileBase): + + class MySendfileProto(MyBaseProto): + + def __init__(self, loop=None, close_after=0): + super().__init__(loop) + self.data = bytearray() + self.close_after = close_after + + def data_received(self, data): + self.data.extend(data) + super().data_received(data) + if self.close_after and self.nbytes >= self.close_after: + self.transport.close() + + + # Note: sendfile via SSL transport is equal to sendfile fallback + + def prepare_sendfile(self, *, is_ssl=False, close_after=0): port = support.find_unused_port() - srv_proto = MySendfileProto(loop=self.loop, close_after=close_after) + srv_proto = self.MySendfileProto(loop=self.loop, + close_after=close_after) if is_ssl: if not ssl: self.skipTest("No ssl module") @@ -2156,7 +2267,7 @@ class SendfileMixin: # reduce send socket buffer size to test on relative small data sets cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) cli_sock.connect((support.HOST, port)) - cli_proto = MySendfileProto(loop=self.loop) + cli_proto = self.MySendfileProto(loop=self.loop) tr, pr = self.run_loop(self.loop.create_connection( lambda: cli_proto, sock=cli_sock, ssl=cli_ctx, server_hostname=server_hostname)) @@ -2189,7 +2300,7 @@ class SendfileMixin: tr.close() def test_sendfile(self): - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) cli_proto.transport.close() @@ -2200,7 +2311,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_force_fallback(self): - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() def sendfile_native(transp, file, offset, count): # to raise SendfileNotAvailableError @@ -2222,7 +2333,7 @@ class SendfileMixin: if sys.platform == 'win32': if isinstance(self.loop, asyncio.ProactorEventLoop): self.skipTest("Fails on proactor event loop") - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() def sendfile_native(transp, file, offset, count): # to raise SendfileNotAvailableError @@ -2243,7 +2354,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), 0) def test_sendfile_ssl(self): - srv_proto, cli_proto = self.prepare(is_ssl=True) + srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) cli_proto.transport.close() @@ -2254,7 +2365,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_for_closing_transp(self): - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() cli_proto.transport.close() with self.assertRaisesRegex(RuntimeError, "is closing"): self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) @@ -2263,7 +2374,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), 0) def test_sendfile_pre_and_post_data(self): - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() PREFIX = b'zxcvbnm' * 1024 SUFFIX = b'0987654321' * 1024 cli_proto.transport.write(PREFIX) @@ -2277,7 +2388,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_ssl_pre_and_post_data(self): - srv_proto, cli_proto = self.prepare(is_ssl=True) + srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) PREFIX = b'zxcvbnm' * 1024 SUFFIX = b'0987654321' * 1024 cli_proto.transport.write(PREFIX) @@ -2291,7 +2402,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_partial(self): - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) cli_proto.transport.close() @@ -2302,7 +2413,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), 1100) def test_sendfile_ssl_partial(self): - srv_proto, cli_proto = self.prepare(is_ssl=True) + srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) cli_proto.transport.close() @@ -2313,7 +2424,8 @@ class SendfileMixin: self.assertEqual(self.file.tell(), 1100) def test_sendfile_close_peer_after_receiving(self): - srv_proto, cli_proto = self.prepare(close_after=len(self.DATA)) + srv_proto, cli_proto = self.prepare_sendfile( + close_after=len(self.DATA)) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) cli_proto.transport.close() @@ -2324,8 +2436,8 @@ class SendfileMixin: self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_ssl_close_peer_after_receiving(self): - srv_proto, cli_proto = self.prepare(is_ssl=True, - close_after=len(self.DATA)) + srv_proto, cli_proto = self.prepare_sendfile( + is_ssl=True, close_after=len(self.DATA)) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) self.run_loop(srv_proto.done) @@ -2335,7 +2447,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_close_peer_in_middle_of_receiving(self): - srv_proto, cli_proto = self.prepare(close_after=1024) + srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) with self.assertRaises(ConnectionError): self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) @@ -2345,6 +2457,7 @@ class SendfileMixin: srv_proto.nbytes) self.assertTrue(1024 <= self.file.tell() < len(self.DATA), self.file.tell()) + self.assertTrue(cli_proto.transport.is_closing()) def test_sendfile_fallback_close_peer_in_middle_of_receiving(self): @@ -2355,7 +2468,7 @@ class SendfileMixin: self.loop._sendfile_native = sendfile_native - srv_proto, cli_proto = self.prepare(close_after=1024) + srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) with self.assertRaises(ConnectionError): self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) @@ -2369,7 +2482,7 @@ class SendfileMixin: @unittest.skipIf(not hasattr(os, 'sendfile'), "Don't have native sendfile support") def test_sendfile_prevents_bare_write(self): - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() fut = self.loop.create_future() async def coro(): @@ -2397,6 +2510,7 @@ if sys.platform == 'win32': class SelectEventLoopTests(EventLoopTestsMixin, SendfileMixin, + SockSendfileMixin, test_utils.TestCase): def create_event_loop(self): @@ -2404,6 +2518,7 @@ if sys.platform == 'win32': class ProactorEventLoopTests(EventLoopTestsMixin, SendfileMixin, + SockSendfileMixin, SubprocessTestsMixin, test_utils.TestCase): @@ -2431,7 +2546,9 @@ if sys.platform == 'win32': else: import selectors - class UnixEventLoopTestsMixin(EventLoopTestsMixin, SendfileMixin): + class UnixEventLoopTestsMixin(EventLoopTestsMixin, + SendfileMixin, + SockSendfileMixin): def setUp(self): super().setUp() watcher = asyncio.SafeChildWatcher() |