summaryrefslogtreecommitdiffstats
path: root/Modules/mathmodule.c
diff options
context:
space:
mode:
authorMark Dickinson <dickinsm@gmail.com>2010-05-15 17:02:38 (GMT)
committerMark Dickinson <dickinsm@gmail.com>2010-05-15 17:02:38 (GMT)
commit4c8a9a2df3c31b1c29d0b3cf74523e3c8b3dae72 (patch)
treef28718e84ae7a59ec3ec6780fa5fa2328362edf2 /Modules/mathmodule.c
parentae6265f8d06dbec7d08c73ca23dad0f040d09b8e (diff)
downloadcpython-4c8a9a2df3c31b1c29d0b3cf74523e3c8b3dae72.zip
cpython-4c8a9a2df3c31b1c29d0b3cf74523e3c8b3dae72.tar.gz
cpython-4c8a9a2df3c31b1c29d0b3cf74523e3c8b3dae72.tar.bz2
Issue #8692: Improve performance of math.factorial:
(1) use a different algorithm that roughly halves the total number of multiplications required and results in more balanced multiplications (2) use a lookup table for small arguments (3) fast accumulation of products in C integer arithmetic rather than PyLong arithmetic when possible. Typical speedup, from unscientific testing on a 64-bit laptop, is 4.5x to 6.5x for arguments in the range 100 - 10000. Patch by Daniel Stutzbach; extensive reviews by Alexander Belopolsky.
Diffstat (limited to 'Modules/mathmodule.c')
-rw-r--r--Modules/mathmodule.c260
1 files changed, 240 insertions, 20 deletions
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index 76d7906..d57ad90 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -1129,18 +1129,239 @@ PyDoc_STRVAR(math_fsum_doc,
Return an accurate floating point sum of values in the iterable.\n\
Assumes IEEE-754 floating point arithmetic.");
+/* Return the smallest integer k such that n < 2**k, or 0 if n == 0.
+ * Equivalent to floor(lg(x))+1. Also equivalent to: bitwidth_of_type -
+ * count_leading_zero_bits(x)
+ */
+
+/* XXX: This routine does more or less the same thing as
+ * bits_in_digit() in Objects/longobject.c. Someday it would be nice to
+ * consolidate them. On BSD, there's a library function called fls()
+ * that we could use, and GCC provides __builtin_clz().
+ */
+
+static unsigned long
+bit_length(unsigned long n)
+{
+ unsigned long len = 0;
+ while (n != 0) {
+ ++len;
+ n >>= 1;
+ }
+ return len;
+}
+
+static unsigned long
+count_set_bits(unsigned long n)
+{
+ unsigned long count = 0;
+ while (n != 0) {
+ ++count;
+ n &= n - 1; /* clear least significant bit */
+ }
+ return count;
+}
+
+/* Divide-and-conquer factorial algorithm
+ *
+ * Based on the formula and psuedo-code provided at:
+ * http://www.luschny.de/math/factorial/binarysplitfact.html
+ *
+ * Faster algorithms exist, but they're more complicated and depend on
+ * a fast prime factoriazation algorithm.
+ *
+ * Notes on the algorithm
+ * ----------------------
+ *
+ * factorial(n) is written in the form 2**k * m, with m odd. k and m are
+ * computed separately, and then combined using a left shift.
+ *
+ * The function factorial_odd_part computes the odd part m (i.e., the greatest
+ * odd divisor) of factorial(n), using the formula:
+ *
+ * factorial_odd_part(n) =
+ *
+ * product_{i >= 0} product_{0 < j <= n / 2**i, j odd} j
+ *
+ * Example: factorial_odd_part(20) =
+ *
+ * (1) *
+ * (1) *
+ * (1 * 3 * 5) *
+ * (1 * 3 * 5 * 7 * 9)
+ * (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19)
+ *
+ * Here i goes from large to small: the first term corresponds to i=4 (any
+ * larger i gives an empty product), and the last term corresponds to i=0.
+ * Each term can be computed from the last by multiplying by the extra odd
+ * numbers required: e.g., to get from the penultimate term to the last one,
+ * we multiply by (11 * 13 * 15 * 17 * 19).
+ *
+ * To see a hint of why this formula works, here are the same numbers as above
+ * but with the even parts (i.e., the appropriate powers of 2) included. For
+ * each subterm in the product for i, we multiply that subterm by 2**i:
+ *
+ * factorial(20) =
+ *
+ * (16) *
+ * (8) *
+ * (4 * 12 * 20) *
+ * (2 * 6 * 10 * 14 * 18) *
+ * (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19)
+ *
+ * The factorial_partial_product function computes the product of all odd j in
+ * range(start, stop) for given start and stop. It's used to compute the
+ * partial products like (11 * 13 * 15 * 17 * 19) in the example above. It
+ * operates recursively, repeatedly splitting the range into two roughly equal
+ * pieces until the subranges are small enough to be computed using only C
+ * integer arithmetic.
+ *
+ * The two-valuation k (i.e., the exponent of the largest power of 2 dividing
+ * the factorial) is computed independently in the main math_factorial
+ * function. By standard results, its value is:
+ *
+ * two_valuation = n//2 + n//4 + n//8 + ....
+ *
+ * It can be shown (e.g., by complete induction on n) that two_valuation is
+ * equal to n - count_set_bits(n), where count_set_bits(n) gives the number of
+ * '1'-bits in the binary expansion of n.
+ */
+
+/* factorial_partial_product: Compute product(range(start, stop, 2)) using
+ * divide and conquer. Assumes start and stop are odd and stop > start.
+ * max_bits must be >= bit_length(stop - 2). */
+
+static PyObject *
+factorial_partial_product(unsigned long start, unsigned long stop,
+ unsigned long max_bits)
+{
+ unsigned long midpoint, num_operands;
+ PyObject *left = NULL, *right = NULL, *result = NULL;
+
+ /* If the return value will fit an unsigned long, then we can
+ * multiply in a tight, fast loop where each multiply is O(1).
+ * Compute an upper bound on the number of bits required to store
+ * the answer.
+ *
+ * Storing some integer z requires floor(lg(z))+1 bits, which is
+ * conveniently the value returned by bit_length(z). The
+ * product x*y will require at most
+ * bit_length(x) + bit_length(y) bits to store, based
+ * on the idea that lg product = lg x + lg y.
+ *
+ * We know that stop - 2 is the largest number to be multiplied. From
+ * there, we have: bit_length(answer) <= num_operands *
+ * bit_length(stop - 2)
+ */
+
+ num_operands = (stop - start) / 2;
+ /* The "num_operands <= 8 * SIZEOF_LONG" check guards against the
+ * unlikely case of an overflow in num_operands * max_bits. */
+ if (num_operands <= 8 * SIZEOF_LONG &&
+ num_operands * max_bits <= 8 * SIZEOF_LONG) {
+ unsigned long j, total;
+ for (total = start, j = start + 2; j < stop; j += 2)
+ total *= j;
+ return PyLong_FromUnsignedLong(total);
+ }
+
+ /* find midpoint of range(start, stop), rounded up to next odd number. */
+ midpoint = (start + num_operands) | 1;
+ left = factorial_partial_product(start, midpoint,
+ bit_length(midpoint - 2));
+ if (left == NULL)
+ goto error;
+ right = factorial_partial_product(midpoint, stop, max_bits);
+ if (right == NULL)
+ goto error;
+ result = PyNumber_Multiply(left, right);
+
+ error:
+ Py_XDECREF(left);
+ Py_XDECREF(right);
+ return result;
+}
+
+/* factorial_odd_part: compute the odd part of factorial(n). */
+
+static PyObject *
+factorial_odd_part(unsigned long n)
+{
+ long i;
+ unsigned long v, lower, upper;
+ PyObject *partial, *tmp, *inner, *outer;
+
+ inner = PyLong_FromLong(1);
+ if (inner == NULL)
+ return NULL;
+ outer = inner;
+ Py_INCREF(outer);
+
+ upper = 3;
+ for (i = bit_length(n) - 2; i >= 0; i--) {
+ v = n >> i;
+ if (v <= 2)
+ continue;
+ lower = upper;
+ /* (v + 1) | 1 = least odd integer strictly larger than n / 2**i */
+ upper = (v + 1) | 1;
+ /* Here inner is the product of all odd integers j in the range (0,
+ n/2**(i+1)]. The factorial_partial_product call below gives the
+ product of all odd integers j in the range (n/2**(i+1), n/2**i]. */
+ partial = factorial_partial_product(lower, upper, bit_length(upper-2));
+ /* inner *= partial */
+ if (partial == NULL)
+ goto error;
+ tmp = PyNumber_Multiply(inner, partial);
+ Py_DECREF(partial);
+ if (tmp == NULL)
+ goto error;
+ Py_DECREF(inner);
+ inner = tmp;
+ /* Now inner is the product of all odd integers j in the range (0,
+ n/2**i], giving the inner product in the formula above. */
+
+ /* outer *= inner; */
+ tmp = PyNumber_Multiply(outer, inner);
+ if (tmp == NULL)
+ goto error;
+ Py_DECREF(outer);
+ outer = tmp;
+ }
+
+ goto done;
+
+ error:
+ Py_DECREF(outer);
+ done:
+ Py_DECREF(inner);
+ return outer;
+}
+
+/* Lookup table for small factorial values */
+
+static const unsigned long SmallFactorials[] = {
+ 1, 1, 2, 6, 24, 120, 720, 5040, 40320,
+ 362880, 3628800, 39916800, 479001600,
+#if SIZEOF_LONG >= 8
+ 6227020800, 87178291200, 1307674368000,
+ 20922789888000, 355687428096000, 6402373705728000,
+ 121645100408832000, 2432902008176640000
+#endif
+};
+
static PyObject *
math_factorial(PyObject *self, PyObject *arg)
{
- long i, x;
- PyObject *result, *iobj, *newresult;
+ long x;
+ PyObject *result, *odd_part, *two_valuation;
if (PyFloat_Check(arg)) {
PyObject *lx;
double dx = PyFloat_AS_DOUBLE((PyFloatObject *)arg);
if (!(Py_IS_FINITE(dx) && dx == floor(dx))) {
PyErr_SetString(PyExc_ValueError,
- "factorial() only accepts integral values");
+ "factorial() only accepts integral values");
return NULL;
}
lx = PyLong_FromDouble(dx);
@@ -1156,29 +1377,28 @@ math_factorial(PyObject *self, PyObject *arg)
return NULL;
if (x < 0) {
PyErr_SetString(PyExc_ValueError,
- "factorial() not defined for negative values");
+ "factorial() not defined for negative values");
return NULL;
}
- result = (PyObject *)PyLong_FromLong(1);
- if (result == NULL)
+ /* use lookup table if x is small */
+ if (x < (long)(sizeof(SmallFactorials)/sizeof(SmallFactorials[0])))
+ return PyLong_FromUnsignedLong(SmallFactorials[x]);
+
+ /* else express in the form odd_part * 2**two_valuation, and compute as
+ odd_part << two_valuation. */
+ odd_part = factorial_odd_part(x);
+ if (odd_part == NULL)
+ return NULL;
+ two_valuation = PyLong_FromLong(x - count_set_bits(x));
+ if (two_valuation == NULL) {
+ Py_DECREF(odd_part);
return NULL;
- for (i=1 ; i<=x ; i++) {
- iobj = (PyObject *)PyLong_FromLong(i);
- if (iobj == NULL)
- goto error;
- newresult = PyNumber_Multiply(result, iobj);
- Py_DECREF(iobj);
- if (newresult == NULL)
- goto error;
- Py_DECREF(result);
- result = newresult;
}
+ result = PyNumber_Lshift(odd_part, two_valuation);
+ Py_DECREF(two_valuation);
+ Py_DECREF(odd_part);
return result;
-
-error:
- Py_DECREF(result);
- return NULL;
}
PyDoc_STRVAR(math_factorial_doc,