From 745ab3807e0c5f141376f6b9e1b6111e806c31d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Giampaolo=20Rodol=C3=A0?= Date: Sun, 29 Aug 2010 19:25:49 +0000 Subject: Fix issue issue9706: provides a better error handling for various SSL operations --- Lib/ssl.py | 8 ++++++-- Lib/test/test_ssl.py | 20 ++++++++++++++++++-- Modules/_ssl.c | 23 ++++++++++++++++++++--- 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/Lib/ssl.py b/Lib/ssl.py index af1cc84..a634442 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -122,6 +122,9 @@ class SSLSocket(socket): if _context: self.context = _context else: + if server_side and not certfile: + raise ValueError("certfile must be specified for server-side " + "operations") if certfile and not keyfile: keyfile = certfile self.context = SSLContext(ssl_version) @@ -138,7 +141,7 @@ class SSLSocket(socket): self.ssl_version = ssl_version self.ca_certs = ca_certs self.ciphers = ciphers - + self.server_side = server_side self.do_handshake_on_connect = do_handshake_on_connect self.suppress_ragged_eofs = suppress_ragged_eofs connected = False @@ -358,7 +361,8 @@ class SSLSocket(socket): def connect(self, addr): """Connects to remote ADDR, and then wraps the connection in an SSL channel.""" - + if self.server_side: + raise ValueError("can't connect in server-side mode") # Here we assume that the socket is client-side, and not # connected at the time of the call. We connect it, then wrap it. if self._sslobj: diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 71ba9e1..b485605 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -172,6 +172,19 @@ class BasicSocketTests(unittest.TestCase): ss = ssl.wrap_socket(s) self.assertEqual(timeout, ss.gettimeout()) + def test_errors(self): + sock = socket.socket() + with self.assertRaisesRegexp(ValueError, "certfile must be specified"): + ssl.wrap_socket(sock, server_side=True) + ssl.wrap_socket(sock, server_side=True, certfile="") + s = ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) + self.assertRaisesRegexp(ValueError, "can't connect in server-side mode", + s.connect, (HOST, 8080)) + with self.assertRaisesRegexp(IOError, "No such file"): + ssl.wrap_socket(sock, certfile=WRONGCERT) + ssl.wrap_socket(sock, keyfile=WRONGCERT) + ssl.wrap_socket(sock, certfile=WRONGCERT, keyfile=WRONGCERT) + class ContextTests(unittest.TestCase): @@ -240,7 +253,7 @@ class ContextTests(unittest.TestCase): ctx.load_cert_chain(CERTFILE) ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE) self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE) - with self.assertRaisesRegexp(ssl.SSLError, "system lib"): + with self.assertRaisesRegexp(IOError, "No such file"): ctx.load_cert_chain(WRONGCERT) with self.assertRaisesRegexp(ssl.SSLError, "PEM lib"): ctx.load_cert_chain(BADCERT) @@ -270,7 +283,7 @@ class ContextTests(unittest.TestCase): ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None) self.assertRaises(TypeError, ctx.load_verify_locations) self.assertRaises(TypeError, ctx.load_verify_locations, None, None) - with self.assertRaisesRegexp(ssl.SSLError, "system lib"): + with self.assertRaisesRegexp(IOError, "No such file"): ctx.load_verify_locations(WRONGCERT) with self.assertRaisesRegexp(ssl.SSLError, "PEM lib"): ctx.load_verify_locations(BADCERT) @@ -863,6 +876,9 @@ else: except socket.error as x: if support.verbose: sys.stdout.write("\nsocket.error is %s\n" % x[1]) + except IOError as x: + if support.verbose: + sys.stdout.write("\nsocket.error is %s\n" % str(x)) else: raise AssertionError("Use of invalid cert should have failed!") finally: diff --git a/Modules/_ssl.c b/Modules/_ssl.c index f8428c4..0008691 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -1580,6 +1580,7 @@ load_cert_chain(PySSLContext *self, PyObject *args, PyObject *kwds) PyObject *certfile_bytes = NULL, *keyfile_bytes = NULL; int r; + errno = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:load_cert_chain", kwlist, &certfile, &keyfile)) @@ -1601,7 +1602,12 @@ load_cert_chain(PySSLContext *self, PyObject *args, PyObject *kwds) PyBytes_AS_STRING(certfile_bytes)); PySSL_END_ALLOW_THREADS if (r != 1) { - _setSSLError(NULL, 0, __FILE__, __LINE__); + if (errno != 0) { + PyErr_SetFromErrno(PyExc_IOError); + } + else { + _setSSLError(NULL, 0, __FILE__, __LINE__); + } goto error; } PySSL_BEGIN_ALLOW_THREADS @@ -1612,7 +1618,12 @@ load_cert_chain(PySSLContext *self, PyObject *args, PyObject *kwds) Py_XDECREF(keyfile_bytes); Py_XDECREF(certfile_bytes); if (r != 1) { - _setSSLError(NULL, 0, __FILE__, __LINE__); + if (errno != 0) { + PyErr_SetFromErrno(PyExc_IOError); + } + else { + _setSSLError(NULL, 0, __FILE__, __LINE__); + } return NULL; } PySSL_BEGIN_ALLOW_THREADS @@ -1639,6 +1650,7 @@ load_verify_locations(PySSLContext *self, PyObject *args, PyObject *kwds) const char *cafile_buf = NULL, *capath_buf = NULL; int r; + errno = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO:load_verify_locations", kwlist, &cafile, &capath)) @@ -1673,7 +1685,12 @@ load_verify_locations(PySSLContext *self, PyObject *args, PyObject *kwds) Py_XDECREF(cafile_bytes); Py_XDECREF(capath_bytes); if (r != 1) { - _setSSLError(NULL, 0, __FILE__, __LINE__); + if (errno != 0) { + PyErr_SetFromErrno(PyExc_IOError); + } + else { + _setSSLError(NULL, 0, __FILE__, __LINE__); + } return NULL; } Py_RETURN_NONE; -- cgit v0.12