diff options
author | Christian Heimes <christian@python.org> | 2018-02-24 20:10:57 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-24 20:10:57 (GMT) |
commit | 141c5e8c2437a9fed95a04c81e400ef725592a17 (patch) | |
tree | 01d9c30cff72bfe58a95cf2013758581adcc7907 /Lib/ssl.py | |
parent | b18f8bc1a77193c372d79afa79b284028a2842d7 (diff) | |
download | cpython-141c5e8c2437a9fed95a04c81e400ef725592a17.zip cpython-141c5e8c2437a9fed95a04c81e400ef725592a17.tar.gz cpython-141c5e8c2437a9fed95a04c81e400ef725592a17.tar.bz2 |
bpo-24334: Cleanup SSLSocket (#5252)
* The SSLSocket is no longer implemented on top of SSLObject to
avoid an extra level of indirection.
* Owner and session are now handled in the internal constructor.
* _ssl._SSLSocket now uses the same method names as SSLSocket and
SSLObject.
* Channel binding type check is now handled in C code. Channel binding
is always available.
The patch also changes the signature of SSLObject.__init__(). In my
opinion it's fine. A SSLObject is not a user-constructable object.
SSLContext.wrap_bio() is the only valid factory.
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r-- | Lib/ssl.py | 116 |
1 files changed, 62 insertions, 54 deletions
@@ -166,10 +166,7 @@ import warnings socket_error = OSError # keep that public name in module namespace -if _ssl.HAS_TLS_UNIQUE: - CHANNEL_BINDING_TYPES = ['tls-unique'] -else: - CHANNEL_BINDING_TYPES = [] +CHANNEL_BINDING_TYPES = ['tls-unique'] HAS_NEVER_CHECK_COMMON_NAME = hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT') @@ -407,11 +404,11 @@ class SSLContext(_SSLContext): server_hostname=None, session=None): # Need to encode server_hostname here because _wrap_bio() can only # handle ASCII str. - sslobj = self._wrap_bio( + return self.sslobject_class( incoming, outgoing, server_side=server_side, - server_hostname=self._encode_hostname(server_hostname) + server_hostname=self._encode_hostname(server_hostname), + session=session, _context=self, ) - return self.sslobject_class(sslobj, session=session) def set_npn_protocols(self, npn_protocols): protos = bytearray() @@ -616,12 +613,13 @@ class SSLObject: * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery. """ - def __init__(self, sslobj, owner=None, session=None): - self._sslobj = sslobj - # Note: _sslobj takes a weak reference to owner - self._sslobj.owner = owner or self - if session is not None: - self._sslobj.session = session + def __init__(self, incoming, outgoing, server_side=False, + server_hostname=None, session=None, _context=None): + self._sslobj = _context._wrap_bio( + incoming, outgoing, server_side=server_side, + server_hostname=server_hostname, + owner=self, session=session + ) @property def context(self): @@ -684,7 +682,7 @@ class SSLObject: Return None if no certificate was provided, {} if a certificate was provided, but not validated. """ - return self._sslobj.peer_certificate(binary_form) + return self._sslobj.getpeercert(binary_form) def selected_npn_protocol(self): """Return the currently selected NPN protocol as a string, or ``None`` @@ -732,13 +730,7 @@ class SSLObject: """Get channel binding data for current connection. Raise ValueError if the requested `cb_type` is not supported. Return bytes of the data or None if the data is not available (e.g. before the handshake).""" - if cb_type not in CHANNEL_BINDING_TYPES: - raise ValueError("Unsupported channel binding type") - if cb_type != "tls-unique": - raise NotImplementedError( - "{0} channel binding type not implemented" - .format(cb_type)) - return self._sslobj.tls_unique_cb() + return self._sslobj.get_channel_binding(cb_type) def version(self): """Return a string identifying the protocol version used by the @@ -832,10 +824,10 @@ class SSLSocket(socket): if connected: # create the SSL object try: - sslobj = self._context._wrap_socket(self, server_side, - self.server_hostname) - self._sslobj = SSLObject(sslobj, owner=self, - session=self._session) + self._sslobj = self._context._wrap_socket( + self, server_side, self.server_hostname, + owner=self, session=self._session, + ) if do_handshake_on_connect: timeout = self.gettimeout() if timeout == 0.0: @@ -895,10 +887,13 @@ class SSLSocket(socket): Return zero-length string on EOF.""" self._checkClosed() - if not self._sslobj: + if self._sslobj is None: raise ValueError("Read on closed or unwrapped SSL socket.") try: - return self._sslobj.read(len, buffer) + if buffer is not None: + return self._sslobj.read(len, buffer) + else: + return self._sslobj.read(len) except SSLError as x: if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: if buffer is not None: @@ -913,7 +908,7 @@ class SSLSocket(socket): number of bytes of DATA actually transmitted.""" self._checkClosed() - if not self._sslobj: + if self._sslobj is None: raise ValueError("Write on closed or unwrapped SSL socket.") return self._sslobj.write(data) @@ -929,41 +924,42 @@ class SSLSocket(socket): def selected_npn_protocol(self): self._checkClosed() - if not self._sslobj or not _ssl.HAS_NPN: + if self._sslobj is None or not _ssl.HAS_NPN: return None else: return self._sslobj.selected_npn_protocol() def selected_alpn_protocol(self): self._checkClosed() - if not self._sslobj or not _ssl.HAS_ALPN: + if self._sslobj is None or not _ssl.HAS_ALPN: return None else: return self._sslobj.selected_alpn_protocol() def cipher(self): self._checkClosed() - if not self._sslobj: + if self._sslobj is None: return None else: return self._sslobj.cipher() def shared_ciphers(self): self._checkClosed() - if not self._sslobj: + if self._sslobj is None: return None - return self._sslobj.shared_ciphers() + else: + return self._sslobj.shared_ciphers() def compression(self): self._checkClosed() - if not self._sslobj: + if self._sslobj is None: return None else: return self._sslobj.compression() def send(self, data, flags=0): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: if flags != 0: raise ValueError( "non-zero flags not allowed in calls to send() on %s" % @@ -974,7 +970,7 @@ class SSLSocket(socket): def sendto(self, data, flags_or_addr, addr=None): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: raise ValueError("sendto not allowed on instances of %s" % self.__class__) elif addr is None: @@ -990,7 +986,7 @@ class SSLSocket(socket): def sendall(self, data, flags=0): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: if flags != 0: raise ValueError( "non-zero flags not allowed in calls to sendall() on %s" % @@ -1008,15 +1004,15 @@ class SSLSocket(socket): """Send a file, possibly by using os.sendfile() if this is a clear-text socket. Return the total number of bytes sent. """ - if self._sslobj is None: + if self._sslobj is not None: + return self._sendfile_use_send(file, offset, count) + else: # os.sendfile() works with plain sockets only return super().sendfile(file, offset, count) - else: - return self._sendfile_use_send(file, offset, count) def recv(self, buflen=1024, flags=0): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: if flags != 0: raise ValueError( "non-zero flags not allowed in calls to recv() on %s" % @@ -1031,7 +1027,7 @@ class SSLSocket(socket): nbytes = len(buffer) elif nbytes is None: nbytes = 1024 - if self._sslobj: + if self._sslobj is not None: if flags != 0: raise ValueError( "non-zero flags not allowed in calls to recv_into() on %s" % @@ -1042,7 +1038,7 @@ class SSLSocket(socket): def recvfrom(self, buflen=1024, flags=0): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: raise ValueError("recvfrom not allowed on instances of %s" % self.__class__) else: @@ -1050,7 +1046,7 @@ class SSLSocket(socket): def recvfrom_into(self, buffer, nbytes=None, flags=0): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: raise ValueError("recvfrom_into not allowed on instances of %s" % self.__class__) else: @@ -1066,7 +1062,7 @@ class SSLSocket(socket): def pending(self): self._checkClosed() - if self._sslobj: + if self._sslobj is not None: return self._sslobj.pending() else: return 0 @@ -1078,7 +1074,7 @@ class SSLSocket(socket): def unwrap(self): if self._sslobj: - s = self._sslobj.unwrap() + s = self._sslobj.shutdown() self._sslobj = None return s else: @@ -1096,6 +1092,11 @@ class SSLSocket(socket): if timeout == 0.0 and block: self.settimeout(None) self._sslobj.do_handshake() + if self.context.check_hostname: + if not self.server_hostname: + raise ValueError("check_hostname needs server_hostname " + "argument") + match_hostname(self.getpeercert(), self.server_hostname) finally: self.settimeout(timeout) @@ -1104,11 +1105,12 @@ class SSLSocket(socket): raise ValueError("can't connect in server-side mode") # Here we assume that the socket is client-side, and not # connected at the time of the call. We connect it, then wrap it. - if self._connected: + if self._connected or self._sslobj is not None: raise ValueError("attempt to connect already-connected SSLSocket!") - sslobj = self.context._wrap_socket(self, False, self.server_hostname) - self._sslobj = SSLObject(sslobj, owner=self, - session=self._session) + self._sslobj = self.context._wrap_socket( + self, False, self.server_hostname, + owner=self, session=self._session + ) try: if connect_ex: rc = super().connect_ex(addr) @@ -1151,18 +1153,24 @@ class SSLSocket(socket): if the requested `cb_type` is not supported. Return bytes of the data or None if the data is not available (e.g. before the handshake). """ - if self._sslobj is None: + if self._sslobj is not None: + return self._sslobj.get_channel_binding(cb_type) + else: + if cb_type not in CHANNEL_BINDING_TYPES: + raise ValueError( + "{0} channel binding type not implemented".format(cb_type) + ) return None - return self._sslobj.get_channel_binding(cb_type) def version(self): """ Return a string identifying the protocol version used by the current SSL channel, or None if there is no established channel. """ - if self._sslobj is None: + if self._sslobj is not None: + return self._sslobj.version() + else: return None - return self._sslobj.version() # Python does not support forward declaration of types. |