summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_socket.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_socket.py')
-rw-r--r--Lib/test/test_socket.py2799
1 files changed, 2738 insertions, 61 deletions
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index e40b21e..0ce9729 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -2,11 +2,14 @@
import unittest
from test import support
+from unittest.case import _ExpectedFailure
import errno
import io
import socket
import select
+import tempfile
+import _testcapi
import time
import traceback
import queue
@@ -18,34 +21,19 @@ import contextlib
from weakref import proxy
import signal
import math
+import pickle
+import struct
try:
import fcntl
except ImportError:
fcntl = False
-
-def try_address(host, port=0, family=socket.AF_INET):
- """Try to bind a socket on the given host:port and return True
- if that has been possible."""
- try:
- sock = socket.socket(family, socket.SOCK_STREAM)
- sock.bind((host, port))
- except (socket.error, socket.gaierror):
- return False
- else:
- sock.close()
- return True
-
-def linux_version():
- try:
- # platform.release() is something like '2.6.33.7-desktop-2mnb'
- version_string = platform.release().split('-')[0]
- return tuple(map(int, version_string.split('.')))
- except ValueError:
- return 0, 0, 0
+try:
+ import multiprocessing
+except ImportError:
+ multiprocessing = False
HOST = support.HOST
-MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf8') ## test unicode string and carriage return
-SUPPORTS_IPV6 = socket.has_ipv6 and try_address('::1', family=socket.AF_INET6)
+MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') ## test unicode string and carriage return
try:
import _thread as thread
@@ -54,6 +42,33 @@ except ImportError:
thread = None
threading = None
+def _have_socket_can():
+ """Check whether CAN sockets are supported on this host."""
+ try:
+ s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
+ except (AttributeError, socket.error, OSError):
+ return False
+ else:
+ s.close()
+ return True
+
+def _have_socket_rds():
+ """Check whether RDS sockets are supported on this host."""
+ try:
+ s = socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0)
+ except (AttributeError, OSError):
+ return False
+ else:
+ s.close()
+ return True
+
+HAVE_SOCKET_CAN = _have_socket_can()
+
+HAVE_SOCKET_RDS = _have_socket_rds()
+
+# Size in bytes of the int type
+SIZEOF_INT = array.array("i").itemsize
+
class SocketTCPTest(unittest.TestCase):
def setUp(self):
@@ -75,6 +90,63 @@ class SocketUDPTest(unittest.TestCase):
self.serv.close()
self.serv = None
+class ThreadSafeCleanupTestCase(unittest.TestCase):
+ """Subclass of unittest.TestCase with thread-safe cleanup methods.
+
+ This subclass protects the addCleanup() and doCleanups() methods
+ with a recursive lock.
+ """
+
+ if threading:
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._cleanup_lock = threading.RLock()
+
+ def addCleanup(self, *args, **kwargs):
+ with self._cleanup_lock:
+ return super().addCleanup(*args, **kwargs)
+
+ def doCleanups(self, *args, **kwargs):
+ with self._cleanup_lock:
+ return super().doCleanups(*args, **kwargs)
+
+class SocketCANTest(unittest.TestCase):
+
+ """To be able to run this test, a `vcan0` CAN interface can be created with
+ the following commands:
+ # modprobe vcan
+ # ip link add dev vcan0 type vcan
+ # ifconfig vcan0 up
+ """
+ interface = 'vcan0'
+ bufsize = 128
+
+ def setUp(self):
+ self.s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
+ self.addCleanup(self.s.close)
+ try:
+ self.s.bind((self.interface,))
+ except socket.error:
+ self.skipTest('network interface `%s` does not exist' %
+ self.interface)
+
+
+class SocketRDSTest(unittest.TestCase):
+
+ """To be able to run this test, the `rds` kernel module must be loaded:
+ # modprobe rds
+ """
+ bufsize = 8192
+
+ def setUp(self):
+ self.serv = socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0)
+ self.addCleanup(self.serv.close)
+ try:
+ self.port = support.bind_port(self.serv)
+ except OSError:
+ self.skipTest('unable to bind RDS socket')
+
+
class ThreadableTest:
"""Threadable Test class
@@ -132,6 +204,7 @@ class ThreadableTest:
self.client_ready = threading.Event()
self.done = threading.Event()
self.queue = queue.Queue(1)
+ self.server_crashed = False
# Do some munging to start the client test.
methodname = self.id()
@@ -141,8 +214,12 @@ class ThreadableTest:
self.client_thread = thread.start_new_thread(
self.clientRun, (test_method,))
- self.__setUp()
- if not self.server_ready.is_set():
+ try:
+ self.__setUp()
+ except:
+ self.server_crashed = True
+ raise
+ finally:
self.server_ready.set()
self.client_ready.wait()
@@ -158,10 +235,16 @@ class ThreadableTest:
self.server_ready.wait()
self.clientSetUp()
self.client_ready.set()
+ if self.server_crashed:
+ self.clientTearDown()
+ return
if not hasattr(test_func, '__call__'):
raise TypeError("test_func must be a callable function")
try:
test_func()
+ except _ExpectedFailure:
+ # We deliberately ignore expected failures
+ pass
except BaseException as e:
self.queue.put(e)
finally:
@@ -202,6 +285,48 @@ class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
self.cli = None
ThreadableTest.clientTearDown(self)
+class ThreadedCANSocketTest(SocketCANTest, ThreadableTest):
+
+ def __init__(self, methodName='runTest'):
+ SocketCANTest.__init__(self, methodName=methodName)
+ ThreadableTest.__init__(self)
+
+ def clientSetUp(self):
+ self.cli = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
+ try:
+ self.cli.bind((self.interface,))
+ except socket.error:
+ # skipTest should not be called here, and will be called in the
+ # server instead
+ pass
+
+ def clientTearDown(self):
+ self.cli.close()
+ self.cli = None
+ ThreadableTest.clientTearDown(self)
+
+class ThreadedRDSSocketTest(SocketRDSTest, ThreadableTest):
+
+ def __init__(self, methodName='runTest'):
+ SocketRDSTest.__init__(self, methodName=methodName)
+ ThreadableTest.__init__(self)
+
+ def clientSetUp(self):
+ self.cli = socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0)
+ try:
+ # RDS sockets must be bound explicitly to send or receive data
+ self.cli.bind((HOST, 0))
+ self.cli_addr = self.cli.getsockname()
+ except OSError:
+ # skipTest should not be called here, and will be called in the
+ # server instead
+ pass
+
+ def clientTearDown(self):
+ self.cli.close()
+ self.cli = None
+ ThreadableTest.clientTearDown(self)
+
class SocketConnectedTest(ThreadedTCPSocketTest):
"""Socket tests for client-server connection.
@@ -257,6 +382,243 @@ class SocketPairTest(unittest.TestCase, ThreadableTest):
ThreadableTest.clientTearDown(self)
+# The following classes are used by the sendmsg()/recvmsg() tests.
+# Combining, for instance, ConnectedStreamTestMixin and TCPTestBase
+# gives a drop-in replacement for SocketConnectedTest, but different
+# address families can be used, and the attributes serv_addr and
+# cli_addr will be set to the addresses of the endpoints.
+
+class SocketTestBase(unittest.TestCase):
+ """A base class for socket tests.
+
+ Subclasses must provide methods newSocket() to return a new socket
+ and bindSock(sock) to bind it to an unused address.
+
+ Creates a socket self.serv and sets self.serv_addr to its address.
+ """
+
+ def setUp(self):
+ self.serv = self.newSocket()
+ self.bindServer()
+
+ def bindServer(self):
+ """Bind server socket and set self.serv_addr to its address."""
+ self.bindSock(self.serv)
+ self.serv_addr = self.serv.getsockname()
+
+ def tearDown(self):
+ self.serv.close()
+ self.serv = None
+
+
+class SocketListeningTestMixin(SocketTestBase):
+ """Mixin to listen on the server socket."""
+
+ def setUp(self):
+ super().setUp()
+ self.serv.listen(1)
+
+
+class ThreadedSocketTestMixin(ThreadSafeCleanupTestCase, SocketTestBase,
+ ThreadableTest):
+ """Mixin to add client socket and allow client/server tests.
+
+ Client socket is self.cli and its address is self.cli_addr. See
+ ThreadableTest for usage information.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ ThreadableTest.__init__(self)
+
+ def clientSetUp(self):
+ self.cli = self.newClientSocket()
+ self.bindClient()
+
+ def newClientSocket(self):
+ """Return a new socket for use as client."""
+ return self.newSocket()
+
+ def bindClient(self):
+ """Bind client socket and set self.cli_addr to its address."""
+ self.bindSock(self.cli)
+ self.cli_addr = self.cli.getsockname()
+
+ def clientTearDown(self):
+ self.cli.close()
+ self.cli = None
+ ThreadableTest.clientTearDown(self)
+
+
+class ConnectedStreamTestMixin(SocketListeningTestMixin,
+ ThreadedSocketTestMixin):
+ """Mixin to allow client/server stream tests with connected client.
+
+ Server's socket representing connection to client is self.cli_conn
+ and client's connection to server is self.serv_conn. (Based on
+ SocketConnectedTest.)
+ """
+
+ def setUp(self):
+ super().setUp()
+ # Indicate explicitly we're ready for the client thread to
+ # proceed and then perform the blocking call to accept
+ self.serverExplicitReady()
+ conn, addr = self.serv.accept()
+ self.cli_conn = conn
+
+ def tearDown(self):
+ self.cli_conn.close()
+ self.cli_conn = None
+ super().tearDown()
+
+ def clientSetUp(self):
+ super().clientSetUp()
+ self.cli.connect(self.serv_addr)
+ self.serv_conn = self.cli
+
+ def clientTearDown(self):
+ self.serv_conn.close()
+ self.serv_conn = None
+ super().clientTearDown()
+
+
+class UnixSocketTestBase(SocketTestBase):
+ """Base class for Unix-domain socket tests."""
+
+ # This class is used for file descriptor passing tests, so we
+ # create the sockets in a private directory so that other users
+ # can't send anything that might be problematic for a privileged
+ # user running the tests.
+
+ def setUp(self):
+ self.dir_path = tempfile.mkdtemp()
+ self.addCleanup(os.rmdir, self.dir_path)
+ super().setUp()
+
+ def bindSock(self, sock):
+ path = tempfile.mktemp(dir=self.dir_path)
+ sock.bind(path)
+ self.addCleanup(support.unlink, path)
+
+class UnixStreamBase(UnixSocketTestBase):
+ """Base class for Unix-domain SOCK_STREAM tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+
+
+class InetTestBase(SocketTestBase):
+ """Base class for IPv4 socket tests."""
+
+ host = HOST
+
+ def setUp(self):
+ super().setUp()
+ self.port = self.serv_addr[1]
+
+ def bindSock(self, sock):
+ support.bind_port(sock, host=self.host)
+
+class TCPTestBase(InetTestBase):
+ """Base class for TCP-over-IPv4 tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+class UDPTestBase(InetTestBase):
+ """Base class for UDP-over-IPv4 tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+
+class SCTPStreamBase(InetTestBase):
+ """Base class for SCTP tests in one-to-one (SOCK_STREAM) mode."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET, socket.SOCK_STREAM,
+ socket.IPPROTO_SCTP)
+
+
+class Inet6TestBase(InetTestBase):
+ """Base class for IPv6 socket tests."""
+
+ # Don't use "localhost" here - it may not have an IPv6 address
+ # assigned to it by default (e.g. in /etc/hosts), and if someone
+ # has assigned it an IPv4-mapped address, then it's unlikely to
+ # work with the full IPv6 API.
+ host = "::1"
+
+class UDP6TestBase(Inet6TestBase):
+ """Base class for UDP-over-IPv6 tests."""
+
+ def newSocket(self):
+ return socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
+
+
+# Test-skipping decorators for use with ThreadableTest.
+
+def skipWithClientIf(condition, reason):
+ """Skip decorated test if condition is true, add client_skip decorator.
+
+ If the decorated object is not a class, sets its attribute
+ "client_skip" to a decorator which will return an empty function
+ if the test is to be skipped, or the original function if it is
+ not. This can be used to avoid running the client part of a
+ skipped test when using ThreadableTest.
+ """
+ def client_pass(*args, **kwargs):
+ pass
+ def skipdec(obj):
+ retval = unittest.skip(reason)(obj)
+ if not isinstance(obj, type):
+ retval.client_skip = lambda f: client_pass
+ return retval
+ def noskipdec(obj):
+ if not (isinstance(obj, type) or hasattr(obj, "client_skip")):
+ obj.client_skip = lambda f: f
+ return obj
+ return skipdec if condition else noskipdec
+
+
+def requireAttrs(obj, *attributes):
+ """Skip decorated test if obj is missing any of the given attributes.
+
+ Sets client_skip attribute as skipWithClientIf() does.
+ """
+ missing = [name for name in attributes if not hasattr(obj, name)]
+ return skipWithClientIf(
+ missing, "don't have " + ", ".join(name for name in missing))
+
+
+def requireSocket(*args):
+ """Skip decorated test if a socket cannot be created with given arguments.
+
+ When an argument is given as a string, will use the value of that
+ attribute of the socket module, or skip the test if it doesn't
+ exist. Sets client_skip attribute as skipWithClientIf() does.
+ """
+ err = None
+ missing = [obj for obj in args if
+ isinstance(obj, str) and not hasattr(socket, obj)]
+ if missing:
+ err = "don't have " + ", ".join(name for name in missing)
+ else:
+ callargs = [getattr(socket, obj) if isinstance(obj, str) else obj
+ for obj in args]
+ try:
+ s = socket.socket(*callargs)
+ except socket.error as e:
+ # XXX: check errno?
+ err = str(e)
+ else:
+ s.close()
+ return skipWithClientIf(
+ err is not None,
+ "can't create socket({0}): {1}".format(
+ ", ".join(str(o) for o in args), err))
+
+
#######################################################################
## Begin Tests
@@ -282,18 +644,13 @@ class GeneralModuleTests(unittest.TestCase):
def testSocketError(self):
# Testing socket module exceptions
- def raise_error(*args, **kwargs):
+ msg = "Error raising socket exception (%s)."
+ with self.assertRaises(socket.error, msg=msg % 'socket.error'):
raise socket.error
- def raise_herror(*args, **kwargs):
+ with self.assertRaises(socket.error, msg=msg % 'socket.herror'):
raise socket.herror
- def raise_gaierror(*args, **kwargs):
+ with self.assertRaises(socket.error, msg=msg % 'socket.gaierror'):
raise socket.gaierror
- self.assertRaises(socket.error, raise_error,
- "Error raising socket exception.")
- self.assertRaises(socket.error, raise_herror,
- "Error raising socket exception.")
- self.assertRaises(socket.error, raise_gaierror,
- "Error raising socket exception.")
def testSendtoErrors(self):
# Testing that sendto doens't masks failures. See #10169.
@@ -369,6 +726,52 @@ class GeneralModuleTests(unittest.TestCase):
if not fqhn in all_host_names:
self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
+ @unittest.skipUnless(hasattr(socket, 'sethostname'), "test needs socket.sethostname()")
+ @unittest.skipUnless(hasattr(socket, 'gethostname'), "test needs socket.gethostname()")
+ def test_sethostname(self):
+ oldhn = socket.gethostname()
+ try:
+ socket.sethostname('new')
+ except socket.error as e:
+ if e.errno == errno.EPERM:
+ self.skipTest("test should be run as root")
+ else:
+ raise
+ try:
+ # running test as root!
+ self.assertEqual(socket.gethostname(), 'new')
+ # Should work with bytes objects too
+ socket.sethostname(b'bar')
+ self.assertEqual(socket.gethostname(), 'bar')
+ finally:
+ socket.sethostname(oldhn)
+
+ @unittest.skipUnless(hasattr(socket, 'if_nameindex'),
+ 'socket.if_nameindex() not available.')
+ def testInterfaceNameIndex(self):
+ interfaces = socket.if_nameindex()
+ for index, name in interfaces:
+ self.assertIsInstance(index, int)
+ self.assertIsInstance(name, str)
+ # interface indices are non-zero integers
+ self.assertGreater(index, 0)
+ _index = socket.if_nametoindex(name)
+ self.assertIsInstance(_index, int)
+ self.assertEqual(index, _index)
+ _name = socket.if_indextoname(index)
+ self.assertIsInstance(_name, str)
+ self.assertEqual(name, _name)
+
+ @unittest.skipUnless(hasattr(socket, 'if_nameindex'),
+ 'socket.if_nameindex() not available.')
+ def testInvalidInterfaceNameIndex(self):
+ # test nonexistent interface index/name
+ self.assertRaises(socket.error, socket.if_indextoname, 0)
+ self.assertRaises(socket.error, socket.if_nametoindex, '_DEADBEEF')
+ # test with invalid values
+ self.assertRaises(TypeError, socket.if_nametoindex, 0)
+ self.assertRaises(TypeError, socket.if_indextoname, '_DEADBEEF')
+
def testRefCountGetNameInfo(self):
# Testing reference count for getnameinfo
if hasattr(sys, "getrefcount"):
@@ -421,10 +824,8 @@ class GeneralModuleTests(unittest.TestCase):
# Find one service that exists, then check all the related interfaces.
# I've ordered this by protocols that have both a tcp and udp
# protocol, at least for modern Linuxes.
- if (sys.platform.startswith('linux') or
- sys.platform.startswith('freebsd') or
- sys.platform.startswith('netbsd') or
- sys.platform == 'darwin'):
+ if (sys.platform.startswith(('freebsd', 'netbsd'))
+ or sys.platform in ('linux', 'darwin')):
# avoid the 'echo' service on this platform, as there is an
# assumption breaking non-standard port/protocol entry
services = ('daytime', 'qotd', 'domain')
@@ -719,7 +1120,7 @@ class GeneralModuleTests(unittest.TestCase):
socket.getaddrinfo('localhost', 80)
socket.getaddrinfo('127.0.0.1', 80)
socket.getaddrinfo(None, 80)
- if SUPPORTS_IPV6:
+ if support.IPV6_ENABLED:
socket.getaddrinfo('::1', 80)
# port can be a string service name such as "http", a numeric
# port number or None
@@ -772,7 +1173,12 @@ class GeneralModuleTests(unittest.TestCase):
@unittest.skipUnless(support.is_resource_enabled('network'),
'network is not enabled')
def test_idna(self):
- support.requires('network')
+ # Check for internet access before running test (issue #12804).
+ try:
+ socket.gethostbyname('python.org')
+ except socket.gaierror as e:
+ if e.errno == socket.EAI_NODATA:
+ self.skipTest('internet access required for this test')
# these should all be successful
socket.gethostbyname('испытание.python.org')
socket.gethostbyname_ex('испытание.python.org')
@@ -850,14 +1256,20 @@ class GeneralModuleTests(unittest.TestCase):
self.assertRaises(ValueError, fp.writable)
self.assertRaises(ValueError, fp.seekable)
- def testListenBacklog0(self):
+ def test_pickle(self):
+ sock = socket.socket()
+ with sock:
+ for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
+ self.assertRaises(TypeError, pickle.dumps, sock, protocol)
+
+ def test_listen_backlog0(self):
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
srv.bind((HOST, 0))
# backlog = 0
srv.listen(0)
srv.close()
- @unittest.skipUnless(SUPPORTS_IPV6, 'IPv6 required for this test.')
+ @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
def test_flowinfo(self):
self.assertRaises(OverflowError, socket.getnameinfo,
('::1',0, 0xffffffff), 0)
@@ -865,6 +1277,222 @@ class GeneralModuleTests(unittest.TestCase):
self.assertRaises(OverflowError, s.bind, ('::1', 0, -10))
+@unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.')
+class BasicCANTest(unittest.TestCase):
+
+ def testCrucialConstants(self):
+ socket.AF_CAN
+ socket.PF_CAN
+ socket.CAN_RAW
+
+ def testCreateSocket(self):
+ with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
+ pass
+
+ def testBindAny(self):
+ with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
+ s.bind(('', ))
+
+ def testTooLongInterfaceName(self):
+ # most systems limit IFNAMSIZ to 16, take 1024 to be sure
+ with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
+ self.assertRaisesRegex(socket.error, 'interface name too long',
+ s.bind, ('x' * 1024,))
+
+ @unittest.skipUnless(hasattr(socket, "CAN_RAW_LOOPBACK"),
+ 'socket.CAN_RAW_LOOPBACK required for this test.')
+ def testLoopback(self):
+ with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
+ for loopback in (0, 1):
+ s.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_LOOPBACK,
+ loopback)
+ self.assertEqual(loopback,
+ s.getsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_LOOPBACK))
+
+ @unittest.skipUnless(hasattr(socket, "CAN_RAW_FILTER"),
+ 'socket.CAN_RAW_FILTER required for this test.')
+ def testFilter(self):
+ can_id, can_mask = 0x200, 0x700
+ can_filter = struct.pack("=II", can_id, can_mask)
+ with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
+ s.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, can_filter)
+ self.assertEqual(can_filter,
+ s.getsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, 8))
+
+
+@unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.')
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class CANTest(ThreadedCANSocketTest):
+
+ """The CAN frame structure is defined in <linux/can.h>:
+
+ struct can_frame {
+ canid_t can_id; /* 32 bit CAN_ID + EFF/RTR/ERR flags */
+ __u8 can_dlc; /* data length code: 0 .. 8 */
+ __u8 data[8] __attribute__((aligned(8)));
+ };
+ """
+ can_frame_fmt = "=IB3x8s"
+
+ def __init__(self, methodName='runTest'):
+ ThreadedCANSocketTest.__init__(self, methodName=methodName)
+
+ @classmethod
+ def build_can_frame(cls, can_id, data):
+ """Build a CAN frame."""
+ can_dlc = len(data)
+ data = data.ljust(8, b'\x00')
+ return struct.pack(cls.can_frame_fmt, can_id, can_dlc, data)
+
+ @classmethod
+ def dissect_can_frame(cls, frame):
+ """Dissect a CAN frame."""
+ can_id, can_dlc, data = struct.unpack(cls.can_frame_fmt, frame)
+ return (can_id, can_dlc, data[:can_dlc])
+
+ def testSendFrame(self):
+ cf, addr = self.s.recvfrom(self.bufsize)
+ self.assertEqual(self.cf, cf)
+ self.assertEqual(addr[0], self.interface)
+ self.assertEqual(addr[1], socket.AF_CAN)
+
+ def _testSendFrame(self):
+ self.cf = self.build_can_frame(0x00, b'\x01\x02\x03\x04\x05')
+ self.cli.send(self.cf)
+
+ def testSendMaxFrame(self):
+ cf, addr = self.s.recvfrom(self.bufsize)
+ self.assertEqual(self.cf, cf)
+
+ def _testSendMaxFrame(self):
+ self.cf = self.build_can_frame(0x00, b'\x07' * 8)
+ self.cli.send(self.cf)
+
+ def testSendMultiFrames(self):
+ cf, addr = self.s.recvfrom(self.bufsize)
+ self.assertEqual(self.cf1, cf)
+
+ cf, addr = self.s.recvfrom(self.bufsize)
+ self.assertEqual(self.cf2, cf)
+
+ def _testSendMultiFrames(self):
+ self.cf1 = self.build_can_frame(0x07, b'\x44\x33\x22\x11')
+ self.cli.send(self.cf1)
+
+ self.cf2 = self.build_can_frame(0x12, b'\x99\x22\x33')
+ self.cli.send(self.cf2)
+
+
+@unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.')
+class BasicRDSTest(unittest.TestCase):
+
+ def testCrucialConstants(self):
+ socket.AF_RDS
+ socket.PF_RDS
+
+ def testCreateSocket(self):
+ with socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0) as s:
+ pass
+
+ def testSocketBufferSize(self):
+ bufsize = 16384
+ with socket.socket(socket.PF_RDS, socket.SOCK_SEQPACKET, 0) as s:
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, bufsize)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, bufsize)
+
+
+@unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.')
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RDSTest(ThreadedRDSSocketTest):
+
+ def __init__(self, methodName='runTest'):
+ ThreadedRDSSocketTest.__init__(self, methodName=methodName)
+
+ def setUp(self):
+ super().setUp()
+ self.evt = threading.Event()
+
+ def testSendAndRecv(self):
+ data, addr = self.serv.recvfrom(self.bufsize)
+ self.assertEqual(self.data, data)
+ self.assertEqual(self.cli_addr, addr)
+
+ def _testSendAndRecv(self):
+ self.data = b'spam'
+ self.cli.sendto(self.data, 0, (HOST, self.port))
+
+ def testPeek(self):
+ data, addr = self.serv.recvfrom(self.bufsize, socket.MSG_PEEK)
+ self.assertEqual(self.data, data)
+ data, addr = self.serv.recvfrom(self.bufsize)
+ self.assertEqual(self.data, data)
+
+ def _testPeek(self):
+ self.data = b'spam'
+ self.cli.sendto(self.data, 0, (HOST, self.port))
+
+ @requireAttrs(socket.socket, 'recvmsg')
+ def testSendAndRecvMsg(self):
+ data, ancdata, msg_flags, addr = self.serv.recvmsg(self.bufsize)
+ self.assertEqual(self.data, data)
+
+ @requireAttrs(socket.socket, 'sendmsg')
+ def _testSendAndRecvMsg(self):
+ self.data = b'hello ' * 10
+ self.cli.sendmsg([self.data], (), 0, (HOST, self.port))
+
+ def testSendAndRecvMulti(self):
+ data, addr = self.serv.recvfrom(self.bufsize)
+ self.assertEqual(self.data1, data)
+
+ data, addr = self.serv.recvfrom(self.bufsize)
+ self.assertEqual(self.data2, data)
+
+ def _testSendAndRecvMulti(self):
+ self.data1 = b'bacon'
+ self.cli.sendto(self.data1, 0, (HOST, self.port))
+
+ self.data2 = b'egg'
+ self.cli.sendto(self.data2, 0, (HOST, self.port))
+
+ def testSelect(self):
+ r, w, x = select.select([self.serv], [], [], 3.0)
+ self.assertIn(self.serv, r)
+ data, addr = self.serv.recvfrom(self.bufsize)
+ self.assertEqual(self.data, data)
+
+ def _testSelect(self):
+ self.data = b'select'
+ self.cli.sendto(self.data, 0, (HOST, self.port))
+
+ def testCongestion(self):
+ # wait until the sender is done
+ self.evt.wait()
+
+ def _testCongestion(self):
+ # test the behavior in case of congestion
+ self.data = b'fill'
+ self.cli.setblocking(False)
+ try:
+ # try to lower the receiver's socket buffer size
+ self.cli.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 16384)
+ except OSError:
+ pass
+ with self.assertRaises(OSError) as cm:
+ try:
+ # fill the receiver's socket buffer
+ while True:
+ self.cli.sendto(self.data, 0, (HOST, self.port))
+ finally:
+ # signal the receiver we're done
+ self.evt.set()
+ # sendto() should have failed with ENOBUFS
+ self.assertEqual(cm.exception.errno, errno.ENOBUFS)
+ # and we should have received a congestion notification through poll
+ r, w, x = select.select([self.serv], [], [], 3.0)
+ self.assertIn(self.serv, r)
+
+
@unittest.skipUnless(thread, 'Threading required for this test.')
class BasicTCPTest(SocketConnectedTest):
@@ -1004,6 +1632,1859 @@ class BasicUDPTest(ThreadedUDPSocketTest):
def _testRecvFromNegative(self):
self.cli.sendto(MSG, 0, (HOST, self.port))
+# Tests for the sendmsg()/recvmsg() interface. Where possible, the
+# same test code is used with different families and types of socket
+# (e.g. stream, datagram), and tests using recvmsg() are repeated
+# using recvmsg_into().
+#
+# The generic test classes such as SendmsgTests and
+# RecvmsgGenericTests inherit from SendrecvmsgBase and expect to be
+# supplied with sockets cli_sock and serv_sock representing the
+# client's and the server's end of the connection respectively, and
+# attributes cli_addr and serv_addr holding their (numeric where
+# appropriate) addresses.
+#
+# The final concrete test classes combine these with subclasses of
+# SocketTestBase which set up client and server sockets of a specific
+# type, and with subclasses of SendrecvmsgBase such as
+# SendrecvmsgDgramBase and SendrecvmsgConnectedBase which map these
+# sockets to cli_sock and serv_sock and override the methods and
+# attributes of SendrecvmsgBase to fill in destination addresses if
+# needed when sending, check for specific flags in msg_flags, etc.
+#
+# RecvmsgIntoMixin provides a version of doRecvmsg() implemented using
+# recvmsg_into().
+
+# XXX: like the other datagram (UDP) tests in this module, the code
+# here assumes that datagram delivery on the local machine will be
+# reliable.
+
+class SendrecvmsgBase(ThreadSafeCleanupTestCase):
+ # Base class for sendmsg()/recvmsg() tests.
+
+ # Time in seconds to wait before considering a test failed, or
+ # None for no timeout. Not all tests actually set a timeout.
+ fail_timeout = 3.0
+
+ def setUp(self):
+ self.misc_event = threading.Event()
+ super().setUp()
+
+ def sendToServer(self, msg):
+ # Send msg to the server.
+ return self.cli_sock.send(msg)
+
+ # Tuple of alternative default arguments for sendmsg() when called
+ # via sendmsgToServer() (e.g. to include a destination address).
+ sendmsg_to_server_defaults = ()
+
+ def sendmsgToServer(self, *args):
+ # Call sendmsg() on self.cli_sock with the given arguments,
+ # filling in any arguments which are not supplied with the
+ # corresponding items of self.sendmsg_to_server_defaults, if
+ # any.
+ return self.cli_sock.sendmsg(
+ *(args + self.sendmsg_to_server_defaults[len(args):]))
+
+ def doRecvmsg(self, sock, bufsize, *args):
+ # Call recvmsg() on sock with given arguments and return its
+ # result. Should be used for tests which can use either
+ # recvmsg() or recvmsg_into() - RecvmsgIntoMixin overrides
+ # this method with one which emulates it using recvmsg_into(),
+ # thus allowing the same test to be used for both methods.
+ result = sock.recvmsg(bufsize, *args)
+ self.registerRecvmsgResult(result)
+ return result
+
+ def registerRecvmsgResult(self, result):
+ # Called by doRecvmsg() with the return value of recvmsg() or
+ # recvmsg_into(). Can be overridden to arrange cleanup based
+ # on the returned ancillary data, for instance.
+ pass
+
+ def checkRecvmsgAddress(self, addr1, addr2):
+ # Called to compare the received address with the address of
+ # the peer.
+ self.assertEqual(addr1, addr2)
+
+ # Flags that are normally unset in msg_flags
+ msg_flags_common_unset = 0
+ for name in ("MSG_CTRUNC", "MSG_OOB"):
+ msg_flags_common_unset |= getattr(socket, name, 0)
+
+ # Flags that are normally set
+ msg_flags_common_set = 0
+
+ # Flags set when a complete record has been received (e.g. MSG_EOR
+ # for SCTP)
+ msg_flags_eor_indicator = 0
+
+ # Flags set when a complete record has not been received
+ # (e.g. MSG_TRUNC for datagram sockets)
+ msg_flags_non_eor_indicator = 0
+
+ def checkFlags(self, flags, eor=None, checkset=0, checkunset=0, ignore=0):
+ # Method to check the value of msg_flags returned by recvmsg[_into]().
+ #
+ # Checks that all bits in msg_flags_common_set attribute are
+ # set in "flags" and all bits in msg_flags_common_unset are
+ # unset.
+ #
+ # The "eor" argument specifies whether the flags should
+ # indicate that a full record (or datagram) has been received.
+ # If "eor" is None, no checks are done; otherwise, checks
+ # that:
+ #
+ # * if "eor" is true, all bits in msg_flags_eor_indicator are
+ # set and all bits in msg_flags_non_eor_indicator are unset
+ #
+ # * if "eor" is false, all bits in msg_flags_non_eor_indicator
+ # are set and all bits in msg_flags_eor_indicator are unset
+ #
+ # If "checkset" and/or "checkunset" are supplied, they require
+ # the given bits to be set or unset respectively, overriding
+ # what the attributes require for those bits.
+ #
+ # If any bits are set in "ignore", they will not be checked,
+ # regardless of the other inputs.
+ #
+ # Will raise Exception if the inputs require a bit to be both
+ # set and unset, and it is not ignored.
+
+ defaultset = self.msg_flags_common_set
+ defaultunset = self.msg_flags_common_unset
+
+ if eor:
+ defaultset |= self.msg_flags_eor_indicator
+ defaultunset |= self.msg_flags_non_eor_indicator
+ elif eor is not None:
+ defaultset |= self.msg_flags_non_eor_indicator
+ defaultunset |= self.msg_flags_eor_indicator
+
+ # Function arguments override defaults
+ defaultset &= ~checkunset
+ defaultunset &= ~checkset
+
+ # Merge arguments with remaining defaults, and check for conflicts
+ checkset |= defaultset
+ checkunset |= defaultunset
+ inboth = checkset & checkunset & ~ignore
+ if inboth:
+ raise Exception("contradictory set, unset requirements for flags "
+ "{0:#x}".format(inboth))
+
+ # Compare with given msg_flags value
+ mask = (checkset | checkunset) & ~ignore
+ self.assertEqual(flags & mask, checkset & mask)
+
+
+class RecvmsgIntoMixin(SendrecvmsgBase):
+ # Mixin to implement doRecvmsg() using recvmsg_into().
+
+ def doRecvmsg(self, sock, bufsize, *args):
+ buf = bytearray(bufsize)
+ result = sock.recvmsg_into([buf], *args)
+ self.registerRecvmsgResult(result)
+ self.assertGreaterEqual(result[0], 0)
+ self.assertLessEqual(result[0], bufsize)
+ return (bytes(buf[:result[0]]),) + result[1:]
+
+
+class SendrecvmsgDgramFlagsBase(SendrecvmsgBase):
+ # Defines flags to be checked in msg_flags for datagram sockets.
+
+ @property
+ def msg_flags_non_eor_indicator(self):
+ return super().msg_flags_non_eor_indicator | socket.MSG_TRUNC
+
+
+class SendrecvmsgSCTPFlagsBase(SendrecvmsgBase):
+ # Defines flags to be checked in msg_flags for SCTP sockets.
+
+ @property
+ def msg_flags_eor_indicator(self):
+ return super().msg_flags_eor_indicator | socket.MSG_EOR
+
+
+class SendrecvmsgConnectionlessBase(SendrecvmsgBase):
+ # Base class for tests on connectionless-mode sockets. Users must
+ # supply sockets on attributes cli and serv to be mapped to
+ # cli_sock and serv_sock respectively.
+
+ @property
+ def serv_sock(self):
+ return self.serv
+
+ @property
+ def cli_sock(self):
+ return self.cli
+
+ @property
+ def sendmsg_to_server_defaults(self):
+ return ([], [], 0, self.serv_addr)
+
+ def sendToServer(self, msg):
+ return self.cli_sock.sendto(msg, self.serv_addr)
+
+
+class SendrecvmsgConnectedBase(SendrecvmsgBase):
+ # Base class for tests on connected sockets. Users must supply
+ # sockets on attributes serv_conn and cli_conn (representing the
+ # connections *to* the server and the client), to be mapped to
+ # cli_sock and serv_sock respectively.
+
+ @property
+ def serv_sock(self):
+ return self.cli_conn
+
+ @property
+ def cli_sock(self):
+ return self.serv_conn
+
+ def checkRecvmsgAddress(self, addr1, addr2):
+ # Address is currently "unspecified" for a connected socket,
+ # so we don't examine it
+ pass
+
+
+class SendrecvmsgServerTimeoutBase(SendrecvmsgBase):
+ # Base class to set a timeout on server's socket.
+
+ def setUp(self):
+ super().setUp()
+ self.serv_sock.settimeout(self.fail_timeout)
+
+
+class SendmsgTests(SendrecvmsgServerTimeoutBase):
+ # Tests for sendmsg() which can use any socket type and do not
+ # involve recvmsg() or recvmsg_into().
+
+ def testSendmsg(self):
+ # Send a simple message with sendmsg().
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsg(self):
+ self.assertEqual(self.sendmsgToServer([MSG]), len(MSG))
+
+ def testSendmsgDataGenerator(self):
+ # Send from buffer obtained from a generator (not a sequence).
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgDataGenerator(self):
+ self.assertEqual(self.sendmsgToServer((o for o in [MSG])),
+ len(MSG))
+
+ def testSendmsgAncillaryGenerator(self):
+ # Gather (empty) ancillary data from a generator.
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgAncillaryGenerator(self):
+ self.assertEqual(self.sendmsgToServer([MSG], (o for o in [])),
+ len(MSG))
+
+ def testSendmsgArray(self):
+ # Send data from an array instead of the usual bytes object.
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgArray(self):
+ self.assertEqual(self.sendmsgToServer([array.array("B", MSG)]),
+ len(MSG))
+
+ def testSendmsgGather(self):
+ # Send message data from more than one buffer (gather write).
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgGather(self):
+ self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG))
+
+ def testSendmsgBadArgs(self):
+ # Check that sendmsg() rejects invalid arguments.
+ self.assertEqual(self.serv_sock.recv(1000), b"done")
+
+ def _testSendmsgBadArgs(self):
+ self.assertRaises(TypeError, self.cli_sock.sendmsg)
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ b"not in an iterable")
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ object())
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [object()])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG, object()])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], object())
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [], object())
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [], 0, object())
+ self.sendToServer(b"done")
+
+ def testSendmsgBadCmsg(self):
+ # Check that invalid ancillary data items are rejected.
+ self.assertEqual(self.serv_sock.recv(1000), b"done")
+
+ def _testSendmsgBadCmsg(self):
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [object()])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(object(), 0, b"data")])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, object(), b"data")])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, 0, object())])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, 0)])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, 0, b"data", 42)])
+ self.sendToServer(b"done")
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testSendmsgBadMultiCmsg(self):
+ # Check that invalid ancillary data items are rejected when
+ # more than one item is present.
+ self.assertEqual(self.serv_sock.recv(1000), b"done")
+
+ @testSendmsgBadMultiCmsg.client_skip
+ def _testSendmsgBadMultiCmsg(self):
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [0, 0, b""])
+ self.assertRaises(TypeError, self.sendmsgToServer,
+ [MSG], [(0, 0, b""), object()])
+ self.sendToServer(b"done")
+
+ def testSendmsgExcessCmsgReject(self):
+ # Check that sendmsg() rejects excess ancillary data items
+ # when the number that can be sent is limited.
+ self.assertEqual(self.serv_sock.recv(1000), b"done")
+
+ def _testSendmsgExcessCmsgReject(self):
+ if not hasattr(socket, "CMSG_SPACE"):
+ # Can only send one item
+ with self.assertRaises(socket.error) as cm:
+ self.sendmsgToServer([MSG], [(0, 0, b""), (0, 0, b"")])
+ self.assertIsNone(cm.exception.errno)
+ self.sendToServer(b"done")
+
+ def testSendmsgAfterClose(self):
+ # Check that sendmsg() fails on a closed socket.
+ pass
+
+ def _testSendmsgAfterClose(self):
+ self.cli_sock.close()
+ self.assertRaises(socket.error, self.sendmsgToServer, [MSG])
+
+
+class SendmsgStreamTests(SendmsgTests):
+ # Tests for sendmsg() which require a stream socket and do not
+ # involve recvmsg() or recvmsg_into().
+
+ def testSendmsgExplicitNoneAddr(self):
+ # Check that peer address can be specified as None.
+ self.assertEqual(self.serv_sock.recv(len(MSG)), MSG)
+
+ def _testSendmsgExplicitNoneAddr(self):
+ self.assertEqual(self.sendmsgToServer([MSG], [], 0, None), len(MSG))
+
+ def testSendmsgTimeout(self):
+ # Check that timeout works with sendmsg().
+ self.assertEqual(self.serv_sock.recv(512), b"a"*512)
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+
+ def _testSendmsgTimeout(self):
+ try:
+ self.cli_sock.settimeout(0.03)
+ with self.assertRaises(socket.timeout):
+ while True:
+ self.sendmsgToServer([b"a"*512])
+ finally:
+ self.misc_event.set()
+
+ # XXX: would be nice to have more tests for sendmsg flags argument.
+
+ # Linux supports MSG_DONTWAIT when sending, but in general, it
+ # only works when receiving. Could add other platforms if they
+ # support it too.
+ @skipWithClientIf(sys.platform not in {"linux2"},
+ "MSG_DONTWAIT not known to work on this platform when "
+ "sending")
+ def testSendmsgDontWait(self):
+ # Check that MSG_DONTWAIT in flags causes non-blocking behaviour.
+ self.assertEqual(self.serv_sock.recv(512), b"a"*512)
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+
+ @testSendmsgDontWait.client_skip
+ def _testSendmsgDontWait(self):
+ try:
+ with self.assertRaises(socket.error) as cm:
+ while True:
+ self.sendmsgToServer([b"a"*512], [], socket.MSG_DONTWAIT)
+ self.assertIn(cm.exception.errno,
+ (errno.EAGAIN, errno.EWOULDBLOCK))
+ finally:
+ self.misc_event.set()
+
+
+class SendmsgConnectionlessTests(SendmsgTests):
+ # Tests for sendmsg() which require a connectionless-mode
+ # (e.g. datagram) socket, and do not involve recvmsg() or
+ # recvmsg_into().
+
+ def testSendmsgNoDestAddr(self):
+ # Check that sendmsg() fails when no destination address is
+ # given for unconnected socket.
+ pass
+
+ def _testSendmsgNoDestAddr(self):
+ self.assertRaises(socket.error, self.cli_sock.sendmsg,
+ [MSG])
+ self.assertRaises(socket.error, self.cli_sock.sendmsg,
+ [MSG], [], 0, None)
+
+
+class RecvmsgGenericTests(SendrecvmsgBase):
+ # Tests for recvmsg() which can also be emulated using
+ # recvmsg_into(), and can use any socket type.
+
+ def testRecvmsg(self):
+ # Receive a simple message with recvmsg[_into]().
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsg(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgExplicitDefaults(self):
+ # Test recvmsg[_into]() with default arguments provided explicitly.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 0, 0)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgExplicitDefaults(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgShorter(self):
+ # Receive a message smaller than buffer.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG) + 42)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgShorter(self):
+ self.sendToServer(MSG)
+
+ # FreeBSD < 8 doesn't always set the MSG_TRUNC flag when a truncated
+ # datagram is received (issue #13001).
+ @support.requires_freebsd_version(8)
+ def testRecvmsgTrunc(self):
+ # Receive part of message, check for truncation indicators.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG) - 3)
+ self.assertEqual(msg, MSG[:-3])
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=False)
+
+ @support.requires_freebsd_version(8)
+ def _testRecvmsgTrunc(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgShortAncillaryBuf(self):
+ # Test ancillary data buffer too small to hold any ancillary data.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 1)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgShortAncillaryBuf(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgLongAncillaryBuf(self):
+ # Test large ancillary data buffer.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgLongAncillaryBuf(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgAfterClose(self):
+ # Check that recvmsg[_into]() fails on a closed socket.
+ self.serv_sock.close()
+ self.assertRaises(socket.error, self.doRecvmsg, self.serv_sock, 1024)
+
+ def _testRecvmsgAfterClose(self):
+ pass
+
+ def testRecvmsgTimeout(self):
+ # Check that timeout works.
+ try:
+ self.serv_sock.settimeout(0.03)
+ self.assertRaises(socket.timeout,
+ self.doRecvmsg, self.serv_sock, len(MSG))
+ finally:
+ self.misc_event.set()
+
+ def _testRecvmsgTimeout(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+
+ @requireAttrs(socket, "MSG_PEEK")
+ def testRecvmsgPeek(self):
+ # Check that MSG_PEEK in flags enables examination of pending
+ # data without consuming it.
+
+ # Receive part of data with MSG_PEEK.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG) - 3, 0,
+ socket.MSG_PEEK)
+ self.assertEqual(msg, MSG[:-3])
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ # Ignoring MSG_TRUNC here (so this test is the same for stream
+ # and datagram sockets). Some wording in POSIX seems to
+ # suggest that it needn't be set when peeking, but that may
+ # just be a slip.
+ self.checkFlags(flags, eor=False,
+ ignore=getattr(socket, "MSG_TRUNC", 0))
+
+ # Receive all data with MSG_PEEK.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 0,
+ socket.MSG_PEEK)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ # Check that the same data can still be received normally.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ @testRecvmsgPeek.client_skip
+ def _testRecvmsgPeek(self):
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket.socket, "sendmsg")
+ def testRecvmsgFromSendmsg(self):
+ # Test receiving with recvmsg[_into]() when message is sent
+ # using sendmsg().
+ self.serv_sock.settimeout(self.fail_timeout)
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, len(MSG))
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ @testRecvmsgFromSendmsg.client_skip
+ def _testRecvmsgFromSendmsg(self):
+ self.assertEqual(self.sendmsgToServer([MSG[:3], MSG[3:]]), len(MSG))
+
+
+class RecvmsgGenericStreamTests(RecvmsgGenericTests):
+ # Tests which require a stream socket and can use either recvmsg()
+ # or recvmsg_into().
+
+ def testRecvmsgEOF(self):
+ # Receive end-of-stream indicator (b"", peer socket closed).
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024)
+ self.assertEqual(msg, b"")
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=None) # Might not have end-of-record marker
+
+ def _testRecvmsgEOF(self):
+ self.cli_sock.close()
+
+ def testRecvmsgOverflow(self):
+ # Receive a message in more than one chunk.
+ seg1, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG) - 3)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=False)
+
+ seg2, ancdata, flags, addr = self.doRecvmsg(self.serv_sock, 1024)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ msg = seg1 + seg2
+ self.assertEqual(msg, MSG)
+
+ def _testRecvmsgOverflow(self):
+ self.sendToServer(MSG)
+
+
+class RecvmsgTests(RecvmsgGenericTests):
+ # Tests for recvmsg() which can use any socket type.
+
+ def testRecvmsgBadArgs(self):
+ # Check that recvmsg() rejects invalid arguments.
+ self.assertRaises(TypeError, self.serv_sock.recvmsg)
+ self.assertRaises(ValueError, self.serv_sock.recvmsg,
+ -1, 0, 0)
+ self.assertRaises(ValueError, self.serv_sock.recvmsg,
+ len(MSG), -1, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg,
+ [bytearray(10)], 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg,
+ object(), 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg,
+ len(MSG), object(), 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg,
+ len(MSG), 0, object())
+
+ msg, ancdata, flags, addr = self.serv_sock.recvmsg(len(MSG), 0, 0)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgBadArgs(self):
+ self.sendToServer(MSG)
+
+
+class RecvmsgIntoTests(RecvmsgIntoMixin, RecvmsgGenericTests):
+ # Tests for recvmsg_into() which can use any socket type.
+
+ def testRecvmsgIntoBadArgs(self):
+ # Check that recvmsg_into() rejects invalid arguments.
+ buf = bytearray(len(MSG))
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ len(MSG), 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ buf, 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [object()], 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [b"I'm not writable"], 0, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [buf, object()], 0, 0)
+ self.assertRaises(ValueError, self.serv_sock.recvmsg_into,
+ [buf], -1, 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [buf], object(), 0)
+ self.assertRaises(TypeError, self.serv_sock.recvmsg_into,
+ [buf], 0, object())
+
+ nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf], 0, 0)
+ self.assertEqual(nbytes, len(MSG))
+ self.assertEqual(buf, bytearray(MSG))
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgIntoBadArgs(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgIntoGenerator(self):
+ # Receive into buffer obtained from a generator (not a sequence).
+ buf = bytearray(len(MSG))
+ nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into(
+ (o for o in [buf]))
+ self.assertEqual(nbytes, len(MSG))
+ self.assertEqual(buf, bytearray(MSG))
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgIntoGenerator(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgIntoArray(self):
+ # Receive into an array rather than the usual bytearray.
+ buf = array.array("B", [0] * len(MSG))
+ nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into([buf])
+ self.assertEqual(nbytes, len(MSG))
+ self.assertEqual(buf.tobytes(), MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgIntoArray(self):
+ self.sendToServer(MSG)
+
+ def testRecvmsgIntoScatter(self):
+ # Receive into multiple buffers (scatter write).
+ b1 = bytearray(b"----")
+ b2 = bytearray(b"0123456789")
+ b3 = bytearray(b"--------------")
+ nbytes, ancdata, flags, addr = self.serv_sock.recvmsg_into(
+ [b1, memoryview(b2)[2:9], b3])
+ self.assertEqual(nbytes, len(b"Mary had a little lamb"))
+ self.assertEqual(b1, bytearray(b"Mary"))
+ self.assertEqual(b2, bytearray(b"01 had a 9"))
+ self.assertEqual(b3, bytearray(b"little lamb---"))
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True)
+
+ def _testRecvmsgIntoScatter(self):
+ self.sendToServer(b"Mary had a little lamb")
+
+
+class CmsgMacroTests(unittest.TestCase):
+ # Test the functions CMSG_LEN() and CMSG_SPACE(). Tests
+ # assumptions used by sendmsg() and recvmsg[_into](), which share
+ # code with these functions.
+
+ # Match the definition in socketmodule.c
+ socklen_t_limit = min(0x7fffffff, _testcapi.INT_MAX)
+
+ @requireAttrs(socket, "CMSG_LEN")
+ def testCMSG_LEN(self):
+ # Test CMSG_LEN() with various valid and invalid values,
+ # checking the assumptions used by recvmsg() and sendmsg().
+ toobig = self.socklen_t_limit - socket.CMSG_LEN(0) + 1
+ values = list(range(257)) + list(range(toobig - 257, toobig))
+
+ # struct cmsghdr has at least three members, two of which are ints
+ self.assertGreater(socket.CMSG_LEN(0), array.array("i").itemsize * 2)
+ for n in values:
+ ret = socket.CMSG_LEN(n)
+ # This is how recvmsg() calculates the data size
+ self.assertEqual(ret - socket.CMSG_LEN(0), n)
+ self.assertLessEqual(ret, self.socklen_t_limit)
+
+ self.assertRaises(OverflowError, socket.CMSG_LEN, -1)
+ # sendmsg() shares code with these functions, and requires
+ # that it reject values over the limit.
+ self.assertRaises(OverflowError, socket.CMSG_LEN, toobig)
+ self.assertRaises(OverflowError, socket.CMSG_LEN, sys.maxsize)
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testCMSG_SPACE(self):
+ # Test CMSG_SPACE() with various valid and invalid values,
+ # checking the assumptions used by sendmsg().
+ toobig = self.socklen_t_limit - socket.CMSG_SPACE(1) + 1
+ values = list(range(257)) + list(range(toobig - 257, toobig))
+
+ last = socket.CMSG_SPACE(0)
+ # struct cmsghdr has at least three members, two of which are ints
+ self.assertGreater(last, array.array("i").itemsize * 2)
+ for n in values:
+ ret = socket.CMSG_SPACE(n)
+ self.assertGreaterEqual(ret, last)
+ self.assertGreaterEqual(ret, socket.CMSG_LEN(n))
+ self.assertGreaterEqual(ret, n + socket.CMSG_LEN(0))
+ self.assertLessEqual(ret, self.socklen_t_limit)
+ last = ret
+
+ self.assertRaises(OverflowError, socket.CMSG_SPACE, -1)
+ # sendmsg() shares code with these functions, and requires
+ # that it reject values over the limit.
+ self.assertRaises(OverflowError, socket.CMSG_SPACE, toobig)
+ self.assertRaises(OverflowError, socket.CMSG_SPACE, sys.maxsize)
+
+
+class SCMRightsTest(SendrecvmsgServerTimeoutBase):
+ # Tests for file descriptor passing on Unix-domain sockets.
+
+ # Invalid file descriptor value that's unlikely to evaluate to a
+ # real FD even if one of its bytes is replaced with a different
+ # value (which shouldn't actually happen).
+ badfd = -0x5555
+
+ def newFDs(self, n):
+ # Return a list of n file descriptors for newly-created files
+ # containing their list indices as ASCII numbers.
+ fds = []
+ for i in range(n):
+ fd, path = tempfile.mkstemp()
+ self.addCleanup(os.unlink, path)
+ self.addCleanup(os.close, fd)
+ os.write(fd, str(i).encode())
+ fds.append(fd)
+ return fds
+
+ def checkFDs(self, fds):
+ # Check that the file descriptors in the given list contain
+ # their correct list indices as ASCII numbers.
+ for n, fd in enumerate(fds):
+ os.lseek(fd, 0, os.SEEK_SET)
+ self.assertEqual(os.read(fd, 1024), str(n).encode())
+
+ def registerRecvmsgResult(self, result):
+ self.addCleanup(self.closeRecvmsgFDs, result)
+
+ def closeRecvmsgFDs(self, recvmsg_result):
+ # Close all file descriptors specified in the ancillary data
+ # of the given return value from recvmsg() or recvmsg_into().
+ for cmsg_level, cmsg_type, cmsg_data in recvmsg_result[1]:
+ if (cmsg_level == socket.SOL_SOCKET and
+ cmsg_type == socket.SCM_RIGHTS):
+ fds = array.array("i")
+ fds.frombytes(cmsg_data[:
+ len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
+ for fd in fds:
+ os.close(fd)
+
+ def createAndSendFDs(self, n):
+ # Send n new file descriptors created by newFDs() to the
+ # server, with the constant MSG as the non-ancillary data.
+ self.assertEqual(
+ self.sendmsgToServer([MSG],
+ [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", self.newFDs(n)))]),
+ len(MSG))
+
+ def checkRecvmsgFDs(self, numfds, result, maxcmsgs=1, ignoreflags=0):
+ # Check that constant MSG was received with numfds file
+ # descriptors in a maximum of maxcmsgs control messages (which
+ # must contain only complete integers). By default, check
+ # that MSG_CTRUNC is unset, but ignore any flags in
+ # ignoreflags.
+ msg, ancdata, flags, addr = result
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ self.assertIsInstance(ancdata, list)
+ self.assertLessEqual(len(ancdata), maxcmsgs)
+ fds = array.array("i")
+ for item in ancdata:
+ self.assertIsInstance(item, tuple)
+ cmsg_level, cmsg_type, cmsg_data = item
+ self.assertEqual(cmsg_level, socket.SOL_SOCKET)
+ self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
+ self.assertIsInstance(cmsg_data, bytes)
+ self.assertEqual(len(cmsg_data) % SIZEOF_INT, 0)
+ fds.frombytes(cmsg_data)
+
+ self.assertEqual(len(fds), numfds)
+ self.checkFDs(fds)
+
+ def testFDPassSimple(self):
+ # Pass a single FD (array read from bytes object).
+ self.checkRecvmsgFDs(1, self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240))
+
+ def _testFDPassSimple(self):
+ self.assertEqual(
+ self.sendmsgToServer(
+ [MSG],
+ [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", self.newFDs(1)).tobytes())]),
+ len(MSG))
+
+ def testMultipleFDPass(self):
+ # Pass multiple FDs in a single array.
+ self.checkRecvmsgFDs(4, self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240))
+
+ def _testMultipleFDPass(self):
+ self.createAndSendFDs(4)
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testFDPassCMSG_SPACE(self):
+ # Test using CMSG_SPACE() to calculate ancillary buffer size.
+ self.checkRecvmsgFDs(
+ 4, self.doRecvmsg(self.serv_sock, len(MSG),
+ socket.CMSG_SPACE(4 * SIZEOF_INT)))
+
+ @testFDPassCMSG_SPACE.client_skip
+ def _testFDPassCMSG_SPACE(self):
+ self.createAndSendFDs(4)
+
+ def testFDPassCMSG_LEN(self):
+ # Test using CMSG_LEN() to calculate ancillary buffer size.
+ self.checkRecvmsgFDs(1,
+ self.doRecvmsg(self.serv_sock, len(MSG),
+ socket.CMSG_LEN(4 * SIZEOF_INT)),
+ # RFC 3542 says implementations may set
+ # MSG_CTRUNC if there isn't enough space
+ # for trailing padding.
+ ignoreflags=socket.MSG_CTRUNC)
+
+ def _testFDPassCMSG_LEN(self):
+ self.createAndSendFDs(1)
+
+ # Issue #12958: The following test has problems on Mac OS X
+ @support.anticipate_failure(sys.platform == "darwin")
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testFDPassSeparate(self):
+ # Pass two FDs in two separate arrays. Arrays may be combined
+ # into a single control message by the OS.
+ self.checkRecvmsgFDs(2,
+ self.doRecvmsg(self.serv_sock, len(MSG), 10240),
+ maxcmsgs=2)
+
+ @testFDPassSeparate.client_skip
+ @support.anticipate_failure(sys.platform == "darwin")
+ def _testFDPassSeparate(self):
+ fd0, fd1 = self.newFDs(2)
+ self.assertEqual(
+ self.sendmsgToServer([MSG], [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd0])),
+ (socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd1]))]),
+ len(MSG))
+
+ # Issue #12958: The following test has problems on Mac OS X
+ @support.anticipate_failure(sys.platform == "darwin")
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testFDPassSeparateMinSpace(self):
+ # Pass two FDs in two separate arrays, receiving them into the
+ # minimum space for two arrays.
+ self.checkRecvmsgFDs(2,
+ self.doRecvmsg(self.serv_sock, len(MSG),
+ socket.CMSG_SPACE(SIZEOF_INT) +
+ socket.CMSG_LEN(SIZEOF_INT)),
+ maxcmsgs=2, ignoreflags=socket.MSG_CTRUNC)
+
+ @testFDPassSeparateMinSpace.client_skip
+ @support.anticipate_failure(sys.platform == "darwin")
+ def _testFDPassSeparateMinSpace(self):
+ fd0, fd1 = self.newFDs(2)
+ self.assertEqual(
+ self.sendmsgToServer([MSG], [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd0])),
+ (socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd1]))]),
+ len(MSG))
+
+ def sendAncillaryIfPossible(self, msg, ancdata):
+ # Try to send msg and ancdata to server, but if the system
+ # call fails, just send msg with no ancillary data.
+ try:
+ nbytes = self.sendmsgToServer([msg], ancdata)
+ except socket.error as e:
+ # Check that it was the system call that failed
+ self.assertIsInstance(e.errno, int)
+ nbytes = self.sendmsgToServer([msg])
+ self.assertEqual(nbytes, len(msg))
+
+ def testFDPassEmpty(self):
+ # Try to pass an empty FD array. Can receive either no array
+ # or an empty array.
+ self.checkRecvmsgFDs(0, self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240),
+ ignoreflags=socket.MSG_CTRUNC)
+
+ def _testFDPassEmpty(self):
+ self.sendAncillaryIfPossible(MSG, [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ b"")])
+
+ def testFDPassPartialInt(self):
+ # Try to pass a truncated FD array.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC)
+ self.assertLessEqual(len(ancdata), 1)
+ for cmsg_level, cmsg_type, cmsg_data in ancdata:
+ self.assertEqual(cmsg_level, socket.SOL_SOCKET)
+ self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
+ self.assertLess(len(cmsg_data), SIZEOF_INT)
+
+ def _testFDPassPartialInt(self):
+ self.sendAncillaryIfPossible(
+ MSG,
+ [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [self.badfd]).tobytes()[:-1])])
+
+ @requireAttrs(socket, "CMSG_SPACE")
+ def testFDPassPartialIntInMiddle(self):
+ # Try to pass two FD arrays, the first of which is truncated.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), 10240)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, ignore=socket.MSG_CTRUNC)
+ self.assertLessEqual(len(ancdata), 2)
+ fds = array.array("i")
+ # Arrays may have been combined in a single control message
+ for cmsg_level, cmsg_type, cmsg_data in ancdata:
+ self.assertEqual(cmsg_level, socket.SOL_SOCKET)
+ self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
+ fds.frombytes(cmsg_data[:
+ len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
+ self.assertLessEqual(len(fds), 2)
+ self.checkFDs(fds)
+
+ @testFDPassPartialIntInMiddle.client_skip
+ def _testFDPassPartialIntInMiddle(self):
+ fd0, fd1 = self.newFDs(2)
+ self.sendAncillaryIfPossible(
+ MSG,
+ [(socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd0, self.badfd]).tobytes()[:-1]),
+ (socket.SOL_SOCKET,
+ socket.SCM_RIGHTS,
+ array.array("i", [fd1]))])
+
+ def checkTruncatedHeader(self, result, ignoreflags=0):
+ # Check that no ancillary data items are returned when data is
+ # truncated inside the cmsghdr structure.
+ msg, ancdata, flags, addr = result
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ def testCmsgTruncNoBufSize(self):
+ # Check that no ancillary data is received when no buffer size
+ # is specified.
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG)),
+ # BSD seems to set MSG_CTRUNC only
+ # if an item has been partially
+ # received.
+ ignoreflags=socket.MSG_CTRUNC)
+
+ def _testCmsgTruncNoBufSize(self):
+ self.createAndSendFDs(1)
+
+ def testCmsgTrunc0(self):
+ # Check that no ancillary data is received when buffer size is 0.
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 0),
+ ignoreflags=socket.MSG_CTRUNC)
+
+ def _testCmsgTrunc0(self):
+ self.createAndSendFDs(1)
+
+ # Check that no ancillary data is returned for various non-zero
+ # (but still too small) buffer sizes.
+
+ def testCmsgTrunc1(self):
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG), 1))
+
+ def _testCmsgTrunc1(self):
+ self.createAndSendFDs(1)
+
+ def testCmsgTrunc2Int(self):
+ # The cmsghdr structure has at least three members, two of
+ # which are ints, so we still shouldn't see any ancillary
+ # data.
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG),
+ SIZEOF_INT * 2))
+
+ def _testCmsgTrunc2Int(self):
+ self.createAndSendFDs(1)
+
+ def testCmsgTruncLen0Minus1(self):
+ self.checkTruncatedHeader(self.doRecvmsg(self.serv_sock, len(MSG),
+ socket.CMSG_LEN(0) - 1))
+
+ def _testCmsgTruncLen0Minus1(self):
+ self.createAndSendFDs(1)
+
+ # The following tests try to truncate the control message in the
+ # middle of the FD array.
+
+ def checkTruncatedArray(self, ancbuf, maxdata, mindata=0):
+ # Check that file descriptor data is truncated to between
+ # mindata and maxdata bytes when received with buffer size
+ # ancbuf, and that any complete file descriptor numbers are
+ # valid.
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), ancbuf)
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
+
+ if mindata == 0 and ancdata == []:
+ return
+ self.assertEqual(len(ancdata), 1)
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ self.assertEqual(cmsg_level, socket.SOL_SOCKET)
+ self.assertEqual(cmsg_type, socket.SCM_RIGHTS)
+ self.assertGreaterEqual(len(cmsg_data), mindata)
+ self.assertLessEqual(len(cmsg_data), maxdata)
+ fds = array.array("i")
+ fds.frombytes(cmsg_data[:
+ len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
+ self.checkFDs(fds)
+
+ def testCmsgTruncLen0(self):
+ self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0), maxdata=0)
+
+ def _testCmsgTruncLen0(self):
+ self.createAndSendFDs(1)
+
+ def testCmsgTruncLen0Plus1(self):
+ self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(0) + 1, maxdata=1)
+
+ def _testCmsgTruncLen0Plus1(self):
+ self.createAndSendFDs(2)
+
+ def testCmsgTruncLen1(self):
+ self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(SIZEOF_INT),
+ maxdata=SIZEOF_INT)
+
+ def _testCmsgTruncLen1(self):
+ self.createAndSendFDs(2)
+
+ def testCmsgTruncLen2Minus1(self):
+ self.checkTruncatedArray(ancbuf=socket.CMSG_LEN(2 * SIZEOF_INT) - 1,
+ maxdata=(2 * SIZEOF_INT) - 1)
+
+ def _testCmsgTruncLen2Minus1(self):
+ self.createAndSendFDs(2)
+
+
+class RFC3542AncillaryTest(SendrecvmsgServerTimeoutBase):
+ # Test sendmsg() and recvmsg[_into]() using the ancillary data
+ # features of the RFC 3542 Advanced Sockets API for IPv6.
+ # Currently we can only handle certain data items (e.g. traffic
+ # class, hop limit, MTU discovery and fragmentation settings)
+ # without resorting to unportable means such as the struct module,
+ # but the tests here are aimed at testing the ancillary data
+ # handling in sendmsg() and recvmsg() rather than the IPv6 API
+ # itself.
+
+ # Test value to use when setting hop limit of packet
+ hop_limit = 2
+
+ # Test value to use when setting traffic class of packet.
+ # -1 means "use kernel default".
+ traffic_class = -1
+
+ def ancillaryMapping(self, ancdata):
+ # Given ancillary data list ancdata, return a mapping from
+ # pairs (cmsg_level, cmsg_type) to corresponding cmsg_data.
+ # Check that no (level, type) pair appears more than once.
+ d = {}
+ for cmsg_level, cmsg_type, cmsg_data in ancdata:
+ self.assertNotIn((cmsg_level, cmsg_type), d)
+ d[(cmsg_level, cmsg_type)] = cmsg_data
+ return d
+
+ def checkHopLimit(self, ancbufsize, maxhop=255, ignoreflags=0):
+ # Receive hop limit into ancbufsize bytes of ancillary data
+ # space. Check that data is MSG, ancillary data is not
+ # truncated (but ignore any flags in ignoreflags), and hop
+ # limit is between 0 and maxhop inclusive.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), ancbufsize)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ self.assertEqual(len(ancdata), 1)
+ self.assertIsInstance(ancdata[0], tuple)
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ self.assertEqual(cmsg_type, socket.IPV6_HOPLIMIT)
+ self.assertIsInstance(cmsg_data, bytes)
+ self.assertEqual(len(cmsg_data), SIZEOF_INT)
+ a = array.array("i")
+ a.frombytes(cmsg_data)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], maxhop)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testRecvHopLimit(self):
+ # Test receiving the packet hop limit as ancillary data.
+ self.checkHopLimit(ancbufsize=10240)
+
+ @testRecvHopLimit.client_skip
+ def _testRecvHopLimit(self):
+ # Need to wait until server has asked to receive ancillary
+ # data, as implementations are not required to buffer it
+ # otherwise.
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testRecvHopLimitCMSG_SPACE(self):
+ # Test receiving hop limit, using CMSG_SPACE to calculate buffer size.
+ self.checkHopLimit(ancbufsize=socket.CMSG_SPACE(SIZEOF_INT))
+
+ @testRecvHopLimitCMSG_SPACE.client_skip
+ def _testRecvHopLimitCMSG_SPACE(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ # Could test receiving into buffer sized using CMSG_LEN, but RFC
+ # 3542 says portable applications must provide space for trailing
+ # padding. Implementations may set MSG_CTRUNC if there isn't
+ # enough space for the padding.
+
+ @requireAttrs(socket.socket, "sendmsg")
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSetHopLimit(self):
+ # Test setting hop limit on outgoing packet and receiving it
+ # at the other end.
+ self.checkHopLimit(ancbufsize=10240, maxhop=self.hop_limit)
+
+ @testSetHopLimit.client_skip
+ def _testSetHopLimit(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.assertEqual(
+ self.sendmsgToServer([MSG],
+ [(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
+ array.array("i", [self.hop_limit]))]),
+ len(MSG))
+
+ def checkTrafficClassAndHopLimit(self, ancbufsize, maxhop=255,
+ ignoreflags=0):
+ # Receive traffic class and hop limit into ancbufsize bytes of
+ # ancillary data space. Check that data is MSG, ancillary
+ # data is not truncated (but ignore any flags in ignoreflags),
+ # and traffic class and hop limit are in range (hop limit no
+ # more than maxhop).
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVTCLASS, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), ancbufsize)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkunset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+ self.assertEqual(len(ancdata), 2)
+ ancmap = self.ancillaryMapping(ancdata)
+
+ tcdata = ancmap[(socket.IPPROTO_IPV6, socket.IPV6_TCLASS)]
+ self.assertEqual(len(tcdata), SIZEOF_INT)
+ a = array.array("i")
+ a.frombytes(tcdata)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], 255)
+
+ hldata = ancmap[(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT)]
+ self.assertEqual(len(hldata), SIZEOF_INT)
+ a = array.array("i")
+ a.frombytes(hldata)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], maxhop)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testRecvTrafficClassAndHopLimit(self):
+ # Test receiving traffic class and hop limit as ancillary data.
+ self.checkTrafficClassAndHopLimit(ancbufsize=10240)
+
+ @testRecvTrafficClassAndHopLimit.client_skip
+ def _testRecvTrafficClassAndHopLimit(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testRecvTrafficClassAndHopLimitCMSG_SPACE(self):
+ # Test receiving traffic class and hop limit, using
+ # CMSG_SPACE() to calculate buffer size.
+ self.checkTrafficClassAndHopLimit(
+ ancbufsize=socket.CMSG_SPACE(SIZEOF_INT) * 2)
+
+ @testRecvTrafficClassAndHopLimitCMSG_SPACE.client_skip
+ def _testRecvTrafficClassAndHopLimitCMSG_SPACE(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket.socket, "sendmsg")
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSetTrafficClassAndHopLimit(self):
+ # Test setting traffic class and hop limit on outgoing packet,
+ # and receiving them at the other end.
+ self.checkTrafficClassAndHopLimit(ancbufsize=10240,
+ maxhop=self.hop_limit)
+
+ @testSetTrafficClassAndHopLimit.client_skip
+ def _testSetTrafficClassAndHopLimit(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.assertEqual(
+ self.sendmsgToServer([MSG],
+ [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
+ array.array("i", [self.traffic_class])),
+ (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
+ array.array("i", [self.hop_limit]))]),
+ len(MSG))
+
+ @requireAttrs(socket.socket, "sendmsg")
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testOddCmsgSize(self):
+ # Try to send ancillary data with first item one byte too
+ # long. Fall back to sending with correct size if this fails,
+ # and check that second item was handled correctly.
+ self.checkTrafficClassAndHopLimit(ancbufsize=10240,
+ maxhop=self.hop_limit)
+
+ @testOddCmsgSize.client_skip
+ def _testOddCmsgSize(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ try:
+ nbytes = self.sendmsgToServer(
+ [MSG],
+ [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
+ array.array("i", [self.traffic_class]).tobytes() + b"\x00"),
+ (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
+ array.array("i", [self.hop_limit]))])
+ except socket.error as e:
+ self.assertIsInstance(e.errno, int)
+ nbytes = self.sendmsgToServer(
+ [MSG],
+ [(socket.IPPROTO_IPV6, socket.IPV6_TCLASS,
+ array.array("i", [self.traffic_class])),
+ (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
+ array.array("i", [self.hop_limit]))])
+ self.assertEqual(nbytes, len(MSG))
+
+ # Tests for proper handling of truncated ancillary data
+
+ def checkHopLimitTruncatedHeader(self, ancbufsize, ignoreflags=0):
+ # Receive hop limit into ancbufsize bytes of ancillary data
+ # space, which should be too small to contain the ancillary
+ # data header (if ancbufsize is None, pass no second argument
+ # to recvmsg()). Check that data is MSG, MSG_CTRUNC is set
+ # (unless included in ignoreflags), and no ancillary data is
+ # returned.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.misc_event.set()
+ args = () if ancbufsize is None else (ancbufsize,)
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), *args)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.assertEqual(ancdata, [])
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testCmsgTruncNoBufSize(self):
+ # Check that no ancillary data is received when no ancillary
+ # buffer size is provided.
+ self.checkHopLimitTruncatedHeader(ancbufsize=None,
+ # BSD seems to set
+ # MSG_CTRUNC only if an item
+ # has been partially
+ # received.
+ ignoreflags=socket.MSG_CTRUNC)
+
+ @testCmsgTruncNoBufSize.client_skip
+ def _testCmsgTruncNoBufSize(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTrunc0(self):
+ # Check that no ancillary data is received when ancillary
+ # buffer size is zero.
+ self.checkHopLimitTruncatedHeader(ancbufsize=0,
+ ignoreflags=socket.MSG_CTRUNC)
+
+ @testSingleCmsgTrunc0.client_skip
+ def _testSingleCmsgTrunc0(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ # Check that no ancillary data is returned for various non-zero
+ # (but still too small) buffer sizes.
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTrunc1(self):
+ self.checkHopLimitTruncatedHeader(ancbufsize=1)
+
+ @testSingleCmsgTrunc1.client_skip
+ def _testSingleCmsgTrunc1(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTrunc2Int(self):
+ self.checkHopLimitTruncatedHeader(ancbufsize=2 * SIZEOF_INT)
+
+ @testSingleCmsgTrunc2Int.client_skip
+ def _testSingleCmsgTrunc2Int(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTruncLen0Minus1(self):
+ self.checkHopLimitTruncatedHeader(ancbufsize=socket.CMSG_LEN(0) - 1)
+
+ @testSingleCmsgTruncLen0Minus1.client_skip
+ def _testSingleCmsgTruncLen0Minus1(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT")
+ def testSingleCmsgTruncInData(self):
+ # Test truncation of a control message inside its associated
+ # data. The message may be returned with its data truncated,
+ # or not returned at all.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(
+ self.serv_sock, len(MSG), socket.CMSG_LEN(SIZEOF_INT) - 1)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
+
+ self.assertLessEqual(len(ancdata), 1)
+ if ancdata:
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ self.assertEqual(cmsg_type, socket.IPV6_HOPLIMIT)
+ self.assertLess(len(cmsg_data), SIZEOF_INT)
+
+ @testSingleCmsgTruncInData.client_skip
+ def _testSingleCmsgTruncInData(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ def checkTruncatedSecondHeader(self, ancbufsize, ignoreflags=0):
+ # Receive traffic class and hop limit into ancbufsize bytes of
+ # ancillary data space, which should be large enough to
+ # contain the first item, but too small to contain the header
+ # of the second. Check that data is MSG, MSG_CTRUNC is set
+ # (unless included in ignoreflags), and only one ancillary
+ # data item is returned.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVTCLASS, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(self.serv_sock,
+ len(MSG), ancbufsize)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC,
+ ignore=ignoreflags)
+
+ self.assertEqual(len(ancdata), 1)
+ cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ self.assertIn(cmsg_type, {socket.IPV6_TCLASS, socket.IPV6_HOPLIMIT})
+ self.assertEqual(len(cmsg_data), SIZEOF_INT)
+ a = array.array("i")
+ a.frombytes(cmsg_data)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], 255)
+
+ # Try the above test with various buffer sizes.
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecondCmsgTrunc0(self):
+ self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT),
+ ignoreflags=socket.MSG_CTRUNC)
+
+ @testSecondCmsgTrunc0.client_skip
+ def _testSecondCmsgTrunc0(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecondCmsgTrunc1(self):
+ self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) + 1)
+
+ @testSecondCmsgTrunc1.client_skip
+ def _testSecondCmsgTrunc1(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecondCmsgTrunc2Int(self):
+ self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) +
+ 2 * SIZEOF_INT)
+
+ @testSecondCmsgTrunc2Int.client_skip
+ def _testSecondCmsgTrunc2Int(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecondCmsgTruncLen0Minus1(self):
+ self.checkTruncatedSecondHeader(socket.CMSG_SPACE(SIZEOF_INT) +
+ socket.CMSG_LEN(0) - 1)
+
+ @testSecondCmsgTruncLen0Minus1.client_skip
+ def _testSecondCmsgTruncLen0Minus1(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+ @requireAttrs(socket, "CMSG_SPACE", "IPV6_RECVHOPLIMIT", "IPV6_HOPLIMIT",
+ "IPV6_RECVTCLASS", "IPV6_TCLASS")
+ def testSecomdCmsgTruncInData(self):
+ # Test truncation of the second of two control messages inside
+ # its associated data.
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVHOPLIMIT, 1)
+ self.serv_sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_RECVTCLASS, 1)
+ self.misc_event.set()
+ msg, ancdata, flags, addr = self.doRecvmsg(
+ self.serv_sock, len(MSG),
+ socket.CMSG_SPACE(SIZEOF_INT) + socket.CMSG_LEN(SIZEOF_INT) - 1)
+
+ self.assertEqual(msg, MSG)
+ self.checkRecvmsgAddress(addr, self.cli_addr)
+ self.checkFlags(flags, eor=True, checkset=socket.MSG_CTRUNC)
+
+ cmsg_types = {socket.IPV6_TCLASS, socket.IPV6_HOPLIMIT}
+
+ cmsg_level, cmsg_type, cmsg_data = ancdata.pop(0)
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ cmsg_types.remove(cmsg_type)
+ self.assertEqual(len(cmsg_data), SIZEOF_INT)
+ a = array.array("i")
+ a.frombytes(cmsg_data)
+ self.assertGreaterEqual(a[0], 0)
+ self.assertLessEqual(a[0], 255)
+
+ if ancdata:
+ cmsg_level, cmsg_type, cmsg_data = ancdata.pop(0)
+ self.assertEqual(cmsg_level, socket.IPPROTO_IPV6)
+ cmsg_types.remove(cmsg_type)
+ self.assertLess(len(cmsg_data), SIZEOF_INT)
+
+ self.assertEqual(ancdata, [])
+
+ @testSecomdCmsgTruncInData.client_skip
+ def _testSecomdCmsgTruncInData(self):
+ self.assertTrue(self.misc_event.wait(timeout=self.fail_timeout))
+ self.sendToServer(MSG)
+
+
+# Derive concrete test classes for different socket types.
+
+class SendrecvmsgUDPTestBase(SendrecvmsgDgramFlagsBase,
+ SendrecvmsgConnectionlessBase,
+ ThreadedSocketTestMixin, UDPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgUDPTest(SendmsgConnectionlessTests, SendrecvmsgUDPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgUDPTest(RecvmsgTests, SendrecvmsgUDPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoUDPTest(RecvmsgIntoTests, SendrecvmsgUDPTestBase):
+ pass
+
+
+class SendrecvmsgUDP6TestBase(SendrecvmsgDgramFlagsBase,
+ SendrecvmsgConnectionlessBase,
+ ThreadedSocketTestMixin, UDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgUDP6Test(SendmsgConnectionlessTests, SendrecvmsgUDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgUDP6Test(RecvmsgTests, SendrecvmsgUDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoUDP6Test(RecvmsgIntoTests, SendrecvmsgUDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
+@requireAttrs(socket, "IPPROTO_IPV6")
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgRFC3542AncillaryUDP6Test(RFC3542AncillaryTest,
+ SendrecvmsgUDP6TestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
+@requireAttrs(socket, "IPPROTO_IPV6")
+@requireSocket("AF_INET6", "SOCK_DGRAM")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoRFC3542AncillaryUDP6Test(RecvmsgIntoMixin,
+ RFC3542AncillaryTest,
+ SendrecvmsgUDP6TestBase):
+ pass
+
+
+class SendrecvmsgTCPTestBase(SendrecvmsgConnectedBase,
+ ConnectedStreamTestMixin, TCPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgTCPTest(SendmsgStreamTests, SendrecvmsgTCPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgTCPTest(RecvmsgTests, RecvmsgGenericStreamTests,
+ SendrecvmsgTCPTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoTCPTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
+ SendrecvmsgTCPTestBase):
+ pass
+
+
+class SendrecvmsgSCTPStreamTestBase(SendrecvmsgSCTPFlagsBase,
+ SendrecvmsgConnectedBase,
+ ConnectedStreamTestMixin, SCTPStreamBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgSCTPStreamTest(SendmsgStreamTests, SendrecvmsgSCTPStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
+ SendrecvmsgSCTPStreamTestBase):
+
+ def testRecvmsgEOF(self):
+ try:
+ super(RecvmsgSCTPStreamTest, self).testRecvmsgEOF()
+ except OSError as e:
+ if e.errno != errno.ENOTCONN:
+ raise
+ self.skipTest("sporadic ENOTCONN (kernel issue?) - see issue #13876")
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoSCTPStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
+ SendrecvmsgSCTPStreamTestBase):
+ pass
+
+
+class SendrecvmsgUnixStreamTestBase(SendrecvmsgConnectedBase,
+ ConnectedStreamTestMixin, UnixStreamBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg")
+@requireAttrs(socket, "AF_UNIX")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class SendmsgUnixStreamTest(SendmsgStreamTests, SendrecvmsgUnixStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg")
+@requireAttrs(socket, "AF_UNIX")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgUnixStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
+ SendrecvmsgUnixStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "recvmsg_into")
+@requireAttrs(socket, "AF_UNIX")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoUnixStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
+ SendrecvmsgUnixStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg", "recvmsg")
+@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgSCMRightsStreamTest(SCMRightsTest, SendrecvmsgUnixStreamTestBase):
+ pass
+
+@requireAttrs(socket.socket, "sendmsg", "recvmsg_into")
+@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class RecvmsgIntoSCMRightsStreamTest(RecvmsgIntoMixin, SCMRightsTest,
+ SendrecvmsgUnixStreamTestBase):
+ pass
+
+
+# Test interrupting the interruptible send/receive methods with a
+# signal when a timeout is set. These tests avoid having multiple
+# threads alive during the test so that the OS cannot deliver the
+# signal to the wrong one.
+
+class InterruptedTimeoutBase(unittest.TestCase):
+ # Base class for interrupted send/receive tests. Installs an
+ # empty handler for SIGALRM and removes it on teardown, along with
+ # any scheduled alarms.
+
+ def setUp(self):
+ super().setUp()
+ orig_alrm_handler = signal.signal(signal.SIGALRM,
+ lambda signum, frame: None)
+ self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
+ self.addCleanup(self.setAlarm, 0)
+
+ # Timeout for socket operations
+ timeout = 4.0
+
+ # Provide setAlarm() method to schedule delivery of SIGALRM after
+ # given number of seconds, or cancel it if zero, and an
+ # appropriate time value to use. Use setitimer() if available.
+ if hasattr(signal, "setitimer"):
+ alarm_time = 0.05
+
+ def setAlarm(self, seconds):
+ signal.setitimer(signal.ITIMER_REAL, seconds)
+ else:
+ # Old systems may deliver the alarm up to one second early
+ alarm_time = 2
+
+ def setAlarm(self, seconds):
+ signal.alarm(seconds)
+
+
+# Require siginterrupt() in order to ensure that system calls are
+# interrupted by default.
+@requireAttrs(signal, "siginterrupt")
+@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"),
+ "Don't have signal.alarm or signal.setitimer")
+class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase):
+ # Test interrupting the recv*() methods with signals when a
+ # timeout is set.
+
+ def setUp(self):
+ super().setUp()
+ self.serv.settimeout(self.timeout)
+
+ def checkInterruptedRecv(self, func, *args, **kwargs):
+ # Check that func(*args, **kwargs) raises socket.error with an
+ # errno of EINTR when interrupted by a signal.
+ self.setAlarm(self.alarm_time)
+ with self.assertRaises(socket.error) as cm:
+ func(*args, **kwargs)
+ self.assertNotIsInstance(cm.exception, socket.timeout)
+ self.assertEqual(cm.exception.errno, errno.EINTR)
+
+ def testInterruptedRecvTimeout(self):
+ self.checkInterruptedRecv(self.serv.recv, 1024)
+
+ def testInterruptedRecvIntoTimeout(self):
+ self.checkInterruptedRecv(self.serv.recv_into, bytearray(1024))
+
+ def testInterruptedRecvfromTimeout(self):
+ self.checkInterruptedRecv(self.serv.recvfrom, 1024)
+
+ def testInterruptedRecvfromIntoTimeout(self):
+ self.checkInterruptedRecv(self.serv.recvfrom_into, bytearray(1024))
+
+ @requireAttrs(socket.socket, "recvmsg")
+ def testInterruptedRecvmsgTimeout(self):
+ self.checkInterruptedRecv(self.serv.recvmsg, 1024)
+
+ @requireAttrs(socket.socket, "recvmsg_into")
+ def testInterruptedRecvmsgIntoTimeout(self):
+ self.checkInterruptedRecv(self.serv.recvmsg_into, [bytearray(1024)])
+
+
+# Require siginterrupt() in order to ensure that system calls are
+# interrupted by default.
+@requireAttrs(signal, "siginterrupt")
+@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"),
+ "Don't have signal.alarm or signal.setitimer")
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class InterruptedSendTimeoutTest(InterruptedTimeoutBase,
+ ThreadSafeCleanupTestCase,
+ SocketListeningTestMixin, TCPTestBase):
+ # Test interrupting the interruptible send*() methods with signals
+ # when a timeout is set.
+
+ def setUp(self):
+ super().setUp()
+ self.serv_conn = self.newSocket()
+ self.addCleanup(self.serv_conn.close)
+ # Use a thread to complete the connection, but wait for it to
+ # terminate before running the test, so that there is only one
+ # thread to accept the signal.
+ cli_thread = threading.Thread(target=self.doConnect)
+ cli_thread.start()
+ self.cli_conn, addr = self.serv.accept()
+ self.addCleanup(self.cli_conn.close)
+ cli_thread.join()
+ self.serv_conn.settimeout(self.timeout)
+
+ def doConnect(self):
+ self.serv_conn.connect(self.serv_addr)
+
+ def checkInterruptedSend(self, func, *args, **kwargs):
+ # Check that func(*args, **kwargs), run in a loop, raises
+ # socket.error with an errno of EINTR when interrupted by a
+ # signal.
+ with self.assertRaises(socket.error) as cm:
+ while True:
+ self.setAlarm(self.alarm_time)
+ func(*args, **kwargs)
+ self.assertNotIsInstance(cm.exception, socket.timeout)
+ self.assertEqual(cm.exception.errno, errno.EINTR)
+
+ # Issue #12958: The following tests have problems on Mac OS X
+ @support.anticipate_failure(sys.platform == "darwin")
+ def testInterruptedSendTimeout(self):
+ self.checkInterruptedSend(self.serv_conn.send, b"a"*512)
+
+ @support.anticipate_failure(sys.platform == "darwin")
+ def testInterruptedSendtoTimeout(self):
+ # Passing an actual address here as Python's wrapper for
+ # sendto() doesn't allow passing a zero-length one; POSIX
+ # requires that the address is ignored since the socket is
+ # connection-mode, however.
+ self.checkInterruptedSend(self.serv_conn.sendto, b"a"*512,
+ self.serv_addr)
+
+ @support.anticipate_failure(sys.platform == "darwin")
+ @requireAttrs(socket.socket, "sendmsg")
+ def testInterruptedSendmsgTimeout(self):
+ self.checkInterruptedSend(self.serv_conn.sendmsg, [b"a"*512])
+
+
@unittest.skipUnless(thread, 'Threading required for this test.')
class TCPCloserTest(ThreadedTCPSocketTest):
@@ -1080,11 +3561,8 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest):
pass
if hasattr(socket, "SOCK_NONBLOCK"):
+ @support.requires_linux_version(2, 6, 28)
def testInitNonBlocking(self):
- v = linux_version()
- if v < (2, 6, 28):
- self.skipTest("Linux kernel 2.6.28 or higher required, not %s"
- % ".".join(map(str, v)))
# reinit server socket
self.serv.close()
self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM |
@@ -1186,7 +3664,7 @@ class FileObjectClassTestCase(SocketConnectedTest):
"""
bufsize = -1 # Use default buffer size
- encoding = 'utf8'
+ encoding = 'utf-8'
errors = 'strict'
newline = None
@@ -1367,7 +3845,7 @@ class FileObjectInterruptedTestCase(unittest.TestCase):
@staticmethod
def _raise_eintr():
- raise socket.error(errno.EINTR)
+ raise socket.error(errno.EINTR, "interrupted")
def _textiowrap_mock_socket(self, mock, buffering=-1):
raw = socket.SocketIO(mock, "r")
@@ -1407,7 +3885,7 @@ class FileObjectInterruptedTestCase(unittest.TestCase):
data = b''
else:
data = ''
- expecting = expecting.decode('utf8')
+ expecting = expecting.decode('utf-8')
while len(data) != len(expecting):
part = fo.read(size)
if not part:
@@ -1569,7 +4047,7 @@ class UnicodeReadFileObjectClassTestCase(FileObjectClassTestCase):
"""Tests for socket.makefile() in text mode (rather than binary)"""
read_mode = 'r'
- read_msg = MSG.decode('utf8')
+ read_msg = MSG.decode('utf-8')
write_mode = 'wb'
write_msg = MSG
newline = ''
@@ -1581,7 +4059,7 @@ class UnicodeWriteFileObjectClassTestCase(FileObjectClassTestCase):
read_mode = 'rb'
read_msg = MSG
write_mode = 'w'
- write_msg = MSG.decode('utf8')
+ write_msg = MSG.decode('utf-8')
newline = ''
@@ -1589,9 +4067,9 @@ class UnicodeReadWriteFileObjectClassTestCase(FileObjectClassTestCase):
"""Tests for socket.makefile() in text mode (rather than binary)"""
read_mode = 'r'
- read_msg = MSG.decode('utf8')
+ read_msg = MSG.decode('utf-8')
write_mode = 'w'
- write_msg = MSG.decode('utf8')
+ write_msg = MSG.decode('utf-8')
newline = ''
@@ -1882,6 +4360,78 @@ class TestLinuxAbstractNamespace(unittest.TestCase):
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
self.assertRaises(socket.error, s.bind, address)
+ def testStrName(self):
+ # Check that an abstract name can be passed as a string.
+ s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ try:
+ s.bind("\x00python\x00test\x00")
+ self.assertEqual(s.getsockname(), b"\x00python\x00test\x00")
+ finally:
+ s.close()
+
+class TestUnixDomain(unittest.TestCase):
+
+ def setUp(self):
+ self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+
+ def tearDown(self):
+ self.sock.close()
+
+ def encoded(self, path):
+ # Return the given path encoded in the file system encoding,
+ # or skip the test if this is not possible.
+ try:
+ return os.fsencode(path)
+ except UnicodeEncodeError:
+ self.skipTest(
+ "Pathname {0!a} cannot be represented in file "
+ "system encoding {1!r}".format(
+ path, sys.getfilesystemencoding()))
+
+ def bind(self, sock, path):
+ # Bind the socket
+ try:
+ sock.bind(path)
+ except OSError as e:
+ if str(e) == "AF_UNIX path too long":
+ self.skipTest(
+ "Pathname {0!a} is too long to serve as a AF_UNIX path"
+ .format(path))
+ else:
+ raise
+
+ def testStrAddr(self):
+ # Test binding to and retrieving a normal string pathname.
+ path = os.path.abspath(support.TESTFN)
+ self.bind(self.sock, path)
+ self.addCleanup(support.unlink, path)
+ self.assertEqual(self.sock.getsockname(), path)
+
+ def testBytesAddr(self):
+ # Test binding to a bytes pathname.
+ path = os.path.abspath(support.TESTFN)
+ self.bind(self.sock, self.encoded(path))
+ self.addCleanup(support.unlink, path)
+ self.assertEqual(self.sock.getsockname(), path)
+
+ def testSurrogateescapeBind(self):
+ # Test binding to a valid non-ASCII pathname, with the
+ # non-ASCII bytes supplied using surrogateescape encoding.
+ path = os.path.abspath(support.TESTFN_UNICODE)
+ b = self.encoded(path)
+ self.bind(self.sock, b.decode("ascii", "surrogateescape"))
+ self.addCleanup(support.unlink, path)
+ self.assertEqual(self.sock.getsockname(), path)
+
+ def testUnencodableAddr(self):
+ # Test binding to a pathname that cannot be encoded in the
+ # file system encoding.
+ if support.TESTFN_UNENCODABLE is None:
+ self.skipTest("No unencodable filename available")
+ path = os.path.abspath(support.TESTFN_UNENCODABLE)
+ self.bind(self.sock, path)
+ self.addCleanup(support.unlink, path)
+ self.assertEqual(self.sock.getsockname(), path)
@unittest.skipUnless(thread, 'Threading required for this test.')
class BufferIOTest(SocketConnectedTest):
@@ -2082,11 +4632,8 @@ class ContextManagersTest(ThreadedTCPSocketTest):
"SOCK_CLOEXEC not defined")
@unittest.skipUnless(fcntl, "module fcntl not available")
class CloexecConstantTest(unittest.TestCase):
+ @support.requires_linux_version(2, 6, 28)
def test_SOCK_CLOEXEC(self):
- v = linux_version()
- if v < (2, 6, 28):
- self.skipTest("Linux kernel 2.6.28 or higher required, not %s"
- % ".".join(map(str, v)))
with socket.socket(socket.AF_INET,
socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s:
self.assertTrue(s.type & socket.SOCK_CLOEXEC)
@@ -2104,11 +4651,8 @@ class NonblockConstantTest(unittest.TestCase):
self.assertFalse(s.type & socket.SOCK_NONBLOCK)
self.assertEqual(s.gettimeout(), None)
+ @support.requires_linux_version(2, 6, 28)
def test_SOCK_NONBLOCK(self):
- v = linux_version()
- if v < (2, 6, 28):
- self.skipTest("Linux kernel 2.6.28 or higher required, not %s"
- % ".".join(map(str, v)))
# a lot of it seems silly and redundant, but I wanted to test that
# changing back and forth worked ok
with socket.socket(socket.AF_INET,
@@ -2141,6 +4685,109 @@ class NonblockConstantTest(unittest.TestCase):
socket.setdefaulttimeout(t)
+@unittest.skipUnless(os.name == "nt", "Windows specific")
+@unittest.skipUnless(multiprocessing, "need multiprocessing")
+class TestSocketSharing(SocketTCPTest):
+ # This must be classmethod and not staticmethod or multiprocessing
+ # won't be able to bootstrap it.
+ @classmethod
+ def remoteProcessServer(cls, q):
+ # Recreate socket from shared data
+ sdata = q.get()
+ message = q.get()
+
+ s = socket.fromshare(sdata)
+ s2, c = s.accept()
+
+ # Send the message
+ s2.sendall(message)
+ s2.close()
+ s.close()
+
+ def testShare(self):
+ # Transfer the listening server socket to another process
+ # and service it from there.
+
+ # Create process:
+ q = multiprocessing.Queue()
+ p = multiprocessing.Process(target=self.remoteProcessServer, args=(q,))
+ p.start()
+
+ # Get the shared socket data
+ data = self.serv.share(p.pid)
+
+ # Pass the shared socket to the other process
+ addr = self.serv.getsockname()
+ self.serv.close()
+ q.put(data)
+
+ # The data that the server will send us
+ message = b"slapmahfro"
+ q.put(message)
+
+ # Connect
+ s = socket.create_connection(addr)
+ # listen for the data
+ m = []
+ while True:
+ data = s.recv(100)
+ if not data:
+ break
+ m.append(data)
+ s.close()
+ received = b"".join(m)
+ self.assertEqual(received, message)
+ p.join()
+
+ def testShareLength(self):
+ data = self.serv.share(os.getpid())
+ self.assertRaises(ValueError, socket.fromshare, data[:-1])
+ self.assertRaises(ValueError, socket.fromshare, data+b"foo")
+
+ def compareSockets(self, org, other):
+ # socket sharing is expected to work only for blocking socket
+ # since the internal python timout value isn't transfered.
+ self.assertEqual(org.gettimeout(), None)
+ self.assertEqual(org.gettimeout(), other.gettimeout())
+
+ self.assertEqual(org.family, other.family)
+ self.assertEqual(org.type, other.type)
+ # If the user specified "0" for proto, then
+ # internally windows will have picked the correct value.
+ # Python introspection on the socket however will still return
+ # 0. For the shared socket, the python value is recreated
+ # from the actual value, so it may not compare correctly.
+ if org.proto != 0:
+ self.assertEqual(org.proto, other.proto)
+
+ def testShareLocal(self):
+ data = self.serv.share(os.getpid())
+ s = socket.fromshare(data)
+ try:
+ self.compareSockets(self.serv, s)
+ finally:
+ s.close()
+
+ def testTypes(self):
+ families = [socket.AF_INET, socket.AF_INET6]
+ types = [socket.SOCK_STREAM, socket.SOCK_DGRAM]
+ for f in families:
+ for t in types:
+ try:
+ source = socket.socket(f, t)
+ except OSError:
+ continue # This combination is not supported
+ try:
+ data = source.share(os.getpid())
+ shared = socket.fromshare(data)
+ try:
+ self.compareSockets(source, shared)
+ finally:
+ shared.close()
+ finally:
+ source.close()
+
+
def test_main():
tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ]
@@ -2164,11 +4811,41 @@ def test_main():
])
if hasattr(socket, "socketpair"):
tests.append(BasicSocketPairTest)
- if sys.platform == 'linux2':
+ if hasattr(socket, "AF_UNIX"):
+ tests.append(TestUnixDomain)
+ if sys.platform == 'linux':
tests.append(TestLinuxAbstractNamespace)
if isTipcAvailable():
tests.append(TIPCTest)
tests.append(TIPCThreadableTest)
+ tests.extend([BasicCANTest, CANTest])
+ tests.extend([BasicRDSTest, RDSTest])
+ tests.extend([
+ CmsgMacroTests,
+ SendmsgUDPTest,
+ RecvmsgUDPTest,
+ RecvmsgIntoUDPTest,
+ SendmsgUDP6Test,
+ RecvmsgUDP6Test,
+ RecvmsgRFC3542AncillaryUDP6Test,
+ RecvmsgIntoRFC3542AncillaryUDP6Test,
+ RecvmsgIntoUDP6Test,
+ SendmsgTCPTest,
+ RecvmsgTCPTest,
+ RecvmsgIntoTCPTest,
+ SendmsgSCTPStreamTest,
+ RecvmsgSCTPStreamTest,
+ RecvmsgIntoSCTPStreamTest,
+ SendmsgUnixStreamTest,
+ RecvmsgUnixStreamTest,
+ RecvmsgIntoUnixStreamTest,
+ RecvmsgSCMRightsStreamTest,
+ RecvmsgIntoSCMRightsStreamTest,
+ # These are slow when setitimer() is not available
+ InterruptedRecvTimeoutTest,
+ InterruptedSendTimeoutTest,
+ TestSocketSharing,
+ ])
thread_info = support.threading_setup()
support.run_unittest(*tests)