summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorMark Dickinson <dickinsm@gmail.com>2010-05-15 17:02:38 (GMT)
committerMark Dickinson <dickinsm@gmail.com>2010-05-15 17:02:38 (GMT)
commit4c8a9a2df3c31b1c29d0b3cf74523e3c8b3dae72 (patch)
treef28718e84ae7a59ec3ec6780fa5fa2328362edf2 /Lib
parentae6265f8d06dbec7d08c73ca23dad0f040d09b8e (diff)
downloadcpython-4c8a9a2df3c31b1c29d0b3cf74523e3c8b3dae72.zip
cpython-4c8a9a2df3c31b1c29d0b3cf74523e3c8b3dae72.tar.gz
cpython-4c8a9a2df3c31b1c29d0b3cf74523e3c8b3dae72.tar.bz2
Issue #8692: Improve performance of math.factorial:
(1) use a different algorithm that roughly halves the total number of multiplications required and results in more balanced multiplications (2) use a lookup table for small arguments (3) fast accumulation of products in C integer arithmetic rather than PyLong arithmetic when possible. Typical speedup, from unscientific testing on a 64-bit laptop, is 4.5x to 6.5x for arguments in the range 100 - 10000. Patch by Daniel Stutzbach; extensive reviews by Alexander Belopolsky.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/test_math.py71
1 files changed, 61 insertions, 10 deletions
diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py
index 48d9b1a..6c44435 100644
--- a/Lib/test/test_math.py
+++ b/Lib/test/test_math.py
@@ -60,6 +60,56 @@ def ulps_check(expected, got, ulps=20):
return "error = {} ulps; permitted error = {} ulps".format(ulps_error,
ulps)
+# Here's a pure Python version of the math.factorial algorithm, for
+# documentation and comparison purposes.
+#
+# Formula:
+#
+# factorial(n) = factorial_odd_part(n) << (n - count_set_bits(n))
+#
+# where
+#
+# factorial_odd_part(n) = product_{i >= 0} product_{0 < j <= n >> i; j odd} j
+#
+# The outer product above is an infinite product, but once i >= n.bit_length,
+# (n >> i) < 1 and the corresponding term of the product is empty. So only the
+# finitely many terms for 0 <= i < n.bit_length() contribute anything.
+#
+# We iterate downwards from i == n.bit_length() - 1 to i == 0. The inner
+# product in the formula above starts at 1 for i == n.bit_length(); for each i
+# < n.bit_length() we get the inner product for i from that for i + 1 by
+# multiplying by all j in {n >> i+1 < j <= n >> i; j odd}. In Python terms,
+# this set is range((n >> i+1) + 1 | 1, (n >> i) + 1 | 1, 2).
+
+def count_set_bits(n):
+ """Number of '1' bits in binary expansion of a nonnnegative integer."""
+ return 1 + count_set_bits(n & n - 1) if n else 0
+
+def partial_product(start, stop):
+ """Product of integers in range(start, stop, 2), computed recursively.
+ start and stop should both be odd, with start <= stop.
+
+ """
+ numfactors = (stop - start) >> 1
+ if not numfactors:
+ return 1
+ elif numfactors == 1:
+ return start
+ else:
+ mid = (start + numfactors) | 1
+ return partial_product(start, mid) * partial_product(mid, stop)
+
+def py_factorial(n):
+ """Factorial of nonnegative integer n, via "Binary Split Factorial Formula"
+ described at http://www.luschny.de/math/factorial/binarysplitfact.html
+
+ """
+ inner = outer = 1
+ for i in reversed(range(n.bit_length())):
+ inner *= partial_product((n >> i + 1) + 1 | 1, (n >> i) + 1 | 1)
+ outer *= inner
+ return outer << (n - count_set_bits(n))
+
def acc_check(expected, got, rel_err=2e-15, abs_err = 5e-323):
"""Determine whether non-NaN floats a and b are equal to within a
(small) rounding error. The default values for rel_err and
@@ -365,18 +415,19 @@ class MathTests(unittest.TestCase):
self.ftest('fabs(1)', math.fabs(1), 1)
def testFactorial(self):
- def fact(n):
- result = 1
- for i in range(1, int(n)+1):
- result *= i
- return result
- values = list(range(10)) + [50, 100, 500]
- random.shuffle(values)
- for x in values:
- for cast in (int, float):
- self.assertEqual(math.factorial(cast(x)), fact(x), (x, fact(x), math.factorial(x)))
+ self.assertEqual(math.factorial(0), 1)
+ self.assertEqual(math.factorial(0.0), 1)
+ total = 1
+ for i in range(1, 1000):
+ total *= i
+ self.assertEqual(math.factorial(i), total)
+ self.assertEqual(math.factorial(float(i)), total)
+ self.assertEqual(math.factorial(i), py_factorial(i))
self.assertRaises(ValueError, math.factorial, -1)
+ self.assertRaises(ValueError, math.factorial, -1.0)
self.assertRaises(ValueError, math.factorial, math.pi)
+ self.assertRaises(OverflowError, math.factorial, sys.maxsize+1)
+ self.assertRaises(OverflowError, math.factorial, 10e100)
def testFloor(self):
self.assertRaises(TypeError, math.floor)