diff options
author | Serhiy Storchaka <storchaka@gmail.com> | 2019-06-01 19:09:02 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-01 19:09:02 (GMT) |
commit | 2b843ac0ae745026ce39514573c5d075137bef65 (patch) | |
tree | 8e176372e55d171590b4c798d6deaf9311cbef8c /Modules/mathmodule.c | |
parent | 9843bc110dc4241ba7cb05f3d3ef74ac6c77caf2 (diff) | |
download | cpython-2b843ac0ae745026ce39514573c5d075137bef65.zip cpython-2b843ac0ae745026ce39514573c5d075137bef65.tar.gz cpython-2b843ac0ae745026ce39514573c5d075137bef65.tar.bz2 |
bpo-35431: Refactor math.comb() implementation. (GH-13725)
* Fixed some bugs.
* Added support for index-likes objects.
* Improved error messages.
* Cleaned up and optimized the code.
* Added more tests.
Diffstat (limited to 'Modules/mathmodule.c')
-rw-r--r-- | Modules/mathmodule.c | 155 |
1 files changed, 81 insertions, 74 deletions
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index 007a880..bea4607 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -3001,10 +3001,11 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start) /*[clinic input] math.comb - n: object(subclass_of='&PyLong_Type') - k: object(subclass_of='&PyLong_Type') + n: object + k: object + / -Number of ways to choose *k* items from *n* items without repetition and without order. +Number of ways to choose k items from n items without repetition and without order. Also called the binomial coefficient. It is mathematically equal to the expression n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in @@ -3017,103 +3018,109 @@ Raises ValueError if the arguments are negative or if k > n. static PyObject * math_comb_impl(PyObject *module, PyObject *n, PyObject *k) -/*[clinic end generated code: output=bd2cec8d854f3493 input=565f340f98efb5b5]*/ +/*[clinic end generated code: output=bd2cec8d854f3493 input=2f336ac9ec8242f9]*/ { - PyObject *val = NULL, - *temp_obj1 = NULL, - *temp_obj2 = NULL, - *dump_var = NULL; + PyObject *result = NULL, *factor = NULL, *temp; int overflow, cmp; - long long i, terms; + long long i, factors; - cmp = PyObject_RichCompareBool(n, k, Py_LT); - if (cmp < 0) { - goto fail_comb; + n = PyNumber_Index(n); + if (n == NULL) { + return NULL; } - else if (cmp > 0) { - PyErr_Format(PyExc_ValueError, - "n must be an integer greater than or equal to k"); - goto fail_comb; + k = PyNumber_Index(k); + if (k == NULL) { + Py_DECREF(n); + return NULL; } - /* b = min(b, a - b) */ - dump_var = PyNumber_Subtract(n, k); - if (dump_var == NULL) { - goto fail_comb; + if (Py_SIZE(n) < 0) { + PyErr_SetString(PyExc_ValueError, + "n must be a non-negative integer"); + goto error; } - cmp = PyObject_RichCompareBool(k, dump_var, Py_GT); - if (cmp < 0) { - goto fail_comb; + /* k = min(k, n - k) */ + temp = PyNumber_Subtract(n, k); + if (temp == NULL) { + goto error; } - else if (cmp > 0) { - k = dump_var; - dump_var = NULL; + if (Py_SIZE(temp) < 0) { + Py_DECREF(temp); + PyErr_SetString(PyExc_ValueError, + "k must be an integer less than or equal to n"); + goto error; + } + cmp = PyObject_RichCompareBool(k, temp, Py_GT); + if (cmp > 0) { + Py_SETREF(k, temp); } else { - Py_DECREF(dump_var); - dump_var = NULL; + Py_DECREF(temp); + if (cmp < 0) { + goto error; + } } - terms = PyLong_AsLongLongAndOverflow(k, &overflow); - if (terms < 0 && PyErr_Occurred()) { - goto fail_comb; - } - else if (overflow > 0) { + factors = PyLong_AsLongLongAndOverflow(k, &overflow); + if (overflow > 0) { PyErr_Format(PyExc_OverflowError, - "minimum(n - k, k) must not exceed %lld", + "min(n - k, k) must not exceed %lld", LLONG_MAX); - goto fail_comb; + goto error; } - else if (overflow < 0 || terms < 0) { - PyErr_Format(PyExc_ValueError, - "k must be a positive integer"); - goto fail_comb; + else if (overflow < 0 || factors < 0) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, + "k must be a non-negative integer"); + } + goto error; } - if (terms == 0) { - return PyNumber_Long(_PyLong_One); + if (factors == 0) { + result = PyLong_FromLong(1); + goto done; } - val = PyNumber_Long(n); - for (i = 1; i < terms; ++i) { - temp_obj1 = PyLong_FromSsize_t(i); - if (temp_obj1 == NULL) { - goto fail_comb; - } - temp_obj2 = PyNumber_Subtract(n, temp_obj1); - if (temp_obj2 == NULL) { - goto fail_comb; + result = n; + Py_INCREF(result); + if (factors == 1) { + goto done; + } + + factor = n; + Py_INCREF(factor); + for (i = 1; i < factors; ++i) { + Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One)); + if (factor == NULL) { + goto error; } - dump_var = val; - val = PyNumber_Multiply(val, temp_obj2); - if (val == NULL) { - goto fail_comb; + Py_SETREF(result, PyNumber_Multiply(result, factor)); + if (result == NULL) { + goto error; } - Py_DECREF(dump_var); - dump_var = NULL; - Py_DECREF(temp_obj2); - temp_obj2 = PyLong_FromUnsignedLongLong((unsigned long long)(i + 1)); - if (temp_obj2 == NULL) { - goto fail_comb; + + temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1); + if (temp == NULL) { + goto error; } - dump_var = val; - val = PyNumber_FloorDivide(val, temp_obj2); - if (val == NULL) { - goto fail_comb; + Py_SETREF(result, PyNumber_FloorDivide(result, temp)); + Py_DECREF(temp); + if (result == NULL) { + goto error; } - Py_DECREF(dump_var); - Py_DECREF(temp_obj1); - Py_DECREF(temp_obj2); } + Py_DECREF(factor); - return val; - -fail_comb: - Py_XDECREF(val); - Py_XDECREF(dump_var); - Py_XDECREF(temp_obj1); - Py_XDECREF(temp_obj2); +done: + Py_DECREF(n); + Py_DECREF(k); + return result; +error: + Py_XDECREF(factor); + Py_XDECREF(result); + Py_DECREF(n); + Py_DECREF(k); return NULL; } |