summaryrefslogtreecommitdiffstats
path: root/Modules/_ssl.c
diff options
context:
space:
mode:
Diffstat (limited to 'Modules/_ssl.c')
-rw-r--r--Modules/_ssl.c383
1 files changed, 331 insertions, 52 deletions
diff --git a/Modules/_ssl.c b/Modules/_ssl.c
index abb48b9..88525c8 100644
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -40,6 +40,61 @@
#endif
+/* Include symbols from _socket module */
+#include "socketmodule.h"
+
+static PySocketModule_APIObject PySocketModule;
+
+#if defined(HAVE_POLL_H)
+#include <poll.h>
+#elif defined(HAVE_SYS_POLL_H)
+#include <sys/poll.h>
+#endif
+
+/* Include OpenSSL header files */
+#include "openssl/rsa.h"
+#include "openssl/crypto.h"
+#include "openssl/x509.h"
+#include "openssl/x509v3.h"
+#include "openssl/pem.h"
+#include "openssl/ssl.h"
+#include "openssl/err.h"
+#include "openssl/rand.h"
+
+/* SSL error object */
+static PyObject *PySSLErrorObject;
+static PyObject *PySSLZeroReturnErrorObject;
+static PyObject *PySSLWantReadErrorObject;
+static PyObject *PySSLWantWriteErrorObject;
+static PyObject *PySSLSyscallErrorObject;
+static PyObject *PySSLEOFErrorObject;
+
+/* Error mappings */
+static PyObject *err_codes_to_names;
+static PyObject *err_names_to_codes;
+static PyObject *lib_codes_to_names;
+
+struct py_ssl_error_code {
+ const char *mnemonic;
+ int library, reason;
+};
+struct py_ssl_library_code {
+ const char *library;
+ int code;
+};
+
+/* Include generated data (error codes) */
+#include "_ssl_data.h"
+
+/* Openssl comes with TLSv1.1 and TLSv1.2 between 1.0.0h and 1.0.1
+ http://www.openssl.org/news/changelog.html
+ */
+#if OPENSSL_VERSION_NUMBER >= 0x10001000L
+# define HAVE_TLSv1_2 1
+#else
+# define HAVE_TLSv1_2 0
+#endif
+
enum py_ssl_error {
/* these mirror ssl.h */
PY_SSL_ERROR_NONE,
@@ -73,55 +128,14 @@ enum py_ssl_version {
#endif
PY_SSL_VERSION_SSL3=1,
PY_SSL_VERSION_SSL23,
+#if HAVE_TLSv1_2
+ PY_SSL_VERSION_TLS1,
+ PY_SSL_VERSION_TLS1_1,
+ PY_SSL_VERSION_TLS1_2
+#else
PY_SSL_VERSION_TLS1
-};
-
-struct py_ssl_error_code {
- const char *mnemonic;
- int library, reason;
-};
-
-struct py_ssl_library_code {
- const char *library;
- int code;
-};
-
-/* Include symbols from _socket module */
-#include "socketmodule.h"
-
-static PySocketModule_APIObject PySocketModule;
-
-#if defined(HAVE_POLL_H)
-#include <poll.h>
-#elif defined(HAVE_SYS_POLL_H)
-#include <sys/poll.h>
#endif
-
-/* Include OpenSSL header files */
-#include "openssl/rsa.h"
-#include "openssl/crypto.h"
-#include "openssl/x509.h"
-#include "openssl/x509v3.h"
-#include "openssl/pem.h"
-#include "openssl/ssl.h"
-#include "openssl/err.h"
-#include "openssl/rand.h"
-
-/* Include generated data (error codes) */
-#include "_ssl_data.h"
-
-/* SSL error object */
-static PyObject *PySSLErrorObject;
-static PyObject *PySSLZeroReturnErrorObject;
-static PyObject *PySSLWantReadErrorObject;
-static PyObject *PySSLWantWriteErrorObject;
-static PyObject *PySSLSyscallErrorObject;
-static PyObject *PySSLEOFErrorObject;
-
-/* Error mappings */
-static PyObject *err_codes_to_names;
-static PyObject *err_names_to_codes;
-static PyObject *lib_codes_to_names;
+};
#ifdef WITH_THREAD
@@ -181,12 +195,16 @@ typedef struct {
char *npn_protocols;
int npn_protocols_len;
#endif
+#ifndef OPENSSL_NO_TLSEXT
+ PyObject *set_hostname;
+#endif
} PySSLContext;
typedef struct {
PyObject_HEAD
PyObject *Socket; /* weakref to socket on which we're layered */
SSL *ssl;
+ PySSLContext *ctx; /* weakref to SSL context */
X509 *peer_cert;
int shutdown_seen_zero;
enum py_ssl_server_or_client socket_type;
@@ -437,11 +455,12 @@ _setSSLError (char *errstr, int errcode, char *filename, int lineno) {
*/
static PySSLSocket *
-newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock,
+newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
enum py_ssl_server_or_client socket_type,
char *server_hostname)
{
PySSLSocket *self;
+ SSL_CTX *ctx = sslctx->ctx;
self = PyObject_New(PySSLSocket, &PySSLSocket_Type);
if (self == NULL)
@@ -450,6 +469,8 @@ newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock,
self->peer_cert = NULL;
self->ssl = NULL;
self->Socket = NULL;
+ self->ctx = sslctx;
+ Py_INCREF(sslctx);
/* Make sure the SSL error state is initialized */
(void) ERR_get_state();
@@ -458,6 +479,7 @@ newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock,
PySSL_BEGIN_ALLOW_THREADS
self->ssl = SSL_new(ctx);
PySSL_END_ALLOW_THREADS
+ SSL_set_app_data(self->ssl,self);
SSL_set_fd(self->ssl, sock->sock_fd);
#ifdef SSL_MODE_AUTO_RETRY
SSL_set_mode(self->ssl, SSL_MODE_AUTO_RETRY);
@@ -1164,6 +1186,38 @@ static PyObject *PySSL_compression(PySSLSocket *self) {
#endif
}
+static PySSLContext *PySSL_get_context(PySSLSocket *self, void *closure) {
+ Py_INCREF(self->ctx);
+ return self->ctx;
+}
+
+static int PySSL_set_context(PySSLSocket *self, PyObject *value,
+ void *closure) {
+
+ if (PyObject_TypeCheck(value, &PySSLContext_Type)) {
+
+ Py_INCREF(value);
+ Py_DECREF(self->ctx);
+ self->ctx = (PySSLContext *) value;
+ SSL_set_SSL_CTX(self->ssl, self->ctx->ctx);
+ } else {
+ PyErr_SetString(PyExc_TypeError, "The value must be a SSLContext");
+ return -1;
+ }
+
+ return 0;
+}
+
+PyDoc_STRVAR(PySSL_set_context_doc,
+"_setter_context(ctx)\n\
+\
+This changes the context associated with the SSLSocket. This is typically\n\
+used from within a callback function set by the set_servername_callback\n\
+on the SSLContext to change the certificate information associated with the\n\
+SSLSocket before the cryptographic exchange handshake messages\n");
+
+
+
static void PySSL_dealloc(PySSLSocket *self)
{
if (self->peer_cert) /* Possible not to have one? */
@@ -1171,6 +1225,7 @@ static void PySSL_dealloc(PySSLSocket *self)
if (self->ssl)
SSL_free(self->ssl);
Py_XDECREF(self->Socket);
+ Py_XDECREF(self->ctx);
PyObject_Del(self);
}
@@ -1606,6 +1661,12 @@ If the TLS handshake is not yet complete, None is returned");
#endif /* HAVE_OPENSSL_FINISHED */
+static PyGetSetDef ssl_getsetlist[] = {
+ {"context", (getter) PySSL_get_context,
+ (setter) PySSL_set_context, PySSL_set_context_doc},
+ {NULL}, /* sentinel */
+};
+
static PyMethodDef PySSLMethods[] = {
{"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS},
{"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS,
@@ -1660,6 +1721,8 @@ static PyTypeObject PySSLSocket_Type = {
0, /*tp_iter*/
0, /*tp_iternext*/
PySSLMethods, /*tp_methods*/
+ 0, /*tp_members*/
+ ssl_getsetlist, /*tp_getset*/
};
@@ -1683,6 +1746,12 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
PySSL_BEGIN_ALLOW_THREADS
if (proto_version == PY_SSL_VERSION_TLS1)
ctx = SSL_CTX_new(TLSv1_method());
+#if HAVE_TLSv1_2
+ else if (proto_version == PY_SSL_VERSION_TLS1_1)
+ ctx = SSL_CTX_new(TLSv1_1_method());
+ else if (proto_version == PY_SSL_VERSION_TLS1_2)
+ ctx = SSL_CTX_new(TLSv1_2_method());
+#endif
else if (proto_version == PY_SSL_VERSION_SSL3)
ctx = SSL_CTX_new(SSLv3_method());
#ifndef OPENSSL_NO_SSL2
@@ -1716,6 +1785,9 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
#ifdef OPENSSL_NPN_NEGOTIATED
self->npn_protocols = NULL;
#endif
+#ifndef OPENSSL_NO_TLSEXT
+ self->set_hostname = NULL;
+#endif
/* Defaults */
SSL_CTX_set_verify(self->ctx, SSL_VERIFY_NONE, NULL);
SSL_CTX_set_options(self->ctx,
@@ -1729,9 +1801,28 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
return (PyObject *)self;
}
+static int
+context_traverse(PySSLContext *self, visitproc visit, void *arg)
+{
+#ifndef OPENSSL_NO_TLSEXT
+ Py_VISIT(self->set_hostname);
+#endif
+ return 0;
+}
+
+static int
+context_clear(PySSLContext *self)
+{
+#ifndef OPENSSL_NO_TLSEXT
+ Py_CLEAR(self->set_hostname);
+#endif
+ return 0;
+}
+
static void
context_dealloc(PySSLContext *self)
{
+ context_clear(self);
SSL_CTX_free(self->ctx);
#ifdef OPENSSL_NPN_NEGOTIATED
PyMem_Free(self->npn_protocols);
@@ -2224,7 +2315,7 @@ context_wrap_socket(PySSLContext *self, PyObject *args, PyObject *kwds)
#endif
}
- res = (PyObject *) newPySSLSocket(self->ctx, sock, server_side,
+ res = (PyObject *) newPySSLSocket(self, sock, server_side,
hostname);
if (hostname != NULL)
PyMem_Free(hostname);
@@ -2309,6 +2400,137 @@ set_ecdh_curve(PySSLContext *self, PyObject *name)
}
#endif
+#ifndef OPENSSL_NO_TLSEXT
+static int
+_servername_callback(SSL *s, int *al, void *args)
+{
+ int ret;
+ PySSLContext *ssl_ctx = (PySSLContext *) args;
+ PySSLSocket *ssl;
+ PyObject *servername_o;
+ PyObject *servername_idna;
+ PyObject *result;
+ /* The high-level ssl.SSLSocket object */
+ PyObject *ssl_socket;
+ const char *servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name);
+#ifdef WITH_THREAD
+ PyGILState_STATE gstate = PyGILState_Ensure();
+#endif
+
+ if (ssl_ctx->set_hostname == NULL) {
+ /* remove race condition in this the call back while if removing the
+ * callback is in progress */
+#ifdef WITH_THREAD
+ PyGILState_Release(gstate);
+#endif
+ return SSL_TLSEXT_ERR_OK;
+ }
+
+ ssl = SSL_get_app_data(s);
+ assert(PySSLSocket_Check(ssl));
+ ssl_socket = PyWeakref_GetObject(ssl->Socket);
+ Py_INCREF(ssl_socket);
+ if (ssl_socket == Py_None) {
+ goto error;
+ }
+
+ servername_o = PyBytes_FromString(servername);
+ if (servername_o == NULL) {
+ PyErr_WriteUnraisable((PyObject *) ssl_ctx);
+ goto error;
+ }
+ servername_idna = PyUnicode_FromEncodedObject(servername_o, "idna", NULL);
+ if (servername_idna == NULL) {
+ PyErr_WriteUnraisable(servername_o);
+ Py_DECREF(servername_o);
+ goto error;
+ }
+ Py_DECREF(servername_o);
+ result = PyObject_CallFunctionObjArgs(ssl_ctx->set_hostname, ssl_socket,
+ servername_idna, ssl_ctx, NULL);
+ Py_DECREF(ssl_socket);
+ Py_DECREF(servername_idna);
+
+ if (result == NULL) {
+ PyErr_WriteUnraisable(ssl_ctx->set_hostname);
+ *al = SSL_AD_HANDSHAKE_FAILURE;
+ ret = SSL_TLSEXT_ERR_ALERT_FATAL;
+ }
+ else {
+ if (result != Py_None) {
+ *al = (int) PyLong_AsLong(result);
+ if (PyErr_Occurred()) {
+ PyErr_WriteUnraisable(result);
+ *al = SSL_AD_INTERNAL_ERROR;
+ }
+ ret = SSL_TLSEXT_ERR_ALERT_FATAL;
+ }
+ else {
+ ret = SSL_TLSEXT_ERR_OK;
+ }
+ Py_DECREF(result);
+ }
+
+#ifdef WITH_THREAD
+ PyGILState_Release(gstate);
+#endif
+ return ret;
+
+error:
+ Py_DECREF(ssl_socket);
+ *al = SSL_AD_INTERNAL_ERROR;
+ ret = SSL_TLSEXT_ERR_ALERT_FATAL;
+#ifdef WITH_THREAD
+ PyGILState_Release(gstate);
+#endif
+ return ret;
+}
+
+PyDoc_STRVAR(PySSL_set_servername_callback_doc,
+"set_servername_callback(method)\n\
+\
+This sets a callback that will be called when a server name is provided by\n\
+the SSL/TLS client in the SNI extension.\n\
+\
+If the argument is None then the callback is disabled. The method is called\n\
+with the SSLSocket, the server name as a string, and the SSLContext object.\n\
+See RFC 6066 for details of the SNI");
+#endif
+
+static PyObject *
+set_servername_callback(PySSLContext *self, PyObject *args)
+{
+#ifndef OPENSSL_NO_TLSEXT
+ PyObject *cb;
+
+ if (!PyArg_ParseTuple(args, "O", &cb))
+ return NULL;
+
+ Py_CLEAR(self->set_hostname);
+ if (cb == Py_None) {
+ SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL);
+ }
+ else {
+ if (!PyCallable_Check(cb)) {
+ SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL);
+ PyErr_SetString(PyExc_TypeError,
+ "not a callable object");
+ return NULL;
+ }
+ Py_INCREF(cb);
+ self->set_hostname = cb;
+ SSL_CTX_set_tlsext_servername_callback(self->ctx, _servername_callback);
+ SSL_CTX_set_tlsext_servername_arg(self->ctx, self);
+ }
+ Py_RETURN_NONE;
+#else
+ PyErr_SetString(PyExc_NotImplementedError,
+ "The TLS extension servername callback, "
+ "SSL_CTX_set_tlsext_servername_callback, "
+ "is not in the current OpenSSL library.");
+#endif
+}
+
static PyGetSetDef context_getsetlist[] = {
{"options", (getter) get_options,
(setter) set_options, NULL},
@@ -2338,6 +2560,8 @@ static struct PyMethodDef context_methods[] = {
{"set_ecdh_curve", (PyCFunction) set_ecdh_curve,
METH_O, NULL},
#endif
+ {"set_servername_callback", (PyCFunction) set_servername_callback,
+ METH_VARARGS, PySSL_set_servername_callback_doc},
{NULL, NULL} /* sentinel */
};
@@ -2361,10 +2585,10 @@ static PyTypeObject PySSLContext_Type = {
0, /*tp_getattro*/
0, /*tp_setattro*/
0, /*tp_as_buffer*/
- Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /*tp_flags*/
0, /*tp_doc*/
- 0, /*tp_traverse*/
- 0, /*tp_clear*/
+ (traverseproc) context_traverse, /*tp_traverse*/
+ (inquiry) context_clear, /*tp_clear*/
0, /*tp_richcompare*/
0, /*tp_weaklistoffset*/
0, /*tp_iter*/
@@ -2744,6 +2968,51 @@ PyInit__ssl(void)
PyModule_AddIntConstant(m, "CERT_REQUIRED",
PY_SSL_CERT_REQUIRED);
+ /* Alert Descriptions from ssl.h */
+ /* note RESERVED constants no longer intended for use have been removed */
+ /* http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6 */
+
+#define ADD_AD_CONSTANT(s) \
+ PyModule_AddIntConstant(m, "ALERT_DESCRIPTION_"#s, \
+ SSL_AD_##s)
+
+ ADD_AD_CONSTANT(CLOSE_NOTIFY);
+ ADD_AD_CONSTANT(UNEXPECTED_MESSAGE);
+ ADD_AD_CONSTANT(BAD_RECORD_MAC);
+ ADD_AD_CONSTANT(RECORD_OVERFLOW);
+ ADD_AD_CONSTANT(DECOMPRESSION_FAILURE);
+ ADD_AD_CONSTANT(HANDSHAKE_FAILURE);
+ ADD_AD_CONSTANT(BAD_CERTIFICATE);
+ ADD_AD_CONSTANT(UNSUPPORTED_CERTIFICATE);
+ ADD_AD_CONSTANT(CERTIFICATE_REVOKED);
+ ADD_AD_CONSTANT(CERTIFICATE_EXPIRED);
+ ADD_AD_CONSTANT(CERTIFICATE_UNKNOWN);
+ ADD_AD_CONSTANT(ILLEGAL_PARAMETER);
+ ADD_AD_CONSTANT(UNKNOWN_CA);
+ ADD_AD_CONSTANT(ACCESS_DENIED);
+ ADD_AD_CONSTANT(DECODE_ERROR);
+ ADD_AD_CONSTANT(DECRYPT_ERROR);
+ ADD_AD_CONSTANT(PROTOCOL_VERSION);
+ ADD_AD_CONSTANT(INSUFFICIENT_SECURITY);
+ ADD_AD_CONSTANT(INTERNAL_ERROR);
+ ADD_AD_CONSTANT(USER_CANCELLED);
+ ADD_AD_CONSTANT(NO_RENEGOTIATION);
+ ADD_AD_CONSTANT(UNSUPPORTED_EXTENSION);
+ ADD_AD_CONSTANT(CERTIFICATE_UNOBTAINABLE);
+ ADD_AD_CONSTANT(UNRECOGNIZED_NAME);
+ /* Not all constants are in old OpenSSL versions */
+#ifdef SSL_AD_BAD_CERTIFICATE_STATUS_RESPONSE
+ ADD_AD_CONSTANT(BAD_CERTIFICATE_STATUS_RESPONSE);
+#endif
+#ifdef SSL_AD_BAD_CERTIFICATE_HASH_VALUE
+ ADD_AD_CONSTANT(BAD_CERTIFICATE_HASH_VALUE);
+#endif
+#ifdef SSL_AD_UNKNOWN_PSK_IDENTITY
+ ADD_AD_CONSTANT(UNKNOWN_PSK_IDENTITY);
+#endif
+
+#undef ADD_AD_CONSTANT
+
/* protocol versions */
#ifndef OPENSSL_NO_SSL2
PyModule_AddIntConstant(m, "PROTOCOL_SSLv2",
@@ -2755,6 +3024,12 @@ PyInit__ssl(void)
PY_SSL_VERSION_SSL23);
PyModule_AddIntConstant(m, "PROTOCOL_TLSv1",
PY_SSL_VERSION_TLS1);
+#if HAVE_TLSv1_2
+ PyModule_AddIntConstant(m, "PROTOCOL_TLSv1_1",
+ PY_SSL_VERSION_TLS1_1);
+ PyModule_AddIntConstant(m, "PROTOCOL_TLSv1_2",
+ PY_SSL_VERSION_TLS1_2);
+#endif
/* protocol options */
PyModule_AddIntConstant(m, "OP_ALL",
@@ -2762,6 +3037,10 @@ PyInit__ssl(void)
PyModule_AddIntConstant(m, "OP_NO_SSLv2", SSL_OP_NO_SSLv2);
PyModule_AddIntConstant(m, "OP_NO_SSLv3", SSL_OP_NO_SSLv3);
PyModule_AddIntConstant(m, "OP_NO_TLSv1", SSL_OP_NO_TLSv1);
+#if HAVE_TLSv1_2
+ PyModule_AddIntConstant(m, "OP_NO_TLSv1_1", SSL_OP_NO_TLSv1_1);
+ PyModule_AddIntConstant(m, "OP_NO_TLSv1_2", SSL_OP_NO_TLSv1_2);
+#endif
PyModule_AddIntConstant(m, "OP_CIPHER_SERVER_PREFERENCE",
SSL_OP_CIPHER_SERVER_PREFERENCE);
PyModule_AddIntConstant(m, "OP_SINGLE_DH_USE", SSL_OP_SINGLE_DH_USE);