diff options
Diffstat (limited to 'Modules/_ssl.c')
-rw-r--r-- | Modules/_ssl.c | 132 |
1 files changed, 110 insertions, 22 deletions
diff --git a/Modules/_ssl.c b/Modules/_ssl.c index 5969663..2e19439 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -109,6 +109,11 @@ struct py_ssl_library_code { # define HAVE_SNI 0 #endif +/* ALPN added in OpenSSL 1.0.2 */ +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) +# define HAVE_ALPN +#endif + enum py_ssl_error { /* these mirror ssl.h */ PY_SSL_ERROR_NONE, @@ -180,9 +185,13 @@ typedef struct { PyObject_HEAD SSL_CTX *ctx; #ifdef OPENSSL_NPN_NEGOTIATED - char *npn_protocols; + unsigned char *npn_protocols; int npn_protocols_len; #endif +#ifdef HAVE_ALPN + unsigned char *alpn_protocols; + int alpn_protocols_len; +#endif #ifndef OPENSSL_NO_TLSEXT PyObject *set_hostname; #endif @@ -1460,7 +1469,20 @@ static PyObject *PySSL_selected_npn_protocol(PySSLSocket *self) { if (out == NULL) Py_RETURN_NONE; - return PyUnicode_FromStringAndSize((char *) out, outlen); + return PyUnicode_FromStringAndSize((char *)out, outlen); +} +#endif + +#ifdef HAVE_ALPN +static PyObject *PySSL_selected_alpn_protocol(PySSLSocket *self) { + const unsigned char *out; + unsigned int outlen; + + SSL_get0_alpn_selected(self->ssl, &out, &outlen); + + if (out == NULL) + Py_RETURN_NONE; + return PyUnicode_FromStringAndSize((char *)out, outlen); } #endif @@ -2054,6 +2076,9 @@ static PyMethodDef PySSLMethods[] = { #ifdef OPENSSL_NPN_NEGOTIATED {"selected_npn_protocol", (PyCFunction)PySSL_selected_npn_protocol, METH_NOARGS}, #endif +#ifdef HAVE_ALPN + {"selected_alpn_protocol", (PyCFunction)PySSL_selected_alpn_protocol, METH_NOARGS}, +#endif {"compression", (PyCFunction)PySSL_compression, METH_NOARGS}, {"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS, PySSL_SSLshutdown_doc}, @@ -2159,6 +2184,9 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds) #ifdef OPENSSL_NPN_NEGOTIATED self->npn_protocols = NULL; #endif +#ifdef HAVE_ALPN + self->alpn_protocols = NULL; +#endif #ifndef OPENSSL_NO_TLSEXT self->set_hostname = NULL; #endif @@ -2218,7 +2246,10 @@ context_dealloc(PySSLContext *self) context_clear(self); SSL_CTX_free(self->ctx); #ifdef OPENSSL_NPN_NEGOTIATED - PyMem_Free(self->npn_protocols); + PyMem_FREE(self->npn_protocols); +#endif +#ifdef HAVE_ALPN + PyMem_FREE(self->alpn_protocols); #endif Py_TYPE(self)->tp_free(self); } @@ -2244,6 +2275,23 @@ set_ciphers(PySSLContext *self, PyObject *args) Py_RETURN_NONE; } +static int +do_protocol_selection(unsigned char **out, unsigned char *outlen, + const unsigned char *remote_protocols, unsigned int remote_protocols_len, + unsigned char *our_protocols, unsigned int our_protocols_len) +{ + if (our_protocols == NULL) { + our_protocols = (unsigned char*)""; + our_protocols_len = 0; + } + + SSL_select_next_proto(out, outlen, + remote_protocols, remote_protocols_len, + our_protocols, our_protocols_len); + + return SSL_TLSEXT_ERR_OK; +} + #ifdef OPENSSL_NPN_NEGOTIATED /* this callback gets passed to SSL_CTX_set_next_protos_advertise_cb */ static int @@ -2254,10 +2302,10 @@ _advertiseNPN_cb(SSL *s, PySSLContext *ssl_ctx = (PySSLContext *) args; if (ssl_ctx->npn_protocols == NULL) { - *data = (unsigned char *) ""; + *data = (unsigned char *)""; *len = 0; } else { - *data = (unsigned char *) ssl_ctx->npn_protocols; + *data = ssl_ctx->npn_protocols; *len = ssl_ctx->npn_protocols_len; } @@ -2270,23 +2318,9 @@ _selectNPN_cb(SSL *s, const unsigned char *server, unsigned int server_len, void *args) { - PySSLContext *ssl_ctx = (PySSLContext *) args; - - unsigned char *client = (unsigned char *) ssl_ctx->npn_protocols; - int client_len; - - if (client == NULL) { - client = (unsigned char *) ""; - client_len = 0; - } else { - client_len = ssl_ctx->npn_protocols_len; - } - - SSL_select_next_proto(out, outlen, - server, server_len, - client, client_len); - - return SSL_TLSEXT_ERR_OK; + PySSLContext *ctx = (PySSLContext *)args; + return do_protocol_selection(out, outlen, server, server_len, + ctx->npn_protocols, ctx->npn_protocols_len); } #endif @@ -2329,6 +2363,50 @@ _set_npn_protocols(PySSLContext *self, PyObject *args) #endif } +#ifdef HAVE_ALPN +static int +_selectALPN_cb(SSL *s, + const unsigned char **out, unsigned char *outlen, + const unsigned char *client_protocols, unsigned int client_protocols_len, + void *args) +{ + PySSLContext *ctx = (PySSLContext *)args; + return do_protocol_selection((unsigned char **)out, outlen, + client_protocols, client_protocols_len, + ctx->alpn_protocols, ctx->alpn_protocols_len); +} +#endif + +static PyObject * +_set_alpn_protocols(PySSLContext *self, PyObject *args) +{ +#ifdef HAVE_ALPN + Py_buffer protos; + + if (!PyArg_ParseTuple(args, "y*:set_npn_protocols", &protos)) + return NULL; + + PyMem_FREE(self->alpn_protocols); + self->alpn_protocols = PyMem_Malloc(protos.len); + if (!self->alpn_protocols) + return PyErr_NoMemory(); + memcpy(self->alpn_protocols, protos.buf, protos.len); + self->alpn_protocols_len = protos.len; + PyBuffer_Release(&protos); + + if (SSL_CTX_set_alpn_protos(self->ctx, self->alpn_protocols, self->alpn_protocols_len)) + return PyErr_NoMemory(); + SSL_CTX_set_alpn_select_cb(self->ctx, _selectALPN_cb, self); + + PyBuffer_Release(&protos); + Py_RETURN_NONE; +#else + PyErr_SetString(PyExc_NotImplementedError, + "The ALPN extension requires OpenSSL 1.0.2 or later."); + return NULL; +#endif +} + static PyObject * get_verify_mode(PySSLContext *self, void *c) { @@ -3307,6 +3385,8 @@ static struct PyMethodDef context_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"set_ciphers", (PyCFunction) set_ciphers, METH_VARARGS, NULL}, + {"_set_alpn_protocols", (PyCFunction) _set_alpn_protocols, + METH_VARARGS, NULL}, {"_set_npn_protocols", (PyCFunction) _set_npn_protocols, METH_VARARGS, NULL}, {"load_cert_chain", (PyCFunction) load_cert_chain, @@ -4502,6 +4582,14 @@ PyInit__ssl(void) Py_INCREF(r); PyModule_AddObject(m, "HAS_NPN", r); +#ifdef HAVE_ALPN + r = Py_True; +#else + r = Py_False; +#endif + Py_INCREF(r); + PyModule_AddObject(m, "HAS_ALPN", r); + /* Mappings for error codes */ err_codes_to_names = PyDict_New(); err_names_to_codes = PyDict_New(); |