diff options
author | Giampaolo Rodola <g.rodola@gmail.com> | 2019-04-08 22:34:02 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-04-08 22:34:02 (GMT) |
commit | eb7e29f2a9d075accc1ab3faf3612ac44f5e2183 (patch) | |
tree | 6d4d31556465bc34e12de0ebc98dd751cb0fc09a /Lib | |
parent | 58721a903074d28151d008d8990c98fc31d1e798 (diff) | |
download | cpython-eb7e29f2a9d075accc1ab3faf3612ac44f5e2183.zip cpython-eb7e29f2a9d075accc1ab3faf3612ac44f5e2183.tar.gz cpython-eb7e29f2a9d075accc1ab3faf3612ac44f5e2183.tar.bz2 |
bpo-35934: Add socket.create_server() utility function (GH-11784)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/ftplib.py | 21 | ||||
-rw-r--r-- | Lib/socket.py | 87 | ||||
-rw-r--r-- | Lib/test/_test_multiprocessing.py | 8 | ||||
-rw-r--r-- | Lib/test/eintrdata/eintr_tester.py | 5 | ||||
-rw-r--r-- | Lib/test/test_asyncio/functional.py | 10 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_events.py | 12 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_streams.py | 11 | ||||
-rw-r--r-- | Lib/test/test_epoll.py | 4 | ||||
-rw-r--r-- | Lib/test/test_ftplib.py | 9 | ||||
-rw-r--r-- | Lib/test/test_httplib.py | 5 | ||||
-rw-r--r-- | Lib/test/test_kqueue.py | 4 | ||||
-rw-r--r-- | Lib/test/test_socket.py | 126 | ||||
-rw-r--r-- | Lib/test/test_ssl.py | 11 | ||||
-rw-r--r-- | Lib/test/test_support.py | 6 |
14 files changed, 232 insertions, 87 deletions
diff --git a/Lib/ftplib.py b/Lib/ftplib.py index 9611282..a9b1aee 100644 --- a/Lib/ftplib.py +++ b/Lib/ftplib.py @@ -302,26 +302,7 @@ class FTP: def makeport(self): '''Create a new socket and send a PORT command for it.''' - err = None - sock = None - for res in socket.getaddrinfo(None, 0, self.af, socket.SOCK_STREAM, 0, socket.AI_PASSIVE): - af, socktype, proto, canonname, sa = res - try: - sock = socket.socket(af, socktype, proto) - sock.bind(sa) - except OSError as _: - err = _ - if sock: - sock.close() - sock = None - continue - break - if sock is None: - if err is not None: - raise err - else: - raise OSError("getaddrinfo returns an empty list") - sock.listen(1) + sock = socket.create_server(("", 0), family=self.af, backlog=1) port = sock.getsockname()[1] # Get proper port host = self.sock.getsockname()[0] # Get proper host if self.af == socket.AF_INET: diff --git a/Lib/socket.py b/Lib/socket.py index 772b9e1..2e51cd1 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -60,8 +60,8 @@ EBADF = getattr(errno, 'EBADF', 9) EAGAIN = getattr(errno, 'EAGAIN', 11) EWOULDBLOCK = getattr(errno, 'EWOULDBLOCK', 11) -__all__ = ["fromfd", "getfqdn", "create_connection", - "AddressFamily", "SocketKind"] +__all__ = ["fromfd", "getfqdn", "create_connection", "create_server", + "has_dualstack_ipv6", "AddressFamily", "SocketKind"] __all__.extend(os._get_exports_list(_socket)) # Set up the socket.AF_* socket.SOCK_* constants as members of IntEnums for @@ -728,6 +728,89 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, else: raise error("getaddrinfo returns an empty list") + +def has_dualstack_ipv6(): + """Return True if the platform supports creating a SOCK_STREAM socket + which can handle both AF_INET and AF_INET6 (IPv4 / IPv6) connections. + """ + if not has_ipv6 \ + or not hasattr(_socket, 'IPPROTO_IPV6') \ + or not hasattr(_socket, 'IPV6_V6ONLY'): + return False + try: + with socket(AF_INET6, SOCK_STREAM) as sock: + sock.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 0) + return True + except error: + return False + + +def create_server(address, *, family=AF_INET, backlog=0, reuse_port=False, + dualstack_ipv6=False): + """Convenience function which creates a SOCK_STREAM type socket + bound to *address* (a 2-tuple (host, port)) and return the socket + object. + + *family* should be either AF_INET or AF_INET6. + *backlog* is the queue size passed to socket.listen(). + *reuse_port* dictates whether to use the SO_REUSEPORT socket option. + *dualstack_ipv6*: if true and the platform supports it, it will + create an AF_INET6 socket able to accept both IPv4 or IPv6 + connections. When false it will explicitly disable this option on + platforms that enable it by default (e.g. Linux). + + >>> with create_server((None, 8000)) as server: + ... while True: + ... conn, addr = server.accept() + ... # handle new connection + """ + if reuse_port and not hasattr(_socket, "SO_REUSEPORT"): + raise ValueError("SO_REUSEPORT not supported on this platform") + if dualstack_ipv6: + if not has_dualstack_ipv6(): + raise ValueError("dualstack_ipv6 not supported on this platform") + if family != AF_INET6: + raise ValueError("dualstack_ipv6 requires AF_INET6 family") + sock = socket(family, SOCK_STREAM) + try: + # Note about Windows. We don't set SO_REUSEADDR because: + # 1) It's unnecessary: bind() will succeed even in case of a + # previous closed socket on the same address and still in + # TIME_WAIT state. + # 2) If set, another socket is free to bind() on the same + # address, effectively preventing this one from accepting + # connections. Also, it may set the process in a state where + # it'll no longer respond to any signals or graceful kills. + # See: msdn2.microsoft.com/en-us/library/ms740621(VS.85).aspx + if os.name not in ('nt', 'cygwin') and \ + hasattr(_socket, 'SO_REUSEADDR'): + try: + sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) + except error: + # Fail later on bind(), for platforms which may not + # support this option. + pass + if reuse_port: + sock.setsockopt(SOL_SOCKET, SO_REUSEPORT, 1) + if has_ipv6 and family == AF_INET6: + if dualstack_ipv6: + sock.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 0) + elif hasattr(_socket, "IPV6_V6ONLY") and \ + hasattr(_socket, "IPPROTO_IPV6"): + sock.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) + try: + sock.bind(address) + except error as err: + msg = '%s (while attempting to bind on address %r)' % \ + (err.strerror, address) + raise error(err.errno, msg) from None + sock.listen(backlog) + return sock + except error: + sock.close() + raise + + def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): """Resolve host and port into list of address info entries. diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index f4239ba..553ab81 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -3334,9 +3334,7 @@ class _TestPicklingConnections(BaseTestCase): new_conn.close() l.close() - l = socket.socket() - l.bind((test.support.HOST, 0)) - l.listen() + l = socket.create_server((test.support.HOST, 0)) conn.send(l.getsockname()) new_conn, addr = l.accept() conn.send(new_conn) @@ -4345,9 +4343,7 @@ class TestWait(unittest.TestCase): def test_wait_socket(self, slow=False): from multiprocessing.connection import wait - l = socket.socket() - l.bind((test.support.HOST, 0)) - l.listen() + l = socket.create_server((test.support.HOST, 0)) addr = l.getsockname() readers = [] procs = [] diff --git a/Lib/test/eintrdata/eintr_tester.py b/Lib/test/eintrdata/eintr_tester.py index 5f956b5..404934c 100644 --- a/Lib/test/eintrdata/eintr_tester.py +++ b/Lib/test/eintrdata/eintr_tester.py @@ -285,12 +285,9 @@ class SocketEINTRTest(EINTRBaseTest): self._test_send(lambda sock, data: sock.sendmsg([data])) def test_accept(self): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock = socket.create_server((support.HOST, 0)) self.addCleanup(sock.close) - - sock.bind((support.HOST, 0)) port = sock.getsockname()[1] - sock.listen() code = '\n'.join(( 'import socket, time', diff --git a/Lib/test/test_asyncio/functional.py b/Lib/test/test_asyncio/functional.py index 6b5b3cc..70cd140 100644 --- a/Lib/test/test_asyncio/functional.py +++ b/Lib/test/test_asyncio/functional.py @@ -60,21 +60,13 @@ class FunctionalTestCaseMixin: else: addr = ('127.0.0.1', 0) - sock = socket.socket(family, socket.SOCK_STREAM) - + sock = socket.create_server(addr, family=family, backlog=backlog) if timeout is None: raise RuntimeError('timeout is required') if timeout <= 0: raise RuntimeError('only blocking sockets are supported') sock.settimeout(timeout) - try: - sock.bind(addr) - sock.listen(backlog) - except OSError as ex: - sock.close() - raise ex - return TestThreadedServer( self, sock, server_prog, timeout, max_clients) diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index a2b954e..b46b614 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -667,9 +667,7 @@ class EventLoopTestsMixin: super().data_received(data) self.transport.write(expected_response) - lsock = socket.socket() - lsock.bind(('127.0.0.1', 0)) - lsock.listen(1) + lsock = socket.create_server(('127.0.0.1', 0), backlog=1) addr = lsock.getsockname() message = b'test data' @@ -1118,9 +1116,7 @@ class EventLoopTestsMixin: super().connection_made(transport) proto.set_result(self) - sock_ob = socket.socket(type=socket.SOCK_STREAM) - sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock_ob.bind(('0.0.0.0', 0)) + sock_ob = socket.create_server(('0.0.0.0', 0)) f = self.loop.create_server(TestMyProto, sock=sock_ob) server = self.loop.run_until_complete(f) @@ -1136,9 +1132,7 @@ class EventLoopTestsMixin: server.close() def test_create_server_addr_in_use(self): - sock_ob = socket.socket(type=socket.SOCK_STREAM) - sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock_ob.bind(('0.0.0.0', 0)) + sock_ob = socket.create_server(('0.0.0.0', 0)) f = self.loop.create_server(MyProto, sock=sock_ob) server = self.loop.run_until_complete(f) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 043fac7..630f91d 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -592,8 +592,7 @@ class StreamTests(test_utils.TestCase): await client_writer.wait_closed() def start(self): - sock = socket.socket() - sock.bind(('127.0.0.1', 0)) + sock = socket.create_server(('127.0.0.1', 0)) self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client, sock=sock, @@ -605,8 +604,7 @@ class StreamTests(test_utils.TestCase): client_writer)) def start_callback(self): - sock = socket.socket() - sock.bind(('127.0.0.1', 0)) + sock = socket.create_server(('127.0.0.1', 0)) addr = sock.getsockname() sock.close() self.server = self.loop.run_until_complete( @@ -796,10 +794,7 @@ os.close(fd) def server(): # Runs in a separate thread. - sock = socket.socket() - with sock: - sock.bind(('localhost', 0)) - sock.listen(1) + with socket.create_server(('localhost', 0)) as sock: addr = sock.getsockname() q.put(addr) clt, _ = sock.accept() diff --git a/Lib/test/test_epoll.py b/Lib/test/test_epoll.py index 53ce1d5..8ac0f31 100644 --- a/Lib/test/test_epoll.py +++ b/Lib/test/test_epoll.py @@ -41,9 +41,7 @@ except OSError as e: class TestEPoll(unittest.TestCase): def setUp(self): - self.serverSocket = socket.socket() - self.serverSocket.bind(('127.0.0.1', 0)) - self.serverSocket.listen() + self.serverSocket = socket.create_server(('127.0.0.1', 0)) self.connections = [self.serverSocket] def tearDown(self): diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index da8ba32..b0e4641 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -132,9 +132,7 @@ class DummyFTPHandler(asynchat.async_chat): self.push('200 active data connection established') def cmd_pasv(self, arg): - with socket.socket() as sock: - sock.bind((self.socket.getsockname()[0], 0)) - sock.listen() + with socket.create_server((self.socket.getsockname()[0], 0)) as sock: sock.settimeout(TIMEOUT) ip, port = sock.getsockname()[:2] ip = ip.replace('.', ','); p1 = port / 256; p2 = port % 256 @@ -150,9 +148,8 @@ class DummyFTPHandler(asynchat.async_chat): self.push('200 active data connection established') def cmd_epsv(self, arg): - with socket.socket(socket.AF_INET6) as sock: - sock.bind((self.socket.getsockname()[0], 0)) - sock.listen() + with socket.create_server((self.socket.getsockname()[0], 0), + family=socket.AF_INET6) as sock: sock.settimeout(TIMEOUT) port = sock.getsockname()[1] self.push('229 entering extended passive mode (|||%d|)' %port) diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index 4755f8b..6591461 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -1118,11 +1118,8 @@ class BasicTest(TestCase): def test_response_fileno(self): # Make sure fd returned by fileno is valid. - serv = socket.socket( - socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) + serv = socket.create_server((HOST, 0)) self.addCleanup(serv.close) - serv.bind((HOST, 0)) - serv.listen() result = None def run_server(): diff --git a/Lib/test/test_kqueue.py b/Lib/test/test_kqueue.py index 1099c75..998fd9d 100644 --- a/Lib/test/test_kqueue.py +++ b/Lib/test/test_kqueue.py @@ -110,9 +110,7 @@ class TestKQueue(unittest.TestCase): def test_queue_event(self): - serverSocket = socket.socket() - serverSocket.bind(('127.0.0.1', 0)) - serverSocket.listen() + serverSocket = socket.create_server(('127.0.0.1', 0)) client = socket.socket() client.setblocking(False) try: diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 8a990ea..b0bdb11 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -6068,9 +6068,133 @@ class TestMSWindowsTCPFlags(unittest.TestCase): self.assertEqual([], unknown, "New TCP flags were discovered. See bpo-32394 for more information") + +class CreateServerTest(unittest.TestCase): + + def test_address(self): + port = support.find_unused_port() + with socket.create_server(("127.0.0.1", port)) as sock: + self.assertEqual(sock.getsockname()[0], "127.0.0.1") + self.assertEqual(sock.getsockname()[1], port) + if support.IPV6_ENABLED: + with socket.create_server(("::1", port), + family=socket.AF_INET6) as sock: + self.assertEqual(sock.getsockname()[0], "::1") + self.assertEqual(sock.getsockname()[1], port) + + def test_family_and_type(self): + with socket.create_server(("127.0.0.1", 0)) as sock: + self.assertEqual(sock.family, socket.AF_INET) + self.assertEqual(sock.type, socket.SOCK_STREAM) + if support.IPV6_ENABLED: + with socket.create_server(("::1", 0), family=socket.AF_INET6) as s: + self.assertEqual(s.family, socket.AF_INET6) + self.assertEqual(sock.type, socket.SOCK_STREAM) + + def test_reuse_port(self): + if not hasattr(socket, "SO_REUSEPORT"): + with self.assertRaises(ValueError): + socket.create_server(("localhost", 0), reuse_port=True) + else: + with socket.create_server(("localhost", 0)) as sock: + opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) + self.assertEqual(opt, 0) + with socket.create_server(("localhost", 0), reuse_port=True) as sock: + opt = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) + self.assertNotEqual(opt, 0) + + @unittest.skipIf(not hasattr(_socket, 'IPPROTO_IPV6') or + not hasattr(_socket, 'IPV6_V6ONLY'), + "IPV6_V6ONLY option not supported") + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test') + def test_ipv6_only_default(self): + with socket.create_server(("::1", 0), family=socket.AF_INET6) as sock: + assert sock.getsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY) + + @unittest.skipIf(not socket.has_dualstack_ipv6(), + "dualstack_ipv6 not supported") + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test') + def test_dualstack_ipv6_family(self): + with socket.create_server(("::1", 0), family=socket.AF_INET6, + dualstack_ipv6=True) as sock: + self.assertEqual(sock.family, socket.AF_INET6) + + +class CreateServerFunctionalTest(unittest.TestCase): + timeout = 3 + + def setUp(self): + self.thread = None + + def tearDown(self): + if self.thread is not None: + self.thread.join(self.timeout) + + def echo_server(self, sock): + def run(sock): + with sock: + conn, _ = sock.accept() + with conn: + event.wait(self.timeout) + msg = conn.recv(1024) + if not msg: + return + conn.sendall(msg) + + event = threading.Event() + sock.settimeout(self.timeout) + self.thread = threading.Thread(target=run, args=(sock, )) + self.thread.start() + event.set() + + def echo_client(self, addr, family): + with socket.socket(family=family) as sock: + sock.settimeout(self.timeout) + sock.connect(addr) + sock.sendall(b'foo') + self.assertEqual(sock.recv(1024), b'foo') + + def test_tcp4(self): + port = support.find_unused_port() + with socket.create_server(("", port)) as sock: + self.echo_server(sock) + self.echo_client(("127.0.0.1", port), socket.AF_INET) + + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test') + def test_tcp6(self): + port = support.find_unused_port() + with socket.create_server(("", port), + family=socket.AF_INET6) as sock: + self.echo_server(sock) + self.echo_client(("::1", port), socket.AF_INET6) + + # --- dual stack tests + + @unittest.skipIf(not socket.has_dualstack_ipv6(), + "dualstack_ipv6 not supported") + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test') + def test_dual_stack_client_v4(self): + port = support.find_unused_port() + with socket.create_server(("", port), family=socket.AF_INET6, + dualstack_ipv6=True) as sock: + self.echo_server(sock) + self.echo_client(("127.0.0.1", port), socket.AF_INET) + + @unittest.skipIf(not socket.has_dualstack_ipv6(), + "dualstack_ipv6 not supported") + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test') + def test_dual_stack_client_v6(self): + port = support.find_unused_port() + with socket.create_server(("", port), family=socket.AF_INET6, + dualstack_ipv6=True) as sock: + self.echo_server(sock) + self.echo_client(("::1", port), socket.AF_INET6) + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, - TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] + TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, + UDPTimeoutTest, CreateServerTest, CreateServerFunctionalTest] tests.extend([ NonBlockingTCPTests, diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 5571822..4444e94 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -765,9 +765,7 @@ class BasicSocketTests(unittest.TestCase): def test_unknown_channel_binding(self): # should raise ValueError for unknown type - s = socket.socket(socket.AF_INET) - s.bind(('127.0.0.1', 0)) - s.listen() + s = socket.create_server(('127.0.0.1', 0)) c = socket.socket(socket.AF_INET) c.connect(s.getsockname()) with test_wrap_socket(c, do_handshake_on_connect=False) as ss: @@ -1663,11 +1661,8 @@ class SSLErrorTests(unittest.TestCase): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - s.listen() - c = socket.socket() - c.connect(s.getsockname()) + with socket.create_server(("127.0.0.1", 0)) as s: + c = socket.create_connection(s.getsockname()) c.setblocking(False) with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c: with self.assertRaises(ssl.SSLWantReadError) as cm: diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 4a8f3c5..cb664ba 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -91,14 +91,12 @@ class TestSupport(unittest.TestCase): support.rmtree('__pycache__') def test_HOST(self): - s = socket.socket() - s.bind((support.HOST, 0)) + s = socket.create_server((support.HOST, 0)) s.close() def test_find_unused_port(self): port = support.find_unused_port() - s = socket.socket() - s.bind((support.HOST, port)) + s = socket.create_server((support.HOST, port)) s.close() def test_bind_port(self): |