diff options
author | Giampaolo Rodola <g.rodola@gmail.com> | 2019-04-08 22:34:02 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-04-08 22:34:02 (GMT) |
commit | eb7e29f2a9d075accc1ab3faf3612ac44f5e2183 (patch) | |
tree | 6d4d31556465bc34e12de0ebc98dd751cb0fc09a /Lib/test | |
parent | 58721a903074d28151d008d8990c98fc31d1e798 (diff) | |
download | cpython-eb7e29f2a9d075accc1ab3faf3612ac44f5e2183.zip cpython-eb7e29f2a9d075accc1ab3faf3612ac44f5e2183.tar.gz cpython-eb7e29f2a9d075accc1ab3faf3612ac44f5e2183.tar.bz2 |
bpo-35934: Add socket.create_server() utility function (GH-11784)
Diffstat (limited to 'Lib/test')
-rw-r--r-- | Lib/test/_test_multiprocessing.py | 8 | ||||
-rw-r--r-- | Lib/test/eintrdata/eintr_tester.py | 5 | ||||
-rw-r--r-- | Lib/test/test_asyncio/functional.py | 10 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_events.py | 12 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_streams.py | 11 | ||||
-rw-r--r-- | Lib/test/test_epoll.py | 4 | ||||
-rw-r--r-- | Lib/test/test_ftplib.py | 9 | ||||
-rw-r--r-- | Lib/test/test_httplib.py | 5 | ||||
-rw-r--r-- | Lib/test/test_kqueue.py | 4 | ||||
-rw-r--r-- | Lib/test/test_socket.py | 126 | ||||
-rw-r--r-- | Lib/test/test_ssl.py | 11 | ||||
-rw-r--r-- | Lib/test/test_support.py | 6 |
12 files changed, 146 insertions, 65 deletions
diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index f4239ba..553ab81 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -3334,9 +3334,7 @@ class _TestPicklingConnections(BaseTestCase): new_conn.close() l.close() - l = socket.socket() - l.bind((test.support.HOST, 0)) - l.listen() + l = socket.create_server((test.support.HOST, 0)) conn.send(l.getsockname()) new_conn, addr = l.accept() conn.send(new_conn) @@ -4345,9 +4343,7 @@ class TestWait(unittest.TestCase): def test_wait_socket(self, slow=False): from multiprocessing.connection import wait - l = socket.socket() - l.bind((test.support.HOST, 0)) - l.listen() + l = socket.create_server((test.support.HOST, 0)) addr = l.getsockname() readers = [] procs = [] diff --git a/Lib/test/eintrdata/eintr_tester.py b/Lib/test/eintrdata/eintr_tester.py index 5f956b5..404934c 100644 --- a/Lib/test/eintrdata/eintr_tester.py +++ b/Lib/test/eintrdata/eintr_tester.py @@ -285,12 +285,9 @@ class SocketEINTRTest(EINTRBaseTest): self._test_send(lambda sock, data: sock.sendmsg([data])) def test_accept(self): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock = socket.create_server((support.HOST, 0)) self.addCleanup(sock.close) - - sock.bind((support.HOST, 0)) port = sock.getsockname()[1] - sock.listen() code = '\n'.join(( 'import socket, time', diff --git a/Lib/test/test_asyncio/functional.py b/Lib/test/test_asyncio/functional.py index 6b5b3cc..70cd140 100644 --- a/Lib/test/test_asyncio/functional.py +++ b/Lib/test/test_asyncio/functional.py @@ -60,21 +60,13 @@ class FunctionalTestCaseMixin: else: addr = ('127.0.0.1', 0) - sock = socket.socket(family, socket.SOCK_STREAM) - + sock = socket.create_server(addr, family=family, backlog=backlog) if timeout is None: raise RuntimeError('timeout is required') if timeout <= 0: raise RuntimeError('only blocking sockets are supported') sock.settimeout(timeout) - try: - sock.bind(addr) - sock.listen(backlog) - except OSError as ex: - sock.close() - raise ex - return TestThreadedServer( self, sock, server_prog, timeout, max_clients) diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index a2b954e..b46b614 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -667,9 +667,7 @@ class EventLoopTestsMixin: super().data_received(data) self.transport.write(expected_response) - lsock = socket.socket() - lsock.bind(('127.0.0.1', 0)) - lsock.listen(1) + lsock = socket.create_server(('127.0.0.1', 0), backlog=1) addr = lsock.getsockname() message = b'test data' @@ -1118,9 +1116,7 @@ class EventLoopTestsMixin: super().connection_made(transport) proto.set_result(self) - sock_ob = socket.socket(type=socket.SOCK_STREAM) - sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock_ob.bind(('0.0.0.0', 0)) + sock_ob = socket.create_server(('0.0.0.0', 0)) f = self.loop.create_server(TestMyProto, sock=sock_ob) server = self.loop.run_until_complete(f) @@ -1136,9 +1132,7 @@ class EventLoopTestsMixin: server.close() def test_create_server_addr_in_use(self): - sock_ob = socket.socket(type=socket.SOCK_STREAM) - sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock_ob.bind(('0.0.0.0', 0)) + sock_ob = socket.create_server(('0.0.0.0', 0)) f = self.loop.create_server(MyProto, sock=sock_ob) server = self.loop.run_until_complete(f) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 043fac7..630f91d 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -592,8 +592,7 @@ class StreamTests(test_utils.TestCase): await client_writer.wait_closed() def start(self): - sock = socket.socket() - sock.bind(('127.0.0.1', 0)) + sock = socket.create_server(('127.0.0.1', 0)) self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client, sock=sock, @@ -605,8 +604,7 @@ class StreamTests(test_utils.TestCase): client_writer)) def start_callback(self): - sock = socket.socket() - sock.bind(('127.0.0.1', 0)) + sock = socket.create_server(('127.0.0.1', 0)) addr = sock.getsockname() sock.close() self.server = self.loop.run_until_complete( @@ -796,10 +794,7 @@ os.close(fd) def server(): # Runs in a separate thread. - sock = socket.socket() - with sock: - sock.bind(('localhost', 0)) - sock.listen(1) + with socket.create_server(('localhost', 0)) as sock: addr = sock.getsockname() q.put(addr) clt, _ = sock.accept() diff --git a/Lib/test/test_epoll.py b/Lib/test/test_epoll.py index 53ce1d5..8ac0f31 100644 --- a/Lib/test/test_epoll.py +++ b/Lib/test/test_epoll.py @@ -41,9 +41,7 @@ except OSError as e: class TestEPoll(unittest.TestCase): def setUp(self): - self.serverSocket = socket.socket() - self.serverSocket.bind(('127.0.0.1', 0)) - self.serverSocket.listen() + self.serverSocket = socket.create_server(('127.0.0.1', 0)) self.connections = [self.serverSocket] def tearDown(self): diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index da8ba32..b0e4641 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -132,9 +132,7 @@ class DummyFTPHandler(asynchat.async_chat): self.push('200 active data connection established') def cmd_pasv(self, arg): - with socket.socket() as sock: - sock.bind((self.socket.getsockname()[0], 0)) - sock.listen() + with socket.create_server((self.socket.getsockname()[0], 0)) as sock: sock.settimeout(TIMEOUT) ip, port = sock.getsockname()[:2] ip = ip.replace('.', ','); p1 = port / 256; p2 = port % 256 @@ -150,9 +148,8 @@ class DummyFTPHandler(asynchat.async_chat): self.push('200 active data connection established') def cmd_epsv(self, arg): - with socket.socket(socket.AF_INET6) as sock: - sock.bind((self.socket.getsockname()[0], 0)) - sock.listen() + with socket.create_server((self.socket.getsockname()[0], 0), + family=socket.AF_INET6) as sock: sock.settimeout(TIMEOUT) port = sock.getsockname()[1] self.push('229 entering extended passive mode (|||%d|)' %port) diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index 4755f8b..6591461 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -1118,11 +1118,8 @@ class BasicTest(TestCase): def test_response_fileno(self): # Make sure fd returned by fileno is valid. - serv = socket.socket( - socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) + serv = socket.create_server((HOST, 0)) self.addCleanup(serv.close) - serv.bind((HOST, 0)) - serv.listen() result = None def run_server(): diff --git a/Lib/test/test_kqueue.py b/Lib/test/test_kqueue.py index 1099c75..998fd9d 100644 --- a/Lib/test/test_kqueue.py +++ b/Lib/test/test_kqueue.py @@ -110,9 +110,7 @@ class TestKQueue(unittest.TestCase): def test_queue_event(self): - serverSocket = socket.socket() - serverSocket.bind(('127.0.0.1', 0)) - serverSocket.listen() + serverSocket = socket.create_server(('127.0.0.1', 0)) client = socket.socket() client.setblocking(False) try: diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 8a990ea..b0bdb11 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -6068,9 +6068,133 @@ class TestMSWindowsTCPFlags(unittest.TestCase): self.assertEqual([], unknown, "New TCP flags were discovered. See bpo-32394 for more information") + +class CreateServerTest(unittest.TestCase): + + def test_address(self): + port = support.find_unused_port() + with socket.create_server(("127.0.0.1", port)) as sock: + self.assertEqual(sock.getsockname()[0], "127.0.0.1") + self.assertEqual(sock.getsockname()[1], port) + if support.IPV6_ENABLED: + with socket.create_server(("::1", port), + family=socket.AF_INET6) as sock: + self.assertEqual(sock.getsockname()[0], "::1") + self.assertEqual(sock.getsockname()[1], port) + + def test_family_and_type(self): + with socket.create_server(("127.0.0.1", 0)) as sock: + self.assertEqual(sock.family, socket.AF_INET) + self.assertEqual(sock.type, socket.SOCK_STREAM) + if support.IPV6_ENABLED: + with socket.create_server(("::1", 0), family=socket.AF_INET6) as s: + self.assertEqual(s.family, socket.AF_INET6) + self.assertEqual(sock.type, socket.SOCK_STREAM) + + def test_reuse_port(self): + if not hasattr(socket, "SO_REUSEPORT"): + with self.assertRaises(ValueError): + socket.create_server(("localhost", 0), reuse_port=True) + else: + with socket.create_server(("localhost", 0)) as sock: + opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) + self.assertEqual(opt, 0) + with socket.create_server(("localhost", 0), reuse_port=True) as sock: + opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) + self.assertNotEqual(opt, 0) + + @unittest.skipIf(not hasattr(_socket, 'IPPROTO_IPV6') or + not hasattr(_socket, 'IPV6_V6ONLY'), + "IPV6_V6ONLY option not supported") + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test') + def test_ipv6_only_default(self): + with socket.create_server(("::1", 0), family=socket.AF_INET6) as sock: + assert sock.getsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY) + + @unittest.skipIf(not socket.has_dualstack_ipv6(), + "dualstack_ipv6 not supported") + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test') + def test_dualstack_ipv6_family(self): + with socket.create_server(("::1", 0), family=socket.AF_INET6, + dualstack_ipv6=True) as sock: + self.assertEqual(sock.family, socket.AF_INET6) + + +class CreateServerFunctionalTest(unittest.TestCase): + timeout = 3 + + def setUp(self): + self.thread = None + + def tearDown(self): + if self.thread is not None: + self.thread.join(self.timeout) + + def echo_server(self, sock): + def run(sock): + with sock: + conn, _ = sock.accept() + with conn: + event.wait(self.timeout) + msg = conn.recv(1024) + if not msg: + return + conn.sendall(msg) + + event = threading.Event() + sock.settimeout(self.timeout) + self.thread = threading.Thread(target=run, args=(sock, )) + self.thread.start() + event.set() + + def echo_client(self, addr, family): + with socket.socket(family=family) as sock: + sock.settimeout(self.timeout) + sock.connect(addr) + sock.sendall(b'foo') + self.assertEqual(sock.recv(1024), b'foo') + + def test_tcp4(self): + port = support.find_unused_port() + with socket.create_server(("", port)) as sock: + self.echo_server(sock) + self.echo_client(("127.0.0.1", port), socket.AF_INET) + + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test') + def test_tcp6(self): + port = support.find_unused_port() + with socket.create_server(("", port), + family=socket.AF_INET6) as sock: + self.echo_server(sock) + self.echo_client(("::1", port), socket.AF_INET6) + + # --- dual stack tests + + @unittest.skipIf(not socket.has_dualstack_ipv6(), + "dualstack_ipv6 not supported") + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test') + def test_dual_stack_client_v4(self): + port = support.find_unused_port() + with socket.create_server(("", port), family=socket.AF_INET6, + dualstack_ipv6=True) as sock: + self.echo_server(sock) + self.echo_client(("127.0.0.1", port), socket.AF_INET) + + @unittest.skipIf(not socket.has_dualstack_ipv6(), + "dualstack_ipv6 not supported") + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test') + def test_dual_stack_client_v6(self): + port = support.find_unused_port() + with socket.create_server(("", port), family=socket.AF_INET6, + dualstack_ipv6=True) as sock: + self.echo_server(sock) + self.echo_client(("::1", port), socket.AF_INET6) + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, - TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] + TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, + UDPTimeoutTest, CreateServerTest, CreateServerFunctionalTest] tests.extend([ NonBlockingTCPTests, diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 5571822..4444e94 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -765,9 +765,7 @@ class BasicSocketTests(unittest.TestCase): def test_unknown_channel_binding(self): # should raise ValueError for unknown type - s = socket.socket(socket.AF_INET) - s.bind(('127.0.0.1', 0)) - s.listen() + s = socket.create_server(('127.0.0.1', 0)) c = socket.socket(socket.AF_INET) c.connect(s.getsockname()) with test_wrap_socket(c, do_handshake_on_connect=False) as ss: @@ -1663,11 +1661,8 @@ class SSLErrorTests(unittest.TestCase): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - s.listen() - c = socket.socket() - c.connect(s.getsockname()) + with socket.create_server(("127.0.0.1", 0)) as s: + c = socket.create_connection(s.getsockname()) c.setblocking(False) with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c: with self.assertRaises(ssl.SSLWantReadError) as cm: diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 4a8f3c5..cb664ba 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -91,14 +91,12 @@ class TestSupport(unittest.TestCase): support.rmtree('__pycache__') def test_HOST(self): - s = socket.socket() - s.bind((support.HOST, 0)) + s = socket.create_server((support.HOST, 0)) s.close() def test_find_unused_port(self): port = support.find_unused_port() - s = socket.socket() - s.bind((support.HOST, port)) + s = socket.create_server((support.HOST, port)) s.close() def test_bind_port(self): |