diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/test_ssl.py | 378 |
1 files changed, 188 insertions, 190 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 012b8bf..b03c45f 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -36,7 +36,7 @@ def handle_error(prefix): class BasicTests(unittest.TestCase): - def testSSLconnect(self): + def test_connect(self): if not support.is_resource_enabled('network'): return s = ssl.wrap_socket(socket.socket(socket.AF_INET), @@ -57,7 +57,7 @@ class BasicTests(unittest.TestCase): finally: s.close() - def testCrucialConstants(self): + def test_constants(self): ssl.PROTOCOL_SSLv2 ssl.PROTOCOL_SSLv23 ssl.PROTOCOL_SSLv3 @@ -66,7 +66,7 @@ class BasicTests(unittest.TestCase): ssl.CERT_OPTIONAL ssl.CERT_REQUIRED - def testRAND(self): + def test_random(self): v = ssl.RAND_status() if support.verbose: sys.stdout.write("\n RAND_status is %d (%s)\n" @@ -80,7 +80,7 @@ class BasicTests(unittest.TestCase): print("didn't raise TypeError") ssl.RAND_add("this is a random string", 75.0) - def testParseCert(self): + def test_parse_cert(self): # note that this uses an 'unofficial' function in _ssl.c, # provided solely for this test, to exercise the certificate # parsing code @@ -88,9 +88,9 @@ class BasicTests(unittest.TestCase): if support.verbose: sys.stdout.write("\n" + pprint.pformat(p) + "\n") - def testDERtoPEM(self): - - pem = open(SVN_PYTHON_ORG_ROOT_CERT, 'r').read() + def test_DER_to_PEM(self): + with open(SVN_PYTHON_ORG_ROOT_CERT, 'r') as f: + pem = f.read() d1 = ssl.PEM_cert_to_DER_cert(pem) p2 = ssl.DER_cert_to_PEM_cert(d1) d2 = ssl.PEM_cert_to_DER_cert(p2) @@ -166,7 +166,7 @@ class BasicTests(unittest.TestCase): class NetworkedTests(unittest.TestCase): - def testConnect(self): + def test_connect(self): s = ssl.wrap_socket(socket.socket(socket.AF_INET), cert_reqs=ssl.CERT_NONE) s.connect(("svn.python.org", 443)) @@ -213,7 +213,7 @@ class NetworkedTests(unittest.TestCase): os.read(fd, 0) self.assertEqual(e.exception.errno, errno.EBADF) - def testNonBlockingHandshake(self): + def test_non_blocking_handshake(self): s = socket.socket(socket.AF_INET) s.connect(("svn.python.org", 443)) s.setblocking(False) @@ -237,8 +237,7 @@ class NetworkedTests(unittest.TestCase): if support.verbose: sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count) - def testFetchServerCert(self): - + def test_get_server_certificate(self): pem = ssl.get_server_certificate(("svn.python.org", 443)) if not pem: self.fail("No server certificate on svn.python.org:443!") @@ -289,7 +288,6 @@ try: except ImportError: _have_threads = False else: - _have_threads = True class ThreadedEchoServer(threading.Thread): @@ -310,7 +308,7 @@ else: threading.Thread.__init__(self) self.daemon = True - def wrap_conn (self): + def wrap_conn(self): try: self.sslconn = ssl.wrap_socket(self.sock, server_side=True, certfile=self.server.certificate, @@ -359,7 +357,7 @@ else: else: self.sock.close() - def run (self): + def run(self): self.running = True if not self.server.starttls_server: if not self.wrap_conn(): @@ -367,28 +365,28 @@ else: while self.running: try: msg = self.read() - amsg = (msg and str(msg, 'ASCII', 'strict')) or '' - if not msg: + stripped = msg.strip() + if not stripped: # eof, so quit this handler self.running = False self.close() - elif amsg.strip() == 'over': + elif stripped == b'over': if support.verbose and self.server.connectionchatty: sys.stdout.write(" server: client closed connection\n") self.close() return elif (self.server.starttls_server and - amsg.strip() == 'STARTTLS'): + stripped == 'STARTTLS'): if support.verbose and self.server.connectionchatty: sys.stdout.write(" server: read STARTTLS from client, sending OK...\n") - self.write("OK\n".encode("ASCII", "strict")) + self.write(b"OK\n") if not self.wrap_conn(): return elif (self.server.starttls_server and self.sslconn - and amsg.strip() == 'ENDTLS'): + and stripped == 'ENDTLS'): if support.verbose and self.server.connectionchatty: sys.stdout.write(" server: read ENDTLS from client, sending OK...\n") - self.write("OK\n".encode("ASCII", "strict")) + self.write(b"OK\n") self.sock = self.sslconn.unwrap() self.sslconn = None if support.verbose and self.server.connectionchatty: @@ -397,9 +395,9 @@ else: if (support.verbose and self.server.connectionchatty): ctype = (self.sslconn and "encrypted") or "unencrypted" - sys.stdout.write(" server: read %s (%s), sending back %s (%s)...\n" - % (repr(msg), ctype, repr(msg.lower()), ctype)) - self.write(amsg.lower().encode('ASCII', 'strict')) + sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n" + % (msg, ctype, msg.lower(), ctype)) + self.write(msg.lower()) except socket.error: if self.server.chatty: handle_error("Test server failure:\n") @@ -432,11 +430,11 @@ else: threading.Thread.__init__(self) self.daemon = True - def start (self, flag=None): + def start(self, flag=None): self.flag = flag threading.Thread.start(self) - def run (self): + def run(self): self.sock.settimeout(0.05) self.sock.listen(5) self.active = True @@ -457,7 +455,7 @@ else: self.stop() self.sock.close() - def stop (self): + def stop(self): self.active = False class OurHTTPSServer(threading.Thread): @@ -467,12 +465,9 @@ else: class HTTPSServer(HTTPServer): def __init__(self, server_address, RequestHandlerClass, certfile): - HTTPServer.__init__(self, server_address, RequestHandlerClass) # we assume the certfile contains both private key and certificate self.certfile = certfile - self.active = False - self.active_lock = threading.Lock() self.allow_reuse_address = True def __str__(self): @@ -481,7 +476,7 @@ else: self.server_name, self.server_port)) - def get_request (self): + def get_request(self): # override this to wrap socket with SSL sock, addr = self.socket.accept() sslconn = ssl.wrap_socket(sock, server_side=True, @@ -489,7 +484,6 @@ else: return sslconn, addr class RootedHTTPRequestHandler(SimpleHTTPRequestHandler): - # need to override translate_path to get a known root, # instead of using os.curdir, since the test could be # run from anywhere @@ -520,7 +514,6 @@ else: return path def log_message(self, format, *args): - # we override this to suppress logging unless "verbose" if support.verbose: @@ -534,7 +527,6 @@ else: def __init__(self, certfile): self.flag = None - self.active = False self.RootedHTTPRequestHandler.root = os.path.split(CERTFILE)[0] self.server = self.HTTPSServer( (HOST, 0), self.RootedHTTPRequestHandler, certfile) @@ -545,19 +537,16 @@ else: def __str__(self): return "<%s %s>" % (self.__class__.__name__, self.server) - def start (self, flag=None): + def start(self, flag=None): self.flag = flag threading.Thread.start(self) - def run (self): - self.active = True + def run(self): if self.flag: self.flag.set() self.server.serve_forever(0.05) - self.active = False - def stop (self): - self.active = False + def stop(self): self.server.shutdown() @@ -609,7 +598,7 @@ else: if not data: self.close() else: - self.send(str(data, 'ASCII', 'strict').lower().encode('ASCII', 'strict')) + self.send(data.lower()) def handle_close(self): self.close() @@ -650,7 +639,7 @@ else: self.flag = flag threading.Thread.start(self) - def run (self): + def run(self): self.active = True if self.flag: self.flag.set() @@ -660,11 +649,15 @@ else: except: pass - def stop (self): + def stop(self): self.active = False self.server.close() - def badCertTest (certfile): + def bad_cert_test(certfile): + """ + Launch a server with CERT_REQUIRED, and check that trying to + connect to it with the given client certificate fails. + """ server = ThreadedEchoServer(CERTFILE, certreqs=ssl.CERT_REQUIRED, cacerts=CERTFILE, chatty=False, @@ -684,7 +677,7 @@ else: if support.verbose: sys.stdout.write("\nSSLError is %s\n" % x.args[1]) except socket.error as x: - if test_support.verbose: + if support.verbose: sys.stdout.write("\nsocket.error is %s\n" % x[1]) else: self.fail("Use of invalid cert should have failed!") @@ -692,11 +685,13 @@ else: server.stop() server.join() - def serverParamsTest (certfile, protocol, certreqs, cacertsfile, - client_certfile, client_protocol=None, - indata="FOO\n", - ciphers=None, chatty=False, connectionchatty=False): - + def server_params_test(certfile, protocol, certreqs, cacertsfile, + client_certfile, client_protocol=None, indata=b"FOO\n", + ciphers=None, chatty=True, connectionchatty=False): + """ + Launch a server, connect a client to it and try various reads + and writes. + """ server = ThreadedEchoServer(certfile, certreqs=certreqs, ssl_version=protocol, @@ -719,24 +714,22 @@ else: cert_reqs=certreqs, ssl_version=client_protocol) s.connect((HOST, server.port)) - bindata = indata.encode('ASCII', 'strict') - for arg in [bindata, bytearray(bindata), memoryview(bindata)]: + for arg in [indata, bytearray(indata), memoryview(indata)]: if connectionchatty: if support.verbose: sys.stdout.write( - " client: sending %s...\n" % (repr(indata))) + " client: sending %r...\n" % indata) s.write(arg) outdata = s.read() if connectionchatty: if support.verbose: - sys.stdout.write(" client: read %s\n" % repr(outdata)) - outdata = str(outdata, 'ASCII', 'strict') + sys.stdout.write(" client: read %r\n" % outdata) if outdata != indata.lower(): self.fail( - "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" - % (repr(outdata[:min(len(outdata),20)]), len(outdata), - repr(indata[:min(len(indata),20)].lower()), len(indata))) - s.write("over\n".encode("ASCII", "strict")) + "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n" + % (outdata[:20], len(outdata), + indata[:20].lower(), len(indata))) + s.write(b"over\n") if connectionchatty: if support.verbose: sys.stdout.write(" client: closing connection.\n") @@ -745,22 +738,19 @@ else: server.stop() server.join() - def tryProtocolCombo (server_protocol, - client_protocol, - expectedToWork, - certsreqs=None): - + def try_protocol_combo(server_protocol, + client_protocol, + expect_success, + certsreqs=None): if certsreqs is None: certsreqs = ssl.CERT_NONE - - if certsreqs == ssl.CERT_NONE: - certtype = "CERT_NONE" - elif certsreqs == ssl.CERT_OPTIONAL: - certtype = "CERT_OPTIONAL" - elif certsreqs == ssl.CERT_REQUIRED: - certtype = "CERT_REQUIRED" + certtype = { + ssl.CERT_NONE: "CERT_NONE", + ssl.CERT_OPTIONAL: "CERT_OPTIONAL", + ssl.CERT_REQUIRED: "CERT_REQUIRED", + }[certsreqs] if support.verbose: - formatstr = (expectedToWork and " %s->%s %s\n") or " {%s->%s} %s\n" + formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n" sys.stdout.write(formatstr % (ssl.get_protocol_name(client_protocol), ssl.get_protocol_name(server_protocol), @@ -769,20 +759,20 @@ else: # NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client # will send an SSLv3 hello (rather than SSLv2) starting from # OpenSSL 1.0.0 (see issue #8322). - serverParamsTest(CERTFILE, server_protocol, certsreqs, - CERTFILE, CERTFILE, client_protocol, - ciphers="ALL", - chatty=False, connectionchatty=False) + server_params_test(CERTFILE, server_protocol, certsreqs, + CERTFILE, CERTFILE, client_protocol, + ciphers="ALL", chatty=False, + connectionchatty=False) # Protocol mismatch can result in either an SSLError, or a # "Connection reset by peer" error. except ssl.SSLError: - if expectedToWork: + if expect_success: raise except socket.error as e: - if expectedToWork or e.errno != errno.ECONNRESET: + if expect_success or e.errno != errno.ECONNRESET: raise else: - if not expectedToWork: + if not expect_success: self.fail( "Client protocol %s succeeded with server protocol %s!" % (ssl.get_protocol_name(client_protocol), @@ -791,16 +781,15 @@ else: class ThreadedTests(unittest.TestCase): - def testEcho (self): - + def test_echo(self): + """Basic test of an SSL client connecting to a server""" if support.verbose: sys.stdout.write("\n") - serverParamsTest(CERTFILE, ssl.PROTOCOL_TLSv1, ssl.CERT_NONE, - CERTFILE, CERTFILE, ssl.PROTOCOL_TLSv1, - chatty=True, connectionchatty=True) - - def testReadCert(self): + server_params_test(CERTFILE, ssl.PROTOCOL_TLSv1, ssl.CERT_NONE, + CERTFILE, CERTFILE, ssl.PROTOCOL_TLSv1, + chatty=True, connectionchatty=True) + def test_getpeercert(self): if support.verbose: sys.stdout.write("\n") s2 = socket.socket() @@ -840,23 +829,30 @@ else: server.stop() server.join() - def testNULLcert(self): - badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, - "nullcert.pem")) - def testMalformedCert(self): - badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, - "badcert.pem")) - def testWrongCert(self): - badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, - "wrongcert.pem")) - def testMalformedKey(self): - badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir, - "badkey.pem")) - - def testRudeShutdown(self): - + def test_empty_cert(self): + """Connecting with an empty cert file""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "nullcert.pem")) + def test_malformed_cert(self): + """Connecting with a badly formatted certificate (syntax error)""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "badcert.pem")) + def test_nonexisting_cert(self): + """Connecting with a non-existing cert file""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "wrongcert.pem")) + def test_malformed_key(self): + """Connecting with a badly formatted key (syntax error)""" + bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, + "badkey.pem")) + + def test_rude_shutdown(self): + """A brutal shutdown of an SSL server should raise an IOError + in the client when attempting handshake. + """ listener_ready = threading.Event() listener_gone = threading.Event() + s = socket.socket() port = support.bind_port(s, HOST) @@ -890,62 +886,66 @@ else: finally: t.join() - def testProtocolSSL2(self): + def test_protocol_sslv2(self): + """Connecting to an SSLv2 server with various client options""" if support.verbose: sys.stdout.write("\n") - tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True) - tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL) - tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED) - tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True) - tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False) - tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False) - - def testProtocolSSL23(self): + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False) + + def test_protocol_sslv23(self): + """Connecting to an SSLv23 server with various client options""" if support.verbose: sys.stdout.write("\n") try: - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) except (ssl.SSLError, socket.error) as x: # this fails on some older versions of OpenSSL (0.9.7l, for instance) if support.verbose: sys.stdout.write( " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" % str(x)) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) - tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) - def testProtocolSSL3(self): + def test_protocol_sslv3(self): + """Connecting to an SSLv3 server with various client options""" if support.verbose: sys.stdout.write("\n") - tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True) - tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) - tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) - tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) - tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False) - tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) - - def testProtocolTLS1(self): + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) + + def test_protocol_tlsv1(self): + """Connecting to a TLSv1 server with various client options""" if support.verbose: sys.stdout.write("\n") - tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True) - tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) - tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) - tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) - tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) - tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False) - - def testSTARTTLS (self): + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False) - msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", "msg 5", "msg 6") + def test_starttls(self): + """Switching from clear text to encrypted and back again.""" + msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6") server = ThreadedEchoServer(CERTFILE, ssl_version=ssl.PROTOCOL_TLSv1, @@ -965,45 +965,42 @@ else: if support.verbose: sys.stdout.write("\n") for indata in msgs: - msg = indata.encode('ASCII', 'replace') if support.verbose: sys.stdout.write( - " client: sending %s...\n" % repr(msg)) + " client: sending %r...\n" % indata) if wrapped: - conn.write(msg) + conn.write(indata) outdata = conn.read() else: - s.send(msg) + s.send(indata) outdata = s.recv(1024) - if (indata == "STARTTLS" and - str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")): + msg = outdata.strip().lower() + if indata == b"STARTTLS" and msg.startswith(b"ok"): + # STARTTLS ok, switch to secure mode if support.verbose: - msg = str(outdata, 'ASCII', 'replace') sys.stdout.write( - " client: read %s from server, starting TLS...\n" - % repr(msg)) + " client: read %r from server, starting TLS...\n" + % msg) conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) wrapped = True - elif (indata == "ENDTLS" and - str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")): + elif indata == b"ENDTLS" and msg.startswith(b"ok"): + # ENDTLS ok, switch back to clear text if support.verbose: - msg = str(outdata, 'ASCII', 'replace') sys.stdout.write( - " client: read %s from server, ending TLS...\n" - % repr(msg)) + " client: read %r from server, ending TLS...\n" + % msg) s = conn.unwrap() wrapped = False else: if support.verbose: - msg = str(outdata, 'ASCII', 'replace') sys.stdout.write( - " client: read %s from server\n" % repr(msg)) + " client: read %r from server\n" % msg) if support.verbose: sys.stdout.write(" client: closing connection.\n") if wrapped: - conn.write("over\n".encode("ASCII", "strict")) + conn.write(b"over\n") else: - s.send("over\n".encode("ASCII", "strict")) + s.send(b"over\n") if wrapped: conn.close() else: @@ -1012,8 +1009,8 @@ else: server.stop() server.join() - def testSocketServer(self): - + def test_socketserver(self): + """Using a SocketServer to create and manage SSL connections.""" server = OurHTTPSServer(CERTFILE) flag = threading.Event() server.start(flag) @@ -1023,7 +1020,8 @@ else: try: if support.verbose: sys.stdout.write('\n') - d1 = open(CERTFILE, 'rb').read() + with open(CERTFILE, 'rb') as f: + d1 = f.read() d2 = '' # now fetch the same data from the HTTPS server url = 'https://%s:%d/%s' % ( @@ -1046,12 +1044,14 @@ else: sys.stdout.write('joining thread\n') server.join() - def testAsyncoreServer(self): + def test_asyncore_server(self): + """Check the example asyncore integration.""" + indata = "TEST MESSAGE of mixed case\n" if support.verbose: sys.stdout.write("\n") - indata="FOO\n" + indata = b"FOO\n" server = AsyncoreEchoServer(CERTFILE) flag = threading.Event() server.start(flag) @@ -1063,18 +1063,17 @@ else: s.connect(('127.0.0.1', server.port)) if support.verbose: sys.stdout.write( - " client: sending %s...\n" % (repr(indata))) - s.write(indata.encode('ASCII', 'strict')) + " client: sending %r...\n" % indata) + s.write(indata) outdata = s.read() if support.verbose: - sys.stdout.write(" client: read %s\n" % repr(outdata)) - outdata = str(outdata, 'ASCII', 'strict') + sys.stdout.write(" client: read %r\n" % outdata) if outdata != indata.lower(): self.fail( - "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" - % (outdata[:min(len(outdata),20)], len(outdata), - indata[:min(len(indata),20)].lower(), len(indata))) - s.write("over\n".encode("ASCII", "strict")) + "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n" + % (outdata[:20], len(outdata), + indata[:20].lower(), len(indata))) + s.write(b"over\n") if support.verbose: sys.stdout.write(" client: closing connection.\n") s.close() @@ -1082,8 +1081,8 @@ else: server.stop() server.join() - def testAllRecvAndSendMethods(self): - + def test_recv_send(self): + """Test recv(), send() and friends.""" if support.verbose: sys.stdout.write("\n") @@ -1132,19 +1131,18 @@ else: data_prefix = "PREFIX_" for meth_name, send_meth, expect_success, args in send_methods: - indata = data_prefix + meth_name + indata = (data_prefix + meth_name).encode('ascii') try: - send_meth(indata.encode('ASCII', 'strict'), *args) + send_meth(indata, *args) outdata = s.read() - outdata = str(outdata, 'ASCII', 'strict') if outdata != indata.lower(): self.fail( "While sending with <<{name:s}>> bad data " - "<<{outdata:s}>> ({nout:d}) received; " - "expected <<{indata:s}>> ({nin:d})\n".format( - name=meth_name, outdata=repr(outdata[:20]), + "<<{outdata:r}>> ({nout:d}) received; " + "expected <<{indata:r}>> ({nin:d})\n".format( + name=meth_name, outdata=outdata[:20], nout=len(outdata), - indata=repr(indata[:20]), nin=len(indata) + indata=indata[:20], nin=len(indata) ) ) except ValueError as e: @@ -1162,19 +1160,18 @@ else: ) for meth_name, recv_meth, expect_success, args in recv_methods: - indata = data_prefix + meth_name + indata = (data_prefix + meth_name).encode('ascii') try: - s.send(indata.encode('ASCII', 'strict')) + s.send(indata) outdata = recv_meth(*args) - outdata = str(outdata, 'ASCII', 'strict') if outdata != indata.lower(): self.fail( "While receiving with <<{name:s}>> bad data " - "<<{outdata:s}>> ({nout:d}) received; " - "expected <<{indata:s}>> ({nin:d})\n".format( - name=meth_name, outdata=repr(outdata[:20]), + "<<{outdata:r}>> ({nout:d}) received; " + "expected <<{indata:r}>> ({nin:d})\n".format( + name=meth_name, outdata=outdata[:20], nout=len(outdata), - indata=repr(indata[:20]), nin=len(indata) + indata=indata[:20], nin=len(indata) ) ) except ValueError as e: @@ -1193,7 +1190,7 @@ else: # consume data s.read() - s.write("over\n".encode("ASCII", "strict")) + s.write(b"over\n") s.close() finally: server.stop() @@ -1272,10 +1269,11 @@ def test_main(verbose=False): if thread_info and support.is_resource_enabled('network'): tests.append(ThreadedTests) - support.run_unittest(*tests) - - if _have_threads: - support.threading_cleanup(*thread_info) + try: + support.run_unittest(*tests) + finally: + if _have_threads: + support.threading_cleanup(*thread_info) if __name__ == "__main__": test_main() |