diff options
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r-- | Lib/test/test_ssl.py | 423 |
1 files changed, 371 insertions, 52 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 1c4aa7c..75dc202 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -6,6 +6,7 @@ from test import support import socket import select import time +import datetime import gc import os import errno @@ -17,16 +18,11 @@ import asyncore import weakref import platform import functools +from unittest import mock ssl = support.import_module("ssl") -PROTOCOLS = [ - ssl.PROTOCOL_SSLv3, - ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1 -] -if hasattr(ssl, 'PROTOCOL_SSLv2'): - PROTOCOLS.append(ssl.PROTOCOL_SSLv2) - +PROTOCOLS = sorted(ssl._PROTOCOL_NAMES) HOST = support.HOST data_file = lambda name: os.path.join(os.path.dirname(__file__), name) @@ -48,6 +44,11 @@ KEY_PASSWORD = "somepass" CAPATH = data_file("capath") BYTES_CAPATH = os.fsencode(CAPATH) +# Two keys and certs signed by the same CA (for SNI tests) +SIGNED_CERTFILE = data_file("keycert3.pem") +SIGNED_CERTFILE2 = data_file("keycert4.pem") +SIGNING_CA = data_file("pycacert.pem") + SVN_PYTHON_ORG_ROOT_CERT = data_file("https_svn_python_org_root.pem") EMPTYCERT = data_file("nullcert.pem") @@ -59,6 +60,7 @@ NOKIACERT = data_file("nokia.pem") DHFILE = data_file("dh512.pem") BYTES_DHFILE = os.fsencode(DHFILE) + def handle_error(prefix): exc_format = ' '.join(traceback.format_exception(*sys.exc_info())) if support.verbose: @@ -72,6 +74,19 @@ def no_sslv2_implies_sslv3_hello(): # 0.9.7h or higher return ssl.OPENSSL_VERSION_INFO >= (0, 9, 7, 8, 15) +def asn1time(cert_time): + # Some versions of OpenSSL ignore seconds, see #18207 + # 0.9.8.i + if ssl._OPENSSL_API_VERSION == (0, 9, 8, 9, 15): + fmt = "%b %d %H:%M:%S %Y GMT" + dt = datetime.datetime.strptime(cert_time, fmt) + dt = dt.replace(second=0) + cert_time = dt.strftime(fmt) + # %d adds leading zero but ASN1_TIME_print() uses leading space + if cert_time[4] == "0": + cert_time = cert_time[:4] + " " + cert_time[5:] + + return cert_time # Issue #9415: Ubuntu hijacks their OpenSSL and forcefully disables SSLv2 def skip_if_broken_ubuntu_ssl(func): @@ -89,14 +104,12 @@ def skip_if_broken_ubuntu_ssl(func): else: return func +needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test") + class BasicSocketTests(unittest.TestCase): def test_constants(self): - #ssl.PROTOCOL_SSLv2 - ssl.PROTOCOL_SSLv23 - ssl.PROTOCOL_SSLv3 - ssl.PROTOCOL_TLSv1 ssl.CERT_NONE ssl.CERT_OPTIONAL ssl.CERT_REQUIRED @@ -142,8 +155,9 @@ class BasicSocketTests(unittest.TestCase): (('organizationName', 'Python Software Foundation'),), (('commonName', 'localhost'),)) ) - self.assertEqual(p['notAfter'], 'Oct 5 23:01:56 2020 GMT') - self.assertEqual(p['notBefore'], 'Oct 8 23:01:56 2010 GMT') + # Note the next three asserts will fail if the keys are regenerated + self.assertEqual(p['notAfter'], asn1time('Oct 5 23:01:56 2020 GMT')) + self.assertEqual(p['notBefore'], asn1time('Oct 8 23:01:56 2010 GMT')) self.assertEqual(p['serialNumber'], 'D7C7381919AFC24E') self.assertEqual(p['subject'], ((('countryName', 'XY'),), @@ -214,15 +228,15 @@ class BasicSocketTests(unittest.TestCase): def test_wrapped_unconnected(self): # Methods on an unconnected SSLSocket propagate the original - # socket.error raise by the underlying socket object. + # OSError raise by the underlying socket object. s = socket.socket(socket.AF_INET) with ssl.wrap_socket(s) as ss: - self.assertRaises(socket.error, ss.recv, 1) - self.assertRaises(socket.error, ss.recv_into, bytearray(b'x')) - self.assertRaises(socket.error, ss.recvfrom, 1) - self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1) - self.assertRaises(socket.error, ss.send, b'x') - self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0)) + self.assertRaises(OSError, ss.recv, 1) + self.assertRaises(OSError, ss.recv_into, bytearray(b'x')) + self.assertRaises(OSError, ss.recvfrom, 1) + self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1) + self.assertRaises(OSError, ss.send, b'x') + self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0)) def test_timeout(self): # Issue #8524: when creating an SSL socket, the timeout of the @@ -247,15 +261,15 @@ class BasicSocketTests(unittest.TestCase): with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s: self.assertRaisesRegex(ValueError, "can't connect in server-side mode", s.connect, (HOST, 8080)) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: with socket.socket() as sock: ssl.wrap_socket(sock, certfile=WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: with socket.socket() as sock: ssl.wrap_socket(sock, certfile=CERTFILE, keyfile=WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: with socket.socket() as sock: ssl.wrap_socket(sock, certfile=WRONGCERT, keyfile=WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) @@ -394,15 +408,48 @@ class BasicSocketTests(unittest.TestCase): support.gc_collect() self.assertIn(r, str(cm.warning.args[0])) + def test_get_default_verify_paths(self): + paths = ssl.get_default_verify_paths() + self.assertEqual(len(paths), 6) + self.assertIsInstance(paths, ssl.DefaultVerifyPaths) + + with support.EnvironmentVarGuard() as env: + env["SSL_CERT_DIR"] = CAPATH + env["SSL_CERT_FILE"] = CERTFILE + paths = ssl.get_default_verify_paths() + self.assertEqual(paths.cafile, CERTFILE) + self.assertEqual(paths.capath, CAPATH) + + + @unittest.skipUnless(sys.platform == "win32", "Windows specific") + def test_enum_cert_store(self): + self.assertEqual(ssl.X509_ASN_ENCODING, 1) + self.assertEqual(ssl.PKCS_7_ASN_ENCODING, 0x00010000) + + self.assertEqual(ssl.enum_cert_store("CA"), + ssl.enum_cert_store("CA", "certificate")) + ssl.enum_cert_store("CA", "crl") + self.assertEqual(ssl.enum_cert_store("ROOT"), + ssl.enum_cert_store("ROOT", "certificate")) + ssl.enum_cert_store("ROOT", "crl") + + self.assertRaises(TypeError, ssl.enum_cert_store) + self.assertRaises(WindowsError, ssl.enum_cert_store, "") + self.assertRaises(ValueError, ssl.enum_cert_store, "CA", "wrong") + + ca = ssl.enum_cert_store("CA") + self.assertIsInstance(ca, list) + self.assertIsInstance(ca[0], tuple) + self.assertEqual(len(ca[0]), 2) + self.assertIsInstance(ca[0][0], bytes) + self.assertIsInstance(ca[0][1], int) + class ContextTests(unittest.TestCase): @skip_if_broken_ubuntu_ssl def test_constructor(self): - if hasattr(ssl, 'PROTOCOL_SSLv2'): - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv2) - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv3) - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + for protocol in PROTOCOLS: + ssl.SSLContext(protocol) self.assertRaises(TypeError, ssl.SSLContext) self.assertRaises(ValueError, ssl.SSLContext, -1) self.assertRaises(ValueError, ssl.SSLContext, 42) @@ -462,7 +509,7 @@ class ContextTests(unittest.TestCase): ctx.load_cert_chain(CERTFILE) ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE) self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: ctx.load_cert_chain(WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) with self.assertRaisesRegex(ssl.SSLError, "PEM lib"): @@ -547,7 +594,7 @@ class ContextTests(unittest.TestCase): ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None) self.assertRaises(TypeError, ctx.load_verify_locations) self.assertRaises(TypeError, ctx.load_verify_locations, None, None) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: ctx.load_verify_locations(WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) with self.assertRaisesRegex(ssl.SSLError, "PEM lib"): @@ -605,6 +652,75 @@ class ContextTests(unittest.TestCase): self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo") self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo") + @needs_sni + def test_sni_callback(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + + # set_servername_callback expects a callable, or None + self.assertRaises(TypeError, ctx.set_servername_callback) + self.assertRaises(TypeError, ctx.set_servername_callback, 4) + self.assertRaises(TypeError, ctx.set_servername_callback, "") + self.assertRaises(TypeError, ctx.set_servername_callback, ctx) + + def dummycallback(sock, servername, ctx): + pass + ctx.set_servername_callback(None) + ctx.set_servername_callback(dummycallback) + + @needs_sni + def test_sni_callback_refcycle(self): + # Reference cycles through the servername callback are detected + # and cleared. + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + def dummycallback(sock, servername, ctx, cycle=ctx): + pass + ctx.set_servername_callback(dummycallback) + wr = weakref.ref(ctx) + del ctx, dummycallback + gc.collect() + self.assertIs(wr(), None) + + def test_cert_store_stats(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 0, 'crl': 0, 'x509': 0}) + ctx.load_cert_chain(CERTFILE) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 0, 'crl': 0, 'x509': 0}) + ctx.load_verify_locations(CERTFILE) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 0, 'crl': 0, 'x509': 1}) + ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 1, 'crl': 0, 'x509': 2}) + + def test_get_ca_certs(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.get_ca_certs(), []) + # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE + ctx.load_verify_locations(CERTFILE) + self.assertEqual(ctx.get_ca_certs(), []) + # but SVN_PYTHON_ORG_ROOT_CERT is a CA cert + ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT) + self.assertEqual(ctx.get_ca_certs(), + [{'issuer': ((('organizationName', 'Root CA'),), + (('organizationalUnitName', 'http://www.cacert.org'),), + (('commonName', 'CA Cert Signing Authority'),), + (('emailAddress', 'support@cacert.org'),)), + 'notAfter': asn1time('Mar 29 12:29:49 2033 GMT'), + 'notBefore': asn1time('Mar 30 12:29:49 2003 GMT'), + 'serialNumber': '00', + 'subject': ((('organizationName', 'Root CA'),), + (('organizationalUnitName', 'http://www.cacert.org'),), + (('commonName', 'CA Cert Signing Authority'),), + (('emailAddress', 'support@cacert.org'),)), + 'version': 3}]) + + with open(SVN_PYTHON_ORG_ROOT_CERT) as f: + pem = f.read() + der = ssl.PEM_cert_to_DER_cert(pem) + self.assertEqual(ctx.get_ca_certs(True), [der]) + class SSLErrorTests(unittest.TestCase): @@ -920,6 +1036,22 @@ class NetworkedTests(unittest.TestCase): finally: s.close() + def test_get_ca_certs_capath(self): + # capath certs are loaded on request + with support.transient_internet("svn.python.org"): + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(capath=CAPATH) + self.assertEqual(ctx.get_ca_certs(), []) + s = ctx.wrap_socket(socket.socket(socket.AF_INET)) + s.connect(("svn.python.org", 443)) + try: + cert = s.getpeercert() + self.assertTrue(cert) + finally: + s.close() + self.assertEqual(len(ctx.get_ca_certs()), 1) + try: import threading @@ -1047,7 +1179,7 @@ else: sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n" % (msg, ctype, msg.lower(), ctype)) self.write(msg.lower()) - except socket.error: + except OSError: if self.server.chatty: handle_error("Test server failure:\n") self.close() @@ -1157,7 +1289,7 @@ else: return self.handle_close() except ssl.SSLError: raise - except socket.error as err: + except OSError as err: if err.args[0] == errno.ECONNABORTED: return self.handle_close() else: @@ -1261,19 +1393,19 @@ else: except ssl.SSLError as x: if support.verbose: sys.stdout.write("\nSSLError is %s\n" % x.args[1]) - except socket.error as x: + except OSError as x: if support.verbose: - sys.stdout.write("\nsocket.error is %s\n" % x.args[1]) - except IOError as x: + sys.stdout.write("\nOSError is %s\n" % x.args[1]) + except OSError as x: if x.errno != errno.ENOENT: raise if support.verbose: - sys.stdout.write("\IOError is %s\n" % str(x)) + sys.stdout.write("\OSError is %s\n" % str(x)) else: raise AssertionError("Use of invalid cert should have failed!") def server_params_test(client_context, server_context, indata=b"FOO\n", - chatty=True, connectionchatty=False): + chatty=True, connectionchatty=False, sni_name=None): """ Launch a server, connect a client to it and try various reads and writes. @@ -1283,7 +1415,8 @@ else: chatty=chatty, connectionchatty=False) with server: - with client_context.wrap_socket(socket.socket()) as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=sni_name) as s: s.connect((HOST, server.port)) for arg in [indata, bytearray(indata), memoryview(indata)]: if connectionchatty: @@ -1307,6 +1440,7 @@ else: stats.update({ 'compression': s.compression(), 'cipher': s.cipher(), + 'peercert': s.getpeercert(), 'client_npn_protocol': s.selected_npn_protocol() }) s.close() @@ -1332,12 +1466,15 @@ else: client_context.options = ssl.OP_ALL | client_options server_context = ssl.SSLContext(server_protocol) server_context.options = ssl.OP_ALL | server_options + + # NOTE: we must enable "ALL" ciphers on the client, otherwise an + # SSLv23 client will send an SSLv3 hello (rather than SSLv2) + # starting from OpenSSL 1.0.0 (see issue #8322). + if client_context.protocol == ssl.PROTOCOL_SSLv23: + client_context.set_ciphers("ALL") + for ctx in (client_context, server_context): ctx.verify_mode = certsreqs - # NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client - # will send an SSLv3 hello (rather than SSLv2) starting from - # OpenSSL 1.0.0 (see issue #8322). - ctx.set_ciphers("ALL") ctx.load_cert_chain(CERTFILE) ctx.load_verify_locations(CERTFILE) try: @@ -1348,7 +1485,7 @@ else: except ssl.SSLError: if expect_success: raise - except socket.error as e: + except OSError as e: if expect_success or e.errno != errno.ECONNRESET: raise else: @@ -1367,10 +1504,11 @@ else: if support.verbose: sys.stdout.write("\n") for protocol in PROTOCOLS: - context = ssl.SSLContext(protocol) - context.load_cert_chain(CERTFILE) - server_params_test(context, context, - chatty=True, connectionchatty=True) + with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]): + context = ssl.SSLContext(protocol) + context.load_cert_chain(CERTFILE) + server_params_test(context, context, + chatty=True, connectionchatty=True) def test_getpeercert(self): if support.verbose: @@ -1422,7 +1560,7 @@ else: "badkey.pem")) def test_rude_shutdown(self): - """A brutal shutdown of an SSL server should raise an IOError + """A brutal shutdown of an SSL server should raise an OSError in the client when attempting handshake. """ listener_ready = threading.Event() @@ -1450,7 +1588,7 @@ else: listener_gone.wait() try: ssl_sock = ssl.wrap_socket(c) - except IOError: + except OSError: pass else: self.fail('connecting to closed SSL socket should have failed') @@ -1493,7 +1631,7 @@ else: if hasattr(ssl, 'PROTOCOL_SSLv2'): try: try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) - except (ssl.SSLError, socket.error) as x: + except OSError as x: # this fails on some older versions of OpenSSL (0.9.7l, for instance) if support.verbose: sys.stdout.write( @@ -1553,6 +1691,49 @@ else: try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_TLSv1) + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"), + "TLS version 1.1 not supported.") + def test_protocol_tlsv1_1(self): + """Connecting to a TLSv1.1 server with various client options. + Testing against older TLS versions.""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, True) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1_1) + + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) + + + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"), + "TLS version 1.2 not supported.") + def test_protocol_tlsv1_2(self): + """Connecting to a TLSv1.2 server with various client options. + Testing against older TLS versions.""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, True, + server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2, + client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1_2) + + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) + def test_starttls(self): """Switching from clear text to encrypted and back again.""" msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6") @@ -1613,7 +1794,7 @@ else: def test_socketserver(self): """Using a SocketServer to create and manage SSL connections.""" - server = make_https_server(self, CERTFILE) + server = make_https_server(self, certfile=CERTFILE) # try to connect if support.verbose: sys.stdout.write('\n') @@ -1869,6 +2050,20 @@ else: self.assertIsInstance(remote, ssl.SSLSocket) self.assertEqual(peer, client_addr) + def test_getpeercert_enotconn(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + with context.wrap_socket(socket.socket()) as sock: + with self.assertRaises(OSError) as cm: + sock.getpeercert() + self.assertEqual(cm.exception.errno, errno.ENOTCONN) + + def test_do_handshake_enotconn(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + with context.wrap_socket(socket.socket()) as sock: + with self.assertRaises(OSError) as cm: + sock.do_handshake() + self.assertEqual(cm.exception.errno, errno.ENOTCONN) + def test_default_ciphers(self): context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) try: @@ -1880,7 +2075,7 @@ else: ssl_version=ssl.PROTOCOL_SSLv23, chatty=False) as server: with context.wrap_socket(socket.socket()) as s: - with self.assertRaises((OSError, ssl.SSLError)): + with self.assertRaises(OSError): s.connect((HOST, server.port)) self.assertIn("no shared cipher", str(server.conn_errors[0])) @@ -2013,6 +2208,124 @@ else: if len(stats['server_npn_protocols']) else 'nothing' self.assertEqual(server_result, expected, msg % (server_result, "server")) + def sni_contexts(self): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + other_context.load_cert_chain(SIGNED_CERTFILE2) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_verify_locations(SIGNING_CA) + return server_context, other_context, client_context + + def check_common_name(self, stats, name): + cert = stats['peercert'] + self.assertIn((('commonName', name),), cert['subject']) + + @needs_sni + def test_sni_callback(self): + calls = [] + server_context, other_context, client_context = self.sni_contexts() + + def servername_cb(ssl_sock, server_name, initial_context): + calls.append((server_name, initial_context)) + if server_name is not None: + ssl_sock.context = other_context + server_context.set_servername_callback(servername_cb) + + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name='supermessage') + # The hostname was fetched properly, and the certificate was + # changed for the connection. + self.assertEqual(calls, [("supermessage", server_context)]) + # CERTFILE4 was selected + self.check_common_name(stats, 'fakehostname') + + calls = [] + # The callback is called with server_name=None + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name=None) + self.assertEqual(calls, [(None, server_context)]) + self.check_common_name(stats, 'localhost') + + # Check disabling the callback + calls = [] + server_context.set_servername_callback(None) + + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name='notfunny') + # Certificate didn't change + self.check_common_name(stats, 'localhost') + self.assertEqual(calls, []) + + @needs_sni + def test_sni_callback_alert(self): + # Returning a TLS alert is reflected to the connecting client + server_context, other_context, client_context = self.sni_contexts() + + def cb_returning_alert(ssl_sock, server_name, initial_context): + return ssl.ALERT_DESCRIPTION_ACCESS_DENIED + server_context.set_servername_callback(cb_returning_alert) + + with self.assertRaises(ssl.SSLError) as cm: + stats = server_params_test(client_context, server_context, + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED') + + @needs_sni + def test_sni_callback_raising(self): + # Raising fails the connection with a TLS handshake failure alert. + server_context, other_context, client_context = self.sni_contexts() + + def cb_raising(ssl_sock, server_name, initial_context): + 1/0 + server_context.set_servername_callback(cb_raising) + + with self.assertRaises(ssl.SSLError) as cm, \ + support.captured_stderr() as stderr: + stats = server_params_test(client_context, server_context, + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE') + self.assertIn("ZeroDivisionError", stderr.getvalue()) + + @needs_sni + def test_sni_callback_wrong_return_type(self): + # Returning the wrong return type terminates the TLS connection + # with an internal error alert. + server_context, other_context, client_context = self.sni_contexts() + + def cb_wrong_return_type(ssl_sock, server_name, initial_context): + return "foo" + server_context.set_servername_callback(cb_wrong_return_type) + + with self.assertRaises(ssl.SSLError) as cm, \ + support.captured_stderr() as stderr: + stats = server_params_test(client_context, server_context, + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR') + self.assertIn("TypeError", stderr.getvalue()) + + def test_read_write_after_close_raises_valuerror(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + + with server: + s = context.wrap_socket(socket.socket()) + s.connect((HOST, server.port)) + s.close() + + self.assertRaises(ValueError, s.read, 1024) + self.assertRaises(ValueError, s.write, b'hello') + def test_main(verbose=False): if support.verbose: @@ -2032,10 +2345,16 @@ def test_main(verbose=False): (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO)) print(" under %s" % plat) print(" HAS_SNI = %r" % ssl.HAS_SNI) + print(" OP_ALL = 0x%8x" % ssl.OP_ALL) + try: + print(" OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1) + except AttributeError: + pass for filename in [ CERTFILE, SVN_PYTHON_ORG_ROOT_CERT, BYTES_CERTFILE, ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY, + SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA, BADCERT, BADKEY, EMPTYCERT]: if not os.path.exists(filename): raise support.TestFailed("Can't read certificate file %r" % filename) |