summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNick Coghlan <ncoghlan@gmail.com>2011-08-22 01:55:57 (GMT)
committerNick Coghlan <ncoghlan@gmail.com>2011-08-22 01:55:57 (GMT)
commit96fe56abec36d2cb82c56c9ddafea0096f4f6c7e (patch)
treefa5e6293d1fc1cc00b41c57cd192df49a6a6b7ff
parent8983729dc08a05d5419d35ec3f431c7b442401a6 (diff)
downloadcpython-96fe56abec36d2cb82c56c9ddafea0096f4f6c7e.zip
cpython-96fe56abec36d2cb82c56c9ddafea0096f4f6c7e.tar.gz
cpython-96fe56abec36d2cb82c56c9ddafea0096f4f6c7e.tar.bz2
Add support for the send/recvmsg API to the socket module. Patch by David Watson and Heiko Wundram. (Closes #6560)
-rw-r--r--Doc/library/socket.rst182
-rw-r--r--Doc/whatsnew/3.3.rst12
-rw-r--r--Lib/ssl.py24
-rw-r--r--Lib/test/test_socket.py2120
-rw-r--r--Lib/test/test_ssl.py16
-rw-r--r--Misc/NEWS4
-rw-r--r--Modules/socketmodule.c809
7 files changed, 3167 insertions, 0 deletions
diff --git a/Doc/library/socket.rst b/Doc/library/socket.rst
index f587977..196fe8b 100644
--- a/Doc/library/socket.rst
+++ b/Doc/library/socket.rst
@@ -198,6 +198,7 @@ The module :mod:`socket` exports the following constants and functions:
SOMAXCONN
MSG_*
SOL_*
+ SCM_*
IPPROTO_*
IPPORT_*
INADDR_*
@@ -511,6 +512,49 @@ The module :mod:`socket` exports the following constants and functions:
Availability: Unix (maybe not all platforms).
+..
+ XXX: Are sendmsg(), recvmsg() and CMSG_*() available on any
+ non-Unix platforms? The old (obsolete?) 4.2BSD form of the
+ interface, in which struct msghdr has no msg_control or
+ msg_controllen members, is not currently supported.
+
+.. function:: CMSG_LEN(length)
+
+ Return the total length, without trailing padding, of an ancillary
+ data item with associated data of the given *length*. This value
+ can often be used as the buffer size for :meth:`~socket.recvmsg` to
+ receive a single item of ancillary data, but :rfc:`3542` requires
+ portable applications to use :func:`CMSG_SPACE` and thus include
+ space for padding, even when the item will be the last in the
+ buffer. Raises :exc:`OverflowError` if *length* is outside the
+ permissible range of values.
+
+ Availability: most Unix platforms, possibly others.
+
+ .. versionadded:: 3.3
+
+
+.. function:: CMSG_SPACE(length)
+
+ Return the buffer size needed for :meth:`~socket.recvmsg` to
+ receive an ancillary data item with associated data of the given
+ *length*, along with any trailing padding. The buffer space needed
+ to receive multiple items is the sum of the :func:`CMSG_SPACE`
+ values for their associated data lengths. Raises
+ :exc:`OverflowError` if *length* is outside the permissible range
+ of values.
+
+ Note that some systems might support ancillary data without
+ providing this function. Also note that setting the buffer size
+ using the results of this function may not precisely limit the
+ amount of ancillary data that can be received, since additional
+ data may be able to fit into the padding area.
+
+ Availability: most Unix platforms, possibly others.
+
+ .. versionadded:: 3.3
+
+
.. function:: getdefaulttimeout()
Return the default timeout in seconds (float) for new socket objects. A value
@@ -742,6 +786,109 @@ correspond to Unix system calls applicable to sockets.
to zero. (The format of *address* depends on the address family --- see above.)
+.. method:: socket.recvmsg(bufsize[, ancbufsize[, flags]])
+
+ Receive normal data (up to *bufsize* bytes) and ancillary data from
+ the socket. The *ancbufsize* argument sets the size in bytes of
+ the internal buffer used to receive the ancillary data; it defaults
+ to 0, meaning that no ancillary data will be received. Appropriate
+ buffer sizes for ancillary data can be calculated using
+ :func:`CMSG_SPACE` or :func:`CMSG_LEN`, and items which do not fit
+ into the buffer might be truncated or discarded. The *flags*
+ argument defaults to 0 and has the same meaning as for
+ :meth:`recv`.
+
+ The return value is a 4-tuple: ``(data, ancdata, msg_flags,
+ address)``. The *data* item is a :class:`bytes` object holding the
+ non-ancillary data received. The *ancdata* item is a list of zero
+ or more tuples ``(cmsg_level, cmsg_type, cmsg_data)`` representing
+ the ancillary data (control messages) received: *cmsg_level* and
+ *cmsg_type* are integers specifying the protocol level and
+ protocol-specific type respectively, and *cmsg_data* is a
+ :class:`bytes` object holding the associated data. The *msg_flags*
+ item is the bitwise OR of various flags indicating conditions on
+ the received message; see your system documentation for details.
+ If the receiving socket is unconnected, *address* is the address of
+ the sending socket, if available; otherwise, its value is
+ unspecified.
+
+ On some systems, :meth:`sendmsg` and :meth:`recvmsg` can be used to
+ pass file descriptors between processes over an :const:`AF_UNIX`
+ socket. When this facility is used (it is often restricted to
+ :const:`SOCK_STREAM` sockets), :meth:`recvmsg` will return, in its
+ ancillary data, items of the form ``(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS, fds)``, where *fds* is a :class:`bytes` object
+ representing the new file descriptors as a binary array of the
+ native C :c:type:`int` type. If :meth:`recvmsg` raises an
+ exception after the system call returns, it will first attempt to
+ close any file descriptors received via this mechanism.
+
+ Some systems do not indicate the truncated length of ancillary data
+ items which have been only partially received. If an item appears
+ to extend beyond the end of the buffer, :meth:`recvmsg` will issue
+ a :exc:`RuntimeWarning`, and will return the part of it which is
+ inside the buffer provided it has not been truncated before the
+ start of its associated data.
+
+ On systems which support the :const:`SCM_RIGHTS` mechanism, the
+ following function will receive up to *maxfds* file descriptors,
+ returning the message data and a list containing the descriptors
+ (while ignoring unexpected conditions such as unrelated control
+ messages being received). See also :meth:`sendmsg`. ::
+
+ import socket, array
+
+ def recv_fds(sock, msglen, maxfds):
+ fds = array.array("i") # Array of ints
+ msg, ancdata, flags, addr = sock.recvmsg(msglen, socket.CMSG_LEN(maxfds * fds.itemsize))
+ for cmsg_level, cmsg_type, cmsg_data in ancdata:
+ if (cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS):
+ # Append data, ignoring any truncated integers at the end.
+ fds.fromstring(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
+ return msg, list(fds)
+
+ Availability: most Unix platforms, possibly others.
+
+ .. versionadded:: 3.3
+
+
+.. method:: socket.recvmsg_into(buffers[, ancbufsize[, flags]])
+
+ Receive normal data and ancillary data from the socket, behaving as
+ :meth:`recvmsg` would, but scatter the non-ancillary data into a
+ series of buffers instead of returning a new bytes object. The
+ *buffers* argument must be an iterable of objects that export
+ writable buffers (e.g. :class:`bytearray` objects); these will be
+ filled with successive chunks of the non-ancillary data until it
+ has all been written or there are no more buffers. The operating
+ system may set a limit (:func:`~os.sysconf` value ``SC_IOV_MAX``)
+ on the number of buffers that can be used. The *ancbufsize* and
+ *flags* arguments have the same meaning as for :meth:`recvmsg`.
+
+ The return value is a 4-tuple: ``(nbytes, ancdata, msg_flags,
+ address)``, where *nbytes* is the total number of bytes of
+ non-ancillary data written into the buffers, and *ancdata*,
+ *msg_flags* and *address* are the same as for :meth:`recvmsg`.
+
+ Example::
+
+ >>> import socket
+ >>> s1, s2 = socket.socketpair()
+ >>> b1 = bytearray(b'----')
+ >>> b2 = bytearray(b'0123456789')
+ >>> b3 = bytearray(b'--------------')
+ >>> s1.send(b'Mary had a little lamb')
+ 22
+ >>> s2.recvmsg_into([b1, memoryview(b2)[2:9], b3])
+ (22, [], 0, None)
+ >>> [b1, b2, b3]
+ [bytearray(b'Mary'), bytearray(b'01 had a 9'), bytearray(b'little lamb---')]
+
+ Availability: most Unix platforms, possibly others.
+
+ .. versionadded:: 3.3
+
+
.. method:: socket.recvfrom_into(buffer[, nbytes[, flags]])
Receive data from the socket, writing it into *buffer* instead of creating a
@@ -789,6 +936,41 @@ correspond to Unix system calls applicable to sockets.
above.)
+.. method:: socket.sendmsg(buffers[, ancdata[, flags[, address]]])
+
+ Send normal and ancillary data to the socket, gathering the
+ non-ancillary data from a series of buffers and concatenating it
+ into a single message. The *buffers* argument specifies the
+ non-ancillary data as an iterable of buffer-compatible objects
+ (e.g. :class:`bytes` objects); the operating system may set a limit
+ (:func:`~os.sysconf` value ``SC_IOV_MAX``) on the number of buffers
+ that can be used. The *ancdata* argument specifies the ancillary
+ data (control messages) as an iterable of zero or more tuples
+ ``(cmsg_level, cmsg_type, cmsg_data)``, where *cmsg_level* and
+ *cmsg_type* are integers specifying the protocol level and
+ protocol-specific type respectively, and *cmsg_data* is a
+ buffer-compatible object holding the associated data. Note that
+ some systems (in particular, systems without :func:`CMSG_SPACE`)
+ might support sending only one control message per call. The
+ *flags* argument defaults to 0 and has the same meaning as for
+ :meth:`send`. If *address* is supplied and not ``None``, it sets a
+ destination address for the message. The return value is the
+ number of bytes of non-ancillary data sent.
+
+ The following function sends the list of file descriptors *fds*
+ over an :const:`AF_UNIX` socket, on systems which support the
+ :const:`SCM_RIGHTS` mechanism. See also :meth:`recvmsg`. ::
+
+ import socket, array
+
+ def send_fds(sock, msg, fds):
+ return sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))])
+
+ Availability: most Unix platforms, possibly others.
+
+ .. versionadded:: 3.3
+
+
.. method:: socket.setblocking(flag)
Set blocking or non-blocking mode of the socket: if *flag* is false, the
diff --git a/Doc/whatsnew/3.3.rst b/Doc/whatsnew/3.3.rst
index 0edf5ff..3036a3e 100644
--- a/Doc/whatsnew/3.3.rst
+++ b/Doc/whatsnew/3.3.rst
@@ -212,6 +212,18 @@ signal
* :func:`signal.signal` and :func:`signal.siginterrupt` raise an OSError,
instead of a RuntimeError: OSError has an errno attribute.
+socket
+------
+
+The :class:`~socket.socket` class now exposes addititonal methods to
+process ancillary data when supported by the underlying platform:
+
+* :func:`~socket.socket.sendmsg`
+* :func:`~socket.socket.recvmsg`
+* :func:`~socket.socket.recvmsg_into`
+
+(Contributed by David Watson in :issue:`6560`, based on an earlier patch
+by Heiko Wundram)
ssl
---
diff --git a/Lib/ssl.py b/Lib/ssl.py
index 914e749..1b7416e 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -355,6 +355,14 @@ class SSLSocket(socket):
else:
return socket.sendto(self, data, flags_or_addr, addr)
+ def sendmsg(self, *args, **kwargs):
+ self._checkClosed()
+ if self._sslobj:
+ raise ValueError("sendmsg not allowed on instances of %s" %
+ self.__class__)
+ else:
+ return socket.sendmsg(self, *args, **kwargs)
+
def sendall(self, data, flags=0):
self._checkClosed()
if self._sslobj:
@@ -413,6 +421,22 @@ class SSLSocket(socket):
else:
return socket.recvfrom_into(self, buffer, nbytes, flags)
+ def recvmsg(self, *args, **kwargs):
+ self._checkClosed()
+ if self._sslobj:
+ raise ValueError("recvmsg not allowed on instances of %s" %
+ self.__class__)
+ else:
+ return socket.recvmsg(self, *args, **kwargs)
+
+ def recvmsg_into(self, *args, **kwargs):
+ self._checkClosed()
+ if self._sslobj:
+ raise ValueError("recvmsg_into not allowed on instances of %s" %
+ self.__class__)
+ else:
+ return socket.recvmsg_into(self, *args, **kwargs)
+
def pending(self):
self._checkClosed()
if self._sslobj:
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index 4e5085e..bbc9b78 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -7,6 +7,8 @@ import errno
import io
import socket
import select
+import tempfile
+import _testcapi
import time
import traceback
import queue
@@ -34,6 +36,9 @@ except ImportError:
thread = None
threading = None
+# Size in bytes of the int type
+SIZEOF_INT = array.array("i").itemsize
+
class SocketTCPTest(unittest.TestCase):
def setUp(self):
@@ -55,6 +60,26 @@ class SocketUDPTest(unittest.TestCase):
self.serv.close()
self.serv = None
+class ThreadSafeCleanupTestCase(unittest.TestCase):
+ """Subclass of unittest.TestCase with thread-safe cleanup methods.
+
+ This subclass protects the addCleanup() and doCleanups() methods
+ with a recursive lock.
+ """
+
+ if threading:
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._cleanup_lock = threading.RLock()
+
+ def addCleanup(self, *args, **kwargs):
+ with self._cleanup_lock:
+ return super().addCleanup(*args, **kwargs)
+
+ def doCleanups(self, *args, **kwargs):
+ with self._cleanup_lock:
+ return super().doCleanups(*args, **kwargs)
+
class ThreadableTest:
"""Threadable Test class
@@ -237,6 +262,243 @@ class SocketPairTest(unittest.TestCase, ThreadableTest):
ThreadableTest.clientTearDown(self)
+# The following classes are used by the sendmsg()/recvmsg() tests.
+# Combining, for instance, ConnectedStreamTestMixin and TCPTestBase
+# gives a drop-in replacement for SocketConnectedTest, but different
+# address families can be used, and the attributes serv_addr and
+# cli_addr will be set to the addresses of the endpoints.
+
+class SocketTestBase(unittest.TestCase):
+ """A base class for socket tests.
+
+ Subclasses must provide methods newSocket() to return a new socket
+ and bindSock(sock) to bind it to an unused address.
+
+ Creates a socket self.serv and sets self.serv_addr to its address.
+ """
+
+ def setUp(self):
+ self.serv = self.newSocket()
+ self.bindServer()
+
+ def bindServer(self):
+ """Bind server socket and set self.serv_addr to its address."""
+ self.bindSock(self.serv)
+ self.serv_addr = self.serv.getsockname()
+
+ def tearDown(self):
+ self.serv.close()
+ self.serv = None
+
+
+class SocketListeningTestMixin(SocketTestBase):
+ """Mixin to listen on the server socket."""
+
+ def setUp(self):
+ super().setUp()
+ self.serv.listen(1)
+
+
+class ThreadedSocketTestMixin(ThreadSafeCleanupTestCase, SocketTestBase,
+ ThreadableTest):
+ """Mixin to add client socket and allow client/server tests.
+
+ Client socket is self.cli and its address is self.cli_addr. See
+ ThreadableTest for usage information.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ ThreadableTest.__init__(self)
+
+ def clientSetUp(self):
+ self.cli = self.newClientSocket()
+ self.bindClient()
+
+ def newClientSocket(self):
+ """Return a new socket for use as client."""
+ return self.newSocket()
+
+ def bindClient(self):
+ """Bind client socket and set self.cli_addr to its address."""
+ self.bindSock(self.cli)
+ self.cli_addr = self.cli.getsockname()
+
+ def clientTearDown(self):
+ self.cli.close()
+ self.cli = None
+ ThreadableTest.clientTearDown(self)
+
+
+class ConnectedStreamTestMixin(SocketListeningTestMixin,
+ ThreadedSocketTestMixin):
+ """Mixin to allow client/server stream tests with connected client.
+
+ Server's socket representing connection to client is self.cli_conn
+ and client's connection to server is self.serv_conn. (Based on
+ SocketConnectedTest.)
+ """
+
+ def setUp(self):
+ super().setUp()
+ # Indicate explicitly we're ready for the client thread to
+ # proceed and then perform the blocking call to accept
+ self.serverExplicitReady()
+ conn, addr = self.serv.accept()
+ self.cli_conn = conn
+
+ def tearDown(self):
+ self.cli_conn.close()
+ self.cli_conn = None
+ super().tearDown()
+
+ def clientSetUp(self):
+ super().clientSetUp()
+ self.cli.connect(self.serv_addr)
+ self.serv_conn = self.cli
+
+ def clientTearDown(self):
+ self.serv_conn.close()
+ self.serv_conn = None
+ super().clientTearDown()
+
+
+class UnixSocketTestBase(SocketTestBase):
+ """Base class for Unix-domain socket tests."""
+
+ # This class is used for file descriptor passing tests, so we
+ # create the sockets in a private directory so that other users
+ # can't send anything that might be problematic for a privileged
+ # user running the tests.
+
+ def setUp(self):
+ self.dir_path = tempfile.mkdtemp()
+ self.addCleanup(os.rmdir, self.dir_path)
+ super().setUp()
+
+ def bindSock(self, sock):
+ path = tempfile.mktemp(dir=self.dir_path)
+ sock.bind(path)
+ self.addCleanup(support.unlink, path)
+
+class UnixStreamBase(UnixSocketTestBase):
+ """Base class for Unix-domain SOCK_STREAM tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+
+
+class InetTestBase(SocketTestBase):
+ """Base class for IPv4 socket tests."""
+
+ host = HOST
+
+ def setUp(self):
+ super().setUp()
+ self.port = self.serv_addr[1]
+
+ def bindSock(self, sock):
+ support.bind_port(sock, host=self.host)
+
+class TCPTestBase(InetTestBase):
+ """Base class for TCP-over-IPv4 tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+class UDPTestBase(InetTestBase):
+ """Base class for UDP-over-IPv4 tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+
+class SCTPStreamBase(InetTestBase):
+ """Base class for SCTP tests in one-to-one (SOCK_STREAM) mode."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET, socket.SOCK_STREAM,
+ socket.IPPROTO_SCTP)
+
+
+class Inet6TestBase(InetTestBase):
+ """Base class for IPv6 socket tests."""
+
+ # Don't use "localhost" here - it may not have an IPv6 address
+ # assigned to it by default (e.g. in /etc/hosts), and if someone
+ # has assigned it an IPv4-mapped address, then it's unlikely to
+ # work with the full IPv6 API.
+ host = "::1"
+
+class UDP6TestBase(Inet6TestBase):
+ """Base class for UDP-over-IPv6 tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
+
+
+# Test-skipping decorators for use with ThreadableTest.
+
+def skipWithClientIf(condition, reason):
+ """Skip decorated test if condition is true, add client_skip decorator.
+
+ If the decorated object is not a class, sets its attribute
+ "client_skip" to a decorator which will return an empty function
+ if the test is to be skipped, or the original function if it is
+ not. This can be used to avoid running the client part of a
+ skipped test when using ThreadableTest.
+ """
+ def client_pass(*args, **kwargs):
+ pass
+ def skipdec(obj):
+ retval = unittest.skip(reason)(obj)
+ if not isinstance(obj, type):
+ retval.client_skip = lambda f: client_pass
+ return retval
+ def noskipdec(obj):
+ if not (isinstance(obj, type) or hasattr(obj, "client_skip")):
+ obj.client_skip = lambda f: f
+ return obj
+ return skipdec if condition else noskipdec
+
+
+def requireAttrs(obj, *attributes):
+ """Skip decorated test if obj is missing any of the given attributes.
+
+ Sets client_skip attribute as skipWithClientIf() does.
+ """
+ missing = [name for name in attributes if not hasattr(obj, name)]
+ return skipWithClientIf(
+ missing, "don't have " + ", ".join(name for name in missing))
+
+
+def requireSocket(*args):
+ """Skip decorated test if a socket cannot be created with given arguments.
+
+ When an argument is given as a string, will use the value of that
+ attribute of the socket module, or skip the test if it doesn't
+ exist. Sets client_skip attribute as skipWithClientIf() does.
+ """
+ err = None
+ missing = [obj for obj in args if
+ isinstance(obj, str) and not hasattr(socket, obj)]
+ if missing:
+ err = "don't have " + ", ".join(name for name in missing)
+ else:
+ callargs = [getattr(socket, obj) if isinstance(obj, str) else obj
+ for obj in args]
+ try:
+ s = socket.socket(*callargs)
+ except socket.error as e:
+ # XXX: check errno?
+ err = str(e)
+ else:
+ s.close()
+ return skipWithClientIf(
+ err is not None,
+ "can't create socket({0}): {1}".format(
+ ", ".join(str(o) for o in args), err))
+
+
#######################################################################
## Begin Tests
@@ -945,6 +1207,1839 @@ class BasicUDPTest(ThreadedUDPSocketTest):
def _testRecvFromNegative(self):
self.cli.sendto(MSG, 0, (HOST, self.port))
+
+# Tests for the sendmsg()/recvmsg() interface. Where possible, the
+# same test code is used with different families and types of socket
+# (e.g. stream, datagram), and tests using recvmsg() are repeated
+# using recvmsg_into().
+#
+# The generic test classes such as SendmsgTests and
+# RecvmsgGenericTests inherit from SendrecvmsgBase and expect to be
+# supplied with sockets cli_sock and serv_sock representing the
+# client's and the server's end of the connection respectively, and
+# attributes cli_addr and serv_addr holding their (numeric where
+# appropriate) addresses.
+#
+# The final concrete test classes combine these with subclasses of
+# SocketTestBase which set up client and server sockets of a specific
+# type, and with subclasses of SendrecvmsgBase such as
+# SendrecvmsgDgramBase and SendrecvmsgConnectedBase which map these
+# sockets to cli_sock and serv_sock and override the methods and
+# attributes of SendrecvmsgBase to fill in destination addresses if
+# needed when sending, check for specific flags in msg_flags, etc.
+#
+# RecvmsgIntoMixin provides a version of doRecvmsg() implemented using
+# recvmsg_into().
+
+# XXX: like the other datagram (UDP) tests in this module, the code
+# here assumes that datagram delivery on the local machine will be
+# reliable.
+
+class SendrecvmsgBase(ThreadSafeCleanupTestCase):
+ # Base class for sendmsg()/recvmsg() tests.
+
+ # Time in seconds to wait before considering a test failed, or
+ # None for no timeout. Not all tests actually set a timeout.
+ fail_timeout = 3.0
+
+ def setUp(self):
+ self.misc_event = threading.Event()
+ super().setUp()
+
+ def sendToServer(self, msg):
+ # Send msg to the server.
+ return self.cli_sock.send(msg)
+
+ # Tuple of alternative default arguments for sendmsg() when called
+ # via sendmsgToServer() (e.g. to include a destination address).
+ sendmsg_to_server_defaults = ()
+
+ def sendmsgToServer(self, *args):
+ # Call sendmsg() on self.cli_sock with the given arguments,
+ # filling in any arguments which are not supplied with the
+ # corresponding items of self.sendmsg_to_server_defaults, if
+ # any.
+ return self.cli_sock.sendmsg(
+ *(args + self.sendmsg_to_server_defaults[len(args):]))
+
+ def doRecvmsg(self, sock, bufsize, *args):
+ # Call recvmsg() on sock with given arguments and return its
+ # result. Should be used for tests which can use either
+ # recvmsg() or recvmsg_into() - RecvmsgIntoMixin overrides
+ # this method with one which emulates it using recvmsg_into(),
+ # thus allowing the same test to be used for both methods.
+ result = sock.recvmsg(bufsize, *args)
+ self.registerRecvmsgResult(result)
+ return result
+
+ def registerRecvmsgResult(self, result):
+ # Called by doRecvmsg() with the return value of recvmsg() or
+ # recvmsg_into(). Can be overridden to arrange cleanup based
+ # on the returned ancillary data, for instance.
+ pass
+
+ def checkRecvmsgAddress(self, addr1, addr2):
+ # Called to compare the received address with the address of
+ # the peer.
+ self.assertEqual(addr1, addr2)
+
+ # Flags that are normally unset in msg_flags
+ msg_flags_common_unset = 0
+ for name in ("MSG_CTRUNC", "MSG_OOB"):
+ msg_flags_common_unset |= getattr(socket, name, 0)
+
+ # Flags that are normally set
+ msg_flags_common_set = 0
+
+ # Flags set when a complete record has been received (e.g. MSG_EOR
+ # for SCTP)
+ msg_flags_eor_indicator = 0
+
+ # Flags set when a complete record has not been received
+ # (e.g. MSG_TRUNC for datagram sockets)
+ msg_flags_non_eor_indicator = 0
+
+ def checkFlags(self, flags, eor=None, checkset=0, checkunset=0, ignore=0):
+ # Method to check the value of msg_flags returned by recvmsg[_into]().
+ #
+ # Checks that all bits in msg_flags_common_set attribute are
+ # set in "flags" and all bits in msg_flags_common_unset are
+ # unset.
+ #
+ # The "eor" argument specifies whether the flags should
+ # indicate that a full record (or datagram) has been received.
+ # If "eor" is None, no checks are done; otherwise, checks
+ # that:
+ #
+ # * if "eor" is true, all bits in msg_flags_eor_indicator are
+ # set and all bits in msg_flags_non_eor_indicator are unset
+ #
+ # * if "eor" is false, all bits in msg_flags_non_eor_indicator
+ # are set and all bits in msg_flags_eor_indicator are unset
+ #
+ # If "checkset" and/or "checkunset" are supplied, they require
+ # the given bits to be set or unset respectively, overriding
+ # what the attributes require for those bits.
+ #
+ # If any bits are set in "ignore", they will not be checked,
+ # regardless of the other inputs.
+ #
+ # Will raise Exception if the inputs require a bit to be both
+ # set and unset, and it is not ignored.
+
+ defaultset = self.msg_flags_common_set
+ defaultunset = self.msg_flags_common_unset
+
+ if eor:
+ defaultset |= self.msg_flags_eor_indicator
+ defaultunset |= self.msg_flags_non_eor_indicator
+ elif eor is not None:
+ defaultset |= self.msg_flags_non_eor_indicator
+ defaultunset |= self.msg_flags_eor_indicator
+
+ # Function arguments override defaults
+ defaultset &= ~checkunset
+ defaultunset &= ~checkset
+
+ # Merge arguments with remaining defaults, and check for conflicts
+ checkset |= defaultset
+ checkunset |= defaultunset
+ inboth = checkset & checkunset & ~ignore
+ if inboth:
+ raise Exception("contradictory set, unset requirements for flags "
+ "{0:#x}".format(inboth))
+
+ # Compare with given msg_flags value
+ mask = (checkset | checkunset) & ~ignore
+ self.assertEqual(flags & mask, checkset & mask)
+
+
+class RecvmsgIntoMixin(SendrecvmsgBase):
+ # Mixin to implement doRecvmsg() using recvmsg_into().
+
+ def doRecvmsg(self, sock, bufsize, *args):
+ buf = bytearray(bufsize)
+ result = sock.recvmsg_into([buf], *args)
+ self.registerRecvmsgResult(result)
+ self.assertGreaterEqual(result[0], 0)
+ self.assertLessEqual(result[0], bufsize)
+ return (bytes(buf[:result[0]]),) + result[1:]
+
+
+class SendrecvmsgDgramFlagsBase(SendrecvmsgBase):
+ # Defines flags to be checked in msg_flags for datagram sockets.
+
+ @property
+ def msg_flags_non_eor_indicator(self):
+ return super().msg_flags_non_eor_indicator | socket.MSG_TRUNC
+
+
+class SendrecvmsgSCTPFlagsBase(SendrecvmsgBase):
+ # Defines flags to be checked in msg_flags for SCTP sockets.
+
+ @property
+ def msg_flags_eor_indicator(self):
+ return super().msg_flags_eor_indicator | socket.MSG_EOR
+
+
+class SendrecvmsgConnectionlessBase(SendrecvmsgBase):
+ # Base class for tests on connectionless-mode sockets. Users must
+ # supply sockets on attributes cli and serv to be mapped to
+ # cli_sock and serv_sock respectively.
+
+ @property
+ def serv_sock(self):
+ return self.serv
+
+ @property
+ def cli_sock(self):
+ return self.cli
+
+ @property
+ def sendmsg_to_server_defaults(self):
+ return ([], [], 0, self.serv_addr)
+
+ def sendToServer(self, msg):
+ return self.cli_sock.sendto(msg, self.serv_addr)
+
+
+class SendrecvmsgConnectedBase(SendrecvmsgBase):
+ # Base class for tests on connected sockets. Users must supply
+ # sockets on attributes serv_conn and cli_conn (representing the
+ # connections *to* the server and the client), to be mapped to
+ # cli_sock and serv_sock respectively.
+
+ @property
+ def serv_sock(self):
+ return self.cli_conn
+
+ @property
+ def cli_sock(self):
+ return self.serv_conn
+
+ def checkRecvmsgAddress(self, addr1, addr2):
+ # Address is currently "unspecified" for a connected socket,
+ # so we don't examine it
+ pass
+
+
+class SendrecvmsgServerTimeoutBase(SendrecvmsgBase):
+ # Base class to set a timeout on server's socket.
+
+ def setUp(self):
+ super().setUp()
+ self.serv_sock.settimeout(self.fail_timeout)
+
+
+class SendmsgTests(SendrecvmsgServerTimeoutBase):
+ # Tests for sendmsg() which can use any socket type and do not
+ # involve recvmsg() or recvmsg_into().
+
+ def testSendmsg(self):
+ # Send a simple message with sendmsg().
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsg(self):
+ self.assertEqual(self.sendmsgToServer([MSG]), len(MSG))
+
+ def testSendmsgDataGenerator(self):
+ # Send from buffer obtained from a generator (not a sequence).
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgDataGenerator(self):
+ self.assertEqual(self.sendmsgToServer((o for o in [MSG])),
+ len(MSG))
+
+ def testSendmsgAncillaryGenerator(self):
+ # Gather (empty) ancillary data from a generator.
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgAncillaryGenerator(self):
+ self.assertEqual(self.sendmsgToServer([MSG], (o for o in [])),
+ len(MSG))
+
+ def testSendmsgArray(self):
+ # Send data from an array instead of the usual bytes object.
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgArray(self):
+ self.assertEqual(self.sendmsgToServer([array.array("B", MSG)]),
+ len(MSG))
+
+ def testSendmsgGather(self):
+ # Send message data from more than one buffer (gather write).
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgGather(self):
+ self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG))
+
+ def testSendmsgBadArgs(self):
+ # Check that sendmsg() rejects invalid arguments.
+ self.assertEqual(self.serv_sock.recv(1000), b"done")
+
+ def _testSendmsgBadArgs(self):
+ self.assertRaises(TypeError, self.cli_sock.sendmsg)
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ b"not in an iterable")
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ object())
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [object()])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG, object()])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], object())
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [], object())
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [], 0, object())
+ self.sendToServer(b"done")
+
+ def testSendmsgBadCmsg(self):
+ # Check that invalid ancillary data items are rejected.
+ self.assertEqual(self.serv_sock.recv(1000), b"done")
+
+ def _testSendmsgBadCmsg(self):
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [object()])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(object(), 0, b"data")])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, object(), b"data")])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, 0, object())])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, 0)])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, 0, b"data", 42)])
+ self.sendToServer(b"done")
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testSendmsgBadMultiCmsg(self):
+ # Check that invalid ancillary data items are rejected when
+ # more than one item is present.
+ self.assertEqual(self.serv_sock.recv(1000), b"done")
+
+ @testSendmsgBadMultiCmsg.client_skip
+ def _testSendmsgBadMultiCmsg(self):
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [0, 0, b""])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, 0, b""), object()])
+ self.sendToServer(b"done")
+
+ def testSendmsgExcessCmsgReject(self):
+ # Check that sendmsg() rejects excess ancillary data items
+ # when the number that can be sent is limited.
+ self.assertEqual(self.serv_sock.recv(1000), b"done")
+
+ def _testSendmsgExcessCmsgReject(self):
+ if not hasattr(socket, "CMSG_SPACE"):
+ # Can only send one item
+ with self.assertRaises(socket.error) as cm:
+ self.sendmsgToServer([MSG], [(0, 0, b""), (0, 0, b"")])
+ self.assertIsNone(cm.exception.errno)
+ self.sendToServer(b"done")
+
+ def testSendmsgAfterClose(self):
+ # Check that sendmsg() fails on a closed socket.
+ pass
+
+ def _testSendmsgAfterClose(self):
+ self.cli_sock.close()
+ self.assertRaises(socket.error, self.sendmsgToServer, [MSG])
+
+
+class SendmsgStreamTests(SendmsgTests):
+ # Tests for sendmsg() which require a stream socket and do not
+ # involve recvmsg() or recvmsg_into().
+
+ def testSendmsgExplicitNoneAddr(self):
+ # Check that peer address can be specified as None.
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgExplicitNoneAddr(self):
+ self.assertEqual(self.sendmsgToServer([MSG], [], 0, None), len(MSG))
+
+ def testSendmsgTimeout(self):
+ # Check that timeout works with sendmsg().
+ self.assertEqual(self.serv_sock.recv(512), b"a"*512)
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+
+ def _testSendmsgTimeout(self):
+ try:
+ self.cli_sock.settimeout(0.03)
+ with self.assertRaises(socket.timeout):
+ while True:
+ self.sendmsgToServer([b"a"*512])
+ finally:
+ self.misc_event.set()
+
+ # XXX: would be nice to have more tests for sendmsg flags argument.
+
+ # Linux supports MSG_DONTWAIT when sending, but in general, it
+ # only works when receiving. Could add other platforms if they
+ # support it too.
+ @skipWithClientIf(sys.platform not in {"linux2"},
+ "MSG_DONTWAIT not known to work on this platform when "
+ "sending")
+ def testSendmsgDontWait(self):
+ # Check that MSG_DONTWAIT in flags causes non-blocking behaviour.
+ self.assertEqual(self.serv_sock.recv(512), b"a"*512)
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+
+ @testSendmsgDontWait.client_skip
+ def _testSendmsgDontWait(self):
+ try:
+ with self.assertRaises(socket.error) as cm:
+ while True:
+ self.sendmsgToServer([b"a"*512], [], socket.MSG_DONTWAIT)
+ self.assertIn(cm.exception.errno,
+ (errno.EAGAIN, errno.EWOULDBLOCK))
+ finally:
+ self.misc_event.set()
+
+
+class SendmsgConnectionlessTests(SendmsgTests):
+ # Tests for sendmsg() which require a connectionless-mode
+ # (e.g. datagram) socket, and do not involve recvmsg() or
+ # recvmsg_into().
+
+ def testSendmsgNoDestAddr(self):
+ # Check that sendmsg() fails when no destination address is
+ # given for unconnected socket.
+ pass
+
+ def _testSendmsgNoDestAddr(self):
+ self.assertRaises(socket.error, self.cli_sock.sendmsg,
+ [MSG])
+ self.assertRaises(socket.error, self.cli_sock.sendmsg,
+ [MSG], [], 0, None)
+
+
+class RecvmsgGenericTests(SendrecvmsgBase):
+ # Tests for recvmsg() which can also be emulated using
+ # recvmsg_into(), and can use any socket type.
+
+ def testRecvmsg(self):
+ # Receive a simple message with recvmsg[_into]().
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsg(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgExplicitDefaults(self):
+ # Test recvmsg[_into]() with default arguments provided explicitly.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 0, 0)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgExplicitDefaults(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgShorter(self):
+ # Receive a message smaller than buffer.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG) + 42)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgShorter(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgTrunc(self):
+ # Receive part of message, check for truncation indicators.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG) - 3)
+ self.assertEqual(msg, MSG[:-3])
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=False)
+
+ def _testRecvmsgTrunc(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgShortAncillaryBuf(self):
+ # Test ancillary data buffer too small to hold any ancillary data.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 1)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgShortAncillaryBuf(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgLongAncillaryBuf(self):
+ # Test large ancillary data buffer.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgLongAncillaryBuf(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgAfterClose(self):
+ # Check that recvmsg[_into]() fails on a closed socket.
+ self.serv_sock.close()
+ self.assertRaises(socket.error, self.doRecvmsg, self.serv_sock, 1024)
+
+ def _testRecvmsgAfterClose(self):
+ pass
+
+ def testRecvmsgTimeout(self):
+ # Check that timeout works.
+ try:
+ self.serv_sock.settimeout(0.03)
+ self.assertRaises(socket.timeout,
+ self.doRecvmsg, self.serv_sock, len(MSG))
+ finally:
+ self.misc_event.set()
+
+ def _testRecvmsgTimeout(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+
+ @requireAttrs(socket, "MSG_PEEK")
+ def testRecvmsgPeek(self):
+ # Check that MSG_PEEK in flags enables examination of pending
+ # data without consuming it.
+
+ # Receive part of data with MSG_PEEK.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG) - 3, 0,
+ socket.MSG_PEEK)
+ self.assertEqual(msg, MSG[:-3])
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ # Ignoring MSG_TRUNC here (so this test is the same for stream
+ # and datagram sockets). Some wording in POSIX seems to
+ # suggest that it needn't be set when peeking, but that may
+ # just be a slip.
+ self.checkFlags(flags, eor=False,
+ ignore=getattr(socket, "MSG_TRUNC", 0))
+
+ # Receive all data with MSG_PEEK.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 0,
+ socket.MSG_PEEK)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ # Check that the same data can still be received normally.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ @testRecvmsgPeek.client_skip
+ def _testRecvmsgPeek(self):
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket.socket, "sendmsg")
+ def testRecvmsgFromSendmsg(self):
+ # Test receiving with recvmsg[_into]() when message is sent
+ # using sendmsg().
+ self.serv_sock.settimeout(self.fail_timeout)
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ @testRecvmsgFromSendmsg.client_skip
+ def _testRecvmsgFromSendmsg(self):
+ self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG))
+
+
+class RecvmsgGenericStreamTests(RecvmsgGenericTests):
+ # Tests which require a stream socket and can use either recvmsg()
+ # or recvmsg_into().
+
+ def testRecvmsgEOF(self):
+ # Receive end-of-stream indicator (b"", peer socket closed).
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024)
+ self.assertEqual(msg, b"")
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=None) # Might not have end-of-record marker
+
+ def _testRecvmsgEOF(self):
+ self.cli_sock.close()
+
+ def testRecvmsgOverflow(self):
+ # Receive a message in more than one chunk.
+ seg1, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG) - 3)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=False)
+
+ seg2, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ msg = seg1 + seg2
+ self.assertEqual(msg, MSG)
+
+ def _testRecvmsgOverflow(self):
+ self.sendToServer(MSG)
+
+
+class RecvmsgTests(RecvmsgGenericTests):
+ # Tests for recvmsg() which can use any socket type.
+
+ def testRecvmsgBadArgs(self):
+ # Check that recvmsg() rejects invalid arguments.
+ self.assertRaises(TypeError, self.serv_sock.recvmsg)
+ self.assertRaises(ValueError, self.serv_sock.recvmsg,
+ -1, 0, 0)
+ self.assertRaises(ValueError, self.serv_sock.recvmsg,
+ len(MSG), -1, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg,
+ [bytearray(10)], 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg,
+ object(), 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg,
+ len(MSG), object(), 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg,
+ len(MSG), 0, object())
+
+ msg, ancdata, flags, addr = self.serv_sock.recvmsg(len(MSG), 0, 0)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgBadArgs(self):
+ self.sendToServer(MSG)
+
+
+class RecvmsgIntoTests(RecvmsgIntoMixin, RecvmsgGenericTests):
+ # Tests for recvmsg_into() which can use any socket type.
+
+ def testRecvmsgIntoBadArgs(self):
+ # Check that recvmsg_into() rejects invalid arguments.
+ buf = bytearray(len(MSG))
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ len(MSG), 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ buf, 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [object()], 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [b"I'm not writable"], 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [buf, object()], 0, 0)
+ self.assertRaises(ValueError, self.serv_sock.recvmsg_into,
+ [buf], -1, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [buf], object(), 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [buf], 0, object())
+
+ nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf], 0, 0)
+ self.assertEqual(nbytes, len(MSG))
+ self.assertEqual(buf, bytearray(MSG))
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgIntoBadArgs(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgIntoGenerator(self):
+ # Receive into buffer obtained from a generator (not a sequence).
+ buf = bytearray(len(MSG))
+ nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into(
+ (o for o in [buf]))
+ self.assertEqual(nbytes, len(MSG))
+ self.assertEqual(buf, bytearray(MSG))
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgIntoGenerator(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgIntoArray(self):
+ # Receive into an array rather than the usual bytearray.
+ buf = array.array("B", [0] * len(MSG))
+ nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf])
+ self.assertEqual(nbytes, len(MSG))
+ self.assertEqual(buf.tostring(), MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgIntoArray(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgIntoScatter(self):
+ # Receive into multiple buffers (scatter write).
+ b1 = bytearray(b"----")
+ b2 = bytearray(b"0123456789")
+ b3 = bytearray(b"--------------")
+ nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into(
+ [b1, memoryview(b2)[2:9], b3])
+ self.assertEqual(nbytes, len(b"Mary had a little lamb"))
+ self.assertEqual(b1, bytearray(b"Mary"))
+ self.assertEqual(b2, bytearray(b"01 had a 9"))
+ self.assertEqual(b3, bytearray(b"little lamb---"))
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgIntoScatter(self):
+ self.sendToServer(b"Mary had a little lamb")
+
+
+class CmsgMacroTests(unittest.TestCase):
+ # Test the functions CMSG_LEN() and CMSG_SPACE(). Tests
+ # assumptions used by sendmsg() and recvmsg[_into](), which share
+ # code with these functions.
+
+ # Match the definition in socketmodule.c
+ socklen_t_limit = min(0x7fffffff, _testcapi.INT_MAX)
+
+ @requireAttrs(socket, "CMSG_LEN")
+ def testCMSG_LEN(self):
+ # Test CMSG_LEN() with various valid and invalid values,
+ # checking the assumptions used by recvmsg() and sendmsg().
+ toobig = self.socklen_t_limit - socket.CMSG_LEN(0) + 1
+ values = list(range(257)) + list(range(toobig - 257, toobig))
+
+ # struct cmsghdr has at least three members, two of which are ints
+ self.assertGreater(socket.CMSG_LEN(0), array.array("i").itemsize * 2)
+ for n in values:
+ ret = socket.CMSG_LEN(n)
+ # This is how recvmsg() calculates the data size
+ self.assertEqual(ret - socket.CMSG_LEN(0), n)
+ self.assertLessEqual(ret, self.socklen_t_limit)
+
+ self.assertRaises(OverflowError, socket.CMSG_LEN, -1)
+ # sendmsg() shares code with these functions, and requires
+ # that it reject values over the limit.
+ self.assertRaises(OverflowError, socket.CMSG_LEN, toobig)
+ self.assertRaises(OverflowError, socket.CMSG_LEN, sys.maxsize)
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testCMSG_SPACE(self):
+ # Test CMSG_SPACE() with various valid and invalid values,
+ # checking the assumptions used by sendmsg().
+ toobig = self.socklen_t_limit - socket.CMSG_SPACE(1) + 1
+ values = list(range(257)) + list(range(toobig - 257, toobig))
+
+ last = socket.CMSG_SPACE(0)
+ # struct cmsghdr has at least three members, two of which are ints
+ self.assertGreater(last, array.array("i").itemsize * 2)
+ for n in values:
+ ret = socket.CMSG_SPACE(n)
+ self.assertGreaterEqual(ret, last)
+ self.assertGreaterEqual(ret, socket.CMSG_LEN(n))
+ self.assertGreaterEqual(ret, n + socket.CMSG_LEN(0))
+ self.assertLessEqual(ret, self.socklen_t_limit)
+ last = ret
+
+ self.assertRaises(OverflowError, socket.CMSG_SPACE, -1)
+ # sendmsg() shares code with these functions, and requires
+ # that it reject values over the limit.
+ self.assertRaises(OverflowError, socket.CMSG_SPACE, toobig)
+ self.assertRaises(OverflowError, socket.CMSG_SPACE, sys.maxsize)
+
+
+class SCMRightsTest(SendrecvmsgServerTimeoutBase):
+ # Tests for file descriptor passing on Unix-domain sockets.
+
+ # Invalid file descriptor value that's unlikely to evaluate to a
+ # real FD even if one of its bytes is replaced with a different
+ # value (which shouldn't actually happen).
+ badfd = -0x5555
+
+ def newFDs(self, n):
+ # Return a list of n file descriptors for newly-created files
+ # containing their list indices as ASCII numbers.
+ fds = []
+ for i in range(n):
+ fd, path = tempfile.mkstemp()
+ self.addCleanup(os.unlink, path)
+ self.addCleanup(os.close, fd)
+ os.write(fd, str(i).encode())
+ fds.append(fd)
+ return fds
+
+ def checkFDs(self, fds):
+ # Check that the file descriptors in the given list contain
+ # their correct list indices as ASCII numbers.
+ for n, fd in enumerate(fds):
+ os.lseek(fd, 0, os.SEEK_SET)
+ self.assertEqual(os.read(fd, 1024), str(n).encode())
+
+ def registerRecvmsgResult(self, result):
+ self.addCleanup(self.closeRecvmsgFDs, result)
+
+ def closeRecvmsgFDs(self, recvmsg_result):
+ # Close all file descriptors specified in the ancillary data
+ # of the given return value from recvmsg() or recvmsg_into().
+ for cmsg_level, cmsg_type, cmsg_data in recvmsg_result[1]:
+ if (cmsg_level == socket.SOL_SOCKET and
+ cmsg_type == socket.SCM_RIGHTS):
+ fds = array.array("i")
+ fds.fromstring(cmsg_data[:
+ len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
+ for fd in fds:
+ os.close(fd)
+
+ def createAndSendFDs(self, n):
+ # Send n new file descriptors created by newFDs() to the
+ # server, with the constant MSG as the non-ancillary data.
+ self.assertEqual(
+ self.sendmsgToServer([MSG],
+ [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", self.newFDs(n)))]),
+ len(MSG))
+
+ def checkRecvmsgFDs(self, numfds, result, maxcmsgs=1, ignoreflags=0):
+ # Check that constant MSG was received with numfds file
+ # descriptors in a maximum of maxcmsgs control messages (which
+ # must contain only complete integers). By default, check
+ # that MSG_CTRUNC is unset, but ignore any flags in
+ # ignoreflags.
+ msg, ancdata, flags, addr = result
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ self.assertIsInstance(ancdata, list)
+ self.assertLessEqual(len(ancdata), maxcmsgs)
+ fds = array.array("i")
+ for item in ancdata:
+ self.assertIsInstance(item, tuple)
+ cmsg_level, cmsg_type, cmsg_data = item
+ self.assertEqual(cmsg_level, socket.SOL_SOCKET)
+ self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
+ self.assertIsInstance(cmsg_data, bytes)
+ self.assertEqual(len(cmsg_data) % SIZEOF_INT, 0)
+ fds.fromstring(cmsg_data)
+
+ self.assertEqual(len(fds), numfds)
+ self.checkFDs(fds)
+
+ def testFDPassSimple(self):
+ # Pass a single FD (array read from bytes object).
+ self.checkRecvmsgFDs(1, self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240))
+
+ def _testFDPassSimple(self):
+ self.assertEqual(
+ self.sendmsgToServer(
+ [MSG],
+ [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", self.newFDs(1)).tostring())]),
+ len(MSG))
+
+ def testMultipleFDPass(self):
+ # Pass multiple FDs in a single array.
+ self.checkRecvmsgFDs(4, self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240))
+
+ def _testMultipleFDPass(self):
+ self.createAndSendFDs(4)
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testFDPassCMSG_SPACE(self):
+ # Test using CMSG_SPACE() to calculate ancillary buffer size.
+ self.checkRecvmsgFDs(
+ 4, self.doRecvmsg(self.serv_sock, len(MSG),
+ socket.CMSG_SPACE(4 * SIZEOF_INT)))
+
+ @testFDPassCMSG_SPACE.client_skip
+ def _testFDPassCMSG_SPACE(self):
+ self.createAndSendFDs(4)
+
+ def testFDPassCMSG_LEN(self):
+ # Test using CMSG_LEN() to calculate ancillary buffer size.
+ self.checkRecvmsgFDs(1,
+ self.doRecvmsg(self.serv_sock, len(MSG),
+ socket.CMSG_LEN(4 * SIZEOF_INT)),
+ # RFC 3542 says implementations may set
+ # MSG_CTRUNC if there isn't enough space
+ # for trailing padding.
+ ignoreflags=socket.MSG_CTRUNC)
+
+ def _testFDPassCMSG_LEN(self):
+ self.createAndSendFDs(1)
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testFDPassSeparate(self):
+ # Pass two FDs in two separate arrays. Arrays may be combined
+ # into a single control message by the OS.
+ self.checkRecvmsgFDs(2,
+ self.doRecvmsg(self.serv_sock, len(MSG), 10240),
+ maxcmsgs=2)
+
+ @testFDPassSeparate.client_skip
+ def _testFDPassSeparate(self):
+ fd0, fd1 = self.newFDs(2)
+ self.assertEqual(
+ self.sendmsgToServer([MSG], [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd0])),
+ (socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd1]))]),
+ len(MSG))
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testFDPassSeparateMinSpace(self):
+ # Pass two FDs in two separate arrays, receiving them into the
+ # minimum space for two arrays.
+ self.checkRecvmsgFDs(2,
+ self.doRecvmsg(self.serv_sock, len(MSG),
+ socket.CMSG_SPACE(SIZEOF_INT) +
+ socket.CMSG_LEN(SIZEOF_INT)),
+ maxcmsgs=2, ignoreflags=socket.MSG_CTRUNC)
+
+ @testFDPassSeparateMinSpace.client_skip
+ def _testFDPassSeparateMinSpace(self):
+ fd0, fd1 = self.newFDs(2)
+ self.assertEqual(
+ self.sendmsgToServer([MSG], [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd0])),
+ (socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd1]))]),
+ len(MSG))
+
+ def sendAncillaryIfPossible(self, msg, ancdata):
+ # Try to send msg and ancdata to server, but if the system
+ # call fails, just send msg with no ancillary data.
+ try:
+ nbytes = self.sendmsgToServer([msg], ancdata)
+ except socket.error as e:
+ # Check that it was the system call that failed
+ self.assertIsInstance(e.errno, int)
+ nbytes = self.sendmsgToServer([msg])
+ self.assertEqual(nbytes, len(msg))
+
+ def testFDPassEmpty(self):
+ # Try to pass an empty FD array. Can receive either no array
+ # or an empty array.
+ self.checkRecvmsgFDs(0, self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240),
+ ignoreflags=socket.MSG_CTRUNC)
+
+ def _testFDPassEmpty(self):
+ self.sendAncillaryIfPossible(MSG, [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ b"")])
+
+ def testFDPassPartialInt(self):
+ # Try to pass a truncated FD array.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC)
+ self.assertLessEqual(len(ancdata), 1)
+ for cmsg_level, cmsg_type, cmsg_data in ancdata:
+ self.assertEqual(cmsg_level, socket.SOL_SOCKET)
+ self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
+ self.assertLess(len(cmsg_data), SIZEOF_INT)
+
+ def _testFDPassPartialInt(self):
+ self.sendAncillaryIfPossible(
+ MSG,
+ [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [self.badfd]).tostring()[:-1])])
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testFDPassPartialIntInMiddle(self):
+ # Try to pass two FD arrays, the first of which is truncated.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC)
+ self.assertLessEqual(len(ancdata), 2)
+ fds = array.array("i")
+ # Arrays may have been combined in a single control message
+ for cmsg_level, cmsg_type, cmsg_data in ancdata:
+ self.assertEqual(cmsg_level, socket.SOL_SOCKET)
+ self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
+ fds.fromstring(cmsg_data[:
+ len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
+ self.assertLessEqual(len(fds), 2)
+ self.checkFDs(fds)
+
+ @testFDPassPartialIntInMiddle.client_skip
+ def _testFDPassPartialIntInMiddle(self):
+ fd0, fd1 = self.newFDs(2)
+ self.sendAncillaryIfPossible(
+ MSG,
+ [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd0, self.badfd]).tostring()[:-1]),
+ (socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd1]))])
+
+ def checkTruncatedHeader(self, result, ignoreflags=0):
+ # Check that no ancillary data items are returned when data is
+ # truncated inside the cmsghdr structure.
+ msg, ancdata, flags, addr = result
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ def testCmsgTruncNoBufSize(self):
+ # Check that no ancillary data is received when no buffer size
+ # is specified.
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG)),
+ # BSD seems to set MSG_CTRUNC only
+ # if an item has been partially
+ # received.
+ ignoreflags=socket.MSG_CTRUNC)
+
+ def _testCmsgTruncNoBufSize(self):
+ self.createAndSendFDs(1)
+
+ def testCmsgTrunc0(self):
+ # Check that no ancillary data is received when buffer size is 0.
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 0),
+ ignoreflags=socket.MSG_CTRUNC)
+
+ def _testCmsgTrunc0(self):
+ self.createAndSendFDs(1)
+
+ # Check that no ancillary data is returned for various non-zero
+ # (but still too small) buffer sizes.
+
+ def testCmsgTrunc1(self):
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 1))
+
+ def _testCmsgTrunc1(self):
+ self.createAndSendFDs(1)
+
+ def testCmsgTrunc2Int(self):
+ # The cmsghdr structure has at least three members, two of
+ # which are ints, so we still shouldn't see any ancillary
+ # data.
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG),
+ SIZEOF_INT * 2))
+
+ def _testCmsgTrunc2Int(self):
+ self.createAndSendFDs(1)
+
+ def testCmsgTruncLen0Minus1(self):
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG),
+ socket.CMSG_LEN(0) - 1))
+
+ def _testCmsgTruncLen0Minus1(self):
+ self.createAndSendFDs(1)
+
+ # The following tests try to truncate the control message in the
+ # middle of the FD array.
+
+ def checkTruncatedArray(self, ancbuf, maxdata, mindata=0):
+ # Check that file descriptor data is truncated to between
+ # mindata and maxdata bytes when received with buffer size
+ # ancbuf, and that any complete file descriptor numbers are
+ # valid.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), ancbuf)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
+
+ if mindata == 0 and ancdata == []:
+ return
+ self.assertEqual(len(ancdata), 1)
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ self.assertEqual(cmsg_level, socket.SOL_SOCKET)
+ self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
+ self.assertGreaterEqual(len(cmsg_data), mindata)
+ self.assertLessEqual(len(cmsg_data), maxdata)
+ fds = array.array("i")
+ fds.fromstring(cmsg_data[:
+ len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
+ self.checkFDs(fds)
+
+ def testCmsgTruncLen0(self):
+ self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0), maxdata=0)
+
+ def _testCmsgTruncLen0(self):
+ self.createAndSendFDs(1)
+
+ def testCmsgTruncLen0Plus1(self):
+ self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0) + 1, maxdata=1)
+
+ def _testCmsgTruncLen0Plus1(self):
+ self.createAndSendFDs(2)
+
+ def testCmsgTruncLen1(self):
+ self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(SIZEOF_INT),
+ maxdata=SIZEOF_INT)
+
+ def _testCmsgTruncLen1(self):
+ self.createAndSendFDs(2)
+
+ def testCmsgTruncLen2Minus1(self):
+ self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(2 * SIZEOF_INT) - 1,
+ maxdata=(2 * SIZEOF_INT) - 1)
+
+ def _testCmsgTruncLen2Minus1(self):
+ self.createAndSendFDs(2)
+
+
+class RFC3542AncillaryTest(SendrecvmsgServerTimeoutBase):
+ # Test sendmsg() and recvmsg[_into]() using the ancillary data
+ # features of the RFC 3542 Advanced Sockets API for IPv6.
+ # Currently we can only handle certain data items (e.g. traffic
+ # class, hop limit, MTU discovery and fragmentation settings)
+ # without resorting to unportable means such as the struct module,
+ # but the tests here are aimed at testing the ancillary data
+ # handling in sendmsg() and recvmsg() rather than the IPv6 API
+ # itself.
+
+ # Test value to use when setting hop limit of packet
+ hop_limit = 2
+
+ # Test value to use when setting traffic class of packet.
+ # -1 means "use kernel default".
+ traffic_class = -1
+
+ def ancillaryMapping(self, ancdata):
+ # Given ancillary data list ancdata, return a mapping from
+ # pairs (cmsg_level, cmsg_type) to corresponding cmsg_data.
+ # Check that no (level, type) pair appears more than once.
+ d = {}
+ for cmsg_level, cmsg_type, cmsg_data in ancdata:
+ self.assertNotIn((cmsg_level, cmsg_type), d)
+ d[(cmsg_level, cmsg_type)] = cmsg_data
+ return d
+
+ def checkHopLimit(self, ancbufsize, maxhop=255, ignoreflags=0):
+ # Receive hop limit into ancbufsize bytes of ancillary data
+ # space. Check that data is MSG, ancillary data is not
+ # truncated (but ignore any flags in ignoreflags), and hop
+ # limit is between 0 and maxhop inclusive.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), ancbufsize)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ self.assertEqual(len(ancdata), 1)
+ self.assertIsInstance(ancdata[0], tuple)
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ self.assertEqual(cmsg_type, socket.IPV6_HOPLIMIT)
+ self.assertIsInstance(cmsg_data, bytes)
+ self.assertEqual(len(cmsg_data), SIZEOF_INT)
+ a = array.array("i")
+ a.fromstring(cmsg_data)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], maxhop)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testRecvHopLimit(self):
+ # Test receiving the packet hop limit as ancillary data.
+ self.checkHopLimit(ancbufsize=10240)
+
+ @testRecvHopLimit.client_skip
+ def _testRecvHopLimit(self):
+ # Need to wait until server has asked to receive ancillary
+ # data, as implementations are not required to buffer it
+ # otherwise.
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testRecvHopLimitCMSG_SPACE(self):
+ # Test receiving hop limit, using CMSG_SPACE to calculate buffer size.
+ self.checkHopLimit(ancbufsize=socket.CMSG_SPACE(SIZEOF_INT))
+
+ @testRecvHopLimitCMSG_SPACE.client_skip
+ def _testRecvHopLimitCMSG_SPACE(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ # Could test receiving into buffer sized using CMSG_LEN, but RFC
+ # 3542 says portable applications must provide space for trailing
+ # padding. Implementations may set MSG_CTRUNC if there isn't
+ # enough space for the padding.
+
+ @requireAttrs(socket.socket, "sendmsg")
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSetHopLimit(self):
+ # Test setting hop limit on outgoing packet and receiving it
+ # at the other end.
+ self.checkHopLimit(ancbufsize=10240, maxhop=self.hop_limit)
+
+ @testSetHopLimit.client_skip
+ def _testSetHopLimit(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.assertEqual(
+ self.sendmsgToServer([MSG],
+ [(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
+ array.array("i", [self.hop_limit]))]),
+ len(MSG))
+
+ def checkTrafficClassAndHopLimit(self, ancbufsize, maxhop=255,
+ ignoreflags=0):
+ # Receive traffic class and hop limit into ancbufsize bytes of
+ # ancillary data space. Check that data is MSG, ancillary
+ # data is not truncated (but ignore any flags in ignoreflags),
+ # and traffic class and hop limit are in range (hop limit no
+ # more than maxhop).
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVTCLASS, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), ancbufsize)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+ self.assertEqual(len(ancdata), 2)
+ ancmap = self.ancillaryMapping(ancdata)
+
+ tcdata = ancmap[(socket.IPPROTO_IPV6, socket.IPV6_TCLASS)]
+ self.assertEqual(len(tcdata), SIZEOF_INT)
+ a = array.array("i")
+ a.fromstring(tcdata)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], 255)
+
+ hldata = ancmap[(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT)]
+ self.assertEqual(len(hldata), SIZEOF_INT)
+ a = array.array("i")
+ a.fromstring(hldata)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], maxhop)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testRecvTrafficClassAndHopLimit(self):
+ # Test receiving traffic class and hop limit as ancillary data.
+ self.checkTrafficClassAndHopLimit(ancbufsize=10240)
+
+ @testRecvTrafficClassAndHopLimit.client_skip
+ def _testRecvTrafficClassAndHopLimit(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testRecvTrafficClassAndHopLimitCMSG_SPACE(self):
+ # Test receiving traffic class and hop limit, using
+ # CMSG_SPACE() to calculate buffer size.
+ self.checkTrafficClassAndHopLimit(
+ ancbufsize=socket.CMSG_SPACE(SIZEOF_INT) * 2)
+
+ @testRecvTrafficClassAndHopLimitCMSG_SPACE.client_skip
+ def _testRecvTrafficClassAndHopLimitCMSG_SPACE(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket.socket, "sendmsg")
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSetTrafficClassAndHopLimit(self):
+ # Test setting traffic class and hop limit on outgoing packet,
+ # and receiving them at the other end.
+ self.checkTrafficClassAndHopLimit(ancbufsize=10240,
+ maxhop=self.hop_limit)
+
+ @testSetTrafficClassAndHopLimit.client_skip
+ def _testSetTrafficClassAndHopLimit(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.assertEqual(
+ self.sendmsgToServer([MSG],
+ [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
+ array.array("i", [self.traffic_class])),
+ (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
+ array.array("i", [self.hop_limit]))]),
+ len(MSG))
+
+ @requireAttrs(socket.socket, "sendmsg")
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testOddCmsgSize(self):
+ # Try to send ancillary data with first item one byte too
+ # long. Fall back to sending with correct size if this fails,
+ # and check that second item was handled correctly.
+ self.checkTrafficClassAndHopLimit(ancbufsize=10240,
+ maxhop=self.hop_limit)
+
+ @testOddCmsgSize.client_skip
+ def _testOddCmsgSize(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ try:
+ nbytes = self.sendmsgToServer(
+ [MSG],
+ [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
+ array.array("i", [self.traffic_class]).tostring() + b"\x00"),
+ (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
+ array.array("i", [self.hop_limit]))])
+ except socket.error as e:
+ self.assertIsInstance(e.errno, int)
+ nbytes = self.sendmsgToServer(
+ [MSG],
+ [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
+ array.array("i", [self.traffic_class])),
+ (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
+ array.array("i", [self.hop_limit]))])
+ self.assertEqual(nbytes, len(MSG))
+
+ # Tests for proper handling of truncated ancillary data
+
+ def checkHopLimitTruncatedHeader(self, ancbufsize, ignoreflags=0):
+ # Receive hop limit into ancbufsize bytes of ancillary data
+ # space, which should be too small to contain the ancillary
+ # data header (if ancbufsize is None, pass no second argument
+ # to recvmsg()). Check that data is MSG, MSG_CTRUNC is set
+ # (unless included in ignoreflags), and no ancillary data is
+ # returned.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.misc_event.set()
+ args = () if ancbufsize is None else (ancbufsize,)
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), *args)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testCmsgTruncNoBufSize(self):
+ # Check that no ancillary data is received when no ancillary
+ # buffer size is provided.
+ self.checkHopLimitTruncatedHeader(ancbufsize=None,
+ # BSD seems to set
+ # MSG_CTRUNC only if an item
+ # has been partially
+ # received.
+ ignoreflags=socket.MSG_CTRUNC)
+
+ @testCmsgTruncNoBufSize.client_skip
+ def _testCmsgTruncNoBufSize(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTrunc0(self):
+ # Check that no ancillary data is received when ancillary
+ # buffer size is zero.
+ self.checkHopLimitTruncatedHeader(ancbufsize=0,
+ ignoreflags=socket.MSG_CTRUNC)
+
+ @testSingleCmsgTrunc0.client_skip
+ def _testSingleCmsgTrunc0(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ # Check that no ancillary data is returned for various non-zero
+ # (but still too small) buffer sizes.
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTrunc1(self):
+ self.checkHopLimitTruncatedHeader(ancbufsize=1)
+
+ @testSingleCmsgTrunc1.client_skip
+ def _testSingleCmsgTrunc1(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTrunc2Int(self):
+ self.checkHopLimitTruncatedHeader(ancbufsize=2 * SIZEOF_INT)
+
+ @testSingleCmsgTrunc2Int.client_skip
+ def _testSingleCmsgTrunc2Int(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTruncLen0Minus1(self):
+ self.checkHopLimitTruncatedHeader(ancbufsize=socket.CMSG_LEN(0) - 1)
+
+ @testSingleCmsgTruncLen0Minus1.client_skip
+ def _testSingleCmsgTruncLen0Minus1(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTruncInData(self):
+ # Test truncation of a control message inside its associated
+ # data. The message may be returned with its data truncated,
+ # or not returned at all.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(
+ self.serv_sock, len(MSG), socket.CMSG_LEN(SIZEOF_INT) - 1)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
+
+ self.assertLessEqual(len(ancdata), 1)
+ if ancdata:
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ self.assertEqual(cmsg_type, socket.IPV6_HOPLIMIT)
+ self.assertLess(len(cmsg_data), SIZEOF_INT)
+
+ @testSingleCmsgTruncInData.client_skip
+ def _testSingleCmsgTruncInData(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ def checkTruncatedSecondHeader(self, ancbufsize, ignoreflags=0):
+ # Receive traffic class and hop limit into ancbufsize bytes of
+ # ancillary data space, which should be large enough to
+ # contain the first item, but too small to contain the header
+ # of the second. Check that data is MSG, MSG_CTRUNC is set
+ # (unless included in ignoreflags), and only one ancillary
+ # data item is returned.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVTCLASS, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), ancbufsize)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ self.assertEqual(len(ancdata), 1)
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ self.assertIn(cmsg_type, {socket.IPV6_TCLASS, socket.IPV6_HOPLIMIT})
+ self.assertEqual(len(cmsg_data), SIZEOF_INT)
+ a = array.array("i")
+ a.fromstring(cmsg_data)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], 255)
+
+ # Try the above test with various buffer sizes.
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecondCmsgTrunc0(self):
+ self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT),
+ ignoreflags=socket.MSG_CTRUNC)
+
+ @testSecondCmsgTrunc0.client_skip
+ def _testSecondCmsgTrunc0(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecondCmsgTrunc1(self):
+ self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) + 1)
+
+ @testSecondCmsgTrunc1.client_skip
+ def _testSecondCmsgTrunc1(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecondCmsgTrunc2Int(self):
+ self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) +
+ 2 * SIZEOF_INT)
+
+ @testSecondCmsgTrunc2Int.client_skip
+ def _testSecondCmsgTrunc2Int(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecondCmsgTruncLen0Minus1(self):
+ self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) +
+ socket.CMSG_LEN(0) - 1)
+
+ @testSecondCmsgTruncLen0Minus1.client_skip
+ def _testSecondCmsgTruncLen0Minus1(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecomdCmsgTruncInData(self):
+ # Test truncation of the second of two control messages inside
+ # its associated data.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVTCLASS, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(
+ self.serv_sock, len(MSG),
+ socket.CMSG_SPACE(SIZEOF_INT) + socket.CMSG_LEN(SIZEOF_INT) - 1)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
+
+ cmsg_types = {socket.IPV6_TCLASS, socket.IPV6_HOPLIMIT}
+
+ cmsg_level, cmsg_type, cmsg_data = ancdata.pop(0)
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ cmsg_types.remove(cmsg_type)
+ self.assertEqual(len(cmsg_data), SIZEOF_INT)
+ a = array.array("i")
+ a.fromstring(cmsg_data)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], 255)
+
+ if ancdata:
+ cmsg_level, cmsg_type, cmsg_data = ancdata.pop(0)
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ cmsg_types.remove(cmsg_type)
+ self.assertLess(len(cmsg_data), SIZEOF_INT)
+
+ self.assertEqual(ancdata, [])
+
+ @testSecomdCmsgTruncInData.client_skip
+ def _testSecomdCmsgTruncInData(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+
+# Derive concrete test classes for different socket types.
+
+class SendrecvmsgUDPTestBase(SendrecvmsgDgramFlagsBase,
+ SendrecvmsgConnectionlessBase,
+ ThreadedSocketTestMixin, UDPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgUDPTest(SendmsgConnectionlessTests, SendrecvmsgUDPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgUDPTest(RecvmsgTests, SendrecvmsgUDPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoUDPTest(RecvmsgIntoTests, SendrecvmsgUDPTestBase):
+ pass
+
+
+class SendrecvmsgUDP6TestBase(SendrecvmsgDgramFlagsBase,
+ SendrecvmsgConnectionlessBase,
+ ThreadedSocketTestMixin, UDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support")
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgUDP6Test(SendmsgConnectionlessTests, SendrecvmsgUDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support")
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgUDP6Test(RecvmsgTests, SendrecvmsgUDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support")
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoUDP6Test(RecvmsgIntoTests, SendrecvmsgUDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support")
+@requireAttrs(socket, "IPPROTO_IPV6")
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgRFC3542AncillaryUDP6Test(RFC3542AncillaryTest,
+ SendrecvmsgUDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(socket.has_ipv6, "Python not built with IPv6 support")
+@requireAttrs(socket, "IPPROTO_IPV6")
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoRFC3542AncillaryUDP6Test(RecvmsgIntoMixin,
+ RFC3542AncillaryTest,
+ SendrecvmsgUDP6TestBase):
+ pass
+
+
+class SendrecvmsgTCPTestBase(SendrecvmsgConnectedBase,
+ ConnectedStreamTestMixin, TCPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgTCPTest(SendmsgStreamTests, SendrecvmsgTCPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgTCPTest(RecvmsgTests, RecvmsgGenericStreamTests,
+ SendrecvmsgTCPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoTCPTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
+ SendrecvmsgTCPTestBase):
+ pass
+
+
+class SendrecvmsgSCTPStreamTestBase(SendrecvmsgSCTPFlagsBase,
+ SendrecvmsgConnectedBase,
+ ConnectedStreamTestMixin, SCTPStreamBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgSCTPStreamTest(SendmsgStreamTests, SendrecvmsgSCTPStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
+ SendrecvmsgSCTPStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoSCTPStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
+ SendrecvmsgSCTPStreamTestBase):
+ pass
+
+
+class SendrecvmsgUnixStreamTestBase(SendrecvmsgConnectedBase,
+ ConnectedStreamTestMixin, UnixStreamBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@requireAttrs(socket, "AF_UNIX")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgUnixStreamTest(SendmsgStreamTests, SendrecvmsgUnixStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@requireAttrs(socket, "AF_UNIX")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgUnixStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
+ SendrecvmsgUnixStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@requireAttrs(socket, "AF_UNIX")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoUnixStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
+ SendrecvmsgUnixStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg", "recvmsg")
+@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgSCMRightsStreamTest(SCMRightsTest, SendrecvmsgUnixStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg", "recvmsg_into")
+@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoSCMRightsStreamTest(RecvmsgIntoMixin, SCMRightsTest,
+ SendrecvmsgUnixStreamTestBase):
+ pass
+
+
+# Test interrupting the interruptible send/receive methods with a
+# signal when a timeout is set. These tests avoid having multiple
+# threads alive during the test so that the OS cannot deliver the
+# signal to the wrong one.
+
+class InterruptedTimeoutBase(unittest.TestCase):
+ # Base class for interrupted send/receive tests. Installs an
+ # empty handler for SIGALRM and removes it on teardown, along with
+ # any scheduled alarms.
+
+ def setUp(self):
+ super().setUp()
+ orig_alrm_handler = signal.signal(signal.SIGALRM,
+ lambda signum, frame: None)
+ self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
+ self.addCleanup(self.setAlarm, 0)
+
+ # Timeout for socket operations
+ timeout = 4.0
+
+ # Provide setAlarm() method to schedule delivery of SIGALRM after
+ # given number of seconds, or cancel it if zero, and an
+ # appropriate time value to use. Use setitimer() if available.
+ if hasattr(signal, "setitimer"):
+ alarm_time = 0.05
+
+ def setAlarm(self, seconds):
+ signal.setitimer(signal.ITIMER_REAL, seconds)
+ else:
+ # Old systems may deliver the alarm up to one second early
+ alarm_time = 2
+
+ def setAlarm(self, seconds):
+ signal.alarm(seconds)
+
+
+# Require siginterrupt() in order to ensure that system calls are
+# interrupted by default.
+@requireAttrs(signal, "siginterrupt")
+@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"),
+ "Don't have signal.alarm or signal.setitimer")
+class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase):
+ # Test interrupting the recv*() methods with signals when a
+ # timeout is set.
+
+ def setUp(self):
+ super().setUp()
+ self.serv.settimeout(self.timeout)
+
+ def checkInterruptedRecv(self, func, *args, **kwargs):
+ # Check that func(*args, **kwargs) raises socket.error with an
+ # errno of EINTR when interrupted by a signal.
+ self.setAlarm(self.alarm_time)
+ with self.assertRaises(socket.error) as cm:
+ func(*args, **kwargs)
+ self.assertNotIsInstance(cm.exception, socket.timeout)
+ self.assertEqual(cm.exception.errno, errno.EINTR)
+
+ def testInterruptedRecvTimeout(self):
+ self.checkInterruptedRecv(self.serv.recv, 1024)
+
+ def testInterruptedRecvIntoTimeout(self):
+ self.checkInterruptedRecv(self.serv.recv_into, bytearray(1024))
+
+ def testInterruptedRecvfromTimeout(self):
+ self.checkInterruptedRecv(self.serv.recvfrom, 1024)
+
+ def testInterruptedRecvfromIntoTimeout(self):
+ self.checkInterruptedRecv(self.serv.recvfrom_into, bytearray(1024))
+
+ @requireAttrs(socket.socket, "recvmsg")
+ def testInterruptedRecvmsgTimeout(self):
+ self.checkInterruptedRecv(self.serv.recvmsg, 1024)
+
+ @requireAttrs(socket.socket, "recvmsg_into")
+ def testInterruptedRecvmsgIntoTimeout(self):
+ self.checkInterruptedRecv(self.serv.recvmsg_into, [bytearray(1024)])
+
+
+# Require siginterrupt() in order to ensure that system calls are
+# interrupted by default.
+@requireAttrs(signal, "siginterrupt")
+@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"),
+ "Don't have signal.alarm or signal.setitimer")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class InterruptedSendTimeoutTest(InterruptedTimeoutBase,
+ ThreadSafeCleanupTestCase,
+ SocketListeningTestMixin, TCPTestBase):
+ # Test interrupting the interruptible send*() methods with signals
+ # when a timeout is set.
+
+ def setUp(self):
+ super().setUp()
+ self.serv_conn = self.newSocket()
+ self.addCleanup(self.serv_conn.close)
+ # Use a thread to complete the connection, but wait for it to
+ # terminate before running the test, so that there is only one
+ # thread to accept the signal.
+ cli_thread = threading.Thread(target=self.doConnect)
+ cli_thread.start()
+ self.cli_conn, addr = self.serv.accept()
+ self.addCleanup(self.cli_conn.close)
+ cli_thread.join()
+ self.serv_conn.settimeout(self.timeout)
+
+ def doConnect(self):
+ self.serv_conn.connect(self.serv_addr)
+
+ def checkInterruptedSend(self, func, *args, **kwargs):
+ # Check that func(*args, **kwargs), run in a loop, raises
+ # socket.error with an errno of EINTR when interrupted by a
+ # signal.
+ with self.assertRaises(socket.error) as cm:
+ while True:
+ self.setAlarm(self.alarm_time)
+ func(*args, **kwargs)
+ self.assertNotIsInstance(cm.exception, socket.timeout)
+ self.assertEqual(cm.exception.errno, errno.EINTR)
+
+ def testInterruptedSendTimeout(self):
+ self.checkInterruptedSend(self.serv_conn.send, b"a"*512)
+
+ def testInterruptedSendtoTimeout(self):
+ # Passing an actual address here as Python's wrapper for
+ # sendto() doesn't allow passing a zero-length one; POSIX
+ # requires that the address is ignored since the socket is
+ # connection-mode, however.
+ self.checkInterruptedSend(self.serv_conn.sendto, b"a"*512,
+ self.serv_addr)
+
+ @requireAttrs(socket.socket, "sendmsg")
+ def testInterruptedSendmsgTimeout(self):
+ self.checkInterruptedSend(self.serv_conn.sendmsg, [b"a"*512])
+
+
@unittest.skipUnless(thread, 'Threading required for this test.')
class TCPCloserTest(ThreadedTCPSocketTest):
@@ -2077,6 +4172,31 @@ def test_main():
if isTipcAvailable():
tests.append(TIPCTest)
tests.append(TIPCThreadableTest)
+ tests.extend([
+ CmsgMacroTests,
+ SendmsgUDPTest,
+ RecvmsgUDPTest,
+ RecvmsgIntoUDPTest,
+ SendmsgUDP6Test,
+ RecvmsgUDP6Test,
+ RecvmsgRFC3542AncillaryUDP6Test,
+ RecvmsgIntoRFC3542AncillaryUDP6Test,
+ RecvmsgIntoUDP6Test,
+ SendmsgTCPTest,
+ RecvmsgTCPTest,
+ RecvmsgIntoTCPTest,
+ SendmsgSCTPStreamTest,
+ RecvmsgSCTPStreamTest,
+ RecvmsgIntoSCTPStreamTest,
+ SendmsgUnixStreamTest,
+ RecvmsgUnixStreamTest,
+ RecvmsgIntoUnixStreamTest,
+ RecvmsgSCMRightsStreamTest,
+ RecvmsgIntoSCMRightsStreamTest,
+ # These are slow when setitimer() is not available
+ InterruptedRecvTimeoutTest,
+ InterruptedSendTimeoutTest,
+ ])
thread_info = support.threading_setup()
support.run_unittest(*tests)
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index f3f0c54..1d26c11 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -186,8 +186,11 @@ class BasicSocketTests(unittest.TestCase):
self.assertRaises(socket.error, ss.recv_into, bytearray(b'x'))
self.assertRaises(socket.error, ss.recvfrom, 1)
self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1)
+ self.assertRaises(socket.error, ss.recvmsg, 1)
+ self.assertRaises(socket.error, ss.recvmsg_into, [bytearray(b'x')])
self.assertRaises(socket.error, ss.send, b'x')
self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0))
+ self.assertRaises(socket.error, ss.sendmsg, [b'x'])
def test_timeout(self):
# Issue #8524: when creating an SSL socket, the timeout of the
@@ -1520,17 +1523,30 @@ else:
count, addr = s.recvfrom_into(b)
return b[:count]
+ def _recvmsg(*args, **kwargs):
+ return s.recvmsg(*args, **kwargs)[0]
+
+ def _recvmsg_into(bufsize, *args, **kwargs):
+ b = bytearray(bufsize)
+ return bytes(b[:s.recvmsg_into([b], *args, **kwargs)[0]])
+
+ def _sendmsg(msg, *args, **kwargs):
+ return s.sendmsg([msg])
+
# (name, method, whether to expect success, *args)
send_methods = [
('send', s.send, True, []),
('sendto', s.sendto, False, ["some.address"]),
+ ('sendmsg', _sendmsg, False, []),
('sendall', s.sendall, True, []),
]
recv_methods = [
('recv', s.recv, True, []),
('recvfrom', s.recvfrom, False, ["some.address"]),
+ ('recvmsg', _recvmsg, False, [100]),
('recv_into', _recv_into, True, []),
('recvfrom_into', _recvfrom_into, False, []),
+ ('recvmsg_into', _recvmsg_into, False, [100]),
]
data_prefix = "PREFIX_"
diff --git a/Misc/NEWS b/Misc/NEWS
index a2943a7..67e8218 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -265,6 +265,10 @@ Core and Builtins
Library
-------
+- Issue #6560: The sendmsg/recvmsg API is now exposed by the socket module
+ when provided by the underlying platform, supporting processing of
+ ancillary data in pure Python code.
+
- Issue #12326: On Linux, sys.platform doesn't contain the major version
anymore. It is now always 'linux', instead of 'linux2' or 'linux3' depending
on the Linux version used to build Python.
diff --git a/Modules/socketmodule.c b/Modules/socketmodule.c
index 8de84b7..75cde79 100644
--- a/Modules/socketmodule.c
+++ b/Modules/socketmodule.c
@@ -263,6 +263,7 @@ if_indextoname(index) -- return the corresponding interface name\n\
#ifdef HAVE_NET_IF_H
#include <net/if.h>
#endif
+#include <unistd.h>
/* Generic socket object definitions and includes */
#define PySocket_BUILDING_SOCKET
@@ -469,6 +470,17 @@ static PyTypeObject sock_type;
#include <sys/poll.h>
#endif
+/* Largest value to try to store in a socklen_t (used when handling
+ ancillary data). POSIX requires socklen_t to hold at least
+ (2**31)-1 and recommends against storing larger values, but
+ socklen_t was originally int in the BSD interface, so to be on the
+ safe side we use the smaller of (2**31)-1 and INT_MAX. */
+#if INT_MAX > 0x7fffffff
+#define SOCKLEN_T_LIMIT 0x7fffffff
+#else
+#define SOCKLEN_T_LIMIT INT_MAX
+#endif
+
#ifdef Py_SOCKET_FD_CAN_BE_GE_FD_SETSIZE
/* Platform can select file descriptors beyond FD_SETSIZE */
#define IS_SELECTABLE(s) 1
@@ -1678,6 +1690,117 @@ getsockaddrlen(PySocketSockObject *s, socklen_t *len_ret)
}
+/* Support functions for the sendmsg() and recvmsg[_into]() methods.
+ Currently, these methods are only compiled if the RFC 2292/3542
+ CMSG_LEN() macro is available. Older systems seem to have used
+ sizeof(struct cmsghdr) + (length) where CMSG_LEN() is used now, so
+ it may be possible to define CMSG_LEN() that way if it's not
+ provided. Some architectures might need extra padding after the
+ cmsghdr, however, and CMSG_LEN() would have to take account of
+ this. */
+#ifdef CMSG_LEN
+/* If length is in range, set *result to CMSG_LEN(length) and return
+ true; otherwise, return false. */
+static int
+get_CMSG_LEN(size_t length, size_t *result)
+{
+ size_t tmp;
+
+ if (length > (SOCKLEN_T_LIMIT - CMSG_LEN(0)))
+ return 0;
+ tmp = CMSG_LEN(length);
+ if (tmp > SOCKLEN_T_LIMIT || tmp < length)
+ return 0;
+ *result = tmp;
+ return 1;
+}
+
+#ifdef CMSG_SPACE
+/* If length is in range, set *result to CMSG_SPACE(length) and return
+ true; otherwise, return false. */
+static int
+get_CMSG_SPACE(size_t length, size_t *result)
+{
+ size_t tmp;
+
+ /* Use CMSG_SPACE(1) here in order to take account of the padding
+ necessary before *and* after the data. */
+ if (length > (SOCKLEN_T_LIMIT - CMSG_SPACE(1)))
+ return 0;
+ tmp = CMSG_SPACE(length);
+ if (tmp > SOCKLEN_T_LIMIT || tmp < length)
+ return 0;
+ *result = tmp;
+ return 1;
+}
+#endif
+
+/* Return true iff msg->msg_controllen is valid, cmsgh is a valid
+ pointer in msg->msg_control with at least "space" bytes after it,
+ and its cmsg_len member inside the buffer. */
+static int
+cmsg_min_space(struct msghdr *msg, struct cmsghdr *cmsgh, size_t space)
+{
+ size_t cmsg_offset;
+ static const size_t cmsg_len_end = (offsetof(struct cmsghdr, cmsg_len) +
+ sizeof(cmsgh->cmsg_len));
+
+ if (cmsgh == NULL || msg->msg_control == NULL || msg->msg_controllen < 0)
+ return 0;
+ if (space < cmsg_len_end)
+ space = cmsg_len_end;
+ cmsg_offset = (char *)cmsgh - (char *)msg->msg_control;
+ return (cmsg_offset <= (size_t)-1 - space &&
+ cmsg_offset + space <= msg->msg_controllen);
+}
+
+/* If pointer CMSG_DATA(cmsgh) is in buffer msg->msg_control, set
+ *space to number of bytes following it in the buffer and return
+ true; otherwise, return false. Assumes cmsgh, msg->msg_control and
+ msg->msg_controllen are valid. */
+static int
+get_cmsg_data_space(struct msghdr *msg, struct cmsghdr *cmsgh, size_t *space)
+{
+ size_t data_offset;
+ char *data_ptr;
+
+ if ((data_ptr = (char *)CMSG_DATA(cmsgh)) == NULL)
+ return 0;
+ data_offset = data_ptr - (char *)msg->msg_control;
+ if (data_offset > msg->msg_controllen)
+ return 0;
+ *space = msg->msg_controllen - data_offset;
+ return 1;
+}
+
+/* If cmsgh is invalid or not contained in the buffer pointed to by
+ msg->msg_control, return -1. If cmsgh is valid and its associated
+ data is entirely contained in the buffer, set *data_len to the
+ length of the associated data and return 0. If only part of the
+ associated data is contained in the buffer but cmsgh is otherwise
+ valid, set *data_len to the length contained in the buffer and
+ return 1. */
+static int
+get_cmsg_data_len(struct msghdr *msg, struct cmsghdr *cmsgh, size_t *data_len)
+{
+ size_t space, cmsg_data_len;
+
+ if (!cmsg_min_space(msg, cmsgh, CMSG_LEN(0)) ||
+ cmsgh->cmsg_len < CMSG_LEN(0))
+ return -1;
+ cmsg_data_len = cmsgh->cmsg_len - CMSG_LEN(0);
+ if (!get_cmsg_data_space(msg, cmsgh, &space))
+ return -1;
+ if (space >= cmsg_data_len) {
+ *data_len = cmsg_data_len;
+ return 0;
+ }
+ *data_len = space;
+ return 1;
+}
+#endif /* CMSG_LEN */
+
+
/* s._accept() -> (fd, address) */
static PyObject *
@@ -2631,6 +2754,333 @@ PyDoc_STRVAR(recvfrom_into_doc,
Like recv_into(buffer[, nbytes[, flags]]) but also return the sender's address info.");
+/* The sendmsg() and recvmsg[_into]() methods require a working
+ CMSG_LEN(). See the comment near get_CMSG_LEN(). */
+#ifdef CMSG_LEN
+/*
+ * Call recvmsg() with the supplied iovec structures, flags, and
+ * ancillary data buffer size (controllen). Returns the tuple return
+ * value for recvmsg() or recvmsg_into(), with the first item provided
+ * by the supplied makeval() function. makeval() will be called with
+ * the length read and makeval_data as arguments, and must return a
+ * new reference (which will be decrefed if there is a subsequent
+ * error). On error, closes any file descriptors received via
+ * SCM_RIGHTS.
+ */
+static PyObject *
+sock_recvmsg_guts(PySocketSockObject *s, struct iovec *iov, int iovlen,
+ int flags, Py_ssize_t controllen,
+ PyObject *(*makeval)(ssize_t, void *), void *makeval_data)
+{
+ ssize_t bytes_received = -1;
+ int timeout;
+ sock_addr_t addrbuf;
+ socklen_t addrbuflen;
+ static const struct msghdr msg_blank;
+ struct msghdr msg;
+ PyObject *cmsg_list = NULL, *retval = NULL;
+ void *controlbuf = NULL;
+ struct cmsghdr *cmsgh;
+ size_t cmsgdatalen = 0;
+ int cmsg_status;
+
+ /* XXX: POSIX says that msg_name and msg_namelen "shall be
+ ignored" when the socket is connected (Linux fills them in
+ anyway for AF_UNIX sockets at least). Normally msg_namelen
+ seems to be set to 0 if there's no address, but try to
+ initialize msg_name to something that won't be mistaken for a
+ real address if that doesn't happen. */
+ if (!getsockaddrlen(s, &addrbuflen))
+ return NULL;
+ memset(&addrbuf, 0, addrbuflen);
+ SAS2SA(&addrbuf)->sa_family = AF_UNSPEC;
+
+ if (controllen < 0 || controllen > SOCKLEN_T_LIMIT) {
+ PyErr_SetString(PyExc_ValueError,
+ "invalid ancillary data buffer length");
+ return NULL;
+ }
+ if (controllen > 0 && (controlbuf = PyMem_Malloc(controllen)) == NULL)
+ return PyErr_NoMemory();
+
+ /* Make the system call. */
+ if (!IS_SELECTABLE(s)) {
+ select_error();
+ goto finally;
+ }
+
+ BEGIN_SELECT_LOOP(s)
+ Py_BEGIN_ALLOW_THREADS;
+ msg = msg_blank; /* Set all members to 0 or NULL */
+ msg.msg_name = SAS2SA(&addrbuf);
+ msg.msg_namelen = addrbuflen;
+ msg.msg_iov = iov;
+ msg.msg_iovlen = iovlen;
+ msg.msg_control = controlbuf;
+ msg.msg_controllen = controllen;
+ timeout = internal_select_ex(s, 0, interval);
+ if (!timeout)
+ bytes_received = recvmsg(s->sock_fd, &msg, flags);
+ Py_END_ALLOW_THREADS;
+ if (timeout == 1) {
+ PyErr_SetString(socket_timeout, "timed out");
+ goto finally;
+ }
+ END_SELECT_LOOP(s)
+
+ if (bytes_received < 0) {
+ s->errorhandler();
+ goto finally;
+ }
+
+ /* Make list of (level, type, data) tuples from control messages. */
+ if ((cmsg_list = PyList_New(0)) == NULL)
+ goto err_closefds;
+ /* Check for empty ancillary data as old CMSG_FIRSTHDR()
+ implementations didn't do so. */
+ for (cmsgh = ((msg.msg_controllen > 0) ? CMSG_FIRSTHDR(&msg) : NULL);
+ cmsgh != NULL; cmsgh = CMSG_NXTHDR(&msg, cmsgh)) {
+ PyObject *bytes, *tuple;
+ int tmp;
+
+ cmsg_status = get_cmsg_data_len(&msg, cmsgh, &cmsgdatalen);
+ if (cmsg_status != 0) {
+ if (PyErr_WarnEx(PyExc_RuntimeWarning,
+ "received malformed or improperly-truncated "
+ "ancillary data", 1) == -1)
+ goto err_closefds;
+ }
+ if (cmsg_status < 0)
+ break;
+ if (cmsgdatalen > PY_SSIZE_T_MAX) {
+ PyErr_SetString(socket_error, "control message too long");
+ goto err_closefds;
+ }
+
+ bytes = PyBytes_FromStringAndSize((char *)CMSG_DATA(cmsgh),
+ cmsgdatalen);
+ tuple = Py_BuildValue("iiN", (int)cmsgh->cmsg_level,
+ (int)cmsgh->cmsg_type, bytes);
+ if (tuple == NULL)
+ goto err_closefds;
+ tmp = PyList_Append(cmsg_list, tuple);
+ Py_DECREF(tuple);
+ if (tmp != 0)
+ goto err_closefds;
+
+ if (cmsg_status != 0)
+ break;
+ }
+
+ retval = Py_BuildValue("NOiN",
+ (*makeval)(bytes_received, makeval_data),
+ cmsg_list,
+ (int)msg.msg_flags,
+ makesockaddr(s->sock_fd, SAS2SA(&addrbuf),
+ ((msg.msg_namelen > addrbuflen) ?
+ addrbuflen : msg.msg_namelen),
+ s->sock_proto));
+ if (retval == NULL)
+ goto err_closefds;
+
+finally:
+ Py_XDECREF(cmsg_list);
+ PyMem_Free(controlbuf);
+ return retval;
+
+err_closefds:
+#ifdef SCM_RIGHTS
+ /* Close all descriptors coming from SCM_RIGHTS, so they don't leak. */
+ for (cmsgh = ((msg.msg_controllen > 0) ? CMSG_FIRSTHDR(&msg) : NULL);
+ cmsgh != NULL; cmsgh = CMSG_NXTHDR(&msg, cmsgh)) {
+ cmsg_status = get_cmsg_data_len(&msg, cmsgh, &cmsgdatalen);
+ if (cmsg_status < 0)
+ break;
+ if (cmsgh->cmsg_level == SOL_SOCKET &&
+ cmsgh->cmsg_type == SCM_RIGHTS) {
+ size_t numfds;
+ int *fdp;
+
+ numfds = cmsgdatalen / sizeof(int);
+ fdp = (int *)CMSG_DATA(cmsgh);
+ while (numfds-- > 0)
+ close(*fdp++);
+ }
+ if (cmsg_status != 0)
+ break;
+ }
+#endif /* SCM_RIGHTS */
+ goto finally;
+}
+
+
+static PyObject *
+makeval_recvmsg(ssize_t received, void *data)
+{
+ PyObject **buf = data;
+
+ if (received < PyBytes_GET_SIZE(*buf))
+ _PyBytes_Resize(buf, received);
+ Py_XINCREF(*buf);
+ return *buf;
+}
+
+/* s.recvmsg(bufsize[, ancbufsize[, flags]]) method */
+
+static PyObject *
+sock_recvmsg(PySocketSockObject *s, PyObject *args)
+{
+ Py_ssize_t bufsize, ancbufsize = 0;
+ int flags = 0;
+ struct iovec iov;
+ PyObject *buf = NULL, *retval = NULL;
+
+ if (!PyArg_ParseTuple(args, "n|ni:recvmsg", &bufsize, &ancbufsize, &flags))
+ return NULL;
+
+ if (bufsize < 0) {
+ PyErr_SetString(PyExc_ValueError, "negative buffer size in recvmsg()");
+ return NULL;
+ }
+ if ((buf = PyBytes_FromStringAndSize(NULL, bufsize)) == NULL)
+ return NULL;
+ iov.iov_base = PyBytes_AS_STRING(buf);
+ iov.iov_len = bufsize;
+
+ /* Note that we're passing a pointer to *our pointer* to the bytes
+ object here (&buf); makeval_recvmsg() may incref the object, or
+ deallocate it and set our pointer to NULL. */
+ retval = sock_recvmsg_guts(s, &iov, 1, flags, ancbufsize,
+ &makeval_recvmsg, &buf);
+ Py_XDECREF(buf);
+ return retval;
+}
+
+PyDoc_STRVAR(recvmsg_doc,
+"recvmsg(bufsize[, ancbufsize[, flags]]) -> (data, ancdata, msg_flags, address)\n\
+\n\
+Receive normal data (up to bufsize bytes) and ancillary data from the\n\
+socket. The ancbufsize argument sets the size in bytes of the\n\
+internal buffer used to receive the ancillary data; it defaults to 0,\n\
+meaning that no ancillary data will be received. Appropriate buffer\n\
+sizes for ancillary data can be calculated using CMSG_SPACE() or\n\
+CMSG_LEN(), and items which do not fit into the buffer might be\n\
+truncated or discarded. The flags argument defaults to 0 and has the\n\
+same meaning as for recv().\n\
+\n\
+The return value is a 4-tuple: (data, ancdata, msg_flags, address).\n\
+The data item is a bytes object holding the non-ancillary data\n\
+received. The ancdata item is a list of zero or more tuples\n\
+(cmsg_level, cmsg_type, cmsg_data) representing the ancillary data\n\
+(control messages) received: cmsg_level and cmsg_type are integers\n\
+specifying the protocol level and protocol-specific type respectively,\n\
+and cmsg_data is a bytes object holding the associated data. The\n\
+msg_flags item is the bitwise OR of various flags indicating\n\
+conditions on the received message; see your system documentation for\n\
+details. If the receiving socket is unconnected, address is the\n\
+address of the sending socket, if available; otherwise, its value is\n\
+unspecified.\n\
+\n\
+If recvmsg() raises an exception after the system call returns, it\n\
+will first attempt to close any file descriptors received via the\n\
+SCM_RIGHTS mechanism.");
+
+
+static PyObject *
+makeval_recvmsg_into(ssize_t received, void *data)
+{
+ return PyLong_FromSsize_t(received);
+}
+
+/* s.recvmsg_into(buffers[, ancbufsize[, flags]]) method */
+
+static PyObject *
+sock_recvmsg_into(PySocketSockObject *s, PyObject *args)
+{
+ Py_ssize_t ancbufsize = 0;
+ int flags = 0;
+ struct iovec *iovs = NULL;
+ Py_ssize_t i, nitems, nbufs = 0;
+ Py_buffer *bufs = NULL;
+ PyObject *buffers_arg, *fast, *retval = NULL;
+
+ if (!PyArg_ParseTuple(args, "O|ni:recvmsg_into",
+ &buffers_arg, &ancbufsize, &flags))
+ return NULL;
+
+ if ((fast = PySequence_Fast(buffers_arg,
+ "recvmsg_into() argument 1 must be an "
+ "iterable")) == NULL)
+ return NULL;
+ nitems = PySequence_Fast_GET_SIZE(fast);
+ if (nitems > INT_MAX) {
+ PyErr_SetString(socket_error, "recvmsg_into() argument 1 is too long");
+ goto finally;
+ }
+
+ /* Fill in an iovec for each item, and save the Py_buffer
+ structs to release afterwards. */
+ if (nitems > 0 && ((iovs = PyMem_New(struct iovec, nitems)) == NULL ||
+ (bufs = PyMem_New(Py_buffer, nitems)) == NULL)) {
+ PyErr_NoMemory();
+ goto finally;
+ }
+ for (; nbufs < nitems; nbufs++) {
+ if (!PyArg_Parse(PySequence_Fast_GET_ITEM(fast, nbufs),
+ "w*;recvmsg_into() argument 1 must be an iterable "
+ "of single-segment read-write buffers",
+ &bufs[nbufs]))
+ goto finally;
+ iovs[nbufs].iov_base = bufs[nbufs].buf;
+ iovs[nbufs].iov_len = bufs[nbufs].len;
+ }
+
+ retval = sock_recvmsg_guts(s, iovs, nitems, flags, ancbufsize,
+ &makeval_recvmsg_into, NULL);
+finally:
+ for (i = 0; i < nbufs; i++)
+ PyBuffer_Release(&bufs[i]);
+ PyMem_Free(bufs);
+ PyMem_Free(iovs);
+ Py_DECREF(fast);
+ return retval;
+}
+
+PyDoc_STRVAR(recvmsg_into_doc,
+"recvmsg_into(buffers[, ancbufsize[, flags]]) -> (nbytes, ancdata, msg_flags, address)\n\
+\n\
+Receive normal data and ancillary data from the socket, scattering the\n\
+non-ancillary data into a series of buffers. The buffers argument\n\
+must be an iterable of objects that export writable buffers\n\
+(e.g. bytearray objects); these will be filled with successive chunks\n\
+of the non-ancillary data until it has all been written or there are\n\
+no more buffers. The ancbufsize argument sets the size in bytes of\n\
+the internal buffer used to receive the ancillary data; it defaults to\n\
+0, meaning that no ancillary data will be received. Appropriate\n\
+buffer sizes for ancillary data can be calculated using CMSG_SPACE()\n\
+or CMSG_LEN(), and items which do not fit into the buffer might be\n\
+truncated or discarded. The flags argument defaults to 0 and has the\n\
+same meaning as for recv().\n\
+\n\
+The return value is a 4-tuple: (nbytes, ancdata, msg_flags, address).\n\
+The nbytes item is the total number of bytes of non-ancillary data\n\
+written into the buffers. The ancdata item is a list of zero or more\n\
+tuples (cmsg_level, cmsg_type, cmsg_data) representing the ancillary\n\
+data (control messages) received: cmsg_level and cmsg_type are\n\
+integers specifying the protocol level and protocol-specific type\n\
+respectively, and cmsg_data is a bytes object holding the associated\n\
+data. The msg_flags item is the bitwise OR of various flags\n\
+indicating conditions on the received message; see your system\n\
+documentation for details. If the receiving socket is unconnected,\n\
+address is the address of the sending socket, if available; otherwise,\n\
+its value is unspecified.\n\
+\n\
+If recvmsg_into() raises an exception after the system call returns,\n\
+it will first attempt to close any file descriptors received via the\n\
+SCM_RIGHTS mechanism.");
+#endif /* CMSG_LEN */
+
+
/* s.send(data [,flags]) method */
static PyObject *
@@ -2826,6 +3276,237 @@ Like send(data, flags) but allows specifying the destination address.\n\
For IP sockets, the address is a pair (hostaddr, port).");
+/* The sendmsg() and recvmsg[_into]() methods require a working
+ CMSG_LEN(). See the comment near get_CMSG_LEN(). */
+#ifdef CMSG_LEN
+/* s.sendmsg(buffers[, ancdata[, flags[, address]]]) method */
+
+static PyObject *
+sock_sendmsg(PySocketSockObject *s, PyObject *args)
+{
+ Py_ssize_t i, ndataparts, ndatabufs = 0, ncmsgs, ncmsgbufs = 0;
+ Py_buffer *databufs = NULL;
+ struct iovec *iovs = NULL;
+ sock_addr_t addrbuf;
+ static const struct msghdr msg_blank;
+ struct msghdr msg;
+ struct cmsginfo {
+ int level;
+ int type;
+ Py_buffer data;
+ } *cmsgs = NULL;
+ void *controlbuf = NULL;
+ size_t controllen, controllen_last;
+ ssize_t bytes_sent = -1;
+ int addrlen, timeout, flags = 0;
+ PyObject *data_arg, *cmsg_arg = NULL, *addr_arg = NULL, *data_fast = NULL,
+ *cmsg_fast = NULL, *retval = NULL;
+
+ if (!PyArg_ParseTuple(args, "O|OiO:sendmsg",
+ &data_arg, &cmsg_arg, &flags, &addr_arg))
+ return NULL;
+
+ msg = msg_blank; /* Set all members to 0 or NULL */
+
+ /* Parse destination address. */
+ if (addr_arg != NULL && addr_arg != Py_None) {
+ if (!getsockaddrarg(s, addr_arg, SAS2SA(&addrbuf), &addrlen))
+ goto finally;
+ msg.msg_name = &addrbuf;
+ msg.msg_namelen = addrlen;
+ }
+
+ /* Fill in an iovec for each message part, and save the Py_buffer
+ structs to release afterwards. */
+ if ((data_fast = PySequence_Fast(data_arg,
+ "sendmsg() argument 1 must be an "
+ "iterable")) == NULL)
+ goto finally;
+ ndataparts = PySequence_Fast_GET_SIZE(data_fast);
+ if (ndataparts > INT_MAX) {
+ PyErr_SetString(socket_error, "sendmsg() argument 1 is too long");
+ goto finally;
+ }
+ msg.msg_iovlen = ndataparts;
+ if (ndataparts > 0 &&
+ ((msg.msg_iov = iovs = PyMem_New(struct iovec, ndataparts)) == NULL ||
+ (databufs = PyMem_New(Py_buffer, ndataparts)) == NULL)) {
+ PyErr_NoMemory();
+ goto finally;
+ }
+ for (; ndatabufs < ndataparts; ndatabufs++) {
+ if (!PyArg_Parse(PySequence_Fast_GET_ITEM(data_fast, ndatabufs),
+ "y*;sendmsg() argument 1 must be an iterable of "
+ "buffer-compatible objects",
+ &databufs[ndatabufs]))
+ goto finally;
+ iovs[ndatabufs].iov_base = databufs[ndatabufs].buf;
+ iovs[ndatabufs].iov_len = databufs[ndatabufs].len;
+ }
+
+ if (cmsg_arg == NULL)
+ ncmsgs = 0;
+ else {
+ if ((cmsg_fast = PySequence_Fast(cmsg_arg,
+ "sendmsg() argument 2 must be an "
+ "iterable")) == NULL)
+ goto finally;
+ ncmsgs = PySequence_Fast_GET_SIZE(cmsg_fast);
+ }
+
+#ifndef CMSG_SPACE
+ if (ncmsgs > 1) {
+ PyErr_SetString(socket_error,
+ "sending multiple control messages is not supported "
+ "on this system");
+ goto finally;
+ }
+#endif
+ /* Save level, type and Py_buffer for each control message,
+ and calculate total size. */
+ if (ncmsgs > 0 && (cmsgs = PyMem_New(struct cmsginfo, ncmsgs)) == NULL) {
+ PyErr_NoMemory();
+ goto finally;
+ }
+ controllen = controllen_last = 0;
+ while (ncmsgbufs < ncmsgs) {
+ size_t bufsize, space;
+
+ if (!PyArg_Parse(PySequence_Fast_GET_ITEM(cmsg_fast, ncmsgbufs),
+ "(iiy*):[sendmsg() ancillary data items]",
+ &cmsgs[ncmsgbufs].level,
+ &cmsgs[ncmsgbufs].type,
+ &cmsgs[ncmsgbufs].data))
+ goto finally;
+ bufsize = cmsgs[ncmsgbufs++].data.len;
+
+#ifdef CMSG_SPACE
+ if (!get_CMSG_SPACE(bufsize, &space)) {
+#else
+ if (!get_CMSG_LEN(bufsize, &space)) {
+#endif
+ PyErr_SetString(socket_error, "ancillary data item too large");
+ goto finally;
+ }
+ controllen += space;
+ if (controllen > SOCKLEN_T_LIMIT || controllen < controllen_last) {
+ PyErr_SetString(socket_error, "too much ancillary data");
+ goto finally;
+ }
+ controllen_last = controllen;
+ }
+
+ /* Construct ancillary data block from control message info. */
+ if (ncmsgbufs > 0) {
+ struct cmsghdr *cmsgh = NULL;
+
+ if ((msg.msg_control = controlbuf =
+ PyMem_Malloc(controllen)) == NULL) {
+ PyErr_NoMemory();
+ goto finally;
+ }
+ msg.msg_controllen = controllen;
+
+ /* Need to zero out the buffer as a workaround for glibc's
+ CMSG_NXTHDR() implementation. After getting the pointer to
+ the next header, it checks its (uninitialized) cmsg_len
+ member to see if the "message" fits in the buffer, and
+ returns NULL if it doesn't. Zero-filling the buffer
+ ensures that that doesn't happen. */
+ memset(controlbuf, 0, controllen);
+
+ for (i = 0; i < ncmsgbufs; i++) {
+ size_t msg_len, data_len = cmsgs[i].data.len;
+ int enough_space = 0;
+
+ cmsgh = (i == 0) ? CMSG_FIRSTHDR(&msg) : CMSG_NXTHDR(&msg, cmsgh);
+ if (cmsgh == NULL) {
+ PyErr_Format(PyExc_RuntimeError,
+ "unexpected NULL result from %s()",
+ (i == 0) ? "CMSG_FIRSTHDR" : "CMSG_NXTHDR");
+ goto finally;
+ }
+ if (!get_CMSG_LEN(data_len, &msg_len)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "item size out of range for CMSG_LEN()");
+ goto finally;
+ }
+ if (cmsg_min_space(&msg, cmsgh, msg_len)) {
+ size_t space;
+
+ cmsgh->cmsg_len = msg_len;
+ if (get_cmsg_data_space(&msg, cmsgh, &space))
+ enough_space = (space >= data_len);
+ }
+ if (!enough_space) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "ancillary data does not fit in calculated "
+ "space");
+ goto finally;
+ }
+ cmsgh->cmsg_level = cmsgs[i].level;
+ cmsgh->cmsg_type = cmsgs[i].type;
+ memcpy(CMSG_DATA(cmsgh), cmsgs[i].data.buf, data_len);
+ }
+ }
+
+ /* Make the system call. */
+ if (!IS_SELECTABLE(s)) {
+ select_error();
+ goto finally;
+ }
+
+ BEGIN_SELECT_LOOP(s)
+ Py_BEGIN_ALLOW_THREADS;
+ timeout = internal_select_ex(s, 1, interval);
+ if (!timeout)
+ bytes_sent = sendmsg(s->sock_fd, &msg, flags);
+ Py_END_ALLOW_THREADS;
+ if (timeout == 1) {
+ PyErr_SetString(socket_timeout, "timed out");
+ goto finally;
+ }
+ END_SELECT_LOOP(s)
+
+ if (bytes_sent < 0) {
+ s->errorhandler();
+ goto finally;
+ }
+ retval = PyLong_FromSsize_t(bytes_sent);
+
+finally:
+ PyMem_Free(controlbuf);
+ for (i = 0; i < ncmsgbufs; i++)
+ PyBuffer_Release(&cmsgs[i].data);
+ PyMem_Free(cmsgs);
+ Py_XDECREF(cmsg_fast);
+ for (i = 0; i < ndatabufs; i++)
+ PyBuffer_Release(&databufs[i]);
+ PyMem_Free(databufs);
+ PyMem_Free(iovs);
+ Py_XDECREF(data_fast);
+ return retval;
+}
+
+PyDoc_STRVAR(sendmsg_doc,
+"sendmsg(buffers[, ancdata[, flags[, address]]]) -> count\n\
+\n\
+Send normal and ancillary data to the socket, gathering the\n\
+non-ancillary data from a series of buffers and concatenating it into\n\
+a single message. The buffers argument specifies the non-ancillary\n\
+data as an iterable of buffer-compatible objects (e.g. bytes objects).\n\
+The ancdata argument specifies the ancillary data (control messages)\n\
+as an iterable of zero or more tuples (cmsg_level, cmsg_type,\n\
+cmsg_data), where cmsg_level and cmsg_type are integers specifying the\n\
+protocol level and protocol-specific type respectively, and cmsg_data\n\
+is a buffer-compatible object holding the associated data. The flags\n\
+argument defaults to 0 and has the same meaning as for send(). If\n\
+address is supplied and not None, it sets a destination address for\n\
+the message. The return value is the number of bytes of non-ancillary\n\
+data sent.");
+#endif /* CMSG_LEN */
+
+
/* s.shutdown(how) method */
static PyObject *
@@ -2952,6 +3633,14 @@ static PyMethodDef sock_methods[] = {
setsockopt_doc},
{"shutdown", (PyCFunction)sock_shutdown, METH_O,
shutdown_doc},
+#ifdef CMSG_LEN
+ {"recvmsg", (PyCFunction)sock_recvmsg, METH_VARARGS,
+ recvmsg_doc},
+ {"recvmsg_into", (PyCFunction)sock_recvmsg_into, METH_VARARGS,
+ recvmsg_into_doc,},
+ {"sendmsg", (PyCFunction)sock_sendmsg, METH_VARARGS,
+ sendmsg_doc},
+#endif
{NULL, NULL} /* sentinel */
};
@@ -4377,6 +5066,68 @@ Returns the interface name corresponding to the interface index if_index.");
#endif /* HAVE_IF_NAMEINDEX */
+#ifdef CMSG_LEN
+/* Python interface to CMSG_LEN(length). */
+
+static PyObject *
+socket_CMSG_LEN(PyObject *self, PyObject *args)
+{
+ Py_ssize_t length;
+ size_t result;
+
+ if (!PyArg_ParseTuple(args, "n:CMSG_LEN", &length))
+ return NULL;
+ if (length < 0 || !get_CMSG_LEN(length, &result)) {
+ PyErr_Format(PyExc_OverflowError, "CMSG_LEN() argument out of range");
+ return NULL;
+ }
+ return PyLong_FromSize_t(result);
+}
+
+PyDoc_STRVAR(CMSG_LEN_doc,
+"CMSG_LEN(length) -> control message length\n\
+\n\
+Return the total length, without trailing padding, of an ancillary\n\
+data item with associated data of the given length. This value can\n\
+often be used as the buffer size for recvmsg() to receive a single\n\
+item of ancillary data, but RFC 3542 requires portable applications to\n\
+use CMSG_SPACE() and thus include space for padding, even when the\n\
+item will be the last in the buffer. Raises OverflowError if length\n\
+is outside the permissible range of values.");
+
+
+#ifdef CMSG_SPACE
+/* Python interface to CMSG_SPACE(length). */
+
+static PyObject *
+socket_CMSG_SPACE(PyObject *self, PyObject *args)
+{
+ Py_ssize_t length;
+ size_t result;
+
+ if (!PyArg_ParseTuple(args, "n:CMSG_SPACE", &length))
+ return NULL;
+ if (length < 0 || !get_CMSG_SPACE(length, &result)) {
+ PyErr_SetString(PyExc_OverflowError,
+ "CMSG_SPACE() argument out of range");
+ return NULL;
+ }
+ return PyLong_FromSize_t(result);
+}
+
+PyDoc_STRVAR(CMSG_SPACE_doc,
+"CMSG_SPACE(length) -> buffer size\n\
+\n\
+Return the buffer size needed for recvmsg() to receive an ancillary\n\
+data item with associated data of the given length, along with any\n\
+trailing padding. The buffer space needed to receive multiple items\n\
+is the sum of the CMSG_SPACE() values for their associated data\n\
+lengths. Raises OverflowError if length is outside the permissible\n\
+range of values.");
+#endif /* CMSG_SPACE */
+#endif /* CMSG_LEN */
+
+
/* List of functions exported by this module. */
static PyMethodDef socket_methods[] = {
@@ -4440,6 +5191,14 @@ static PyMethodDef socket_methods[] = {
{"if_indextoname", socket_if_indextoname,
METH_O, if_indextoname_doc},
#endif
+#ifdef CMSG_LEN
+ {"CMSG_LEN", socket_CMSG_LEN,
+ METH_VARARGS, CMSG_LEN_doc},
+#ifdef CMSG_SPACE
+ {"CMSG_SPACE", socket_CMSG_SPACE,
+ METH_VARARGS, CMSG_SPACE_doc},
+#endif
+#endif
{NULL, NULL} /* Sentinel */
};
@@ -4927,6 +5686,15 @@ PyInit__socket(void)
#ifdef SO_SETFIB
PyModule_AddIntConstant(m, "SO_SETFIB", SO_SETFIB);
#endif
+#ifdef SO_PASSCRED
+ PyModule_AddIntConstant(m, "SO_PASSCRED", SO_PASSCRED);
+#endif
+#ifdef SO_PEERCRED
+ PyModule_AddIntConstant(m, "SO_PEERCRED", SO_PEERCRED);
+#endif
+#ifdef LOCAL_PEERCRED
+ PyModule_AddIntConstant(m, "LOCAL_PEERCRED", LOCAL_PEERCRED);
+#endif
/* Maximum number of connections for "listen" */
#ifdef SOMAXCONN
@@ -4935,6 +5703,17 @@ PyInit__socket(void)
PyModule_AddIntConstant(m, "SOMAXCONN", 5); /* Common value */
#endif
+ /* Ancilliary message types */
+#ifdef SCM_RIGHTS
+ PyModule_AddIntConstant(m, "SCM_RIGHTS", SCM_RIGHTS);
+#endif
+#ifdef SCM_CREDENTIALS
+ PyModule_AddIntConstant(m, "SCM_CREDENTIALS", SCM_CREDENTIALS);
+#endif
+#ifdef SCM_CREDS
+ PyModule_AddIntConstant(m, "SCM_CREDS", SCM_CREDS);
+#endif
+
/* Flags for send, recv */
#ifdef MSG_OOB
PyModule_AddIntConstant(m, "MSG_OOB", MSG_OOB);
@@ -4966,6 +5745,33 @@ PyInit__socket(void)
#ifdef MSG_ETAG
PyModule_AddIntConstant(m, "MSG_ETAG", MSG_ETAG);
#endif
+#ifdef MSG_NOSIGNAL
+ PyModule_AddIntConstant(m, "MSG_NOSIGNAL", MSG_NOSIGNAL);
+#endif
+#ifdef MSG_NOTIFICATION
+ PyModule_AddIntConstant(m, "MSG_NOTIFICATION", MSG_NOTIFICATION);
+#endif
+#ifdef MSG_CMSG_CLOEXEC
+ PyModule_AddIntConstant(m, "MSG_CMSG_CLOEXEC", MSG_CMSG_CLOEXEC);
+#endif
+#ifdef MSG_ERRQUEUE
+ PyModule_AddIntConstant(m, "MSG_ERRQUEUE", MSG_ERRQUEUE);
+#endif
+#ifdef MSG_CONFIRM
+ PyModule_AddIntConstant(m, "MSG_CONFIRM", MSG_CONFIRM);
+#endif
+#ifdef MSG_MORE
+ PyModule_AddIntConstant(m, "MSG_MORE", MSG_MORE);
+#endif
+#ifdef MSG_EOF
+ PyModule_AddIntConstant(m, "MSG_EOF", MSG_EOF);
+#endif
+#ifdef MSG_BCAST
+ PyModule_AddIntConstant(m, "MSG_BCAST", MSG_BCAST);
+#endif
+#ifdef MSG_MCAST
+ PyModule_AddIntConstant(m, "MSG_MCAST", MSG_MCAST);
+#endif
/* Protocol level and numbers, usable for [gs]etsockopt */
#ifdef SOL_SOCKET
@@ -5105,6 +5911,9 @@ PyInit__socket(void)
#ifdef IPPROTO_VRRP
PyModule_AddIntConstant(m, "IPPROTO_VRRP", IPPROTO_VRRP);
#endif
+#ifdef IPPROTO_SCTP
+ PyModule_AddIntConstant(m, "IPPROTO_SCTP", IPPROTO_SCTP);
+#endif
#ifdef IPPROTO_BIP
PyModule_AddIntConstant(m, "IPPROTO_BIP", IPPROTO_BIP);
#endif