summaryrefslogtreecommitdiffstats
path: root/Lib/ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r--Lib/ssl.py138
1 files changed, 67 insertions, 71 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index 94ea35e..75ebcc1 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -390,24 +390,24 @@ class SSLContext(_SSLContext):
server_hostname=None, session=None):
# SSLSocket class handles server_hostname encoding before it calls
# ctx._wrap_socket()
- return self.sslsocket_class(
+ return self.sslsocket_class._create(
sock=sock,
server_side=server_side,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
server_hostname=server_hostname,
- _context=self,
- _session=session
+ context=self,
+ session=session
)
def wrap_bio(self, incoming, outgoing, server_side=False,
server_hostname=None, session=None):
# Need to encode server_hostname here because _wrap_bio() can only
# handle ASCII str.
- return self.sslobject_class(
+ return self.sslobject_class._create(
incoming, outgoing, server_side=server_side,
server_hostname=self._encode_hostname(server_hostname),
- session=session, _context=self,
+ session=session, context=self,
)
def set_npn_protocols(self, npn_protocols):
@@ -612,14 +612,23 @@ class SSLObject:
* Any form of network IO incluging methods such as ``recv`` and ``send``.
* The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
"""
+ def __init__(self, *args, **kwargs):
+ raise TypeError(
+ f"{self.__class__.__name__} does not have a public "
+ f"constructor. Instances are returned by SSLContext.wrap_bio()."
+ )
- def __init__(self, incoming, outgoing, server_side=False,
- server_hostname=None, session=None, _context=None):
- self._sslobj = _context._wrap_bio(
+ @classmethod
+ def _create(cls, incoming, outgoing, server_side=False,
+ server_hostname=None, session=None, context=None):
+ self = cls.__new__(cls)
+ sslobj = context._wrap_bio(
incoming, outgoing, server_side=server_side,
server_hostname=server_hostname,
owner=self, session=session
)
+ self._sslobj = sslobj
+ return self
@property
def context(self):
@@ -741,72 +750,48 @@ class SSLObject:
class SSLSocket(socket):
"""This class implements a subtype of socket.socket that wraps
the underlying OS socket in an SSL context when necessary, and
- provides read and write methods over that channel."""
-
- def __init__(self, sock=None, keyfile=None, certfile=None,
- server_side=False, cert_reqs=CERT_NONE,
- ssl_version=PROTOCOL_TLS, ca_certs=None,
- do_handshake_on_connect=True,
- family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
- suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
- server_hostname=None,
- _context=None, _session=None):
-
- if _context:
- self._context = _context
- else:
- if server_side and not certfile:
- raise ValueError("certfile must be specified for server-side "
- "operations")
- if keyfile and not certfile:
- raise ValueError("certfile must be specified")
- if certfile and not keyfile:
- keyfile = certfile
- self._context = SSLContext(ssl_version)
- self._context.verify_mode = cert_reqs
- if ca_certs:
- self._context.load_verify_locations(ca_certs)
- if certfile:
- self._context.load_cert_chain(certfile, keyfile)
- if npn_protocols:
- self._context.set_npn_protocols(npn_protocols)
- if ciphers:
- self._context.set_ciphers(ciphers)
- self.keyfile = keyfile
- self.certfile = certfile
- self.cert_reqs = cert_reqs
- self.ssl_version = ssl_version
- self.ca_certs = ca_certs
- self.ciphers = ciphers
- # Can't use sock.type as other flags (such as SOCK_NONBLOCK) get
- # mixed in.
+ provides read and write methods over that channel. """
+
+ def __init__(self, *args, **kwargs):
+ raise TypeError(
+ f"{self.__class__.__name__} does not have a public "
+ f"constructor. Instances are returned by "
+ f"SSLContext.wrap_socket()."
+ )
+
+ @classmethod
+ def _create(cls, sock, server_side=False, do_handshake_on_connect=True,
+ suppress_ragged_eofs=True, server_hostname=None,
+ context=None, session=None):
if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
raise NotImplementedError("only stream sockets are supported")
if server_side:
if server_hostname:
raise ValueError("server_hostname can only be specified "
"in client mode")
- if _session is not None:
+ if session is not None:
raise ValueError("session can only be specified in "
"client mode")
- if self._context.check_hostname and not server_hostname:
+ if context.check_hostname and not server_hostname:
raise ValueError("check_hostname requires server_hostname")
- self._session = _session
+
+ kwargs = dict(
+ family=sock.family, type=sock.type, proto=sock.proto,
+ fileno=sock.fileno()
+ )
+ self = cls.__new__(cls, **kwargs)
+ super(SSLSocket, self).__init__(**kwargs)
+ self.settimeout(sock.gettimeout())
+ sock.detach()
+
+ self._context = context
+ self._session = session
+ self._closed = False
+ self._sslobj = None
self.server_side = server_side
- self.server_hostname = self._context._encode_hostname(server_hostname)
+ self.server_hostname = context._encode_hostname(server_hostname)
self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs
- if sock is not None:
- super().__init__(family=sock.family,
- type=sock.type,
- proto=sock.proto,
- fileno=sock.fileno())
- self.settimeout(sock.gettimeout())
- sock.detach()
- elif fileno is not None:
- super().__init__(fileno=fileno)
- else:
- super().__init__(family=family, type=type, proto=proto)
# See if we are connected
try:
@@ -818,8 +803,6 @@ class SSLSocket(socket):
else:
connected = True
- self._closed = False
- self._sslobj = None
self._connected = connected
if connected:
# create the SSL object
@@ -834,10 +817,10 @@ class SSLSocket(socket):
# non-blocking
raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
self.do_handshake()
-
except (OSError, ValueError):
self.close()
raise
+ return self
@property
def context(self):
@@ -1184,12 +1167,25 @@ def wrap_socket(sock, keyfile=None, certfile=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
ciphers=None):
- return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
- server_side=server_side, cert_reqs=cert_reqs,
- ssl_version=ssl_version, ca_certs=ca_certs,
- do_handshake_on_connect=do_handshake_on_connect,
- suppress_ragged_eofs=suppress_ragged_eofs,
- ciphers=ciphers)
+
+ if server_side and not certfile:
+ raise ValueError("certfile must be specified for server-side "
+ "operations")
+ if keyfile and not certfile:
+ raise ValueError("certfile must be specified")
+ context = SSLContext(ssl_version)
+ context.verify_mode = cert_reqs
+ if ca_certs:
+ context.load_verify_locations(ca_certs)
+ if certfile:
+ context.load_cert_chain(certfile, keyfile)
+ if ciphers:
+ context.set_ciphers(ciphers)
+ return context.wrap_socket(
+ sock=sock, server_side=server_side,
+ do_handshake_on_connect=do_handshake_on_connect,
+ suppress_ragged_eofs=suppress_ragged_eofs
+ )
# some utility functions