diff options
Diffstat (limited to 'Modules/_ssl.c')
-rw-r--r-- | Modules/_ssl.c | 86 |
1 files changed, 60 insertions, 26 deletions
diff --git a/Modules/_ssl.c b/Modules/_ssl.c index bee1040..76cd7dd 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -64,10 +64,19 @@ typedef struct { static PyTypeObject PySSL_Type; static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args); static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args); -static int wait_for_timeout(PySocketSockObject *s, int writing); +static int check_socket_and_wait_for_timeout(PySocketSockObject *s, + int writing); #define PySSLObject_Check(v) ((v)->ob_type == &PySSL_Type) +typedef enum { + SOCKET_IS_NONBLOCKING, + SOCKET_IS_BLOCKING, + SOCKET_HAS_TIMED_OUT, + SOCKET_HAS_BEEN_CLOSED, + SOCKET_OPERATION_OK +} timeout_state; + /* XXX It might be helpful to augment the error message generated below with the name of the SSL function that generated the error. I expect it's obvious most of the time. @@ -170,7 +179,7 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file) char *errstr = NULL; int ret; int err; - int timedout; + int sockstate; self = PyObject_New(PySSLObject, &PySSL_Type); /* Create new object */ if (self == NULL){ @@ -239,7 +248,7 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file) /* Actually negotiate SSL connection */ /* XXX If SSL_connect() returns 0, it's also a failure. */ - timedout = 0; + sockstate = 0; do { Py_BEGIN_ALLOW_THREADS ret = SSL_connect(self->ssl); @@ -249,13 +258,20 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file) goto fail; } if (err == SSL_ERROR_WANT_READ) { - timedout = wait_for_timeout(Sock, 0); + sockstate = check_socket_and_wait_for_timeout(Sock, 0); } else if (err == SSL_ERROR_WANT_WRITE) { - timedout = wait_for_timeout(Sock, 1); + sockstate = check_socket_and_wait_for_timeout(Sock, 1); + } else { + sockstate = SOCKET_OPERATION_OK; } - if (timedout) { - errstr = "The connect operation timed out"; + if (sockstate == SOCKET_HAS_TIMED_OUT) { + PyErr_SetString(PySSLErrorObject, "The connect operation timed out"); goto fail; + } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { + PyErr_SetString(PySSLErrorObject, "Underlying socket has been closed."); + goto fail; + } else if (sockstate == SOCKET_IS_NONBLOCKING) { + break; } } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); if (ret <= 0) { @@ -334,22 +350,25 @@ static void PySSL_dealloc(PySSLObject *self) /* If the socket has a timeout, do a select() on the socket. The argument writing indicates the direction. - Return non-zero if the socket timed out, zero otherwise. + Returns one of the possibilities in the timeout_state enum (above). */ + static int -wait_for_timeout(PySocketSockObject *s, int writing) +check_socket_and_wait_for_timeout(PySocketSockObject *s, int writing) { fd_set fds; struct timeval tv; int rc; /* Nothing to do unless we're in timeout mode (not non-blocking) */ - if (s->sock_timeout <= 0.0) - return 0; + if (s->sock_timeout < 0.0) + return SOCKET_IS_BLOCKING; + else if (s->sock_timeout == 0.0) + return SOCKET_IS_NONBLOCKING; /* Guard against closed socket */ if (s->sock_fd < 0) - return 0; + return SOCKET_HAS_BEEN_CLOSED; /* Construct the arguments to select */ tv.tv_sec = (int)s->sock_timeout; @@ -365,8 +384,9 @@ wait_for_timeout(PySocketSockObject *s, int writing) rc = select(s->sock_fd+1, &fds, NULL, NULL, &tv); Py_END_ALLOW_THREADS - /* Return 1 on timeout, 0 otherwise */ - return rc == 0; + /* Return SOCKET_TIMED_OUT on timeout, SOCKET_OPERATION_OK otherwise + (when we are able to write or when there's something to read) */ + return rc == 0 ? SOCKET_HAS_TIMED_OUT : SOCKET_OPERATION_OK; } static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args) @@ -374,16 +394,19 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args) char *data; int len; int count; - int timedout; + int sockstate; int err; if (!PyArg_ParseTuple(args, "s#:write", &data, &count)) return NULL; - timedout = wait_for_timeout(self->Socket, 1); - if (timedout) { + sockstate = check_socket_and_wait_for_timeout(self->Socket, 1); + if (sockstate == SOCKET_HAS_TIMED_OUT) { PyErr_SetString(PySSLErrorObject, "The write operation timed out"); return NULL; + } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { + PyErr_SetString(PySSLErrorObject, "Underlying socket has been closed."); + return NULL; } do { err = 0; @@ -395,13 +418,20 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args) return NULL; } if (err == SSL_ERROR_WANT_READ) { - timedout = wait_for_timeout(self->Socket, 0); + sockstate = check_socket_and_wait_for_timeout(self->Socket, 0); } else if (err == SSL_ERROR_WANT_WRITE) { - timedout = wait_for_timeout(self->Socket, 1); + sockstate = check_socket_and_wait_for_timeout(self->Socket, 1); + } else { + sockstate = SOCKET_OPERATION_OK; } - if (timedout) { + if (sockstate == SOCKET_HAS_TIMED_OUT) { PyErr_SetString(PySSLErrorObject, "The write operation timed out"); return NULL; + } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { + PyErr_SetString(PySSLErrorObject, "Underlying socket has been closed."); + return NULL; + } else if (sockstate == SOCKET_IS_NONBLOCKING) { + break; } } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); if (len > 0) @@ -421,7 +451,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) PyObject *buf; int count = 0; int len = 1024; - int timedout; + int sockstate; int err; if (!PyArg_ParseTuple(args, "|i:read", &len)) @@ -430,8 +460,8 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) if (!(buf = PyString_FromStringAndSize((char *) 0, len))) return NULL; - timedout = wait_for_timeout(self->Socket, 0); - if (timedout) { + sockstate = check_socket_and_wait_for_timeout(self->Socket, 0); + if (sockstate == SOCKET_HAS_TIMED_OUT) { PyErr_SetString(PySSLErrorObject, "The read operation timed out"); Py_DECREF(buf); return NULL; @@ -447,14 +477,18 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args) return NULL; } if (err == SSL_ERROR_WANT_READ) { - timedout = wait_for_timeout(self->Socket, 0); + sockstate = check_socket_and_wait_for_timeout(self->Socket, 0); } else if (err == SSL_ERROR_WANT_WRITE) { - timedout = wait_for_timeout(self->Socket, 1); + sockstate = check_socket_and_wait_for_timeout(self->Socket, 1); + } else { + sockstate = SOCKET_OPERATION_OK; } - if (timedout) { + if (sockstate == SOCKET_HAS_TIMED_OUT) { PyErr_SetString(PySSLErrorObject, "The read operation timed out"); Py_DECREF(buf); return NULL; + } else if (sockstate == SOCKET_IS_NONBLOCKING) { + break; } } while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); if (count <= 0) { |