diff options
author | Christian Heimes <christian@cheimes.de> | 2013-11-23 12:56:58 (GMT) |
---|---|---|
committer | Christian Heimes <christian@cheimes.de> | 2013-11-23 12:56:58 (GMT) |
commit | 72d28500b3c5e6f4051826432b2a801ce4e556f4 (patch) | |
tree | a9dea78f3f5f280297c4f419f5fd049c8e96f0bc /Lib | |
parent | a30d82f597927f0a7184d1b1018416d1739f4b11 (diff) | |
download | cpython-72d28500b3c5e6f4051826432b2a801ce4e556f4.zip cpython-72d28500b3c5e6f4051826432b2a801ce4e556f4.tar.gz cpython-72d28500b3c5e6f4051826432b2a801ce4e556f4.tar.bz2 |
Issue #19292: Add SSLContext.load_default_certs() to load default root CA
certificates from default stores or system stores. By default the method
loads CA certs for authentication of server certs.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/ssl.py | 28 | ||||
-rw-r--r-- | Lib/test/test_ssl.py | 32 |
2 files changed, 60 insertions, 0 deletions
@@ -92,6 +92,7 @@ import re import sys import os from collections import namedtuple +from enum import Enum as _Enum import _ssl # if we can't import it, let the error propagate @@ -298,11 +299,19 @@ class _ASN1Object(namedtuple("_ASN1Object", "nid shortname longname oid")): return super().__new__(cls, *_txt2obj(name, name=True)) +class Purpose(_ASN1Object, _Enum): + """SSLContext purpose flags with X509v3 Extended Key Usage objects + """ + SERVER_AUTH = '1.3.6.1.5.5.7.3.1' + CLIENT_AUTH = '1.3.6.1.5.5.7.3.2' + + class SSLContext(_SSLContext): """An SSLContext holds various SSL-related configuration options and data, such as certificates and possibly a private key.""" __slots__ = ('protocol', '__weakref__') + _windows_cert_stores = ("CA", "ROOT") def __new__(cls, protocol, *args, **kwargs): self = _SSLContext.__new__(cls, protocol) @@ -334,6 +343,25 @@ class SSLContext(_SSLContext): self._set_npn_protocols(protos) + def _load_windows_store_certs(self, storename, purpose): + certs = bytearray() + for cert, encoding, trust in enum_certificates(storename): + # CA certs are never PKCS#7 encoded + if encoding == "x509_asn": + if trust is True or purpose.oid in trust: + certs.extend(cert) + self.load_verify_locations(cadata=certs) + return certs + + def load_default_certs(self, purpose=Purpose.SERVER_AUTH): + if not isinstance(purpose, _ASN1Object): + raise TypeError(purpose) + if sys.platform == "win32": + for storename in self._windows_cert_stores: + self._load_windows_store_certs(storename, purpose) + else: + self.set_default_verify_paths() + class SSLSocket(socket): """This class implements a subtype of socket.socket that wraps diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index d6a7443..722d331 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -611,6 +611,23 @@ class BasicSocketTests(unittest.TestCase): with self.assertRaisesRegex(ValueError, "unknown object 'serverauth'"): ssl._ASN1Object.fromname('serverauth') + def test_purpose_enum(self): + val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1') + self.assertIsInstance(ssl.Purpose.SERVER_AUTH, ssl._ASN1Object) + self.assertEqual(ssl.Purpose.SERVER_AUTH, val) + self.assertEqual(ssl.Purpose.SERVER_AUTH.nid, 129) + self.assertEqual(ssl.Purpose.SERVER_AUTH.shortname, 'serverAuth') + self.assertEqual(ssl.Purpose.SERVER_AUTH.oid, + '1.3.6.1.5.5.7.3.1') + + val = ssl._ASN1Object('1.3.6.1.5.5.7.3.2') + self.assertIsInstance(ssl.Purpose.CLIENT_AUTH, ssl._ASN1Object) + self.assertEqual(ssl.Purpose.CLIENT_AUTH, val) + self.assertEqual(ssl.Purpose.CLIENT_AUTH.nid, 130) + self.assertEqual(ssl.Purpose.CLIENT_AUTH.shortname, 'clientAuth') + self.assertEqual(ssl.Purpose.CLIENT_AUTH.oid, + '1.3.6.1.5.5.7.3.2') + class ContextTests(unittest.TestCase): @@ -967,6 +984,21 @@ class ContextTests(unittest.TestCase): der = ssl.PEM_cert_to_DER_cert(pem) self.assertEqual(ctx.get_ca_certs(True), [der]) + def test_load_default_certs(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_default_certs() + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_default_certs(ssl.Purpose.SERVER_AUTH) + ctx.load_default_certs() + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH) + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertRaises(TypeError, ctx.load_default_certs, None) + self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH') + class SSLErrorTests(unittest.TestCase): |