summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_math.py
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2023-01-07 18:46:35 (GMT)
committerGitHub <noreply@github.com>2023-01-07 18:46:35 (GMT)
commit47b9f83a83db288c652e43567c7b0f74d87a29be (patch)
treecb4fde0440b01f79852ef45e1a2a2fe22ba6daca /Lib/test/test_math.py
parentdeaf090699a7312cccb0637409f44de3f382389b (diff)
downloadcpython-47b9f83a83db288c652e43567c7b0f74d87a29be.zip
cpython-47b9f83a83db288c652e43567c7b0f74d87a29be.tar.gz
cpython-47b9f83a83db288c652e43567c7b0f74d87a29be.tar.bz2
GH-100485: Add math.sumprod() (GH-100677)
Diffstat (limited to 'Lib/test/test_math.py')
-rw-r--r--Lib/test/test_math.py166
1 files changed, 166 insertions, 0 deletions
diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py
index bf0d0a5..65fe169 100644
--- a/Lib/test/test_math.py
+++ b/Lib/test/test_math.py
@@ -4,6 +4,7 @@
from test.support import verbose, requires_IEEE_754
from test import support
import unittest
+import fractions
import itertools
import decimal
import math
@@ -1202,6 +1203,171 @@ class MathTests(unittest.TestCase):
self.assertEqual(math.log(INF), INF)
self.assertTrue(math.isnan(math.log10(NAN)))
+ def testSumProd(self):
+ sumprod = math.sumprod
+ Decimal = decimal.Decimal
+ Fraction = fractions.Fraction
+
+ # Core functionality
+ self.assertEqual(sumprod(iter([10, 20, 30]), (1, 2, 3)), 140)
+ self.assertEqual(sumprod([1.5, 2.5], [3.5, 4.5]), 16.5)
+ self.assertEqual(sumprod([], []), 0)
+
+ # Type preservation and coercion
+ for v in [
+ (10, 20, 30),
+ (1.5, -2.5),
+ (Fraction(3, 5), Fraction(4, 5)),
+ (Decimal(3.5), Decimal(4.5)),
+ (2.5, 10), # float/int
+ (2.5, Fraction(3, 5)), # float/fraction
+ (25, Fraction(3, 5)), # int/fraction
+ (25, Decimal(4.5)), # int/decimal
+ ]:
+ for p, q in [(v, v), (v, v[::-1])]:
+ with self.subTest(p=p, q=q):
+ expected = sum(p_i * q_i for p_i, q_i in zip(p, q, strict=True))
+ actual = sumprod(p, q)
+ self.assertEqual(expected, actual)
+ self.assertEqual(type(expected), type(actual))
+
+ # Bad arguments
+ self.assertRaises(TypeError, sumprod) # No args
+ self.assertRaises(TypeError, sumprod, []) # One arg
+ self.assertRaises(TypeError, sumprod, [], [], []) # Three args
+ self.assertRaises(TypeError, sumprod, None, [10]) # Non-iterable
+ self.assertRaises(TypeError, sumprod, [10], None) # Non-iterable
+
+ # Uneven lengths
+ self.assertRaises(ValueError, sumprod, [10, 20], [30])
+ self.assertRaises(ValueError, sumprod, [10], [20, 30])
+
+ # Error in iterator
+ def raise_after(n):
+ for i in range(n):
+ yield i
+ raise RuntimeError
+ with self.assertRaises(RuntimeError):
+ sumprod(range(10), raise_after(5))
+ with self.assertRaises(RuntimeError):
+ sumprod(raise_after(5), range(10))
+
+ # Error in multiplication
+ class BadMultiply:
+ def __mul__(self, other):
+ raise RuntimeError
+ def __rmul__(self, other):
+ raise RuntimeError
+ with self.assertRaises(RuntimeError):
+ sumprod([10, BadMultiply(), 30], [1, 2, 3])
+ with self.assertRaises(RuntimeError):
+ sumprod([1, 2, 3], [10, BadMultiply(), 30])
+
+ # Error in addition
+ with self.assertRaises(TypeError):
+ sumprod(['abc', 3], [5, 10])
+ with self.assertRaises(TypeError):
+ sumprod([5, 10], ['abc', 3])
+
+ # Special values should give the same as the pure python recipe
+ self.assertEqual(sumprod([10.1, math.inf], [20.2, 30.3]), math.inf)
+ self.assertEqual(sumprod([10.1, math.inf], [math.inf, 30.3]), math.inf)
+ self.assertEqual(sumprod([10.1, math.inf], [math.inf, math.inf]), math.inf)
+ self.assertEqual(sumprod([10.1, -math.inf], [20.2, 30.3]), -math.inf)
+ self.assertTrue(math.isnan(sumprod([10.1, math.inf], [-math.inf, math.inf])))
+ self.assertTrue(math.isnan(sumprod([10.1, math.nan], [20.2, 30.3])))
+ self.assertTrue(math.isnan(sumprod([10.1, math.inf], [math.nan, 30.3])))
+ self.assertTrue(math.isnan(sumprod([10.1, math.inf], [20.3, math.nan])))
+
+ # Error cases that arose during development
+ args = ((-5, -5, 10), (1.5, 4611686018427387904, 2305843009213693952))
+ self.assertEqual(sumprod(*args), 0.0)
+
+
+ @requires_IEEE_754
+ @unittest.skipIf(HAVE_DOUBLE_ROUNDING,
+ "sumprod() accuracy not guaranteed on machines with double rounding")
+ @support.cpython_only # Other implementations may choose a different algorithm
+ def test_sumprod_accuracy(self):
+ sumprod = math.sumprod
+ self.assertEqual(sumprod([0.1] * 10, [1]*10), 1.0)
+ self.assertEqual(sumprod([0.1] * 20, [True, False] * 10), 1.0)
+ self.assertEqual(sumprod([1.0, 10E100, 1.0, -10E100], [1.0]*4), 2.0)
+
+ def test_sumprod_stress(self):
+ sumprod = math.sumprod
+ product = itertools.product
+ Decimal = decimal.Decimal
+ Fraction = fractions.Fraction
+
+ class Int(int):
+ def __add__(self, other):
+ return Int(int(self) + int(other))
+ def __mul__(self, other):
+ return Int(int(self) * int(other))
+ __radd__ = __add__
+ __rmul__ = __mul__
+ def __repr__(self):
+ return f'Int({int(self)})'
+
+ class Flt(float):
+ def __add__(self, other):
+ return Int(int(self) + int(other))
+ def __mul__(self, other):
+ return Int(int(self) * int(other))
+ __radd__ = __add__
+ __rmul__ = __mul__
+ def __repr__(self):
+ return f'Flt({int(self)})'
+
+ def baseline_sumprod(p, q):
+ """This defines the target behavior including expections and special values.
+ However, it is subject to rounding errors, so float inputs should be exactly
+ representable with only a few bits.
+ """
+ total = 0
+ for p_i, q_i in zip(p, q, strict=True):
+ total += p_i * q_i
+ return total
+
+ def run(func, *args):
+ "Make comparing functions easier. Returns error status, type, and result."
+ try:
+ result = func(*args)
+ except (AssertionError, NameError):
+ raise
+ except Exception as e:
+ return type(e), None, 'None'
+ return None, type(result), repr(result)
+
+ pools = [
+ (-5, 10, -2**20, 2**31, 2**40, 2**61, 2**62, 2**80, 1.5, Int(7)),
+ (5.25, -3.5, 4.75, 11.25, 400.5, 0.046875, 0.25, -1.0, -0.078125),
+ (-19.0*2**500, 11*2**1000, -3*2**1500, 17*2*333,
+ 5.25, -3.25, -3.0*2**(-333), 3, 2**513),
+ (3.75, 2.5, -1.5, float('inf'), -float('inf'), float('NaN'), 14,
+ 9, 3+4j, Flt(13), 0.0),
+ (13.25, -4.25, Decimal('10.5'), Decimal('-2.25'), Fraction(13, 8),
+ Fraction(-11, 16), 4.75 + 0.125j, 97, -41, Int(3)),
+ (Decimal('6.125'), Decimal('12.375'), Decimal('-2.75'), Decimal(0),
+ Decimal('Inf'), -Decimal('Inf'), Decimal('NaN'), 12, 13.5),
+ (-2.0 ** -1000, 11*2**1000, 3, 7, -37*2**32, -2*2**-537, -2*2**-538,
+ 2*2**-513),
+ (-7 * 2.0 ** -510, 5 * 2.0 ** -520, 17, -19.0, -6.25),
+ (11.25, -3.75, -0.625, 23.375, True, False, 7, Int(5)),
+ ]
+
+ for pool in pools:
+ for size in range(4):
+ for args1 in product(pool, repeat=size):
+ for args2 in product(pool, repeat=size):
+ args = (args1, args2)
+ self.assertEqual(
+ run(baseline_sumprod, *args),
+ run(sumprod, *args),
+ args,
+ )
+
def testModf(self):
self.assertRaises(TypeError, math.modf)