summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorAntoine Pitrou <solipsis@pitrou.net>2011-12-21 15:54:45 (GMT)
committerAntoine Pitrou <solipsis@pitrou.net>2011-12-21 15:54:45 (GMT)
commit6b15c90fd8892846444d175f60f59456a79d9848 (patch)
tree83d6576887a364c8b4f24404c168d1166ffd121d /Lib
parentf0a49a9e27075478af9eb57ce45ecb10c6a6d383 (diff)
parent65a3f4b8c57a761cfe0e6ee14565db421c50f4c0 (diff)
downloadcpython-6b15c90fd8892846444d175f60f59456a79d9848.zip
cpython-6b15c90fd8892846444d175f60f59456a79d9848.tar.gz
cpython-6b15c90fd8892846444d175f60f59456a79d9848.tar.bz2
Use context managers in test_ssl to simplify test writing.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/test_ssl.py126
1 files changed, 43 insertions, 83 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 1960e14..d549799 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -986,6 +986,14 @@ else:
threading.Thread.__init__(self)
self.daemon = True
+ def __enter__(self):
+ self.start(threading.Event())
+ self.flag.wait()
+
+ def __exit__(self, *args):
+ self.stop()
+ self.join()
+
def start(self, flag=None):
self.flag = flag
threading.Thread.start(self)
@@ -1097,6 +1105,20 @@ else:
def __str__(self):
return "<%s %s>" % (self.__class__.__name__, self.server)
+ def __enter__(self):
+ self.start(threading.Event())
+ self.flag.wait()
+
+ def __exit__(self, *args):
+ if support.verbose:
+ sys.stdout.write(" cleanup: stopping server.\n")
+ self.stop()
+ if support.verbose:
+ sys.stdout.write(" cleanup: joining server thread.\n")
+ self.join()
+ if support.verbose:
+ sys.stdout.write(" cleanup: successfully joined.\n")
+
def start (self, flag=None):
self.flag = flag
threading.Thread.start(self)
@@ -1124,12 +1146,7 @@ else:
certreqs=ssl.CERT_REQUIRED,
cacerts=CERTFILE, chatty=False,
connectionchatty=False)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
- try:
+ with server:
try:
with socket.socket() as sock:
s = ssl.wrap_socket(sock,
@@ -1149,9 +1166,6 @@ else:
sys.stdout.write("\IOError is %s\n" % str(x))
else:
raise AssertionError("Use of invalid cert should have failed!")
- finally:
- server.stop()
- server.join()
def server_params_test(client_context, server_context, indata=b"FOO\n",
chatty=True, connectionchatty=False):
@@ -1162,12 +1176,7 @@ else:
server = ThreadedEchoServer(context=server_context,
chatty=chatty,
connectionchatty=False)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
- try:
+ with server:
s = client_context.wrap_socket(socket.socket())
s.connect((HOST, server.port))
for arg in [indata, bytearray(indata), memoryview(indata)]:
@@ -1195,9 +1204,6 @@ else:
}
s.close()
return stats
- finally:
- server.stop()
- server.join()
def try_protocol_combo(server_protocol, client_protocol, expect_success,
certsreqs=None, server_options=0, client_options=0):
@@ -1266,12 +1272,7 @@ else:
context.load_verify_locations(CERTFILE)
context.load_cert_chain(CERTFILE)
server = ThreadedEchoServer(context=context, chatty=False)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
- try:
+ with server:
s = context.wrap_socket(socket.socket())
s.connect((HOST, server.port))
cert = s.getpeercert()
@@ -1294,9 +1295,6 @@ else:
after = ssl.cert_time_to_seconds(cert['notAfter'])
self.assertLess(before, after)
s.close()
- finally:
- server.stop()
- server.join()
def test_empty_cert(self):
"""Connecting with an empty cert file"""
@@ -1456,13 +1454,8 @@ else:
starttls_server=True,
chatty=True,
connectionchatty=True)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
wrapped = False
- try:
+ with server:
s = socket.socket()
s.setblocking(1)
s.connect((HOST, server.port))
@@ -1509,9 +1502,6 @@ else:
conn.close()
else:
s.close()
- finally:
- server.stop()
- server.join()
def test_socketserver(self):
"""Using a SocketServer to create and manage SSL connections."""
@@ -1547,12 +1537,7 @@ else:
indata = b"FOO\n"
server = AsyncoreEchoServer(CERTFILE)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
- try:
+ with server:
s = ssl.wrap_socket(socket.socket())
s.connect(('127.0.0.1', server.port))
if support.verbose:
@@ -1573,15 +1558,6 @@ else:
s.close()
if support.verbose:
sys.stdout.write(" client: connection closed.\n")
- finally:
- if support.verbose:
- sys.stdout.write(" cleanup: stopping server.\n")
- server.stop()
- if support.verbose:
- sys.stdout.write(" cleanup: joining server thread.\n")
- server.join()
- if support.verbose:
- sys.stdout.write(" cleanup: successfully joined.\n")
def test_recv_send(self):
"""Test recv(), send() and friends."""
@@ -1594,19 +1570,14 @@ else:
cacerts=CERTFILE,
chatty=True,
connectionchatty=False)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
- s = ssl.wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
- s.connect((HOST, server.port))
- try:
+ with server:
+ s = ssl.wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
# helper methods for standardising recv* method signatures
def _recv_into():
b = bytearray(b"\0"*100)
@@ -1702,9 +1673,6 @@ else:
s.write(b"over\n")
s.close()
- finally:
- server.stop()
- server.join()
def test_handshake_timeout(self):
# Issue #5103: SSL handshake must respect the socket timeout
@@ -1768,19 +1736,14 @@ else:
cacerts=CERTFILE,
chatty=True,
connectionchatty=False)
- flag = threading.Event()
- server.start(flag)
- # wait for it to start
- flag.wait()
- # try to connect
- s = ssl.wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
- s.connect((HOST, server.port))
- try:
+ with server:
+ s = ssl.wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
# get the data
cb_data = s.get_channel_binding("tls-unique")
if support.verbose:
@@ -1819,9 +1782,6 @@ else:
self.assertEqual(peer_data_repr,
repr(new_cb_data).encode("us-ascii"))
s.close()
- finally:
- server.stop()
- server.join()
def test_compression(self):
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)