diff options
-rw-r--r-- | Lib/_pylong.py | 295 | ||||
-rw-r--r-- | Lib/test/test_int.py | 47 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Core and Builtins/2022-09-09-16-32-58.gh-issue-90716.z4yuYq.rst | 3 | ||||
-rw-r--r-- | Objects/longobject.c | 170 | ||||
-rw-r--r-- | Python/stdlib_module_names.h | 1 | ||||
-rw-r--r-- | Tools/scripts/divmod_threshold.py | 56 |
6 files changed, 572 insertions, 0 deletions
diff --git a/Lib/_pylong.py b/Lib/_pylong.py new file mode 100644 index 0000000..e0d4e90 --- /dev/null +++ b/Lib/_pylong.py @@ -0,0 +1,295 @@ +"""Python implementations of some algorithms for use by longobject.c. +The goal is to provide asymptotically faster algorithms that can be +used for operations on integers with many digits. In those cases, the +performance overhead of the Python implementation is not significant +since the asymptotic behavior is what dominates runtime. Functions +provided by this module should be considered private and not part of any +public API. + +Note: for ease of maintainability, please prefer clear code and avoid +"micro-optimizations". This module will only be imported and used for +integers with a huge number of digits. Saving a few microseconds with +tricky or non-obvious code is not worth it. For people looking for +maximum performance, they should use something like gmpy2.""" + +import sys +import re +import decimal + +_DEBUG = False + + +def int_to_decimal(n): + """Asymptotically fast conversion of an 'int' to Decimal.""" + + # Function due to Tim Peters. See GH issue #90716 for details. + # https://github.com/python/cpython/issues/90716 + # + # The implementation in longobject.c of base conversion algorithms + # between power-of-2 and non-power-of-2 bases are quadratic time. + # This function implements a divide-and-conquer algorithm that is + # faster for large numbers. Builds an equal decimal.Decimal in a + # "clever" recursive way. If we want a string representation, we + # apply str to _that_. + + if _DEBUG: + print('int_to_decimal', n.bit_length(), file=sys.stderr) + + D = decimal.Decimal + D2 = D(2) + + BITLIM = 128 + + mem = {} + + def w2pow(w): + """Return D(2)**w and store the result. Also possibly save some + intermediate results. In context, these are likely to be reused + across various levels of the conversion to Decimal.""" + if (result := mem.get(w)) is None: + if w <= BITLIM: + result = D2**w + elif w - 1 in mem: + result = (t := mem[w - 1]) + t + else: + w2 = w >> 1 + # If w happens to be odd, w-w2 is one larger then w2 + # now. Recurse on the smaller first (w2), so that it's + # in the cache and the larger (w-w2) can be handled by + # the cheaper `w-1 in mem` branch instead. + result = w2pow(w2) * w2pow(w - w2) + mem[w] = result + return result + + def inner(n, w): + if w <= BITLIM: + return D(n) + w2 = w >> 1 + hi = n >> w2 + lo = n - (hi << w2) + return inner(lo, w2) + inner(hi, w - w2) * w2pow(w2) + + with decimal.localcontext() as ctx: + ctx.prec = decimal.MAX_PREC + ctx.Emax = decimal.MAX_EMAX + ctx.Emin = decimal.MIN_EMIN + ctx.traps[decimal.Inexact] = 1 + + if n < 0: + negate = True + n = -n + else: + negate = False + result = inner(n, n.bit_length()) + if negate: + result = -result + return result + + +def int_to_decimal_string(n): + """Asymptotically fast conversion of an 'int' to a decimal string.""" + return str(int_to_decimal(n)) + + +def _str_to_int_inner(s): + """Asymptotically fast conversion of a 'str' to an 'int'.""" + + # Function due to Bjorn Martinsson. See GH issue #90716 for details. + # https://github.com/python/cpython/issues/90716 + # + # The implementation in longobject.c of base conversion algorithms + # between power-of-2 and non-power-of-2 bases are quadratic time. + # This function implements a divide-and-conquer algorithm making use + # of Python's built in big int multiplication. Since Python uses the + # Karatsuba algorithm for multiplication, the time complexity + # of this function is O(len(s)**1.58). + + DIGLIM = 2048 + + mem = {} + + def w5pow(w): + """Return 5**w and store the result. + Also possibly save some intermediate results. In context, these + are likely to be reused across various levels of the conversion + to 'int'. + """ + if (result := mem.get(w)) is None: + if w <= DIGLIM: + result = 5**w + elif w - 1 in mem: + result = mem[w - 1] * 5 + else: + w2 = w >> 1 + # If w happens to be odd, w-w2 is one larger then w2 + # now. Recurse on the smaller first (w2), so that it's + # in the cache and the larger (w-w2) can be handled by + # the cheaper `w-1 in mem` branch instead. + result = w5pow(w2) * w5pow(w - w2) + mem[w] = result + return result + + def inner(a, b): + if b - a <= DIGLIM: + return int(s[a:b]) + mid = (a + b + 1) >> 1 + return inner(mid, b) + ((inner(a, mid) * w5pow(b - mid)) << (b - mid)) + + return inner(0, len(s)) + + +def int_from_string(s): + """Asymptotically fast version of PyLong_FromString(), conversion + of a string of decimal digits into an 'int'.""" + if _DEBUG: + print('int_from_string', len(s), file=sys.stderr) + # PyLong_FromString() has already removed leading +/-, checked for invalid + # use of underscore characters, checked that string consists of only digits + # and underscores, and stripped leading whitespace. The input can still + # contain underscores and have trailing whitespace. + s = s.rstrip().replace('_', '') + return _str_to_int_inner(s) + + +def str_to_int(s): + """Asymptotically fast version of decimal string to 'int' conversion.""" + # FIXME: this doesn't support the full syntax that int() supports. + m = re.match(r'\s*([+-]?)([0-9_]+)\s*', s) + if not m: + raise ValueError('invalid literal for int() with base 10') + v = int_from_string(m.group(2)) + if m.group(1) == '-': + v = -v + return v + + +# Fast integer division, based on code from Mark Dickinson, fast_div.py +# GH-47701. Additional refinements and optimizations by Bjorn Martinsson. The +# algorithm is due to Burnikel and Ziegler, in their paper "Fast Recursive +# Division". + +_DIV_LIMIT = 4000 + + +def _div2n1n(a, b, n): + """Divide a 2n-bit nonnegative integer a by an n-bit positive integer + b, using a recursive divide-and-conquer algorithm. + + Inputs: + n is a positive integer + b is a positive integer with exactly n bits + a is a nonnegative integer such that a < 2**n * b + + Output: + (q, r) such that a = b*q+r and 0 <= r < b. + + """ + if a.bit_length() - n <= _DIV_LIMIT: + return divmod(a, b) + pad = n & 1 + if pad: + a <<= 1 + b <<= 1 + n += 1 + half_n = n >> 1 + mask = (1 << half_n) - 1 + b1, b2 = b >> half_n, b & mask + q1, r = _div3n2n(a >> n, (a >> half_n) & mask, b, b1, b2, half_n) + q2, r = _div3n2n(r, a & mask, b, b1, b2, half_n) + if pad: + r >>= 1 + return q1 << half_n | q2, r + + +def _div3n2n(a12, a3, b, b1, b2, n): + """Helper function for _div2n1n; not intended to be called directly.""" + if a12 >> n == b1: + q, r = (1 << n) - 1, a12 - (b1 << n) + b1 + else: + q, r = _div2n1n(a12, b1, n) + r = (r << n | a3) - q * b2 + while r < 0: + q -= 1 + r += b + return q, r + + +def _int2digits(a, n): + """Decompose non-negative int a into base 2**n + + Input: + a is a non-negative integer + + Output: + List of the digits of a in base 2**n in little-endian order, + meaning the most significant digit is last. The most + significant digit is guaranteed to be non-zero. + If a is 0 then the output is an empty list. + + """ + a_digits = [0] * ((a.bit_length() + n - 1) // n) + + def inner(x, L, R): + if L + 1 == R: + a_digits[L] = x + return + mid = (L + R) >> 1 + shift = (mid - L) * n + upper = x >> shift + lower = x ^ (upper << shift) + inner(lower, L, mid) + inner(upper, mid, R) + + if a: + inner(a, 0, len(a_digits)) + return a_digits + + +def _digits2int(digits, n): + """Combine base-2**n digits into an int. This function is the + inverse of `_int2digits`. For more details, see _int2digits. + """ + + def inner(L, R): + if L + 1 == R: + return digits[L] + mid = (L + R) >> 1 + shift = (mid - L) * n + return (inner(mid, R) << shift) + inner(L, mid) + + return inner(0, len(digits)) if digits else 0 + + +def _divmod_pos(a, b): + """Divide a non-negative integer a by a positive integer b, giving + quotient and remainder.""" + # Use grade-school algorithm in base 2**n, n = nbits(b) + n = b.bit_length() + a_digits = _int2digits(a, n) + + r = 0 + q_digits = [] + for a_digit in reversed(a_digits): + q_digit, r = _div2n1n((r << n) + a_digit, b, n) + q_digits.append(q_digit) + q_digits.reverse() + q = _digits2int(q_digits, n) + return q, r + + +def int_divmod(a, b): + """Asymptotically fast replacement for divmod, for 'int'. + Its time complexity is O(n**1.58), where n = #bits(a) + #bits(b). + """ + if _DEBUG: + print('int_divmod', a.bit_length(), b.bit_length(), file=sys.stderr) + if b == 0: + raise ZeroDivisionError + elif b < 0: + q, r = int_divmod(-a, -b) + return q, -r + elif a < 0: + q, r = int_divmod(~a, b) + return ~q, b + ~r + else: + return _divmod_pos(a, b) diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py index 625c388..f484c59 100644 --- a/Lib/test/test_int.py +++ b/Lib/test/test_int.py @@ -795,5 +795,52 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests): int_class = IntSubclass +class PyLongModuleTests(unittest.TestCase): + # Tests of the functions in _pylong.py. Those get used when the + # number of digits in the input values are large enough. + + def setUp(self): + super().setUp() + self._previous_limit = sys.get_int_max_str_digits() + sys.set_int_max_str_digits(0) + + def tearDown(self): + sys.set_int_max_str_digits(self._previous_limit) + super().tearDown() + + def test_pylong_int_to_decimal(self): + n = (1 << 100_000) - 1 + suffix = '9883109375' + s = str(n) + assert s[-10:] == suffix + s = str(-n) + assert s[-10:] == suffix + s = '%d' % n + assert s[-10:] == suffix + s = b'%d' % n + assert s[-10:] == suffix.encode('ascii') + + def test_pylong_int_divmod(self): + n = (1 << 100_000) + a, b = divmod(n*3 + 1, n) + assert a == 3 and b == 1 + + def test_pylong_str_to_int(self): + v1 = 1 << 100_000 + s = str(v1) + v2 = int(s) + assert v1 == v2 + v3 = int(' -' + s) + assert -v1 == v3 + v4 = int(' +' + s + ' ') + assert v1 == v4 + with self.assertRaises(ValueError) as err: + int(s + 'z') + with self.assertRaises(ValueError) as err: + int(s + '_') + with self.assertRaises(ValueError) as err: + int('_' + s) + + if __name__ == "__main__": unittest.main() diff --git a/Misc/NEWS.d/next/Core and Builtins/2022-09-09-16-32-58.gh-issue-90716.z4yuYq.rst b/Misc/NEWS.d/next/Core and Builtins/2022-09-09-16-32-58.gh-issue-90716.z4yuYq.rst new file mode 100644 index 0000000..d04e2c0 --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2022-09-09-16-32-58.gh-issue-90716.z4yuYq.rst @@ -0,0 +1,3 @@ +Add _pylong.py module. It includes asymptotically faster algorithms that +can be used for operations on integers with many digits. It is used by +longobject.c to speed up some operations. diff --git a/Objects/longobject.c b/Objects/longobject.c index 304fabf..cf859cc 100644 --- a/Objects/longobject.c +++ b/Objects/longobject.c @@ -39,6 +39,9 @@ medium_value(PyLongObject *x) #define _MAX_STR_DIGITS_ERROR_FMT_TO_INT "Exceeds the limit (%d) for integer string conversion: value has %zd digits; use sys.set_int_max_str_digits() to increase the limit" #define _MAX_STR_DIGITS_ERROR_FMT_TO_STR "Exceeds the limit (%d) for integer string conversion; use sys.set_int_max_str_digits() to increase the limit" +/* If defined, use algorithms from the _pylong.py module */ +#define WITH_PYLONG_MODULE 1 + static inline void _Py_DECREF_INT(PyLongObject *op) { @@ -1732,6 +1735,69 @@ rem1(PyLongObject *a, digit n) ); } +#ifdef WITH_PYLONG_MODULE +/* asymptotically faster long_to_decimal_string, using _pylong.py */ +static int +pylong_int_to_decimal_string(PyObject *aa, + PyObject **p_output, + _PyUnicodeWriter *writer, + _PyBytesWriter *bytes_writer, + char **bytes_str) +{ + PyObject *s = NULL; + PyObject *mod = PyImport_ImportModule("_pylong"); + if (mod == NULL) { + return -1; + } + s = PyObject_CallMethod(mod, "int_to_decimal_string", "O", aa); + if (s == NULL) { + goto error; + } + assert(PyUnicode_Check(s)); + if (writer) { + Py_ssize_t size = PyUnicode_GET_LENGTH(s); + if (_PyUnicodeWriter_Prepare(writer, size, '9') == -1) { + goto error; + } + if (_PyUnicodeWriter_WriteStr(writer, s) < 0) { + goto error; + } + goto success; + } + else if (bytes_writer) { + Py_ssize_t size = PyUnicode_GET_LENGTH(s); + const void *data = PyUnicode_DATA(s); + int kind = PyUnicode_KIND(s); + *bytes_str = _PyBytesWriter_Prepare(bytes_writer, *bytes_str, size); + if (*bytes_str == NULL) { + goto error; + } + char *p = *bytes_str; + for (Py_ssize_t i=0; i < size; i++) { + Py_UCS4 ch = PyUnicode_READ(kind, data, i); + *p++ = (char) ch; + } + (*bytes_str) = p; + goto success; + } + else { + *p_output = (PyObject *)s; + Py_INCREF(s); + goto success; + } + +error: + Py_DECREF(mod); + Py_XDECREF(s); + return -1; + +success: + Py_DECREF(mod); + Py_DECREF(s); + return 0; +} +#endif /* WITH_PYLONG_MODULE */ + /* Convert an integer to a base 10 string. Returns a new non-shared string. (Return value is non-shared so that callers can modify the returned value if necessary.) */ @@ -1776,6 +1842,17 @@ long_to_decimal_string_internal(PyObject *aa, } } +#if WITH_PYLONG_MODULE + if (size_a > 1000) { + /* Switch to _pylong.int_to_decimal_string(). */ + return pylong_int_to_decimal_string(aa, + p_output, + writer, + bytes_writer, + bytes_str); + } +#endif + /* quick and dirty upper bound for the number of digits required to express a in base _PyLong_DECIMAL_BASE: @@ -2272,6 +2349,39 @@ long_from_binary_base(const char *start, const char *end, Py_ssize_t digits, int return 0; } +static PyObject *long_neg(PyLongObject *v); + +#ifdef WITH_PYLONG_MODULE +/* asymptotically faster str-to-long conversion for base 10, using _pylong.py */ +static int +pylong_int_from_string(const char *start, const char *end, PyLongObject **res) +{ + PyObject *mod = PyImport_ImportModule("_pylong"); + if (mod == NULL) { + goto error; + } + PyObject *s = PyUnicode_FromStringAndSize(start, end-start); + if (s == NULL) { + goto error; + } + PyObject *result = PyObject_CallMethod(mod, "int_from_string", "O", s); + Py_DECREF(s); + Py_DECREF(mod); + if (result == NULL) { + goto error; + } + if (!PyLong_Check(result)) { + PyErr_SetString(PyExc_TypeError, "an integer is required"); + goto error; + } + *res = (PyLongObject *)result; + return 0; +error: + *res = NULL; + return 0; +} +#endif /* WITH_PYLONG_MODULE */ + /*** long_from_non_binary_base: parameters and return values are the same as long_from_binary_base. @@ -2586,6 +2696,12 @@ long_from_string_base(const char **str, int base, PyLongObject **res) return 0; } } +#if WITH_PYLONG_MODULE + if (digits > 6000 && base == 10) { + /* Switch to _pylong.int_from_string() */ + return pylong_int_from_string(start, end, res); + } +#endif /* Use the quadratic algorithm for non binary bases. */ return long_from_non_binary_base(start, end, digits, base, res); } @@ -3913,6 +4029,48 @@ fast_floor_div(PyLongObject *a, PyLongObject *b) return PyLong_FromLong(div); } +#ifdef WITH_PYLONG_MODULE +/* asymptotically faster divmod, using _pylong.py */ +static int +pylong_int_divmod(PyLongObject *v, PyLongObject *w, + PyLongObject **pdiv, PyLongObject **pmod) +{ + PyObject *mod = PyImport_ImportModule("_pylong"); + if (mod == NULL) { + return -1; + } + PyObject *result = PyObject_CallMethod(mod, "int_divmod", "OO", v, w); + Py_DECREF(mod); + if (result == NULL) { + return -1; + } + if (!PyTuple_Check(result)) { + Py_DECREF(result); + PyErr_SetString(PyExc_ValueError, + "tuple is required from int_divmod()"); + return -1; + } + PyObject *q = PyTuple_GET_ITEM(result, 0); + PyObject *r = PyTuple_GET_ITEM(result, 1); + if (!PyLong_Check(q) || !PyLong_Check(r)) { + Py_DECREF(result); + PyErr_SetString(PyExc_ValueError, + "tuple of int is required from int_divmod()"); + return -1; + } + if (pdiv != NULL) { + Py_INCREF(q); + *pdiv = (PyLongObject *)q; + } + if (pmod != NULL) { + Py_INCREF(r); + *pmod = (PyLongObject *)r; + } + Py_DECREF(result); + return 0; +} +#endif /* WITH_PYLONG_MODULE */ + /* The / and % operators are now defined in terms of divmod(). The expression a mod b has the value a - b*floor(a/b). The long_divrem function gives the remainder after division of @@ -3964,6 +4122,18 @@ l_divmod(PyLongObject *v, PyLongObject *w, } return 0; } +#if WITH_PYLONG_MODULE + Py_ssize_t size_v = Py_ABS(Py_SIZE(v)); /* digits in numerator */ + Py_ssize_t size_w = Py_ABS(Py_SIZE(w)); /* digits in denominator */ + if (size_w > 300 && (size_v - size_w) > 150) { + /* Switch to _pylong.int_divmod(). If the quotient is small then + "schoolbook" division is linear-time so don't use in that case. + These limits are empirically determined and should be slightly + conservative so that _pylong is used in cases it is likely + to be faster. See Tools/scripts/divmod_threshold.py. */ + return pylong_int_divmod(v, w, pdiv, pmod); + } +#endif if (long_divrem(v, w, &div, &mod) < 0) return -1; if ((Py_SIZE(mod) < 0 && Py_SIZE(w) > 0) || diff --git a/Python/stdlib_module_names.h b/Python/stdlib_module_names.h index 12827e7..c31bf1e 100644 --- a/Python/stdlib_module_names.h +++ b/Python/stdlib_module_names.h @@ -58,6 +58,7 @@ static const char* _Py_stdlib_module_names[] = { "_py_abc", "_pydecimal", "_pyio", +"_pylong", "_queue", "_random", "_scproxy", diff --git a/Tools/scripts/divmod_threshold.py b/Tools/scripts/divmod_threshold.py new file mode 100644 index 0000000..69819fe --- /dev/null +++ b/Tools/scripts/divmod_threshold.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# +# Determine threshold for switching from longobject.c divmod to +# _pylong.int_divmod(). + +from random import randrange +from time import perf_counter as now +from _pylong import int_divmod as divmod_fast + +BITS_PER_DIGIT = 30 + + +def rand_digits(n): + top = 1 << (n * BITS_PER_DIGIT) + return randrange(top >> 1, top) + + +def probe_den(nd): + den = rand_digits(nd) + count = 0 + for nn in range(nd, nd + 3000): + num = rand_digits(nn) + t0 = now() + e1, e2 = divmod(num, den) + t1 = now() + f1, f2 = divmod_fast(num, den) + t2 = now() + s1 = t1 - t0 + s2 = t2 - t1 + assert e1 == f1 + assert e2 == f2 + if s2 < s1: + count += 1 + if count >= 3: + print( + "for", + nd, + "denom digits,", + nn - nd, + "extra num digits is enough", + ) + break + else: + count = 0 + else: + print("for", nd, "denom digits, no num seems big enough") + + +def main(): + for nd in range(30): + nd = (nd + 1) * 100 + probe_den(nd) + + +if __name__ == '__main__': + main() |