diff options
author | Giampaolo Rodola' <g.rodola@gmail.com> | 2014-06-11 01:54:30 (GMT) |
---|---|---|
committer | Giampaolo Rodola' <g.rodola@gmail.com> | 2014-06-11 01:54:30 (GMT) |
commit | 915d14190ee07b66e29b385943a942a51c825ae6 (patch) | |
tree | c822a64395af65a3f22f02aa5b87bd39c37a10b4 | |
parent | b398d33c65636aed68246179a4f6ae4b3fbcf182 (diff) | |
download | cpython-915d14190ee07b66e29b385943a942a51c825ae6.zip cpython-915d14190ee07b66e29b385943a942a51c825ae6.tar.gz cpython-915d14190ee07b66e29b385943a942a51c825ae6.tar.bz2 |
fix issue #17552: add socket.sendfile() method allowing to send a file over a socket by using high-performance os.sendfile() on UNIX. Patch by Giampaolo Rodola'·
-rw-r--r-- | Doc/library/os.rst | 4 | ||||
-rw-r--r-- | Doc/library/socket.rst | 15 | ||||
-rw-r--r-- | Doc/library/ssl.rst | 3 | ||||
-rw-r--r-- | Doc/whatsnew/3.5.rst | 11 | ||||
-rw-r--r-- | Lib/socket.py | 148 | ||||
-rw-r--r-- | Lib/ssl.py | 10 | ||||
-rw-r--r-- | Lib/test/test_socket.py | 273 | ||||
-rw-r--r-- | Lib/test/test_ssl.py | 17 | ||||
-rw-r--r-- | Misc/NEWS | 4 |
9 files changed, 483 insertions, 2 deletions
diff --git a/Doc/library/os.rst b/Doc/library/os.rst index 2de4888..170d07b 100644 --- a/Doc/library/os.rst +++ b/Doc/library/os.rst @@ -1092,6 +1092,10 @@ or `the MSDN <http://msdn.microsoft.com/en-us/library/z0kc8e3z.aspx>`_ on Window Availability: Unix. + .. note:: + + For a higher-level version of this see :mod:`socket.socket.sendfile`. + .. versionadded:: 3.3 diff --git a/Doc/library/socket.rst b/Doc/library/socket.rst index 3a9d611..ceeb776 100644 --- a/Doc/library/socket.rst +++ b/Doc/library/socket.rst @@ -1148,6 +1148,21 @@ to sockets. .. versionadded:: 3.3 +.. method:: socket.sendfile(file, offset=0, count=None) + + Send a file until EOF is reached by using high-performance + :mod:`os.sendfile` and return the total number of bytes which were sent. + *file* must be a regular file object opened in binary mode. If + :mod:`os.sendfile` is not available (e.g. Windows) or *file* is not a + regular file :meth:`send` will be used instead. *offset* tells from where to + start reading the file. If specified, *count* is the total number of bytes + to transmit as opposed to sending the file until EOF is reached. File + position is updated on return or also in case of error in which case + :meth:`file.tell() <io.IOBase.tell>` can be used to figure out the number of + bytes which were sent. The socket must be of :const:`SOCK_STREAM` type. Non- + blocking sockets are not supported. + + .. versionadded:: 3.5 .. method:: socket.set_inheritable(inheritable) diff --git a/Doc/library/ssl.rst b/Doc/library/ssl.rst index 0b0edd8..6eb47e3 100644 --- a/Doc/library/ssl.rst +++ b/Doc/library/ssl.rst @@ -789,6 +789,9 @@ SSL sockets provide the following methods of :ref:`socket-objects`: (but passing a non-zero ``flags`` argument is not allowed) - :meth:`~socket.socket.send()`, :meth:`~socket.socket.sendall()` (with the same limitation) +- :meth:`~socket.socket.sendfile()` (but :mod:`os.sendfile` will be used + for plain-text sockets only, else :meth:`~socket.socket.send()` will be used) + .. versionadded:: 3.5 - :meth:`~socket.socket.shutdown()` However, since the SSL (and TLS) protocol has its own framing atop diff --git a/Doc/whatsnew/3.5.rst b/Doc/whatsnew/3.5.rst index 5c4e0b5..af734c8 100644 --- a/Doc/whatsnew/3.5.rst +++ b/Doc/whatsnew/3.5.rst @@ -181,9 +181,18 @@ signal * Different constants of :mod:`signal` module are now enumeration values using the :mod:`enum` module. This allows meaningful names to be printed during - debugging, instead of integer “magic numbers”. (contribute by Giampaolo + debugging, instead of integer “magic numbers”. (contributed by Giampaolo Rodola' in :issue:`21076`) +socket +------ + +* New :meth:`socket.socket.sendfile` method allows to send a file over a socket + by using high-performance :func:`os.sendfile` function on UNIX resulting in + uploads being from 2x to 3x faster than when using plain + :meth:`socket.socket.send`. + (contributed by Giampaolo Rodola' in :issue:`17552`) + xmlrpc ------ diff --git a/Lib/socket.py b/Lib/socket.py index 6d67b3d..cbadff7 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -47,7 +47,7 @@ the setsockopt() and getsockopt() methods. import _socket from _socket import * -import os, sys, io +import os, sys, io, selectors from enum import IntEnum try: @@ -109,6 +109,9 @@ if sys.platform.lower().startswith("win"): __all__.append("errorTab") +class _GiveupOnSendfile(Exception): pass + + class socket(_socket.socket): """A subclass of _socket.socket adding the makefile() method.""" @@ -233,6 +236,149 @@ class socket(_socket.socket): text.mode = mode return text + if hasattr(os, 'sendfile'): + + def _sendfile_use_sendfile(self, file, offset=0, count=None): + self._check_sendfile_params(file, offset, count) + sockno = self.fileno() + try: + fileno = file.fileno() + except (AttributeError, io.UnsupportedOperation) as err: + raise _GiveupOnSendfile(err) # not a regular file + try: + fsize = os.fstat(fileno).st_size + except OSError: + raise _GiveupOnSendfile(err) # not a regular file + if not fsize: + return 0 # empty file + blocksize = fsize if not count else count + + timeout = self.gettimeout() + if timeout == 0: + raise ValueError("non-blocking sockets are not supported") + # poll/select have the advantage of not requiring any + # extra file descriptor, contrarily to epoll/kqueue + # (also, they require a single syscall). + if hasattr(selectors, 'PollSelector'): + selector = selectors.PollSelector() + else: + selector = selectors.SelectSelector() + selector.register(sockno, selectors.EVENT_WRITE) + + total_sent = 0 + # localize variable access to minimize overhead + selector_select = selector.select + os_sendfile = os.sendfile + try: + while True: + if timeout and not selector_select(timeout): + raise _socket.timeout('timed out') + if count: + blocksize = count - total_sent + if blocksize <= 0: + break + try: + sent = os_sendfile(sockno, fileno, offset, blocksize) + except BlockingIOError: + if not timeout: + # Block until the socket is ready to send some + # data; avoids hogging CPU resources. + selector_select() + continue + except OSError as err: + if total_sent == 0: + # We can get here for different reasons, the main + # one being 'file' is not a regular mmap(2)-like + # file, in which case we'll fall back on using + # plain send(). + raise _GiveupOnSendfile(err) + raise err from None + else: + if sent == 0: + break # EOF + offset += sent + total_sent += sent + return total_sent + finally: + if total_sent > 0 and hasattr(file, 'seek'): + file.seek(offset) + else: + def _sendfile_use_sendfile(self, file, offset=0, count=None): + raise _GiveupOnSendfile( + "os.sendfile() not available on this platform") + + def _sendfile_use_send(self, file, offset=0, count=None): + self._check_sendfile_params(file, offset, count) + if self.gettimeout() == 0: + raise ValueError("non-blocking sockets are not supported") + if offset: + file.seek(offset) + blocksize = min(count, 8192) if count else 8192 + total_sent = 0 + # localize variable access to minimize overhead + file_read = file.read + sock_send = self.send + try: + while True: + if count: + blocksize = min(count - total_sent, blocksize) + if blocksize <= 0: + break + data = memoryview(file_read(blocksize)) + if not data: + break # EOF + while True: + try: + sent = sock_send(data) + except BlockingIOError: + continue + else: + total_sent += sent + if sent < len(data): + data = data[sent:] + else: + break + return total_sent + finally: + if total_sent > 0 and hasattr(file, 'seek'): + file.seek(offset + total_sent) + + def _check_sendfile_params(self, file, offset, count): + if 'b' not in getattr(file, 'mode', 'b'): + raise ValueError("file should be opened in binary mode") + if not self.type & SOCK_STREAM: + raise ValueError("only SOCK_STREAM type sockets are supported") + if count is not None: + if not isinstance(count, int): + raise TypeError( + "count must be a positive integer (got {!r})".format(count)) + if count <= 0: + raise ValueError( + "count must be a positive integer (got {!r})".format(count)) + + def sendfile(self, file, offset=0, count=None): + """sendfile(file[, offset[, count]]) -> sent + + Send a file until EOF is reached by using high-performance + os.sendfile() and return the total number of bytes which + were sent. + *file* must be a regular file object opened in binary mode. + If os.sendfile() is not available (e.g. Windows) or file is + not a regular file socket.send() will be used instead. + *offset* tells from where to start reading the file. + If specified, *count* is the total number of bytes to transmit + as opposed to sending the file until EOF is reached. + File position is updated on return or also in case of error in + which case file.tell() can be used to figure out the number of + bytes which were sent. + The socket must be of SOCK_STREAM type. + Non-blocking sockets are not supported. + """ + try: + return self._sendfile_use_sendfile(file, offset, count) + except _GiveupOnSendfile: + return self._sendfile_use_send(file, offset, count) + def _decref_socketios(self): if self._io_refs > 0: self._io_refs -= 1 @@ -700,6 +700,16 @@ class SSLSocket(socket): else: return socket.sendall(self, data, flags) + def sendfile(self, file, offset=0, count=None): + """Send a file, possibly by using os.sendfile() if this is a + clear-text socket. Return the total number of bytes sent. + """ + if self._sslobj is None: + # os.sendfile() works with plain sockets only + return super().sendfile(file, offset, count) + else: + return self._sendfile_use_send(file, offset, count) + def recv(self, buflen=1024, flags=0): self._checkClosed() if self._sslobj: diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 86cbec7..82b79cd 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -19,6 +19,8 @@ import signal import math import pickle import struct +import random +import string try: import multiprocessing except ImportError: @@ -5077,6 +5079,275 @@ class TestSocketSharing(SocketTCPTest): source.close() +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendfileUsingSendTest(ThreadedTCPSocketTest): + """ + Test the send() implementation of socket.sendfile(). + """ + + FILESIZE = (10 * 1024 * 1024) # 10MB + BUFSIZE = 8192 + FILEDATA = b"" + TIMEOUT = 2 + + @classmethod + def setUpClass(cls): + def chunks(total, step): + assert total >= step + while total > step: + yield step + total -= step + if total: + yield total + + chunk = b"".join([random.choice(string.ascii_letters).encode() + for i in range(cls.BUFSIZE)]) + with open(support.TESTFN, 'wb') as f: + for csize in chunks(cls.FILESIZE, cls.BUFSIZE): + f.write(chunk) + with open(support.TESTFN, 'rb') as f: + cls.FILEDATA = f.read() + assert len(cls.FILEDATA) == cls.FILESIZE + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + + def accept_conn(self): + self.serv.settimeout(self.TIMEOUT) + conn, addr = self.serv.accept() + conn.settimeout(self.TIMEOUT) + self.addCleanup(conn.close) + return conn + + def recv_data(self, conn): + received = [] + while True: + chunk = conn.recv(self.BUFSIZE) + if not chunk: + break + received.append(chunk) + return b''.join(received) + + def meth_from_sock(self, sock): + # Depending on the mixin class being run return either send() + # or sendfile() method implementation. + return getattr(sock, "_sendfile_use_send") + + # regular file + + def _testRegularFile(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, self.FILESIZE) + self.assertEqual(file.tell(), self.FILESIZE) + + def testRegularFile(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # non regular file + + def _testNonRegularFile(self): + address = self.serv.getsockname() + file = io.BytesIO(self.FILEDATA) + with socket.create_connection(address) as sock, file as file: + sent = sock.sendfile(file) + self.assertEqual(sent, self.FILESIZE) + self.assertEqual(file.tell(), self.FILESIZE) + self.assertRaises(socket._GiveupOnSendfile, + sock._sendfile_use_sendfile, file) + + def testNonRegularFile(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # empty file + + def _testEmptyFileSend(self): + address = self.serv.getsockname() + filename = support.TESTFN + "2" + with open(filename, 'wb'): + self.addCleanup(support.unlink, filename) + file = open(filename, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, 0) + self.assertEqual(file.tell(), 0) + + def testEmptyFileSend(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(data, b"") + + # offset + + def _testOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file, offset=5000) + self.assertEqual(sent, self.FILESIZE - 5000) + self.assertEqual(file.tell(), self.FILESIZE) + + def testOffset(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE - 5000) + self.assertEqual(data, self.FILEDATA[5000:]) + + # count + + def _testCount(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 5000007 + meth = self.meth_from_sock(sock) + sent = meth(file, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count) + + def testCount(self): + count = 5000007 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[:count]) + + # count small + + def _testCountSmall(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 1 + meth = self.meth_from_sock(sock) + sent = meth(file, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count) + + def testCountSmall(self): + count = 1 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[:count]) + + # count + offset + + def _testCountWithOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 100007 + meth = self.meth_from_sock(sock) + sent = meth(file, offset=2007, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count + 2007) + + def testCountWithOffset(self): + count = 100007 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[2007:count+2007]) + + # non blocking sockets are not supposed to work + + def _testNonBlocking(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + sock.setblocking(False) + meth = self.meth_from_sock(sock) + self.assertRaises(ValueError, meth, file) + self.assertRaises(ValueError, sock.sendfile, file) + + def testNonBlocking(self): + conn = self.accept_conn() + if conn.recv(8192): + self.fail('was not supposed to receive any data') + + # timeout (non-triggered) + + def _testWithTimeout(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, self.FILESIZE) + + def testWithTimeout(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # timeout (triggered) + + def _testWithTimeoutTriggeredSend(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=0.01) as sock, \ + file as file: + meth = self.meth_from_sock(sock) + self.assertRaises(socket.timeout, meth, file) + + def testWithTimeoutTriggeredSend(self): + conn = self.accept_conn() + conn.recv(88192) + + # errors + + def _test_errors(self): + pass + + def test_errors(self): + with open(support.TESTFN, 'rb') as file: + with socket.socket(type=socket.SOCK_DGRAM) as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex( + ValueError, "SOCK_STREAM", meth, file) + with open(support.TESTFN, 'rt') as file: + with socket.socket() as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex( + ValueError, "binary mode", meth, file) + with open(support.TESTFN, 'rb') as file: + with socket.socket() as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex(TypeError, "positive integer", + meth, file, count='2') + self.assertRaisesRegex(TypeError, "positive integer", + meth, file, count=0.1) + self.assertRaisesRegex(ValueError, "positive integer", + meth, file, count=0) + self.assertRaisesRegex(ValueError, "positive integer", + meth, file, count=-1) + + +@unittest.skipUnless(thread, 'Threading required for this test.') +@unittest.skipUnless(hasattr(os, "sendfile"), + 'os.sendfile() required for this test.') +class SendfileUsingSendfileTest(SendfileUsingSendTest): + """ + Test the sendfile() implementation of socket.sendfile(). + """ + def meth_from_sock(self, sock): + return getattr(sock, "_sendfile_use_sendfile") + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] @@ -5129,6 +5400,8 @@ def test_main(): InterruptedRecvTimeoutTest, InterruptedSendTimeoutTest, TestSocketSharing, + SendfileUsingSendTest, + SendfileUsingSendfileTest, ]) thread_info = support.threading_setup() diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index f72fb15..bdde9ac 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -2957,6 +2957,23 @@ else: self.assertRaises(ValueError, s.read, 1024) self.assertRaises(ValueError, s.write, b'hello') + def test_sendfile(self): + TEST_DATA = b"x" * 512 + with open(support.TESTFN, 'wb') as f: + f.write(TEST_DATA) + self.addCleanup(support.unlink, support.TESTFN) + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + with server: + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + with open(support.TESTFN, 'rb') as file: + s.sendfile(file) + self.assertEqual(s.recv(1024), TEST_DATA) + def test_main(verbose=False): if support.verbose: @@ -92,6 +92,10 @@ Core and Builtins Library ------- +- Issue 17552: new socket.sendfile() method allowing to send a file over a + socket by using high-performance os.sendfile() on UNIX. + Patch by Giampaolo Rodola'. + - Issue #18039: dbm.dump.open() now always creates a new database when the flag has the value 'n'. Patch by Claudiu Popa. |