diff options
Diffstat (limited to 'Lib/test/test_ssl.py')
-rw-r--r-- | Lib/test/test_ssl.py | 803 |
1 files changed, 747 insertions, 56 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 9fc6027..0dc04c0 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -6,6 +6,7 @@ from test import support import socket import select import time +import datetime import gc import os import errno @@ -17,19 +18,15 @@ import asyncore import weakref import platform import functools +from unittest import mock ssl = support.import_module("ssl") -PROTOCOLS = [ - ssl.PROTOCOL_SSLv3, - ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1 -] -if hasattr(ssl, 'PROTOCOL_SSLv2'): - PROTOCOLS.append(ssl.PROTOCOL_SSLv2) - +PROTOCOLS = sorted(ssl._PROTOCOL_NAMES) HOST = support.HOST -data_file = lambda name: os.path.join(os.path.dirname(__file__), name) +def data_file(*name): + return os.path.join(os.path.dirname(__file__), *name) # The custom key and certificate files used in test_ssl are generated # using Lib/test/make_ssl_certs.py. @@ -47,6 +44,17 @@ ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem") KEY_PASSWORD = "somepass" CAPATH = data_file("capath") BYTES_CAPATH = os.fsencode(CAPATH) +CAFILE_NEURONIO = data_file("capath", "4e1295a3.0") +CAFILE_CACERT = data_file("capath", "5ed36f99.0") + + +# empty CRL +CRLFILE = data_file("revocation.crl") + +# Two keys and certs signed by the same CA (for SNI tests) +SIGNED_CERTFILE = data_file("keycert3.pem") +SIGNED_CERTFILE2 = data_file("keycert4.pem") +SIGNING_CA = data_file("pycacert.pem") SVN_PYTHON_ORG_ROOT_CERT = data_file("https_svn_python_org_root.pem") @@ -60,6 +68,7 @@ NULLBYTECERT = data_file("nullbytecert.pem") DHFILE = data_file("dh512.pem") BYTES_DHFILE = os.fsencode(DHFILE) + def handle_error(prefix): exc_format = ' '.join(traceback.format_exception(*sys.exc_info())) if support.verbose: @@ -73,6 +82,23 @@ def no_sslv2_implies_sslv3_hello(): # 0.9.7h or higher return ssl.OPENSSL_VERSION_INFO >= (0, 9, 7, 8, 15) +def have_verify_flags(): + # 0.9.8 or higher + return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15) + +def asn1time(cert_time): + # Some versions of OpenSSL ignore seconds, see #18207 + # 0.9.8.i + if ssl._OPENSSL_API_VERSION == (0, 9, 8, 9, 15): + fmt = "%b %d %H:%M:%S %Y GMT" + dt = datetime.datetime.strptime(cert_time, fmt) + dt = dt.replace(second=0) + cert_time = dt.strftime(fmt) + # %d adds leading zero but ASN1_TIME_print() uses leading space + if cert_time[4] == "0": + cert_time = cert_time[:4] + " " + cert_time[5:] + + return cert_time # Issue #9415: Ubuntu hijacks their OpenSSL and forcefully disables SSLv2 def skip_if_broken_ubuntu_ssl(func): @@ -90,14 +116,12 @@ def skip_if_broken_ubuntu_ssl(func): else: return func +needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test") + class BasicSocketTests(unittest.TestCase): def test_constants(self): - #ssl.PROTOCOL_SSLv2 - ssl.PROTOCOL_SSLv23 - ssl.PROTOCOL_SSLv3 - ssl.PROTOCOL_TLSv1 ssl.CERT_NONE ssl.CERT_OPTIONAL ssl.CERT_REQUIRED @@ -179,8 +203,9 @@ class BasicSocketTests(unittest.TestCase): (('organizationName', 'Python Software Foundation'),), (('commonName', 'localhost'),)) ) - self.assertEqual(p['notAfter'], 'Oct 5 23:01:56 2020 GMT') - self.assertEqual(p['notBefore'], 'Oct 8 23:01:56 2010 GMT') + # Note the next three asserts will fail if the keys are regenerated + self.assertEqual(p['notAfter'], asn1time('Oct 5 23:01:56 2020 GMT')) + self.assertEqual(p['notBefore'], asn1time('Oct 8 23:01:56 2010 GMT')) self.assertEqual(p['serialNumber'], 'D7C7381919AFC24E') self.assertEqual(p['subject'], ((('countryName', 'XY'),), @@ -198,6 +223,12 @@ class BasicSocketTests(unittest.TestCase): (('DNS', 'projects.developer.nokia.com'), ('DNS', 'projects.forum.nokia.com')) ) + # extra OCSP and AIA fields + self.assertEqual(p['OCSP'], ('http://ocsp.verisign.com',)) + self.assertEqual(p['caIssuers'], + ('http://SVRIntl-G3-aia.verisign.com/SVRIntlG3.cer',)) + self.assertEqual(p['crlDistributionPoints'], + ('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',)) def test_parse_cert_CVE_2013_4238(self): p = ssl._ssl._test_decode_cert(NULLBYTECERT) @@ -280,15 +311,15 @@ class BasicSocketTests(unittest.TestCase): def test_wrapped_unconnected(self): # Methods on an unconnected SSLSocket propagate the original - # socket.error raise by the underlying socket object. + # OSError raise by the underlying socket object. s = socket.socket(socket.AF_INET) with ssl.wrap_socket(s) as ss: - self.assertRaises(socket.error, ss.recv, 1) - self.assertRaises(socket.error, ss.recv_into, bytearray(b'x')) - self.assertRaises(socket.error, ss.recvfrom, 1) - self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1) - self.assertRaises(socket.error, ss.send, b'x') - self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0)) + self.assertRaises(OSError, ss.recv, 1) + self.assertRaises(OSError, ss.recv_into, bytearray(b'x')) + self.assertRaises(OSError, ss.recvfrom, 1) + self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1) + self.assertRaises(OSError, ss.send, b'x') + self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0)) def test_timeout(self): # Issue #8524: when creating an SSL socket, the timeout of the @@ -313,15 +344,15 @@ class BasicSocketTests(unittest.TestCase): with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s: self.assertRaisesRegex(ValueError, "can't connect in server-side mode", s.connect, (HOST, 8080)) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: with socket.socket() as sock: ssl.wrap_socket(sock, certfile=WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: with socket.socket() as sock: ssl.wrap_socket(sock, certfile=CERTFILE, keyfile=WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: with socket.socket() as sock: ssl.wrap_socket(sock, certfile=WRONGCERT, keyfile=WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) @@ -493,6 +524,114 @@ class BasicSocketTests(unittest.TestCase): support.gc_collect() self.assertIn(r, str(cm.warning.args[0])) + def test_get_default_verify_paths(self): + paths = ssl.get_default_verify_paths() + self.assertEqual(len(paths), 6) + self.assertIsInstance(paths, ssl.DefaultVerifyPaths) + + with support.EnvironmentVarGuard() as env: + env["SSL_CERT_DIR"] = CAPATH + env["SSL_CERT_FILE"] = CERTFILE + paths = ssl.get_default_verify_paths() + self.assertEqual(paths.cafile, CERTFILE) + self.assertEqual(paths.capath, CAPATH) + + @unittest.skipUnless(sys.platform == "win32", "Windows specific") + def test_enum_certificates(self): + self.assertTrue(ssl.enum_certificates("CA")) + self.assertTrue(ssl.enum_certificates("ROOT")) + + self.assertRaises(TypeError, ssl.enum_certificates) + self.assertRaises(WindowsError, ssl.enum_certificates, "") + + trust_oids = set() + for storename in ("CA", "ROOT"): + store = ssl.enum_certificates(storename) + self.assertIsInstance(store, list) + for element in store: + self.assertIsInstance(element, tuple) + self.assertEqual(len(element), 3) + cert, enc, trust = element + self.assertIsInstance(cert, bytes) + self.assertIn(enc, {"x509_asn", "pkcs_7_asn"}) + self.assertIsInstance(trust, (set, bool)) + if isinstance(trust, set): + trust_oids.update(trust) + + serverAuth = "1.3.6.1.5.5.7.3.1" + self.assertIn(serverAuth, trust_oids) + + @unittest.skipUnless(sys.platform == "win32", "Windows specific") + def test_enum_crls(self): + self.assertTrue(ssl.enum_crls("CA")) + self.assertRaises(TypeError, ssl.enum_crls) + self.assertRaises(WindowsError, ssl.enum_crls, "") + + crls = ssl.enum_crls("CA") + self.assertIsInstance(crls, list) + for element in crls: + self.assertIsInstance(element, tuple) + self.assertEqual(len(element), 2) + self.assertIsInstance(element[0], bytes) + self.assertIn(element[1], {"x509_asn", "pkcs_7_asn"}) + + + def test_asn1object(self): + expected = (129, 'serverAuth', 'TLS Web Server Authentication', + '1.3.6.1.5.5.7.3.1') + + val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1') + self.assertEqual(val, expected) + self.assertEqual(val.nid, 129) + self.assertEqual(val.shortname, 'serverAuth') + self.assertEqual(val.longname, 'TLS Web Server Authentication') + self.assertEqual(val.oid, '1.3.6.1.5.5.7.3.1') + self.assertIsInstance(val, ssl._ASN1Object) + self.assertRaises(ValueError, ssl._ASN1Object, 'serverAuth') + + val = ssl._ASN1Object.fromnid(129) + self.assertEqual(val, expected) + self.assertIsInstance(val, ssl._ASN1Object) + self.assertRaises(ValueError, ssl._ASN1Object.fromnid, -1) + with self.assertRaisesRegex(ValueError, "unknown NID 100000"): + ssl._ASN1Object.fromnid(100000) + for i in range(1000): + try: + obj = ssl._ASN1Object.fromnid(i) + except ValueError: + pass + else: + self.assertIsInstance(obj.nid, int) + self.assertIsInstance(obj.shortname, str) + self.assertIsInstance(obj.longname, str) + self.assertIsInstance(obj.oid, (str, type(None))) + + val = ssl._ASN1Object.fromname('TLS Web Server Authentication') + self.assertEqual(val, expected) + self.assertIsInstance(val, ssl._ASN1Object) + self.assertEqual(ssl._ASN1Object.fromname('serverAuth'), expected) + self.assertEqual(ssl._ASN1Object.fromname('1.3.6.1.5.5.7.3.1'), + expected) + with self.assertRaisesRegex(ValueError, "unknown object 'serverauth'"): + ssl._ASN1Object.fromname('serverauth') + + def test_purpose_enum(self): + val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1') + self.assertIsInstance(ssl.Purpose.SERVER_AUTH, ssl._ASN1Object) + self.assertEqual(ssl.Purpose.SERVER_AUTH, val) + self.assertEqual(ssl.Purpose.SERVER_AUTH.nid, 129) + self.assertEqual(ssl.Purpose.SERVER_AUTH.shortname, 'serverAuth') + self.assertEqual(ssl.Purpose.SERVER_AUTH.oid, + '1.3.6.1.5.5.7.3.1') + + val = ssl._ASN1Object('1.3.6.1.5.5.7.3.2') + self.assertIsInstance(ssl.Purpose.CLIENT_AUTH, ssl._ASN1Object) + self.assertEqual(ssl.Purpose.CLIENT_AUTH, val) + self.assertEqual(ssl.Purpose.CLIENT_AUTH.nid, 130) + self.assertEqual(ssl.Purpose.CLIENT_AUTH.shortname, 'clientAuth') + self.assertEqual(ssl.Purpose.CLIENT_AUTH.oid, + '1.3.6.1.5.5.7.3.2') + def test_unsupported_dtls(self): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.addCleanup(s.close) @@ -509,11 +648,8 @@ class ContextTests(unittest.TestCase): @skip_if_broken_ubuntu_ssl def test_constructor(self): - if hasattr(ssl, 'PROTOCOL_SSLv2'): - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv2) - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv3) - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + for protocol in PROTOCOLS: + ssl.SSLContext(protocol) self.assertRaises(TypeError, ssl.SSLContext) self.assertRaises(ValueError, ssl.SSLContext, -1) self.assertRaises(ValueError, ssl.SSLContext, 42) @@ -550,7 +686,7 @@ class ContextTests(unittest.TestCase): with self.assertRaises(ValueError): ctx.options = 0 - def test_verify(self): + def test_verify_mode(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) # Default value self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) @@ -565,13 +701,32 @@ class ContextTests(unittest.TestCase): with self.assertRaises(ValueError): ctx.verify_mode = 42 + @unittest.skipUnless(have_verify_flags(), + "verify_flags need OpenSSL > 0.9.8") + def test_verify_flags(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + # default value by OpenSSL + self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT) + ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF + self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_LEAF) + ctx.verify_flags = ssl.VERIFY_CRL_CHECK_CHAIN + self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_CHAIN) + ctx.verify_flags = ssl.VERIFY_DEFAULT + self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT) + # supports any value + ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT + self.assertEqual(ctx.verify_flags, + ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT) + with self.assertRaises(TypeError): + ctx.verify_flags = None + def test_load_cert_chain(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) # Combined key and cert in a single file ctx.load_cert_chain(CERTFILE) ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE) self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE) - with self.assertRaises(IOError) as cm: + with self.assertRaises(OSError) as cm: ctx.load_cert_chain(WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) with self.assertRaisesRegex(ssl.SSLError, "PEM lib"): @@ -655,8 +810,8 @@ class ContextTests(unittest.TestCase): ctx.load_verify_locations(BYTES_CERTFILE) ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None) self.assertRaises(TypeError, ctx.load_verify_locations) - self.assertRaises(TypeError, ctx.load_verify_locations, None, None) - with self.assertRaises(IOError) as cm: + self.assertRaises(TypeError, ctx.load_verify_locations, None, None, None) + with self.assertRaises(OSError) as cm: ctx.load_verify_locations(WRONGCERT) self.assertEqual(cm.exception.errno, errno.ENOENT) with self.assertRaisesRegex(ssl.SSLError, "PEM lib"): @@ -667,6 +822,64 @@ class ContextTests(unittest.TestCase): # Issue #10989: crash if the second argument type is invalid self.assertRaises(TypeError, ctx.load_verify_locations, None, True) + def test_load_verify_cadata(self): + # test cadata + with open(CAFILE_CACERT) as f: + cacert_pem = f.read() + cacert_der = ssl.PEM_cert_to_DER_cert(cacert_pem) + with open(CAFILE_NEURONIO) as f: + neuronio_pem = f.read() + neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem) + + # test PEM + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0) + ctx.load_verify_locations(cadata=cacert_pem) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1) + ctx.load_verify_locations(cadata=neuronio_pem) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + # cert already in hash table + ctx.load_verify_locations(cadata=neuronio_pem) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # combined + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + combined = "\n".join((cacert_pem, neuronio_pem)) + ctx.load_verify_locations(cadata=combined) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # with junk around the certs + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + combined = ["head", cacert_pem, "other", neuronio_pem, "again", + neuronio_pem, "tail"] + ctx.load_verify_locations(cadata="\n".join(combined)) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # test DER + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_verify_locations(cadata=cacert_der) + ctx.load_verify_locations(cadata=neuronio_der) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + # cert already in hash table + ctx.load_verify_locations(cadata=cacert_der) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # combined + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + combined = b"".join((cacert_der, neuronio_der)) + ctx.load_verify_locations(cadata=combined) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # error cases + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object) + + with self.assertRaisesRegex(ssl.SSLError, "no start line"): + ctx.load_verify_locations(cadata="broken") + with self.assertRaisesRegex(ssl.SSLError, "not enough data"): + ctx.load_verify_locations(cadata=b"broken") + + def test_load_dh_params(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ctx.load_dh_params(DHFILE) @@ -714,6 +927,158 @@ class ContextTests(unittest.TestCase): self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo") self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo") + @needs_sni + def test_sni_callback(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + + # set_servername_callback expects a callable, or None + self.assertRaises(TypeError, ctx.set_servername_callback) + self.assertRaises(TypeError, ctx.set_servername_callback, 4) + self.assertRaises(TypeError, ctx.set_servername_callback, "") + self.assertRaises(TypeError, ctx.set_servername_callback, ctx) + + def dummycallback(sock, servername, ctx): + pass + ctx.set_servername_callback(None) + ctx.set_servername_callback(dummycallback) + + @needs_sni + def test_sni_callback_refcycle(self): + # Reference cycles through the servername callback are detected + # and cleared. + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + def dummycallback(sock, servername, ctx, cycle=ctx): + pass + ctx.set_servername_callback(dummycallback) + wr = weakref.ref(ctx) + del ctx, dummycallback + gc.collect() + self.assertIs(wr(), None) + + def test_cert_store_stats(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 0, 'crl': 0, 'x509': 0}) + ctx.load_cert_chain(CERTFILE) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 0, 'crl': 0, 'x509': 0}) + ctx.load_verify_locations(CERTFILE) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 0, 'crl': 0, 'x509': 1}) + ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 1, 'crl': 0, 'x509': 2}) + + def test_get_ca_certs(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.get_ca_certs(), []) + # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE + ctx.load_verify_locations(CERTFILE) + self.assertEqual(ctx.get_ca_certs(), []) + # but SVN_PYTHON_ORG_ROOT_CERT is a CA cert + ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT) + self.assertEqual(ctx.get_ca_certs(), + [{'issuer': ((('organizationName', 'Root CA'),), + (('organizationalUnitName', 'http://www.cacert.org'),), + (('commonName', 'CA Cert Signing Authority'),), + (('emailAddress', 'support@cacert.org'),)), + 'notAfter': asn1time('Mar 29 12:29:49 2033 GMT'), + 'notBefore': asn1time('Mar 30 12:29:49 2003 GMT'), + 'serialNumber': '00', + 'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',), + 'subject': ((('organizationName', 'Root CA'),), + (('organizationalUnitName', 'http://www.cacert.org'),), + (('commonName', 'CA Cert Signing Authority'),), + (('emailAddress', 'support@cacert.org'),)), + 'version': 3}]) + + with open(SVN_PYTHON_ORG_ROOT_CERT) as f: + pem = f.read() + der = ssl.PEM_cert_to_DER_cert(pem) + self.assertEqual(ctx.get_ca_certs(True), [der]) + + def test_load_default_certs(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_default_certs() + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_default_certs(ssl.Purpose.SERVER_AUTH) + ctx.load_default_certs() + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH) + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertRaises(TypeError, ctx.load_default_certs, None) + self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH') + + def test_create_default_context(self): + ctx = ssl.create_default_context() + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self.assertTrue(ctx.check_hostname) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + with open(SIGNING_CA) as f: + cadata = f.read() + ctx = ssl.create_default_context(cafile=SIGNING_CA, capath=CAPATH, + cadata=cadata) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + def test__create_stdlib_context(self): + ctx = ssl._create_stdlib_context() + self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23) + self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + self.assertFalse(ctx.check_hostname) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1, + cert_reqs=ssl.CERT_REQUIRED, + check_hostname=True) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self.assertTrue(ctx.check_hostname) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + ctx = ssl._create_stdlib_context(purpose=ssl.Purpose.CLIENT_AUTH) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23) + self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2) + + def test_check_hostname(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertFalse(ctx.check_hostname) + + # Requires CERT_REQUIRED or CERT_OPTIONAL + with self.assertRaises(ValueError): + ctx.check_hostname = True + ctx.verify_mode = ssl.CERT_REQUIRED + self.assertFalse(ctx.check_hostname) + ctx.check_hostname = True + self.assertTrue(ctx.check_hostname) + + ctx.verify_mode = ssl.CERT_OPTIONAL + ctx.check_hostname = True + self.assertTrue(ctx.check_hostname) + + # Cannot set CERT_NONE with check_hostname enabled + with self.assertRaises(ValueError): + ctx.verify_mode = ssl.CERT_NONE + ctx.check_hostname = False + self.assertFalse(ctx.check_hostname) + class SSLErrorTests(unittest.TestCase): @@ -919,6 +1284,28 @@ class NetworkedTests(unittest.TestCase): finally: s.close() + def test_connect_cadata(self): + with open(CAFILE_CACERT) as f: + pem = f.read() + der = ssl.PEM_cert_to_DER_cert(pem) + with support.transient_internet("svn.python.org"): + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(cadata=pem) + with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: + s.connect(("svn.python.org", 443)) + cert = s.getpeercert() + self.assertTrue(cert) + + # same with DER + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(cadata=der) + with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: + s.connect(("svn.python.org", 443)) + cert = s.getpeercert() + self.assertTrue(cert) + @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows") def test_makefile_close(self): # Issue #5238: creating a file-like object with makefile() shouldn't @@ -1031,6 +1418,36 @@ class NetworkedTests(unittest.TestCase): finally: s.close() + def test_get_ca_certs_capath(self): + # capath certs are loaded on request + with support.transient_internet("svn.python.org"): + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(capath=CAPATH) + self.assertEqual(ctx.get_ca_certs(), []) + s = ctx.wrap_socket(socket.socket(socket.AF_INET)) + s.connect(("svn.python.org", 443)) + try: + cert = s.getpeercert() + self.assertTrue(cert) + finally: + s.close() + self.assertEqual(len(ctx.get_ca_certs()), 1) + + @needs_sni + def test_context_setget(self): + # Check that the context of a connected socket can be replaced. + with support.transient_internet("svn.python.org"): + ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + s = socket.socket(socket.AF_INET) + with ctx1.wrap_socket(s) as ss: + ss.connect(("svn.python.org", 443)) + self.assertIs(ss.context, ctx1) + self.assertIs(ss._sslobj.context, ctx1) + ss.context = ctx2 + self.assertIs(ss.context, ctx2) + self.assertIs(ss._sslobj.context, ctx2) try: import threading @@ -1158,7 +1575,7 @@ else: sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n" % (msg, ctype, msg.lower(), ctype)) self.write(msg.lower()) - except socket.error: + except OSError: if self.server.chatty: handle_error("Test server failure:\n") self.close() @@ -1268,7 +1685,7 @@ else: return self.handle_close() except ssl.SSLError: raise - except socket.error as err: + except OSError as err: if err.args[0] == errno.ECONNABORTED: return self.handle_close() else: @@ -1372,19 +1789,19 @@ else: except ssl.SSLError as x: if support.verbose: sys.stdout.write("\nSSLError is %s\n" % x.args[1]) - except socket.error as x: + except OSError as x: if support.verbose: - sys.stdout.write("\nsocket.error is %s\n" % x.args[1]) - except IOError as x: + sys.stdout.write("\nOSError is %s\n" % x.args[1]) + except OSError as x: if x.errno != errno.ENOENT: raise if support.verbose: - sys.stdout.write("\IOError is %s\n" % str(x)) + sys.stdout.write("\OSError is %s\n" % str(x)) else: raise AssertionError("Use of invalid cert should have failed!") def server_params_test(client_context, server_context, indata=b"FOO\n", - chatty=True, connectionchatty=False): + chatty=True, connectionchatty=False, sni_name=None): """ Launch a server, connect a client to it and try various reads and writes. @@ -1394,7 +1811,8 @@ else: chatty=chatty, connectionchatty=False) with server: - with client_context.wrap_socket(socket.socket()) as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=sni_name) as s: s.connect((HOST, server.port)) for arg in [indata, bytearray(indata), memoryview(indata)]: if connectionchatty: @@ -1418,6 +1836,7 @@ else: stats.update({ 'compression': s.compression(), 'cipher': s.cipher(), + 'peercert': s.getpeercert(), 'client_npn_protocol': s.selected_npn_protocol() }) s.close() @@ -1443,12 +1862,15 @@ else: client_context.options |= client_options server_context = ssl.SSLContext(server_protocol) server_context.options |= server_options + + # NOTE: we must enable "ALL" ciphers on the client, otherwise an + # SSLv23 client will send an SSLv3 hello (rather than SSLv2) + # starting from OpenSSL 1.0.0 (see issue #8322). + if client_context.protocol == ssl.PROTOCOL_SSLv23: + client_context.set_ciphers("ALL") + for ctx in (client_context, server_context): ctx.verify_mode = certsreqs - # NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client - # will send an SSLv3 hello (rather than SSLv2) starting from - # OpenSSL 1.0.0 (see issue #8322). - ctx.set_ciphers("ALL") ctx.load_cert_chain(CERTFILE) ctx.load_verify_locations(CERTFILE) try: @@ -1459,7 +1881,7 @@ else: except ssl.SSLError: if expect_success: raise - except socket.error as e: + except OSError as e: if expect_success or e.errno != errno.ECONNRESET: raise else: @@ -1478,10 +1900,11 @@ else: if support.verbose: sys.stdout.write("\n") for protocol in PROTOCOLS: - context = ssl.SSLContext(protocol) - context.load_cert_chain(CERTFILE) - server_params_test(context, context, - chatty=True, connectionchatty=True) + with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]): + context = ssl.SSLContext(protocol) + context.load_cert_chain(CERTFILE) + server_params_test(context, context, + chatty=True, connectionchatty=True) def test_getpeercert(self): if support.verbose: @@ -1492,8 +1915,14 @@ else: context.load_cert_chain(CERTFILE) server = ThreadedEchoServer(context=context, chatty=False) with server: - s = context.wrap_socket(socket.socket()) + s = context.wrap_socket(socket.socket(), + do_handshake_on_connect=False) s.connect((HOST, server.port)) + # getpeercert() raise ValueError while the handshake isn't + # done. + with self.assertRaises(ValueError): + s.getpeercert() + s.do_handshake() cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") cipher = s.cipher() @@ -1515,6 +1944,87 @@ else: self.assertLess(before, after) s.close() + @unittest.skipUnless(have_verify_flags(), + "verify_flags need OpenSSL > 0.9.8") + def test_crl_check(self): + if support.verbose: + sys.stdout.write("\n") + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(SIGNING_CA) + self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT) + + # VERIFY_DEFAULT should pass + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + + # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails + context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF + + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket()) as s: + with self.assertRaisesRegex(ssl.SSLError, + "certificate verify failed"): + s.connect((HOST, server.port)) + + # now load a CRL file. The CRL file is signed by the CA. + context.load_verify_locations(CRLFILE) + + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + + @needs_sni + def test_check_hostname(self): + if support.verbose: + sys.stdout.write("\n") + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + context.load_verify_locations(SIGNING_CA) + + # correct hostname should verify + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket(), + server_hostname="localhost") as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + + # incorrect hostname should raise an exception + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket(), + server_hostname="invalid") as s: + with self.assertRaisesRegex(ssl.CertificateError, + "hostname 'invalid' doesn't match 'localhost'"): + s.connect((HOST, server.port)) + + # missing server_hostname arg should cause an exception, too + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with socket.socket() as s: + with self.assertRaisesRegex(ValueError, + "check_hostname requires server_hostname"): + context.wrap_socket(s) + def test_empty_cert(self): """Connecting with an empty cert file""" bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir, @@ -1533,7 +2043,7 @@ else: "badkey.pem")) def test_rude_shutdown(self): - """A brutal shutdown of an SSL server should raise an IOError + """A brutal shutdown of an SSL server should raise an OSError in the client when attempting handshake. """ listener_ready = threading.Event() @@ -1561,7 +2071,7 @@ else: listener_gone.wait() try: ssl_sock = ssl.wrap_socket(c) - except IOError: + except OSError: pass else: self.fail('connecting to closed SSL socket should have failed') @@ -1604,7 +2114,7 @@ else: if hasattr(ssl, 'PROTOCOL_SSLv2'): try: try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) - except (ssl.SSLError, socket.error) as x: + except OSError as x: # this fails on some older versions of OpenSSL (0.9.7l, for instance) if support.verbose: sys.stdout.write( @@ -1664,6 +2174,49 @@ else: try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_TLSv1) + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"), + "TLS version 1.1 not supported.") + def test_protocol_tlsv1_1(self): + """Connecting to a TLSv1.1 server with various client options. + Testing against older TLS versions.""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, True) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1_1) + + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) + + + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"), + "TLS version 1.2 not supported.") + def test_protocol_tlsv1_2(self): + """Connecting to a TLSv1.2 server with various client options. + Testing against older TLS versions.""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, True, + server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2, + client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1_2) + + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) + def test_starttls(self): """Switching from clear text to encrypted and back again.""" msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6") @@ -1724,7 +2277,7 @@ else: def test_socketserver(self): """Using a SocketServer to create and manage SSL connections.""" - server = make_https_server(self, CERTFILE) + server = make_https_server(self, certfile=CERTFILE) # try to connect if support.verbose: sys.stdout.write('\n') @@ -1980,6 +2533,20 @@ else: self.assertIsInstance(remote, ssl.SSLSocket) self.assertEqual(peer, client_addr) + def test_getpeercert_enotconn(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + with context.wrap_socket(socket.socket()) as sock: + with self.assertRaises(OSError) as cm: + sock.getpeercert() + self.assertEqual(cm.exception.errno, errno.ENOTCONN) + + def test_do_handshake_enotconn(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + with context.wrap_socket(socket.socket()) as sock: + with self.assertRaises(OSError) as cm: + sock.do_handshake() + self.assertEqual(cm.exception.errno, errno.ENOTCONN) + def test_default_ciphers(self): context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) try: @@ -1991,7 +2558,7 @@ else: ssl_version=ssl.PROTOCOL_SSLv23, chatty=False) as server: with context.wrap_socket(socket.socket()) as s: - with self.assertRaises((OSError, ssl.SSLError)): + with self.assertRaises(OSError): s.connect((HOST, server.port)) self.assertIn("no shared cipher", str(server.conn_errors[0])) @@ -2124,6 +2691,124 @@ else: if len(stats['server_npn_protocols']) else 'nothing' self.assertEqual(server_result, expected, msg % (server_result, "server")) + def sni_contexts(self): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + other_context.load_cert_chain(SIGNED_CERTFILE2) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_verify_locations(SIGNING_CA) + return server_context, other_context, client_context + + def check_common_name(self, stats, name): + cert = stats['peercert'] + self.assertIn((('commonName', name),), cert['subject']) + + @needs_sni + def test_sni_callback(self): + calls = [] + server_context, other_context, client_context = self.sni_contexts() + + def servername_cb(ssl_sock, server_name, initial_context): + calls.append((server_name, initial_context)) + if server_name is not None: + ssl_sock.context = other_context + server_context.set_servername_callback(servername_cb) + + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name='supermessage') + # The hostname was fetched properly, and the certificate was + # changed for the connection. + self.assertEqual(calls, [("supermessage", server_context)]) + # CERTFILE4 was selected + self.check_common_name(stats, 'fakehostname') + + calls = [] + # The callback is called with server_name=None + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name=None) + self.assertEqual(calls, [(None, server_context)]) + self.check_common_name(stats, 'localhost') + + # Check disabling the callback + calls = [] + server_context.set_servername_callback(None) + + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name='notfunny') + # Certificate didn't change + self.check_common_name(stats, 'localhost') + self.assertEqual(calls, []) + + @needs_sni + def test_sni_callback_alert(self): + # Returning a TLS alert is reflected to the connecting client + server_context, other_context, client_context = self.sni_contexts() + + def cb_returning_alert(ssl_sock, server_name, initial_context): + return ssl.ALERT_DESCRIPTION_ACCESS_DENIED + server_context.set_servername_callback(cb_returning_alert) + + with self.assertRaises(ssl.SSLError) as cm: + stats = server_params_test(client_context, server_context, + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED') + + @needs_sni + def test_sni_callback_raising(self): + # Raising fails the connection with a TLS handshake failure alert. + server_context, other_context, client_context = self.sni_contexts() + + def cb_raising(ssl_sock, server_name, initial_context): + 1/0 + server_context.set_servername_callback(cb_raising) + + with self.assertRaises(ssl.SSLError) as cm, \ + support.captured_stderr() as stderr: + stats = server_params_test(client_context, server_context, + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE') + self.assertIn("ZeroDivisionError", stderr.getvalue()) + + @needs_sni + def test_sni_callback_wrong_return_type(self): + # Returning the wrong return type terminates the TLS connection + # with an internal error alert. + server_context, other_context, client_context = self.sni_contexts() + + def cb_wrong_return_type(ssl_sock, server_name, initial_context): + return "foo" + server_context.set_servername_callback(cb_wrong_return_type) + + with self.assertRaises(ssl.SSLError) as cm, \ + support.captured_stderr() as stderr: + stats = server_params_test(client_context, server_context, + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR') + self.assertIn("TypeError", stderr.getvalue()) + + def test_read_write_after_close_raises_valuerror(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + + with server: + s = context.wrap_socket(socket.socket()) + s.connect((HOST, server.port)) + s.close() + + self.assertRaises(ValueError, s.read, 1024) + self.assertRaises(ValueError, s.write, b'hello') + def test_main(verbose=False): if support.verbose: @@ -2143,10 +2828,16 @@ def test_main(verbose=False): (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO)) print(" under %s" % plat) print(" HAS_SNI = %r" % ssl.HAS_SNI) + print(" OP_ALL = 0x%8x" % ssl.OP_ALL) + try: + print(" OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1) + except AttributeError: + pass for filename in [ CERTFILE, SVN_PYTHON_ORG_ROOT_CERT, BYTES_CERTFILE, ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY, + SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA, BADCERT, BADKEY, EMPTYCERT]: if not os.path.exists(filename): raise support.TestFailed("Can't read certificate file %r" % filename) |