summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/ssl.py37
-rw-r--r--Lib/test/test_ssl.py117
2 files changed, 151 insertions, 3 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index 75ebcc1..2db8873 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -112,9 +112,11 @@ except ImportError:
pass
-from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_TLSv1_3
-from _ssl import _DEFAULT_CIPHERS
-from _ssl import _OPENSSL_API_VERSION
+from _ssl import (
+ HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_SSLv2, HAS_SSLv3, HAS_TLSv1,
+ HAS_TLSv1_1, HAS_TLSv1_2, HAS_TLSv1_3
+)
+from _ssl import _DEFAULT_CIPHERS, _OPENSSL_API_VERSION
_IntEnum._convert(
@@ -153,6 +155,16 @@ _PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()
_SSLv2_IF_EXISTS = getattr(_SSLMethod, 'PROTOCOL_SSLv2', None)
+class TLSVersion(_IntEnum):
+ MINIMUM_SUPPORTED = _ssl.PROTO_MINIMUM_SUPPORTED
+ SSLv3 = _ssl.PROTO_SSLv3
+ TLSv1 = _ssl.PROTO_TLSv1
+ TLSv1_1 = _ssl.PROTO_TLSv1_1
+ TLSv1_2 = _ssl.PROTO_TLSv1_2
+ TLSv1_3 = _ssl.PROTO_TLSv1_3
+ MAXIMUM_SUPPORTED = _ssl.PROTO_MAXIMUM_SUPPORTED
+
+
if sys.platform == "win32":
from _ssl import enum_certificates, enum_crls
@@ -467,6 +479,25 @@ class SSLContext(_SSLContext):
self._load_windows_store_certs(storename, purpose)
self.set_default_verify_paths()
+ if hasattr(_SSLContext, 'minimum_version'):
+ @property
+ def minimum_version(self):
+ return TLSVersion(super().minimum_version)
+
+ @minimum_version.setter
+ def minimum_version(self, value):
+ if value == TLSVersion.SSLv3:
+ self.options &= ~Options.OP_NO_SSLv3
+ super(SSLContext, SSLContext).minimum_version.__set__(self, value)
+
+ @property
+ def maximum_version(self):
+ return TLSVersion(super().maximum_version)
+
+ @maximum_version.setter
+ def maximum_version(self, value):
+ super(SSLContext, SSLContext).maximum_version.__set__(self, value)
+
@property
def options(self):
return Options(super().options)
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index ca2357e..8d98b80 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -1077,6 +1077,69 @@ class ContextTests(unittest.TestCase):
with self.assertRaises(AttributeError):
ctx.hostname_checks_common_name = True
+ @unittest.skipUnless(hasattr(ssl.SSLContext, 'minimum_version'),
+ "required OpenSSL 1.1.0g")
+ def test_min_max_version(self):
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ self.assertEqual(
+ ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED
+ )
+ self.assertEqual(
+ ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
+ )
+
+ ctx.minimum_version = ssl.TLSVersion.TLSv1_1
+ ctx.maximum_version = ssl.TLSVersion.TLSv1_2
+ self.assertEqual(
+ ctx.minimum_version, ssl.TLSVersion.TLSv1_1
+ )
+ self.assertEqual(
+ ctx.maximum_version, ssl.TLSVersion.TLSv1_2
+ )
+
+ ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
+ ctx.maximum_version = ssl.TLSVersion.TLSv1
+ self.assertEqual(
+ ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED
+ )
+ self.assertEqual(
+ ctx.maximum_version, ssl.TLSVersion.TLSv1
+ )
+
+ ctx.maximum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
+ self.assertEqual(
+ ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
+ )
+
+ ctx.maximum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
+ self.assertIn(
+ ctx.maximum_version,
+ {ssl.TLSVersion.TLSv1, ssl.TLSVersion.SSLv3}
+ )
+
+ ctx.minimum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
+ self.assertIn(
+ ctx.minimum_version,
+ {ssl.TLSVersion.TLSv1_2, ssl.TLSVersion.TLSv1_3}
+ )
+
+ with self.assertRaises(ValueError):
+ ctx.minimum_version = 42
+
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_1)
+
+ self.assertEqual(
+ ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED
+ )
+ self.assertEqual(
+ ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
+ )
+ with self.assertRaises(ValueError):
+ ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
+ with self.assertRaises(ValueError):
+ ctx.maximum_version = ssl.TLSVersion.TLSv1
+
+
@unittest.skipUnless(have_verify_flags(),
"verify_flags need OpenSSL > 0.9.8")
def test_verify_flags(self):
@@ -3457,6 +3520,60 @@ class ThreadedTests(unittest.TestCase):
})
self.assertEqual(s.version(), 'TLSv1.3')
+ @unittest.skipUnless(hasattr(ssl.SSLContext, 'minimum_version'),
+ "required OpenSSL 1.1.0g")
+ def test_min_max_version(self):
+ client_context, server_context, hostname = testing_context()
+ # client TLSv1.0 to 1.2
+ client_context.minimum_version = ssl.TLSVersion.TLSv1
+ client_context.maximum_version = ssl.TLSVersion.TLSv1_2
+ # server only TLSv1.2
+ server_context.minimum_version = ssl.TLSVersion.TLSv1_2
+ server_context.maximum_version = ssl.TLSVersion.TLSv1_2
+
+ with ThreadedEchoServer(context=server_context) as server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ self.assertEqual(s.version(), 'TLSv1.2')
+
+ # client 1.0 to 1.2, server 1.0 to 1.1
+ server_context.minimum_version = ssl.TLSVersion.TLSv1
+ server_context.maximum_version = ssl.TLSVersion.TLSv1_1
+
+ with ThreadedEchoServer(context=server_context) as server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ self.assertEqual(s.version(), 'TLSv1.1')
+
+ # client 1.0, server 1.2 (mismatch)
+ server_context.minimum_version = ssl.TLSVersion.TLSv1_2
+ server_context.maximum_version = ssl.TLSVersion.TLSv1_2
+ client_context.minimum_version = ssl.TLSVersion.TLSv1
+ client_context.maximum_version = ssl.TLSVersion.TLSv1
+ with ThreadedEchoServer(context=server_context) as server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ with self.assertRaises(ssl.SSLError) as e:
+ s.connect((HOST, server.port))
+ self.assertIn("alert", str(e.exception))
+
+
+ @unittest.skipUnless(hasattr(ssl.SSLContext, 'minimum_version'),
+ "required OpenSSL 1.1.0g")
+ @unittest.skipUnless(ssl.HAS_SSLv3, "requires SSLv3 support")
+ def test_min_max_version_sslv3(self):
+ client_context, server_context, hostname = testing_context()
+ server_context.minimum_version = ssl.TLSVersion.SSLv3
+ client_context.minimum_version = ssl.TLSVersion.SSLv3
+ client_context.maximum_version = ssl.TLSVersion.SSLv3
+ with ThreadedEchoServer(context=server_context) as server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ self.assertEqual(s.version(), 'SSLv3')
+
@unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
def test_default_ecdh_curve(self):
# Issue #21015: elliptic curve-based Diffie Hellman key exchange