summaryrefslogtreecommitdiffstats
path: root/Modules/_ssl.c
diff options
context:
space:
mode:
authorAntoine Pitrou <solipsis@pitrou.net>2013-01-05 20:20:29 (GMT)
committerAntoine Pitrou <solipsis@pitrou.net>2013-01-05 20:20:29 (GMT)
commit58ddc9d743d09ee93d5cf46a4de62eab30dad79d (patch)
tree777a03261e792820c18c4a8ffd3923e2cc284adb /Modules/_ssl.c
parent3c9850aad7ef6d88d26081e062c11635af5dea8e (diff)
downloadcpython-58ddc9d743d09ee93d5cf46a4de62eab30dad79d.zip
cpython-58ddc9d743d09ee93d5cf46a4de62eab30dad79d.tar.gz
cpython-58ddc9d743d09ee93d5cf46a4de62eab30dad79d.tar.bz2
Issue #8109: The ssl module now has support for server-side SNI, thanks to a :meth:`SSLContext.set_servername_callback` method.
Patch by Daniel Black.
Diffstat (limited to 'Modules/_ssl.c')
-rw-r--r--Modules/_ssl.c253
1 files changed, 248 insertions, 5 deletions
diff --git a/Modules/_ssl.c b/Modules/_ssl.c
index ad22c52..16fa076 100644
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -181,12 +181,16 @@ typedef struct {
char *npn_protocols;
int npn_protocols_len;
#endif
+#ifndef OPENSSL_NO_TLSEXT
+ PyObject *set_hostname;
+#endif
} PySSLContext;
typedef struct {
PyObject_HEAD
PyObject *Socket; /* weakref to socket on which we're layered */
SSL *ssl;
+ PySSLContext *ctx; /* weakref to SSL context */
X509 *peer_cert;
int shutdown_seen_zero;
enum py_ssl_server_or_client socket_type;
@@ -437,11 +441,12 @@ _setSSLError (char *errstr, int errcode, char *filename, int lineno) {
*/
static PySSLSocket *
-newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock,
+newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock,
enum py_ssl_server_or_client socket_type,
char *server_hostname)
{
PySSLSocket *self;
+ SSL_CTX *ctx = sslctx->ctx;
self = PyObject_New(PySSLSocket, &PySSLSocket_Type);
if (self == NULL)
@@ -450,6 +455,8 @@ newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock,
self->peer_cert = NULL;
self->ssl = NULL;
self->Socket = NULL;
+ self->ctx = sslctx;
+ Py_INCREF(sslctx);
/* Make sure the SSL error state is initialized */
(void) ERR_get_state();
@@ -458,6 +465,7 @@ newPySSLSocket(SSL_CTX *ctx, PySocketSockObject *sock,
PySSL_BEGIN_ALLOW_THREADS
self->ssl = SSL_new(ctx);
PySSL_END_ALLOW_THREADS
+ SSL_set_app_data(self->ssl,self);
SSL_set_fd(self->ssl, sock->sock_fd);
#ifdef SSL_MODE_AUTO_RETRY
SSL_set_mode(self->ssl, SSL_MODE_AUTO_RETRY);
@@ -1164,6 +1172,38 @@ static PyObject *PySSL_compression(PySSLSocket *self) {
#endif
}
+static PySSLContext *PySSL_get_context(PySSLSocket *self, void *closure) {
+ Py_INCREF(self->ctx);
+ return self->ctx;
+}
+
+static int PySSL_set_context(PySSLSocket *self, PyObject *value,
+ void *closure) {
+
+ if (PyObject_TypeCheck(value, &PySSLContext_Type)) {
+
+ Py_INCREF(value);
+ Py_DECREF(self->ctx);
+ self->ctx = (PySSLContext *) value;
+ SSL_set_SSL_CTX(self->ssl, self->ctx->ctx);
+ } else {
+ PyErr_SetString(PyExc_TypeError, "The value must be a SSLContext");
+ return -1;
+ }
+
+ return 0;
+}
+
+PyDoc_STRVAR(PySSL_set_context_doc,
+"_setter_context(ctx)\n\
+\
+This changes the context associated with the SSLSocket. This is typically\n\
+used from within a callback function set by the set_servername_callback\n\
+on the SSLContext to change the certificate information associated with the\n\
+SSLSocket before the cryptographic exchange handshake messages\n");
+
+
+
static void PySSL_dealloc(PySSLSocket *self)
{
if (self->peer_cert) /* Possible not to have one? */
@@ -1171,6 +1211,7 @@ static void PySSL_dealloc(PySSLSocket *self)
if (self->ssl)
SSL_free(self->ssl);
Py_XDECREF(self->Socket);
+ Py_XDECREF(self->ctx);
PyObject_Del(self);
}
@@ -1606,6 +1647,12 @@ If the TLS handshake is not yet complete, None is returned");
#endif /* HAVE_OPENSSL_FINISHED */
+static PyGetSetDef ssl_getsetlist[] = {
+ {"context", (getter) PySSL_get_context,
+ (setter) PySSL_set_context, PySSL_set_context_doc},
+ {NULL}, /* sentinel */
+};
+
static PyMethodDef PySSLMethods[] = {
{"do_handshake", (PyCFunction)PySSL_SSLdo_handshake, METH_NOARGS},
{"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS,
@@ -1660,6 +1707,8 @@ static PyTypeObject PySSLSocket_Type = {
0, /*tp_iter*/
0, /*tp_iternext*/
PySSLMethods, /*tp_methods*/
+ 0, /*tp_members*/
+ ssl_getsetlist, /*tp_getset*/
};
@@ -1716,6 +1765,9 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
#ifdef OPENSSL_NPN_NEGOTIATED
self->npn_protocols = NULL;
#endif
+#ifndef OPENSSL_NO_TLSEXT
+ self->set_hostname = NULL;
+#endif
/* Defaults */
SSL_CTX_set_verify(self->ctx, SSL_VERIFY_NONE, NULL);
SSL_CTX_set_options(self->ctx,
@@ -1729,9 +1781,28 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
return (PyObject *)self;
}
+static int
+context_traverse(PySSLContext *self, visitproc visit, void *arg)
+{
+#ifndef OPENSSL_NO_TLSEXT
+ Py_VISIT(self->set_hostname);
+#endif
+ return 0;
+}
+
+static int
+context_clear(PySSLContext *self)
+{
+#ifndef OPENSSL_NO_TLSEXT
+ Py_CLEAR(self->set_hostname);
+#endif
+ return 0;
+}
+
static void
context_dealloc(PySSLContext *self)
{
+ context_clear(self);
SSL_CTX_free(self->ctx);
#ifdef OPENSSL_NPN_NEGOTIATED
PyMem_Free(self->npn_protocols);
@@ -2223,7 +2294,7 @@ context_wrap_socket(PySSLContext *self, PyObject *args, PyObject *kwds)
#endif
}
- res = (PyObject *) newPySSLSocket(self->ctx, sock, server_side,
+ res = (PyObject *) newPySSLSocket(self, sock, server_side,
hostname);
if (hostname != NULL)
PyMem_Free(hostname);
@@ -2308,6 +2379,131 @@ set_ecdh_curve(PySSLContext *self, PyObject *name)
}
#endif
+#ifndef OPENSSL_NO_TLSEXT
+static int
+_servername_callback(SSL *s, int *al, void *args)
+{
+ int ret;
+ PySSLContext *ssl_ctx = (PySSLContext *) args;
+ PySSLSocket *ssl;
+ PyObject *servername_o;
+ PyObject *servername_idna;
+ PyObject *result;
+ /* The high-level ssl.SSLSocket object */
+ PyObject *ssl_socket;
+ PyGILState_STATE gstate;
+ const char *servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name);
+
+ gstate = PyGILState_Ensure();
+
+ if (ssl_ctx->set_hostname == NULL) {
+ /* remove race condition in this the call back while if removing the
+ * callback is in progress */
+ PyGILState_Release(gstate);
+ return ret;
+ }
+
+ ssl = SSL_get_app_data(s);
+ assert(PySSLSocket_Check(ssl));
+ ssl_socket = PyWeakref_GetObject(ssl->Socket);
+ Py_INCREF(ssl_socket);
+ if (ssl_socket == Py_None) {
+ goto error;
+ }
+
+ servername_o = PyBytes_FromString(servername);
+ if (servername_o == NULL) {
+ PyErr_WriteUnraisable((PyObject *) ssl_ctx);
+ goto error;
+ }
+ servername_idna = PyUnicode_FromEncodedObject(servername_o, "idna", NULL);
+ if (servername_idna == NULL) {
+ PyErr_WriteUnraisable(servername_o);
+ Py_DECREF(servername_o);
+ goto error;
+ }
+ Py_DECREF(servername_o);
+ result = PyObject_CallFunctionObjArgs(ssl_ctx->set_hostname, ssl_socket,
+ servername_idna, ssl_ctx, NULL);
+ Py_DECREF(ssl_socket);
+ Py_DECREF(servername_idna);
+
+ if (result == NULL) {
+ PyErr_WriteUnraisable(ssl_ctx->set_hostname);
+ *al = SSL_AD_HANDSHAKE_FAILURE;
+ ret = SSL_TLSEXT_ERR_ALERT_FATAL;
+ }
+ else {
+ if (result != Py_None) {
+ *al = (int) PyLong_AsLong(result);
+ if (PyErr_Occurred()) {
+ PyErr_WriteUnraisable(result);
+ *al = SSL_AD_INTERNAL_ERROR;
+ }
+ ret = SSL_TLSEXT_ERR_ALERT_FATAL;
+ }
+ else {
+ ret = SSL_TLSEXT_ERR_OK;
+ }
+ Py_DECREF(result);
+ }
+
+ PyGILState_Release(gstate);
+ return ret;
+
+error:
+ Py_DECREF(ssl_socket);
+ *al = SSL_AD_INTERNAL_ERROR;
+ ret = SSL_TLSEXT_ERR_ALERT_FATAL;
+ PyGILState_Release(gstate);
+ return ret;
+}
+
+PyDoc_STRVAR(PySSL_set_servername_callback_doc,
+"set_servername_callback(method)\n\
+\
+This sets a callback that will be called when a server name is provided by\n\
+the SSL/TLS client in the SNI extension.\n\
+\
+If the argument is None then the callback is disabled. The method is called\n\
+with the SSLSocket, the server name as a string, and the SSLContext object.\n\
+See RFC 6066 for details of the SNI");
+#endif
+
+static PyObject *
+set_servername_callback(PySSLContext *self, PyObject *args)
+{
+#ifndef OPENSSL_NO_TLSEXT
+ PyObject *cb;
+
+ if (!PyArg_ParseTuple(args, "O", &cb))
+ return NULL;
+
+ Py_CLEAR(self->set_hostname);
+ if (cb == Py_None) {
+ SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL);
+ }
+ else {
+ if (!PyCallable_Check(cb)) {
+ SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL);
+ PyErr_SetString(PyExc_TypeError,
+ "not a callable object");
+ return NULL;
+ }
+ Py_INCREF(cb);
+ self->set_hostname = cb;
+ SSL_CTX_set_tlsext_servername_callback(self->ctx, _servername_callback);
+ SSL_CTX_set_tlsext_servername_arg(self->ctx, self);
+ }
+ Py_RETURN_NONE;
+#else
+ PyErr_SetString(PyExc_NotImplementedError,
+ "The TLS extension servername callback, "
+ "SSL_CTX_set_tlsext_servername_callback, "
+ "is not in the current OpenSSL library.");
+#endif
+}
+
static PyGetSetDef context_getsetlist[] = {
{"options", (getter) get_options,
(setter) set_options, NULL},
@@ -2337,6 +2533,8 @@ static struct PyMethodDef context_methods[] = {
{"set_ecdh_curve", (PyCFunction) set_ecdh_curve,
METH_O, NULL},
#endif
+ {"set_servername_callback", (PyCFunction) set_servername_callback,
+ METH_VARARGS, PySSL_set_servername_callback_doc},
{NULL, NULL} /* sentinel */
};
@@ -2360,10 +2558,10 @@ static PyTypeObject PySSLContext_Type = {
0, /*tp_getattro*/
0, /*tp_setattro*/
0, /*tp_as_buffer*/
- Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /*tp_flags*/
0, /*tp_doc*/
- 0, /*tp_traverse*/
- 0, /*tp_clear*/
+ (traverseproc) context_traverse, /*tp_traverse*/
+ (inquiry) context_clear, /*tp_clear*/
0, /*tp_richcompare*/
0, /*tp_weaklistoffset*/
0, /*tp_iter*/
@@ -2743,6 +2941,51 @@ PyInit__ssl(void)
PyModule_AddIntConstant(m, "CERT_REQUIRED",
PY_SSL_CERT_REQUIRED);
+ /* Alert Descriptions from ssl.h */
+ /* note RESERVED constants no longer intended for use have been removed */
+ /* http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6 */
+
+#define ADD_AD_CONSTANT(s) \
+ PyModule_AddIntConstant(m, "ALERT_DESCRIPTION_"#s, \
+ SSL_AD_##s)
+
+ ADD_AD_CONSTANT(CLOSE_NOTIFY);
+ ADD_AD_CONSTANT(UNEXPECTED_MESSAGE);
+ ADD_AD_CONSTANT(BAD_RECORD_MAC);
+ ADD_AD_CONSTANT(RECORD_OVERFLOW);
+ ADD_AD_CONSTANT(DECOMPRESSION_FAILURE);
+ ADD_AD_CONSTANT(HANDSHAKE_FAILURE);
+ ADD_AD_CONSTANT(BAD_CERTIFICATE);
+ ADD_AD_CONSTANT(UNSUPPORTED_CERTIFICATE);
+ ADD_AD_CONSTANT(CERTIFICATE_REVOKED);
+ ADD_AD_CONSTANT(CERTIFICATE_EXPIRED);
+ ADD_AD_CONSTANT(CERTIFICATE_UNKNOWN);
+ ADD_AD_CONSTANT(ILLEGAL_PARAMETER);
+ ADD_AD_CONSTANT(UNKNOWN_CA);
+ ADD_AD_CONSTANT(ACCESS_DENIED);
+ ADD_AD_CONSTANT(DECODE_ERROR);
+ ADD_AD_CONSTANT(DECRYPT_ERROR);
+ ADD_AD_CONSTANT(PROTOCOL_VERSION);
+ ADD_AD_CONSTANT(INSUFFICIENT_SECURITY);
+ ADD_AD_CONSTANT(INTERNAL_ERROR);
+ ADD_AD_CONSTANT(USER_CANCELLED);
+ ADD_AD_CONSTANT(NO_RENEGOTIATION);
+ ADD_AD_CONSTANT(UNSUPPORTED_EXTENSION);
+ ADD_AD_CONSTANT(CERTIFICATE_UNOBTAINABLE);
+ ADD_AD_CONSTANT(UNRECOGNIZED_NAME);
+ /* Not all constants are in old OpenSSL versions */
+#ifdef SSL_AD_BAD_CERTIFICATE_STATUS_RESPONSE
+ ADD_AD_CONSTANT(BAD_CERTIFICATE_STATUS_RESPONSE);
+#endif
+#ifdef SSL_AD_BAD_CERTIFICATE_HASH_VALUE
+ ADD_AD_CONSTANT(BAD_CERTIFICATE_HASH_VALUE);
+#endif
+#ifdef SSL_AD_UNKNOWN_PSK_IDENTITY
+ ADD_AD_CONSTANT(UNKNOWN_PSK_IDENTITY);
+#endif
+
+#undef ADD_AD_CONSTANT
+
/* protocol versions */
#ifndef OPENSSL_NO_SSL2
PyModule_AddIntConstant(m, "PROTOCOL_SSLv2",