From 2f0a338be66e276a700f86af51d5ef54a138cbfb Mon Sep 17 00:00:00 2001 From: Tim Peters Date: Tue, 7 May 2024 19:09:09 -0500 Subject: gh-118610: Centralize power caching in `_pylong.py` (#118611) A new `compute_powers()` function computes all and only the powers of the base the various base-conversion functions need, as efficiently as reasonably possible (turns out that invoking `**`is needed at most once). This typically gives a few % speedup, but the primary point is to simplify the base-conversion functions, which no longer need their own, ad hoc, and less efficient power-caching schemes. Co-authored-by: Serhiy Storchaka --- Lib/_pylong.py | 168 +++++++++++++++++++++++++++++++-------------------- Lib/test/test_int.py | 12 ++++ 2 files changed, 113 insertions(+), 67 deletions(-) diff --git a/Lib/_pylong.py b/Lib/_pylong.py index 30bee6f..4970eb3 100644 --- a/Lib/_pylong.py +++ b/Lib/_pylong.py @@ -19,6 +19,86 @@ try: except ImportError: _decimal = None +# A number of functions have this form, where `w` is a desired number of +# digits in base `base`: +# +# def inner(...w...): +# if w <= LIMIT: +# return something +# lo = w >> 1 +# hi = w - lo +# something involving base**lo, inner(...lo...), j, and inner(...hi...) +# figure out largest w needed +# result = inner(w) +# +# They all had some on-the-fly scheme to cache `base**lo` results for reuse. +# Power is costly. +# +# This routine aims to compute all amd only the needed powers in advance, as +# efficiently as reasonably possible. This isn't trivial, and all the +# on-the-fly methods did needless work in many cases. The driving code above +# changes to: +# +# figure out largest w needed +# mycache = compute_powers(w, base, LIMIT) +# result = inner(w) +# +# and `mycache[lo]` replaces `base**lo` in the inner function. +# +# While this does give minor speedups (a few percent at best), the primary +# intent is to simplify the functions using this, by eliminating the need for +# them to craft their own ad-hoc caching schemes. +def compute_powers(w, base, more_than, show=False): + seen = set() + need = set() + ws = {w} + while ws: + w = ws.pop() # any element is fine to use next + if w in seen or w <= more_than: + continue + seen.add(w) + lo = w >> 1 + # only _need_ lo here; some other path may, or may not, need hi + need.add(lo) + ws.add(lo) + if w & 1: + ws.add(lo + 1) + + d = {} + if not need: + return d + it = iter(sorted(need)) + first = next(it) + if show: + print("pow at", first) + d[first] = base ** first + for this in it: + if this - 1 in d: + if show: + print("* base at", this) + d[this] = d[this - 1] * base # cheap + else: + lo = this >> 1 + hi = this - lo + assert lo in d + if show: + print("square at", this) + # Multiplying a bigint by itself (same object!) is about twice + # as fast in CPython. + sq = d[lo] * d[lo] + if hi != lo: + assert hi == lo + 1 + if show: + print(" and * base") + sq *= base + d[this] = sq + return d + +_unbounded_dec_context = decimal.getcontext().copy() +_unbounded_dec_context.prec = decimal.MAX_PREC +_unbounded_dec_context.Emax = decimal.MAX_EMAX +_unbounded_dec_context.Emin = decimal.MIN_EMIN +_unbounded_dec_context.traps[decimal.Inexact] = 1 # sanity check def int_to_decimal(n): """Asymptotically fast conversion of an 'int' to Decimal.""" @@ -33,57 +113,32 @@ def int_to_decimal(n): # "clever" recursive way. If we want a string representation, we # apply str to _that_. - D = decimal.Decimal - D2 = D(2) - - BITLIM = 128 - - mem = {} - - def w2pow(w): - """Return D(2)**w and store the result. Also possibly save some - intermediate results. In context, these are likely to be reused - across various levels of the conversion to Decimal.""" - if (result := mem.get(w)) is None: - if w <= BITLIM: - result = D2**w - elif w - 1 in mem: - result = (t := mem[w - 1]) + t - else: - w2 = w >> 1 - # If w happens to be odd, w-w2 is one larger then w2 - # now. Recurse on the smaller first (w2), so that it's - # in the cache and the larger (w-w2) can be handled by - # the cheaper `w-1 in mem` branch instead. - result = w2pow(w2) * w2pow(w - w2) - mem[w] = result - return result + from decimal import Decimal as D + BITLIM = 200 + # Don't bother caching the "lo" mask in this; the time to compute it is + # tiny compared to the multiply. def inner(n, w): if w <= BITLIM: return D(n) w2 = w >> 1 hi = n >> w2 - lo = n - (hi << w2) - return inner(lo, w2) + inner(hi, w - w2) * w2pow(w2) - - with decimal.localcontext() as ctx: - ctx.prec = decimal.MAX_PREC - ctx.Emax = decimal.MAX_EMAX - ctx.Emin = decimal.MIN_EMIN - ctx.traps[decimal.Inexact] = 1 + lo = n & ((1 << w2) - 1) + return inner(lo, w2) + inner(hi, w - w2) * w2pow[w2] + with decimal.localcontext(_unbounded_dec_context): + nbits = n.bit_length() + w2pow = compute_powers(nbits, D(2), BITLIM) if n < 0: negate = True n = -n else: negate = False - result = inner(n, n.bit_length()) + result = inner(n, nbits) if negate: result = -result return result - def int_to_decimal_string(n): """Asymptotically fast conversion of an 'int' to a decimal string.""" w = n.bit_length() @@ -97,14 +152,13 @@ def int_to_decimal_string(n): # available. This algorithm is asymptotically worse than the algorithm # using the decimal module, but better than the quadratic time # implementation in longobject.c. + + DIGLIM = 1000 def inner(n, w): - if w <= 1000: + if w <= DIGLIM: return str(n) w2 = w >> 1 - d = pow10_cache.get(w2) - if d is None: - d = pow10_cache[w2] = 5**w2 << w2 # 10**i = (5*2)**i = 5**i * 2**i - hi, lo = divmod(n, d) + hi, lo = divmod(n, pow10[w2]) return inner(hi, w - w2) + inner(lo, w2).zfill(w2) # The estimation of the number of decimal digits. @@ -115,7 +169,9 @@ def int_to_decimal_string(n): # only if the number has way more than 10**15 digits, that exceeds # the 52-bit physical address limit in both Intel64 and AMD64. w = int(w * 0.3010299956639812 + 1) # log10(2) - pow10_cache = {} + pow10 = compute_powers(w, 5, DIGLIM) + for k, v in pow10.items(): + pow10[k] = v << k # 5**k << k == 5**k * 2**k == 10**k if n < 0: n = -n sign = '-' @@ -128,7 +184,6 @@ def int_to_decimal_string(n): s = s.lstrip('0') return sign + s - def _str_to_int_inner(s): """Asymptotically fast conversion of a 'str' to an 'int'.""" @@ -144,35 +199,15 @@ def _str_to_int_inner(s): DIGLIM = 2048 - mem = {} - - def w5pow(w): - """Return 5**w and store the result. - Also possibly save some intermediate results. In context, these - are likely to be reused across various levels of the conversion - to 'int'. - """ - if (result := mem.get(w)) is None: - if w <= DIGLIM: - result = 5**w - elif w - 1 in mem: - result = mem[w - 1] * 5 - else: - w2 = w >> 1 - # If w happens to be odd, w-w2 is one larger then w2 - # now. Recurse on the smaller first (w2), so that it's - # in the cache and the larger (w-w2) can be handled by - # the cheaper `w-1 in mem` branch instead. - result = w5pow(w2) * w5pow(w - w2) - mem[w] = result - return result - def inner(a, b): if b - a <= DIGLIM: return int(s[a:b]) mid = (a + b + 1) >> 1 - return inner(mid, b) + ((inner(a, mid) * w5pow(b - mid)) << (b - mid)) + return (inner(mid, b) + + ((inner(a, mid) * w5pow[b - mid]) + << (b - mid))) + w5pow = compute_powers(len(s), 5, DIGLIM) return inner(0, len(s)) @@ -186,7 +221,6 @@ def int_from_string(s): s = s.rstrip().replace('_', '') return _str_to_int_inner(s) - def str_to_int(s): """Asymptotically fast version of decimal string to 'int' conversion.""" # FIXME: this doesn't support the full syntax that int() supports. diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py index c862639..8959ffb 100644 --- a/Lib/test/test_int.py +++ b/Lib/test/test_int.py @@ -906,6 +906,18 @@ class PyLongModuleTests(unittest.TestCase): with self.assertRaises(RuntimeError): int(big_value) + def test_pylong_roundtrip(self): + from random import randrange, getrandbits + bits = 5000 + while bits <= 1_000_000: + bits += randrange(-100, 101) # break bitlength patterns + hibit = 1 << (bits - 1) + n = hibit | getrandbits(bits - 1) + assert n.bit_length() == bits + sn = str(n) + self.assertFalse(sn.startswith('0')) + self.assertEqual(n, int(sn)) + bits <<= 1 if __name__ == "__main__": unittest.main() -- cgit v0.12