summaryrefslogtreecommitdiffstats
path: root/Lib/test
diff options
context:
space:
mode:
authorPablo Galindo <Pablogsal@gmail.com>2019-03-09 19:18:08 (GMT)
committerGitHub <noreply@github.com>2019-03-09 19:18:08 (GMT)
commit0411411c6b16a574144dfb59a7780b057ca8e750 (patch)
treec776766542eb4e21b4e462dc8b84e12a7ab40a3d /Lib/test
parent62fa51f1216e788310d3118f4259f1b4b1e529fe (diff)
downloadcpython-0411411c6b16a574144dfb59a7780b057ca8e750.zip
cpython-0411411c6b16a574144dfb59a7780b057ca8e750.tar.gz
cpython-0411411c6b16a574144dfb59a7780b057ca8e750.tar.bz2
Rework integer overflow path in math.prod and add more tests (GH-11809)
The overflow check was relying on undefined behaviour as it was using the result of the multiplication to do the check, and once the overflow has already happened, any operation on the result is undefined behaviour. Some extra checks that exercise code paths related to this are also added.
Diffstat (limited to 'Lib/test')
-rw-r--r--Lib/test/test_math.py121
1 files changed, 86 insertions, 35 deletions
diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py
index 856b1e8..cb05dee 100644
--- a/Lib/test/test_math.py
+++ b/Lib/test/test_math.py
@@ -1595,6 +1595,92 @@ class MathTests(unittest.TestCase):
self.fail('Failures in test_mtestfile:\n ' +
'\n '.join(failures))
+ def test_prod(self):
+ prod = math.prod
+ self.assertEqual(prod([]), 1)
+ self.assertEqual(prod([], start=5), 5)
+ self.assertEqual(prod(list(range(2,8))), 5040)
+ self.assertEqual(prod(iter(list(range(2,8)))), 5040)
+ self.assertEqual(prod(range(1, 10), start=10), 3628800)
+
+ self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
+ self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
+ self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
+ self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
+
+ # Test overflow in fast-path for integers
+ self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
+ # Test overflow in fast-path for floats
+ self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
+
+ self.assertRaises(TypeError, prod)
+ self.assertRaises(TypeError, prod, 42)
+ self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
+ self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
+ self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
+ values = [bytearray(b'a'), bytearray(b'b')]
+ self.assertRaises(TypeError, prod, values, bytearray(b''))
+ self.assertRaises(TypeError, prod, [[1], [2], [3]])
+ self.assertRaises(TypeError, prod, [{2:3}])
+ self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
+ self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
+ with self.assertRaises(TypeError):
+ prod([10, 20], [30, 40]) # start is a keyword-only argument
+
+ self.assertEqual(prod([0, 1, 2, 3]), 0)
+ self.assertEqual(prod([1, 0, 2, 3]), 0)
+ self.assertEqual(prod([1, 2, 3, 0]), 0)
+
+ def _naive_prod(iterable, start=1):
+ for elem in iterable:
+ start *= elem
+ return start
+
+ # Big integers
+
+ iterable = range(1, 10000)
+ self.assertEqual(prod(iterable), _naive_prod(iterable))
+ iterable = range(-10000, -1)
+ self.assertEqual(prod(iterable), _naive_prod(iterable))
+ iterable = range(-1000, 1000)
+ self.assertEqual(prod(iterable), 0)
+
+ # Big floats
+
+ iterable = [float(x) for x in range(1, 1000)]
+ self.assertEqual(prod(iterable), _naive_prod(iterable))
+ iterable = [float(x) for x in range(-1000, -1)]
+ self.assertEqual(prod(iterable), _naive_prod(iterable))
+ iterable = [float(x) for x in range(-1000, 1000)]
+ self.assertIsNaN(prod(iterable))
+
+ # Float tests
+
+ self.assertIsNaN(prod([1, 2, 3, float("nan"), 2, 3]))
+ self.assertIsNaN(prod([1, 0, float("nan"), 2, 3]))
+ self.assertIsNaN(prod([1, float("nan"), 0, 3]))
+ self.assertIsNaN(prod([1, float("inf"), float("nan"),3]))
+ self.assertIsNaN(prod([1, float("-inf"), float("nan"),3]))
+ self.assertIsNaN(prod([1, float("nan"), float("inf"),3]))
+ self.assertIsNaN(prod([1, float("nan"), float("-inf"),3]))
+
+ self.assertEqual(prod([1, 2, 3, float('inf'),-3,4]), float('-inf'))
+ self.assertEqual(prod([1, 2, 3, float('-inf'),-3,4]), float('inf'))
+
+ self.assertIsNaN(prod([1,2,0,float('inf'), -3, 4]))
+ self.assertIsNaN(prod([1,2,0,float('-inf'), -3, 4]))
+ self.assertIsNaN(prod([1, 2, 3, float('inf'), -3, 0, 3]))
+ self.assertIsNaN(prod([1, 2, 3, float('-inf'), -3, 0, 2]))
+
+ # Type preservation
+
+ self.assertEqual(type(prod([1, 2, 3, 4, 5, 6])), int)
+ self.assertEqual(type(prod([1, 2.0, 3, 4, 5, 6])), float)
+ self.assertEqual(type(prod(range(1, 10000))), int)
+ self.assertEqual(type(prod(range(1, 10000), start=1.0)), float)
+ self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])),
+ decimal.Decimal)
+
# Custom assertions.
def assertIsNaN(self, value):
@@ -1724,41 +1810,6 @@ class IsCloseTests(unittest.TestCase):
self.assertAllClose(fraction_examples, rel_tol=1e-8)
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
- def test_prod(self):
- prod = math.prod
- self.assertEqual(prod([]), 1)
- self.assertEqual(prod([], start=5), 5)
- self.assertEqual(prod(list(range(2,8))), 5040)
- self.assertEqual(prod(iter(list(range(2,8)))), 5040)
- self.assertEqual(prod(range(1, 10), start=10), 3628800)
-
- self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
- self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
- self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
- self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
-
- # Test overflow in fast-path for integers
- self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
- # Test overflow in fast-path for floats
- self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
-
- self.assertRaises(TypeError, prod)
- self.assertRaises(TypeError, prod, 42)
- self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
- self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
- self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
- values = [bytearray(b'a'), bytearray(b'b')]
- self.assertRaises(TypeError, prod, values, bytearray(b''))
- self.assertRaises(TypeError, prod, [[1], [2], [3]])
- self.assertRaises(TypeError, prod, [{2:3}])
- self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
- self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
- with self.assertRaises(TypeError):
- prod([10, 20], [30, 40]) # start is a keyword-only argument
-
- self.assertEqual(prod([0, 1, 2, 3]), 0)
- self.assertEqual(prod([1, 0, 2, 3]), 0)
- self.assertEqual(prod(range(10)), 0)
def test_main():
from doctest import DocFileSuite