diff options
author | Antoine Pitrou <solipsis@pitrou.net> | 2013-01-05 20:20:29 (GMT) |
---|---|---|
committer | Antoine Pitrou <solipsis@pitrou.net> | 2013-01-05 20:20:29 (GMT) |
commit | 58ddc9d743d09ee93d5cf46a4de62eab30dad79d (patch) | |
tree | 777a03261e792820c18c4a8ffd3923e2cc284adb /Modules/_ssl.c | |
parent | 3c9850aad7ef6d88d26081e062c11635af5dea8e (diff) | |
download | cpython-58ddc9d743d09ee93d5cf46a4de62eab30dad79d.zip cpython-58ddc9d743d09ee93d5cf46a4de62eab30dad79d.tar.gz cpython-58ddc9d743d09ee93d5cf46a4de62eab30dad79d.tar.bz2 |
Issue #8109: The ssl module now has support for server-side SNI, thanks to a :meth:`SSLContext.set_servername_callback` method.
Patch by Daniel Black.
Diffstat (limited to 'Modules/_ssl.c')
-rw-r--r-- | Modules/_ssl.c | 253 |
1 files changed, 248 insertions, 5 deletions
diff --git a/Modules/_ssl.c b/Modules/_ssl.c index ad22c52..16fa076 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -181,12 +181,16 @@ typedef struct { char *npn_protocols; int npn_protocols_len; #endif +#ifndef OPENSSL_NO_TLSEXT + PyObject *set_hostname; +#endif } PySSLContext; typedef struct { PyObject_HEAD PyObject *Socket; /* weakref to socket on which we're layered */ SSL *ssl; + PySSLContext *ctx; /* weakref to SSL context */ X509 *peer_cert; int shutdown_seen_zero; enum py_ssl_server_or_client socket_type; @@ -437,11 +441,12 @@ _setSSLError (char *errstr, int errcode, char *filename, int lineno) { */ static PySSLSocket * -newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock, +newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, enum py_ssl_server_or_client socket_type, char *server_hostname) { PySSLSocket *self; + SSL_CTX *ctx = sslctx->ctx; self = PyObject_New(PySSLSocket, &PySSLSocket_Type); if (self == NULL) @@ -450,6 +455,8 @@ newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock, self->peer_cert = NULL; self->ssl = NULL; self->Socket = NULL; + self->ctx = sslctx; + Py_INCREF(sslctx); /* Make sure the SSL error state is initialized */ (void) ERR_get_state(); @@ -458,6 +465,7 @@ newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock, PySSL_BEGIN_ALLOW_THREADS self->ssl = SSL_new(ctx); PySSL_END_ALLOW_THREADS + SSL_set_app_data(self->ssl,self); SSL_set_fd(self->ssl, sock->sock_fd); #ifdef SSL_MODE_AUTO_RETRY SSL_set_mode(self->ssl, SSL_MODE_AUTO_RETRY); @@ -1164,6 +1172,38 @@ static PyObject *PySSL_compression(PySSLSocket *self) { #endif } +static PySSLContext *PySSL_get_context(PySSLSocket *self, void *closure) { + Py_INCREF(self->ctx); + return self->ctx; +} + +static int PySSL_set_context(PySSLSocket *self, PyObject *value, + void *closure) { + + if (PyObject_TypeCheck(value, &PySSLContext_Type)) { + + Py_INCREF(value); + Py_DECREF(self->ctx); + self->ctx = (PySSLContext *) value; + SSL_set_SSL_CTX(self->ssl, self->ctx->ctx); + } else { + PyErr_SetString(PyExc_TypeError, "The value must be a SSLContext"); + return -1; + } + + return 0; +} + +PyDoc_STRVAR(PySSL_set_context_doc, +"_setter_context(ctx)\n\ +\ +This changes the context associated with the SSLSocket. This is typically\n\ +used from within a callback function set by the set_servername_callback\n\ +on the SSLContext to change the certificate information associated with the\n\ +SSLSocket before the cryptographic exchange handshake messages\n"); + + + static void PySSL_dealloc(PySSLSocket *self) { if (self->peer_cert) /* Possible not to have one? */ @@ -1171,6 +1211,7 @@ static void PySSL_dealloc(PySSLSocket *self) if (self->ssl) SSL_free(self->ssl); Py_XDECREF(self->Socket); + Py_XDECREF(self->ctx); PyObject_Del(self); } @@ -1606,6 +1647,12 @@ If the TLS handshake is not yet complete, None is returned"); #endif /* HAVE_OPENSSL_FINISHED */ +static PyGetSetDef ssl_getsetlist[] = { + {"context", (getter) PySSL_get_context, + (setter) PySSL_set_context, PySSL_set_context_doc}, + {NULL}, /* sentinel */ +}; + static PyMethodDef PySSLMethods[] = { {"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS}, {"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS, @@ -1660,6 +1707,8 @@ static PyTypeObject PySSLSocket_Type = { 0, /*tp_iter*/ 0, /*tp_iternext*/ PySSLMethods, /*tp_methods*/ + 0, /*tp_members*/ + ssl_getsetlist, /*tp_getset*/ }; @@ -1716,6 +1765,9 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds) #ifdef OPENSSL_NPN_NEGOTIATED self->npn_protocols = NULL; #endif +#ifndef OPENSSL_NO_TLSEXT + self->set_hostname = NULL; +#endif /* Defaults */ SSL_CTX_set_verify(self->ctx, SSL_VERIFY_NONE, NULL); SSL_CTX_set_options(self->ctx, @@ -1729,9 +1781,28 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds) return (PyObject *)self; } +static int +context_traverse(PySSLContext *self, visitproc visit, void *arg) +{ +#ifndef OPENSSL_NO_TLSEXT + Py_VISIT(self->set_hostname); +#endif + return 0; +} + +static int +context_clear(PySSLContext *self) +{ +#ifndef OPENSSL_NO_TLSEXT + Py_CLEAR(self->set_hostname); +#endif + return 0; +} + static void context_dealloc(PySSLContext *self) { + context_clear(self); SSL_CTX_free(self->ctx); #ifdef OPENSSL_NPN_NEGOTIATED PyMem_Free(self->npn_protocols); @@ -2223,7 +2294,7 @@ context_wrap_socket(PySSLContext *self, PyObject *args, PyObject *kwds) #endif } - res = (PyObject *) newPySSLSocket(self->ctx, sock, server_side, + res = (PyObject *) newPySSLSocket(self, sock, server_side, hostname); if (hostname != NULL) PyMem_Free(hostname); @@ -2308,6 +2379,131 @@ set_ecdh_curve(PySSLContext *self, PyObject *name) } #endif +#ifndef OPENSSL_NO_TLSEXT +static int +_servername_callback(SSL *s, int *al, void *args) +{ + int ret; + PySSLContext *ssl_ctx = (PySSLContext *) args; + PySSLSocket *ssl; + PyObject *servername_o; + PyObject *servername_idna; + PyObject *result; + /* The high-level ssl.SSLSocket object */ + PyObject *ssl_socket; + PyGILState_STATE gstate; + const char *servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name); + + gstate = PyGILState_Ensure(); + + if (ssl_ctx->set_hostname == NULL) { + /* remove race condition in this the call back while if removing the + * callback is in progress */ + PyGILState_Release(gstate); + return ret; + } + + ssl = SSL_get_app_data(s); + assert(PySSLSocket_Check(ssl)); + ssl_socket = PyWeakref_GetObject(ssl->Socket); + Py_INCREF(ssl_socket); + if (ssl_socket == Py_None) { + goto error; + } + + servername_o = PyBytes_FromString(servername); + if (servername_o == NULL) { + PyErr_WriteUnraisable((PyObject *) ssl_ctx); + goto error; + } + servername_idna = PyUnicode_FromEncodedObject(servername_o, "idna", NULL); + if (servername_idna == NULL) { + PyErr_WriteUnraisable(servername_o); + Py_DECREF(servername_o); + goto error; + } + Py_DECREF(servername_o); + result = PyObject_CallFunctionObjArgs(ssl_ctx->set_hostname, ssl_socket, + servername_idna, ssl_ctx, NULL); + Py_DECREF(ssl_socket); + Py_DECREF(servername_idna); + + if (result == NULL) { + PyErr_WriteUnraisable(ssl_ctx->set_hostname); + *al = SSL_AD_HANDSHAKE_FAILURE; + ret = SSL_TLSEXT_ERR_ALERT_FATAL; + } + else { + if (result != Py_None) { + *al = (int) PyLong_AsLong(result); + if (PyErr_Occurred()) { + PyErr_WriteUnraisable(result); + *al = SSL_AD_INTERNAL_ERROR; + } + ret = SSL_TLSEXT_ERR_ALERT_FATAL; + } + else { + ret = SSL_TLSEXT_ERR_OK; + } + Py_DECREF(result); + } + + PyGILState_Release(gstate); + return ret; + +error: + Py_DECREF(ssl_socket); + *al = SSL_AD_INTERNAL_ERROR; + ret = SSL_TLSEXT_ERR_ALERT_FATAL; + PyGILState_Release(gstate); + return ret; +} + +PyDoc_STRVAR(PySSL_set_servername_callback_doc, +"set_servername_callback(method)\n\ +\ +This sets a callback that will be called when a server name is provided by\n\ +the SSL/TLS client in the SNI extension.\n\ +\ +If the argument is None then the callback is disabled. The method is called\n\ +with the SSLSocket, the server name as a string, and the SSLContext object.\n\ +See RFC 6066 for details of the SNI"); +#endif + +static PyObject * +set_servername_callback(PySSLContext *self, PyObject *args) +{ +#ifndef OPENSSL_NO_TLSEXT + PyObject *cb; + + if (!PyArg_ParseTuple(args, "O", &cb)) + return NULL; + + Py_CLEAR(self->set_hostname); + if (cb == Py_None) { + SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL); + } + else { + if (!PyCallable_Check(cb)) { + SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL); + PyErr_SetString(PyExc_TypeError, + "not a callable object"); + return NULL; + } + Py_INCREF(cb); + self->set_hostname = cb; + SSL_CTX_set_tlsext_servername_callback(self->ctx, _servername_callback); + SSL_CTX_set_tlsext_servername_arg(self->ctx, self); + } + Py_RETURN_NONE; +#else + PyErr_SetString(PyExc_NotImplementedError, + "The TLS extension servername callback, " + "SSL_CTX_set_tlsext_servername_callback, " + "is not in the current OpenSSL library."); +#endif +} + static PyGetSetDef context_getsetlist[] = { {"options", (getter) get_options, (setter) set_options, NULL}, @@ -2337,6 +2533,8 @@ static struct PyMethodDef context_methods[] = { {"set_ecdh_curve", (PyCFunction) set_ecdh_curve, METH_O, NULL}, #endif + {"set_servername_callback", (PyCFunction) set_servername_callback, + METH_VARARGS, PySSL_set_servername_callback_doc}, {NULL, NULL} /* sentinel */ }; @@ -2360,10 +2558,10 @@ static PyTypeObject PySSLContext_Type = { 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /*tp_flags*/ 0, /*tp_doc*/ - 0, /*tp_traverse*/ - 0, /*tp_clear*/ + (traverseproc) context_traverse, /*tp_traverse*/ + (inquiry) context_clear, /*tp_clear*/ 0, /*tp_richcompare*/ 0, /*tp_weaklistoffset*/ 0, /*tp_iter*/ @@ -2743,6 +2941,51 @@ PyInit__ssl(void) PyModule_AddIntConstant(m, "CERT_REQUIRED", PY_SSL_CERT_REQUIRED); + /* Alert Descriptions from ssl.h */ + /* note RESERVED constants no longer intended for use have been removed */ + /* http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6 */ + +#define ADD_AD_CONSTANT(s) \ + PyModule_AddIntConstant(m, "ALERT_DESCRIPTION_"#s, \ + SSL_AD_##s) + + ADD_AD_CONSTANT(CLOSE_NOTIFY); + ADD_AD_CONSTANT(UNEXPECTED_MESSAGE); + ADD_AD_CONSTANT(BAD_RECORD_MAC); + ADD_AD_CONSTANT(RECORD_OVERFLOW); + ADD_AD_CONSTANT(DECOMPRESSION_FAILURE); + ADD_AD_CONSTANT(HANDSHAKE_FAILURE); + ADD_AD_CONSTANT(BAD_CERTIFICATE); + ADD_AD_CONSTANT(UNSUPPORTED_CERTIFICATE); + ADD_AD_CONSTANT(CERTIFICATE_REVOKED); + ADD_AD_CONSTANT(CERTIFICATE_EXPIRED); + ADD_AD_CONSTANT(CERTIFICATE_UNKNOWN); + ADD_AD_CONSTANT(ILLEGAL_PARAMETER); + ADD_AD_CONSTANT(UNKNOWN_CA); + ADD_AD_CONSTANT(ACCESS_DENIED); + ADD_AD_CONSTANT(DECODE_ERROR); + ADD_AD_CONSTANT(DECRYPT_ERROR); + ADD_AD_CONSTANT(PROTOCOL_VERSION); + ADD_AD_CONSTANT(INSUFFICIENT_SECURITY); + ADD_AD_CONSTANT(INTERNAL_ERROR); + ADD_AD_CONSTANT(USER_CANCELLED); + ADD_AD_CONSTANT(NO_RENEGOTIATION); + ADD_AD_CONSTANT(UNSUPPORTED_EXTENSION); + ADD_AD_CONSTANT(CERTIFICATE_UNOBTAINABLE); + ADD_AD_CONSTANT(UNRECOGNIZED_NAME); + /* Not all constants are in old OpenSSL versions */ +#ifdef SSL_AD_BAD_CERTIFICATE_STATUS_RESPONSE + ADD_AD_CONSTANT(BAD_CERTIFICATE_STATUS_RESPONSE); +#endif +#ifdef SSL_AD_BAD_CERTIFICATE_HASH_VALUE + ADD_AD_CONSTANT(BAD_CERTIFICATE_HASH_VALUE); +#endif +#ifdef SSL_AD_UNKNOWN_PSK_IDENTITY + ADD_AD_CONSTANT(UNKNOWN_PSK_IDENTITY); +#endif + +#undef ADD_AD_CONSTANT + /* protocol versions */ #ifndef OPENSSL_NO_SSL2 PyModule_AddIntConstant(m, "PROTOCOL_SSLv2", |