diff options
author | Raymond Hettinger <rhettinger@users.noreply.github.com> | 2023-01-07 18:46:35 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-07 18:46:35 (GMT) |
commit | 47b9f83a83db288c652e43567c7b0f74d87a29be (patch) | |
tree | cb4fde0440b01f79852ef45e1a2a2fe22ba6daca /Lib/test/test_math.py | |
parent | deaf090699a7312cccb0637409f44de3f382389b (diff) | |
download | cpython-47b9f83a83db288c652e43567c7b0f74d87a29be.zip cpython-47b9f83a83db288c652e43567c7b0f74d87a29be.tar.gz cpython-47b9f83a83db288c652e43567c7b0f74d87a29be.tar.bz2 |
GH-100485: Add math.sumprod() (GH-100677)
Diffstat (limited to 'Lib/test/test_math.py')
-rw-r--r-- | Lib/test/test_math.py | 166 |
1 files changed, 166 insertions, 0 deletions
diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index bf0d0a5..65fe169 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -4,6 +4,7 @@ from test.support import verbose, requires_IEEE_754 from test import support import unittest +import fractions import itertools import decimal import math @@ -1202,6 +1203,171 @@ class MathTests(unittest.TestCase): self.assertEqual(math.log(INF), INF) self.assertTrue(math.isnan(math.log10(NAN))) + def testSumProd(self): + sumprod = math.sumprod + Decimal = decimal.Decimal + Fraction = fractions.Fraction + + # Core functionality + self.assertEqual(sumprod(iter([10, 20, 30]), (1, 2, 3)), 140) + self.assertEqual(sumprod([1.5, 2.5], [3.5, 4.5]), 16.5) + self.assertEqual(sumprod([], []), 0) + + # Type preservation and coercion + for v in [ + (10, 20, 30), + (1.5, -2.5), + (Fraction(3, 5), Fraction(4, 5)), + (Decimal(3.5), Decimal(4.5)), + (2.5, 10), # float/int + (2.5, Fraction(3, 5)), # float/fraction + (25, Fraction(3, 5)), # int/fraction + (25, Decimal(4.5)), # int/decimal + ]: + for p, q in [(v, v), (v, v[::-1])]: + with self.subTest(p=p, q=q): + expected = sum(p_i * q_i for p_i, q_i in zip(p, q, strict=True)) + actual = sumprod(p, q) + self.assertEqual(expected, actual) + self.assertEqual(type(expected), type(actual)) + + # Bad arguments + self.assertRaises(TypeError, sumprod) # No args + self.assertRaises(TypeError, sumprod, []) # One arg + self.assertRaises(TypeError, sumprod, [], [], []) # Three args + self.assertRaises(TypeError, sumprod, None, [10]) # Non-iterable + self.assertRaises(TypeError, sumprod, [10], None) # Non-iterable + + # Uneven lengths + self.assertRaises(ValueError, sumprod, [10, 20], [30]) + self.assertRaises(ValueError, sumprod, [10], [20, 30]) + + # Error in iterator + def raise_after(n): + for i in range(n): + yield i + raise RuntimeError + with self.assertRaises(RuntimeError): + sumprod(range(10), raise_after(5)) + with self.assertRaises(RuntimeError): + sumprod(raise_after(5), range(10)) + + # Error in multiplication + class BadMultiply: + def __mul__(self, other): + raise RuntimeError + def __rmul__(self, other): + raise RuntimeError + with self.assertRaises(RuntimeError): + sumprod([10, BadMultiply(), 30], [1, 2, 3]) + with self.assertRaises(RuntimeError): + sumprod([1, 2, 3], [10, BadMultiply(), 30]) + + # Error in addition + with self.assertRaises(TypeError): + sumprod(['abc', 3], [5, 10]) + with self.assertRaises(TypeError): + sumprod([5, 10], ['abc', 3]) + + # Special values should give the same as the pure python recipe + self.assertEqual(sumprod([10.1, math.inf], [20.2, 30.3]), math.inf) + self.assertEqual(sumprod([10.1, math.inf], [math.inf, 30.3]), math.inf) + self.assertEqual(sumprod([10.1, math.inf], [math.inf, math.inf]), math.inf) + self.assertEqual(sumprod([10.1, -math.inf], [20.2, 30.3]), -math.inf) + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [-math.inf, math.inf]))) + self.assertTrue(math.isnan(sumprod([10.1, math.nan], [20.2, 30.3]))) + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [math.nan, 30.3]))) + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [20.3, math.nan]))) + + # Error cases that arose during development + args = ((-5, -5, 10), (1.5, 4611686018427387904, 2305843009213693952)) + self.assertEqual(sumprod(*args), 0.0) + + + @requires_IEEE_754 + @unittest.skipIf(HAVE_DOUBLE_ROUNDING, + "sumprod() accuracy not guaranteed on machines with double rounding") + @support.cpython_only # Other implementations may choose a different algorithm + def test_sumprod_accuracy(self): + sumprod = math.sumprod + self.assertEqual(sumprod([0.1] * 10, [1]*10), 1.0) + self.assertEqual(sumprod([0.1] * 20, [True, False] * 10), 1.0) + self.assertEqual(sumprod([1.0, 10E100, 1.0, -10E100], [1.0]*4), 2.0) + + def test_sumprod_stress(self): + sumprod = math.sumprod + product = itertools.product + Decimal = decimal.Decimal + Fraction = fractions.Fraction + + class Int(int): + def __add__(self, other): + return Int(int(self) + int(other)) + def __mul__(self, other): + return Int(int(self) * int(other)) + __radd__ = __add__ + __rmul__ = __mul__ + def __repr__(self): + return f'Int({int(self)})' + + class Flt(float): + def __add__(self, other): + return Int(int(self) + int(other)) + def __mul__(self, other): + return Int(int(self) * int(other)) + __radd__ = __add__ + __rmul__ = __mul__ + def __repr__(self): + return f'Flt({int(self)})' + + def baseline_sumprod(p, q): + """This defines the target behavior including expections and special values. + However, it is subject to rounding errors, so float inputs should be exactly + representable with only a few bits. + """ + total = 0 + for p_i, q_i in zip(p, q, strict=True): + total += p_i * q_i + return total + + def run(func, *args): + "Make comparing functions easier. Returns error status, type, and result." + try: + result = func(*args) + except (AssertionError, NameError): + raise + except Exception as e: + return type(e), None, 'None' + return None, type(result), repr(result) + + pools = [ + (-5, 10, -2**20, 2**31, 2**40, 2**61, 2**62, 2**80, 1.5, Int(7)), + (5.25, -3.5, 4.75, 11.25, 400.5, 0.046875, 0.25, -1.0, -0.078125), + (-19.0*2**500, 11*2**1000, -3*2**1500, 17*2*333, + 5.25, -3.25, -3.0*2**(-333), 3, 2**513), + (3.75, 2.5, -1.5, float('inf'), -float('inf'), float('NaN'), 14, + 9, 3+4j, Flt(13), 0.0), + (13.25, -4.25, Decimal('10.5'), Decimal('-2.25'), Fraction(13, 8), + Fraction(-11, 16), 4.75 + 0.125j, 97, -41, Int(3)), + (Decimal('6.125'), Decimal('12.375'), Decimal('-2.75'), Decimal(0), + Decimal('Inf'), -Decimal('Inf'), Decimal('NaN'), 12, 13.5), + (-2.0 ** -1000, 11*2**1000, 3, 7, -37*2**32, -2*2**-537, -2*2**-538, + 2*2**-513), + (-7 * 2.0 ** -510, 5 * 2.0 ** -520, 17, -19.0, -6.25), + (11.25, -3.75, -0.625, 23.375, True, False, 7, Int(5)), + ] + + for pool in pools: + for size in range(4): + for args1 in product(pool, repeat=size): + for args2 in product(pool, repeat=size): + args = (args1, args2) + self.assertEqual( + run(baseline_sumprod, *args), + run(sumprod, *args), + args, + ) + def testModf(self): self.assertRaises(TypeError, math.modf) |