From 0411411c6b16a574144dfb59a7780b057ca8e750 Mon Sep 17 00:00:00 2001 From: Pablo Galindo Date: Sat, 9 Mar 2019 19:18:08 +0000 Subject: Rework integer overflow path in math.prod and add more tests (GH-11809) The overflow check was relying on undefined behaviour as it was using the result of the multiplication to do the check, and once the overflow has already happened, any operation on the result is undefined behaviour. Some extra checks that exercise code paths related to this are also added. --- Lib/test/test_math.py | 121 +++++++++++++++++++++++++++++++++++--------------- Modules/mathmodule.c | 56 ++++++++++++++++++++--- 2 files changed, 137 insertions(+), 40 deletions(-) diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index 856b1e8..cb05dee 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -1595,6 +1595,92 @@ class MathTests(unittest.TestCase): self.fail('Failures in test_mtestfile:\n ' + '\n '.join(failures)) + def test_prod(self): + prod = math.prod + self.assertEqual(prod([]), 1) + self.assertEqual(prod([], start=5), 5) + self.assertEqual(prod(list(range(2,8))), 5040) + self.assertEqual(prod(iter(list(range(2,8)))), 5040) + self.assertEqual(prod(range(1, 10), start=10), 3628800) + + self.assertEqual(prod([1, 2, 3, 4, 5]), 120) + self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0) + self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0) + self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0) + + # Test overflow in fast-path for integers + self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32) + # Test overflow in fast-path for floats + self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32)) + + self.assertRaises(TypeError, prod) + self.assertRaises(TypeError, prod, 42) + self.assertRaises(TypeError, prod, ['a', 'b', 'c']) + self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '') + self.assertRaises(TypeError, prod, [b'a', b'c'], b'') + values = [bytearray(b'a'), bytearray(b'b')] + self.assertRaises(TypeError, prod, values, bytearray(b'')) + self.assertRaises(TypeError, prod, [[1], [2], [3]]) + self.assertRaises(TypeError, prod, [{2:3}]) + self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3}) + self.assertRaises(TypeError, prod, [[1], [2], [3]], []) + with self.assertRaises(TypeError): + prod([10, 20], [30, 40]) # start is a keyword-only argument + + self.assertEqual(prod([0, 1, 2, 3]), 0) + self.assertEqual(prod([1, 0, 2, 3]), 0) + self.assertEqual(prod([1, 2, 3, 0]), 0) + + def _naive_prod(iterable, start=1): + for elem in iterable: + start *= elem + return start + + # Big integers + + iterable = range(1, 10000) + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = range(-10000, -1) + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = range(-1000, 1000) + self.assertEqual(prod(iterable), 0) + + # Big floats + + iterable = [float(x) for x in range(1, 1000)] + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = [float(x) for x in range(-1000, -1)] + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = [float(x) for x in range(-1000, 1000)] + self.assertIsNaN(prod(iterable)) + + # Float tests + + self.assertIsNaN(prod([1, 2, 3, float("nan"), 2, 3])) + self.assertIsNaN(prod([1, 0, float("nan"), 2, 3])) + self.assertIsNaN(prod([1, float("nan"), 0, 3])) + self.assertIsNaN(prod([1, float("inf"), float("nan"),3])) + self.assertIsNaN(prod([1, float("-inf"), float("nan"),3])) + self.assertIsNaN(prod([1, float("nan"), float("inf"),3])) + self.assertIsNaN(prod([1, float("nan"), float("-inf"),3])) + + self.assertEqual(prod([1, 2, 3, float('inf'),-3,4]), float('-inf')) + self.assertEqual(prod([1, 2, 3, float('-inf'),-3,4]), float('inf')) + + self.assertIsNaN(prod([1,2,0,float('inf'), -3, 4])) + self.assertIsNaN(prod([1,2,0,float('-inf'), -3, 4])) + self.assertIsNaN(prod([1, 2, 3, float('inf'), -3, 0, 3])) + self.assertIsNaN(prod([1, 2, 3, float('-inf'), -3, 0, 2])) + + # Type preservation + + self.assertEqual(type(prod([1, 2, 3, 4, 5, 6])), int) + self.assertEqual(type(prod([1, 2.0, 3, 4, 5, 6])), float) + self.assertEqual(type(prod(range(1, 10000))), int) + self.assertEqual(type(prod(range(1, 10000), start=1.0)), float) + self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])), + decimal.Decimal) + # Custom assertions. def assertIsNaN(self, value): @@ -1724,41 +1810,6 @@ class IsCloseTests(unittest.TestCase): self.assertAllClose(fraction_examples, rel_tol=1e-8) self.assertAllNotClose(fraction_examples, rel_tol=1e-9) - def test_prod(self): - prod = math.prod - self.assertEqual(prod([]), 1) - self.assertEqual(prod([], start=5), 5) - self.assertEqual(prod(list(range(2,8))), 5040) - self.assertEqual(prod(iter(list(range(2,8)))), 5040) - self.assertEqual(prod(range(1, 10), start=10), 3628800) - - self.assertEqual(prod([1, 2, 3, 4, 5]), 120) - self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0) - self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0) - self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0) - - # Test overflow in fast-path for integers - self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32) - # Test overflow in fast-path for floats - self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32)) - - self.assertRaises(TypeError, prod) - self.assertRaises(TypeError, prod, 42) - self.assertRaises(TypeError, prod, ['a', 'b', 'c']) - self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '') - self.assertRaises(TypeError, prod, [b'a', b'c'], b'') - values = [bytearray(b'a'), bytearray(b'b')] - self.assertRaises(TypeError, prod, values, bytearray(b'')) - self.assertRaises(TypeError, prod, [[1], [2], [3]]) - self.assertRaises(TypeError, prod, [{2:3}]) - self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3}) - self.assertRaises(TypeError, prod, [[1], [2], [3]], []) - with self.assertRaises(TypeError): - prod([10, 20], [30, 40]) # start is a keyword-only argument - - self.assertEqual(prod([0, 1, 2, 3]), 0) - self.assertEqual(prod([1, 0, 2, 3]), 0) - self.assertEqual(prod(range(10)), 0) def test_main(): from doctest import DocFileSuite diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index fd0eb32..ba84232 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -2493,6 +2493,55 @@ math_isclose_impl(PyObject *module, double a, double b, double rel_tol, (diff <= abs_tol)); } +static inline int +_check_long_mult_overflow(long a, long b) { + + /* From Python2's int_mul code: + + Integer overflow checking for * is painful: Python tried a couple ways, but + they didn't work on all platforms, or failed in endcases (a product of + -sys.maxint-1 has been a particular pain). + + Here's another way: + + The native long product x*y is either exactly right or *way* off, being + just the last n bits of the true product, where n is the number of bits + in a long (the delivered product is the true product plus i*2**n for + some integer i). + + The native double product (double)x * (double)y is subject to three + rounding errors: on a sizeof(long)==8 box, each cast to double can lose + info, and even on a sizeof(long)==4 box, the multiplication can lose info. + But, unlike the native long product, it's not in *range* trouble: even + if sizeof(long)==32 (256-bit longs), the product easily fits in the + dynamic range of a double. So the leading 50 (or so) bits of the double + product are correct. + + We check these two ways against each other, and declare victory if they're + approximately the same. Else, because the native long product is the only + one that can lose catastrophic amounts of information, it's the native long + product that must have overflowed. + + */ + + long longprod = (long)((unsigned long)a * b); + double doubleprod = (double)a * (double)b; + double doubled_longprod = (double)longprod; + + if (doubled_longprod == doubleprod) { + return 0; + } + + const double diff = doubled_longprod - doubleprod; + const double absdiff = diff >= 0.0 ? diff : -diff; + const double absprod = doubleprod >= 0.0 ? doubleprod : -doubleprod; + + if (32.0 * absdiff <= absprod) { + return 0; + } + + return 1; +} /*[clinic input] math.prod @@ -2558,11 +2607,8 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start) } if (PyLong_CheckExact(item)) { long b = PyLong_AsLongAndOverflow(item, &overflow); - long x = i_result * b; - /* Continue if there is no overflow */ - if (overflow == 0 - && x < LONG_MAX && x > LONG_MIN - && !(b != 0 && x / b != i_result)) { + if (overflow == 0 && !_check_long_mult_overflow(i_result, b)) { + long x = i_result * b; i_result = x; Py_DECREF(item); continue; -- cgit v0.12