summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_math.py121
-rw-r--r--Modules/mathmodule.c56
2 files changed, 137 insertions, 40 deletions
diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py
index 856b1e8..cb05dee 100644
--- a/Lib/test/test_math.py
+++ b/Lib/test/test_math.py
@@ -1595,6 +1595,92 @@ class MathTests(unittest.TestCase):
self.fail('Failures in test_mtestfile:\n ' +
'\n '.join(failures))
+ def test_prod(self):
+ prod = math.prod
+ self.assertEqual(prod([]), 1)
+ self.assertEqual(prod([], start=5), 5)
+ self.assertEqual(prod(list(range(2,8))), 5040)
+ self.assertEqual(prod(iter(list(range(2,8)))), 5040)
+ self.assertEqual(prod(range(1, 10), start=10), 3628800)
+
+ self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
+ self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
+ self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
+ self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
+
+ # Test overflow in fast-path for integers
+ self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
+ # Test overflow in fast-path for floats
+ self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
+
+ self.assertRaises(TypeError, prod)
+ self.assertRaises(TypeError, prod, 42)
+ self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
+ self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
+ self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
+ values = [bytearray(b'a'), bytearray(b'b')]
+ self.assertRaises(TypeError, prod, values, bytearray(b''))
+ self.assertRaises(TypeError, prod, [[1], [2], [3]])
+ self.assertRaises(TypeError, prod, [{2:3}])
+ self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
+ self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
+ with self.assertRaises(TypeError):
+ prod([10, 20], [30, 40]) # start is a keyword-only argument
+
+ self.assertEqual(prod([0, 1, 2, 3]), 0)
+ self.assertEqual(prod([1, 0, 2, 3]), 0)
+ self.assertEqual(prod([1, 2, 3, 0]), 0)
+
+ def _naive_prod(iterable, start=1):
+ for elem in iterable:
+ start *= elem
+ return start
+
+ # Big integers
+
+ iterable = range(1, 10000)
+ self.assertEqual(prod(iterable), _naive_prod(iterable))
+ iterable = range(-10000, -1)
+ self.assertEqual(prod(iterable), _naive_prod(iterable))
+ iterable = range(-1000, 1000)
+ self.assertEqual(prod(iterable), 0)
+
+ # Big floats
+
+ iterable = [float(x) for x in range(1, 1000)]
+ self.assertEqual(prod(iterable), _naive_prod(iterable))
+ iterable = [float(x) for x in range(-1000, -1)]
+ self.assertEqual(prod(iterable), _naive_prod(iterable))
+ iterable = [float(x) for x in range(-1000, 1000)]
+ self.assertIsNaN(prod(iterable))
+
+ # Float tests
+
+ self.assertIsNaN(prod([1, 2, 3, float("nan"), 2, 3]))
+ self.assertIsNaN(prod([1, 0, float("nan"), 2, 3]))
+ self.assertIsNaN(prod([1, float("nan"), 0, 3]))
+ self.assertIsNaN(prod([1, float("inf"), float("nan"),3]))
+ self.assertIsNaN(prod([1, float("-inf"), float("nan"),3]))
+ self.assertIsNaN(prod([1, float("nan"), float("inf"),3]))
+ self.assertIsNaN(prod([1, float("nan"), float("-inf"),3]))
+
+ self.assertEqual(prod([1, 2, 3, float('inf'),-3,4]), float('-inf'))
+ self.assertEqual(prod([1, 2, 3, float('-inf'),-3,4]), float('inf'))
+
+ self.assertIsNaN(prod([1,2,0,float('inf'), -3, 4]))
+ self.assertIsNaN(prod([1,2,0,float('-inf'), -3, 4]))
+ self.assertIsNaN(prod([1, 2, 3, float('inf'), -3, 0, 3]))
+ self.assertIsNaN(prod([1, 2, 3, float('-inf'), -3, 0, 2]))
+
+ # Type preservation
+
+ self.assertEqual(type(prod([1, 2, 3, 4, 5, 6])), int)
+ self.assertEqual(type(prod([1, 2.0, 3, 4, 5, 6])), float)
+ self.assertEqual(type(prod(range(1, 10000))), int)
+ self.assertEqual(type(prod(range(1, 10000), start=1.0)), float)
+ self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])),
+ decimal.Decimal)
+
# Custom assertions.
def assertIsNaN(self, value):
@@ -1724,41 +1810,6 @@ class IsCloseTests(unittest.TestCase):
self.assertAllClose(fraction_examples, rel_tol=1e-8)
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
- def test_prod(self):
- prod = math.prod
- self.assertEqual(prod([]), 1)
- self.assertEqual(prod([], start=5), 5)
- self.assertEqual(prod(list(range(2,8))), 5040)
- self.assertEqual(prod(iter(list(range(2,8)))), 5040)
- self.assertEqual(prod(range(1, 10), start=10), 3628800)
-
- self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
- self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
- self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
- self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
-
- # Test overflow in fast-path for integers
- self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
- # Test overflow in fast-path for floats
- self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
-
- self.assertRaises(TypeError, prod)
- self.assertRaises(TypeError, prod, 42)
- self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
- self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
- self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
- values = [bytearray(b'a'), bytearray(b'b')]
- self.assertRaises(TypeError, prod, values, bytearray(b''))
- self.assertRaises(TypeError, prod, [[1], [2], [3]])
- self.assertRaises(TypeError, prod, [{2:3}])
- self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
- self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
- with self.assertRaises(TypeError):
- prod([10, 20], [30, 40]) # start is a keyword-only argument
-
- self.assertEqual(prod([0, 1, 2, 3]), 0)
- self.assertEqual(prod([1, 0, 2, 3]), 0)
- self.assertEqual(prod(range(10)), 0)
def test_main():
from doctest import DocFileSuite
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index fd0eb32..ba84232 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -2493,6 +2493,55 @@ math_isclose_impl(PyObject *module, double a, double b, double rel_tol,
(diff <= abs_tol));
}
+static inline int
+_check_long_mult_overflow(long a, long b) {
+
+ /* From Python2's int_mul code:
+
+ Integer overflow checking for * is painful: Python tried a couple ways, but
+ they didn't work on all platforms, or failed in endcases (a product of
+ -sys.maxint-1 has been a particular pain).
+
+ Here's another way:
+
+ The native long product x*y is either exactly right or *way* off, being
+ just the last n bits of the true product, where n is the number of bits
+ in a long (the delivered product is the true product plus i*2**n for
+ some integer i).
+
+ The native double product (double)x * (double)y is subject to three
+ rounding errors: on a sizeof(long)==8 box, each cast to double can lose
+ info, and even on a sizeof(long)==4 box, the multiplication can lose info.
+ But, unlike the native long product, it's not in *range* trouble: even
+ if sizeof(long)==32 (256-bit longs), the product easily fits in the
+ dynamic range of a double. So the leading 50 (or so) bits of the double
+ product are correct.
+
+ We check these two ways against each other, and declare victory if they're
+ approximately the same. Else, because the native long product is the only
+ one that can lose catastrophic amounts of information, it's the native long
+ product that must have overflowed.
+
+ */
+
+ long longprod = (long)((unsigned long)a * b);
+ double doubleprod = (double)a * (double)b;
+ double doubled_longprod = (double)longprod;
+
+ if (doubled_longprod == doubleprod) {
+ return 0;
+ }
+
+ const double diff = doubled_longprod - doubleprod;
+ const double absdiff = diff >= 0.0 ? diff : -diff;
+ const double absprod = doubleprod >= 0.0 ? doubleprod : -doubleprod;
+
+ if (32.0 * absdiff <= absprod) {
+ return 0;
+ }
+
+ return 1;
+}
/*[clinic input]
math.prod
@@ -2558,11 +2607,8 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
}
if (PyLong_CheckExact(item)) {
long b = PyLong_AsLongAndOverflow(item, &overflow);
- long x = i_result * b;
- /* Continue if there is no overflow */
- if (overflow == 0
- && x < LONG_MAX && x > LONG_MIN
- && !(b != 0 && x / b != i_result)) {
+ if (overflow == 0 && !_check_long_mult_overflow(i_result, b)) {
+ long x = i_result * b;
i_result = x;
Py_DECREF(item);
continue;