diff options
Diffstat (limited to 'Lib/test/test_socket.py')
-rw-r--r-- | Lib/test/test_socket.py | 376 |
1 files changed, 281 insertions, 95 deletions
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index b95e1ab..e995c99 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -2,7 +2,6 @@ import unittest from test import support -from unittest.case import _ExpectedFailure import errno import io @@ -24,13 +23,13 @@ import math import pickle import struct try: - import fcntl -except ImportError: - fcntl = False -try: import multiprocessing except ImportError: multiprocessing = False +try: + import fcntl +except ImportError: + fcntl = None HOST = support.HOST MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') ## test unicode string and carriage return @@ -46,7 +45,7 @@ def _have_socket_can(): """Check whether CAN sockets are supported on this host.""" try: s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) - except (AttributeError, socket.error, OSError): + except (AttributeError, OSError): return False else: s.close() @@ -121,12 +120,42 @@ class SocketCANTest(unittest.TestCase): interface = 'vcan0' bufsize = 128 + """The CAN frame structure is defined in <linux/can.h>: + + struct can_frame { + canid_t can_id; /* 32 bit CAN_ID + EFF/RTR/ERR flags */ + __u8 can_dlc; /* data length code: 0 .. 8 */ + __u8 data[8] __attribute__((aligned(8))); + }; + """ + can_frame_fmt = "=IB3x8s" + can_frame_size = struct.calcsize(can_frame_fmt) + + """The Broadcast Management Command frame structure is defined + in <linux/can/bcm.h>: + + struct bcm_msg_head { + __u32 opcode; + __u32 flags; + __u32 count; + struct timeval ival1, ival2; + canid_t can_id; + __u32 nframes; + struct can_frame frames[0]; + } + + `bcm_msg_head` must be 8 bytes aligned because of the `frames` member (see + `struct can_frame` definition). Must use native not standard types for packing. + """ + bcm_cmd_msg_fmt = "@3I4l2I" + bcm_cmd_msg_fmt += "x" * (struct.calcsize(bcm_cmd_msg_fmt) % 8) + def setUp(self): self.s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) self.addCleanup(self.s.close) try: self.s.bind((self.interface,)) - except socket.error: + except OSError: self.skipTest('network interface `%s` does not exist' % self.interface) @@ -242,9 +271,6 @@ class ThreadableTest: raise TypeError("test_func must be a callable function") try: test_func() - except _ExpectedFailure: - # We deliberately ignore expected failures - pass except BaseException as e: self.queue.put(e) finally: @@ -295,7 +321,7 @@ class ThreadedCANSocketTest(SocketCANTest, ThreadableTest): self.cli = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) try: self.cli.bind((self.interface,)) - except socket.error: + except OSError: # skipTest should not be called here, and will be called in the # server instead pass @@ -543,11 +569,7 @@ class SCTPStreamBase(InetTestBase): class Inet6TestBase(InetTestBase): """Base class for IPv6 socket tests.""" - # Don't use "localhost" here - it may not have an IPv6 address - # assigned to it by default (e.g. in /etc/hosts), and if someone - # has assigned it an IPv4-mapped address, then it's unlikely to - # work with the full IPv6 API. - host = "::1" + host = support.HOSTv6 class UDP6TestBase(Inet6TestBase): """Base class for UDP-over-IPv6 tests.""" @@ -608,7 +630,7 @@ def requireSocket(*args): for obj in args] try: s = socket.socket(*callargs) - except socket.error as e: + except OSError as e: # XXX: check errno? err = str(e) else: @@ -626,8 +648,17 @@ class GeneralModuleTests(unittest.TestCase): def test_repr(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.addCleanup(s.close) - self.assertTrue(repr(s).startswith("<socket.socket object")) + with s: + self.assertIn('fd=%i' % s.fileno(), repr(s)) + self.assertIn('family=%i' % socket.AF_INET, repr(s)) + self.assertIn('type=%i' % socket.SOCK_STREAM, repr(s)) + self.assertIn('proto=0', repr(s)) + self.assertNotIn('raddr', repr(s)) + s.bind(('127.0.0.1', 0)) + self.assertIn('laddr', repr(s)) + self.assertIn(str(s.getsockname()), repr(s)) + self.assertIn('[closed]', repr(s)) + self.assertNotIn('laddr', repr(s)) def test_weakref(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -645,11 +676,11 @@ class GeneralModuleTests(unittest.TestCase): def testSocketError(self): # Testing socket module exceptions msg = "Error raising socket exception (%s)." - with self.assertRaises(socket.error, msg=msg % 'socket.error'): - raise socket.error - with self.assertRaises(socket.error, msg=msg % 'socket.herror'): + with self.assertRaises(OSError, msg=msg % 'OSError'): + raise OSError + with self.assertRaises(OSError, msg=msg % 'socket.herror'): raise socket.herror - with self.assertRaises(socket.error, msg=msg % 'socket.gaierror'): + with self.assertRaises(OSError, msg=msg % 'socket.gaierror'): raise socket.gaierror def testSendtoErrors(self): @@ -712,13 +743,13 @@ class GeneralModuleTests(unittest.TestCase): hostname = socket.gethostname() try: ip = socket.gethostbyname(hostname) - except socket.error: + except OSError: # Probably name lookup wasn't set up right; skip this test return self.assertTrue(ip.find('.') >= 0, "Error resolving host to ip.") try: hname, aliases, ipaddrs = socket.gethostbyaddr(ip) - except socket.error: + except OSError: # Probably a similar problem as above; skip this test return all_host_names = [hostname, hname] + aliases @@ -726,13 +757,27 @@ class GeneralModuleTests(unittest.TestCase): if not fqhn in all_host_names: self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names))) + def test_host_resolution(self): + for addr in ['0.1.1.~1', '1+.1.1.1', '::1q', '::1::2', + '1:1:1:1:1:1:1:1:1']: + self.assertRaises(OSError, socket.gethostbyname, addr) + self.assertRaises(OSError, socket.gethostbyaddr, addr) + + for addr in [support.HOST, '10.0.0.1', '255.255.255.255']: + self.assertEqual(socket.gethostbyname(addr), addr) + + # we don't test support.HOSTv6 because there's a chance it doesn't have + # a matching name entry (e.g. 'ip6-localhost') + for host in [support.HOST]: + self.assertIn(host, socket.gethostbyaddr(host)[2]) + @unittest.skipUnless(hasattr(socket, 'sethostname'), "test needs socket.sethostname()") @unittest.skipUnless(hasattr(socket, 'gethostname'), "test needs socket.gethostname()") def test_sethostname(self): oldhn = socket.gethostname() try: socket.sethostname('new') - except socket.error as e: + except OSError as e: if e.errno == errno.EPERM: self.skipTest("test should be run as root") else: @@ -766,8 +811,8 @@ class GeneralModuleTests(unittest.TestCase): 'socket.if_nameindex() not available.') def testInvalidInterfaceNameIndex(self): # test nonexistent interface index/name - self.assertRaises(socket.error, socket.if_indextoname, 0) - self.assertRaises(socket.error, socket.if_nametoindex, '_DEADBEEF') + self.assertRaises(OSError, socket.if_indextoname, 0) + self.assertRaises(OSError, socket.if_nametoindex, '_DEADBEEF') # test with invalid values self.assertRaises(TypeError, socket.if_nametoindex, 0) self.assertRaises(TypeError, socket.if_indextoname, '_DEADBEEF') @@ -788,7 +833,7 @@ class GeneralModuleTests(unittest.TestCase): try: # On some versions, this crashes the interpreter. socket.getnameinfo(('x', 0, 0, 0), 0) - except socket.error: + except OSError: pass def testNtoH(self): @@ -835,17 +880,17 @@ class GeneralModuleTests(unittest.TestCase): try: port = socket.getservbyname(service, 'tcp') break - except socket.error: + except OSError: pass else: - raise socket.error + raise OSError # Try same call with optional protocol omitted port2 = socket.getservbyname(service) eq(port, port2) # Try udp, but don't barf if it doesn't exist try: udpport = socket.getservbyname(service, 'udp') - except socket.error: + except OSError: udpport = None else: eq(udpport, port) @@ -901,7 +946,7 @@ class GeneralModuleTests(unittest.TestCase): g = lambda a: inet_pton(AF_INET, a) assertInvalid = lambda func,a: self.assertRaises( - (socket.error, ValueError), func, a + (OSError, ValueError), func, a ) self.assertEqual(b'\x00\x00\x00\x00', f('0.0.0.0')) @@ -936,7 +981,7 @@ class GeneralModuleTests(unittest.TestCase): return f = lambda a: inet_pton(AF_INET6, a) assertInvalid = lambda a: self.assertRaises( - (socket.error, ValueError), f, a + (OSError, ValueError), f, a ) self.assertEqual(b'\x00' * 16, f('::')) @@ -985,7 +1030,7 @@ class GeneralModuleTests(unittest.TestCase): from socket import inet_ntoa as f, inet_ntop, AF_INET g = lambda a: inet_ntop(AF_INET, a) assertInvalid = lambda func,a: self.assertRaises( - (socket.error, ValueError), func, a + (OSError, ValueError), func, a ) self.assertEqual('1.0.1.0', f(b'\x01\x00\x01\x00')) @@ -1014,7 +1059,7 @@ class GeneralModuleTests(unittest.TestCase): return f = lambda a: inet_ntop(AF_INET6, a) assertInvalid = lambda a: self.assertRaises( - (socket.error, ValueError), f, a + (OSError, ValueError), f, a ) self.assertEqual('::', f(b'\x00' * 16)) @@ -1042,7 +1087,7 @@ class GeneralModuleTests(unittest.TestCase): # At least for eCos. This is required for the S/390 to pass. try: my_ip_addr = socket.gethostbyname(socket.gethostname()) - except socket.error: + except OSError: # Probably name lookup wasn't set up right; skip this test return self.assertIn(name[0], ("0.0.0.0", my_ip_addr), '%s invalid' % name[0]) @@ -1069,13 +1114,19 @@ class GeneralModuleTests(unittest.TestCase): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(1) sock.close() - self.assertRaises(socket.error, sock.send, b"spam") + self.assertRaises(OSError, sock.send, b"spam") def testNewAttributes(self): # testing .family, .type and .protocol + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.assertEqual(sock.family, socket.AF_INET) - self.assertEqual(sock.type, socket.SOCK_STREAM) + if hasattr(socket, 'SOCK_CLOEXEC'): + self.assertIn(sock.type, + (socket.SOCK_STREAM | socket.SOCK_CLOEXEC, + socket.SOCK_STREAM)) + else: + self.assertEqual(sock.type, socket.SOCK_STREAM) self.assertEqual(sock.proto, 0) sock.close() @@ -1128,9 +1179,12 @@ class GeneralModuleTests(unittest.TestCase): socket.getaddrinfo(HOST, 80) socket.getaddrinfo(HOST, None) # test family and socktype filters - infos = socket.getaddrinfo(HOST, None, socket.AF_INET) - for family, _, _, _, _ in infos: + infos = socket.getaddrinfo(HOST, 80, socket.AF_INET, socket.SOCK_STREAM) + for family, type, _, _, _ in infos: self.assertEqual(family, socket.AF_INET) + self.assertEqual(str(family), 'AddressFamily.AF_INET') + self.assertEqual(type, socket.SOCK_STREAM) + self.assertEqual(str(type), 'SocketType.SOCK_STREAM') infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM) for _, socktype, _, _, _ in infos: self.assertEqual(socktype, socket.SOCK_STREAM) @@ -1172,7 +1226,7 @@ class GeneralModuleTests(unittest.TestCase): def test_getnameinfo(self): # only IP addresses are allowed - self.assertRaises(socket.error, socket.getnameinfo, ('mail.python.org',0), 0) + self.assertRaises(OSError, socket.getnameinfo, ('mail.python.org',0), 0) @unittest.skipUnless(support.is_resource_enabled('network'), 'network is not enabled') @@ -1284,10 +1338,31 @@ class GeneralModuleTests(unittest.TestCase): @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.') def test_flowinfo(self): self.assertRaises(OverflowError, socket.getnameinfo, - ('::1',0, 0xffffffff), 0) + (support.HOSTv6, 0, 0xffffffff), 0) with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: - self.assertRaises(OverflowError, s.bind, ('::1', 0, -10)) - + self.assertRaises(OverflowError, s.bind, (support.HOSTv6, 0, -10)) + + def test_str_for_enums(self): + # Make sure that the AF_* and SOCK_* constants have enum-like string + # reprs. + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + self.assertEqual(str(s.family), 'AddressFamily.AF_INET') + self.assertEqual(str(s.type), 'SocketType.SOCK_STREAM') + + @unittest.skipIf(os.name == 'nt', 'Will not work on Windows') + def test_uknown_socket_family_repr(self): + # Test that when created with a family that's not one of the known + # AF_*/SOCK_* constants, socket.family just returns the number. + # + # To do this we fool socket.socket into believing it already has an + # open fd because on this path it doesn't actually verify the family and + # type and populates the socket object. + # + # On Windows this trick won't work, so the test is skipped. + fd, _ = tempfile.mkstemp() + with socket.socket(family=42424, type=13331, fileno=fd) as s: + self.assertEqual(s.family, 42424) + self.assertEqual(s.type, 13331) @unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.') class BasicCANTest(unittest.TestCase): @@ -1297,10 +1372,35 @@ class BasicCANTest(unittest.TestCase): socket.PF_CAN socket.CAN_RAW + @unittest.skipUnless(hasattr(socket, "CAN_BCM"), + 'socket.CAN_BCM required for this test.') + def testBCMConstants(self): + socket.CAN_BCM + + # opcodes + socket.CAN_BCM_TX_SETUP # create (cyclic) transmission task + socket.CAN_BCM_TX_DELETE # remove (cyclic) transmission task + socket.CAN_BCM_TX_READ # read properties of (cyclic) transmission task + socket.CAN_BCM_TX_SEND # send one CAN frame + socket.CAN_BCM_RX_SETUP # create RX content filter subscription + socket.CAN_BCM_RX_DELETE # remove RX content filter subscription + socket.CAN_BCM_RX_READ # read properties of RX content filter subscription + socket.CAN_BCM_TX_STATUS # reply to TX_READ request + socket.CAN_BCM_TX_EXPIRED # notification on performed transmissions (count=0) + socket.CAN_BCM_RX_STATUS # reply to RX_READ request + socket.CAN_BCM_RX_TIMEOUT # cyclic message is absent + socket.CAN_BCM_RX_CHANGED # updated CAN frame (detected content change) + def testCreateSocket(self): with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s: pass + @unittest.skipUnless(hasattr(socket, "CAN_BCM"), + 'socket.CAN_BCM required for this test.') + def testCreateBCMSocket(self): + with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_BCM) as s: + pass + def testBindAny(self): with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s: s.bind(('', )) @@ -1308,7 +1408,7 @@ class BasicCANTest(unittest.TestCase): def testTooLongInterfaceName(self): # most systems limit IFNAMSIZ to 16, take 1024 to be sure with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s: - self.assertRaisesRegex(socket.error, 'interface name too long', + self.assertRaisesRegex(OSError, 'interface name too long', s.bind, ('x' * 1024,)) @unittest.skipUnless(hasattr(socket, "CAN_RAW_LOOPBACK"), @@ -1333,19 +1433,8 @@ class BasicCANTest(unittest.TestCase): @unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.') -@unittest.skipUnless(thread, 'Threading required for this test.') class CANTest(ThreadedCANSocketTest): - """The CAN frame structure is defined in <linux/can.h>: - - struct can_frame { - canid_t can_id; /* 32 bit CAN_ID + EFF/RTR/ERR flags */ - __u8 can_dlc; /* data length code: 0 .. 8 */ - __u8 data[8] __attribute__((aligned(8))); - }; - """ - can_frame_fmt = "=IB3x8s" - def __init__(self, methodName='runTest'): ThreadedCANSocketTest.__init__(self, methodName=methodName) @@ -1394,6 +1483,46 @@ class CANTest(ThreadedCANSocketTest): self.cf2 = self.build_can_frame(0x12, b'\x99\x22\x33') self.cli.send(self.cf2) + @unittest.skipUnless(hasattr(socket, "CAN_BCM"), + 'socket.CAN_BCM required for this test.') + def _testBCM(self): + cf, addr = self.cli.recvfrom(self.bufsize) + self.assertEqual(self.cf, cf) + can_id, can_dlc, data = self.dissect_can_frame(cf) + self.assertEqual(self.can_id, can_id) + self.assertEqual(self.data, data) + + @unittest.skipUnless(hasattr(socket, "CAN_BCM"), + 'socket.CAN_BCM required for this test.') + def testBCM(self): + bcm = socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_BCM) + self.addCleanup(bcm.close) + bcm.connect((self.interface,)) + self.can_id = 0x123 + self.data = bytes([0xc0, 0xff, 0xee]) + self.cf = self.build_can_frame(self.can_id, self.data) + opcode = socket.CAN_BCM_TX_SEND + flags = 0 + count = 0 + ival1_seconds = ival1_usec = ival2_seconds = ival2_usec = 0 + bcm_can_id = 0x0222 + nframes = 1 + assert len(self.cf) == 16 + header = struct.pack(self.bcm_cmd_msg_fmt, + opcode, + flags, + count, + ival1_seconds, + ival1_usec, + ival2_seconds, + ival2_usec, + bcm_can_id, + nframes, + ) + header_plus_frame = header + self.cf + bytes_sent = bcm.send(header_plus_frame) + self.assertEqual(bytes_sent, len(header_plus_frame)) + @unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.') class BasicRDSTest(unittest.TestCase): @@ -1608,7 +1737,7 @@ class BasicTCPTest(SocketConnectedTest): self.assertEqual(f, fileno) # cli_conn cannot be used anymore... self.assertTrue(self.cli_conn._closed) - self.assertRaises(socket.error, self.cli_conn.recv, 1024) + self.assertRaises(OSError, self.cli_conn.recv, 1024) self.cli_conn.close() # ...but we can create another socket using the (still open) # file descriptor @@ -1977,7 +2106,7 @@ class SendmsgTests(SendrecvmsgServerTimeoutBase): def _testSendmsgExcessCmsgReject(self): if not hasattr(socket, "CMSG_SPACE"): # Can only send one item - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: self.sendmsgToServer([MSG], [(0, 0, b""), (0, 0, b"")]) self.assertIsNone(cm.exception.errno) self.sendToServer(b"done") @@ -1988,7 +2117,7 @@ class SendmsgTests(SendrecvmsgServerTimeoutBase): def _testSendmsgAfterClose(self): self.cli_sock.close() - self.assertRaises(socket.error, self.sendmsgToServer, [MSG]) + self.assertRaises(OSError, self.sendmsgToServer, [MSG]) class SendmsgStreamTests(SendmsgTests): @@ -2032,7 +2161,7 @@ class SendmsgStreamTests(SendmsgTests): @testSendmsgDontWait.client_skip def _testSendmsgDontWait(self): try: - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: while True: self.sendmsgToServer([b"a"*512], [], socket.MSG_DONTWAIT) self.assertIn(cm.exception.errno, @@ -2052,9 +2181,9 @@ class SendmsgConnectionlessTests(SendmsgTests): pass def _testSendmsgNoDestAddr(self): - self.assertRaises(socket.error, self.cli_sock.sendmsg, + self.assertRaises(OSError, self.cli_sock.sendmsg, [MSG]) - self.assertRaises(socket.error, self.cli_sock.sendmsg, + self.assertRaises(OSError, self.cli_sock.sendmsg, [MSG], [], 0, None) @@ -2140,7 +2269,7 @@ class RecvmsgGenericTests(SendrecvmsgBase): def testRecvmsgAfterClose(self): # Check that recvmsg[_into]() fails on a closed socket. self.serv_sock.close() - self.assertRaises(socket.error, self.doRecvmsg, self.serv_sock, 1024) + self.assertRaises(OSError, self.doRecvmsg, self.serv_sock, 1024) def _testRecvmsgAfterClose(self): pass @@ -2586,7 +2715,7 @@ class SCMRightsTest(SendrecvmsgServerTimeoutBase): # call fails, just send msg with no ancillary data. try: nbytes = self.sendmsgToServer([msg], ancdata) - except socket.error as e: + except OSError as e: # Check that it was the system call that failed self.assertIsInstance(e.errno, int) nbytes = self.sendmsgToServer([msg]) @@ -2964,7 +3093,7 @@ class RFC3542AncillaryTest(SendrecvmsgServerTimeoutBase): array.array("i", [self.traffic_class]).tobytes() + b"\x00"), (socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT, array.array("i", [self.hop_limit]))]) - except socket.error as e: + except OSError as e: self.assertIsInstance(e.errno, int) nbytes = self.sendmsgToServer( [MSG], @@ -3422,10 +3551,10 @@ class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase): self.serv.settimeout(self.timeout) def checkInterruptedRecv(self, func, *args, **kwargs): - # Check that func(*args, **kwargs) raises socket.error with an + # Check that func(*args, **kwargs) raises OSError with an # errno of EINTR when interrupted by a signal. self.setAlarm(self.alarm_time) - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: func(*args, **kwargs) self.assertNotIsInstance(cm.exception, socket.timeout) self.assertEqual(cm.exception.errno, errno.EINTR) @@ -3482,9 +3611,9 @@ class InterruptedSendTimeoutTest(InterruptedTimeoutBase, def checkInterruptedSend(self, func, *args, **kwargs): # Check that func(*args, **kwargs), run in a loop, raises - # socket.error with an errno of EINTR when interrupted by a + # OSError with an errno of EINTR when interrupted by a # signal. - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: while True: self.setAlarm(self.alarm_time) func(*args, **kwargs) @@ -3581,7 +3710,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): start = time.time() try: self.serv.accept() - except socket.error: + except OSError: pass end = time.time() self.assertTrue((end - start) < 1.0, "Error setting non-blocking mode.") @@ -3606,7 +3735,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): start = time.time() try: self.serv.accept() - except socket.error: + except OSError: pass end = time.time() self.assertTrue((end - start) < 1.0, "Error creating with non-blocking mode.") @@ -3636,7 +3765,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): self.serv.setblocking(0) try: conn, addr = self.serv.accept() - except socket.error: + except OSError: pass else: self.fail("Error trying to do non-blocking accept.") @@ -3666,7 +3795,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): conn.setblocking(0) try: msg = conn.recv(len(MSG)) - except socket.error: + except OSError: pass else: self.fail("Error trying to do non-blocking recv.") @@ -3749,7 +3878,7 @@ class FileObjectClassTestCase(SocketConnectedTest): # First read raises a timeout self.assertRaises(socket.timeout, self.read_file.read, 1) # Second read is disallowed - with self.assertRaises(IOError) as ctx: + with self.assertRaises(OSError) as ctx: self.read_file.read(1) self.assertIn("cannot read from timed out object", str(ctx.exception)) @@ -3841,7 +3970,7 @@ class FileObjectClassTestCase(SocketConnectedTest): self.read_file.close() self.assertRaises(ValueError, self.read_file.fileno) self.cli_conn.close() - self.assertRaises(socket.error, self.cli_conn.getsockname) + self.assertRaises(OSError, self.cli_conn.getsockname) def _testRealClose(self): pass @@ -3878,7 +4007,7 @@ class FileObjectInterruptedTestCase(unittest.TestCase): @staticmethod def _raise_eintr(): - raise socket.error(errno.EINTR, "interrupted") + raise OSError(errno.EINTR, "interrupted") def _textiowrap_mock_socket(self, mock, buffering=-1): raw = socket.SocketIO(mock, "r") @@ -3990,7 +4119,7 @@ class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): self.assertEqual(msg, self.read_msg) # ...until the file is itself closed self.read_file.close() - self.assertRaises(socket.error, self.cli_conn.recv, 1024) + self.assertRaises(OSError, self.cli_conn.recv, 1024) def _testMakefileClose(self): self.write_file.write(self.write_msg) @@ -4139,7 +4268,7 @@ class NetworkConnectionNoServer(unittest.TestCase): port = support.find_unused_port() cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.addCleanup(cli.close) - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: cli.connect((HOST, port)) self.assertEqual(cm.exception.errno, errno.ECONNREFUSED) @@ -4147,7 +4276,7 @@ class NetworkConnectionNoServer(unittest.TestCase): # Issue #9792: errors raised by create_connection() should have # a proper errno attribute. port = support.find_unused_port() - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: socket.create_connection((HOST, port)) # Issue #16257: create_connection() calls getaddrinfo() against @@ -4295,7 +4424,7 @@ class TCPTimeoutTest(SocketTCPTest): foo = self.serv.accept() except socket.timeout: self.fail("caught timeout instead of error (TCP)") - except socket.error: + except OSError: ok = True except: self.fail("caught unexpected exception (TCP)") @@ -4352,7 +4481,7 @@ class UDPTimeoutTest(SocketUDPTest): foo = self.serv.recv(1024) except socket.timeout: self.fail("caught timeout instead of error (UDP)") - except socket.error: + except OSError: ok = True except: self.fail("caught unexpected exception (UDP)") @@ -4362,10 +4491,10 @@ class UDPTimeoutTest(SocketUDPTest): class TestExceptions(unittest.TestCase): def testExceptionTree(self): - self.assertTrue(issubclass(socket.error, Exception)) - self.assertTrue(issubclass(socket.herror, socket.error)) - self.assertTrue(issubclass(socket.gaierror, socket.error)) - self.assertTrue(issubclass(socket.timeout, socket.error)) + self.assertTrue(issubclass(OSError, Exception)) + self.assertTrue(issubclass(socket.herror, OSError)) + self.assertTrue(issubclass(socket.gaierror, OSError)) + self.assertTrue(issubclass(socket.timeout, OSError)) class TestLinuxAbstractNamespace(unittest.TestCase): @@ -4391,7 +4520,7 @@ class TestLinuxAbstractNamespace(unittest.TestCase): def testNameOverflow(self): address = "\x00" + "h" * self.UNIX_PATH_MAX with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: - self.assertRaises(socket.error, s.bind, address) + self.assertRaises(OSError, s.bind, address) def testStrName(self): # Check that an abstract name can be passed as a string. @@ -4630,7 +4759,7 @@ class ContextManagersTest(ThreadedTCPSocketTest): self.assertTrue(sock._closed) # exception inside with block with socket.socket() as sock: - self.assertRaises(socket.error, sock.sendall, b'foo') + self.assertRaises(OSError, sock.sendall, b'foo') self.assertTrue(sock._closed) def testCreateConnectionBase(self): @@ -4658,19 +4787,76 @@ class ContextManagersTest(ThreadedTCPSocketTest): with socket.create_connection(address) as sock: sock.close() self.assertTrue(sock._closed) - self.assertRaises(socket.error, sock.sendall, b'foo') + self.assertRaises(OSError, sock.sendall, b'foo') -@unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"), - "SOCK_CLOEXEC not defined") -@unittest.skipUnless(fcntl, "module fcntl not available") -class CloexecConstantTest(unittest.TestCase): +class InheritanceTest(unittest.TestCase): + @unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"), + "SOCK_CLOEXEC not defined") @support.requires_linux_version(2, 6, 28) def test_SOCK_CLOEXEC(self): with socket.socket(socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s: self.assertTrue(s.type & socket.SOCK_CLOEXEC) - self.assertTrue(fcntl.fcntl(s, fcntl.F_GETFD) & fcntl.FD_CLOEXEC) + self.assertFalse(s.get_inheritable()) + + def test_default_inheritable(self): + sock = socket.socket() + with sock: + self.assertEqual(sock.get_inheritable(), False) + + def test_dup(self): + sock = socket.socket() + with sock: + newsock = sock.dup() + sock.close() + with newsock: + self.assertEqual(newsock.get_inheritable(), False) + + def test_set_inheritable(self): + sock = socket.socket() + with sock: + sock.set_inheritable(True) + self.assertEqual(sock.get_inheritable(), True) + + sock.set_inheritable(False) + self.assertEqual(sock.get_inheritable(), False) + + @unittest.skipIf(fcntl is None, "need fcntl") + def test_get_inheritable_cloexec(self): + sock = socket.socket() + with sock: + fd = sock.fileno() + self.assertEqual(sock.get_inheritable(), False) + + # clear FD_CLOEXEC flag + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + flags &= ~fcntl.FD_CLOEXEC + fcntl.fcntl(fd, fcntl.F_SETFD, flags) + + self.assertEqual(sock.get_inheritable(), True) + + @unittest.skipIf(fcntl is None, "need fcntl") + def test_set_inheritable_cloexec(self): + sock = socket.socket() + with sock: + fd = sock.fileno() + self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC, + fcntl.FD_CLOEXEC) + + sock.set_inheritable(True) + self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC, + 0) + + + @unittest.skipUnless(hasattr(socket, "socketpair"), + "need socket.socketpair()") + def test_socketpair(self): + s1, s2 = socket.socketpair() + self.addCleanup(s1.close) + self.addCleanup(s2.close) + self.assertEqual(s1.get_inheritable(), False) + self.assertEqual(s2.get_inheritable(), False) @unittest.skipUnless(hasattr(socket, "SOCK_NONBLOCK"), @@ -4839,7 +5025,7 @@ def test_main(): NetworkConnectionAttributesTest, NetworkConnectionBehaviourTest, ContextManagersTest, - CloexecConstantTest, + InheritanceTest, NonblockConstantTest ]) if hasattr(socket, "socketpair"): |