summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorStefan Krah <skrah@bytereef.org>2020-02-21 00:52:47 (GMT)
committerGitHub <noreply@github.com>2020-02-21 00:52:47 (GMT)
commit90930e65455f60216f09d175586139242dbba260 (patch)
tree04d12963cae3cfd86430ca5a0480727c1e70439c
parent6c444d0dab8f06cf304263b34beb299101cef3de (diff)
downloadcpython-90930e65455f60216f09d175586139242dbba260.zip
cpython-90930e65455f60216f09d175586139242dbba260.tar.gz
cpython-90930e65455f60216f09d175586139242dbba260.tar.bz2
bpo-39576: Prevent memory error for overly optimistic precisions (GH-18581)
-rw-r--r--Lib/test/test_decimal.py35
-rw-r--r--Modules/_decimal/libmpdec/mpdecimal.c77
-rw-r--r--Modules/_decimal/tests/deccheck.py139
3 files changed, 245 insertions, 6 deletions
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index fe0cfc7..f1abd2a 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -5476,6 +5476,41 @@ class CWhitebox(unittest.TestCase):
self.assertEqual(Decimal.from_float(cls(101.1)),
Decimal.from_float(101.1))
+ def test_maxcontext_exact_arith(self):
+
+ # Make sure that exact operations do not raise MemoryError due
+ # to huge intermediate values when the context precision is very
+ # large.
+
+ # The following functions fill the available precision and are
+ # therefore not suitable for large precisions (by design of the
+ # specification).
+ MaxContextSkip = ['logical_invert', 'next_minus', 'next_plus',
+ 'logical_and', 'logical_or', 'logical_xor',
+ 'next_toward', 'rotate', 'shift']
+
+ Decimal = C.Decimal
+ Context = C.Context
+ localcontext = C.localcontext
+
+ # Here only some functions that are likely candidates for triggering a
+ # MemoryError are tested. deccheck.py has an exhaustive test.
+ maxcontext = Context(prec=C.MAX_PREC, Emin=C.MIN_EMIN, Emax=C.MAX_EMAX)
+ with localcontext(maxcontext):
+ self.assertEqual(Decimal(0).exp(), 1)
+ self.assertEqual(Decimal(1).ln(), 0)
+ self.assertEqual(Decimal(1).log10(), 0)
+ self.assertEqual(Decimal(10**2).log10(), 2)
+ self.assertEqual(Decimal(10**223).log10(), 223)
+ self.assertEqual(Decimal(10**19).logb(), 19)
+ self.assertEqual(Decimal(4).sqrt(), 2)
+ self.assertEqual(Decimal("40E9").sqrt(), Decimal('2.0E+5'))
+ self.assertEqual(divmod(Decimal(10), 3), (3, 1))
+ self.assertEqual(Decimal(10) // 3, 3)
+ self.assertEqual(Decimal(4) / 2, 2)
+ self.assertEqual(Decimal(400) ** -1, Decimal('0.0025'))
+
+
@requires_docstrings
@unittest.skipUnless(C, "test requires C version")
class SignatureTest(unittest.TestCase):
diff --git a/Modules/_decimal/libmpdec/mpdecimal.c b/Modules/_decimal/libmpdec/mpdecimal.c
index bfa8bb3..0986edb 100644
--- a/Modules/_decimal/libmpdec/mpdecimal.c
+++ b/Modules/_decimal/libmpdec/mpdecimal.c
@@ -3781,6 +3781,43 @@ mpd_qdiv(mpd_t *q, const mpd_t *a, const mpd_t *b,
const mpd_context_t *ctx, uint32_t *status)
{
_mpd_qdiv(SET_IDEAL_EXP, q, a, b, ctx, status);
+
+ if (*status & MPD_Malloc_error) {
+ /* Inexact quotients (the usual case) fill the entire context precision,
+ * which can lead to malloc() failures for very high precisions. Retry
+ * the operation with a lower precision in case the result is exact.
+ *
+ * We need an upper bound for the number of digits of a_coeff / b_coeff
+ * when the result is exact. If a_coeff' * 1 / b_coeff' is in lowest
+ * terms, then maxdigits(a_coeff') + maxdigits(1 / b_coeff') is a suitable
+ * bound.
+ *
+ * 1 / b_coeff' is exact iff b_coeff' exclusively has prime factors 2 or 5.
+ * The largest amount of digits is generated if b_coeff' is a power of 2 or
+ * a power of 5 and is less than or equal to log5(b_coeff') <= log2(b_coeff').
+ *
+ * We arrive at a total upper bound:
+ *
+ * maxdigits(a_coeff') + maxdigits(1 / b_coeff') <=
+ * a->digits + log2(b_coeff) =
+ * a->digits + log10(b_coeff) / log10(2) <=
+ * a->digits + b->digits * 4;
+ */
+ uint32_t workstatus = 0;
+ mpd_context_t workctx = *ctx;
+ workctx.prec = a->digits + b->digits * 4;
+ if (workctx.prec >= ctx->prec) {
+ return; /* No point in retrying, keep the original error. */
+ }
+
+ _mpd_qdiv(SET_IDEAL_EXP, q, a, b, &workctx, &workstatus);
+ if (workstatus == 0) { /* The result is exact, unrounded, normal etc. */
+ *status = 0;
+ return;
+ }
+
+ mpd_seterror(q, *status, status);
+ }
}
/* Internal function. */
@@ -7702,9 +7739,9 @@ mpd_qinvroot(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
/* END LIBMPDEC_ONLY */
/* Algorithm from decimal.py */
-void
-mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
- uint32_t *status)
+static void
+_mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
+ uint32_t *status)
{
mpd_context_t maxcontext;
MPD_NEW_STATIC(c,0,0,0,0);
@@ -7836,6 +7873,40 @@ malloc_error:
goto out;
}
+void
+mpd_qsqrt(mpd_t *result, const mpd_t *a, const mpd_context_t *ctx,
+ uint32_t *status)
+{
+ _mpd_qsqrt(result, a, ctx, status);
+
+ if (*status & (MPD_Malloc_error|MPD_Division_impossible)) {
+ /* The above conditions can occur at very high context precisions
+ * if intermediate values get too large. Retry the operation with
+ * a lower context precision in case the result is exact.
+ *
+ * If the result is exact, an upper bound for the number of digits
+ * is the number of digits in the input.
+ *
+ * NOTE: sqrt(40e9) = 2.0e+5 /\ digits(40e9) = digits(2.0e+5) = 2
+ */
+ uint32_t workstatus = 0;
+ mpd_context_t workctx = *ctx;
+ workctx.prec = a->digits;
+
+ if (workctx.prec >= ctx->prec) {
+ return; /* No point in repeating this, keep the original error. */
+ }
+
+ _mpd_qsqrt(result, a, &workctx, &workstatus);
+ if (workstatus == 0) {
+ *status = 0;
+ return;
+ }
+
+ mpd_seterror(result, *status, status);
+ }
+}
+
/******************************************************************************/
/* Base conversions */
diff --git a/Modules/_decimal/tests/deccheck.py b/Modules/_decimal/tests/deccheck.py
index f907531..5cd5db5 100644
--- a/Modules/_decimal/tests/deccheck.py
+++ b/Modules/_decimal/tests/deccheck.py
@@ -125,6 +125,12 @@ ContextFunctions = {
'special': ('context.__reduce_ex__', 'context.create_decimal_from_float')
}
+# Functions that set no context flags but whose result can differ depending
+# on prec, Emin and Emax.
+MaxContextSkip = ['is_normal', 'is_subnormal', 'logical_invert', 'next_minus',
+ 'next_plus', 'number_class', 'logical_and', 'logical_or',
+ 'logical_xor', 'next_toward', 'rotate', 'shift']
+
# Functions that require a restricted exponent range for reasonable runtimes.
UnaryRestricted = [
'__ceil__', '__floor__', '__int__', '__trunc__',
@@ -344,6 +350,20 @@ class TestSet(object):
self.pex = RestrictedList() # Python exceptions for P.Decimal
self.presults = RestrictedList() # P.Decimal results
+ # If the above results are exact, unrounded and not clamped, repeat
+ # the operation with a maxcontext to ensure that huge intermediate
+ # values do not cause a MemoryError.
+ self.with_maxcontext = False
+ self.maxcontext = context.c.copy()
+ self.maxcontext.prec = C.MAX_PREC
+ self.maxcontext.Emax = C.MAX_EMAX
+ self.maxcontext.Emin = C.MIN_EMIN
+ self.maxcontext.clear_flags()
+
+ self.maxop = RestrictedList() # converted C.Decimal operands
+ self.maxex = RestrictedList() # Python exceptions for C.Decimal
+ self.maxresults = RestrictedList() # C.Decimal results
+
# ======================================================================
# SkipHandler: skip known discrepancies
@@ -545,13 +565,17 @@ def function_as_string(t):
if t.contextfunc:
cargs = t.cop
pargs = t.pop
+ maxargs = t.maxop
cfunc = "c_func: %s(" % t.funcname
pfunc = "p_func: %s(" % t.funcname
+ maxfunc = "max_func: %s(" % t.funcname
else:
cself, cargs = t.cop[0], t.cop[1:]
pself, pargs = t.pop[0], t.pop[1:]
+ maxself, maxargs = t.maxop[0], t.maxop[1:]
cfunc = "c_func: %s.%s(" % (repr(cself), t.funcname)
pfunc = "p_func: %s.%s(" % (repr(pself), t.funcname)
+ maxfunc = "max_func: %s.%s(" % (repr(maxself), t.funcname)
err = cfunc
for arg in cargs:
@@ -565,6 +589,14 @@ def function_as_string(t):
err = err.rstrip(", ")
err += ")"
+ if t.with_maxcontext:
+ err += "\n"
+ err += maxfunc
+ for arg in maxargs:
+ err += "%s, " % repr(arg)
+ err = err.rstrip(", ")
+ err += ")"
+
return err
def raise_error(t):
@@ -577,9 +609,24 @@ def raise_error(t):
err = "Error in %s:\n\n" % t.funcname
err += "input operands: %s\n\n" % (t.op,)
err += function_as_string(t)
- err += "\n\nc_result: %s\np_result: %s\n\n" % (t.cresults, t.presults)
- err += "c_exceptions: %s\np_exceptions: %s\n\n" % (t.cex, t.pex)
- err += "%s\n\n" % str(t.context)
+
+ err += "\n\nc_result: %s\np_result: %s\n" % (t.cresults, t.presults)
+ if t.with_maxcontext:
+ err += "max_result: %s\n\n" % (t.maxresults)
+ else:
+ err += "\n"
+
+ err += "c_exceptions: %s\np_exceptions: %s\n" % (t.cex, t.pex)
+ if t.with_maxcontext:
+ err += "max_exceptions: %s\n\n" % t.maxex
+ else:
+ err += "\n"
+
+ err += "%s\n" % str(t.context)
+ if t.with_maxcontext:
+ err += "%s\n" % str(t.maxcontext)
+ else:
+ err += "\n"
raise VerifyError(err)
@@ -603,6 +650,13 @@ def raise_error(t):
# are printed to stdout.
# ======================================================================
+def all_nan(a):
+ if isinstance(a, C.Decimal):
+ return a.is_nan()
+ elif isinstance(a, tuple):
+ return all(all_nan(v) for v in a)
+ return False
+
def convert(t, convstr=True):
""" t is the testset. At this stage the testset contains a tuple of
operands t.op of various types. For decimal methods the first
@@ -617,10 +671,12 @@ def convert(t, convstr=True):
for i, op in enumerate(t.op):
context.clear_status()
+ t.maxcontext.clear_flags()
if op in RoundModes:
t.cop.append(op)
t.pop.append(op)
+ t.maxop.append(op)
elif not t.contextfunc and i == 0 or \
convstr and isinstance(op, str):
@@ -638,11 +694,25 @@ def convert(t, convstr=True):
p = None
pex = e.__class__
+ try:
+ C.setcontext(t.maxcontext)
+ maxop = C.Decimal(op)
+ maxex = None
+ except (TypeError, ValueError, OverflowError) as e:
+ maxop = None
+ maxex = e.__class__
+ finally:
+ C.setcontext(context.c)
+
t.cop.append(c)
t.cex.append(cex)
+
t.pop.append(p)
t.pex.append(pex)
+ t.maxop.append(maxop)
+ t.maxex.append(maxex)
+
if cex is pex:
if str(c) != str(p) or not context.assert_eq_status():
raise_error(t)
@@ -652,14 +722,21 @@ def convert(t, convstr=True):
else:
raise_error(t)
+ # The exceptions in the maxcontext operation can legitimately
+ # differ, only test that maxex implies cex:
+ if maxex is not None and cex is not maxex:
+ raise_error(t)
+
elif isinstance(op, Context):
t.context = op
t.cop.append(op.c)
t.pop.append(op.p)
+ t.maxop.append(t.maxcontext)
else:
t.cop.append(op)
t.pop.append(op)
+ t.maxop.append(op)
return 1
@@ -673,6 +750,7 @@ def callfuncs(t):
t.rc and t.rp are the results of the operation.
"""
context.clear_status()
+ t.maxcontext.clear_flags()
try:
if t.contextfunc:
@@ -700,6 +778,35 @@ def callfuncs(t):
t.rp = None
t.pex.append(e.__class__)
+ # If the above results are exact, unrounded, normal etc., repeat the
+ # operation with a maxcontext to ensure that huge intermediate values
+ # do not cause a MemoryError.
+ if (t.funcname not in MaxContextSkip and
+ not context.c.flags[C.InvalidOperation] and
+ not context.c.flags[C.Inexact] and
+ not context.c.flags[C.Rounded] and
+ not context.c.flags[C.Subnormal] and
+ not context.c.flags[C.Clamped] and
+ not context.clamp and # results are padded to context.prec if context.clamp==1.
+ not any(isinstance(v, C.Context) for v in t.cop)): # another context is used.
+ t.with_maxcontext = True
+ try:
+ if t.contextfunc:
+ maxargs = t.maxop
+ t.rmax = getattr(t.maxcontext, t.funcname)(*maxargs)
+ else:
+ maxself = t.maxop[0]
+ maxargs = t.maxop[1:]
+ try:
+ C.setcontext(t.maxcontext)
+ t.rmax = getattr(maxself, t.funcname)(*maxargs)
+ finally:
+ C.setcontext(context.c)
+ t.maxex.append(None)
+ except (TypeError, ValueError, OverflowError, MemoryError) as e:
+ t.rmax = None
+ t.maxex.append(e.__class__)
+
def verify(t, stat):
""" t is the testset. At this stage the testset contains the following
tuples:
@@ -714,6 +821,9 @@ def verify(t, stat):
"""
t.cresults.append(str(t.rc))
t.presults.append(str(t.rp))
+ if t.with_maxcontext:
+ t.maxresults.append(str(t.rmax))
+
if isinstance(t.rc, C.Decimal) and isinstance(t.rp, P.Decimal):
# General case: both results are Decimals.
t.cresults.append(t.rc.to_eng_string())
@@ -725,6 +835,12 @@ def verify(t, stat):
t.presults.append(str(t.rp.imag))
t.presults.append(str(t.rp.real))
+ if t.with_maxcontext and isinstance(t.rmax, C.Decimal):
+ t.maxresults.append(t.rmax.to_eng_string())
+ t.maxresults.append(t.rmax.as_tuple())
+ t.maxresults.append(str(t.rmax.imag))
+ t.maxresults.append(str(t.rmax.real))
+
nc = t.rc.number_class().lstrip('+-s')
stat[nc] += 1
else:
@@ -732,6 +848,9 @@ def verify(t, stat):
if not isinstance(t.rc, tuple) and not isinstance(t.rp, tuple):
if t.rc != t.rp:
raise_error(t)
+ if t.with_maxcontext and not isinstance(t.rmax, tuple):
+ if t.rmax != t.rc:
+ raise_error(t)
stat[type(t.rc).__name__] += 1
# The return value lists must be equal.
@@ -744,6 +863,20 @@ def verify(t, stat):
if not t.context.assert_eq_status():
raise_error(t)
+ if t.with_maxcontext:
+ # NaN payloads etc. depend on precision and clamp.
+ if all_nan(t.rc) and all_nan(t.rmax):
+ return
+ # The return value lists must be equal.
+ if t.maxresults != t.cresults:
+ raise_error(t)
+ # The Python exception lists (TypeError, etc.) must be equal.
+ if t.maxex != t.cex:
+ raise_error(t)
+ # The context flags must be equal.
+ if t.maxcontext.flags != t.context.c.flags:
+ raise_error(t)
+
# ======================================================================
# Main test loops