diff options
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r-- | Lib/test/test_ssl.py | 246 |
1 files changed, 213 insertions, 33 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index eb4d00c..d786154 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -3,7 +3,9 @@ import sys import unittest from test import test_support +import asyncore import socket +import select import errno import subprocess import time @@ -97,8 +99,7 @@ class BasicTests(unittest.TestCase): if (d1 != d2): raise test_support.TestFailed("PEM-to-DER or DER-to-PEM translation failed") - -class NetworkTests(unittest.TestCase): +class NetworkedTests(unittest.TestCase): def testConnect(self): s = ssl.wrap_socket(socket.socket(socket.AF_INET), @@ -130,6 +131,31 @@ class NetworkTests(unittest.TestCase): finally: s.close() + + def testNonBlockingHandshake(self): + s = socket.socket(socket.AF_INET) + s.connect(("svn.python.org", 443)) + s.setblocking(False) + s = ssl.wrap_socket(s, + cert_reqs=ssl.CERT_NONE, + do_handshake_on_connect=False) + count = 0 + while True: + try: + count += 1 + s.do_handshake() + break + except ssl.SSLError, err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + select.select([s], [], []) + elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: + select.select([], [s], []) + else: + raise + s.close() + if test_support.verbose: + sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count) + def testFetchServerCert(self): pem = ssl.get_server_certificate(("svn.python.org", 443)) @@ -176,6 +202,18 @@ else: threading.Thread.__init__(self) self.setDaemon(True) + def show_conn_details(self): + if self.server.certreqs == ssl.CERT_REQUIRED: + cert = self.sslconn.getpeercert() + if test_support.verbose and self.server.chatty: + sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n") + cert_binary = self.sslconn.getpeercert(True) + if test_support.verbose and self.server.chatty: + sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n") + cipher = self.sslconn.cipher() + if test_support.verbose and self.server.chatty: + sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") + def wrap_conn (self): try: self.sslconn = ssl.wrap_socket(self.sock, server_side=True, @@ -187,6 +225,7 @@ else: if self.server.chatty: handle_error("\n server: bad connection attempt from " + str(self.sock.getpeername()) + ":\n") + self.close() if not self.server.expect_bad_connects: # here, we want to stop the server, because this shouldn't # happen in the context of our test case @@ -197,16 +236,6 @@ else: return False else: - if self.server.certreqs == ssl.CERT_REQUIRED: - cert = self.sslconn.getpeercert() - if test_support.verbose and self.server.chatty: - sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n") - cert_binary = self.sslconn.getpeercert(True) - if test_support.verbose and self.server.chatty: - sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n") - cipher = self.sslconn.cipher() - if test_support.verbose and self.server.chatty: - sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") return True def read(self): @@ -225,13 +254,16 @@ else: if self.sslconn: self.sslconn.close() else: - self.sock.close() + self.sock._sock.close() def run (self): self.running = True if not self.server.starttls_server: - if not self.wrap_conn(): + if isinstance(self.sock, ssl.SSLSocket): + self.sslconn = self.sock + elif not self.wrap_conn(): return + self.show_conn_details() while self.running: try: msg = self.read() @@ -270,7 +302,9 @@ else: def __init__(self, certificate, ssl_version=None, certreqs=None, cacerts=None, expect_bad_connects=False, - chatty=True, connectionchatty=False, starttls_server=False): + chatty=True, connectionchatty=False, starttls_server=False, + wrap_accepting_socket=False): + if ssl_version is None: ssl_version = ssl.PROTOCOL_TLSv1 if certreqs is None: @@ -284,8 +318,16 @@ else: self.connectionchatty = connectionchatty self.starttls_server = starttls_server self.sock = socket.socket() - self.port = test_support.bind_port(self.sock) self.flag = None + if wrap_accepting_socket: + self.sock = ssl.wrap_socket(self.sock, server_side=True, + certfile=self.certificate, + cert_reqs = self.certreqs, + ca_certs = self.cacerts, + ssl_version = self.protocol) + if test_support.verbose and self.chatty: + sys.stdout.write(' server: wrapped server socket as %s\n' % str(self.sock)) + self.port = test_support.bind_port(self.sock) self.active = False threading.Thread.__init__(self) self.setDaemon(False) @@ -316,13 +358,86 @@ else: except: if self.chatty: handle_error("Test server failure:\n") + self.sock.close() def stop (self): self.active = False - self.sock.close() + class AsyncoreEchoServer(threading.Thread): + + class EchoServer (asyncore.dispatcher): + + class ConnectionHandler (asyncore.dispatcher_with_send): + + def __init__(self, conn, certfile): + asyncore.dispatcher_with_send.__init__(self, conn) + self.socket = ssl.wrap_socket(conn, server_side=True, + certfile=certfile, + do_handshake_on_connect=True) + + def readable(self): + if isinstance(self.socket, ssl.SSLSocket): + while self.socket.pending() > 0: + self.handle_read_event() + return True + + def handle_read(self): + data = self.recv(1024) + self.send(data.lower()) + + def handle_close(self): + if test_support.verbose: + sys.stdout.write(" server: closed connection %s\n" % self.socket) + + def handle_error(self): + raise + + def __init__(self, certfile): + self.certfile = certfile + asyncore.dispatcher.__init__(self) + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.port = test_support.bind_port(self.socket) + self.listen(5) + + def handle_accept(self): + sock_obj, addr = self.accept() + if test_support.verbose: + sys.stdout.write(" server: new connection from %s:%s\n" %addr) + self.ConnectionHandler(sock_obj, self.certfile) + + def handle_error(self): + raise + + def __init__(self, certfile): + self.flag = None + self.active = False + self.server = self.EchoServer(certfile) + self.port = self.server.port + threading.Thread.__init__(self) + self.setDaemon(True) + + def __str__(self): + return "<%s %s>" % (self.__class__.__name__, self.server) + + def start (self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run (self): + self.active = True + if self.flag: + self.flag.set() + while self.active: + try: + asyncore.loop(1) + except: + pass + + def stop (self): + self.active = False + self.server.close() - class AsyncoreHTTPSServer(threading.Thread): + class SocketServerHTTPSServer(threading.Thread): class HTTPSServer(HTTPServer): @@ -335,6 +450,12 @@ else: self.active_lock = threading.Lock() self.allow_reuse_address = True + def __str__(self): + return ('<%s %s:%s>' % + (self.__class__.__name__, + self.server_name, + self.server_port)) + def get_request (self): # override this to wrap socket with SSL sock, addr = self.socket.accept() @@ -421,8 +542,8 @@ else: # we override this to suppress logging unless "verbose" if test_support.verbose: - sys.stdout.write(" server (%s, %d, %s):\n [%s] %s\n" % - (self.server.server_name, + sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" % + (self.server.server_address, self.server.server_port, self.request.cipher(), self.log_date_time_string(), @@ -440,9 +561,7 @@ else: self.setDaemon(True) def __str__(self): - return '<%s %s:%d>' % (self.__class__.__name__, - self.server.server_name, - self.server.server_port) + return "<%s %s>" % (self.__class__.__name__, self.server) def start (self, flag=None): self.flag = flag @@ -487,14 +606,16 @@ else: def serverParamsTest (certfile, protocol, certreqs, cacertsfile, client_certfile, client_protocol=None, indata="FOO\n", - chatty=True, connectionchatty=False): + chatty=True, connectionchatty=False, + wrap_accepting_socket=False): server = ThreadedEchoServer(certfile, certreqs=certreqs, ssl_version=protocol, cacerts=cacertsfile, chatty=chatty, - connectionchatty=connectionchatty) + connectionchatty=connectionchatty, + wrap_accepting_socket=wrap_accepting_socket) flag = threading.Event() server.start(flag) # wait for it to start @@ -572,7 +693,7 @@ else: ssl.get_protocol_name(server_protocol))) - class ConnectedTests(unittest.TestCase): + class ThreadedTests(unittest.TestCase): def testRudeShutdown(self): @@ -600,7 +721,7 @@ else: listener_gone.wait() try: ssl_sock = ssl.wrap_socket(s) - except socket.sslerror: + except IOError: pass else: raise test_support.TestFailed( @@ -680,6 +801,9 @@ else: def testMalformedCert(self): badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, "badcert.pem")) + def testWrongCert(self): + badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, + "wrongcert.pem")) def testMalformedKey(self): badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, "badkey.pem")) @@ -796,9 +920,9 @@ else: server.stop() server.join() - def testAsyncore(self): + def testSocketServer(self): - server = AsyncoreHTTPSServer(CERTFILE) + server = SocketServerHTTPSServer(CERTFILE) flag = threading.Event() server.start(flag) # wait for it to start @@ -810,8 +934,8 @@ else: d1 = open(CERTFILE, 'rb').read() d2 = '' # now fetch the same data from the HTTPS server - url = 'https://%s:%d/%s' % ( - HOST, server.port, os.path.split(CERTFILE)[1]) + url = 'https://127.0.0.1:%d/%s' % ( + server.port, os.path.split(CERTFILE)[1]) f = urllib.urlopen(url) dlen = f.info().getheader("content-length") if dlen and (int(dlen) > 0): @@ -834,6 +958,58 @@ else: server.stop() server.join() + def testWrappedAccept (self): + + if test_support.verbose: + sys.stdout.write("\n") + serverParamsTest(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED, + CERTFILE, CERTFILE, ssl.PROTOCOL_SSLv23, + chatty=True, connectionchatty=True, + wrap_accepting_socket=True) + + + def testAsyncoreServer (self): + + indata = "TEST MESSAGE of mixed case\n" + + if test_support.verbose: + sys.stdout.write("\n") + server = AsyncoreEchoServer(CERTFILE) + flag = threading.Event() + server.start(flag) + # wait for it to start + flag.wait() + # try to connect + try: + try: + s = ssl.wrap_socket(socket.socket()) + s.connect(('127.0.0.1', server.port)) + except ssl.SSLError, x: + raise test_support.TestFailed("Unexpected SSL error: " + str(x)) + except Exception, x: + raise test_support.TestFailed("Unexpected exception: " + str(x)) + else: + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % (repr(indata))) + s.write(indata) + outdata = s.read() + if test_support.verbose: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + if outdata != indata.lower(): + raise test_support.TestFailed( + "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (outdata[:min(len(outdata),20)], len(outdata), + indata[:min(len(indata),20)].lower(), len(indata))) + s.write("over\n") + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() + finally: + server.stop() + # wait for server thread to end + server.join() + def test_main(verbose=False): if skip_expected: @@ -850,15 +1026,19 @@ def test_main(verbose=False): not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)): raise test_support.TestFailed("Can't read certificate files!") + TESTPORT = test_support.find_unused_port() + if not TESTPORT: + raise test_support.TestFailed("Can't find open port to test servers on!") + tests = [BasicTests] if test_support.is_resource_enabled('network'): - tests.append(NetworkTests) + tests.append(NetworkedTests) if _have_threads: thread_info = test_support.threading_setup() if thread_info and test_support.is_resource_enabled('network'): - tests.append(ConnectedTests) + tests.append(ThreadedTests) test_support.run_unittest(*tests) |