summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/ssl.py32
-rw-r--r--Lib/test/test_ssl.py62
2 files changed, 89 insertions, 5 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index b29b905..4c155ea 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -148,6 +148,7 @@ if sys.platform == "win32":
from _ssl import enum_certificates, enum_crls
from socket import getnameinfo as _getnameinfo
+from socket import SHUT_RDWR as _SHUT_RDWR
from socket import socket, AF_INET, SOCK_STREAM, create_connection
import base64 # for DER-to-PEM translation
import traceback
@@ -235,7 +236,9 @@ def match_hostname(cert, hostname):
returns nothing.
"""
if not cert:
- raise ValueError("empty or no certificate")
+ raise ValueError("empty or no certificate, match_hostname needs a "
+ "SSL socket or SSL context with either "
+ "CERT_OPTIONAL or CERT_REQUIRED")
dnsnames = []
san = cert.get('subjectAltName', ())
for key, value in san:
@@ -387,9 +390,10 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
context.options |= getattr(_ssl, "OP_NO_COMPRESSION", 0)
# disallow ciphers with known vulnerabilities
context.set_ciphers(_RESTRICTED_CIPHERS)
- # verify certs in client mode
+ # verify certs and host name in client mode
if purpose == Purpose.SERVER_AUTH:
context.verify_mode = CERT_REQUIRED
+ context.check_hostname = True
if cafile or capath or cadata:
context.load_verify_locations(cafile, capath, cadata)
elif context.verify_mode != CERT_NONE:
@@ -480,6 +484,13 @@ class SSLSocket(socket):
if server_side and server_hostname:
raise ValueError("server_hostname can only be specified "
"in client mode")
+ if self._context.check_hostname and not server_hostname:
+ if HAS_SNI:
+ raise ValueError("check_hostname requires server_hostname")
+ else:
+ raise ValueError("check_hostname requires server_hostname, "
+ "but it's not supported by your OpenSSL "
+ "library")
self.server_side = server_side
self.server_hostname = server_hostname
self.do_handshake_on_connect = do_handshake_on_connect
@@ -522,9 +533,9 @@ class SSLSocket(socket):
raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
self.do_handshake()
- except OSError as x:
+ except (OSError, ValueError):
self.close()
- raise x
+ raise
@property
def context(self):
@@ -751,6 +762,17 @@ class SSLSocket(socket):
finally:
self.settimeout(timeout)
+ if self.context.check_hostname:
+ try:
+ if not self.server_hostname:
+ raise ValueError("check_hostname needs server_hostname "
+ "argument")
+ match_hostname(self.getpeercert(), self.server_hostname)
+ except Exception:
+ self.shutdown(_SHUT_RDWR)
+ self.close()
+ raise
+
def _real_connect(self, addr, connect_ex):
if self.server_side:
raise ValueError("can't connect in server-side mode")
@@ -770,7 +792,7 @@ class SSLSocket(socket):
if self.do_handshake_on_connect:
self.do_handshake()
return rc
- except OSError:
+ except (OSError, ValueError):
self._sslobj = None
raise
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index afec72a..ed263c3 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -1003,6 +1003,7 @@ class ContextTests(unittest.TestCase):
ctx = ssl.create_default_context()
self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
+ self.assertTrue(ctx.check_hostname)
self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2)
with open(SIGNING_CA) as f:
@@ -1022,6 +1023,7 @@ class ContextTests(unittest.TestCase):
ctx = ssl._create_stdlib_context()
self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23)
self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
+ self.assertFalse(ctx.check_hostname)
self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2)
ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1)
@@ -1040,6 +1042,28 @@ class ContextTests(unittest.TestCase):
self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2)
+ def test_check_hostname(self):
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ self.assertFalse(ctx.check_hostname)
+
+ # Requires CERT_REQUIRED or CERT_OPTIONAL
+ with self.assertRaises(ValueError):
+ ctx.check_hostname = True
+ ctx.verify_mode = ssl.CERT_REQUIRED
+ self.assertFalse(ctx.check_hostname)
+ ctx.check_hostname = True
+ self.assertTrue(ctx.check_hostname)
+
+ ctx.verify_mode = ssl.CERT_OPTIONAL
+ ctx.check_hostname = True
+ self.assertTrue(ctx.check_hostname)
+
+ # Cannot set CERT_NONE with check_hostname enabled
+ with self.assertRaises(ValueError):
+ ctx.verify_mode = ssl.CERT_NONE
+ ctx.check_hostname = False
+ self.assertFalse(ctx.check_hostname)
+
class SSLErrorTests(unittest.TestCase):
@@ -1930,6 +1954,44 @@ else:
cert = s.getpeercert()
self.assertTrue(cert, "Can't get peer certificate.")
+ def test_check_hostname(self):
+ if support.verbose:
+ sys.stdout.write("\n")
+
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.check_hostname = True
+ context.load_verify_locations(SIGNING_CA)
+
+ # correct hostname should verify
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with context.wrap_socket(socket.socket(),
+ server_hostname="localhost") as s:
+ s.connect((HOST, server.port))
+ cert = s.getpeercert()
+ self.assertTrue(cert, "Can't get peer certificate.")
+
+ # incorrect hostname should raise an exception
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with context.wrap_socket(socket.socket(),
+ server_hostname="invalid") as s:
+ with self.assertRaisesRegex(ssl.CertificateError,
+ "hostname 'invalid' doesn't match 'localhost'"):
+ s.connect((HOST, server.port))
+
+ # missing server_hostname arg should cause an exception, too
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with socket.socket() as s:
+ with self.assertRaisesRegex(ValueError,
+ "check_hostname requires server_hostname"):
+ context.wrap_socket(s)
+
def test_empty_cert(self):
"""Connecting with an empty cert file"""
bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,