diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/ssl.py | 37 | ||||
-rw-r--r-- | Lib/test/test_ssl.py | 117 |
2 files changed, 151 insertions, 3 deletions
@@ -112,9 +112,11 @@ except ImportError: pass -from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_TLSv1_3 -from _ssl import _DEFAULT_CIPHERS -from _ssl import _OPENSSL_API_VERSION +from _ssl import ( + HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1, + HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3 +) +from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION _IntEnum._convert( @@ -153,6 +155,16 @@ _PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items() _SSLv2_IF_EXISTS = getattr(_SSLMethod, 'PROTOCOL_SSLv2', None) +class TLSVersion(_IntEnum): + MINIMUM_SUPPORTED = _ssl.PROTO_MINIMUM_SUPPORTED + SSLv3 = _ssl.PROTO_SSLv3 + TLSv1 = _ssl.PROTO_TLSv1 + TLSv1_1 = _ssl.PROTO_TLSv1_1 + TLSv1_2 = _ssl.PROTO_TLSv1_2 + TLSv1_3 = _ssl.PROTO_TLSv1_3 + MAXIMUM_SUPPORTED = _ssl.PROTO_MAXIMUM_SUPPORTED + + if sys.platform == "win32": from _ssl import enum_certificates, enum_crls @@ -467,6 +479,25 @@ class SSLContext(_SSLContext): self._load_windows_store_certs(storename, purpose) self.set_default_verify_paths() + if hasattr(_SSLContext, 'minimum_version'): + @property + def minimum_version(self): + return TLSVersion(super().minimum_version) + + @minimum_version.setter + def minimum_version(self, value): + if value == TLSVersion.SSLv3: + self.options &= ~Options.OP_NO_SSLv3 + super(SSLContext, SSLContext).minimum_version.__set__(self, value) + + @property + def maximum_version(self): + return TLSVersion(super().maximum_version) + + @maximum_version.setter + def maximum_version(self, value): + super(SSLContext, SSLContext).maximum_version.__set__(self, value) + @property def options(self): return Options(super().options) diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index ca2357e..8d98b80 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -1077,6 +1077,69 @@ class ContextTests(unittest.TestCase): with self.assertRaises(AttributeError): ctx.hostname_checks_common_name = True + @unittest.skipUnless(hasattr(ssl.SSLContext, 'minimum_version'), + "required OpenSSL 1.1.0g") + def test_min_max_version(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + self.assertEqual( + ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED + ) + self.assertEqual( + ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED + ) + + ctx.minimum_version = ssl.TLSVersion.TLSv1_1 + ctx.maximum_version = ssl.TLSVersion.TLSv1_2 + self.assertEqual( + ctx.minimum_version, ssl.TLSVersion.TLSv1_1 + ) + self.assertEqual( + ctx.maximum_version, ssl.TLSVersion.TLSv1_2 + ) + + ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED + ctx.maximum_version = ssl.TLSVersion.TLSv1 + self.assertEqual( + ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED + ) + self.assertEqual( + ctx.maximum_version, ssl.TLSVersion.TLSv1 + ) + + ctx.maximum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED + self.assertEqual( + ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED + ) + + ctx.maximum_version = ssl.TLSVersion.MINIMUM_SUPPORTED + self.assertIn( + ctx.maximum_version, + {ssl.TLSVersion.TLSv1, ssl.TLSVersion.SSLv3} + ) + + ctx.minimum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED + self.assertIn( + ctx.minimum_version, + {ssl.TLSVersion.TLSv1_2, ssl.TLSVersion.TLSv1_3} + ) + + with self.assertRaises(ValueError): + ctx.minimum_version = 42 + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_1) + + self.assertEqual( + ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED + ) + self.assertEqual( + ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED + ) + with self.assertRaises(ValueError): + ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED + with self.assertRaises(ValueError): + ctx.maximum_version = ssl.TLSVersion.TLSv1 + + @unittest.skipUnless(have_verify_flags(), "verify_flags need OpenSSL > 0.9.8") def test_verify_flags(self): @@ -3457,6 +3520,60 @@ class ThreadedTests(unittest.TestCase): }) self.assertEqual(s.version(), 'TLSv1.3') + @unittest.skipUnless(hasattr(ssl.SSLContext, 'minimum_version'), + "required OpenSSL 1.1.0g") + def test_min_max_version(self): + client_context, server_context, hostname = testing_context() + # client TLSv1.0 to 1.2 + client_context.minimum_version = ssl.TLSVersion.TLSv1 + client_context.maximum_version = ssl.TLSVersion.TLSv1_2 + # server only TLSv1.2 + server_context.minimum_version = ssl.TLSVersion.TLSv1_2 + server_context.maximum_version = ssl.TLSVersion.TLSv1_2 + + with ThreadedEchoServer(context=server_context) as server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + self.assertEqual(s.version(), 'TLSv1.2') + + # client 1.0 to 1.2, server 1.0 to 1.1 + server_context.minimum_version = ssl.TLSVersion.TLSv1 + server_context.maximum_version = ssl.TLSVersion.TLSv1_1 + + with ThreadedEchoServer(context=server_context) as server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + self.assertEqual(s.version(), 'TLSv1.1') + + # client 1.0, server 1.2 (mismatch) + server_context.minimum_version = ssl.TLSVersion.TLSv1_2 + server_context.maximum_version = ssl.TLSVersion.TLSv1_2 + client_context.minimum_version = ssl.TLSVersion.TLSv1 + client_context.maximum_version = ssl.TLSVersion.TLSv1 + with ThreadedEchoServer(context=server_context) as server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + with self.assertRaises(ssl.SSLError) as e: + s.connect((HOST, server.port)) + self.assertIn("alert", str(e.exception)) + + + @unittest.skipUnless(hasattr(ssl.SSLContext, 'minimum_version'), + "required OpenSSL 1.1.0g") + @unittest.skipUnless(ssl.HAS_SSLv3, "requires SSLv3 support") + def test_min_max_version_sslv3(self): + client_context, server_context, hostname = testing_context() + server_context.minimum_version = ssl.TLSVersion.SSLv3 + client_context.minimum_version = ssl.TLSVersion.SSLv3 + client_context.maximum_version = ssl.TLSVersion.SSLv3 + with ThreadedEchoServer(context=server_context) as server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + self.assertEqual(s.version(), 'SSLv3') + @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL") def test_default_ecdh_curve(self): # Issue #21015: elliptic curve-based Diffie Hellman key exchange |