diff options
author | Bill Janssen <janssen@parc.com> | 2007-11-15 22:23:56 (GMT) |
---|---|---|
committer | Bill Janssen <janssen@parc.com> | 2007-11-15 22:23:56 (GMT) |
commit | 6e027dba9339887feeb947fa409e18a6f44e210b (patch) | |
tree | c01007b372bbcc467bd00bb97e065a2ff1ed218c /Lib/test/test_ssl.py | |
parent | f83088aefedd3c6ee41171ec7c0b5b354df11e63 (diff) | |
download | cpython-6e027dba9339887feeb947fa409e18a6f44e210b.zip cpython-6e027dba9339887feeb947fa409e18a6f44e210b.tar.gz cpython-6e027dba9339887feeb947fa409e18a6f44e210b.tar.bz2 |
get SSL support to work again
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r-- | Lib/test/test_ssl.py | 317 |
1 files changed, 171 insertions, 146 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 81c9c7a..18df3f4 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -4,6 +4,7 @@ import sys import unittest from test import test_support import socket +import select import errno import subprocess import time @@ -36,27 +37,6 @@ def handle_error(prefix): class BasicTests(unittest.TestCase): - def testSSLconnect(self): - import os - s = ssl.wrap_socket(socket.socket(socket.AF_INET), - cert_reqs=ssl.CERT_NONE) - s.connect(("svn.python.org", 443)) - c = s.getpeercert() - if c: - raise test_support.TestFailed("Peer cert %s shouldn't be here!") - s.close() - - # this should fail because we have no verification certs - s = ssl.wrap_socket(socket.socket(socket.AF_INET), - cert_reqs=ssl.CERT_REQUIRED) - try: - s.connect(("svn.python.org", 443)) - except ssl.SSLError: - pass - finally: - s.close() - - def testCrucialConstants(self): ssl.PROTOCOL_SSLv2 ssl.PROTOCOL_SSLv23 @@ -97,11 +77,31 @@ class BasicTests(unittest.TestCase): if (d1 != d2): raise test_support.TestFailed("PEM-to-DER or DER-to-PEM translation failed") +class NetworkedTests(unittest.TestCase): + + def testFetchServerCert(self): + + pem = ssl.get_server_certificate(("svn.python.org", 443)) + if not pem: + raise test_support.TestFailed("No server certificate on svn.python.org:443!") -class NetworkTests(unittest.TestCase): + try: + pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE) + except ssl.SSLError as x: + #should fail + if test_support.verbose: + sys.stdout.write("%s\n" % x) + else: + raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem) + + pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT) + if not pem: + raise test_support.TestFailed("No server certificate on svn.python.org:443!") + if test_support.verbose: + sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem) def testConnect(self): - import os + s = ssl.wrap_socket(socket.socket(socket.AF_INET), cert_reqs=ssl.CERT_NONE) s.connect(("svn.python.org", 443)) @@ -131,25 +131,29 @@ class NetworkTests(unittest.TestCase): finally: s.close() - def testFetchServerCert(self): - - pem = ssl.get_server_certificate(("svn.python.org", 443)) - if not pem: - raise test_support.TestFailed("No server certificate on svn.python.org:443!") - - try: - pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE) - except ssl.SSLError: - #should fail - pass - else: - raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem) - - pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT) - if not pem: - raise test_support.TestFailed("No server certificate on svn.python.org:443!") + def testNonBlockingHandshake(self): + s = socket.socket(socket.AF_INET) + s.connect(("svn.python.org", 443)) + s.setblocking(False) + s = ssl.wrap_socket(s, + cert_reqs=ssl.CERT_NONE, + do_handshake_on_connect=False) + count = 0 + while True: + try: + count += 1 + s.do_handshake() + break + except ssl.SSLError as err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + select.select([s], [], []) + elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: + select.select([], [s], []) + else: + raise + s.close() if test_support.verbose: - sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem) + sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count) try: @@ -168,10 +172,11 @@ else: with and without the SSL wrapper around the socket connection, so that we can test the STARTTLS functionality.""" - def __init__(self, server, connsock): + def __init__(self, server, connsock, addr): self.server = server self.running = False self.sock = connsock + self.addr = addr self.sock.setblocking(1) self.sslconn = None threading.Thread.__init__(self) @@ -186,8 +191,7 @@ else: cert_reqs=self.server.certreqs) except: if self.server.chatty: - handle_error("\n server: bad connection attempt from " + - str(self.sock.getpeername()) + ":\n") + handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n") if not self.server.expect_bad_connects: # here, we want to stop the server, because this shouldn't # happen in the context of our test case @@ -195,6 +199,7 @@ else: # normally, we'd just stop here, but for the test # harness, we want to stop the server self.server.stop() + self.close() return False else: @@ -236,19 +241,21 @@ else: while self.running: try: msg = self.read() + amsg = (msg and str(msg, 'ASCII', 'strict')) or '' if not msg: # eof, so quit this handler self.running = False self.close() - elif msg.strip() == 'over': + elif amsg.strip() == 'over': if test_support.verbose and self.server.connectionchatty: sys.stdout.write(" server: client closed connection\n") self.close() return - elif self.server.starttls_server and msg.strip() == 'STARTTLS': + elif (self.server.starttls_server and + amsg.strip() == 'STARTTLS'): if test_support.verbose and self.server.connectionchatty: sys.stdout.write(" server: read STARTTLS from client, sending OK...\n") - self.write("OK\n") + self.write("OK\n".encode("ASCII", "strict")) if not self.wrap_conn(): return else: @@ -257,8 +264,8 @@ else: ctype = (self.sslconn and "encrypted") or "unencrypted" sys.stdout.write(" server: read %s (%s), sending back %s (%s)...\n" % (repr(msg), ctype, repr(msg.lower()), ctype)) - self.write(msg.lower()) - except ssl.SSLError: + self.write(amsg.lower().encode('ASCII', 'strict')) + except socket.error: if self.server.chatty: handle_error("Test server failure:\n") self.close() @@ -311,8 +318,8 @@ else: newconn, connaddr = self.sock.accept() if test_support.verbose and self.chatty: sys.stdout.write(' server: new connection from ' - + str(connaddr) + '\n') - handler = self.ConnectionHandler(self, newconn) + + repr(connaddr) + '\n') + handler = self.ConnectionHandler(self, newconn, connaddr) handler.start() except socket.timeout: pass @@ -321,11 +328,10 @@ else: except: if self.chatty: handle_error("Test server failure:\n") + self.sock.close() def stop (self): self.active = False - self.sock.close() - class AsyncoreHTTPSServer(threading.Thread): @@ -339,6 +345,12 @@ else: self.active = False self.allow_reuse_address = True + def __str__(self): + return ('<%s %s:%s>' % + (self.__class__.__name__, + self.server_name, + self.server_port)) + def get_request (self): # override this to wrap socket with SSL sock, addr = self.socket.accept() @@ -415,8 +427,8 @@ else: # we override this to suppress logging unless "verbose" if test_support.verbose: - sys.stdout.write(" server (%s, %d, %s):\n [%s] %s\n" % - (self.server.server_name, + sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" % + (self.server.server_address, self.server.server_port, self.request.cipher(), self.log_date_time_string(), @@ -433,9 +445,7 @@ else: self.setDaemon(True) def __str__(self): - return '<%s %s:%d>' % (self.__class__.__name__, - self.server.server_name, - self.server.server_port) + return "<%s %s>" % (self.__class__.__name__, self.server) def start (self, flag=None): self.flag = flag @@ -456,7 +466,8 @@ else: def badCertTest (certfile): server = ThreadedEchoServer(TESTPORT, CERTFILE, certreqs=ssl.CERT_REQUIRED, - cacerts=CERTFILE, chatty=False) + cacerts=CERTFILE, chatty=False, + connectionchatty=False) flag = threading.Event() server.start(flag) # wait for it to start @@ -470,7 +481,7 @@ else: s.connect(('127.0.0.1', TESTPORT)) except ssl.SSLError as x: if test_support.verbose: - sys.stdout.write("\nSSLError is %s\n" % x[1]) + sys.stdout.write("\nSSLError is %s\n" % x) else: raise test_support.TestFailed( "Use of invalid cert should have failed!") @@ -479,15 +490,16 @@ else: server.join() def serverParamsTest (certfile, protocol, certreqs, cacertsfile, - client_certfile, client_protocol=None, indata="FOO\n", - chatty=True, connectionchatty=False): + client_certfile, client_protocol=None, + indata="FOO\n", + chatty=False, connectionchatty=False): server = ThreadedEchoServer(TESTPORT, certfile, certreqs=certreqs, ssl_version=protocol, cacerts=cacertsfile, chatty=chatty, - connectionchatty=connectionchatty) + connectionchatty=False) flag = threading.Event() server.start(flag) # wait for it to start @@ -496,37 +508,37 @@ else: if client_protocol is None: client_protocol = protocol try: - try: - s = ssl.wrap_socket(socket.socket(), - certfile=client_certfile, - ca_certs=cacertsfile, - cert_reqs=certreqs, - ssl_version=client_protocol) - s.connect(('127.0.0.1', TESTPORT)) - except ssl.SSLError as x: - raise test_support.TestFailed("Unexpected SSL error: " + str(x)) - except Exception as x: - raise test_support.TestFailed("Unexpected exception: " + str(x)) - else: - if connectionchatty: - if test_support.verbose: - sys.stdout.write( - " client: sending %s...\n" % (repr(indata))) - s.write(indata) - outdata = s.read() - if connectionchatty: - if test_support.verbose: - sys.stdout.write(" client: read %s\n" % repr(outdata)) - if outdata != indata.lower(): - raise test_support.TestFailed( - "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" - % (outdata[:min(len(outdata),20)], len(outdata), - indata[:min(len(indata),20)].lower(), len(indata))) - s.write("over\n") - if connectionchatty: - if test_support.verbose: - sys.stdout.write(" client: closing connection.\n") - s.close() + s = ssl.wrap_socket(socket.socket(), + certfile=client_certfile, + ca_certs=cacertsfile, + cert_reqs=certreqs, + ssl_version=client_protocol) + s.connect(('127.0.0.1', TESTPORT)) + except ssl.SSLError as x: + raise test_support.TestFailed("Unexpected SSL error: " + str(x)) + except Exception as x: + raise test_support.TestFailed("Unexpected exception: " + str(x)) + else: + if connectionchatty: + if test_support.verbose: + sys.stdout.write( + " client: sending %s...\n" % (repr(indata))) + s.write(indata.encode('ASCII', 'strict')) + outdata = s.read() + if connectionchatty: + if test_support.verbose: + sys.stdout.write(" client: read %s\n" % repr(outdata)) + outdata = str(outdata, 'ASCII', 'strict') + if outdata != indata.lower(): + raise test_support.TestFailed( + "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" + % (repr(outdata[:min(len(outdata),20)]), len(outdata), + repr(indata[:min(len(indata),20)].lower()), len(indata))) + s.write("over\n".encode("ASCII", "strict")) + if connectionchatty: + if test_support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() finally: server.stop() server.join() @@ -553,7 +565,8 @@ else: certtype)) try: serverParamsTest(CERTFILE, server_protocol, certsreqs, - CERTFILE, CERTFILE, client_protocol, chatty=False) + CERTFILE, CERTFILE, client_protocol, + chatty=False, connectionchatty=False) except test_support.TestFailed: if expectedToWork: raise @@ -565,47 +578,7 @@ else: ssl.get_protocol_name(server_protocol))) - class ConnectedTests(unittest.TestCase): - - def testRudeShutdown(self): - - listener_ready = threading.Event() - listener_gone = threading.Event() - - # `listener` runs in a thread. It opens a socket listening on - # PORT, and sits in an accept() until the main thread connects. - # Then it rudely closes the socket, and sets Event `listener_gone` - # to let the main thread know the socket is gone. - def listener(): - s = socket.socket() - if hasattr(socket, 'SO_REUSEADDR'): - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if hasattr(socket, 'SO_REUSEPORT'): - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - s.bind(('127.0.0.1', TESTPORT)) - s.listen(5) - listener_ready.set() - s.accept() - s = None # reclaim the socket object, which also closes it - listener_gone.set() - - def connector(): - listener_ready.wait() - s = socket.socket() - s.connect(('127.0.0.1', TESTPORT)) - listener_gone.wait() - try: - ssl_sock = ssl.wrap_socket(s) - except socket.sslerror: - pass - else: - raise test_support.TestFailed( - 'connecting to closed SSL socket should have failed') - - t = threading.Thread(target=listener) - t.start() - connector() - t.join() + class ThreadedTests(unittest.TestCase): def testEcho (self): @@ -656,7 +629,7 @@ else: if test_support.verbose: sys.stdout.write(pprint.pformat(cert) + '\n') sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') - if not cert.has_key('subject'): + if 'subject' not in cert: raise test_support.TestFailed( "No subject field in certificate: %s." % pprint.pformat(cert)) @@ -680,6 +653,46 @@ else: badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, "badkey.pem")) + def testRudeShutdown(self): + + listener_ready = threading.Event() + listener_gone = threading.Event() + + # `listener` runs in a thread. It opens a socket listening on + # PORT, and sits in an accept() until the main thread connects. + # Then it rudely closes the socket, and sets Event `listener_gone` + # to let the main thread know the socket is gone. + def listener(): + s = socket.socket() + if hasattr(socket, 'SO_REUSEADDR'): + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(socket, 'SO_REUSEPORT'): + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + s.bind(('127.0.0.1', TESTPORT)) + s.listen(5) + listener_ready.set() + s.accept() + s = None # reclaim the socket object, which also closes it + listener_gone.set() + + def connector(): + listener_ready.wait() + s = socket.socket() + s.connect(('127.0.0.1', TESTPORT)) + listener_gone.wait() + try: + ssl_sock = ssl.wrap_socket(s) + except IOError: + pass + else: + raise test_support.TestFailed( + 'connecting to closed SSL socket should have failed') + + t = threading.Thread(target=listener) + t.start() + connector() + t.join() + def testProtocolSSL2(self): if test_support.verbose: sys.stdout.write("\n") @@ -759,39 +772,47 @@ else: if test_support.verbose: sys.stdout.write("\n") for indata in msgs: + msg = indata.encode('ASCII', 'replace') if test_support.verbose: sys.stdout.write( - " client: sending %s...\n" % repr(indata)) + " client: sending %s...\n" % repr(msg)) if wrapped: - conn.write(indata) + conn.write(msg) outdata = conn.read() else: - s.send(indata) + s.send(msg) outdata = s.recv(1024) if (indata == "STARTTLS" and - outdata.strip().lower().startswith("ok")): + str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")): if test_support.verbose: + msg = str(outdata, 'ASCII', 'replace') sys.stdout.write( " client: read %s from server, starting TLS...\n" - % repr(outdata)) + % repr(msg)) conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) wrapped = True else: if test_support.verbose: + msg = str(outdata, 'ASCII', 'replace') sys.stdout.write( - " client: read %s from server\n" % repr(outdata)) + " client: read %s from server\n" % repr(msg)) if test_support.verbose: sys.stdout.write(" client: closing connection.\n") if wrapped: - conn.write("over\n") + conn.write("over\n".encode("ASCII", "strict")) else: s.send("over\n") + if wrapped: + conn.close() + else: s.close() finally: server.stop() server.join() + class AsyncoreTests(unittest.TestCase): + def testAsyncore(self): server = AsyncoreHTTPSServer(TESTPORT, CERTFILE) @@ -824,6 +845,8 @@ else: raise test_support.TestFailed(msg) else: if not (d1 == d2): + print("d1 is", len(d1), repr(d1)) + print("d2 is", len(d2), repr(d2)) raise test_support.TestFailed( "Couldn't fetch data from HTTPS server") finally: @@ -863,6 +886,7 @@ def test_main(verbose=False): if (not os.path.exists(CERTFILE) or not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)): raise test_support.TestFailed("Can't read certificate files!") + TESTPORT = findtestsocket(10025, 12000) if not TESTPORT: raise test_support.TestFailed("Can't find open port to test servers on!") @@ -870,12 +894,13 @@ def test_main(verbose=False): tests = [BasicTests] if test_support.is_resource_enabled('network'): - tests.append(NetworkTests) + tests.append(NetworkedTests) if _have_threads: thread_info = test_support.threading_setup() if thread_info and test_support.is_resource_enabled('network'): - tests.append(ConnectedTests) + tests.append(ThreadedTests) + tests.append(AsyncoreTests) test_support.run_unittest(*tests) |