diff options
Diffstat (limited to 'Modules/_ssl.c')
-rw-r--r-- | Modules/_ssl.c | 383 |
1 files changed, 331 insertions, 52 deletions
diff --git a/Modules/_ssl.c b/Modules/_ssl.c index abb48b9..88525c8 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -40,6 +40,61 @@ #endif +/* Include symbols from _socket module */ +#include "socketmodule.h" + +static PySocketModule_APIObject PySocketModule; + +#if defined(HAVE_POLL_H) +#include <poll.h> +#elif defined(HAVE_SYS_POLL_H) +#include <sys/poll.h> +#endif + +/* Include OpenSSL header files */ +#include "openssl/rsa.h" +#include "openssl/crypto.h" +#include "openssl/x509.h" +#include "openssl/x509v3.h" +#include "openssl/pem.h" +#include "openssl/ssl.h" +#include "openssl/err.h" +#include "openssl/rand.h" + +/* SSL error object */ +static PyObject *PySSLErrorObject; +static PyObject *PySSLZeroReturnErrorObject; +static PyObject *PySSLWantReadErrorObject; +static PyObject *PySSLWantWriteErrorObject; +static PyObject *PySSLSyscallErrorObject; +static PyObject *PySSLEOFErrorObject; + +/* Error mappings */ +static PyObject *err_codes_to_names; +static PyObject *err_names_to_codes; +static PyObject *lib_codes_to_names; + +struct py_ssl_error_code { + const char *mnemonic; + int library, reason; +}; +struct py_ssl_library_code { + const char *library; + int code; +}; + +/* Include generated data (error codes) */ +#include "_ssl_data.h" + +/* Openssl comes with TLSv1.1 and TLSv1.2 between 1.0.0h and 1.0.1 + http://www.openssl.org/news/changelog.html + */ +#if OPENSSL_VERSION_NUMBER >= 0x10001000L +# define HAVE_TLSv1_2 1 +#else +# define HAVE_TLSv1_2 0 +#endif + enum py_ssl_error { /* these mirror ssl.h */ PY_SSL_ERROR_NONE, @@ -73,55 +128,14 @@ enum py_ssl_version { #endif PY_SSL_VERSION_SSL3=1, PY_SSL_VERSION_SSL23, +#if HAVE_TLSv1_2 + PY_SSL_VERSION_TLS1, + PY_SSL_VERSION_TLS1_1, + PY_SSL_VERSION_TLS1_2 +#else PY_SSL_VERSION_TLS1 -}; - -struct py_ssl_error_code { - const char *mnemonic; - int library, reason; -}; - -struct py_ssl_library_code { - const char *library; - int code; -}; - -/* Include symbols from _socket module */ -#include "socketmodule.h" - -static PySocketModule_APIObject PySocketModule; - -#if defined(HAVE_POLL_H) -#include <poll.h> -#elif defined(HAVE_SYS_POLL_H) -#include <sys/poll.h> #endif - -/* Include OpenSSL header files */ -#include "openssl/rsa.h" -#include "openssl/crypto.h" -#include "openssl/x509.h" -#include "openssl/x509v3.h" -#include "openssl/pem.h" -#include "openssl/ssl.h" -#include "openssl/err.h" -#include "openssl/rand.h" - -/* Include generated data (error codes) */ -#include "_ssl_data.h" - -/* SSL error object */ -static PyObject *PySSLErrorObject; -static PyObject *PySSLZeroReturnErrorObject; -static PyObject *PySSLWantReadErrorObject; -static PyObject *PySSLWantWriteErrorObject; -static PyObject *PySSLSyscallErrorObject; -static PyObject *PySSLEOFErrorObject; - -/* Error mappings */ -static PyObject *err_codes_to_names; -static PyObject *err_names_to_codes; -static PyObject *lib_codes_to_names; +}; #ifdef WITH_THREAD @@ -181,12 +195,16 @@ typedef struct { char *npn_protocols; int npn_protocols_len; #endif +#ifndef OPENSSL_NO_TLSEXT + PyObject *set_hostname; +#endif } PySSLContext; typedef struct { PyObject_HEAD PyObject *Socket; /* weakref to socket on which we're layered */ SSL *ssl; + PySSLContext *ctx; /* weakref to SSL context */ X509 *peer_cert; int shutdown_seen_zero; enum py_ssl_server_or_client socket_type; @@ -437,11 +455,12 @@ _setSSLError (char *errstr, int errcode, char *filename, int lineno) { */ static PySSLSocket * -newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock, +newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, enum py_ssl_server_or_client socket_type, char *server_hostname) { PySSLSocket *self; + SSL_CTX *ctx = sslctx->ctx; self = PyObject_New(PySSLSocket, &PySSLSocket_Type); if (self == NULL) @@ -450,6 +469,8 @@ newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock, self->peer_cert = NULL; self->ssl = NULL; self->Socket = NULL; + self->ctx = sslctx; + Py_INCREF(sslctx); /* Make sure the SSL error state is initialized */ (void) ERR_get_state(); @@ -458,6 +479,7 @@ newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock, PySSL_BEGIN_ALLOW_THREADS self->ssl = SSL_new(ctx); PySSL_END_ALLOW_THREADS + SSL_set_app_data(self->ssl,self); SSL_set_fd(self->ssl, sock->sock_fd); #ifdef SSL_MODE_AUTO_RETRY SSL_set_mode(self->ssl, SSL_MODE_AUTO_RETRY); @@ -1164,6 +1186,38 @@ static PyObject *PySSL_compression(PySSLSocket *self) { #endif } +static PySSLContext *PySSL_get_context(PySSLSocket *self, void *closure) { + Py_INCREF(self->ctx); + return self->ctx; +} + +static int PySSL_set_context(PySSLSocket *self, PyObject *value, + void *closure) { + + if (PyObject_TypeCheck(value, &PySSLContext_Type)) { + + Py_INCREF(value); + Py_DECREF(self->ctx); + self->ctx = (PySSLContext *) value; + SSL_set_SSL_CTX(self->ssl, self->ctx->ctx); + } else { + PyErr_SetString(PyExc_TypeError, "The value must be a SSLContext"); + return -1; + } + + return 0; +} + +PyDoc_STRVAR(PySSL_set_context_doc, +"_setter_context(ctx)\n\ +\ +This changes the context associated with the SSLSocket. This is typically\n\ +used from within a callback function set by the set_servername_callback\n\ +on the SSLContext to change the certificate information associated with the\n\ +SSLSocket before the cryptographic exchange handshake messages\n"); + + + static void PySSL_dealloc(PySSLSocket *self) { if (self->peer_cert) /* Possible not to have one? */ @@ -1171,6 +1225,7 @@ static void PySSL_dealloc(PySSLSocket *self) if (self->ssl) SSL_free(self->ssl); Py_XDECREF(self->Socket); + Py_XDECREF(self->ctx); PyObject_Del(self); } @@ -1606,6 +1661,12 @@ If the TLS handshake is not yet complete, None is returned"); #endif /* HAVE_OPENSSL_FINISHED */ +static PyGetSetDef ssl_getsetlist[] = { + {"context", (getter) PySSL_get_context, + (setter) PySSL_set_context, PySSL_set_context_doc}, + {NULL}, /* sentinel */ +}; + static PyMethodDef PySSLMethods[] = { {"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS}, {"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS, @@ -1660,6 +1721,8 @@ static PyTypeObject PySSLSocket_Type = { 0, /*tp_iter*/ 0, /*tp_iternext*/ PySSLMethods, /*tp_methods*/ + 0, /*tp_members*/ + ssl_getsetlist, /*tp_getset*/ }; @@ -1683,6 +1746,12 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds) PySSL_BEGIN_ALLOW_THREADS if (proto_version == PY_SSL_VERSION_TLS1) ctx = SSL_CTX_new(TLSv1_method()); +#if HAVE_TLSv1_2 + else if (proto_version == PY_SSL_VERSION_TLS1_1) + ctx = SSL_CTX_new(TLSv1_1_method()); + else if (proto_version == PY_SSL_VERSION_TLS1_2) + ctx = SSL_CTX_new(TLSv1_2_method()); +#endif else if (proto_version == PY_SSL_VERSION_SSL3) ctx = SSL_CTX_new(SSLv3_method()); #ifndef OPENSSL_NO_SSL2 @@ -1716,6 +1785,9 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds) #ifdef OPENSSL_NPN_NEGOTIATED self->npn_protocols = NULL; #endif +#ifndef OPENSSL_NO_TLSEXT + self->set_hostname = NULL; +#endif /* Defaults */ SSL_CTX_set_verify(self->ctx, SSL_VERIFY_NONE, NULL); SSL_CTX_set_options(self->ctx, @@ -1729,9 +1801,28 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds) return (PyObject *)self; } +static int +context_traverse(PySSLContext *self, visitproc visit, void *arg) +{ +#ifndef OPENSSL_NO_TLSEXT + Py_VISIT(self->set_hostname); +#endif + return 0; +} + +static int +context_clear(PySSLContext *self) +{ +#ifndef OPENSSL_NO_TLSEXT + Py_CLEAR(self->set_hostname); +#endif + return 0; +} + static void context_dealloc(PySSLContext *self) { + context_clear(self); SSL_CTX_free(self->ctx); #ifdef OPENSSL_NPN_NEGOTIATED PyMem_Free(self->npn_protocols); @@ -2224,7 +2315,7 @@ context_wrap_socket(PySSLContext *self, PyObject *args, PyObject *kwds) #endif } - res = (PyObject *) newPySSLSocket(self->ctx, sock, server_side, + res = (PyObject *) newPySSLSocket(self, sock, server_side, hostname); if (hostname != NULL) PyMem_Free(hostname); @@ -2309,6 +2400,137 @@ set_ecdh_curve(PySSLContext *self, PyObject *name) } #endif +#ifndef OPENSSL_NO_TLSEXT +static int +_servername_callback(SSL *s, int *al, void *args) +{ + int ret; + PySSLContext *ssl_ctx = (PySSLContext *) args; + PySSLSocket *ssl; + PyObject *servername_o; + PyObject *servername_idna; + PyObject *result; + /* The high-level ssl.SSLSocket object */ + PyObject *ssl_socket; + const char *servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name); +#ifdef WITH_THREAD + PyGILState_STATE gstate = PyGILState_Ensure(); +#endif + + if (ssl_ctx->set_hostname == NULL) { + /* remove race condition in this the call back while if removing the + * callback is in progress */ +#ifdef WITH_THREAD + PyGILState_Release(gstate); +#endif + return SSL_TLSEXT_ERR_OK; + } + + ssl = SSL_get_app_data(s); + assert(PySSLSocket_Check(ssl)); + ssl_socket = PyWeakref_GetObject(ssl->Socket); + Py_INCREF(ssl_socket); + if (ssl_socket == Py_None) { + goto error; + } + + servername_o = PyBytes_FromString(servername); + if (servername_o == NULL) { + PyErr_WriteUnraisable((PyObject *) ssl_ctx); + goto error; + } + servername_idna = PyUnicode_FromEncodedObject(servername_o, "idna", NULL); + if (servername_idna == NULL) { + PyErr_WriteUnraisable(servername_o); + Py_DECREF(servername_o); + goto error; + } + Py_DECREF(servername_o); + result = PyObject_CallFunctionObjArgs(ssl_ctx->set_hostname, ssl_socket, + servername_idna, ssl_ctx, NULL); + Py_DECREF(ssl_socket); + Py_DECREF(servername_idna); + + if (result == NULL) { + PyErr_WriteUnraisable(ssl_ctx->set_hostname); + *al = SSL_AD_HANDSHAKE_FAILURE; + ret = SSL_TLSEXT_ERR_ALERT_FATAL; + } + else { + if (result != Py_None) { + *al = (int) PyLong_AsLong(result); + if (PyErr_Occurred()) { + PyErr_WriteUnraisable(result); + *al = SSL_AD_INTERNAL_ERROR; + } + ret = SSL_TLSEXT_ERR_ALERT_FATAL; + } + else { + ret = SSL_TLSEXT_ERR_OK; + } + Py_DECREF(result); + } + +#ifdef WITH_THREAD + PyGILState_Release(gstate); +#endif + return ret; + +error: + Py_DECREF(ssl_socket); + *al = SSL_AD_INTERNAL_ERROR; + ret = SSL_TLSEXT_ERR_ALERT_FATAL; +#ifdef WITH_THREAD + PyGILState_Release(gstate); +#endif + return ret; +} + +PyDoc_STRVAR(PySSL_set_servername_callback_doc, +"set_servername_callback(method)\n\ +\ +This sets a callback that will be called when a server name is provided by\n\ +the SSL/TLS client in the SNI extension.\n\ +\ +If the argument is None then the callback is disabled. The method is called\n\ +with the SSLSocket, the server name as a string, and the SSLContext object.\n\ +See RFC 6066 for details of the SNI"); +#endif + +static PyObject * +set_servername_callback(PySSLContext *self, PyObject *args) +{ +#ifndef OPENSSL_NO_TLSEXT + PyObject *cb; + + if (!PyArg_ParseTuple(args, "O", &cb)) + return NULL; + + Py_CLEAR(self->set_hostname); + if (cb == Py_None) { + SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL); + } + else { + if (!PyCallable_Check(cb)) { + SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL); + PyErr_SetString(PyExc_TypeError, + "not a callable object"); + return NULL; + } + Py_INCREF(cb); + self->set_hostname = cb; + SSL_CTX_set_tlsext_servername_callback(self->ctx, _servername_callback); + SSL_CTX_set_tlsext_servername_arg(self->ctx, self); + } + Py_RETURN_NONE; +#else + PyErr_SetString(PyExc_NotImplementedError, + "The TLS extension servername callback, " + "SSL_CTX_set_tlsext_servername_callback, " + "is not in the current OpenSSL library."); +#endif +} + static PyGetSetDef context_getsetlist[] = { {"options", (getter) get_options, (setter) set_options, NULL}, @@ -2338,6 +2560,8 @@ static struct PyMethodDef context_methods[] = { {"set_ecdh_curve", (PyCFunction) set_ecdh_curve, METH_O, NULL}, #endif + {"set_servername_callback", (PyCFunction) set_servername_callback, + METH_VARARGS, PySSL_set_servername_callback_doc}, {NULL, NULL} /* sentinel */ }; @@ -2361,10 +2585,10 @@ static PyTypeObject PySSLContext_Type = { 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /*tp_flags*/ 0, /*tp_doc*/ - 0, /*tp_traverse*/ - 0, /*tp_clear*/ + (traverseproc) context_traverse, /*tp_traverse*/ + (inquiry) context_clear, /*tp_clear*/ 0, /*tp_richcompare*/ 0, /*tp_weaklistoffset*/ 0, /*tp_iter*/ @@ -2744,6 +2968,51 @@ PyInit__ssl(void) PyModule_AddIntConstant(m, "CERT_REQUIRED", PY_SSL_CERT_REQUIRED); + /* Alert Descriptions from ssl.h */ + /* note RESERVED constants no longer intended for use have been removed */ + /* http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6 */ + +#define ADD_AD_CONSTANT(s) \ + PyModule_AddIntConstant(m, "ALERT_DESCRIPTION_"#s, \ + SSL_AD_##s) + + ADD_AD_CONSTANT(CLOSE_NOTIFY); + ADD_AD_CONSTANT(UNEXPECTED_MESSAGE); + ADD_AD_CONSTANT(BAD_RECORD_MAC); + ADD_AD_CONSTANT(RECORD_OVERFLOW); + ADD_AD_CONSTANT(DECOMPRESSION_FAILURE); + ADD_AD_CONSTANT(HANDSHAKE_FAILURE); + ADD_AD_CONSTANT(BAD_CERTIFICATE); + ADD_AD_CONSTANT(UNSUPPORTED_CERTIFICATE); + ADD_AD_CONSTANT(CERTIFICATE_REVOKED); + ADD_AD_CONSTANT(CERTIFICATE_EXPIRED); + ADD_AD_CONSTANT(CERTIFICATE_UNKNOWN); + ADD_AD_CONSTANT(ILLEGAL_PARAMETER); + ADD_AD_CONSTANT(UNKNOWN_CA); + ADD_AD_CONSTANT(ACCESS_DENIED); + ADD_AD_CONSTANT(DECODE_ERROR); + ADD_AD_CONSTANT(DECRYPT_ERROR); + ADD_AD_CONSTANT(PROTOCOL_VERSION); + ADD_AD_CONSTANT(INSUFFICIENT_SECURITY); + ADD_AD_CONSTANT(INTERNAL_ERROR); + ADD_AD_CONSTANT(USER_CANCELLED); + ADD_AD_CONSTANT(NO_RENEGOTIATION); + ADD_AD_CONSTANT(UNSUPPORTED_EXTENSION); + ADD_AD_CONSTANT(CERTIFICATE_UNOBTAINABLE); + ADD_AD_CONSTANT(UNRECOGNIZED_NAME); + /* Not all constants are in old OpenSSL versions */ +#ifdef SSL_AD_BAD_CERTIFICATE_STATUS_RESPONSE + ADD_AD_CONSTANT(BAD_CERTIFICATE_STATUS_RESPONSE); +#endif +#ifdef SSL_AD_BAD_CERTIFICATE_HASH_VALUE + ADD_AD_CONSTANT(BAD_CERTIFICATE_HASH_VALUE); +#endif +#ifdef SSL_AD_UNKNOWN_PSK_IDENTITY + ADD_AD_CONSTANT(UNKNOWN_PSK_IDENTITY); +#endif + +#undef ADD_AD_CONSTANT + /* protocol versions */ #ifndef OPENSSL_NO_SSL2 PyModule_AddIntConstant(m, "PROTOCOL_SSLv2", @@ -2755,6 +3024,12 @@ PyInit__ssl(void) PY_SSL_VERSION_SSL23); PyModule_AddIntConstant(m, "PROTOCOL_TLSv1", PY_SSL_VERSION_TLS1); +#if HAVE_TLSv1_2 + PyModule_AddIntConstant(m, "PROTOCOL_TLSv1_1", + PY_SSL_VERSION_TLS1_1); + PyModule_AddIntConstant(m, "PROTOCOL_TLSv1_2", + PY_SSL_VERSION_TLS1_2); +#endif /* protocol options */ PyModule_AddIntConstant(m, "OP_ALL", @@ -2762,6 +3037,10 @@ PyInit__ssl(void) PyModule_AddIntConstant(m, "OP_NO_SSLv2", SSL_OP_NO_SSLv2); PyModule_AddIntConstant(m, "OP_NO_SSLv3", SSL_OP_NO_SSLv3); PyModule_AddIntConstant(m, "OP_NO_TLSv1", SSL_OP_NO_TLSv1); +#if HAVE_TLSv1_2 + PyModule_AddIntConstant(m, "OP_NO_TLSv1_1", SSL_OP_NO_TLSv1_1); + PyModule_AddIntConstant(m, "OP_NO_TLSv1_2", SSL_OP_NO_TLSv1_2); +#endif PyModule_AddIntConstant(m, "OP_CIPHER_SERVER_PREFERENCE", SSL_OP_CIPHER_SERVER_PREFERENCE); PyModule_AddIntConstant(m, "OP_SINGLE_DH_USE", SSL_OP_SINGLE_DH_USE); |