summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAntoine Pitrou <solipsis@pitrou.net>2011-12-21 15:52:40 (GMT)
committerAntoine Pitrou <solipsis@pitrou.net>2011-12-21 15:52:40 (GMT)
commit5b95eb90a7167285b6544b50865227c584943c9a (patch)
treea9780a14d897166471dc5c7516a844f161db34bb
parent17c07134a9619d110ad53f7a202612bfc304864e (diff)
downloadcpython-5b95eb90a7167285b6544b50865227c584943c9a.zip
cpython-5b95eb90a7167285b6544b50865227c584943c9a.tar.gz
cpython-5b95eb90a7167285b6544b50865227c584943c9a.tar.bz2
Use context managers in test_ssl to simplify test writing.
-rw-r--r--Lib/test/test_ssl.py102
1 files changed, 38 insertions, 64 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index ba1d868..e5addf8 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -532,6 +532,14 @@ else:
threading.Thread.__init__(self)
self.daemon = True
+ def __enter__(self):
+ self.start(threading.Event())
+ self.flag.wait()
+
+ def __exit__(self, *args):
+ self.stop()
+ self.join()
+
def start(self, flag=None):
self.flag = flag
threading.Thread.start(self)
@@ -638,6 +646,20 @@ else:
def __str__(self):
return "<%s %s>" % (self.__class__.__name__, self.server)
+ def __enter__(self):
+ self.start(threading.Event())
+ self.flag.wait()
+
+ def __exit__(self, *args):
+ if test_support.verbose:
+ sys.stdout.write(" cleanup: stopping server.\n")
+ self.stop()
+ if test_support.verbose:
+ sys.stdout.write(" cleanup: joining server thread.\n")
+ self.join()
+ if test_support.verbose:
+ sys.stdout.write(" cleanup: successfully joined.\n")
+
def start(self, flag=None):
self.flag = flag
threading.Thread.start(self)
@@ -752,12 +774,7 @@ else:
server = ThreadedEchoServer(CERTFILE,
certreqs=ssl.CERT_REQUIRED,
cacerts=CERTFILE, chatty=False)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
- try:
+ with server:
try:
s = ssl.wrap_socket(socket.socket(),
certfile=certfile,
@@ -771,9 +788,6 @@ else:
sys.stdout.write("\nsocket.error is %s\n" % x[1])
else:
raise AssertionError("Use of invalid cert should have failed!")
- finally:
- server.stop()
- server.join()
def server_params_test(certfile, protocol, certreqs, cacertsfile,
client_certfile, client_protocol=None, indata="FOO\n",
@@ -791,14 +805,10 @@ else:
chatty=chatty,
connectionchatty=connectionchatty,
wrap_accepting_socket=wrap_accepting_socket)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
- if client_protocol is None:
- client_protocol = protocol
- try:
+ with server:
+ # try to connect
+ if client_protocol is None:
+ client_protocol = protocol
s = ssl.wrap_socket(socket.socket(),
certfile=client_certfile,
ca_certs=cacertsfile,
@@ -826,9 +836,6 @@ else:
if test_support.verbose:
sys.stdout.write(" client: closing connection.\n")
s.close()
- finally:
- server.stop()
- server.join()
def try_protocol_combo(server_protocol,
client_protocol,
@@ -930,12 +937,7 @@ else:
ssl_version=ssl.PROTOCOL_SSLv23,
cacerts=CERTFILE,
chatty=False)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
- try:
+ with server:
s = ssl.wrap_socket(socket.socket(),
certfile=CERTFILE,
ca_certs=CERTFILE,
@@ -957,9 +959,6 @@ else:
"Missing or invalid 'organizationName' field in certificate subject; "
"should be 'Python Software Foundation'.")
s.close()
- finally:
- server.stop()
- server.join()
def test_empty_cert(self):
"""Connecting with an empty cert file"""
@@ -1042,13 +1041,8 @@ else:
starttls_server=True,
chatty=True,
connectionchatty=True)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
wrapped = False
- try:
+ with server:
s = socket.socket()
s.setblocking(1)
s.connect((HOST, server.port))
@@ -1093,9 +1087,6 @@ else:
else:
s.send("over\n")
s.close()
- finally:
- server.stop()
- server.join()
def test_socketserver(self):
"""Using a SocketServer to create and manage SSL connections."""
@@ -1145,12 +1136,7 @@ else:
if test_support.verbose:
sys.stdout.write("\n")
server = AsyncoreEchoServer(CERTFILE)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
- try:
+ with server:
s = ssl.wrap_socket(socket.socket())
s.connect(('127.0.0.1', server.port))
if test_support.verbose:
@@ -1169,10 +1155,6 @@ else:
if test_support.verbose:
sys.stdout.write(" client: closing connection.\n")
s.close()
- finally:
- server.stop()
- # wait for server thread to end
- server.join()
def test_recv_send(self):
"""Test recv(), send() and friends."""
@@ -1185,19 +1167,14 @@ else:
cacerts=CERTFILE,
chatty=True,
connectionchatty=False)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
- s = ssl.wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
- s.connect((HOST, server.port))
- try:
+ with server:
+ s = ssl.wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
# helper methods for standardising recv* method signatures
def _recv_into():
b = bytearray("\0"*100)
@@ -1285,9 +1262,6 @@ else:
s.write("over\n".encode("ASCII", "strict"))
s.close()
- finally:
- server.stop()
- server.join()
def test_handshake_timeout(self):
# Issue #5103: SSL handshake must respect the socket timeout