From 2b843ac0ae745026ce39514573c5d075137bef65 Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sat, 1 Jun 2019 22:09:02 +0300 Subject: bpo-35431: Refactor math.comb() implementation. (GH-13725) * Fixed some bugs. * Added support for index-likes objects. * Improved error messages. * Cleaned up and optimized the code. * Added more tests. --- Doc/library/math.rst | 4 +- Lib/test/test_math.py | 29 ++++++-- Modules/clinic/mathmodule.c.h | 24 ++----- Modules/mathmodule.c | 155 ++++++++++++++++++++++-------------------- 4 files changed, 111 insertions(+), 101 deletions(-) diff --git a/Doc/library/math.rst b/Doc/library/math.rst index 5243970..206b06e 100644 --- a/Doc/library/math.rst +++ b/Doc/library/math.rst @@ -238,11 +238,11 @@ Number-theoretic and representation functions and without order. Also called the binomial coefficient. It is mathematically equal to the expression - ``n! / (k! (n - k)!)``. It is equivalent to the coefficient of k-th term in + ``n! / (k! (n - k)!)``. It is equivalent to the coefficient of the *k*-th term in the polynomial expansion of the expression ``(1 + x) ** n``. Raises :exc:`TypeError` if the arguments not integers. - Raises :exc:`ValueError` if the arguments are negative or if k > n. + Raises :exc:`ValueError` if the arguments are negative or if *k* > *n*. .. versionadded:: 3.8 diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index 9da7f7c..e27092e 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -1893,9 +1893,11 @@ class IsCloseTests(unittest.TestCase): # Raises TypeError if any argument is non-integer or argument count is # not 2 self.assertRaises(TypeError, comb, 10, 1.0) + self.assertRaises(TypeError, comb, 10, decimal.Decimal(1.0)) self.assertRaises(TypeError, comb, 10, "1") - self.assertRaises(TypeError, comb, "10", 1) self.assertRaises(TypeError, comb, 10.0, 1) + self.assertRaises(TypeError, comb, decimal.Decimal(10.0), 1) + self.assertRaises(TypeError, comb, "10", 1) self.assertRaises(TypeError, comb, 10) self.assertRaises(TypeError, comb, 10, 1, 3) @@ -1903,15 +1905,28 @@ class IsCloseTests(unittest.TestCase): # Raises Value error if not k or n are negative numbers self.assertRaises(ValueError, comb, -1, 1) - self.assertRaises(ValueError, comb, -10*10, 1) + self.assertRaises(ValueError, comb, -2**1000, 1) self.assertRaises(ValueError, comb, 1, -1) - self.assertRaises(ValueError, comb, 1, -10*10) + self.assertRaises(ValueError, comb, 1, -2**1000) # Raises value error if k is greater than n - self.assertRaises(ValueError, comb, 1, 10**10) - self.assertRaises(ValueError, comb, 0, 1) - - + self.assertRaises(ValueError, comb, 1, 2) + self.assertRaises(ValueError, comb, 1, 2**1000) + + n = 2**1000 + self.assertEqual(comb(n, 0), 1) + self.assertEqual(comb(n, 1), n) + self.assertEqual(comb(n, 2), n * (n-1) // 2) + self.assertEqual(comb(n, n), 1) + self.assertEqual(comb(n, n-1), n) + self.assertEqual(comb(n, n-2), n * (n-1) // 2) + self.assertRaises((OverflowError, MemoryError), comb, n, n//2) + + for n, k in (True, True), (True, False), (False, False): + self.assertEqual(comb(n, k), 1) + self.assertIs(type(comb(n, k)), int) + self.assertEqual(comb(MyIndexable(5), MyIndexable(2)), 10) + self.assertIs(type(comb(MyIndexable(5), MyIndexable(2))), int) def test_main(): diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h index cba791e..92ec4be 100644 --- a/Modules/clinic/mathmodule.c.h +++ b/Modules/clinic/mathmodule.c.h @@ -639,10 +639,10 @@ exit: } PyDoc_STRVAR(math_comb__doc__, -"comb($module, /, n, k)\n" +"comb($module, n, k, /)\n" "--\n" "\n" -"Number of ways to choose *k* items from *n* items without repetition and without order.\n" +"Number of ways to choose k items from n items without repetition and without order.\n" "\n" "Also called the binomial coefficient. It is mathematically equal to the expression\n" "n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in\n" @@ -652,38 +652,26 @@ PyDoc_STRVAR(math_comb__doc__, "Raises ValueError if the arguments are negative or if k > n."); #define MATH_COMB_METHODDEF \ - {"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL|METH_KEYWORDS, math_comb__doc__}, + {"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL, math_comb__doc__}, static PyObject * math_comb_impl(PyObject *module, PyObject *n, PyObject *k); static PyObject * -math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs) { PyObject *return_value = NULL; - static const char * const _keywords[] = {"n", "k", NULL}; - static _PyArg_Parser _parser = {NULL, _keywords, "comb", 0}; - PyObject *argsbuf[2]; PyObject *n; PyObject *k; - args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 2, 2, 0, argsbuf); - if (!args) { - goto exit; - } - if (!PyLong_Check(args[0])) { - _PyArg_BadArgument("comb", 1, "int", args[0]); + if (!_PyArg_CheckPositional("comb", nargs, 2, 2)) { goto exit; } n = args[0]; - if (!PyLong_Check(args[1])) { - _PyArg_BadArgument("comb", 2, "int", args[1]); - goto exit; - } k = args[1]; return_value = math_comb_impl(module, n, k); exit: return return_value; } -/*[clinic end generated code: output=00aa76356759617a input=a9049054013a1b77]*/ +/*[clinic end generated code: output=6709521e5e1d90ec input=a9049054013a1b77]*/ diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index 007a880..bea4607 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -3001,10 +3001,11 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start) /*[clinic input] math.comb - n: object(subclass_of='&PyLong_Type') - k: object(subclass_of='&PyLong_Type') + n: object + k: object + / -Number of ways to choose *k* items from *n* items without repetition and without order. +Number of ways to choose k items from n items without repetition and without order. Also called the binomial coefficient. It is mathematically equal to the expression n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in @@ -3017,103 +3018,109 @@ Raises ValueError if the arguments are negative or if k > n. static PyObject * math_comb_impl(PyObject *module, PyObject *n, PyObject *k) -/*[clinic end generated code: output=bd2cec8d854f3493 input=565f340f98efb5b5]*/ +/*[clinic end generated code: output=bd2cec8d854f3493 input=2f336ac9ec8242f9]*/ { - PyObject *val = NULL, - *temp_obj1 = NULL, - *temp_obj2 = NULL, - *dump_var = NULL; + PyObject *result = NULL, *factor = NULL, *temp; int overflow, cmp; - long long i, terms; + long long i, factors; - cmp = PyObject_RichCompareBool(n, k, Py_LT); - if (cmp < 0) { - goto fail_comb; + n = PyNumber_Index(n); + if (n == NULL) { + return NULL; } - else if (cmp > 0) { - PyErr_Format(PyExc_ValueError, - "n must be an integer greater than or equal to k"); - goto fail_comb; + k = PyNumber_Index(k); + if (k == NULL) { + Py_DECREF(n); + return NULL; } - /* b = min(b, a - b) */ - dump_var = PyNumber_Subtract(n, k); - if (dump_var == NULL) { - goto fail_comb; + if (Py_SIZE(n) < 0) { + PyErr_SetString(PyExc_ValueError, + "n must be a non-negative integer"); + goto error; } - cmp = PyObject_RichCompareBool(k, dump_var, Py_GT); - if (cmp < 0) { - goto fail_comb; + /* k = min(k, n - k) */ + temp = PyNumber_Subtract(n, k); + if (temp == NULL) { + goto error; } - else if (cmp > 0) { - k = dump_var; - dump_var = NULL; + if (Py_SIZE(temp) < 0) { + Py_DECREF(temp); + PyErr_SetString(PyExc_ValueError, + "k must be an integer less than or equal to n"); + goto error; + } + cmp = PyObject_RichCompareBool(k, temp, Py_GT); + if (cmp > 0) { + Py_SETREF(k, temp); } else { - Py_DECREF(dump_var); - dump_var = NULL; + Py_DECREF(temp); + if (cmp < 0) { + goto error; + } } - terms = PyLong_AsLongLongAndOverflow(k, &overflow); - if (terms < 0 && PyErr_Occurred()) { - goto fail_comb; - } - else if (overflow > 0) { + factors = PyLong_AsLongLongAndOverflow(k, &overflow); + if (overflow > 0) { PyErr_Format(PyExc_OverflowError, - "minimum(n - k, k) must not exceed %lld", + "min(n - k, k) must not exceed %lld", LLONG_MAX); - goto fail_comb; + goto error; } - else if (overflow < 0 || terms < 0) { - PyErr_Format(PyExc_ValueError, - "k must be a positive integer"); - goto fail_comb; + else if (overflow < 0 || factors < 0) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, + "k must be a non-negative integer"); + } + goto error; } - if (terms == 0) { - return PyNumber_Long(_PyLong_One); + if (factors == 0) { + result = PyLong_FromLong(1); + goto done; } - val = PyNumber_Long(n); - for (i = 1; i < terms; ++i) { - temp_obj1 = PyLong_FromSsize_t(i); - if (temp_obj1 == NULL) { - goto fail_comb; - } - temp_obj2 = PyNumber_Subtract(n, temp_obj1); - if (temp_obj2 == NULL) { - goto fail_comb; + result = n; + Py_INCREF(result); + if (factors == 1) { + goto done; + } + + factor = n; + Py_INCREF(factor); + for (i = 1; i < factors; ++i) { + Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One)); + if (factor == NULL) { + goto error; } - dump_var = val; - val = PyNumber_Multiply(val, temp_obj2); - if (val == NULL) { - goto fail_comb; + Py_SETREF(result, PyNumber_Multiply(result, factor)); + if (result == NULL) { + goto error; } - Py_DECREF(dump_var); - dump_var = NULL; - Py_DECREF(temp_obj2); - temp_obj2 = PyLong_FromUnsignedLongLong((unsigned long long)(i + 1)); - if (temp_obj2 == NULL) { - goto fail_comb; + + temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1); + if (temp == NULL) { + goto error; } - dump_var = val; - val = PyNumber_FloorDivide(val, temp_obj2); - if (val == NULL) { - goto fail_comb; + Py_SETREF(result, PyNumber_FloorDivide(result, temp)); + Py_DECREF(temp); + if (result == NULL) { + goto error; } - Py_DECREF(dump_var); - Py_DECREF(temp_obj1); - Py_DECREF(temp_obj2); } + Py_DECREF(factor); - return val; - -fail_comb: - Py_XDECREF(val); - Py_XDECREF(dump_var); - Py_XDECREF(temp_obj1); - Py_XDECREF(temp_obj2); +done: + Py_DECREF(n); + Py_DECREF(k); + return result; +error: + Py_XDECREF(factor); + Py_XDECREF(result); + Py_DECREF(n); + Py_DECREF(k); return NULL; } -- cgit v0.12