diff options
Diffstat (limited to 'Lib/test/test_socket.py')
-rw-r--r-- | Lib/test/test_socket.py | 676 |
1 files changed, 604 insertions, 72 deletions
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 6a9497b..baca4c1 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1,23 +1,27 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import unittest from test import support import errno +import io import socket import select -import _thread as thread -import threading import time import traceback import queue import sys import os import array +import platform import contextlib from weakref import proxy import signal import math +try: + import fcntl +except ImportError: + fcntl = False def try_address(host, port=0, family=socket.AF_INET): """Try to bind a socket on the given host:port and return True @@ -31,10 +35,25 @@ def try_address(host, port=0, family=socket.AF_INET): sock.close() return True +def linux_version(): + try: + # platform.release() is something like '2.6.33.7-desktop-2mnb' + version_string = platform.release().split('-')[0] + return tuple(map(int, version_string.split('.'))) + except ValueError: + return 0, 0, 0 + HOST = support.HOST MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf8') ## test unicode string and carriage return SUPPORTS_IPV6 = socket.has_ipv6 and try_address('::1', family=socket.AF_INET6) +try: + import _thread as thread + import threading +except ImportError: + thread = None + threading = None + class SocketTCPTest(unittest.TestCase): def setUp(self): @@ -132,8 +151,8 @@ class ThreadableTest: self.done.wait() if self.queue.qsize(): - msg = self.queue.get() - self.fail(msg) + exc = self.queue.get() + raise exc def clientRun(self, test_func): self.server_ready.wait() @@ -143,9 +162,10 @@ class ThreadableTest: raise TypeError("test_func must be a callable function") try: test_func() - except Exception as strerror: - self.queue.put(strerror) - self.clientTearDown() + except BaseException as e: + self.queue.put(e) + finally: + self.clientTearDown() def clientSetUp(self): raise NotImplementedError("clientSetUp must be implemented.") @@ -244,6 +264,7 @@ class GeneralModuleTests(unittest.TestCase): def test_repr(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.addCleanup(s.close) self.assertTrue(repr(s).startswith("<socket.socket object")) def test_weakref(self): @@ -281,28 +302,42 @@ class GeneralModuleTests(unittest.TestCase): s.bind(('', 0)) sockname = s.getsockname() # 2 args - with self.assertRaises(TypeError): + with self.assertRaises(TypeError) as cm: s.sendto('\u2620', sockname) - with self.assertRaises(TypeError): + self.assertEqual(str(cm.exception), + "'str' does not support the buffer interface") + with self.assertRaises(TypeError) as cm: s.sendto(5j, sockname) - with self.assertRaises(TypeError): + self.assertEqual(str(cm.exception), + "'complex' does not support the buffer interface") + with self.assertRaises(TypeError) as cm: s.sendto(b'foo', None) + self.assertIn('not NoneType',str(cm.exception)) # 3 args - with self.assertRaises(TypeError): + with self.assertRaises(TypeError) as cm: s.sendto('\u2620', 0, sockname) - with self.assertRaises(TypeError): + self.assertEqual(str(cm.exception), + "'str' does not support the buffer interface") + with self.assertRaises(TypeError) as cm: s.sendto(5j, 0, sockname) - with self.assertRaises(TypeError): + self.assertEqual(str(cm.exception), + "'complex' does not support the buffer interface") + with self.assertRaises(TypeError) as cm: s.sendto(b'foo', 0, None) - with self.assertRaises(TypeError): + self.assertIn('not NoneType', str(cm.exception)) + with self.assertRaises(TypeError) as cm: s.sendto(b'foo', 'bar', sockname) - with self.assertRaises(TypeError): + self.assertIn('an integer is required', str(cm.exception)) + with self.assertRaises(TypeError) as cm: s.sendto(b'foo', None, None) + self.assertIn('an integer is required', str(cm.exception)) # wrong number of args - with self.assertRaises(TypeError): + with self.assertRaises(TypeError) as cm: s.sendto(b'foo') - with self.assertRaises(TypeError): + self.assertIn('(1 given)', str(cm.exception)) + with self.assertRaises(TypeError) as cm: s.sendto(b'foo', 0, sockname, 4) + self.assertIn('(4 given)', str(cm.exception)) def testCrucialConstants(self): # Testing for mission critical constants @@ -529,24 +564,11 @@ class GeneralModuleTests(unittest.TestCase): # XXX The following don't test module-level functionality... - def _get_unused_port(self, bind_address='0.0.0.0'): - """Use a temporary socket to elicit an unused ephemeral port. - - Args: - bind_address: Hostname or IP address to search for a port on. - - Returns: A most likely to be unused port. - """ - tempsock = socket.socket() - tempsock.bind((bind_address, 0)) - host, port = tempsock.getsockname() - tempsock.close() - return port - def testSockName(self): # Testing getsockname() - port = self._get_unused_port() + port = support.find_unused_port() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.addCleanup(sock.close) sock.bind(("0.0.0.0", port)) name = sock.getsockname() # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate @@ -557,19 +579,21 @@ class GeneralModuleTests(unittest.TestCase): except socket.error: # Probably name lookup wasn't set up right; skip this test return - self.assertTrue(name[0] in ("0.0.0.0", my_ip_addr), '%s invalid' % name[0]) + self.assertIn(name[0], ("0.0.0.0", my_ip_addr), '%s invalid' % name[0]) self.assertEqual(name[1], port) def testGetSockOpt(self): # Testing getsockopt() # We know a socket should start without reuse==0 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.addCleanup(sock.close) reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) self.assertFalse(reuse != 0, "initial mode is reuse") def testSetSockOpt(self): # Testing setsockopt() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.addCleanup(sock.close) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) self.assertFalse(reuse == 0, "failed to set reuse mode") @@ -591,7 +615,7 @@ class GeneralModuleTests(unittest.TestCase): def test_getsockaddrarg(self): host = '0.0.0.0' - port = self._get_unused_port(bind_address=host) + port = support.find_unused_port() big_port = port + 65536 neg_port = port - 65536 sock = socket.socket() @@ -602,13 +626,17 @@ class GeneralModuleTests(unittest.TestCase): finally: sock.close() + @unittest.skipUnless(os.name == "nt", "Windows specific") def test_sock_ioctl(self): - if os.name != "nt": - return self.assertTrue(hasattr(socket.socket, 'ioctl')) self.assertTrue(hasattr(socket, 'SIO_RCVALL')) self.assertTrue(hasattr(socket, 'RCVALL_ON')) self.assertTrue(hasattr(socket, 'RCVALL_OFF')) + self.assertTrue(hasattr(socket, 'SIO_KEEPALIVE_VALS')) + s = socket.socket() + self.addCleanup(s.close) + self.assertRaises(ValueError, s.ioctl, -1, None) + s.ioctl(socket.SIO_KEEPALIVE_VALS, (1, 100, 100)) def testGetaddrinfo(self): try: @@ -647,7 +675,46 @@ class GeneralModuleTests(unittest.TestCase): # usually do this socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE) - + # test keyword arguments + a = socket.getaddrinfo(HOST, None) + b = socket.getaddrinfo(host=HOST, port=None) + self.assertEqual(a, b) + a = socket.getaddrinfo(HOST, None, socket.AF_INET) + b = socket.getaddrinfo(HOST, None, family=socket.AF_INET) + self.assertEqual(a, b) + a = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM) + b = socket.getaddrinfo(HOST, None, type=socket.SOCK_STREAM) + self.assertEqual(a, b) + a = socket.getaddrinfo(HOST, None, 0, 0, socket.SOL_TCP) + b = socket.getaddrinfo(HOST, None, proto=socket.SOL_TCP) + self.assertEqual(a, b) + a = socket.getaddrinfo(HOST, None, 0, 0, 0, socket.AI_PASSIVE) + b = socket.getaddrinfo(HOST, None, flags=socket.AI_PASSIVE) + self.assertEqual(a, b) + a = socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, + socket.AI_PASSIVE) + b = socket.getaddrinfo(host=None, port=0, family=socket.AF_UNSPEC, + type=socket.SOCK_STREAM, proto=0, + flags=socket.AI_PASSIVE) + self.assertEqual(a, b) + # Issue #6697. + self.assertRaises(UnicodeEncodeError, socket.getaddrinfo, 'localhost', '\uD800') + + def test_getnameinfo(self): + # only IP addresses are allowed + self.assertRaises(socket.error, socket.getnameinfo, ('mail.python.org',0), 0) + + @unittest.skipUnless(support.is_resource_enabled('network'), + 'network is not enabled') + def test_idna(self): + support.requires('network') + # these should all be successful + socket.gethostbyname('испытание.python.org') + socket.gethostbyname_ex('испытание.python.org') + socket.getaddrinfo('испытание.python.org',0,socket.AF_UNSPEC,socket.SOCK_STREAM) + # this may not work if the forward lookup choses the IPv6 address, as that doesn't + # have a reverse entry yet + # socket.gethostbyaddr('испытание.python.org') def check_sendall_interrupted(self, with_timeout): # socketpair() is not stricly required, but it makes things easier. @@ -684,7 +751,38 @@ class GeneralModuleTests(unittest.TestCase): def test_sendall_interrupted_with_timeout(self): self.check_sendall_interrupted(True) - + def test_dealloc_warn(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + r = repr(sock) + with self.assertWarns(ResourceWarning) as cm: + sock = None + support.gc_collect() + self.assertIn(r, str(cm.warning.args[0])) + # An open socket file object gets dereferenced after the socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + f = sock.makefile('rb') + r = repr(sock) + sock = None + support.gc_collect() + with self.assertWarns(ResourceWarning): + f = None + support.gc_collect() + + def test_name_closed_socketio(self): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + fp = sock.makefile("rb") + fp.close() + self.assertEqual(repr(fp), "<_io.BufferedReader name=-1>") + + def testListenBacklog0(self): + srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + srv.bind((HOST, 0)) + # backlog = 0 + srv.listen(0) + srv.close() + + +@unittest.skipUnless(thread, 'Threading required for this test.') class BasicTCPTest(SocketConnectedTest): def __init__(self, methodName='runTest'): @@ -742,10 +840,10 @@ class BasicTCPTest(SocketConnectedTest): def testFromFd(self): # Testing fromfd() - if not hasattr(socket, "fromfd"): - return # On Windows, this doesn't exist fd = self.cli_conn.fileno() sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) + self.addCleanup(sock.close) + self.assertIsInstance(sock, socket.socket) msg = sock.recv(1024) self.assertEqual(msg, MSG) @@ -755,6 +853,7 @@ class BasicTCPTest(SocketConnectedTest): def testDup(self): # Testing dup() sock = self.cli_conn.dup() + self.addCleanup(sock.close) msg = sock.recv(1024) self.assertEqual(msg, MSG) @@ -774,6 +873,25 @@ class BasicTCPTest(SocketConnectedTest): self.serv_conn.send(MSG) self.serv_conn.shutdown(2) + def testDetach(self): + # Testing detach() + fileno = self.cli_conn.fileno() + f = self.cli_conn.detach() + self.assertEqual(f, fileno) + # cli_conn cannot be used anymore... + self.assertRaises(socket.error, self.cli_conn.recv, 1024) + self.cli_conn.close() + # ...but we can create another socket using the (still open) + # file descriptor + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=f) + self.addCleanup(sock.close) + msg = sock.recv(1024) + self.assertEqual(msg, MSG) + + def _testDetach(self): + self.serv_conn.send(MSG) + +@unittest.skipUnless(thread, 'Threading required for this test.') class BasicUDPTest(ThreadedUDPSocketTest): def __init__(self, methodName='runTest'): @@ -802,6 +920,7 @@ class BasicUDPTest(ThreadedUDPSocketTest): def _testRecvFromNegative(self): self.cli.sendto(MSG, 0, (HOST, self.port)) +@unittest.skipUnless(thread, 'Threading required for this test.') class TCPCloserTest(ThreadedTCPSocketTest): def testClose(self): @@ -821,11 +940,27 @@ class TCPCloserTest(ThreadedTCPSocketTest): self.cli.connect((HOST, self.port)) time.sleep(1.0) +@unittest.skipUnless(thread, 'Threading required for this test.') class BasicSocketPairTest(SocketPairTest): def __init__(self, methodName='runTest'): SocketPairTest.__init__(self, methodName=methodName) + def _check_defaults(self, sock): + self.assertIsInstance(sock, socket.socket) + if hasattr(socket, 'AF_UNIX'): + self.assertEqual(sock.family, socket.AF_UNIX) + else: + self.assertEqual(sock.family, socket.AF_INET) + self.assertEqual(sock.type, socket.SOCK_STREAM) + self.assertEqual(sock.proto, 0) + + def _testDefaults(self): + self._check_defaults(self.cli) + + def testDefaults(self): + self._check_defaults(self.serv) + def testRecv(self): msg = self.serv.recv(1024) self.assertEqual(msg, MSG) @@ -840,6 +975,7 @@ class BasicSocketPairTest(SocketPairTest): msg = self.cli.recv(1024) self.assertEqual(msg, MSG) +@unittest.skipUnless(thread, 'Threading required for this test.') class NonBlockingTCPTests(ThreadedTCPSocketTest): def __init__(self, methodName='runTest'): @@ -859,6 +995,47 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): def _testSetBlocking(self): pass + if hasattr(socket, "SOCK_NONBLOCK"): + def testInitNonBlocking(self): + v = linux_version() + if v < (2, 6, 28): + self.skipTest("Linux kernel 2.6.28 or higher required, not %s" + % ".".join(map(str, v))) + # reinit server socket + self.serv.close() + self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM | + socket.SOCK_NONBLOCK) + self.port = support.bind_port(self.serv) + self.serv.listen(1) + # actual testing + start = time.time() + try: + self.serv.accept() + except socket.error: + pass + end = time.time() + self.assertTrue((end - start) < 1.0, "Error creating with non-blocking mode.") + + def _testInitNonBlocking(self): + pass + + def testInheritFlags(self): + # Issue #7995: when calling accept() on a listening socket with a + # timeout, the resulting socket should not be non-blocking. + self.serv.settimeout(10) + try: + conn, addr = self.serv.accept() + message = conn.recv(len(MSG)) + finally: + conn.close() + self.serv.settimeout(None) + + def _testInheritFlags(self): + time.sleep(0.1) + self.cli.connect((HOST, self.port)) + time.sleep(0.5) + self.cli.send(MSG) + def testAccept(self): # Testing non-blocking accept self.serv.setblocking(0) @@ -871,6 +1048,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): read, write, err = select.select([self.serv], [], []) if self.serv in read: conn, addr = self.serv.accept() + conn.close() else: self.fail("Error trying to do accept after select.") @@ -881,6 +1059,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): def testConnect(self): # Testing non-blocking connect conn, addr = self.serv.accept() + conn.close() def _testConnect(self): self.cli.settimeout(10) @@ -899,6 +1078,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): read, write, err = select.select([conn], [], []) if conn in read: msg = conn.recv(len(MSG)) + conn.close() self.assertEqual(msg, MSG) else: self.fail("Error during select call to non-blocking socket.") @@ -908,6 +1088,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): time.sleep(0.1) self.cli.send(MSG) +@unittest.skipUnless(thread, 'Threading required for this test.') class FileObjectClassTestCase(SocketConnectedTest): """Unit tests for the object returned by socket.makefile() @@ -934,6 +1115,8 @@ class FileObjectClassTestCase(SocketConnectedTest): SocketConnectedTest.__init__(self, methodName=methodName) def setUp(self): + self.evt1, self.evt2, self.serv_finished, self.cli_finished = [ + threading.Event() for i in range(4)] SocketConnectedTest.setUp(self) self.read_file = self.cli_conn.makefile( self.read_mode, self.bufsize, @@ -942,6 +1125,7 @@ class FileObjectClassTestCase(SocketConnectedTest): newline = self.newline) def tearDown(self): + self.serv_finished.set() self.read_file.close() self.assertTrue(self.read_file.closed) self.read_file = None @@ -956,11 +1140,29 @@ class FileObjectClassTestCase(SocketConnectedTest): newline = self.newline) def clientTearDown(self): + self.cli_finished.set() self.write_file.close() self.assertTrue(self.write_file.closed) self.write_file = None SocketConnectedTest.clientTearDown(self) + def testReadAfterTimeout(self): + # Issue #7322: A file object must disallow further reads + # after a timeout has occurred. + self.cli_conn.settimeout(1) + self.read_file.read(3) + # First read raises a timeout + self.assertRaises(socket.timeout, self.read_file.read, 1) + # Second read is disallowed + with self.assertRaises(IOError) as ctx: + self.read_file.read(1) + self.assertIn("cannot read from timed out object", str(ctx.exception)) + + def _testReadAfterTimeout(self): + self.write_file.write(self.write_msg[0:3]) + self.write_file.flush() + self.serv_finished.wait() + def testSmallRead(self): # Performing small file read test first_seg = self.read_file.read(len(self.read_msg)-3) @@ -1050,6 +1252,117 @@ class FileObjectClassTestCase(SocketConnectedTest): pass +class FileObjectInterruptedTestCase(unittest.TestCase): + """Test that the file object correctly handles EINTR internally.""" + + class MockSocket(object): + def __init__(self, recv_funcs=()): + # A generator that returns callables that we'll call for each + # call to recv(). + self._recv_step = iter(recv_funcs) + + def recv_into(self, buffer): + data = next(self._recv_step)() + assert len(buffer) >= len(data) + buffer[:len(data)] = data + return len(data) + + def _decref_socketios(self): + pass + + def _textiowrap_for_test(self, buffering=-1): + raw = socket.SocketIO(self, "r") + if buffering < 0: + buffering = io.DEFAULT_BUFFER_SIZE + if buffering == 0: + return raw + buffer = io.BufferedReader(raw, buffering) + text = io.TextIOWrapper(buffer, None, None) + text.mode = "rb" + return text + + @staticmethod + def _raise_eintr(): + raise socket.error(errno.EINTR) + + def _textiowrap_mock_socket(self, mock, buffering=-1): + raw = socket.SocketIO(mock, "r") + if buffering < 0: + buffering = io.DEFAULT_BUFFER_SIZE + if buffering == 0: + return raw + buffer = io.BufferedReader(raw, buffering) + text = io.TextIOWrapper(buffer, None, None) + text.mode = "rb" + return text + + def _test_readline(self, size=-1, buffering=-1): + mock_sock = self.MockSocket(recv_funcs=[ + lambda : b"This is the first line\nAnd the sec", + self._raise_eintr, + lambda : b"ond line is here\n", + lambda : b"", + lambda : b"", # XXX(gps): io library does an extra EOF read + ]) + fo = mock_sock._textiowrap_for_test(buffering=buffering) + self.assertEqual(fo.readline(size), "This is the first line\n") + self.assertEqual(fo.readline(size), "And the second line is here\n") + + def _test_read(self, size=-1, buffering=-1): + mock_sock = self.MockSocket(recv_funcs=[ + lambda : b"This is the first line\nAnd the sec", + self._raise_eintr, + lambda : b"ond line is here\n", + lambda : b"", + lambda : b"", # XXX(gps): io library does an extra EOF read + ]) + expecting = (b"This is the first line\n" + b"And the second line is here\n") + fo = mock_sock._textiowrap_for_test(buffering=buffering) + if buffering == 0: + data = b'' + else: + data = '' + expecting = expecting.decode('utf8') + while len(data) != len(expecting): + part = fo.read(size) + if not part: + break + data += part + self.assertEqual(data, expecting) + + def test_default(self): + self._test_readline() + self._test_readline(size=100) + self._test_read() + self._test_read(size=100) + + def test_with_1k_buffer(self): + self._test_readline(buffering=1024) + self._test_readline(size=100, buffering=1024) + self._test_read(buffering=1024) + self._test_read(size=100, buffering=1024) + + def _test_readline_no_buffer(self, size=-1): + mock_sock = self.MockSocket(recv_funcs=[ + lambda : b"a", + lambda : b"\n", + lambda : b"B", + self._raise_eintr, + lambda : b"b", + lambda : b"", + ]) + fo = mock_sock._textiowrap_for_test(buffering=0) + self.assertEqual(fo.readline(size), b"a\n") + self.assertEqual(fo.readline(size), b"Bb") + + def test_no_buffer(self): + self._test_readline_no_buffer() + self._test_readline_no_buffer(size=4) + self._test_read(buffering=0) + self._test_read(size=100, buffering=0) + + class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): """Repeat the tests from FileObjectClassTestCase with bufsize==0. @@ -1097,6 +1410,66 @@ class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): def _testMakefileCloseSocketDestroy(self): pass + # Non-blocking ops + # NOTE: to set `read_file` as non-blocking, we must call + # `cli_conn.setblocking` and vice-versa (see setUp / clientSetUp). + + def testSmallReadNonBlocking(self): + self.cli_conn.setblocking(False) + self.assertEqual(self.read_file.readinto(bytearray(10)), None) + self.assertEqual(self.read_file.read(len(self.read_msg) - 3), None) + self.evt1.set() + self.evt2.wait(1.0) + first_seg = self.read_file.read(len(self.read_msg) - 3) + if first_seg is None: + # Data not arrived (can happen under Windows), wait a bit + time.sleep(0.5) + first_seg = self.read_file.read(len(self.read_msg) - 3) + buf = bytearray(10) + n = self.read_file.readinto(buf) + self.assertEqual(n, 3) + msg = first_seg + buf[:n] + self.assertEqual(msg, self.read_msg) + self.assertEqual(self.read_file.readinto(bytearray(16)), None) + self.assertEqual(self.read_file.read(1), None) + + def _testSmallReadNonBlocking(self): + self.evt1.wait(1.0) + self.write_file.write(self.write_msg) + self.write_file.flush() + self.evt2.set() + # Avoid cloding the socket before the server test has finished, + # otherwise system recv() will return 0 instead of EWOULDBLOCK. + self.serv_finished.wait(5.0) + + def testWriteNonBlocking(self): + self.cli_finished.wait(5.0) + # The client thread can't skip directly - the SkipTest exception + # would appear as a failure. + if self.serv_skipped: + self.skipTest(self.serv_skipped) + + def _testWriteNonBlocking(self): + self.serv_skipped = None + self.serv_conn.setblocking(False) + # Try to saturate the socket buffer pipe with repeated large writes. + BIG = b"x" * (1024 ** 2) + LIMIT = 10 + # The first write() succeeds since a chunk of data can be buffered + n = self.write_file.write(BIG) + self.assertGreater(n, 0) + for i in range(LIMIT): + n = self.write_file.write(BIG) + if n is None: + # Succeeded + break + self.assertGreater(n, 0) + else: + # Let us know that this test didn't manage to establish + # the expected conditions. This is not a failure in itself but, + # if it happens repeatedly, the test should be fixed. + self.serv_skipped = "failed to saturate the socket buffer" + class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase): @@ -1170,23 +1543,18 @@ class NetworkConnectionNoServer(unittest.TestCase): def test_connect(self): port = support.find_unused_port() cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: + self.addCleanup(cli.close) + with self.assertRaises(socket.error) as cm: cli.connect((HOST, port)) - except socket.error as err: - self.assertEqual(err.errno, errno.ECONNREFUSED) - else: - self.fail("socket.error not raised") + self.assertEqual(cm.exception.errno, errno.ECONNREFUSED) def test_create_connection(self): # Issue #9792: errors raised by create_connection() should have # a proper errno attribute. port = support.find_unused_port() - try: + with self.assertRaises(socket.error) as cm: socket.create_connection((HOST, port)) - except socket.error as err: - self.assertEqual(err.errno, errno.ECONNREFUSED) - else: - self.fail("socket.error not raised") + self.assertEqual(cm.exception.errno, errno.ECONNREFUSED) def test_create_connection_timeout(self): # Issue #9792: create_connection() should not recast timeout errors @@ -1196,6 +1564,7 @@ class NetworkConnectionNoServer(unittest.TestCase): socket.create_connection((HOST, 1234)) +@unittest.skipUnless(thread, 'Threading required for this test.') class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest): def __init__(self, methodName='runTest'): @@ -1203,7 +1572,7 @@ class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest): ThreadableTest.__init__(self) def clientSetUp(self): - pass + self.source_port = support.find_unused_port() def clientTearDown(self): self.cli.close() @@ -1212,12 +1581,23 @@ class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest): def _justAccept(self): conn, addr = self.serv.accept() + conn.close() testFamily = _justAccept def _testFamily(self): self.cli = socket.create_connection((HOST, self.port), timeout=30) + self.addCleanup(self.cli.close) self.assertEqual(self.cli.family, 2) + testSourceAddress = _justAccept + def _testSourceAddress(self): + self.cli = socket.create_connection((HOST, self.port), timeout=30, + source_address=('', self.source_port)) + self.addCleanup(self.cli.close) + self.assertEqual(self.cli.getsockname()[1], self.source_port) + # The port number being used is sufficient to show that the bind() + # call happened. + testTimeoutDefault = _justAccept def _testTimeoutDefault(self): # passing no explicit timeout uses socket's global default @@ -1225,6 +1605,7 @@ class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest): socket.setdefaulttimeout(42) try: self.cli = socket.create_connection((HOST, self.port)) + self.addCleanup(self.cli.close) finally: socket.setdefaulttimeout(None) self.assertEqual(self.cli.gettimeout(), 42) @@ -1236,6 +1617,7 @@ class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest): socket.setdefaulttimeout(30) try: self.cli = socket.create_connection((HOST, self.port), timeout=None) + self.addCleanup(self.cli.close) finally: socket.setdefaulttimeout(None) self.assertEqual(self.cli.gettimeout(), None) @@ -1248,8 +1630,10 @@ class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest): testTimeoutValueNonamed = _justAccept def _testTimeoutValueNonamed(self): self.cli = socket.create_connection((HOST, self.port), 30) + self.addCleanup(self.cli.close) self.assertEqual(self.cli.gettimeout(), 30) +@unittest.skipUnless(thread, 'Threading required for this test.') class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest): def __init__(self, methodName='runTest'): @@ -1266,6 +1650,7 @@ class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest): def testInsideTimeout(self): conn, addr = self.serv.accept() + self.addCleanup(conn.close) time.sleep(3) conn.send(b"done!") testOutsideTimeout = testInsideTimeout @@ -1374,27 +1759,28 @@ class TestLinuxAbstractNamespace(unittest.TestCase): def testLinuxAbstractNamespace(self): address = b"\x00python-test-hello\x00\xff" - s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - s1.bind(address) - s1.listen(1) - s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - s2.connect(s1.getsockname()) - s1.accept() - self.assertEqual(s1.getsockname(), address) - self.assertEqual(s2.getpeername(), address) + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s1: + s1.bind(address) + s1.listen(1) + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s2: + s2.connect(s1.getsockname()) + with s1.accept()[0] as s3: + self.assertEqual(s1.getsockname(), address) + self.assertEqual(s2.getpeername(), address) def testMaxName(self): address = b"\x00" + b"h" * (self.UNIX_PATH_MAX - 1) - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - s.bind(address) - self.assertEqual(s.getsockname(), address) + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + s.bind(address) + self.assertEqual(s.getsockname(), address) def testNameOverflow(self): address = "\x00" + "h" * self.UNIX_PATH_MAX - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.assertRaises(socket.error, s.bind, address) + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + self.assertRaises(socket.error, s.bind, address) +@unittest.skipUnless(thread, 'Threading required for this test.') class BufferIOTest(SocketConnectedTest): """ Test the buffer versions of socket.recv() and socket.send(). @@ -1402,28 +1788,64 @@ class BufferIOTest(SocketConnectedTest): def __init__(self, methodName='runTest'): SocketConnectedTest.__init__(self, methodName=methodName) - def testRecvInto(self): + def testRecvIntoArray(self): buf = bytearray(1024) nbytes = self.cli_conn.recv_into(buf) self.assertEqual(nbytes, len(MSG)) msg = buf[:len(MSG)] self.assertEqual(msg, MSG) - def _testRecvInto(self): + def _testRecvIntoArray(self): buf = bytes(MSG) self.serv_conn.send(buf) - def testRecvFromInto(self): + def testRecvIntoBytearray(self): + buf = bytearray(1024) + nbytes = self.cli_conn.recv_into(buf) + self.assertEqual(nbytes, len(MSG)) + msg = buf[:len(MSG)] + self.assertEqual(msg, MSG) + + _testRecvIntoBytearray = _testRecvIntoArray + + def testRecvIntoMemoryview(self): + buf = bytearray(1024) + nbytes = self.cli_conn.recv_into(memoryview(buf)) + self.assertEqual(nbytes, len(MSG)) + msg = buf[:len(MSG)] + self.assertEqual(msg, MSG) + + _testRecvIntoMemoryview = _testRecvIntoArray + + def testRecvFromIntoArray(self): buf = bytearray(1024) nbytes, addr = self.cli_conn.recvfrom_into(buf) self.assertEqual(nbytes, len(MSG)) msg = buf[:len(MSG)] self.assertEqual(msg, MSG) - def _testRecvFromInto(self): + def _testRecvFromIntoArray(self): buf = bytes(MSG) self.serv_conn.send(buf) + def testRecvFromIntoBytearray(self): + buf = bytearray(1024) + nbytes, addr = self.cli_conn.recvfrom_into(buf) + self.assertEqual(nbytes, len(MSG)) + msg = buf[:len(MSG)] + self.assertEqual(msg, MSG) + + _testRecvFromIntoBytearray = _testRecvFromIntoArray + + def testRecvFromIntoMemoryview(self): + buf = bytearray(1024) + nbytes, addr = self.cli_conn.recvfrom_into(memoryview(buf)) + self.assertEqual(nbytes, len(MSG)) + msg = buf[:len(MSG)] + self.assertEqual(msg, MSG) + + _testRecvFromIntoMemoryview = _testRecvFromIntoArray + TIPC_STYPE = 2000 TIPC_LOWER = 200 @@ -1503,15 +1925,122 @@ class TIPCThreadableTest (unittest.TestCase, ThreadableTest): self.cli.close() +@unittest.skipUnless(thread, 'Threading required for this test.') +class ContextManagersTest(ThreadedTCPSocketTest): + + def _testSocketClass(self): + # base test + with socket.socket() as sock: + self.assertFalse(sock._closed) + self.assertTrue(sock._closed) + # close inside with block + with socket.socket() as sock: + sock.close() + self.assertTrue(sock._closed) + # exception inside with block + with socket.socket() as sock: + self.assertRaises(socket.error, sock.sendall, b'foo') + self.assertTrue(sock._closed) + + def testCreateConnectionBase(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + data = conn.recv(1024) + conn.sendall(data) + + def _testCreateConnectionBase(self): + address = self.serv.getsockname() + with socket.create_connection(address) as sock: + self.assertFalse(sock._closed) + sock.sendall(b'foo') + self.assertEqual(sock.recv(1024), b'foo') + self.assertTrue(sock._closed) + + def testCreateConnectionClose(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + data = conn.recv(1024) + conn.sendall(data) + + def _testCreateConnectionClose(self): + address = self.serv.getsockname() + with socket.create_connection(address) as sock: + sock.close() + self.assertTrue(sock._closed) + self.assertRaises(socket.error, sock.sendall, b'foo') + + +@unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"), + "SOCK_CLOEXEC not defined") +@unittest.skipUnless(fcntl, "module fcntl not available") +class CloexecConstantTest(unittest.TestCase): + def test_SOCK_CLOEXEC(self): + v = linux_version() + if v < (2, 6, 28): + self.skipTest("Linux kernel 2.6.28 or higher required, not %s" + % ".".join(map(str, v))) + with socket.socket(socket.AF_INET, + socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s: + self.assertTrue(s.type & socket.SOCK_CLOEXEC) + self.assertTrue(fcntl.fcntl(s, fcntl.F_GETFD) & fcntl.FD_CLOEXEC) + + +@unittest.skipUnless(hasattr(socket, "SOCK_NONBLOCK"), + "SOCK_NONBLOCK not defined") +class NonblockConstantTest(unittest.TestCase): + def checkNonblock(self, s, nonblock=True, timeout=0.0): + if nonblock: + self.assertTrue(s.type & socket.SOCK_NONBLOCK) + self.assertEqual(s.gettimeout(), timeout) + else: + self.assertFalse(s.type & socket.SOCK_NONBLOCK) + self.assertEqual(s.gettimeout(), None) + + def test_SOCK_NONBLOCK(self): + v = linux_version() + if v < (2, 6, 28): + self.skipTest("Linux kernel 2.6.28 or higher required, not %s" + % ".".join(map(str, v))) + # a lot of it seems silly and redundant, but I wanted to test that + # changing back and forth worked ok + with socket.socket(socket.AF_INET, + socket.SOCK_STREAM | socket.SOCK_NONBLOCK) as s: + self.checkNonblock(s) + s.setblocking(1) + self.checkNonblock(s, False) + s.setblocking(0) + self.checkNonblock(s) + s.settimeout(None) + self.checkNonblock(s, False) + s.settimeout(2.0) + self.checkNonblock(s, timeout=2.0) + s.setblocking(1) + self.checkNonblock(s, False) + # defaulttimeout + t = socket.getdefaulttimeout() + socket.setdefaulttimeout(0.0) + with socket.socket() as s: + self.checkNonblock(s) + socket.setdefaulttimeout(None) + with socket.socket() as s: + self.checkNonblock(s, False) + socket.setdefaulttimeout(2.0) + with socket.socket() as s: + self.checkNonblock(s, timeout=2.0) + socket.setdefaulttimeout(None) + with socket.socket() as s: + self.checkNonblock(s, False) + socket.setdefaulttimeout(t) + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, - TestExceptions, BufferIOTest, BasicTCPTest2] - if sys.platform != 'mac': - tests.extend([ BasicUDPTest, UDPTimeoutTest ]) + TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] tests.extend([ NonBlockingTCPTests, FileObjectClassTestCase, + FileObjectInterruptedTestCase, UnbufferedFileObjectClassTestCase, LineBufferedFileObjectClassTestCase, SmallBufferedFileObjectClassTestCase, @@ -1521,6 +2050,9 @@ def test_main(): NetworkConnectionNoServer, NetworkConnectionAttributesTest, NetworkConnectionBehaviourTest, + ContextManagersTest, + CloexecConstantTest, + NonblockConstantTest ]) if hasattr(socket, "socketpair"): tests.append(BasicSocketPairTest) |