summaryrefslogtreecommitdiffstats
path: root/Modules/_ssl.c
diff options
context:
space:
mode:
Diffstat (limited to 'Modules/_ssl.c')
-rw-r--r--Modules/_ssl.c132
1 files changed, 110 insertions, 22 deletions
diff --git a/Modules/_ssl.c b/Modules/_ssl.c
index 5969663..2e19439 100644
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -109,6 +109,11 @@ struct py_ssl_library_code {
# define HAVE_SNI 0
#endif
+/* ALPN added in OpenSSL 1.0.2 */
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+# define HAVE_ALPN
+#endif
+
enum py_ssl_error {
/* these mirror ssl.h */
PY_SSL_ERROR_NONE,
@@ -180,9 +185,13 @@ typedef struct {
PyObject_HEAD
SSL_CTX *ctx;
#ifdef OPENSSL_NPN_NEGOTIATED
- char *npn_protocols;
+ unsigned char *npn_protocols;
int npn_protocols_len;
#endif
+#ifdef HAVE_ALPN
+ unsigned char *alpn_protocols;
+ int alpn_protocols_len;
+#endif
#ifndef OPENSSL_NO_TLSEXT
PyObject *set_hostname;
#endif
@@ -1460,7 +1469,20 @@ static PyObject *PySSL_selected_npn_protocol(PySSLSocket *self) {
if (out == NULL)
Py_RETURN_NONE;
- return PyUnicode_FromStringAndSize((char *) out, outlen);
+ return PyUnicode_FromStringAndSize((char *)out, outlen);
+}
+#endif
+
+#ifdef HAVE_ALPN
+static PyObject *PySSL_selected_alpn_protocol(PySSLSocket *self) {
+ const unsigned char *out;
+ unsigned int outlen;
+
+ SSL_get0_alpn_selected(self->ssl, &out, &outlen);
+
+ if (out == NULL)
+ Py_RETURN_NONE;
+ return PyUnicode_FromStringAndSize((char *)out, outlen);
}
#endif
@@ -2054,6 +2076,9 @@ static PyMethodDef PySSLMethods[] = {
#ifdef OPENSSL_NPN_NEGOTIATED
{"selected_npn_protocol", (PyCFunction)PySSL_selected_npn_protocol, METH_NOARGS},
#endif
+#ifdef HAVE_ALPN
+ {"selected_alpn_protocol", (PyCFunction)PySSL_selected_alpn_protocol, METH_NOARGS},
+#endif
{"compression", (PyCFunction)PySSL_compression, METH_NOARGS},
{"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS,
PySSL_SSLshutdown_doc},
@@ -2159,6 +2184,9 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
#ifdef OPENSSL_NPN_NEGOTIATED
self->npn_protocols = NULL;
#endif
+#ifdef HAVE_ALPN
+ self->alpn_protocols = NULL;
+#endif
#ifndef OPENSSL_NO_TLSEXT
self->set_hostname = NULL;
#endif
@@ -2218,7 +2246,10 @@ context_dealloc(PySSLContext *self)
context_clear(self);
SSL_CTX_free(self->ctx);
#ifdef OPENSSL_NPN_NEGOTIATED
- PyMem_Free(self->npn_protocols);
+ PyMem_FREE(self->npn_protocols);
+#endif
+#ifdef HAVE_ALPN
+ PyMem_FREE(self->alpn_protocols);
#endif
Py_TYPE(self)->tp_free(self);
}
@@ -2244,6 +2275,23 @@ set_ciphers(PySSLContext *self, PyObject *args)
Py_RETURN_NONE;
}
+static int
+do_protocol_selection(unsigned char **out, unsigned char *outlen,
+ const unsigned char *remote_protocols, unsigned int remote_protocols_len,
+ unsigned char *our_protocols, unsigned int our_protocols_len)
+{
+ if (our_protocols == NULL) {
+ our_protocols = (unsigned char*)"";
+ our_protocols_len = 0;
+ }
+
+ SSL_select_next_proto(out, outlen,
+ remote_protocols, remote_protocols_len,
+ our_protocols, our_protocols_len);
+
+ return SSL_TLSEXT_ERR_OK;
+}
+
#ifdef OPENSSL_NPN_NEGOTIATED
/* this callback gets passed to SSL_CTX_set_next_protos_advertise_cb */
static int
@@ -2254,10 +2302,10 @@ _advertiseNPN_cb(SSL *s,
PySSLContext *ssl_ctx = (PySSLContext *) args;
if (ssl_ctx->npn_protocols == NULL) {
- *data = (unsigned char *) "";
+ *data = (unsigned char *)"";
*len = 0;
} else {
- *data = (unsigned char *) ssl_ctx->npn_protocols;
+ *data = ssl_ctx->npn_protocols;
*len = ssl_ctx->npn_protocols_len;
}
@@ -2270,23 +2318,9 @@ _selectNPN_cb(SSL *s,
const unsigned char *server, unsigned int server_len,
void *args)
{
- PySSLContext *ssl_ctx = (PySSLContext *) args;
-
- unsigned char *client = (unsigned char *) ssl_ctx->npn_protocols;
- int client_len;
-
- if (client == NULL) {
- client = (unsigned char *) "";
- client_len = 0;
- } else {
- client_len = ssl_ctx->npn_protocols_len;
- }
-
- SSL_select_next_proto(out, outlen,
- server, server_len,
- client, client_len);
-
- return SSL_TLSEXT_ERR_OK;
+ PySSLContext *ctx = (PySSLContext *)args;
+ return do_protocol_selection(out, outlen, server, server_len,
+ ctx->npn_protocols, ctx->npn_protocols_len);
}
#endif
@@ -2329,6 +2363,50 @@ _set_npn_protocols(PySSLContext *self, PyObject *args)
#endif
}
+#ifdef HAVE_ALPN
+static int
+_selectALPN_cb(SSL *s,
+ const unsigned char **out, unsigned char *outlen,
+ const unsigned char *client_protocols, unsigned int client_protocols_len,
+ void *args)
+{
+ PySSLContext *ctx = (PySSLContext *)args;
+ return do_protocol_selection((unsigned char **)out, outlen,
+ client_protocols, client_protocols_len,
+ ctx->alpn_protocols, ctx->alpn_protocols_len);
+}
+#endif
+
+static PyObject *
+_set_alpn_protocols(PySSLContext *self, PyObject *args)
+{
+#ifdef HAVE_ALPN
+ Py_buffer protos;
+
+ if (!PyArg_ParseTuple(args, "y*:set_npn_protocols", &protos))
+ return NULL;
+
+ PyMem_FREE(self->alpn_protocols);
+ self->alpn_protocols = PyMem_Malloc(protos.len);
+ if (!self->alpn_protocols)
+ return PyErr_NoMemory();
+ memcpy(self->alpn_protocols, protos.buf, protos.len);
+ self->alpn_protocols_len = protos.len;
+ PyBuffer_Release(&protos);
+
+ if (SSL_CTX_set_alpn_protos(self->ctx, self->alpn_protocols, self->alpn_protocols_len))
+ return PyErr_NoMemory();
+ SSL_CTX_set_alpn_select_cb(self->ctx, _selectALPN_cb, self);
+
+ PyBuffer_Release(&protos);
+ Py_RETURN_NONE;
+#else
+ PyErr_SetString(PyExc_NotImplementedError,
+ "The ALPN extension requires OpenSSL 1.0.2 or later.");
+ return NULL;
+#endif
+}
+
static PyObject *
get_verify_mode(PySSLContext *self, void *c)
{
@@ -3307,6 +3385,8 @@ static struct PyMethodDef context_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"set_ciphers", (PyCFunction) set_ciphers,
METH_VARARGS, NULL},
+ {"_set_alpn_protocols", (PyCFunction) _set_alpn_protocols,
+ METH_VARARGS, NULL},
{"_set_npn_protocols", (PyCFunction) _set_npn_protocols,
METH_VARARGS, NULL},
{"load_cert_chain", (PyCFunction) load_cert_chain,
@@ -4502,6 +4582,14 @@ PyInit__ssl(void)
Py_INCREF(r);
PyModule_AddObject(m, "HAS_NPN", r);
+#ifdef HAVE_ALPN
+ r = Py_True;
+#else
+ r = Py_False;
+#endif
+ Py_INCREF(r);
+ PyModule_AddObject(m, "HAS_ALPN", r);
+
/* Mappings for error codes */
err_codes_to_names = PyDict_New();
err_names_to_codes = PyDict_New();