summaryrefslogtreecommitdiffstats
path: root/Lib/ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r--Lib/ssl.py32
1 files changed, 27 insertions, 5 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index b29b905..4c155ea 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -148,6 +148,7 @@ if sys.platform == "win32":
from _ssl import enum_certificates, enum_crls
from socket import getnameinfo as _getnameinfo
+from socket import SHUT_RDWR as _SHUT_RDWR
from socket import socket, AF_INET, SOCK_STREAM, create_connection
import base64 # for DER-to-PEM translation
import traceback
@@ -235,7 +236,9 @@ def match_hostname(cert, hostname):
returns nothing.
"""
if not cert:
- raise ValueError("empty or no certificate")
+ raise ValueError("empty or no certificate, match_hostname needs a "
+ "SSL socket or SSL context with either "
+ "CERT_OPTIONAL or CERT_REQUIRED")
dnsnames = []
san = cert.get('subjectAltName', ())
for key, value in san:
@@ -387,9 +390,10 @@ def create_default_context(purpose=Purpose.SERVER_AUTH, *, cafile=None,
context.options |= getattr(_ssl, "OP_NO_COMPRESSION", 0)
# disallow ciphers with known vulnerabilities
context.set_ciphers(_RESTRICTED_CIPHERS)
- # verify certs in client mode
+ # verify certs and host name in client mode
if purpose == Purpose.SERVER_AUTH:
context.verify_mode = CERT_REQUIRED
+ context.check_hostname = True
if cafile or capath or cadata:
context.load_verify_locations(cafile, capath, cadata)
elif context.verify_mode != CERT_NONE:
@@ -480,6 +484,13 @@ class SSLSocket(socket):
if server_side and server_hostname:
raise ValueError("server_hostname can only be specified "
"in client mode")
+ if self._context.check_hostname and not server_hostname:
+ if HAS_SNI:
+ raise ValueError("check_hostname requires server_hostname")
+ else:
+ raise ValueError("check_hostname requires server_hostname, "
+ "but it's not supported by your OpenSSL "
+ "library")
self.server_side = server_side
self.server_hostname = server_hostname
self.do_handshake_on_connect = do_handshake_on_connect
@@ -522,9 +533,9 @@ class SSLSocket(socket):
raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
self.do_handshake()
- except OSError as x:
+ except (OSError, ValueError):
self.close()
- raise x
+ raise
@property
def context(self):
@@ -751,6 +762,17 @@ class SSLSocket(socket):
finally:
self.settimeout(timeout)
+ if self.context.check_hostname:
+ try:
+ if not self.server_hostname:
+ raise ValueError("check_hostname needs server_hostname "
+ "argument")
+ match_hostname(self.getpeercert(), self.server_hostname)
+ except Exception:
+ self.shutdown(_SHUT_RDWR)
+ self.close()
+ raise
+
def _real_connect(self, addr, connect_ex):
if self.server_side:
raise ValueError("can't connect in server-side mode")
@@ -770,7 +792,7 @@ class SSLSocket(socket):
if self.do_handshake_on_connect:
self.do_handshake()
return rc
- except OSError:
+ except (OSError, ValueError):
self._sslobj = None
raise