diff options
author | Antoine Pitrou <solipsis@pitrou.net> | 2013-05-01 18:52:07 (GMT) |
---|---|---|
committer | Antoine Pitrou <solipsis@pitrou.net> | 2013-05-01 18:52:07 (GMT) |
commit | 242db728e2fcbf9004143517d240301334b02545 (patch) | |
tree | 9e26f188d67f7ecec49cbf495753d6bcbdb74d80 /Lib | |
parent | f6ca26fbffb689190da9dbe66df09c7d7e118616 (diff) | |
download | cpython-242db728e2fcbf9004143517d240301334b02545.zip cpython-242db728e2fcbf9004143517d240301334b02545.tar.gz cpython-242db728e2fcbf9004143517d240301334b02545.tar.bz2 |
Issue #13721: SSLSocket.getpeercert() and SSLSocket.do_handshake() now raise an OSError with ENOTCONN, instead of an AttributeError, when the SSLSocket is not connected.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/ssl.py | 34 | ||||
-rw-r--r-- | Lib/test/test_ssl.py | 15 |
2 files changed, 37 insertions, 12 deletions
@@ -299,7 +299,6 @@ class SSLSocket(socket): self.server_hostname = server_hostname self.do_handshake_on_connect = do_handshake_on_connect self.suppress_ragged_eofs = suppress_ragged_eofs - connected = False if sock is not None: socket.__init__(self, family=sock.family, @@ -307,20 +306,22 @@ class SSLSocket(socket): proto=sock.proto, fileno=sock.fileno()) self.settimeout(sock.gettimeout()) - # see if it's connected - try: - sock.getpeername() - except OSError as e: - if e.errno != errno.ENOTCONN: - raise - else: - connected = True sock.detach() elif fileno is not None: socket.__init__(self, fileno=fileno) else: socket.__init__(self, family=family, type=type, proto=proto) + # See if we are connected + try: + self.getpeername() + except OSError as e: + if e.errno != errno.ENOTCONN: + raise + connected = False + else: + connected = True + self._closed = False self._sslobj = None self._connected = connected @@ -339,6 +340,7 @@ class SSLSocket(socket): except OSError as x: self.close() raise x + @property def context(self): return self._context @@ -356,6 +358,14 @@ class SSLSocket(socket): # raise an exception here if you wish to check for spurious closes pass + def _check_connected(self): + if not self._connected: + # getpeername() will raise ENOTCONN if the socket is really + # not connected; note that we can be connected even without + # _connected being set, e.g. if connect() first returned + # EAGAIN. + self.getpeername() + def read(self, len=0, buffer=None): """Read up to LEN bytes and return them. Return zero-length string on EOF.""" @@ -390,6 +400,7 @@ class SSLSocket(socket): certificate was provided, but not validated.""" self._checkClosed() + self._check_connected() return self._sslobj.peer_certificate(binary_form) def selected_npn_protocol(self): @@ -538,12 +549,11 @@ class SSLSocket(socket): def _real_close(self): self._sslobj = None - # self._closed = True socket._real_close(self) def do_handshake(self, block=False): """Perform a TLS/SSL handshake.""" - + self._check_connected() timeout = self.gettimeout() try: if timeout == 0.0 and block: @@ -567,9 +577,9 @@ class SSLSocket(socket): rc = None socket.connect(self, addr) if not rc: + self._connected = True if self.do_handshake_on_connect: self.do_handshake() - self._connected = True return rc except OSError: self._sslobj = None diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 1f0f62a..d4c90cf 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -17,6 +17,7 @@ import asyncore import weakref import platform import functools +from unittest import mock ssl = support.import_module("ssl") @@ -1931,6 +1932,20 @@ else: self.assertIsInstance(remote, ssl.SSLSocket) self.assertEqual(peer, client_addr) + def test_getpeercert_enotconn(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + with context.wrap_socket(socket.socket()) as sock: + with self.assertRaises(OSError) as cm: + sock.getpeercert() + self.assertEqual(cm.exception.errno, errno.ENOTCONN) + + def test_do_handshake_enotconn(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + with context.wrap_socket(socket.socket()) as sock: + with self.assertRaises(OSError) as cm: + sock.do_handshake() + self.assertEqual(cm.exception.errno, errno.ENOTCONN) + def test_default_ciphers(self): context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) try: |