diff options
-rw-r--r-- | Lib/test/test_math.py | 121 | ||||
-rw-r--r-- | Modules/mathmodule.c | 56 |
2 files changed, 137 insertions, 40 deletions
diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index 856b1e8..cb05dee 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -1595,6 +1595,92 @@ class MathTests(unittest.TestCase): self.fail('Failures in test_mtestfile:\n ' + '\n '.join(failures)) + def test_prod(self): + prod = math.prod + self.assertEqual(prod([]), 1) + self.assertEqual(prod([], start=5), 5) + self.assertEqual(prod(list(range(2,8))), 5040) + self.assertEqual(prod(iter(list(range(2,8)))), 5040) + self.assertEqual(prod(range(1, 10), start=10), 3628800) + + self.assertEqual(prod([1, 2, 3, 4, 5]), 120) + self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0) + self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0) + self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0) + + # Test overflow in fast-path for integers + self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32) + # Test overflow in fast-path for floats + self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32)) + + self.assertRaises(TypeError, prod) + self.assertRaises(TypeError, prod, 42) + self.assertRaises(TypeError, prod, ['a', 'b', 'c']) + self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '') + self.assertRaises(TypeError, prod, [b'a', b'c'], b'') + values = [bytearray(b'a'), bytearray(b'b')] + self.assertRaises(TypeError, prod, values, bytearray(b'')) + self.assertRaises(TypeError, prod, [[1], [2], [3]]) + self.assertRaises(TypeError, prod, [{2:3}]) + self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3}) + self.assertRaises(TypeError, prod, [[1], [2], [3]], []) + with self.assertRaises(TypeError): + prod([10, 20], [30, 40]) # start is a keyword-only argument + + self.assertEqual(prod([0, 1, 2, 3]), 0) + self.assertEqual(prod([1, 0, 2, 3]), 0) + self.assertEqual(prod([1, 2, 3, 0]), 0) + + def _naive_prod(iterable, start=1): + for elem in iterable: + start *= elem + return start + + # Big integers + + iterable = range(1, 10000) + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = range(-10000, -1) + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = range(-1000, 1000) + self.assertEqual(prod(iterable), 0) + + # Big floats + + iterable = [float(x) for x in range(1, 1000)] + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = [float(x) for x in range(-1000, -1)] + self.assertEqual(prod(iterable), _naive_prod(iterable)) + iterable = [float(x) for x in range(-1000, 1000)] + self.assertIsNaN(prod(iterable)) + + # Float tests + + self.assertIsNaN(prod([1, 2, 3, float("nan"), 2, 3])) + self.assertIsNaN(prod([1, 0, float("nan"), 2, 3])) + self.assertIsNaN(prod([1, float("nan"), 0, 3])) + self.assertIsNaN(prod([1, float("inf"), float("nan"),3])) + self.assertIsNaN(prod([1, float("-inf"), float("nan"),3])) + self.assertIsNaN(prod([1, float("nan"), float("inf"),3])) + self.assertIsNaN(prod([1, float("nan"), float("-inf"),3])) + + self.assertEqual(prod([1, 2, 3, float('inf'),-3,4]), float('-inf')) + self.assertEqual(prod([1, 2, 3, float('-inf'),-3,4]), float('inf')) + + self.assertIsNaN(prod([1,2,0,float('inf'), -3, 4])) + self.assertIsNaN(prod([1,2,0,float('-inf'), -3, 4])) + self.assertIsNaN(prod([1, 2, 3, float('inf'), -3, 0, 3])) + self.assertIsNaN(prod([1, 2, 3, float('-inf'), -3, 0, 2])) + + # Type preservation + + self.assertEqual(type(prod([1, 2, 3, 4, 5, 6])), int) + self.assertEqual(type(prod([1, 2.0, 3, 4, 5, 6])), float) + self.assertEqual(type(prod(range(1, 10000))), int) + self.assertEqual(type(prod(range(1, 10000), start=1.0)), float) + self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])), + decimal.Decimal) + # Custom assertions. def assertIsNaN(self, value): @@ -1724,41 +1810,6 @@ class IsCloseTests(unittest.TestCase): self.assertAllClose(fraction_examples, rel_tol=1e-8) self.assertAllNotClose(fraction_examples, rel_tol=1e-9) - def test_prod(self): - prod = math.prod - self.assertEqual(prod([]), 1) - self.assertEqual(prod([], start=5), 5) - self.assertEqual(prod(list(range(2,8))), 5040) - self.assertEqual(prod(iter(list(range(2,8)))), 5040) - self.assertEqual(prod(range(1, 10), start=10), 3628800) - - self.assertEqual(prod([1, 2, 3, 4, 5]), 120) - self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0) - self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0) - self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0) - - # Test overflow in fast-path for integers - self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32) - # Test overflow in fast-path for floats - self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32)) - - self.assertRaises(TypeError, prod) - self.assertRaises(TypeError, prod, 42) - self.assertRaises(TypeError, prod, ['a', 'b', 'c']) - self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '') - self.assertRaises(TypeError, prod, [b'a', b'c'], b'') - values = [bytearray(b'a'), bytearray(b'b')] - self.assertRaises(TypeError, prod, values, bytearray(b'')) - self.assertRaises(TypeError, prod, [[1], [2], [3]]) - self.assertRaises(TypeError, prod, [{2:3}]) - self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3}) - self.assertRaises(TypeError, prod, [[1], [2], [3]], []) - with self.assertRaises(TypeError): - prod([10, 20], [30, 40]) # start is a keyword-only argument - - self.assertEqual(prod([0, 1, 2, 3]), 0) - self.assertEqual(prod([1, 0, 2, 3]), 0) - self.assertEqual(prod(range(10)), 0) def test_main(): from doctest import DocFileSuite diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index fd0eb32..ba84232 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -2493,6 +2493,55 @@ math_isclose_impl(PyObject *module, double a, double b, double rel_tol, (diff <= abs_tol)); } +static inline int +_check_long_mult_overflow(long a, long b) { + + /* From Python2's int_mul code: + + Integer overflow checking for * is painful: Python tried a couple ways, but + they didn't work on all platforms, or failed in endcases (a product of + -sys.maxint-1 has been a particular pain). + + Here's another way: + + The native long product x*y is either exactly right or *way* off, being + just the last n bits of the true product, where n is the number of bits + in a long (the delivered product is the true product plus i*2**n for + some integer i). + + The native double product (double)x * (double)y is subject to three + rounding errors: on a sizeof(long)==8 box, each cast to double can lose + info, and even on a sizeof(long)==4 box, the multiplication can lose info. + But, unlike the native long product, it's not in *range* trouble: even + if sizeof(long)==32 (256-bit longs), the product easily fits in the + dynamic range of a double. So the leading 50 (or so) bits of the double + product are correct. + + We check these two ways against each other, and declare victory if they're + approximately the same. Else, because the native long product is the only + one that can lose catastrophic amounts of information, it's the native long + product that must have overflowed. + + */ + + long longprod = (long)((unsigned long)a * b); + double doubleprod = (double)a * (double)b; + double doubled_longprod = (double)longprod; + + if (doubled_longprod == doubleprod) { + return 0; + } + + const double diff = doubled_longprod - doubleprod; + const double absdiff = diff >= 0.0 ? diff : -diff; + const double absprod = doubleprod >= 0.0 ? doubleprod : -doubleprod; + + if (32.0 * absdiff <= absprod) { + return 0; + } + + return 1; +} /*[clinic input] math.prod @@ -2558,11 +2607,8 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start) } if (PyLong_CheckExact(item)) { long b = PyLong_AsLongAndOverflow(item, &overflow); - long x = i_result * b; - /* Continue if there is no overflow */ - if (overflow == 0 - && x < LONG_MAX && x > LONG_MIN - && !(b != 0 && x / b != i_result)) { + if (overflow == 0 && !_check_long_mult_overflow(i_result, b)) { + long x = i_result * b; i_result = x; Py_DECREF(item); continue; |