diff options
author | Russell Keith-Magee <russell@keith-magee.com> | 2024-07-31 08:24:15 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-31 08:24:15 (GMT) |
commit | f071f01b7b7e19d7d6b3a4b0ec62f820ecb14660 (patch) | |
tree | d4f6a25a18fc0ca8de86794a78cc62646a390d7e | |
parent | d01fd240517e0c5fb679686a93119d0aa6b0fc0f (diff) | |
download | cpython-f071f01b7b7e19d7d6b3a4b0ec62f820ecb14660.zip cpython-f071f01b7b7e19d7d6b3a4b0ec62f820ecb14660.tar.gz cpython-f071f01b7b7e19d7d6b3a4b0ec62f820ecb14660.tar.bz2 |
gh-122133: Rework pure Python socketpair tests to avoid use of importlib.reload. (#122493)
Co-authored-by: Gregory P. Smith <greg@krypto.org>
-rw-r--r-- | Lib/socket.py | 121 | ||||
-rw-r--r-- | Lib/test/test_socket.py | 20 |
2 files changed, 64 insertions, 77 deletions
diff --git a/Lib/socket.py b/Lib/socket.py index 2e6043c..9207101 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -592,16 +592,65 @@ if hasattr(_socket.socket, "share"): return socket(0, 0, 0, info) __all__.append("fromshare") -if hasattr(_socket, "socketpair"): +# Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +# This is used if _socket doesn't natively provide socketpair. It's +# always defined so that it can be patched in for testing purposes. +def _fallback_socketpair(family=AF_INET, type=SOCK_STREAM, proto=0): + if family == AF_INET: + host = _LOCALHOST + elif family == AF_INET6: + host = _LOCALHOST_V6 + else: + raise ValueError("Only AF_INET and AF_INET6 socket address families " + "are supported") + if type != SOCK_STREAM: + raise ValueError("Only SOCK_STREAM socket type is supported") + if proto != 0: + raise ValueError("Only protocol zero is supported") + + # We create a connected TCP socket. Note the trick with + # setblocking(False) that prevents us from having to create a thread. + lsock = socket(family, type, proto) + try: + lsock.bind((host, 0)) + lsock.listen() + # On IPv6, ignore flow_info and scope_id + addr, port = lsock.getsockname()[:2] + csock = socket(family, type, proto) + try: + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + csock.setblocking(True) + ssock, _ = lsock.accept() + except: + csock.close() + raise + finally: + lsock.close() - def socketpair(family=None, type=SOCK_STREAM, proto=0): - """socketpair([family[, type[, proto]]]) -> (socket object, socket object) + # Authenticating avoids using a connection from something else + # able to connect to {host}:{port} instead of us. + # We expect only AF_INET and AF_INET6 families. + try: + if ( + ssock.getsockname() != csock.getpeername() + or csock.getsockname() != ssock.getpeername() + ): + raise ConnectionError("Unexpected peer connection") + except: + # getsockname() and getpeername() can fail + # if either socket isn't connected. + ssock.close() + csock.close() + raise - Create a pair of socket objects from the sockets returned by the platform - socketpair() function. - The arguments are the same as for socket() except the default family is - AF_UNIX if defined on the platform; otherwise, the default is AF_INET. - """ + return (ssock, csock) + +if hasattr(_socket, "socketpair"): + def socketpair(family=None, type=SOCK_STREAM, proto=0): if family is None: try: family = AF_UNIX @@ -613,61 +662,7 @@ if hasattr(_socket, "socketpair"): return a, b else: - - # Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. - def socketpair(family=AF_INET, type=SOCK_STREAM, proto=0): - if family == AF_INET: - host = _LOCALHOST - elif family == AF_INET6: - host = _LOCALHOST_V6 - else: - raise ValueError("Only AF_INET and AF_INET6 socket address families " - "are supported") - if type != SOCK_STREAM: - raise ValueError("Only SOCK_STREAM socket type is supported") - if proto != 0: - raise ValueError("Only protocol zero is supported") - - # We create a connected TCP socket. Note the trick with - # setblocking(False) that prevents us from having to create a thread. - lsock = socket(family, type, proto) - try: - lsock.bind((host, 0)) - lsock.listen() - # On IPv6, ignore flow_info and scope_id - addr, port = lsock.getsockname()[:2] - csock = socket(family, type, proto) - try: - csock.setblocking(False) - try: - csock.connect((addr, port)) - except (BlockingIOError, InterruptedError): - pass - csock.setblocking(True) - ssock, _ = lsock.accept() - except: - csock.close() - raise - finally: - lsock.close() - - # Authenticating avoids using a connection from something else - # able to connect to {host}:{port} instead of us. - # We expect only AF_INET and AF_INET6 families. - try: - if ( - ssock.getsockname() != csock.getpeername() - or csock.getsockname() != ssock.getpeername() - ): - raise ConnectionError("Unexpected peer connection") - except: - # getsockname() and getpeername() can fail - # if either socket isn't connected. - ssock.close() - csock.close() - raise - - return (ssock, csock) + socketpair = _fallback_socketpair __all__.append("socketpair") socketpair.__doc__ = """socketpair([family[, type[, proto]]]) -> (socket object, socket object) diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index bb65c3c..7c607a8 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -4861,7 +4861,6 @@ class BasicSocketPairTest(SocketPairTest): class PurePythonSocketPairTest(SocketPairTest): - # Explicitly use socketpair AF_INET or AF_INET6 to ensure that is the # code path we're using regardless platform is the pure python one where # `_socket.socketpair` does not exist. (AF_INET does not work with @@ -4876,28 +4875,21 @@ class PurePythonSocketPairTest(SocketPairTest): # Local imports in this class make for easy security fix backporting. def setUp(self): - import _socket - self._orig_sp = getattr(_socket, 'socketpair', None) - if self._orig_sp is not None: + if hasattr(_socket, "socketpair"): + self._orig_sp = socket.socketpair # This forces the version using the non-OS provided socketpair # emulation via an AF_INET socket in Lib/socket.py. - del _socket.socketpair - import importlib - global socket - socket = importlib.reload(socket) + socket.socketpair = socket._fallback_socketpair else: - pass # This platform already uses the non-OS provided version. + # This platform already uses the non-OS provided version. + self._orig_sp = None super().setUp() def tearDown(self): super().tearDown() - import _socket if self._orig_sp is not None: # Restore the default socket.socketpair definition. - _socket.socketpair = self._orig_sp - import importlib - global socket - socket = importlib.reload(socket) + socket.socketpair = self._orig_sp def test_recv(self): msg = self.serv.recv(1024) |