diff options
author | Christian Heimes <christian@python.org> | 2016-09-05 21:19:05 (GMT) |
---|---|---|
committer | Christian Heimes <christian@python.org> | 2016-09-05 21:19:05 (GMT) |
commit | 598894ff48e9c1171cb2ec1c798235826a75c7e0 (patch) | |
tree | 8f94a9879770bf268cb8245702324de867490202 /Lib | |
parent | b3b7a5a16b4e3c5bdd8d378d95e0645ab16a9547 (diff) | |
download | cpython-598894ff48e9c1171cb2ec1c798235826a75c7e0.zip cpython-598894ff48e9c1171cb2ec1c798235826a75c7e0.tar.gz cpython-598894ff48e9c1171cb2ec1c798235826a75c7e0.tar.bz2 |
Issue #26470: Port ssl and hashlib module to OpenSSL 1.1.0.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/ssl.py | 18 | ||||
-rw-r--r-- | Lib/test/test_ssl.py | 89 |
2 files changed, 68 insertions, 39 deletions
@@ -51,6 +51,7 @@ The following constants identify various SSL protocol variants: PROTOCOL_SSLv2 PROTOCOL_SSLv3 PROTOCOL_SSLv23 +PROTOCOL_TLS PROTOCOL_TLSv1 PROTOCOL_TLSv1_1 PROTOCOL_TLSv1_2 @@ -128,9 +129,10 @@ from _ssl import _OPENSSL_API_VERSION _IntEnum._convert( '_SSLMethod', __name__, - lambda name: name.startswith('PROTOCOL_'), + lambda name: name.startswith('PROTOCOL_') and name != 'PROTOCOL_SSLv23', source=_ssl) +PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_SSLv23 = _SSLMethod.PROTOCOL_TLS _PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()} try: @@ -357,13 +359,13 @@ class SSLContext(_SSLContext): __slots__ = ('protocol', '__weakref__') _windows_cert_stores = ("CA", "ROOT") - def __new__(cls, protocol, *args, **kwargs): + def __new__(cls, protocol=PROTOCOL_TLS, *args, **kwargs): self = _SSLContext.__new__(cls, protocol) if protocol != _SSLv2_IF_EXISTS: self.set_ciphers(_DEFAULT_CIPHERS) return self - def __init__(self, protocol): + def __init__(self, protocol=PROTOCOL_TLS): self.protocol = protocol def wrap_socket(self, sock, server_side=False, @@ -438,7 +440,7 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, if not isinstance(purpose, _ASN1Object): raise TypeError(purpose) - context = SSLContext(PROTOCOL_SSLv23) + context = SSLContext(PROTOCOL_TLS) # SSLv2 considered harmful. context.options |= OP_NO_SSLv2 @@ -475,7 +477,7 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None, context.load_default_certs(purpose) return context -def _create_unverified_context(protocol=PROTOCOL_SSLv23, *, cert_reqs=None, +def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=None, check_hostname=False, purpose=Purpose.SERVER_AUTH, certfile=None, keyfile=None, cafile=None, capath=None, cadata=None): @@ -666,7 +668,7 @@ class SSLSocket(socket): def __init__(self, sock=None, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_SSLv23, ca_certs=None, + ssl_version=PROTOCOL_TLS, ca_certs=None, do_handshake_on_connect=True, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None, suppress_ragged_eofs=True, npn_protocols=None, ciphers=None, @@ -1056,7 +1058,7 @@ class SSLSocket(socket): def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, - ssl_version=PROTOCOL_SSLv23, ca_certs=None, + ssl_version=PROTOCOL_TLS, ca_certs=None, do_handshake_on_connect=True, suppress_ragged_eofs=True, ciphers=None): @@ -1125,7 +1127,7 @@ def PEM_cert_to_DER_cert(pem_cert_string): d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] return base64.decodebytes(d.encode('ASCII', 'strict')) -def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv23, ca_certs=None): +def get_server_certificate(addr, ssl_version=PROTOCOL_TLS, ca_certs=None): """Retrieve the certificate from the server at the specified address, and return it as a PEM-encoded string. If 'ca_certs' is specified, validate the server cert against it. diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index f48103e..79e26ba 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -23,6 +23,9 @@ ssl = support.import_module("ssl") PROTOCOLS = sorted(ssl._PROTOCOL_NAMES) HOST = support.HOST +IS_LIBRESSL = ssl.OPENSSL_VERSION.startswith('LibreSSL') +IS_OPENSSL_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0) + def data_file(*name): return os.path.join(os.path.dirname(__file__), *name) @@ -143,8 +146,8 @@ class BasicSocketTests(unittest.TestCase): def test_str_for_enums(self): # Make sure that the PROTOCOL_* constants have enum-like string # reprs. - proto = ssl.PROTOCOL_SSLv23 - self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_SSLv23') + proto = ssl.PROTOCOL_TLS + self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_TLS') ctx = ssl.SSLContext(proto) self.assertIs(ctx.protocol, proto) @@ -312,8 +315,8 @@ class BasicSocketTests(unittest.TestCase): self.assertGreaterEqual(status, 0) self.assertLessEqual(status, 15) # Version string as returned by {Open,Libre}SSL, the format might change - if "LibreSSL" in s: - self.assertTrue(s.startswith("LibreSSL {:d}.{:d}".format(major, minor)), + if IS_LIBRESSL: + self.assertTrue(s.startswith("LibreSSL {:d}".format(major)), (s, t, hex(n))) else: self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)), @@ -790,7 +793,8 @@ class ContextTests(unittest.TestCase): def test_constructor(self): for protocol in PROTOCOLS: ssl.SSLContext(protocol) - self.assertRaises(TypeError, ssl.SSLContext) + ctx = ssl.SSLContext() + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS) self.assertRaises(ValueError, ssl.SSLContext, -1) self.assertRaises(ValueError, ssl.SSLContext, 42) @@ -811,15 +815,15 @@ class ContextTests(unittest.TestCase): def test_options(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value - self.assertEqual(ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3, - ctx.options) + default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3) + if not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0): + default |= ssl.OP_NO_COMPRESSION + self.assertEqual(default, ctx.options) ctx.options |= ssl.OP_NO_TLSv1 - self.assertEqual(ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1, - ctx.options) + self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options) if can_clear_options(): - ctx.options = (ctx.options & ~ssl.OP_NO_SSLv2) | ssl.OP_NO_TLSv1 - self.assertEqual(ssl.OP_ALL | ssl.OP_NO_TLSv1 | ssl.OP_NO_SSLv3, - ctx.options) + ctx.options = (ctx.options & ~ssl.OP_NO_TLSv1) + self.assertEqual(default, ctx.options) ctx.options = 0 # Ubuntu has OP_NO_SSLv3 forced on by default self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3) @@ -1155,6 +1159,7 @@ class ContextTests(unittest.TestCase): self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH') @unittest.skipIf(sys.platform == "win32", "not-Windows specific") + @unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars") def test_load_default_certs_env(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) with support.EnvironmentVarGuard() as env: @@ -1750,13 +1755,13 @@ class NetworkedBIOTests(unittest.TestCase): sslobj = ctx.wrap_bio(incoming, outgoing, False, REMOTE_HOST) self.assertIs(sslobj._sslobj.owner, sslobj) self.assertIsNone(sslobj.cipher()) - self.assertIsNone(sslobj.shared_ciphers()) + self.assertIsNotNone(sslobj.shared_ciphers()) self.assertRaises(ValueError, sslobj.getpeercert) if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: self.assertIsNone(sslobj.get_channel_binding('tls-unique')) self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) self.assertTrue(sslobj.cipher()) - self.assertIsNone(sslobj.shared_ciphers()) + self.assertIsNotNone(sslobj.shared_ciphers()) self.assertTrue(sslobj.getpeercert()) if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: self.assertTrue(sslobj.get_channel_binding('tls-unique')) @@ -2993,7 +2998,7 @@ else: with context.wrap_socket(socket.socket()) as s: self.assertIs(s.version(), None) s.connect((HOST, server.port)) - self.assertEqual(s.version(), "TLSv1") + self.assertEqual(s.version(), 'TLSv1') self.assertIs(s.version(), None) @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL") @@ -3135,24 +3140,36 @@ else: (['http/3.0', 'http/4.0'], None) ] for client_protocols, expected in protocol_tests: - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) server_context.load_cert_chain(CERTFILE) server_context.set_alpn_protocols(server_protocols) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) client_context.load_cert_chain(CERTFILE) client_context.set_alpn_protocols(client_protocols) - stats = server_params_test(client_context, server_context, - chatty=True, connectionchatty=True) - msg = "failed trying %s (s) and %s (c).\n" \ - "was expecting %s, but got %%s from the %%s" \ - % (str(server_protocols), str(client_protocols), - str(expected)) - client_result = stats['client_alpn_protocol'] - self.assertEqual(client_result, expected, msg % (client_result, "client")) - server_result = stats['server_alpn_protocols'][-1] \ - if len(stats['server_alpn_protocols']) else 'nothing' - self.assertEqual(server_result, expected, msg % (server_result, "server")) + try: + stats = server_params_test(client_context, + server_context, + chatty=True, + connectionchatty=True) + except ssl.SSLError as e: + stats = e + + if expected is None and IS_OPENSSL_1_1: + # OpenSSL 1.1.0 raises handshake error + self.assertIsInstance(stats, ssl.SSLError) + else: + msg = "failed trying %s (s) and %s (c).\n" \ + "was expecting %s, but got %%s from the %%s" \ + % (str(server_protocols), str(client_protocols), + str(expected)) + client_result = stats['client_alpn_protocol'] + self.assertEqual(client_result, expected, + msg % (client_result, "client")) + server_result = stats['server_alpn_protocols'][-1] \ + if len(stats['server_alpn_protocols']) else 'nothing' + self.assertEqual(server_result, expected, + msg % (server_result, "server")) def test_selected_npn_protocol(self): # selected_npn_protocol() is None unless NPN is used @@ -3300,13 +3317,23 @@ else: client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) client_context.verify_mode = ssl.CERT_REQUIRED client_context.load_verify_locations(SIGNING_CA) - client_context.set_ciphers("RC4") - server_context.set_ciphers("AES:RC4") + if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2): + client_context.set_ciphers("AES128:AES256") + server_context.set_ciphers("AES256") + alg1 = "AES256" + alg2 = "AES-256" + else: + client_context.set_ciphers("AES:3DES") + server_context.set_ciphers("3DES") + alg1 = "3DES" + alg2 = "DES-CBC3" + stats = server_params_test(client_context, server_context) ciphers = stats['server_shared_ciphers'][0] self.assertGreater(len(ciphers), 0) for name, tls_version, bits in ciphers: - self.assertIn("RC4", name.split("-")) + if not alg1 in name.split("-") and alg2 not in name: + self.fail(name) def test_read_write_after_close_raises_valuerror(self): context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) |