diff options
Diffstat (limited to 'Modules/_ssl.c')
-rw-r--r-- | Modules/_ssl.c | 78 |
1 files changed, 65 insertions, 13 deletions
diff --git a/Modules/_ssl.c b/Modules/_ssl.c index 0a42fe7..cfcb8a5 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -168,6 +168,8 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file) PySSLObject *self; char *errstr = NULL; int ret; + int err; + int timedout; self = PyObject_New(PySSLObject, &PySSL_Type); /* Create new object */ if (self == NULL){ @@ -220,14 +222,38 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file) self->ssl = SSL_new(self->ctx); /* New ssl struct */ Py_END_ALLOW_THREADS SSL_set_fd(self->ssl, Sock->sock_fd); /* Set the socket for SSL */ + + /* If the socket is is non-blocking mode or timeout mode, set the BIO + * to non-blocking mode (blocking is the default) + */ + if (Sock->sock_timeout >= 0.0) { + /* Set both the read and write BIO's to non-blocking mode */ + BIO_set_nbio(SSL_get_rbio(self->ssl), 1); + BIO_set_nbio(SSL_get_wbio(self->ssl), 1); + } + Py_BEGIN_ALLOW_THREADS SSL_set_connect_state(self->ssl); - + Py_END_ALLOW_THREADS /* Actually negotiate SSL connection */ /* XXX If SSL_connect() returns 0, it's also a failure. */ - ret = SSL_connect(self->ssl); - Py_END_ALLOW_THREADS + timedout = 0; + do { + Py_BEGIN_ALLOW_THREADS + ret = SSL_connect(self->ssl); + err = SSL_get_error(self->ssl, ret); + Py_END_ALLOW_THREADS + if (err == SSL_ERROR_WANT_READ) { + timedout = wait_for_timeout(Sock, 0); + } else if (err == SSL_ERROR_WANT_WRITE) { + timedout = wait_for_timeout(Sock, 1); + } + if (timedout) { + PyErr_SetString(PySSLErrorObject, "The connect operation timed out"); + return NULL; + } + } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); if (ret <= 0) { PySSL_SetError(self, ret); goto fail; @@ -328,10 +354,12 @@ wait_for_timeout(PySocketSockObject *s, int writing) FD_SET(s->sock_fd, &fds); /* See if the socket is ready */ + Py_BEGIN_ALLOW_THREADS if (writing) rc = select(s->sock_fd+1, NULL, &fds, NULL, &tv); else rc = select(s->sock_fd+1, &fds, NULL, NULL, &tv); + Py_END_ALLOW_THREADS /* Return 1 on timeout, 0 otherwise */ return rc == 0; @@ -342,20 +370,32 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args) char *data; int len; int timedout; + int err; if (!PyArg_ParseTuple(args, "s#:write", &data, &len)) return NULL; - Py_BEGIN_ALLOW_THREADS timedout = wait_for_timeout(self->Socket, 1); - Py_END_ALLOW_THREADS if (timedout) { PyErr_SetString(PySSLErrorObject, "The write operation timed out"); return NULL; } - Py_BEGIN_ALLOW_THREADS - len = SSL_write(self->ssl, data, len); - Py_END_ALLOW_THREADS + do { + err = 0; + Py_BEGIN_ALLOW_THREADS + len = SSL_write(self->ssl, data, len); + err = SSL_get_error(self->ssl, len); + Py_END_ALLOW_THREADS + if (err == SSL_ERROR_WANT_READ) { + timedout = wait_for_timeout(self->Socket, 0); + } else if (err == SSL_ERROR_WANT_WRITE) { + timedout = wait_for_timeout(self->Socket, 1); + } + if (timedout) { + PyErr_SetString(PySSLErrorObject, "The write operation timed out"); + return NULL; + } + } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); if (len > 0) return PyInt_FromLong(len); else @@ -374,6 +414,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) int count = 0; int len = 1024; int timedout; + int err; if (!PyArg_ParseTuple(args, "|i:read", &len)) return NULL; @@ -381,16 +422,27 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) if (!(buf = PyString_FromStringAndSize((char *) 0, len))) return NULL; - Py_BEGIN_ALLOW_THREADS timedout = wait_for_timeout(self->Socket, 0); - Py_END_ALLOW_THREADS if (timedout) { PyErr_SetString(PySSLErrorObject, "The read operation timed out"); return NULL; } - Py_BEGIN_ALLOW_THREADS - count = SSL_read(self->ssl, PyString_AsString(buf), len); - Py_END_ALLOW_THREADS + do { + err = 0; + Py_BEGIN_ALLOW_THREADS + count = SSL_read(self->ssl, PyString_AsString(buf), len); + err = SSL_get_error(self->ssl, count); + Py_END_ALLOW_THREADS + if (err == SSL_ERROR_WANT_READ) { + timedout = wait_for_timeout(self->Socket, 0); + } else if (err == SSL_ERROR_WANT_WRITE) { + timedout = wait_for_timeout(self->Socket, 1); + } + if (timedout) { + PyErr_SetString(PySSLErrorObject, "The read operation timed out"); + return NULL; + } + } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); if (count <= 0) { Py_DECREF(buf); return PySSL_SetError(self, count); |