summaryrefslogtreecommitdiffstats
path: root/Lib/ssl.py
diff options
context:
space:
mode:
authorMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>2024-02-04 16:16:57 (GMT)
committerGitHub <noreply@github.com>2024-02-04 16:16:57 (GMT)
commitf8cba751e25e14aa59ea79df65f885c3aaba1865 (patch)
tree49fddbbe3625361d834241a8ab3efe9d6f8760a7 /Lib/ssl.py
parent55e5ae70b30b365515e156adfcf897fe054aaca1 (diff)
downloadcpython-f8cba751e25e14aa59ea79df65f885c3aaba1865.zip
cpython-f8cba751e25e14aa59ea79df65f885c3aaba1865.tar.gz
cpython-f8cba751e25e14aa59ea79df65f885c3aaba1865.tar.bz2
[3.11] gh-113280: Always close socket if SSLSocket creation failed (GH-114659) (GH-114996)
(cherry picked from commit 0ea366240b75380ed7568acbe95d72e481a734f7) Co-authored-by: Serhiy Storchaka <storchaka@gmail.com> Co-authored-by: Thomas Grainger <tagrain@gmail.com>
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r--Lib/ssl.py107
1 files changed, 53 insertions, 54 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index c0350a8..7825ccc 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -1031,71 +1031,67 @@ class SSLSocket(socket):
if context.check_hostname and not server_hostname:
raise ValueError("check_hostname requires server_hostname")
+ sock_timeout = sock.gettimeout()
kwargs = dict(
family=sock.family, type=sock.type, proto=sock.proto,
fileno=sock.fileno()
)
self = cls.__new__(cls, **kwargs)
super(SSLSocket, self).__init__(**kwargs)
- sock_timeout = sock.gettimeout()
sock.detach()
-
- self._context = context
- self._session = session
- self._closed = False
- self._sslobj = None
- self.server_side = server_side
- self.server_hostname = context._encode_hostname(server_hostname)
- self.do_handshake_on_connect = do_handshake_on_connect
- self.suppress_ragged_eofs = suppress_ragged_eofs
-
- # See if we are connected
+ # Now SSLSocket is responsible for closing the file descriptor.
try:
- self.getpeername()
- except OSError as e:
- if e.errno != errno.ENOTCONN:
- raise
- connected = False
- blocking = self.getblocking()
- self.setblocking(False)
+ self._context = context
+ self._session = session
+ self._closed = False
+ self._sslobj = None
+ self.server_side = server_side
+ self.server_hostname = context._encode_hostname(server_hostname)
+ self.do_handshake_on_connect = do_handshake_on_connect
+ self.suppress_ragged_eofs = suppress_ragged_eofs
+
+ # See if we are connected
try:
- # We are not connected so this is not supposed to block, but
- # testing revealed otherwise on macOS and Windows so we do
- # the non-blocking dance regardless. Our raise when any data
- # is found means consuming the data is harmless.
- notconn_pre_handshake_data = self.recv(1)
+ self.getpeername()
except OSError as e:
- # EINVAL occurs for recv(1) on non-connected on unix sockets.
- if e.errno not in (errno.ENOTCONN, errno.EINVAL):
+ if e.errno != errno.ENOTCONN:
raise
- notconn_pre_handshake_data = b''
- self.setblocking(blocking)
- if notconn_pre_handshake_data:
- # This prevents pending data sent to the socket before it was
- # closed from escaping to the caller who could otherwise
- # presume it came through a successful TLS connection.
- reason = "Closed before TLS handshake with data in recv buffer."
- notconn_pre_handshake_data_error = SSLError(e.errno, reason)
- # Add the SSLError attributes that _ssl.c always adds.
- notconn_pre_handshake_data_error.reason = reason
- notconn_pre_handshake_data_error.library = None
- try:
- self.close()
- except OSError:
- pass
+ connected = False
+ blocking = self.getblocking()
+ self.setblocking(False)
try:
- raise notconn_pre_handshake_data_error
- finally:
- # Explicitly break the reference cycle.
- notconn_pre_handshake_data_error = None
- else:
- connected = True
+ # We are not connected so this is not supposed to block, but
+ # testing revealed otherwise on macOS and Windows so we do
+ # the non-blocking dance regardless. Our raise when any data
+ # is found means consuming the data is harmless.
+ notconn_pre_handshake_data = self.recv(1)
+ except OSError as e:
+ # EINVAL occurs for recv(1) on non-connected on unix sockets.
+ if e.errno not in (errno.ENOTCONN, errno.EINVAL):
+ raise
+ notconn_pre_handshake_data = b''
+ self.setblocking(blocking)
+ if notconn_pre_handshake_data:
+ # This prevents pending data sent to the socket before it was
+ # closed from escaping to the caller who could otherwise
+ # presume it came through a successful TLS connection.
+ reason = "Closed before TLS handshake with data in recv buffer."
+ notconn_pre_handshake_data_error = SSLError(e.errno, reason)
+ # Add the SSLError attributes that _ssl.c always adds.
+ notconn_pre_handshake_data_error.reason = reason
+ notconn_pre_handshake_data_error.library = None
+ try:
+ raise notconn_pre_handshake_data_error
+ finally:
+ # Explicitly break the reference cycle.
+ notconn_pre_handshake_data_error = None
+ else:
+ connected = True
- self.settimeout(sock_timeout) # Must come after setblocking() calls.
- self._connected = connected
- if connected:
- # create the SSL object
- try:
+ self.settimeout(sock_timeout) # Must come after setblocking() calls.
+ self._connected = connected
+ if connected:
+ # create the SSL object
self._sslobj = self._context._wrap_socket(
self, server_side, self.server_hostname,
owner=self, session=self._session,
@@ -1106,9 +1102,12 @@ class SSLSocket(socket):
# non-blocking
raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
self.do_handshake()
- except (OSError, ValueError):
+ except:
+ try:
self.close()
- raise
+ except OSError:
+ pass
+ raise
return self
@property