diff options
author | Antoine Pitrou <solipsis@pitrou.net> | 2011-12-21 15:52:40 (GMT) |
---|---|---|
committer | Antoine Pitrou <solipsis@pitrou.net> | 2011-12-21 15:52:40 (GMT) |
commit | 5b95eb90a7167285b6544b50865227c584943c9a (patch) | |
tree | a9780a14d897166471dc5c7516a844f161db34bb | |
parent | 17c07134a9619d110ad53f7a202612bfc304864e (diff) | |
download | cpython-5b95eb90a7167285b6544b50865227c584943c9a.zip cpython-5b95eb90a7167285b6544b50865227c584943c9a.tar.gz cpython-5b95eb90a7167285b6544b50865227c584943c9a.tar.bz2 |
Use context managers in test_ssl to simplify test writing.
-rw-r--r-- | Lib/test/test_ssl.py | 102 |
1 files changed, 38 insertions, 64 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index ba1d868..e5addf8 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -532,6 +532,14 @@ else: threading.Thread.__init__(self) self.daemon = True + def __enter__(self): + self.start(threading.Event()) + self.flag.wait() + + def __exit__(self, *args): + self.stop() + self.join() + def start(self, flag=None): self.flag = flag threading.Thread.start(self) @@ -638,6 +646,20 @@ else: def __str__(self): return "<%s %s>" % (self.__class__.__name__, self.server) + def __enter__(self): + self.start(threading.Event()) + self.flag.wait() + + def __exit__(self, *args): + if test_support.verbose: + sys.stdout.write(" cleanup: stopping server.\n") + self.stop() + if test_support.verbose: + sys.stdout.write(" cleanup: joining server thread.\n") + self.join() + if test_support.verbose: + sys.stdout.write(" cleanup: successfully joined.\n") + def start(self, flag=None): self.flag = flag threading.Thread.start(self) @@ -752,12 +774,7 @@ else: server = ThreadedEchoServer(CERTFILE, certreqs=ssl.CERT_REQUIRED, cacerts=CERTFILE, chatty=False) - flag = threading.Event() - server.start(flag) - # wait for it to start - flag.wait() - # try to connect - try: + with server: try: s = ssl.wrap_socket(socket.socket(), certfile=certfile, @@ -771,9 +788,6 @@ else: sys.stdout.write("\nsocket.error is %s\n" % x[1]) else: raise AssertionError("Use of invalid cert should have failed!") - finally: - server.stop() - server.join() def server_params_test(certfile, protocol, certreqs, cacertsfile, client_certfile, client_protocol=None, indata="FOO\n", @@ -791,14 +805,10 @@ else: chatty=chatty, connectionchatty=connectionchatty, wrap_accepting_socket=wrap_accepting_socket) - flag = threading.Event() - server.start(flag) - # wait for it to start - flag.wait() - # try to connect - if client_protocol is None: - client_protocol = protocol - try: + with server: + # try to connect + if client_protocol is None: + client_protocol = protocol s = ssl.wrap_socket(socket.socket(), certfile=client_certfile, ca_certs=cacertsfile, @@ -826,9 +836,6 @@ else: if test_support.verbose: sys.stdout.write(" client: closing connection.\n") s.close() - finally: - server.stop() - server.join() def try_protocol_combo(server_protocol, client_protocol, @@ -930,12 +937,7 @@ else: ssl_version=ssl.PROTOCOL_SSLv23, cacerts=CERTFILE, chatty=False) - flag = threading.Event() - server.start(flag) - # wait for it to start - flag.wait() - # try to connect - try: + with server: s = ssl.wrap_socket(socket.socket(), certfile=CERTFILE, ca_certs=CERTFILE, @@ -957,9 +959,6 @@ else: "Missing or invalid 'organizationName' field in certificate subject; " "should be 'Python Software Foundation'.") s.close() - finally: - server.stop() - server.join() def test_empty_cert(self): """Connecting with an empty cert file""" @@ -1042,13 +1041,8 @@ else: starttls_server=True, chatty=True, connectionchatty=True) - flag = threading.Event() - server.start(flag) - # wait for it to start - flag.wait() - # try to connect wrapped = False - try: + with server: s = socket.socket() s.setblocking(1) s.connect((HOST, server.port)) @@ -1093,9 +1087,6 @@ else: else: s.send("over\n") s.close() - finally: - server.stop() - server.join() def test_socketserver(self): """Using a SocketServer to create and manage SSL connections.""" @@ -1145,12 +1136,7 @@ else: if test_support.verbose: sys.stdout.write("\n") server = AsyncoreEchoServer(CERTFILE) - flag = threading.Event() - server.start(flag) - # wait for it to start - flag.wait() - # try to connect - try: + with server: s = ssl.wrap_socket(socket.socket()) s.connect(('127.0.0.1', server.port)) if test_support.verbose: @@ -1169,10 +1155,6 @@ else: if test_support.verbose: sys.stdout.write(" client: closing connection.\n") s.close() - finally: - server.stop() - # wait for server thread to end - server.join() def test_recv_send(self): """Test recv(), send() and friends.""" @@ -1185,19 +1167,14 @@ else: cacerts=CERTFILE, chatty=True, connectionchatty=False) - flag = threading.Event() - server.start(flag) - # wait for it to start - flag.wait() - # try to connect - s = ssl.wrap_socket(socket.socket(), - server_side=False, - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) - s.connect((HOST, server.port)) - try: + with server: + s = ssl.wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) # helper methods for standardising recv* method signatures def _recv_into(): b = bytearray("\0"*100) @@ -1285,9 +1262,6 @@ else: s.write("over\n".encode("ASCII", "strict")) s.close() - finally: - server.stop() - server.join() def test_handshake_timeout(self): # Issue #5103: SSL handshake must respect the socket timeout |