diff options
author | Serhiy Storchaka <storchaka@gmail.com> | 2021-12-05 20:26:10 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-05 20:26:10 (GMT) |
commit | 60c320c38e4e95877cde0b1d8562ebd6bc02ac61 (patch) | |
tree | c22e03b50c431a0c46d84d03f20a8ff9f471aaf0 /Modules/mathmodule.c | |
parent | 628abe4463ed40cd54ca952a2b4cc2d6e74073f7 (diff) | |
download | cpython-60c320c38e4e95877cde0b1d8562ebd6bc02ac61.zip cpython-60c320c38e4e95877cde0b1d8562ebd6bc02ac61.tar.gz cpython-60c320c38e4e95877cde0b1d8562ebd6bc02ac61.tar.bz2 |
bpo-37295: Optimize math.comb() and math.perm() (GH-29090)
For very large numbers use divide-and-conquer algorithm for getting
benefit of Karatsuba multiplication of large numbers.
Do calculations completely in C unsigned long long instead of Python
integers if possible.
Diffstat (limited to 'Modules/mathmodule.c')
-rw-r--r-- | Modules/mathmodule.c | 285 |
1 files changed, 192 insertions, 93 deletions
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index 64ce4e6..84b5b95 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -3221,6 +3221,138 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start) } +/* Number of permutations and combinations. + * P(n, k) = n! / (n-k)! + * C(n, k) = P(n, k) / k! + */ + +/* Calculate C(n, k) for n in the 63-bit range. */ +static PyObject * +perm_comb_small(unsigned long long n, unsigned long long k, int iscomb) +{ + /* long long is at least 64 bit */ + static const unsigned long long fast_comb_limits[] = { + 0, ULLONG_MAX, 4294967296ULL, 3329022, 102570, 13467, 3612, 1449, // 0-7 + 746, 453, 308, 227, 178, 147, 125, 110, // 8-15 + 99, 90, 84, 79, 75, 72, 69, 68, // 16-23 + 66, 65, 64, 63, 63, 62, 62, 62, // 24-31 + }; + static const unsigned long long fast_perm_limits[] = { + 0, ULLONG_MAX, 4294967296ULL, 2642246, 65537, 7133, 1627, 568, // 0-7 + 259, 142, 88, 61, 45, 36, 30, // 8-14 + }; + + if (k == 0) { + return PyLong_FromLong(1); + } + + /* For small enough n and k the result fits in the 64-bit range and can + * be calculated without allocating intermediate PyLong objects. */ + if (iscomb + ? (k < Py_ARRAY_LENGTH(fast_comb_limits) + && n <= fast_comb_limits[k]) + : (k < Py_ARRAY_LENGTH(fast_perm_limits) + && n <= fast_perm_limits[k])) + { + unsigned long long result = n; + if (iscomb) { + for (unsigned long long i = 1; i < k;) { + result *= --n; + result /= ++i; + } + } + else { + for (unsigned long long i = 1; i < k;) { + result *= --n; + ++i; + } + } + return PyLong_FromUnsignedLongLong(result); + } + + /* For larger n use recursive formula. */ + /* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */ + unsigned long long j = k / 2; + PyObject *a, *b; + a = perm_comb_small(n, j, iscomb); + if (a == NULL) { + return NULL; + } + b = perm_comb_small(n - j, k - j, iscomb); + if (b == NULL) { + goto error; + } + Py_SETREF(a, PyNumber_Multiply(a, b)); + Py_DECREF(b); + if (iscomb && a != NULL) { + b = perm_comb_small(k, j, 1); + if (b == NULL) { + goto error; + } + Py_SETREF(a, PyNumber_FloorDivide(a, b)); + Py_DECREF(b); + } + return a; + +error: + Py_DECREF(a); + return NULL; +} + +/* Calculate P(n, k) or C(n, k) using recursive formulas. + * It is more efficient than sequential multiplication thanks to + * Karatsuba multiplication. + */ +static PyObject * +perm_comb(PyObject *n, unsigned long long k, int iscomb) +{ + if (k == 0) { + return PyLong_FromLong(1); + } + if (k == 1) { + Py_INCREF(n); + return n; + } + + /* P(n, k) = P(n, j) * P(n-j, k-j) */ + /* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */ + unsigned long long j = k / 2; + PyObject *a, *b; + a = perm_comb(n, j, iscomb); + if (a == NULL) { + return NULL; + } + PyObject *t = PyLong_FromUnsignedLongLong(j); + if (t == NULL) { + goto error; + } + n = PyNumber_Subtract(n, t); + Py_DECREF(t); + if (n == NULL) { + goto error; + } + b = perm_comb(n, k - j, iscomb); + Py_DECREF(n); + if (b == NULL) { + goto error; + } + Py_SETREF(a, PyNumber_Multiply(a, b)); + Py_DECREF(b); + if (iscomb && a != NULL) { + b = perm_comb_small(k, j, 1); + if (b == NULL) { + goto error; + } + Py_SETREF(a, PyNumber_FloorDivide(a, b)); + Py_DECREF(b); + } + return a; + +error: + Py_DECREF(a); + return NULL; +} + /*[clinic input] math.perm @@ -3244,9 +3376,9 @@ static PyObject * math_perm_impl(PyObject *module, PyObject *n, PyObject *k) /*[clinic end generated code: output=e021a25469653e23 input=5311c5a00f359b53]*/ { - PyObject *result = NULL, *factor = NULL; + PyObject *result = NULL; int overflow, cmp; - long long i, factors; + long long ki, ni; if (k == Py_None) { return math_factorial(module, n); @@ -3260,6 +3392,7 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k) Py_DECREF(n); return NULL; } + assert(PyLong_CheckExact(n) && PyLong_CheckExact(k)); if (Py_SIZE(n) < 0) { PyErr_SetString(PyExc_ValueError, @@ -3281,42 +3414,26 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k) goto error; } - factors = PyLong_AsLongLongAndOverflow(k, &overflow); + ki = PyLong_AsLongLongAndOverflow(k, &overflow); + assert(overflow >= 0 && !PyErr_Occurred()); if (overflow > 0) { PyErr_Format(PyExc_OverflowError, "k must not exceed %lld", LLONG_MAX); goto error; } - else if (factors == -1) { - /* k is nonnegative, so a return value of -1 can only indicate error */ - goto error; - } + assert(ki >= 0); - if (factors == 0) { - result = PyLong_FromLong(1); - goto done; + ni = PyLong_AsLongLongAndOverflow(n, &overflow); + assert(overflow >= 0 && !PyErr_Occurred()); + if (!overflow && ki > 1) { + assert(ni >= 0); + result = perm_comb_small((unsigned long long)ni, + (unsigned long long)ki, 0); } - - result = n; - Py_INCREF(result); - if (factors == 1) { - goto done; - } - - factor = Py_NewRef(n); - PyObject *one = _PyLong_GetOne(); // borrowed ref - for (i = 1; i < factors; ++i) { - Py_SETREF(factor, PyNumber_Subtract(factor, one)); - if (factor == NULL) { - goto error; - } - Py_SETREF(result, PyNumber_Multiply(result, factor)); - if (result == NULL) { - goto error; - } + else { + result = perm_comb(n, (unsigned long long)ki, 0); } - Py_DECREF(factor); done: Py_DECREF(n); @@ -3324,14 +3441,11 @@ done: return result; error: - Py_XDECREF(factor); - Py_XDECREF(result); Py_DECREF(n); Py_DECREF(k); return NULL; } - /*[clinic input] math.comb @@ -3357,9 +3471,9 @@ static PyObject * math_comb_impl(PyObject *module, PyObject *n, PyObject *k) /*[clinic end generated code: output=bd2cec8d854f3493 input=9a05315af2518709]*/ { - PyObject *result = NULL, *factor = NULL, *temp; + PyObject *result = NULL, *temp; int overflow, cmp; - long long i, factors; + long long ki, ni; n = PyNumber_Index(n); if (n == NULL) { @@ -3370,6 +3484,7 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k) Py_DECREF(n); return NULL; } + assert(PyLong_CheckExact(n) && PyLong_CheckExact(k)); if (Py_SIZE(n) < 0) { PyErr_SetString(PyExc_ValueError, @@ -3382,73 +3497,59 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k) goto error; } - /* k = min(k, n - k) */ - temp = PyNumber_Subtract(n, k); - if (temp == NULL) { - goto error; - } - if (Py_SIZE(temp) < 0) { - Py_DECREF(temp); - result = PyLong_FromLong(0); - goto done; - } - cmp = PyObject_RichCompareBool(temp, k, Py_LT); - if (cmp > 0) { - Py_SETREF(k, temp); + ni = PyLong_AsLongLongAndOverflow(n, &overflow); + assert(overflow >= 0 && !PyErr_Occurred()); + if (!overflow) { + assert(ni >= 0); + ki = PyLong_AsLongLongAndOverflow(k, &overflow); + assert(overflow >= 0 && !PyErr_Occurred()); + if (overflow || ki > ni) { + result = PyLong_FromLong(0); + goto done; + } + assert(ki >= 0); + ki = Py_MIN(ki, ni - ki); + if (ki > 1) { + result = perm_comb_small((unsigned long long)ni, + (unsigned long long)ki, 1); + goto done; + } + /* For k == 1 just return the original n in perm_comb(). */ } else { - Py_DECREF(temp); - if (cmp < 0) { + /* k = min(k, n - k) */ + temp = PyNumber_Subtract(n, k); + if (temp == NULL) { goto error; } - } - - factors = PyLong_AsLongLongAndOverflow(k, &overflow); - if (overflow > 0) { - PyErr_Format(PyExc_OverflowError, - "min(n - k, k) must not exceed %lld", - LLONG_MAX); - goto error; - } - if (factors == -1) { - /* k is nonnegative, so a return value of -1 can only indicate error */ - goto error; - } - - if (factors == 0) { - result = PyLong_FromLong(1); - goto done; - } - - result = n; - Py_INCREF(result); - if (factors == 1) { - goto done; - } - - factor = Py_NewRef(n); - PyObject *one = _PyLong_GetOne(); // borrowed ref - for (i = 1; i < factors; ++i) { - Py_SETREF(factor, PyNumber_Subtract(factor, one)); - if (factor == NULL) { - goto error; + if (Py_SIZE(temp) < 0) { + Py_DECREF(temp); + result = PyLong_FromLong(0); + goto done; } - Py_SETREF(result, PyNumber_Multiply(result, factor)); - if (result == NULL) { - goto error; + cmp = PyObject_RichCompareBool(temp, k, Py_LT); + if (cmp > 0) { + Py_SETREF(k, temp); } - - temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1); - if (temp == NULL) { - goto error; + else { + Py_DECREF(temp); + if (cmp < 0) { + goto error; + } } - Py_SETREF(result, PyNumber_FloorDivide(result, temp)); - Py_DECREF(temp); - if (result == NULL) { + + ki = PyLong_AsLongLongAndOverflow(k, &overflow); + assert(overflow >= 0 && !PyErr_Occurred()); + if (overflow) { + PyErr_Format(PyExc_OverflowError, + "min(n - k, k) must not exceed %lld", + LLONG_MAX); goto error; } + assert(ki >= 0); } - Py_DECREF(factor); + + result = perm_comb(n, (unsigned long long)ki, 1); done: Py_DECREF(n); @@ -3456,8 +3557,6 @@ done: return result; error: - Py_XDECREF(factor); - Py_XDECREF(result); Py_DECREF(n); Py_DECREF(k); return NULL; |