summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAndrew Svetlov <andrew.svetlov@gmail.com>2018-10-09 04:52:57 (GMT)
committerGitHub <noreply@github.com>2018-10-09 04:52:57 (GMT)
commit2b2758d0b30f4ed7d37319d6c18552eccbc8e7b7 (patch)
treec75330f29ba7fc380dc182bfd16657c51e2c84bb
parent199a280af540e3194405eb250ca1a8d487f6a4f7 (diff)
downloadcpython-2b2758d0b30f4ed7d37319d6c18552eccbc8e7b7.zip
cpython-2b2758d0b30f4ed7d37319d6c18552eccbc8e7b7.tar.gz
cpython-2b2758d0b30f4ed7d37319d6c18552eccbc8e7b7.tar.bz2
Extract sendfile tests into a separate test file (#9757)
-rw-r--r--Lib/test/test_asyncio/test_events.py451
-rw-r--r--Lib/test/test_asyncio/test_sendfile.py550
2 files changed, 551 insertions, 450 deletions
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
index 607c195..b76cfb7 100644
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -15,7 +15,6 @@ except ImportError:
ssl = None
import subprocess
import sys
-import tempfile
import threading
import time
import errno
@@ -1987,461 +1986,15 @@ class SubprocessTestsMixin:
self.loop.run_until_complete(connect(shell=False))
-class SendfileBase:
-
- DATA = b"SendfileBaseData" * (1024 * 8) # 128 KiB
-
- # Reduce socket buffer size to test on relative small data sets.
- BUF_SIZE = 4 * 1024 # 4 KiB
-
- @classmethod
- def setUpClass(cls):
- with open(support.TESTFN, 'wb') as fp:
- fp.write(cls.DATA)
- super().setUpClass()
-
- @classmethod
- def tearDownClass(cls):
- support.unlink(support.TESTFN)
- super().tearDownClass()
-
- def setUp(self):
- self.file = open(support.TESTFN, 'rb')
- self.addCleanup(self.file.close)
- super().setUp()
-
- def run_loop(self, coro):
- return self.loop.run_until_complete(coro)
-
-
-class SockSendfileMixin(SendfileBase):
-
- class MyProto(asyncio.Protocol):
-
- def __init__(self, loop):
- self.started = False
- self.closed = False
- self.data = bytearray()
- self.fut = loop.create_future()
- self.transport = None
-
- def connection_made(self, transport):
- self.started = True
- self.transport = transport
-
- def data_received(self, data):
- self.data.extend(data)
-
- def connection_lost(self, exc):
- self.closed = True
- self.fut.set_result(None)
-
- async def wait_closed(self):
- await self.fut
-
- @classmethod
- def setUpClass(cls):
- cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE
- constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16
- super().setUpClass()
-
- @classmethod
- def tearDownClass(cls):
- constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize
- super().tearDownClass()
-
- def make_socket(self, cleanup=True):
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.setblocking(False)
- if cleanup:
- self.addCleanup(sock.close)
- return sock
-
- def reduce_receive_buffer_size(self, sock):
- # Reduce receive socket buffer size to test on relative
- # small data sets.
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE)
-
- def reduce_send_buffer_size(self, sock, transport=None):
- # Reduce send socket buffer size to test on relative small data sets.
-
- # On macOS, SO_SNDBUF is reset by connect(). So this method
- # should be called after the socket is connected.
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE)
-
- if transport is not None:
- transport.set_write_buffer_limits(high=self.BUF_SIZE)
-
- def prepare_socksendfile(self):
- proto = self.MyProto(self.loop)
- port = support.find_unused_port()
- srv_sock = self.make_socket(cleanup=False)
- srv_sock.bind((support.HOST, port))
- server = self.run_loop(self.loop.create_server(
- lambda: proto, sock=srv_sock))
- self.reduce_receive_buffer_size(srv_sock)
-
- sock = self.make_socket()
- self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
- self.reduce_send_buffer_size(sock)
-
- def cleanup():
- if proto.transport is not None:
- # can be None if the task was cancelled before
- # connection_made callback
- proto.transport.close()
- self.run_loop(proto.wait_closed())
-
- server.close()
- self.run_loop(server.wait_closed())
-
- self.addCleanup(cleanup)
-
- return sock, proto
-
- def test_sock_sendfile_success(self):
- sock, proto = self.prepare_socksendfile()
- ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
- sock.close()
- self.run_loop(proto.wait_closed())
-
- self.assertEqual(ret, len(self.DATA))
- self.assertEqual(proto.data, self.DATA)
- self.assertEqual(self.file.tell(), len(self.DATA))
-
- def test_sock_sendfile_with_offset_and_count(self):
- sock, proto = self.prepare_socksendfile()
- ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
- 1000, 2000))
- sock.close()
- self.run_loop(proto.wait_closed())
-
- self.assertEqual(proto.data, self.DATA[1000:3000])
- self.assertEqual(self.file.tell(), 3000)
- self.assertEqual(ret, 2000)
-
- def test_sock_sendfile_zero_size(self):
- sock, proto = self.prepare_socksendfile()
- with tempfile.TemporaryFile() as f:
- ret = self.run_loop(self.loop.sock_sendfile(sock, f,
- 0, None))
- sock.close()
- self.run_loop(proto.wait_closed())
-
- self.assertEqual(ret, 0)
- self.assertEqual(self.file.tell(), 0)
-
- def test_sock_sendfile_mix_with_regular_send(self):
- buf = b"mix_regular_send" * (4 * 1024) # 64 KiB
- sock, proto = self.prepare_socksendfile()
- self.run_loop(self.loop.sock_sendall(sock, buf))
- ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
- self.run_loop(self.loop.sock_sendall(sock, buf))
- sock.close()
- self.run_loop(proto.wait_closed())
-
- self.assertEqual(ret, len(self.DATA))
- expected = buf + self.DATA + buf
- self.assertEqual(proto.data, expected)
- self.assertEqual(self.file.tell(), len(self.DATA))
-
-
-class SendfileMixin(SendfileBase):
-
- class MySendfileProto(MyBaseProto):
-
- def __init__(self, loop=None, close_after=0):
- super().__init__(loop)
- self.data = bytearray()
- self.close_after = close_after
-
- def data_received(self, data):
- self.data.extend(data)
- super().data_received(data)
- if self.close_after and self.nbytes >= self.close_after:
- self.transport.close()
-
-
- # Note: sendfile via SSL transport is equal to sendfile fallback
-
- def prepare_sendfile(self, *, is_ssl=False, close_after=0):
- port = support.find_unused_port()
- srv_proto = self.MySendfileProto(loop=self.loop,
- close_after=close_after)
- if is_ssl:
- if not ssl:
- self.skipTest("No ssl module")
- srv_ctx = test_utils.simple_server_sslcontext()
- cli_ctx = test_utils.simple_client_sslcontext()
- else:
- srv_ctx = None
- cli_ctx = None
- srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- srv_sock.bind((support.HOST, port))
- server = self.run_loop(self.loop.create_server(
- lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
- self.reduce_receive_buffer_size(srv_sock)
-
- if is_ssl:
- server_hostname = support.HOST
- else:
- server_hostname = None
- cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- cli_sock.connect((support.HOST, port))
-
- cli_proto = self.MySendfileProto(loop=self.loop)
- tr, pr = self.run_loop(self.loop.create_connection(
- lambda: cli_proto, sock=cli_sock,
- ssl=cli_ctx, server_hostname=server_hostname))
- self.reduce_send_buffer_size(cli_sock, transport=tr)
-
- def cleanup():
- srv_proto.transport.close()
- cli_proto.transport.close()
- self.run_loop(srv_proto.done)
- self.run_loop(cli_proto.done)
-
- server.close()
- self.run_loop(server.wait_closed())
-
- self.addCleanup(cleanup)
- return srv_proto, cli_proto
-
- @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
- def test_sendfile_not_supported(self):
- tr, pr = self.run_loop(
- self.loop.create_datagram_endpoint(
- lambda: MyDatagramProto(loop=self.loop),
- family=socket.AF_INET))
- try:
- with self.assertRaisesRegex(RuntimeError, "not supported"):
- self.run_loop(
- self.loop.sendfile(tr, self.file))
- self.assertEqual(0, self.file.tell())
- finally:
- # don't use self.addCleanup because it produces resource warning
- tr.close()
-
- def test_sendfile(self):
- srv_proto, cli_proto = self.prepare_sendfile()
- ret = self.run_loop(
- self.loop.sendfile(cli_proto.transport, self.file))
- cli_proto.transport.close()
- self.run_loop(srv_proto.done)
- self.assertEqual(ret, len(self.DATA))
- self.assertEqual(srv_proto.nbytes, len(self.DATA))
- self.assertEqual(srv_proto.data, self.DATA)
- self.assertEqual(self.file.tell(), len(self.DATA))
-
- def test_sendfile_force_fallback(self):
- srv_proto, cli_proto = self.prepare_sendfile()
-
- def sendfile_native(transp, file, offset, count):
- # to raise SendfileNotAvailableError
- return base_events.BaseEventLoop._sendfile_native(
- self.loop, transp, file, offset, count)
-
- self.loop._sendfile_native = sendfile_native
-
- ret = self.run_loop(
- self.loop.sendfile(cli_proto.transport, self.file))
- cli_proto.transport.close()
- self.run_loop(srv_proto.done)
- self.assertEqual(ret, len(self.DATA))
- self.assertEqual(srv_proto.nbytes, len(self.DATA))
- self.assertEqual(srv_proto.data, self.DATA)
- self.assertEqual(self.file.tell(), len(self.DATA))
-
- def test_sendfile_force_unsupported_native(self):
- if sys.platform == 'win32':
- if isinstance(self.loop, asyncio.ProactorEventLoop):
- self.skipTest("Fails on proactor event loop")
- srv_proto, cli_proto = self.prepare_sendfile()
-
- def sendfile_native(transp, file, offset, count):
- # to raise SendfileNotAvailableError
- return base_events.BaseEventLoop._sendfile_native(
- self.loop, transp, file, offset, count)
-
- self.loop._sendfile_native = sendfile_native
-
- with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
- "not supported"):
- self.run_loop(
- self.loop.sendfile(cli_proto.transport, self.file,
- fallback=False))
-
- cli_proto.transport.close()
- self.run_loop(srv_proto.done)
- self.assertEqual(srv_proto.nbytes, 0)
- self.assertEqual(self.file.tell(), 0)
-
- def test_sendfile_ssl(self):
- srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
- ret = self.run_loop(
- self.loop.sendfile(cli_proto.transport, self.file))
- cli_proto.transport.close()
- self.run_loop(srv_proto.done)
- self.assertEqual(ret, len(self.DATA))
- self.assertEqual(srv_proto.nbytes, len(self.DATA))
- self.assertEqual(srv_proto.data, self.DATA)
- self.assertEqual(self.file.tell(), len(self.DATA))
-
- def test_sendfile_for_closing_transp(self):
- srv_proto, cli_proto = self.prepare_sendfile()
- cli_proto.transport.close()
- with self.assertRaisesRegex(RuntimeError, "is closing"):
- self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
- self.run_loop(srv_proto.done)
- self.assertEqual(srv_proto.nbytes, 0)
- self.assertEqual(self.file.tell(), 0)
-
- def test_sendfile_pre_and_post_data(self):
- srv_proto, cli_proto = self.prepare_sendfile()
- PREFIX = b'PREFIX__' * 1024 # 8 KiB
- SUFFIX = b'--SUFFIX' * 1024 # 8 KiB
- cli_proto.transport.write(PREFIX)
- ret = self.run_loop(
- self.loop.sendfile(cli_proto.transport, self.file))
- cli_proto.transport.write(SUFFIX)
- cli_proto.transport.close()
- self.run_loop(srv_proto.done)
- self.assertEqual(ret, len(self.DATA))
- self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
- self.assertEqual(self.file.tell(), len(self.DATA))
-
- def test_sendfile_ssl_pre_and_post_data(self):
- srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
- PREFIX = b'zxcvbnm' * 1024
- SUFFIX = b'0987654321' * 1024
- cli_proto.transport.write(PREFIX)
- ret = self.run_loop(
- self.loop.sendfile(cli_proto.transport, self.file))
- cli_proto.transport.write(SUFFIX)
- cli_proto.transport.close()
- self.run_loop(srv_proto.done)
- self.assertEqual(ret, len(self.DATA))
- self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
- self.assertEqual(self.file.tell(), len(self.DATA))
-
- def test_sendfile_partial(self):
- srv_proto, cli_proto = self.prepare_sendfile()
- ret = self.run_loop(
- self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
- cli_proto.transport.close()
- self.run_loop(srv_proto.done)
- self.assertEqual(ret, 100)
- self.assertEqual(srv_proto.nbytes, 100)
- self.assertEqual(srv_proto.data, self.DATA[1000:1100])
- self.assertEqual(self.file.tell(), 1100)
-
- def test_sendfile_ssl_partial(self):
- srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
- ret = self.run_loop(
- self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
- cli_proto.transport.close()
- self.run_loop(srv_proto.done)
- self.assertEqual(ret, 100)
- self.assertEqual(srv_proto.nbytes, 100)
- self.assertEqual(srv_proto.data, self.DATA[1000:1100])
- self.assertEqual(self.file.tell(), 1100)
-
- def test_sendfile_close_peer_after_receiving(self):
- srv_proto, cli_proto = self.prepare_sendfile(
- close_after=len(self.DATA))
- ret = self.run_loop(
- self.loop.sendfile(cli_proto.transport, self.file))
- cli_proto.transport.close()
- self.run_loop(srv_proto.done)
- self.assertEqual(ret, len(self.DATA))
- self.assertEqual(srv_proto.nbytes, len(self.DATA))
- self.assertEqual(srv_proto.data, self.DATA)
- self.assertEqual(self.file.tell(), len(self.DATA))
-
- def test_sendfile_ssl_close_peer_after_receiving(self):
- srv_proto, cli_proto = self.prepare_sendfile(
- is_ssl=True, close_after=len(self.DATA))
- ret = self.run_loop(
- self.loop.sendfile(cli_proto.transport, self.file))
- self.run_loop(srv_proto.done)
- self.assertEqual(ret, len(self.DATA))
- self.assertEqual(srv_proto.nbytes, len(self.DATA))
- self.assertEqual(srv_proto.data, self.DATA)
- self.assertEqual(self.file.tell(), len(self.DATA))
-
- def test_sendfile_close_peer_in_the_middle_of_receiving(self):
- srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
- with self.assertRaises(ConnectionError):
- self.run_loop(
- self.loop.sendfile(cli_proto.transport, self.file))
- self.run_loop(srv_proto.done)
-
- self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
- srv_proto.nbytes)
- self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
- self.file.tell())
- self.assertTrue(cli_proto.transport.is_closing())
-
- def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
-
- def sendfile_native(transp, file, offset, count):
- # to raise SendfileNotAvailableError
- return base_events.BaseEventLoop._sendfile_native(
- self.loop, transp, file, offset, count)
-
- self.loop._sendfile_native = sendfile_native
-
- srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
- with self.assertRaises(ConnectionError):
- self.run_loop(
- self.loop.sendfile(cli_proto.transport, self.file))
- self.run_loop(srv_proto.done)
-
- self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
- srv_proto.nbytes)
- self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
- self.file.tell())
-
- @unittest.skipIf(not hasattr(os, 'sendfile'),
- "Don't have native sendfile support")
- def test_sendfile_prevents_bare_write(self):
- srv_proto, cli_proto = self.prepare_sendfile()
- fut = self.loop.create_future()
-
- async def coro():
- fut.set_result(None)
- return await self.loop.sendfile(cli_proto.transport, self.file)
-
- t = self.loop.create_task(coro())
- self.run_loop(fut)
- with self.assertRaisesRegex(RuntimeError,
- "sendfile is in progress"):
- cli_proto.transport.write(b'data')
- ret = self.run_loop(t)
- self.assertEqual(ret, len(self.DATA))
-
- def test_sendfile_no_fallback_for_fallback_transport(self):
- transport = mock.Mock()
- transport.is_closing.side_effect = lambda: False
- transport._sendfile_compatible = constants._SendfileMode.FALLBACK
- with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'):
- self.loop.run_until_complete(
- self.loop.sendfile(transport, None, fallback=False))
-
-
if sys.platform == 'win32':
class SelectEventLoopTests(EventLoopTestsMixin,
- SendfileMixin,
- SockSendfileMixin,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop()
class ProactorEventLoopTests(EventLoopTestsMixin,
- SendfileMixin,
- SockSendfileMixin,
SubprocessTestsMixin,
test_utils.TestCase):
@@ -2469,9 +2022,7 @@ if sys.platform == 'win32':
else:
import selectors
- class UnixEventLoopTestsMixin(EventLoopTestsMixin,
- SendfileMixin,
- SockSendfileMixin):
+ class UnixEventLoopTestsMixin(EventLoopTestsMixin):
def setUp(self):
super().setUp()
watcher = asyncio.SafeChildWatcher()
diff --git a/Lib/test/test_asyncio/test_sendfile.py b/Lib/test/test_asyncio/test_sendfile.py
new file mode 100644
index 0000000..26e44a3
--- /dev/null
+++ b/Lib/test/test_asyncio/test_sendfile.py
@@ -0,0 +1,550 @@
+"""Tests for sendfile functionality."""
+
+import asyncio
+import os
+import socket
+import sys
+import tempfile
+import unittest
+from asyncio import base_events
+from asyncio import constants
+from unittest import mock
+from test import support
+from test.test_asyncio import utils as test_utils
+
+try:
+ import ssl
+except ImportError:
+ ssl = None
+
+
+class MySendfileProto(asyncio.Protocol):
+
+ def __init__(self, loop=None, close_after=0):
+ self.transport = None
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if loop is not None:
+ self.connected = loop.create_future()
+ self.done = loop.create_future()
+ self.data = bytearray()
+ self.close_after = close_after
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ if self.connected:
+ self.connected.set_result(None)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+ def data_received(self, data):
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+ self.data.extend(data)
+ super().data_received(data)
+ if self.close_after and self.nbytes >= self.close_after:
+ self.transport.close()
+
+
+class MyProto(asyncio.Protocol):
+
+ def __init__(self, loop):
+ self.started = False
+ self.closed = False
+ self.data = bytearray()
+ self.fut = loop.create_future()
+ self.transport = None
+
+ def connection_made(self, transport):
+ self.started = True
+ self.transport = transport
+
+ def data_received(self, data):
+ self.data.extend(data)
+
+ def connection_lost(self, exc):
+ self.closed = True
+ self.fut.set_result(None)
+
+ async def wait_closed(self):
+ await self.fut
+
+
+class SendfileBase:
+
+ DATA = b"SendfileBaseData" * (1024 * 8) # 128 KiB
+
+ # Reduce socket buffer size to test on relative small data sets.
+ BUF_SIZE = 4 * 1024 # 4 KiB
+
+ def create_event_loop(self):
+ raise NotImplementedError
+
+ @classmethod
+ def setUpClass(cls):
+ with open(support.TESTFN, 'wb') as fp:
+ fp.write(cls.DATA)
+ super().setUpClass()
+
+ @classmethod
+ def tearDownClass(cls):
+ support.unlink(support.TESTFN)
+ super().tearDownClass()
+
+ def setUp(self):
+ self.file = open(support.TESTFN, 'rb')
+ self.addCleanup(self.file.close)
+ self.loop = self.create_event_loop()
+ self.set_event_loop(self.loop)
+ super().setUp()
+
+ def tearDown(self):
+ # just in case if we have transport close callbacks
+ if not self.loop.is_closed():
+ test_utils.run_briefly(self.loop)
+
+ self.doCleanups()
+ support.gc_collect()
+ super().tearDown()
+
+ def run_loop(self, coro):
+ return self.loop.run_until_complete(coro)
+
+
+class SockSendfileMixin(SendfileBase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE
+ constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16
+ super().setUpClass()
+
+ @classmethod
+ def tearDownClass(cls):
+ constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize
+ super().tearDownClass()
+
+ def make_socket(self, cleanup=True):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.setblocking(False)
+ if cleanup:
+ self.addCleanup(sock.close)
+ return sock
+
+ def reduce_receive_buffer_size(self, sock):
+ # Reduce receive socket buffer size to test on relative
+ # small data sets.
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE)
+
+ def reduce_send_buffer_size(self, sock, transport=None):
+ # Reduce send socket buffer size to test on relative small data sets.
+
+ # On macOS, SO_SNDBUF is reset by connect(). So this method
+ # should be called after the socket is connected.
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE)
+
+ if transport is not None:
+ transport.set_write_buffer_limits(high=self.BUF_SIZE)
+
+ def prepare_socksendfile(self):
+ proto = MyProto(self.loop)
+ port = support.find_unused_port()
+ srv_sock = self.make_socket(cleanup=False)
+ srv_sock.bind((support.HOST, port))
+ server = self.run_loop(self.loop.create_server(
+ lambda: proto, sock=srv_sock))
+ self.reduce_receive_buffer_size(srv_sock)
+
+ sock = self.make_socket()
+ self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
+ self.reduce_send_buffer_size(sock)
+
+ def cleanup():
+ if proto.transport is not None:
+ # can be None if the task was cancelled before
+ # connection_made callback
+ proto.transport.close()
+ self.run_loop(proto.wait_closed())
+
+ server.close()
+ self.run_loop(server.wait_closed())
+
+ self.addCleanup(cleanup)
+
+ return sock, proto
+
+ def test_sock_sendfile_success(self):
+ sock, proto = self.prepare_socksendfile()
+ ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
+ sock.close()
+ self.run_loop(proto.wait_closed())
+
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sock_sendfile_with_offset_and_count(self):
+ sock, proto = self.prepare_socksendfile()
+ ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
+ 1000, 2000))
+ sock.close()
+ self.run_loop(proto.wait_closed())
+
+ self.assertEqual(proto.data, self.DATA[1000:3000])
+ self.assertEqual(self.file.tell(), 3000)
+ self.assertEqual(ret, 2000)
+
+ def test_sock_sendfile_zero_size(self):
+ sock, proto = self.prepare_socksendfile()
+ with tempfile.TemporaryFile() as f:
+ ret = self.run_loop(self.loop.sock_sendfile(sock, f,
+ 0, None))
+ sock.close()
+ self.run_loop(proto.wait_closed())
+
+ self.assertEqual(ret, 0)
+ self.assertEqual(self.file.tell(), 0)
+
+ def test_sock_sendfile_mix_with_regular_send(self):
+ buf = b"mix_regular_send" * (4 * 1024) # 64 KiB
+ sock, proto = self.prepare_socksendfile()
+ self.run_loop(self.loop.sock_sendall(sock, buf))
+ ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
+ self.run_loop(self.loop.sock_sendall(sock, buf))
+ sock.close()
+ self.run_loop(proto.wait_closed())
+
+ self.assertEqual(ret, len(self.DATA))
+ expected = buf + self.DATA + buf
+ self.assertEqual(proto.data, expected)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+
+class SendfileMixin(SendfileBase):
+
+ # Note: sendfile via SSL transport is equal to sendfile fallback
+
+ def prepare_sendfile(self, *, is_ssl=False, close_after=0):
+ port = support.find_unused_port()
+ srv_proto = MySendfileProto(loop=self.loop,
+ close_after=close_after)
+ if is_ssl:
+ if not ssl:
+ self.skipTest("No ssl module")
+ srv_ctx = test_utils.simple_server_sslcontext()
+ cli_ctx = test_utils.simple_client_sslcontext()
+ else:
+ srv_ctx = None
+ cli_ctx = None
+ srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ srv_sock.bind((support.HOST, port))
+ server = self.run_loop(self.loop.create_server(
+ lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
+ self.reduce_receive_buffer_size(srv_sock)
+
+ if is_ssl:
+ server_hostname = support.HOST
+ else:
+ server_hostname = None
+ cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ cli_sock.connect((support.HOST, port))
+
+ cli_proto = MySendfileProto(loop=self.loop)
+ tr, pr = self.run_loop(self.loop.create_connection(
+ lambda: cli_proto, sock=cli_sock,
+ ssl=cli_ctx, server_hostname=server_hostname))
+ self.reduce_send_buffer_size(cli_sock, transport=tr)
+
+ def cleanup():
+ srv_proto.transport.close()
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.run_loop(cli_proto.done)
+
+ server.close()
+ self.run_loop(server.wait_closed())
+
+ self.addCleanup(cleanup)
+ return srv_proto, cli_proto
+
+ @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
+ def test_sendfile_not_supported(self):
+ tr, pr = self.run_loop(
+ self.loop.create_datagram_endpoint(
+ asyncio.DatagramProtocol,
+ family=socket.AF_INET))
+ try:
+ with self.assertRaisesRegex(RuntimeError, "not supported"):
+ self.run_loop(
+ self.loop.sendfile(tr, self.file))
+ self.assertEqual(0, self.file.tell())
+ finally:
+ # don't use self.addCleanup because it produces resource warning
+ tr.close()
+
+ def test_sendfile(self):
+ srv_proto, cli_proto = self.prepare_sendfile()
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_force_fallback(self):
+ srv_proto, cli_proto = self.prepare_sendfile()
+
+ def sendfile_native(transp, file, offset, count):
+ # to raise SendfileNotAvailableError
+ return base_events.BaseEventLoop._sendfile_native(
+ self.loop, transp, file, offset, count)
+
+ self.loop._sendfile_native = sendfile_native
+
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_force_unsupported_native(self):
+ if sys.platform == 'win32':
+ if isinstance(self.loop, asyncio.ProactorEventLoop):
+ self.skipTest("Fails on proactor event loop")
+ srv_proto, cli_proto = self.prepare_sendfile()
+
+ def sendfile_native(transp, file, offset, count):
+ # to raise SendfileNotAvailableError
+ return base_events.BaseEventLoop._sendfile_native(
+ self.loop, transp, file, offset, count)
+
+ self.loop._sendfile_native = sendfile_native
+
+ with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
+ "not supported"):
+ self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file,
+ fallback=False))
+
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(srv_proto.nbytes, 0)
+ self.assertEqual(self.file.tell(), 0)
+
+ def test_sendfile_ssl(self):
+ srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_for_closing_transp(self):
+ srv_proto, cli_proto = self.prepare_sendfile()
+ cli_proto.transport.close()
+ with self.assertRaisesRegex(RuntimeError, "is closing"):
+ self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
+ self.run_loop(srv_proto.done)
+ self.assertEqual(srv_proto.nbytes, 0)
+ self.assertEqual(self.file.tell(), 0)
+
+ def test_sendfile_pre_and_post_data(self):
+ srv_proto, cli_proto = self.prepare_sendfile()
+ PREFIX = b'PREFIX__' * 1024 # 8 KiB
+ SUFFIX = b'--SUFFIX' * 1024 # 8 KiB
+ cli_proto.transport.write(PREFIX)
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.write(SUFFIX)
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_ssl_pre_and_post_data(self):
+ srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
+ PREFIX = b'zxcvbnm' * 1024
+ SUFFIX = b'0987654321' * 1024
+ cli_proto.transport.write(PREFIX)
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.write(SUFFIX)
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_partial(self):
+ srv_proto, cli_proto = self.prepare_sendfile()
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, 100)
+ self.assertEqual(srv_proto.nbytes, 100)
+ self.assertEqual(srv_proto.data, self.DATA[1000:1100])
+ self.assertEqual(self.file.tell(), 1100)
+
+ def test_sendfile_ssl_partial(self):
+ srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, 100)
+ self.assertEqual(srv_proto.nbytes, 100)
+ self.assertEqual(srv_proto.data, self.DATA[1000:1100])
+ self.assertEqual(self.file.tell(), 1100)
+
+ def test_sendfile_close_peer_after_receiving(self):
+ srv_proto, cli_proto = self.prepare_sendfile(
+ close_after=len(self.DATA))
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_ssl_close_peer_after_receiving(self):
+ srv_proto, cli_proto = self.prepare_sendfile(
+ is_ssl=True, close_after=len(self.DATA))
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_close_peer_in_the_middle_of_receiving(self):
+ srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
+ with self.assertRaises(ConnectionError):
+ self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ self.run_loop(srv_proto.done)
+
+ self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
+ srv_proto.nbytes)
+ self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
+ self.file.tell())
+ self.assertTrue(cli_proto.transport.is_closing())
+
+ def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
+
+ def sendfile_native(transp, file, offset, count):
+ # to raise SendfileNotAvailableError
+ return base_events.BaseEventLoop._sendfile_native(
+ self.loop, transp, file, offset, count)
+
+ self.loop._sendfile_native = sendfile_native
+
+ srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
+ with self.assertRaises(ConnectionError):
+ self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ self.run_loop(srv_proto.done)
+
+ self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
+ srv_proto.nbytes)
+ self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
+ self.file.tell())
+
+ @unittest.skipIf(not hasattr(os, 'sendfile'),
+ "Don't have native sendfile support")
+ def test_sendfile_prevents_bare_write(self):
+ srv_proto, cli_proto = self.prepare_sendfile()
+ fut = self.loop.create_future()
+
+ async def coro():
+ fut.set_result(None)
+ return await self.loop.sendfile(cli_proto.transport, self.file)
+
+ t = self.loop.create_task(coro())
+ self.run_loop(fut)
+ with self.assertRaisesRegex(RuntimeError,
+ "sendfile is in progress"):
+ cli_proto.transport.write(b'data')
+ ret = self.run_loop(t)
+ self.assertEqual(ret, len(self.DATA))
+
+ def test_sendfile_no_fallback_for_fallback_transport(self):
+ transport = mock.Mock()
+ transport.is_closing.side_effect = lambda: False
+ transport._sendfile_compatible = constants._SendfileMode.FALLBACK
+ with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'):
+ self.loop.run_until_complete(
+ self.loop.sendfile(transport, None, fallback=False))
+
+
+class SendfileTestsBase(SendfileMixin, SockSendfileMixin):
+ pass
+
+
+if sys.platform == 'win32':
+
+ class SelectEventLoopTests(SendfileTestsBase,
+ test_utils.TestCase):
+
+ def create_event_loop(self):
+ return asyncio.SelectorEventLoop()
+
+ class ProactorEventLoopTests(SendfileTestsBase,
+ test_utils.TestCase):
+
+ def create_event_loop(self):
+ return asyncio.ProactorEventLoop()
+
+else:
+ import selectors
+
+ if hasattr(selectors, 'KqueueSelector'):
+ class KqueueEventLoopTests(SendfileTestsBase,
+ test_utils.TestCase):
+
+ def create_event_loop(self):
+ return asyncio.SelectorEventLoop(
+ selectors.KqueueSelector())
+
+ if hasattr(selectors, 'EpollSelector'):
+ class EPollEventLoopTests(SendfileTestsBase,
+ test_utils.TestCase):
+
+ def create_event_loop(self):
+ return asyncio.SelectorEventLoop(selectors.EpollSelector())
+
+ if hasattr(selectors, 'PollSelector'):
+ class PollEventLoopTests(SendfileTestsBase,
+ test_utils.TestCase):
+
+ def create_event_loop(self):
+ return asyncio.SelectorEventLoop(selectors.PollSelector())
+
+ # Should always exist.
+ class SelectEventLoopTests(SendfileTestsBase,
+ test_utils.TestCase):
+
+ def create_event_loop(self):
+ return asyncio.SelectorEventLoop(selectors.SelectSelector())