From db5aed931f8a617f7b63e773f62db468fe9c5ca1 Mon Sep 17 00:00:00 2001 From: Christian Heimes Date: Wed, 27 May 2020 21:50:06 +0200 Subject: bpo-40791: Use CRYPTO_memcmp() for compare_digest (#20456) hashlib.compare_digest uses OpenSSL's CRYPTO_memcmp() function when OpenSSL is available. Note: The _operator module is a builtin module. I don't want to add libcrypto dependency to libpython. Therefore I duplicated the wrapper function and added a copy to _hashopenssl.c. --- Doc/library/hmac.rst | 5 + Lib/hmac.py | 3 +- Lib/test/test_hmac.py | 88 +++++++++------- .../2020-05-27-18-04-52.bpo-40791.IzpNor.rst | 2 + Modules/_hashopenssl.c | 116 +++++++++++++++++++++ Modules/_operator.c | 2 + Modules/clinic/_hashopenssl.c.h | 42 +++++++- 7 files changed, 221 insertions(+), 37 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2020-05-27-18-04-52.bpo-40791.IzpNor.rst diff --git a/Doc/library/hmac.rst b/Doc/library/hmac.rst index 5ad3484..6f1b59b 100644 --- a/Doc/library/hmac.rst +++ b/Doc/library/hmac.rst @@ -138,6 +138,11 @@ This module also provides the following helper function: .. versionadded:: 3.3 + .. versionchanged:: 3.10 + + The function uses OpenSSL's ``CRYPTO_memcmp()`` internally when + available. + .. seealso:: diff --git a/Lib/hmac.py b/Lib/hmac.py index 54a1ef9..180bc37 100644 --- a/Lib/hmac.py +++ b/Lib/hmac.py @@ -4,14 +4,15 @@ Implements the HMAC algorithm as described by RFC 2104. """ import warnings as _warnings -from _operator import _compare_digest as compare_digest try: import _hashlib as _hashopenssl except ImportError: _hashopenssl = None _openssl_md_meths = None + from _operator import _compare_digest as compare_digest else: _openssl_md_meths = frozenset(_hashopenssl.openssl_md_meth_names) + compare_digest = _hashopenssl.compare_digest import hashlib as _hashlib trans_5C = bytes((x ^ 0x5C) for x in range(256)) diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py index 7a52e39..6daf22c 100644 --- a/Lib/test/test_hmac.py +++ b/Lib/test/test_hmac.py @@ -8,12 +8,16 @@ import warnings from test.support import hashlib_helper +from _operator import _compare_digest as operator_compare_digest + try: from _hashlib import HMAC as C_HMAC from _hashlib import hmac_new as c_hmac_new + from _hashlib import compare_digest as openssl_compare_digest except ImportError: C_HMAC = None c_hmac_new = None + openssl_compare_digest = None def ignore_warning(func): @@ -505,87 +509,101 @@ class CopyTestCase(unittest.TestCase): class CompareDigestTestCase(unittest.TestCase): - def test_compare_digest(self): + def test_hmac_compare_digest(self): + self._test_compare_digest(hmac.compare_digest) + if openssl_compare_digest is not None: + self.assertIs(hmac.compare_digest, openssl_compare_digest) + else: + self.assertIs(hmac.compare_digest, operator_compare_digest) + + def test_operator_compare_digest(self): + self._test_compare_digest(operator_compare_digest) + + @unittest.skipIf(openssl_compare_digest is None, "test requires _hashlib") + def test_openssl_compare_digest(self): + self._test_compare_digest(openssl_compare_digest) + + def _test_compare_digest(self, compare_digest): # Testing input type exception handling a, b = 100, 200 - self.assertRaises(TypeError, hmac.compare_digest, a, b) + self.assertRaises(TypeError, compare_digest, a, b) a, b = 100, b"foobar" - self.assertRaises(TypeError, hmac.compare_digest, a, b) + self.assertRaises(TypeError, compare_digest, a, b) a, b = b"foobar", 200 - self.assertRaises(TypeError, hmac.compare_digest, a, b) + self.assertRaises(TypeError, compare_digest, a, b) a, b = "foobar", b"foobar" - self.assertRaises(TypeError, hmac.compare_digest, a, b) + self.assertRaises(TypeError, compare_digest, a, b) a, b = b"foobar", "foobar" - self.assertRaises(TypeError, hmac.compare_digest, a, b) + self.assertRaises(TypeError, compare_digest, a, b) # Testing bytes of different lengths a, b = b"foobar", b"foo" - self.assertFalse(hmac.compare_digest(a, b)) + self.assertFalse(compare_digest(a, b)) a, b = b"\xde\xad\xbe\xef", b"\xde\xad" - self.assertFalse(hmac.compare_digest(a, b)) + self.assertFalse(compare_digest(a, b)) # Testing bytes of same lengths, different values a, b = b"foobar", b"foobaz" - self.assertFalse(hmac.compare_digest(a, b)) + self.assertFalse(compare_digest(a, b)) a, b = b"\xde\xad\xbe\xef", b"\xab\xad\x1d\xea" - self.assertFalse(hmac.compare_digest(a, b)) + self.assertFalse(compare_digest(a, b)) # Testing bytes of same lengths, same values a, b = b"foobar", b"foobar" - self.assertTrue(hmac.compare_digest(a, b)) + self.assertTrue(compare_digest(a, b)) a, b = b"\xde\xad\xbe\xef", b"\xde\xad\xbe\xef" - self.assertTrue(hmac.compare_digest(a, b)) + self.assertTrue(compare_digest(a, b)) # Testing bytearrays of same lengths, same values a, b = bytearray(b"foobar"), bytearray(b"foobar") - self.assertTrue(hmac.compare_digest(a, b)) + self.assertTrue(compare_digest(a, b)) # Testing bytearrays of different lengths a, b = bytearray(b"foobar"), bytearray(b"foo") - self.assertFalse(hmac.compare_digest(a, b)) + self.assertFalse(compare_digest(a, b)) # Testing bytearrays of same lengths, different values a, b = bytearray(b"foobar"), bytearray(b"foobaz") - self.assertFalse(hmac.compare_digest(a, b)) + self.assertFalse(compare_digest(a, b)) # Testing byte and bytearray of same lengths, same values a, b = bytearray(b"foobar"), b"foobar" - self.assertTrue(hmac.compare_digest(a, b)) - self.assertTrue(hmac.compare_digest(b, a)) + self.assertTrue(compare_digest(a, b)) + self.assertTrue(compare_digest(b, a)) # Testing byte bytearray of different lengths a, b = bytearray(b"foobar"), b"foo" - self.assertFalse(hmac.compare_digest(a, b)) - self.assertFalse(hmac.compare_digest(b, a)) + self.assertFalse(compare_digest(a, b)) + self.assertFalse(compare_digest(b, a)) # Testing byte and bytearray of same lengths, different values a, b = bytearray(b"foobar"), b"foobaz" - self.assertFalse(hmac.compare_digest(a, b)) - self.assertFalse(hmac.compare_digest(b, a)) + self.assertFalse(compare_digest(a, b)) + self.assertFalse(compare_digest(b, a)) # Testing str of same lengths a, b = "foobar", "foobar" - self.assertTrue(hmac.compare_digest(a, b)) + self.assertTrue(compare_digest(a, b)) # Testing str of different lengths a, b = "foo", "foobar" - self.assertFalse(hmac.compare_digest(a, b)) + self.assertFalse(compare_digest(a, b)) # Testing bytes of same lengths, different values a, b = "foobar", "foobaz" - self.assertFalse(hmac.compare_digest(a, b)) + self.assertFalse(compare_digest(a, b)) # Testing error cases a, b = "foobar", b"foobar" - self.assertRaises(TypeError, hmac.compare_digest, a, b) + self.assertRaises(TypeError, compare_digest, a, b) a, b = b"foobar", "foobar" - self.assertRaises(TypeError, hmac.compare_digest, a, b) + self.assertRaises(TypeError, compare_digest, a, b) a, b = b"foobar", 1 - self.assertRaises(TypeError, hmac.compare_digest, a, b) + self.assertRaises(TypeError, compare_digest, a, b) a, b = 100, 200 - self.assertRaises(TypeError, hmac.compare_digest, a, b) + self.assertRaises(TypeError, compare_digest, a, b) a, b = "fooä", "fooä" - self.assertRaises(TypeError, hmac.compare_digest, a, b) + self.assertRaises(TypeError, compare_digest, a, b) # subclasses are supported by ignore __eq__ class mystr(str): @@ -593,22 +611,22 @@ class CompareDigestTestCase(unittest.TestCase): return False a, b = mystr("foobar"), mystr("foobar") - self.assertTrue(hmac.compare_digest(a, b)) + self.assertTrue(compare_digest(a, b)) a, b = mystr("foobar"), "foobar" - self.assertTrue(hmac.compare_digest(a, b)) + self.assertTrue(compare_digest(a, b)) a, b = mystr("foobar"), mystr("foobaz") - self.assertFalse(hmac.compare_digest(a, b)) + self.assertFalse(compare_digest(a, b)) class mybytes(bytes): def __eq__(self, other): return False a, b = mybytes(b"foobar"), mybytes(b"foobar") - self.assertTrue(hmac.compare_digest(a, b)) + self.assertTrue(compare_digest(a, b)) a, b = mybytes(b"foobar"), b"foobar" - self.assertTrue(hmac.compare_digest(a, b)) + self.assertTrue(compare_digest(a, b)) a, b = mybytes(b"foobar"), mybytes(b"foobaz") - self.assertFalse(hmac.compare_digest(a, b)) + self.assertFalse(compare_digest(a, b)) if __name__ == "__main__": diff --git a/Misc/NEWS.d/next/Library/2020-05-27-18-04-52.bpo-40791.IzpNor.rst b/Misc/NEWS.d/next/Library/2020-05-27-18-04-52.bpo-40791.IzpNor.rst new file mode 100644 index 0000000..b88f308 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2020-05-27-18-04-52.bpo-40791.IzpNor.rst @@ -0,0 +1,2 @@ +:func:`hashlib.compare_digest` uses OpenSSL's ``CRYPTO_memcmp()`` function +when OpenSSL is available. diff --git a/Modules/_hashopenssl.c b/Modules/_hashopenssl.c index 0b2ef95..adc8653 100644 --- a/Modules/_hashopenssl.c +++ b/Modules/_hashopenssl.c @@ -21,6 +21,7 @@ /* EVP is the preferred interface to hashing in OpenSSL */ #include #include +#include /* We use the object interface to discover what hashes OpenSSL supports. */ #include #include "openssl/err.h" @@ -1833,6 +1834,120 @@ _hashlib_get_fips_mode_impl(PyObject *module) #endif // !LIBRESSL_VERSION_NUMBER +static int +_tscmp(const unsigned char *a, const unsigned char *b, + Py_ssize_t len_a, Py_ssize_t len_b) +{ + /* loop count depends on length of b. Might leak very little timing + * information if sizes are different. + */ + Py_ssize_t length = len_b; + const void *left = a; + const void *right = b; + int result = 0; + + if (len_a != length) { + left = b; + result = 1; + } + + result |= CRYPTO_memcmp(left, right, length); + + return (result == 0); +} + +/* NOTE: Keep in sync with _operator.c implementation. */ + +/*[clinic input] +_hashlib.compare_digest + + a: object + b: object + / + +Return 'a == b'. + +This function uses an approach designed to prevent +timing analysis, making it appropriate for cryptography. + +a and b must both be of the same type: either str (ASCII only), +or any bytes-like object. + +Note: If a and b are of different lengths, or if an error occurs, +a timing attack could theoretically reveal information about the +types and lengths of a and b--but not their values. +[clinic start generated code]*/ + +static PyObject * +_hashlib_compare_digest_impl(PyObject *module, PyObject *a, PyObject *b) +/*[clinic end generated code: output=6f1c13927480aed9 input=9c40c6e566ca12f5]*/ +{ + int rc; + + /* ASCII unicode string */ + if(PyUnicode_Check(a) && PyUnicode_Check(b)) { + if (PyUnicode_READY(a) == -1 || PyUnicode_READY(b) == -1) { + return NULL; + } + if (!PyUnicode_IS_ASCII(a) || !PyUnicode_IS_ASCII(b)) { + PyErr_SetString(PyExc_TypeError, + "comparing strings with non-ASCII characters is " + "not supported"); + return NULL; + } + + rc = _tscmp(PyUnicode_DATA(a), + PyUnicode_DATA(b), + PyUnicode_GET_LENGTH(a), + PyUnicode_GET_LENGTH(b)); + } + /* fallback to buffer interface for bytes, bytesarray and other */ + else { + Py_buffer view_a; + Py_buffer view_b; + + if (PyObject_CheckBuffer(a) == 0 && PyObject_CheckBuffer(b) == 0) { + PyErr_Format(PyExc_TypeError, + "unsupported operand types(s) or combination of types: " + "'%.100s' and '%.100s'", + Py_TYPE(a)->tp_name, Py_TYPE(b)->tp_name); + return NULL; + } + + if (PyObject_GetBuffer(a, &view_a, PyBUF_SIMPLE) == -1) { + return NULL; + } + if (view_a.ndim > 1) { + PyErr_SetString(PyExc_BufferError, + "Buffer must be single dimension"); + PyBuffer_Release(&view_a); + return NULL; + } + + if (PyObject_GetBuffer(b, &view_b, PyBUF_SIMPLE) == -1) { + PyBuffer_Release(&view_a); + return NULL; + } + if (view_b.ndim > 1) { + PyErr_SetString(PyExc_BufferError, + "Buffer must be single dimension"); + PyBuffer_Release(&view_a); + PyBuffer_Release(&view_b); + return NULL; + } + + rc = _tscmp((const unsigned char*)view_a.buf, + (const unsigned char*)view_b.buf, + view_a.len, + view_b.len); + + PyBuffer_Release(&view_a); + PyBuffer_Release(&view_b); + } + + return PyBool_FromLong(rc); +} + /* List of functions exported by this module */ static struct PyMethodDef EVP_functions[] = { @@ -1840,6 +1955,7 @@ static struct PyMethodDef EVP_functions[] = { PBKDF2_HMAC_METHODDEF _HASHLIB_SCRYPT_METHODDEF _HASHLIB_GET_FIPS_MODE_METHODDEF + _HASHLIB_COMPARE_DIGEST_METHODDEF _HASHLIB_HMAC_SINGLESHOT_METHODDEF _HASHLIB_HMAC_NEW_METHODDEF _HASHLIB_OPENSSL_MD5_METHODDEF diff --git a/Modules/_operator.c b/Modules/_operator.c index 19026b6..8a54829 100644 --- a/Modules/_operator.c +++ b/Modules/_operator.c @@ -785,6 +785,8 @@ _operator_length_hint_impl(PyObject *module, PyObject *obj, return PyObject_LengthHint(obj, default_value); } +/* NOTE: Keep in sync with _hashopenssl.c implementation. */ + /*[clinic input] _operator._compare_digest = _operator.eq diff --git a/Modules/clinic/_hashopenssl.c.h b/Modules/clinic/_hashopenssl.c.h index 619cb1c..51ae240 100644 --- a/Modules/clinic/_hashopenssl.c.h +++ b/Modules/clinic/_hashopenssl.c.h @@ -1338,6 +1338,46 @@ exit: #endif /* !defined(LIBRESSL_VERSION_NUMBER) */ +PyDoc_STRVAR(_hashlib_compare_digest__doc__, +"compare_digest($module, a, b, /)\n" +"--\n" +"\n" +"Return \'a == b\'.\n" +"\n" +"This function uses an approach designed to prevent\n" +"timing analysis, making it appropriate for cryptography.\n" +"\n" +"a and b must both be of the same type: either str (ASCII only),\n" +"or any bytes-like object.\n" +"\n" +"Note: If a and b are of different lengths, or if an error occurs,\n" +"a timing attack could theoretically reveal information about the\n" +"types and lengths of a and b--but not their values."); + +#define _HASHLIB_COMPARE_DIGEST_METHODDEF \ + {"compare_digest", (PyCFunction)(void(*)(void))_hashlib_compare_digest, METH_FASTCALL, _hashlib_compare_digest__doc__}, + +static PyObject * +_hashlib_compare_digest_impl(PyObject *module, PyObject *a, PyObject *b); + +static PyObject * +_hashlib_compare_digest(PyObject *module, PyObject *const *args, Py_ssize_t nargs) +{ + PyObject *return_value = NULL; + PyObject *a; + PyObject *b; + + if (!_PyArg_CheckPositional("compare_digest", nargs, 2, 2)) { + goto exit; + } + a = args[0]; + b = args[1]; + return_value = _hashlib_compare_digest_impl(module, a, b); + +exit: + return return_value; +} + #ifndef EVPXOF_DIGEST_METHODDEF #define EVPXOF_DIGEST_METHODDEF #endif /* !defined(EVPXOF_DIGEST_METHODDEF) */ @@ -1377,4 +1417,4 @@ exit: #ifndef _HASHLIB_GET_FIPS_MODE_METHODDEF #define _HASHLIB_GET_FIPS_MODE_METHODDEF #endif /* !defined(_HASHLIB_GET_FIPS_MODE_METHODDEF) */ -/*[clinic end generated code: output=d8dddcd85fb11dde input=a9049054013a1b77]*/ +/*[clinic end generated code: output=95447a60132f039e input=a9049054013a1b77]*/ -- cgit v0.12