From d2867642c038a932590958494c96ce0171898283 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 28 Apr 2010 21:12:43 +0000 Subject: Merged revisions 80596 via svnmerge from svn+ssh://pythondev@svn.python.org/python/trunk ........ r80596 | antoine.pitrou | 2010-04-28 23:11:01 +0200 (mer., 28 avril 2010) | 3 lines Fix style issues in test_ssl ........ --- Lib/test/test_ssl.py | 280 ++++++++++++++++++++++++++------------------------- 1 file changed, 143 insertions(+), 137 deletions(-) diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index e5f4a9e..0e86457 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -40,7 +40,7 @@ def handle_error(prefix): class BasicTests(unittest.TestCase): - def testSimpleSSLwrap(self): + def test_sslwrap_simple(self): # A crude test for the legacy API try: ssl.sslwrap_simple(socket.socket(socket.AF_INET)) @@ -57,7 +57,7 @@ class BasicTests(unittest.TestCase): else: raise - def testSSLconnect(self): + def test_connect(self): if not test_support.is_resource_enabled('network'): return s = ssl.wrap_socket(socket.socket(socket.AF_INET), @@ -78,7 +78,7 @@ class BasicTests(unittest.TestCase): finally: s.close() - def testCrucialConstants(self): + def test_constants(self): ssl.PROTOCOL_SSLv2 ssl.PROTOCOL_SSLv23 ssl.PROTOCOL_SSLv3 @@ -87,7 +87,7 @@ class BasicTests(unittest.TestCase): ssl.CERT_OPTIONAL ssl.CERT_REQUIRED - def testRAND(self): + def test_random(self): v = ssl.RAND_status() if test_support.verbose: sys.stdout.write("\n RAND_status is %d (%s)\n" @@ -101,7 +101,7 @@ class BasicTests(unittest.TestCase): print "didn't raise TypeError" ssl.RAND_add("this is a random string", 75.0) - def testParseCert(self): + def test_parse_cert(self): # note that this uses an 'unofficial' function in _ssl.c, # provided solely for this test, to exercise the certificate # parsing code @@ -109,9 +109,9 @@ class BasicTests(unittest.TestCase): if test_support.verbose: sys.stdout.write("\n" + pprint.pformat(p) + "\n") - def testDERtoPEM(self): - - pem = open(SVN_PYTHON_ORG_ROOT_CERT, 'r').read() + def test_DER_to_PEM(self): + with open(SVN_PYTHON_ORG_ROOT_CERT, 'r') as f: + pem = f.read() d1 = ssl.PEM_cert_to_DER_cert(pem) p2 = ssl.DER_cert_to_PEM_cert(d1) d2 = ssl.PEM_cert_to_DER_cert(p2) @@ -133,7 +133,7 @@ class BasicTests(unittest.TestCase): class NetworkedTests(unittest.TestCase): - def testConnect(self): + def test_connect(self): s = ssl.wrap_socket(socket.socket(socket.AF_INET), cert_reqs=ssl.CERT_NONE) s.connect(("svn.python.org", 443)) @@ -186,7 +186,7 @@ class NetworkedTests(unittest.TestCase): else: self.fail("OSError wasn't raised") - def testNonBlockingHandshake(self): + def test_non_blocking_handshake(self): s = socket.socket(socket.AF_INET) s.connect(("svn.python.org", 443)) s.setblocking(False) @@ -210,8 +210,7 @@ class NetworkedTests(unittest.TestCase): if test_support.verbose: sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count) - def testFetchServerCert(self): - + def test_get_server_certificate(self): pem = ssl.get_server_certificate(("svn.python.org", 443)) if not pem: self.fail("No server certificate on svn.python.org:443!") @@ -261,7 +260,6 @@ try: except ImportError: _have_threads = False else: - _have_threads = True class ThreadedEchoServer(threading.Thread): @@ -293,7 +291,7 @@ else: if test_support.verbose and self.server.chatty: sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") - def wrap_conn (self): + def wrap_conn(self): try: self.sslconn = ssl.wrap_socket(self.sock, server_side=True, certfile=self.server.certificate, @@ -332,7 +330,7 @@ else: else: self.sock._sock.close() - def run (self): + def run(self): self.running = True if not self.server.starttls_server: if isinstance(self.sock, ssl.SSLSocket): @@ -413,11 +411,11 @@ else: threading.Thread.__init__(self) self.daemon = True - def start (self, flag=None): + def start(self, flag=None): self.flag = flag threading.Thread.start(self) - def run (self): + def run(self): self.sock.settimeout(0.05) self.sock.listen(5) self.active = True @@ -438,14 +436,14 @@ else: self.stop() self.sock.close() - def stop (self): + def stop(self): self.active = False class AsyncoreEchoServer(threading.Thread): - class EchoServer (asyncore.dispatcher): + class EchoServer(asyncore.dispatcher): - class ConnectionHandler (asyncore.dispatcher_with_send): + class ConnectionHandler(asyncore.dispatcher_with_send): def __init__(self, conn, certfile): asyncore.dispatcher_with_send.__init__(self, conn) @@ -519,18 +517,18 @@ else: def __str__(self): return "<%s %s>" % (self.__class__.__name__, self.server) - def start (self, flag=None): + def start(self, flag=None): self.flag = flag threading.Thread.start(self) - def run (self): + def run(self): self.active = True if self.flag: self.flag.set() while self.active: asyncore.loop(0.05) - def stop (self): + def stop(self): self.active = False self.server.close() @@ -539,12 +537,9 @@ else: class HTTPSServer(HTTPServer): def __init__(self, server_address, RequestHandlerClass, certfile): - HTTPServer.__init__(self, server_address, RequestHandlerClass) # we assume the certfile contains both private key and certificate self.certfile = certfile - self.active = False - self.active_lock = threading.Lock() self.allow_reuse_address = True def __str__(self): @@ -553,7 +548,7 @@ else: self.server_name, self.server_port)) - def get_request (self): + def get_request(self): # override this to wrap socket with SSL sock, addr = self.socket.accept() sslconn = ssl.wrap_socket(sock, server_side=True, @@ -561,7 +556,6 @@ else: return sslconn, addr class RootedHTTPRequestHandler(SimpleHTTPRequestHandler): - # need to override translate_path to get a known root, # instead of using os.curdir, since the test could be # run from anywhere @@ -606,7 +600,6 @@ else: def __init__(self, certfile): self.flag = None - self.active = False self.RootedHTTPRequestHandler.root = os.path.split(CERTFILE)[0] self.server = self.HTTPSServer( (HOST, 0), self.RootedHTTPRequestHandler, certfile) @@ -617,23 +610,24 @@ else: def __str__(self): return "<%s %s>" % (self.__class__.__name__, self.server) - def start (self, flag=None): + def start(self, flag=None): self.flag = flag threading.Thread.start(self) - def run (self): - self.active = True + def run(self): if self.flag: self.flag.set() self.server.serve_forever(0.05) - self.active = False - def stop (self): - self.active = False + def stop(self): self.server.shutdown() - def badCertTest (certfile): + def bad_cert_test(certfile): + """ + Launch a server with CERT_REQUIRED, and check that trying to + connect to it with the given client certificate fails. + """ server = ThreadedEchoServer(CERTFILE, certreqs=ssl.CERT_REQUIRED, cacerts=CERTFILE, chatty=False) @@ -660,11 +654,14 @@ else: server.stop() server.join() - def serverParamsTest (certfile, protocol, certreqs, cacertsfile, - client_certfile, client_protocol=None, indata="FOO\n", - chatty=True, connectionchatty=False, - wrap_accepting_socket=False): - + def server_params_test(certfile, protocol, certreqs, cacertsfile, + client_certfile, client_protocol=None, indata="FOO\n", + chatty=True, connectionchatty=False, + wrap_accepting_socket=False): + """ + Launch a server, connect a client to it and try various reads + and writes. + """ server = ThreadedEchoServer(certfile, certreqs=certreqs, ssl_version=protocol, @@ -709,39 +706,37 @@ else: server.stop() server.join() - def tryProtocolCombo (server_protocol, - client_protocol, - expectedToWork, - certsreqs=None): - + def try_protocol_combo(server_protocol, + client_protocol, + expect_success, + certsreqs=None): if certsreqs is None: certsreqs = ssl.CERT_NONE - - if certsreqs == ssl.CERT_NONE: - certtype = "CERT_NONE" - elif certsreqs == ssl.CERT_OPTIONAL: - certtype = "CERT_OPTIONAL" - elif certsreqs == ssl.CERT_REQUIRED: - certtype = "CERT_REQUIRED" + certtype = { + ssl.CERT_NONE: "CERT_NONE", + ssl.CERT_OPTIONAL: "CERT_OPTIONAL", + ssl.CERT_REQUIRED: "CERT_REQUIRED", + }[certsreqs] if test_support.verbose: - formatstr = (expectedToWork and " %s->%s %s\n") or " {%s->%s} %s\n" + formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n" sys.stdout.write(formatstr % (ssl.get_protocol_name(client_protocol), ssl.get_protocol_name(server_protocol), certtype)) try: - serverParamsTest(CERTFILE, server_protocol, certsreqs, - CERTFILE, CERTFILE, client_protocol, chatty=False) + server_params_test(CERTFILE, server_protocol, certsreqs, + CERTFILE, CERTFILE, client_protocol, + chatty=False) # Protocol mismatch can result in either an SSLError, or a # "Connection reset by peer" error. except ssl.SSLError: - if expectedToWork: + if expect_success: raise except socket.error as e: - if expectedToWork or e.errno != errno.ECONNRESET: + if expect_success or e.errno != errno.ECONNRESET: raise else: - if not expectedToWork: + if not expect_success: self.fail( "Client protocol %s succeeded with server protocol %s!" % (ssl.get_protocol_name(client_protocol), @@ -750,8 +745,10 @@ else: class ThreadedTests(unittest.TestCase): - def testRudeShutdown(self): - + def test_rude_shutdown(self): + """A brutal shutdown of an SSL server should raise an IOError + in the client when attempting handshake. + """ listener_ready = threading.Event() listener_gone = threading.Event() @@ -788,16 +785,15 @@ else: finally: t.join() - def testEcho (self): - + def test_echo(self): + """Basic test of an SSL client connecting to a server""" if test_support.verbose: sys.stdout.write("\n") - serverParamsTest(CERTFILE, ssl.PROTOCOL_TLSv1, ssl.CERT_NONE, - CERTFILE, CERTFILE, ssl.PROTOCOL_TLSv1, - chatty=True, connectionchatty=True) - - def testReadCert(self): + server_params_test(CERTFILE, ssl.PROTOCOL_TLSv1, ssl.CERT_NONE, + CERTFILE, CERTFILE, ssl.PROTOCOL_TLSv1, + chatty=True, connectionchatty=True) + def test_getpeercert(self): if test_support.verbose: sys.stdout.write("\n") s2 = socket.socket() @@ -837,74 +833,82 @@ else: server.stop() server.join() - def testNULLcert(self): - badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, - "nullcert.pem")) - def testMalformedCert(self): - badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, - "badcert.pem")) - def testWrongCert(self): - badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, - "wrongcert.pem")) - def testMalformedKey(self): - badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, - "badkey.pem")) - - def testProtocolSSL2(self): + def test_empty_cert(self): + """Connecting with an empty cert file""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "nullcert.pem")) + def test_malformed_cert(self): + """Connecting with a badly formatted certificate (syntax error)""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "badcert.pem")) + def test_nonexisting_cert(self): + """Connecting with a non-existing cert file""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "wrongcert.pem")) + def test_malformed_key(self): + """Connecting with a badly formatted key (syntax error)""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "badkey.pem")) + + def test_protocol_sslv2(self): + """Connecting to an SSLv2 server with various client options""" if test_support.verbose: sys.stdout.write("\n") - tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True) - tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL) - tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED) - tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True) - tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False) - tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False) - - def testProtocolSSL23(self): + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False) + + def test_protocol_sslv23(self): + """Connecting to an SSLv23 server with various client options""" if test_support.verbose: sys.stdout.write("\n") try: - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) except (ssl.SSLError, socket.error), x: # this fails on some older versions of OpenSSL (0.9.7l, for instance) if test_support.verbose: sys.stdout.write( " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" % str(x)) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) - def testProtocolSSL3(self): + def test_protocol_sslv3(self): + """Connecting to an SSLv3 server with various client options""" if test_support.verbose: sys.stdout.write("\n") - tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True) - tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) - tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) - tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) - tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False) - tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) - - def testProtocolTLS1(self): + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) + + def test_protocol_tlsv1(self): + """Connecting to a TLSv1 server with various client options""" if test_support.verbose: sys.stdout.write("\n") - tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True) - tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) - tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) - tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) - tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) - tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False) - - def testSTARTTLS (self): - + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False) + + def test_starttls(self): + """Switching from clear text to encrypted and back again.""" msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", "msg 5", "msg 6") server = ThreadedEchoServer(CERTFILE, @@ -936,6 +940,7 @@ else: outdata = s.recv(1024) if (indata == "STARTTLS" and outdata.strip().lower().startswith("ok")): + # STARTTLS ok, switch to secure mode if test_support.verbose: sys.stdout.write( " client: read %s from server, starting TLS...\n" @@ -944,6 +949,7 @@ else: wrapped = True elif (indata == "ENDTLS" and outdata.strip().lower().startswith("ok")): + # ENDTLS ok, switch back to clear text if test_support.verbose: sys.stdout.write( " client: read %s from server, ending TLS...\n" @@ -965,8 +971,8 @@ else: server.stop() server.join() - def testSocketServer(self): - + def test_socketserver(self): + """Using a SocketServer to create and manage SSL connections.""" server = SocketServerHTTPSServer(CERTFILE) flag = threading.Event() server.start(flag) @@ -976,7 +982,8 @@ else: try: if test_support.verbose: sys.stdout.write('\n') - d1 = open(CERTFILE, 'rb').read() + with open(CERTFILE, 'rb') as f: + d1 = f.read() d2 = '' # now fetch the same data from the HTTPS server url = 'https://127.0.0.1:%d/%s' % ( @@ -995,18 +1002,17 @@ else: server.stop() server.join() - def testWrappedAccept (self): - + def test_wrapped_accept(self): + """Check the accept() method on SSL sockets.""" if test_support.verbose: sys.stdout.write("\n") - serverParamsTest(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED, - CERTFILE, CERTFILE, ssl.PROTOCOL_SSLv23, - chatty=True, connectionchatty=True, - wrap_accepting_socket=True) - - - def testAsyncoreServer (self): + server_params_test(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED, + CERTFILE, CERTFILE, ssl.PROTOCOL_SSLv23, + chatty=True, connectionchatty=True, + wrap_accepting_socket=True) + def test_asyncore_server(self): + """Check the example asyncore integration.""" indata = "TEST MESSAGE of mixed case\n" if test_support.verbose: @@ -1041,9 +1047,8 @@ else: # wait for server thread to end server.join() - - def testAllRecvAndSendMethods(self): - + def test_recv_send(self): + """Test recv(), send() and friends.""" if test_support.verbose: sys.stdout.write("\n") @@ -1238,10 +1243,11 @@ def test_main(verbose=False): if thread_info and test_support.is_resource_enabled('network'): tests.append(ThreadedTests) - test_support.run_unittest(*tests) - - if _have_threads: - test_support.threading_cleanup(*thread_info) + try: + test_support.run_unittest(*tests) + finally: + if _have_threads: + test_support.threading_cleanup(*thread_info) if __name__ == "__main__": test_main() -- cgit v0.12