summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Doc/library/asyncio-eventloop.rst31
-rw-r--r--Lib/asyncio/base_events.py142
-rw-r--r--Lib/asyncio/constants.py9
-rw-r--r--Lib/asyncio/events.py8
-rw-r--r--Lib/asyncio/proactor_events.py9
-rw-r--r--Lib/asyncio/selector_events.py39
-rw-r--r--Lib/asyncio/sslproto.py7
-rw-r--r--Lib/asyncio/windows_events.py9
-rw-r--r--Lib/test/test_asyncio/test_base_events.py9
-rw-r--r--Lib/test/test_asyncio/test_events.py303
-rw-r--r--Misc/NEWS.d/next/Library/2018-01-22-18-18-44.bpo-32622.A1D6FP.rst1
-rw-r--r--Modules/overlapped.c1
12 files changed, 560 insertions, 8 deletions
diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst
index 834a4e8..fe16223 100644
--- a/Doc/library/asyncio-eventloop.rst
+++ b/Doc/library/asyncio-eventloop.rst
@@ -543,6 +543,37 @@ Creating listening connections
.. versionadded:: 3.5.3
+File Transferring
+-----------------
+
+.. coroutinemethod:: AbstractEventLoop.sendfile(sock, transport, \
+ offset=0, count=None, \
+ *, fallback=True)
+
+ Send a *file* to *transport*, return the total number of bytes
+ which were sent.
+
+ The method uses high-performance :meth:`os.sendfile` if available.
+
+ *file* must be a regular file object opened in binary mode.
+
+ *offset* tells from where to start reading the file. If specified,
+ *count* is the total number of bytes to transmit as opposed to
+ sending the file until EOF is reached. File position is updated on
+ return or also in case of error in which case :meth:`file.tell()
+ <io.IOBase.tell>` can be used to figure out the number of bytes
+ which were sent.
+
+ *fallback* set to ``True`` makes asyncio to manually read and send
+ the file when the platform does not support the sendfile syscall
+ (e.g. Windows or SSL socket on Unix).
+
+ Raise :exc:`SendfileNotAvailableError` if the system does not support
+ *sendfile* syscall and *fallback* is ``False``.
+
+ .. versionadded:: 3.7
+
+
TLS Upgrade
-----------
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 94eb308..f532dc4 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -38,8 +38,10 @@ from . import constants
from . import coroutines
from . import events
from . import futures
+from . import protocols
from . import sslproto
from . import tasks
+from . import transports
from .log import logger
@@ -155,6 +157,75 @@ def _run_until_complete_cb(fut):
futures._get_loop(fut).stop()
+
+class _SendfileFallbackProtocol(protocols.Protocol):
+ def __init__(self, transp):
+ if not isinstance(transp, transports._FlowControlMixin):
+ raise TypeError("transport should be _FlowControlMixin instance")
+ self._transport = transp
+ self._proto = transp.get_protocol()
+ self._should_resume_reading = transp.is_reading()
+ self._should_resume_writing = transp._protocol_paused
+ transp.pause_reading()
+ transp.set_protocol(self)
+ if self._should_resume_writing:
+ self._write_ready_fut = self._transport._loop.create_future()
+ else:
+ self._write_ready_fut = None
+
+ async def drain(self):
+ if self._transport.is_closing():
+ raise ConnectionError("Connection closed by peer")
+ fut = self._write_ready_fut
+ if fut is None:
+ return
+ await fut
+
+ def connection_made(self, transport):
+ raise RuntimeError("Invalid state: "
+ "connection should have been established already.")
+
+ def connection_lost(self, exc):
+ if self._write_ready_fut is not None:
+ # Never happens if peer disconnects after sending the whole content
+ # Thus disconnection is always an exception from user perspective
+ if exc is None:
+ self._write_ready_fut.set_exception(
+ ConnectionError("Connection is closed by peer"))
+ else:
+ self._write_ready_fut.set_exception(exc)
+ self._proto.connection_lost(exc)
+
+ def pause_writing(self):
+ if self._write_ready_fut is not None:
+ return
+ self._write_ready_fut = self._transport._loop.create_future()
+
+ def resume_writing(self):
+ if self._write_ready_fut is None:
+ return
+ self._write_ready_fut.set_result(False)
+ self._write_ready_fut = None
+
+ def data_received(self, data):
+ raise RuntimeError("Invalid state: reading should be paused")
+
+ def eof_received(self):
+ raise RuntimeError("Invalid state: reading should be paused")
+
+ async def restore(self):
+ self._transport.set_protocol(self._proto)
+ if self._should_resume_reading:
+ self._transport.resume_reading()
+ if self._write_ready_fut is not None:
+ # Cancel the future.
+ # Basically it has no effect because protocol is switched back,
+ # no code should wait for it anymore.
+ self._write_ready_fut.cancel()
+ if self._should_resume_writing:
+ self._proto.resume_writing()
+
+
class Server(events.AbstractServer):
def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
@@ -926,6 +997,77 @@ class BaseEventLoop(events.AbstractEventLoop):
return transport, protocol
+ async def sendfile(self, transport, file, offset=0, count=None,
+ *, fallback=True):
+ """Send a file to transport.
+
+ Return the total number of bytes which were sent.
+
+ The method uses high-performance os.sendfile if available.
+
+ file must be a regular file object opened in binary mode.
+
+ offset tells from where to start reading the file. If specified,
+ count is the total number of bytes to transmit as opposed to
+ sending the file until EOF is reached. File position is updated on
+ return or also in case of error in which case file.tell()
+ can be used to figure out the number of bytes
+ which were sent.
+
+ fallback set to True makes asyncio to manually read and send
+ the file when the platform does not support the sendfile syscall
+ (e.g. Windows or SSL socket on Unix).
+
+ Raise SendfileNotAvailableError if the system does not support
+ sendfile syscall and fallback is False.
+ """
+ if transport.is_closing():
+ raise RuntimeError("Transport is closing")
+ mode = getattr(transport, '_sendfile_compatible',
+ constants._SendfileMode.UNSUPPORTED)
+ if mode is constants._SendfileMode.UNSUPPORTED:
+ raise RuntimeError(
+ f"sendfile is not supported for transport {transport!r}")
+ if mode is constants._SendfileMode.TRY_NATIVE:
+ try:
+ return await self._sendfile_native(transport, file,
+ offset, count)
+ except events.SendfileNotAvailableError as exc:
+ if not fallback:
+ raise
+ # the mode is FALLBACK or fallback is True
+ return await self._sendfile_fallback(transport, file,
+ offset, count)
+
+ async def _sendfile_native(self, transp, file, offset, count):
+ raise events.SendfileNotAvailableError(
+ "sendfile syscall is not supported")
+
+ async def _sendfile_fallback(self, transp, file, offset, count):
+ if offset:
+ file.seek(offset)
+ blocksize = min(count, 16384) if count else 16384
+ buf = bytearray(blocksize)
+ total_sent = 0
+ proto = _SendfileFallbackProtocol(transp)
+ try:
+ while True:
+ if count:
+ blocksize = min(count - total_sent, blocksize)
+ if blocksize <= 0:
+ return total_sent
+ view = memoryview(buf)[:blocksize]
+ read = file.readinto(view)
+ if not read:
+ return total_sent # EOF
+ await proto.drain()
+ transp.write(view)
+ total_sent += read
+ finally:
+ if total_sent > 0 and hasattr(file, 'seek'):
+ file.seek(offset + total_sent)
+ await proto.restore()
+
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
diff --git a/Lib/asyncio/constants.py b/Lib/asyncio/constants.py
index 0ad974f..739b0a7 100644
--- a/Lib/asyncio/constants.py
+++ b/Lib/asyncio/constants.py
@@ -1,3 +1,5 @@
+import enum
+
# After the connection is lost, log warnings after this many write()s.
LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5
@@ -11,3 +13,10 @@ DEBUG_STACK_DEPTH = 10
# Number of seconds to wait for SSL handshake to complete
SSL_HANDSHAKE_TIMEOUT = 10.0
+
+# The enum should be here to break circular dependencies between
+# base_events and sslproto
+class _SendfileMode(enum.Enum):
+ UNSUPPORTED = enum.auto()
+ TRY_NATIVE = enum.auto()
+ FALLBACK = enum.auto()
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index 7aa3de0..bdefcf6 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -354,6 +354,14 @@ class AbstractEventLoop:
"""
raise NotImplementedError
+ async def sendfile(self, transport, file, offset=0, count=None,
+ *, fallback=True):
+ """Send a file through a transport.
+
+ Return an amount of sent bytes.
+ """
+ raise NotImplementedError
+
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index ab1285b..6d27e53 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -180,7 +180,12 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
assert self._read_fut is fut or (self._read_fut is None and
self._closing)
self._read_fut = None
- data = fut.result() # deliver data later in "finally" clause
+ if fut.done():
+ # deliver data later in "finally" clause
+ data = fut.result()
+ else:
+ # the future will be replaced by next proactor.recv call
+ fut.cancel()
if self._closing:
# since close() has been called we ignore any read data
@@ -345,6 +350,8 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport,
transports.Transport):
"""Transport for connected sockets."""
+ _sendfile_compatible = constants._SendfileMode.FALLBACK
+
def _set_extra(self, sock):
self._extra['socket'] = sock
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index 9446ae6..5956f2d 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -540,6 +540,20 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
else:
fut.set_result((conn, address))
+ async def _sendfile_native(self, transp, file, offset, count):
+ del self._transports[transp._sock_fd]
+ resume_reading = transp.is_reading()
+ transp.pause_reading()
+ await transp._make_empty_waiter()
+ try:
+ return await self.sock_sendfile(transp._sock, file, offset, count,
+ fallback=False)
+ finally:
+ transp._reset_empty_waiter()
+ if resume_reading:
+ transp.resume_reading()
+ self._transports[transp._sock_fd] = transp
+
def _process_events(self, event_list):
for key, mask in event_list:
fileobj, (reader, writer) = key.fileobj, key.data
@@ -695,12 +709,14 @@ class _SelectorTransport(transports._FlowControlMixin,
class _SelectorSocketTransport(_SelectorTransport):
_start_tls_compatible = True
+ _sendfile_compatible = constants._SendfileMode.TRY_NATIVE
def __init__(self, loop, sock, protocol, waiter=None,
extra=None, server=None):
super().__init__(loop, sock, protocol, extra, server)
self._eof = False
self._paused = False
+ self._empty_waiter = None
# Disable the Nagle algorithm -- small writes will be
# sent without waiting for the TCP ACK. This generally
@@ -765,6 +781,8 @@ class _SelectorSocketTransport(_SelectorTransport):
f'not {type(data).__name__!r}')
if self._eof:
raise RuntimeError('Cannot call write() after write_eof()')
+ if self._empty_waiter is not None:
+ raise RuntimeError('unable to write; sendfile is in progress')
if not data:
return
@@ -807,12 +825,16 @@ class _SelectorSocketTransport(_SelectorTransport):
self._loop._remove_writer(self._sock_fd)
self._buffer.clear()
self._fatal_error(exc, 'Fatal write error on socket transport')
+ if self._empty_waiter is not None:
+ self._empty_waiter.set_exception(exc)
else:
if n:
del self._buffer[:n]
self._maybe_resume_protocol() # May append to buffer.
if not self._buffer:
self._loop._remove_writer(self._sock_fd)
+ if self._empty_waiter is not None:
+ self._empty_waiter.set_result(None)
if self._closing:
self._call_connection_lost(None)
elif self._eof:
@@ -828,6 +850,23 @@ class _SelectorSocketTransport(_SelectorTransport):
def can_write_eof(self):
return True
+ def _call_connection_lost(self, exc):
+ super()._call_connection_lost(exc)
+ if self._empty_waiter is not None:
+ self._empty_waiter.set_exception(
+ ConnectionError("Connection is closed by peer"))
+
+ def _make_empty_waiter(self):
+ if self._empty_waiter is not None:
+ raise RuntimeError("Empty waiter is already set")
+ self._empty_waiter = self._loop.create_future()
+ if not self._buffer:
+ self._empty_waiter.set_result(None)
+ return self._empty_waiter
+
+ def _reset_empty_waiter(self):
+ self._empty_waiter = None
+
class _SelectorDatagramTransport(_SelectorTransport):
diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py
index 1130bce..863b543 100644
--- a/Lib/asyncio/sslproto.py
+++ b/Lib/asyncio/sslproto.py
@@ -282,6 +282,8 @@ class _SSLPipe(object):
class _SSLProtocolTransport(transports._FlowControlMixin,
transports.Transport):
+ _sendfile_compatible = constants._SendfileMode.FALLBACK
+
def __init__(self, loop, ssl_protocol):
self._loop = loop
# SSLProtocol instance
@@ -365,6 +367,11 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
"""Return the current size of the write buffer."""
return self._ssl_protocol._transport.get_write_buffer_size()
+ @property
+ def _protocol_paused(self):
+ # Required for sendfile fallback pause_writing/resume_writing logic
+ return self._ssl_protocol._transport._protocol_paused
+
def write(self, data):
"""Write some data bytes to the transport.
diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py
index 890fce8..f91fcdd 100644
--- a/Lib/asyncio/windows_events.py
+++ b/Lib/asyncio/windows_events.py
@@ -425,7 +425,8 @@ class IocpProactor:
try:
return ov.getresult()
except OSError as exc:
- if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
+ if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+ _overlapped.ERROR_OPERATION_ABORTED):
raise ConnectionResetError(*exc.args)
else:
raise
@@ -447,7 +448,8 @@ class IocpProactor:
try:
return ov.getresult()
except OSError as exc:
- if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
+ if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+ _overlapped.ERROR_OPERATION_ABORTED):
raise ConnectionResetError(*exc.args)
else:
raise
@@ -466,7 +468,8 @@ class IocpProactor:
try:
return ov.getresult()
except OSError as exc:
- if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
+ if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+ _overlapped.ERROR_OPERATION_ABORTED):
raise ConnectionResetError(*exc.args)
else:
raise
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
index 6489f50..ab6560c 100644
--- a/Lib/test/test_asyncio/test_base_events.py
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -1788,7 +1788,7 @@ class RunningLoopTests(unittest.TestCase):
outer_loop.close()
-class BaseLoopSendfileTests(test_utils.TestCase):
+class BaseLoopSockSendfileTests(test_utils.TestCase):
DATA = b"12345abcde" * 16 * 1024 # 160 KiB
@@ -1799,9 +1799,11 @@ class BaseLoopSendfileTests(test_utils.TestCase):
self.closed = False
self.data = bytearray()
self.fut = loop.create_future()
+ self.transport = None
def connection_made(self, transport):
self.started = True
+ self.transport = transport
def data_received(self, data):
self.data.extend(data)
@@ -1809,6 +1811,7 @@ class BaseLoopSendfileTests(test_utils.TestCase):
def connection_lost(self, exc):
self.closed = True
self.fut.set_result(None)
+ self.transport = None
async def wait_closed(self):
await self.fut
@@ -1853,6 +1856,10 @@ class BaseLoopSendfileTests(test_utils.TestCase):
def cleanup():
server.close()
self.run_loop(server.wait_closed())
+ sock.close()
+ if proto.transport is not None:
+ proto.transport.close()
+ self.run_loop(proto.wait_closed())
self.addCleanup(cleanup)
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
index cf21753..0981bd6 100644
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -26,6 +26,7 @@ if sys.platform != 'win32':
import tty
import asyncio
+from asyncio import base_events
from asyncio import coroutines
from asyncio import events
from asyncio import proactor_events
@@ -2090,14 +2091,308 @@ class SubprocessTestsMixin:
self.loop.run_until_complete(connect(shell=False))
+class MySendfileProto(MyBaseProto):
+
+ def __init__(self, loop=None, close_after=0):
+ super().__init__(loop)
+ self.data = bytearray()
+ self.close_after = close_after
+
+ def data_received(self, data):
+ self.data.extend(data)
+ super().data_received(data)
+ if self.close_after and self.nbytes >= self.close_after:
+ self.transport.close()
+
+
+class SendfileMixin:
+ # Note: sendfile via SSL transport is equal to sendfile fallback
+
+ DATA = b"12345abcde" * 160 * 1024 # 160 KiB
+
+ @classmethod
+ def setUpClass(cls):
+ with open(support.TESTFN, 'wb') as fp:
+ fp.write(cls.DATA)
+ super().setUpClass()
+
+ @classmethod
+ def tearDownClass(cls):
+ support.unlink(support.TESTFN)
+ super().tearDownClass()
+
+ def setUp(self):
+ self.file = open(support.TESTFN, 'rb')
+ self.addCleanup(self.file.close)
+ super().setUp()
+
+ def run_loop(self, coro):
+ return self.loop.run_until_complete(coro)
+
+ def prepare(self, *, is_ssl=False, close_after=0):
+ port = support.find_unused_port()
+ srv_proto = MySendfileProto(loop=self.loop, close_after=close_after)
+ if is_ssl:
+ srv_ctx = test_utils.simple_server_sslcontext()
+ cli_ctx = test_utils.simple_client_sslcontext()
+ else:
+ srv_ctx = None
+ cli_ctx = None
+ srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # reduce recv socket buffer size to test on relative small data sets
+ srv_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
+ srv_sock.bind((support.HOST, port))
+ server = self.run_loop(self.loop.create_server(
+ lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
+
+ if is_ssl:
+ server_hostname = support.HOST
+ else:
+ server_hostname = None
+ cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # reduce send socket buffer size to test on relative small data sets
+ cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
+ cli_sock.connect((support.HOST, port))
+ cli_proto = MySendfileProto(loop=self.loop)
+ tr, pr = self.run_loop(self.loop.create_connection(
+ lambda: cli_proto, sock=cli_sock,
+ ssl=cli_ctx, server_hostname=server_hostname))
+
+ def cleanup():
+ srv_proto.transport.close()
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.run_loop(cli_proto.done)
+
+ server.close()
+ self.run_loop(server.wait_closed())
+
+ self.addCleanup(cleanup)
+ return srv_proto, cli_proto
+
+ @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
+ def test_sendfile_not_supported(self):
+ tr, pr = self.run_loop(
+ self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(loop=self.loop),
+ family=socket.AF_INET))
+ try:
+ with self.assertRaisesRegex(RuntimeError, "not supported"):
+ self.run_loop(
+ self.loop.sendfile(tr, self.file))
+ self.assertEqual(0, self.file.tell())
+ finally:
+ # don't use self.addCleanup because it produces resource warning
+ tr.close()
+
+ def test_sendfile(self):
+ srv_proto, cli_proto = self.prepare()
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_force_fallback(self):
+ srv_proto, cli_proto = self.prepare()
+
+ def sendfile_native(transp, file, offset, count):
+ # to raise SendfileNotAvailableError
+ return base_events.BaseEventLoop._sendfile_native(
+ self.loop, transp, file, offset, count)
+
+ self.loop._sendfile_native = sendfile_native
+
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_force_unsupported_native(self):
+ if sys.platform == 'win32':
+ if isinstance(self.loop, asyncio.ProactorEventLoop):
+ self.skipTest("Fails on proactor event loop")
+ srv_proto, cli_proto = self.prepare()
+
+ def sendfile_native(transp, file, offset, count):
+ # to raise SendfileNotAvailableError
+ return base_events.BaseEventLoop._sendfile_native(
+ self.loop, transp, file, offset, count)
+
+ self.loop._sendfile_native = sendfile_native
+
+ with self.assertRaisesRegex(events.SendfileNotAvailableError,
+ "not supported"):
+ self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file,
+ fallback=False))
+
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(srv_proto.nbytes, 0)
+ self.assertEqual(self.file.tell(), 0)
+
+ def test_sendfile_ssl(self):
+ srv_proto, cli_proto = self.prepare(is_ssl=True)
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_for_closing_transp(self):
+ srv_proto, cli_proto = self.prepare()
+ cli_proto.transport.close()
+ with self.assertRaisesRegex(RuntimeError, "is closing"):
+ self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
+ self.run_loop(srv_proto.done)
+ self.assertEqual(srv_proto.nbytes, 0)
+ self.assertEqual(self.file.tell(), 0)
+
+ def test_sendfile_pre_and_post_data(self):
+ srv_proto, cli_proto = self.prepare()
+ PREFIX = b'zxcvbnm' * 1024
+ SUFFIX = b'0987654321' * 1024
+ cli_proto.transport.write(PREFIX)
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.write(SUFFIX)
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_ssl_pre_and_post_data(self):
+ srv_proto, cli_proto = self.prepare(is_ssl=True)
+ PREFIX = b'zxcvbnm' * 1024
+ SUFFIX = b'0987654321' * 1024
+ cli_proto.transport.write(PREFIX)
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.write(SUFFIX)
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_partial(self):
+ srv_proto, cli_proto = self.prepare()
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, 100)
+ self.assertEqual(srv_proto.nbytes, 100)
+ self.assertEqual(srv_proto.data, self.DATA[1000:1100])
+ self.assertEqual(self.file.tell(), 1100)
+
+ def test_sendfile_ssl_partial(self):
+ srv_proto, cli_proto = self.prepare(is_ssl=True)
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, 100)
+ self.assertEqual(srv_proto.nbytes, 100)
+ self.assertEqual(srv_proto.data, self.DATA[1000:1100])
+ self.assertEqual(self.file.tell(), 1100)
+
+ def test_sendfile_close_peer_after_receiving(self):
+ srv_proto, cli_proto = self.prepare(close_after=len(self.DATA))
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ cli_proto.transport.close()
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_ssl_close_peer_after_receiving(self):
+ srv_proto, cli_proto = self.prepare(is_ssl=True,
+ close_after=len(self.DATA))
+ ret = self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ self.run_loop(srv_proto.done)
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(srv_proto.nbytes, len(self.DATA))
+ self.assertEqual(srv_proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sendfile_close_peer_in_middle_of_receiving(self):
+ srv_proto, cli_proto = self.prepare(close_after=1024)
+ with self.assertRaises(ConnectionError):
+ self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ self.run_loop(srv_proto.done)
+
+ self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
+ srv_proto.nbytes)
+ self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
+ self.file.tell())
+
+ def test_sendfile_fallback_close_peer_in_middle_of_receiving(self):
+
+ def sendfile_native(transp, file, offset, count):
+ # to raise SendfileNotAvailableError
+ return base_events.BaseEventLoop._sendfile_native(
+ self.loop, transp, file, offset, count)
+
+ self.loop._sendfile_native = sendfile_native
+
+ srv_proto, cli_proto = self.prepare(close_after=1024)
+ with self.assertRaises(ConnectionError):
+ self.run_loop(
+ self.loop.sendfile(cli_proto.transport, self.file))
+ self.run_loop(srv_proto.done)
+
+ self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
+ srv_proto.nbytes)
+ self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
+ self.file.tell())
+
+ @unittest.skipIf(not hasattr(os, 'sendfile'),
+ "Don't have native sendfile support")
+ def test_sendfile_prevents_bare_write(self):
+ srv_proto, cli_proto = self.prepare()
+ fut = self.loop.create_future()
+
+ async def coro():
+ fut.set_result(None)
+ return await self.loop.sendfile(cli_proto.transport, self.file)
+
+ t = self.loop.create_task(coro())
+ self.run_loop(fut)
+ with self.assertRaisesRegex(RuntimeError,
+ "sendfile is in progress"):
+ cli_proto.transport.write(b'data')
+ ret = self.run_loop(t)
+ self.assertEqual(ret, len(self.DATA))
+
+
if sys.platform == 'win32':
- class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase):
+ class SelectEventLoopTests(EventLoopTestsMixin,
+ SendfileMixin,
+ test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop()
class ProactorEventLoopTests(EventLoopTestsMixin,
+ SendfileMixin,
SubprocessTestsMixin,
test_utils.TestCase):
@@ -2125,7 +2420,7 @@ if sys.platform == 'win32':
else:
import selectors
- class UnixEventLoopTestsMixin(EventLoopTestsMixin):
+ class UnixEventLoopTestsMixin(EventLoopTestsMixin, SendfileMixin):
def setUp(self):
super().setUp()
watcher = asyncio.SafeChildWatcher()
@@ -2556,7 +2851,9 @@ class AbstractEventLoopTests(unittest.TestCase):
with self.assertRaises(NotImplementedError):
await loop.sock_accept(f)
with self.assertRaises(NotImplementedError):
- await loop.sock_sendfile(f, mock.Mock())
+ await loop.sock_sendfile(f, f)
+ with self.assertRaises(NotImplementedError):
+ await loop.sendfile(f, f)
with self.assertRaises(NotImplementedError):
await loop.connect_read_pipe(f, mock.sentinel.pipe)
with self.assertRaises(NotImplementedError):
diff --git a/Misc/NEWS.d/next/Library/2018-01-22-18-18-44.bpo-32622.A1D6FP.rst b/Misc/NEWS.d/next/Library/2018-01-22-18-18-44.bpo-32622.A1D6FP.rst
new file mode 100644
index 0000000..d7433fa
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-01-22-18-18-44.bpo-32622.A1D6FP.rst
@@ -0,0 +1 @@
+Add :meth:`asyncio.AbstractEventLoop.sendfile` method.
diff --git a/Modules/overlapped.c b/Modules/overlapped.c
index e66e856..447a337 100644
--- a/Modules/overlapped.c
+++ b/Modules/overlapped.c
@@ -1436,6 +1436,7 @@ PyInit__overlapped(void)
WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING);
WINAPI_CONSTANT(F_DWORD, ERROR_NETNAME_DELETED);
+ WINAPI_CONSTANT(F_DWORD, ERROR_OPERATION_ABORTED);
WINAPI_CONSTANT(F_DWORD, ERROR_SEM_TIMEOUT);
WINAPI_CONSTANT(F_DWORD, ERROR_PIPE_BUSY);
WINAPI_CONSTANT(F_DWORD, INFINITE);