summaryrefslogtreecommitdiffstats
path: root/Modules/mathmodule.c
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2021-12-05 20:26:10 (GMT)
committerGitHub <noreply@github.com>2021-12-05 20:26:10 (GMT)
commit60c320c38e4e95877cde0b1d8562ebd6bc02ac61 (patch)
treec22e03b50c431a0c46d84d03f20a8ff9f471aaf0 /Modules/mathmodule.c
parent628abe4463ed40cd54ca952a2b4cc2d6e74073f7 (diff)
downloadcpython-60c320c38e4e95877cde0b1d8562ebd6bc02ac61.zip
cpython-60c320c38e4e95877cde0b1d8562ebd6bc02ac61.tar.gz
cpython-60c320c38e4e95877cde0b1d8562ebd6bc02ac61.tar.bz2
bpo-37295: Optimize math.comb() and math.perm() (GH-29090)
For very large numbers use divide-and-conquer algorithm for getting benefit of Karatsuba multiplication of large numbers. Do calculations completely in C unsigned long long instead of Python integers if possible.
Diffstat (limited to 'Modules/mathmodule.c')
-rw-r--r--Modules/mathmodule.c285
1 files changed, 192 insertions, 93 deletions
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index 64ce4e6..84b5b95 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -3221,6 +3221,138 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
}
+/* Number of permutations and combinations.
+ * P(n, k) = n! / (n-k)!
+ * C(n, k) = P(n, k) / k!
+ */
+
+/* Calculate C(n, k) for n in the 63-bit range. */
+static PyObject *
+perm_comb_small(unsigned long long n, unsigned long long k, int iscomb)
+{
+ /* long long is at least 64 bit */
+ static const unsigned long long fast_comb_limits[] = {
+ 0, ULLONG_MAX, 4294967296ULL, 3329022, 102570, 13467, 3612, 1449, // 0-7
+ 746, 453, 308, 227, 178, 147, 125, 110, // 8-15
+ 99, 90, 84, 79, 75, 72, 69, 68, // 16-23
+ 66, 65, 64, 63, 63, 62, 62, 62, // 24-31
+ };
+ static const unsigned long long fast_perm_limits[] = {
+ 0, ULLONG_MAX, 4294967296ULL, 2642246, 65537, 7133, 1627, 568, // 0-7
+ 259, 142, 88, 61, 45, 36, 30, // 8-14
+ };
+
+ if (k == 0) {
+ return PyLong_FromLong(1);
+ }
+
+ /* For small enough n and k the result fits in the 64-bit range and can
+ * be calculated without allocating intermediate PyLong objects. */
+ if (iscomb
+ ? (k < Py_ARRAY_LENGTH(fast_comb_limits)
+ && n <= fast_comb_limits[k])
+ : (k < Py_ARRAY_LENGTH(fast_perm_limits)
+ && n <= fast_perm_limits[k]))
+ {
+ unsigned long long result = n;
+ if (iscomb) {
+ for (unsigned long long i = 1; i < k;) {
+ result *= --n;
+ result /= ++i;
+ }
+ }
+ else {
+ for (unsigned long long i = 1; i < k;) {
+ result *= --n;
+ ++i;
+ }
+ }
+ return PyLong_FromUnsignedLongLong(result);
+ }
+
+ /* For larger n use recursive formula. */
+ /* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
+ unsigned long long j = k / 2;
+ PyObject *a, *b;
+ a = perm_comb_small(n, j, iscomb);
+ if (a == NULL) {
+ return NULL;
+ }
+ b = perm_comb_small(n - j, k - j, iscomb);
+ if (b == NULL) {
+ goto error;
+ }
+ Py_SETREF(a, PyNumber_Multiply(a, b));
+ Py_DECREF(b);
+ if (iscomb && a != NULL) {
+ b = perm_comb_small(k, j, 1);
+ if (b == NULL) {
+ goto error;
+ }
+ Py_SETREF(a, PyNumber_FloorDivide(a, b));
+ Py_DECREF(b);
+ }
+ return a;
+
+error:
+ Py_DECREF(a);
+ return NULL;
+}
+
+/* Calculate P(n, k) or C(n, k) using recursive formulas.
+ * It is more efficient than sequential multiplication thanks to
+ * Karatsuba multiplication.
+ */
+static PyObject *
+perm_comb(PyObject *n, unsigned long long k, int iscomb)
+{
+ if (k == 0) {
+ return PyLong_FromLong(1);
+ }
+ if (k == 1) {
+ Py_INCREF(n);
+ return n;
+ }
+
+ /* P(n, k) = P(n, j) * P(n-j, k-j) */
+ /* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
+ unsigned long long j = k / 2;
+ PyObject *a, *b;
+ a = perm_comb(n, j, iscomb);
+ if (a == NULL) {
+ return NULL;
+ }
+ PyObject *t = PyLong_FromUnsignedLongLong(j);
+ if (t == NULL) {
+ goto error;
+ }
+ n = PyNumber_Subtract(n, t);
+ Py_DECREF(t);
+ if (n == NULL) {
+ goto error;
+ }
+ b = perm_comb(n, k - j, iscomb);
+ Py_DECREF(n);
+ if (b == NULL) {
+ goto error;
+ }
+ Py_SETREF(a, PyNumber_Multiply(a, b));
+ Py_DECREF(b);
+ if (iscomb && a != NULL) {
+ b = perm_comb_small(k, j, 1);
+ if (b == NULL) {
+ goto error;
+ }
+ Py_SETREF(a, PyNumber_FloorDivide(a, b));
+ Py_DECREF(b);
+ }
+ return a;
+
+error:
+ Py_DECREF(a);
+ return NULL;
+}
+
/*[clinic input]
math.perm
@@ -3244,9 +3376,9 @@ static PyObject *
math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
/*[clinic end generated code: output=e021a25469653e23 input=5311c5a00f359b53]*/
{
- PyObject *result = NULL, *factor = NULL;
+ PyObject *result = NULL;
int overflow, cmp;
- long long i, factors;
+ long long ki, ni;
if (k == Py_None) {
return math_factorial(module, n);
@@ -3260,6 +3392,7 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
Py_DECREF(n);
return NULL;
}
+ assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));
if (Py_SIZE(n) < 0) {
PyErr_SetString(PyExc_ValueError,
@@ -3281,42 +3414,26 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
goto error;
}
- factors = PyLong_AsLongLongAndOverflow(k, &overflow);
+ ki = PyLong_AsLongLongAndOverflow(k, &overflow);
+ assert(overflow >= 0 && !PyErr_Occurred());
if (overflow > 0) {
PyErr_Format(PyExc_OverflowError,
"k must not exceed %lld",
LLONG_MAX);
goto error;
}
- else if (factors == -1) {
- /* k is nonnegative, so a return value of -1 can only indicate error */
- goto error;
- }
+ assert(ki >= 0);
- if (factors == 0) {
- result = PyLong_FromLong(1);
- goto done;
+ ni = PyLong_AsLongLongAndOverflow(n, &overflow);
+ assert(overflow >= 0 && !PyErr_Occurred());
+ if (!overflow && ki > 1) {
+ assert(ni >= 0);
+ result = perm_comb_small((unsigned long long)ni,
+ (unsigned long long)ki, 0);
}
-
- result = n;
- Py_INCREF(result);
- if (factors == 1) {
- goto done;
- }
-
- factor = Py_NewRef(n);
- PyObject *one = _PyLong_GetOne(); // borrowed ref
- for (i = 1; i < factors; ++i) {
- Py_SETREF(factor, PyNumber_Subtract(factor, one));
- if (factor == NULL) {
- goto error;
- }
- Py_SETREF(result, PyNumber_Multiply(result, factor));
- if (result == NULL) {
- goto error;
- }
+ else {
+ result = perm_comb(n, (unsigned long long)ki, 0);
}
- Py_DECREF(factor);
done:
Py_DECREF(n);
@@ -3324,14 +3441,11 @@ done:
return result;
error:
- Py_XDECREF(factor);
- Py_XDECREF(result);
Py_DECREF(n);
Py_DECREF(k);
return NULL;
}
-
/*[clinic input]
math.comb
@@ -3357,9 +3471,9 @@ static PyObject *
math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
/*[clinic end generated code: output=bd2cec8d854f3493 input=9a05315af2518709]*/
{
- PyObject *result = NULL, *factor = NULL, *temp;
+ PyObject *result = NULL, *temp;
int overflow, cmp;
- long long i, factors;
+ long long ki, ni;
n = PyNumber_Index(n);
if (n == NULL) {
@@ -3370,6 +3484,7 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
Py_DECREF(n);
return NULL;
}
+ assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));
if (Py_SIZE(n) < 0) {
PyErr_SetString(PyExc_ValueError,
@@ -3382,73 +3497,59 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
goto error;
}
- /* k = min(k, n - k) */
- temp = PyNumber_Subtract(n, k);
- if (temp == NULL) {
- goto error;
- }
- if (Py_SIZE(temp) < 0) {
- Py_DECREF(temp);
- result = PyLong_FromLong(0);
- goto done;
- }
- cmp = PyObject_RichCompareBool(temp, k, Py_LT);
- if (cmp > 0) {
- Py_SETREF(k, temp);
+ ni = PyLong_AsLongLongAndOverflow(n, &overflow);
+ assert(overflow >= 0 && !PyErr_Occurred());
+ if (!overflow) {
+ assert(ni >= 0);
+ ki = PyLong_AsLongLongAndOverflow(k, &overflow);
+ assert(overflow >= 0 && !PyErr_Occurred());
+ if (overflow || ki > ni) {
+ result = PyLong_FromLong(0);
+ goto done;
+ }
+ assert(ki >= 0);
+ ki = Py_MIN(ki, ni - ki);
+ if (ki > 1) {
+ result = perm_comb_small((unsigned long long)ni,
+ (unsigned long long)ki, 1);
+ goto done;
+ }
+ /* For k == 1 just return the original n in perm_comb(). */
}
else {
- Py_DECREF(temp);
- if (cmp < 0) {
+ /* k = min(k, n - k) */
+ temp = PyNumber_Subtract(n, k);
+ if (temp == NULL) {
goto error;
}
- }
-
- factors = PyLong_AsLongLongAndOverflow(k, &overflow);
- if (overflow > 0) {
- PyErr_Format(PyExc_OverflowError,
- "min(n - k, k) must not exceed %lld",
- LLONG_MAX);
- goto error;
- }
- if (factors == -1) {
- /* k is nonnegative, so a return value of -1 can only indicate error */
- goto error;
- }
-
- if (factors == 0) {
- result = PyLong_FromLong(1);
- goto done;
- }
-
- result = n;
- Py_INCREF(result);
- if (factors == 1) {
- goto done;
- }
-
- factor = Py_NewRef(n);
- PyObject *one = _PyLong_GetOne(); // borrowed ref
- for (i = 1; i < factors; ++i) {
- Py_SETREF(factor, PyNumber_Subtract(factor, one));
- if (factor == NULL) {
- goto error;
+ if (Py_SIZE(temp) < 0) {
+ Py_DECREF(temp);
+ result = PyLong_FromLong(0);
+ goto done;
}
- Py_SETREF(result, PyNumber_Multiply(result, factor));
- if (result == NULL) {
- goto error;
+ cmp = PyObject_RichCompareBool(temp, k, Py_LT);
+ if (cmp > 0) {
+ Py_SETREF(k, temp);
}
-
- temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
- if (temp == NULL) {
- goto error;
+ else {
+ Py_DECREF(temp);
+ if (cmp < 0) {
+ goto error;
+ }
}
- Py_SETREF(result, PyNumber_FloorDivide(result, temp));
- Py_DECREF(temp);
- if (result == NULL) {
+
+ ki = PyLong_AsLongLongAndOverflow(k, &overflow);
+ assert(overflow >= 0 && !PyErr_Occurred());
+ if (overflow) {
+ PyErr_Format(PyExc_OverflowError,
+ "min(n - k, k) must not exceed %lld",
+ LLONG_MAX);
goto error;
}
+ assert(ki >= 0);
}
- Py_DECREF(factor);
+
+ result = perm_comb(n, (unsigned long long)ki, 1);
done:
Py_DECREF(n);
@@ -3456,8 +3557,6 @@ done:
return result;
error:
- Py_XDECREF(factor);
- Py_XDECREF(result);
Py_DECREF(n);
Py_DECREF(k);
return NULL;