diff options
author | John Belmonte <john@neggie.net> | 2024-02-12 11:17:51 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-12 11:17:51 (GMT) |
commit | 72340d15cdfdfa4796fdd7c702094c852c2b32d2 (patch) | |
tree | e5236dae006b23712ec07ac39a416d08371dc8b8 /Modules | |
parent | 235cacff81931a68e8c400bb3919ae6e55462fb5 (diff) | |
download | cpython-72340d15cdfdfa4796fdd7c702094c852c2b32d2.zip cpython-72340d15cdfdfa4796fdd7c702094c852c2b32d2.tar.gz cpython-72340d15cdfdfa4796fdd7c702094c852c2b32d2.tar.bz2 |
gh-114563: C decimal falls back to pydecimal for unsupported format strings (GH-114879)
Immediate merits:
* eliminate complex workarounds for 'z' format support
(NOTE: mpdecimal recently added 'z' support, so this becomes
efficient in the long term.)
* fix 'z' format memory leak
* fix 'z' format applied to 'F'
* fix missing '#' format support
Suggested and prototyped by Stefan Krah.
Fixes gh-114563, gh-91060
Co-authored-by: Stefan Krah <skrah@bytereef.org>
Diffstat (limited to 'Modules')
-rw-r--r-- | Modules/_decimal/_decimal.c | 184 |
1 files changed, 62 insertions, 122 deletions
diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index 127f5f2..5b053c7 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -82,6 +82,9 @@ typedef struct { /* Convert rationals for comparison */ PyObject *Rational; + /* Invariant: NULL or pointer to _pydecimal.Decimal */ + PyObject *PyDecimal; + PyObject *SignalTuple; struct DecCondMap *signal_map; @@ -3336,56 +3339,6 @@ dotsep_as_utf8(const char *s) return utf8; } -/* copy of libmpdec _mpd_round() */ -static void -_mpd_round(mpd_t *result, const mpd_t *a, mpd_ssize_t prec, - const mpd_context_t *ctx, uint32_t *status) -{ - mpd_ssize_t exp = a->exp + a->digits - prec; - - if (prec <= 0) { - mpd_seterror(result, MPD_Invalid_operation, status); - return; - } - if (mpd_isspecial(a) || mpd_iszero(a)) { - mpd_qcopy(result, a, status); - return; - } - - mpd_qrescale_fmt(result, a, exp, ctx, status); - if (result->digits > prec) { - mpd_qrescale_fmt(result, result, exp+1, ctx, status); - } -} - -/* Locate negative zero "z" option within a UTF-8 format spec string. - * Returns pointer to "z", else NULL. - * The portion of the spec we're working with is [[fill]align][sign][z] */ -static const char * -format_spec_z_search(char const *fmt, Py_ssize_t size) { - char const *pos = fmt; - char const *fmt_end = fmt + size; - /* skip over [[fill]align] (fill may be multi-byte character) */ - pos += 1; - while (pos < fmt_end && *pos & 0x80) { - pos += 1; - } - if (pos < fmt_end && strchr("<>=^", *pos) != NULL) { - pos += 1; - } else { - /* fill not present-- skip over [align] */ - pos = fmt; - if (pos < fmt_end && strchr("<>=^", *pos) != NULL) { - pos += 1; - } - } - /* skip over [sign] */ - if (pos < fmt_end && strchr("+- ", *pos) != NULL) { - pos += 1; - } - return pos < fmt_end && *pos == 'z' ? pos : NULL; -} - static int dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const char **valuestr) { @@ -3411,6 +3364,48 @@ dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const return 0; } +/* + * Fallback _pydecimal formatting for new format specifiers that mpdecimal does + * not yet support. As documented, libmpdec follows the PEP-3101 format language: + * https://www.bytereef.org/mpdecimal/doc/libmpdec/assign-convert.html#to-string + */ +static PyObject * +pydec_format(PyObject *dec, PyObject *context, PyObject *fmt, decimal_state *state) +{ + PyObject *result; + PyObject *pydec; + PyObject *u; + + if (state->PyDecimal == NULL) { + state->PyDecimal = _PyImport_GetModuleAttrString("_pydecimal", "Decimal"); + if (state->PyDecimal == NULL) { + return NULL; + } + } + + u = dec_str(dec); + if (u == NULL) { + return NULL; + } + + pydec = PyObject_CallOneArg(state->PyDecimal, u); + Py_DECREF(u); + if (pydec == NULL) { + return NULL; + } + + result = PyObject_CallMethod(pydec, "__format__", "(OO)", fmt, context); + Py_DECREF(pydec); + + if (result == NULL && PyErr_ExceptionMatches(PyExc_ValueError)) { + /* Do not confuse users with the _pydecimal exception */ + PyErr_Clear(); + PyErr_SetString(PyExc_ValueError, "invalid format string"); + } + + return result; +} + /* Formatted representation of a PyDecObject. */ static PyObject * dec_format(PyObject *dec, PyObject *args) @@ -3423,16 +3418,11 @@ dec_format(PyObject *dec, PyObject *args) PyObject *fmtarg; PyObject *context; mpd_spec_t spec; - char const *fmt; - char *fmt_copy = NULL; + char *fmt; char *decstring = NULL; uint32_t status = 0; int replace_fillchar = 0; - int no_neg_0 = 0; Py_ssize_t size; - mpd_t *mpd = MPD(dec); - mpd_uint_t dt[MPD_MINALLOC_MAX]; - mpd_t tmp = {MPD_STATIC|MPD_STATIC_DATA,0,0,0,MPD_MINALLOC_MAX,dt}; decimal_state *state = get_module_state_by_def(Py_TYPE(dec)); @@ -3442,7 +3432,7 @@ dec_format(PyObject *dec, PyObject *args) } if (PyUnicode_Check(fmtarg)) { - fmt = PyUnicode_AsUTF8AndSize(fmtarg, &size); + fmt = (char *)PyUnicode_AsUTF8AndSize(fmtarg, &size); if (fmt == NULL) { return NULL; } @@ -3454,35 +3444,15 @@ dec_format(PyObject *dec, PyObject *args) } } - /* NOTE: If https://github.com/python/cpython/pull/29438 lands, the - * format string manipulation below can be eliminated by enhancing - * the forked mpd_parse_fmt_str(). */ if (size > 0 && fmt[0] == '\0') { /* NUL fill character: must be replaced with a valid UTF-8 char before calling mpd_parse_fmt_str(). */ replace_fillchar = 1; - fmt = fmt_copy = dec_strdup(fmt, size); - if (fmt_copy == NULL) { + fmt = dec_strdup(fmt, size); + if (fmt == NULL) { return NULL; } - fmt_copy[0] = '_'; - } - /* Strip 'z' option, which isn't understood by mpd_parse_fmt_str(). - * NOTE: fmt is always null terminated by PyUnicode_AsUTF8AndSize() */ - char const *z_position = format_spec_z_search(fmt, size); - if (z_position != NULL) { - no_neg_0 = 1; - size_t z_index = z_position - fmt; - if (fmt_copy == NULL) { - fmt = fmt_copy = dec_strdup(fmt, size); - if (fmt_copy == NULL) { - return NULL; - } - } - /* Shift characters (including null terminator) left, - overwriting the 'z' option. */ - memmove(fmt_copy + z_index, fmt_copy + z_index + 1, size - z_index); - size -= 1; + fmt[0] = '_'; } } else { @@ -3492,10 +3462,13 @@ dec_format(PyObject *dec, PyObject *args) } if (!mpd_parse_fmt_str(&spec, fmt, CtxCaps(context))) { - PyErr_SetString(PyExc_ValueError, - "invalid format string"); - goto finish; + if (replace_fillchar) { + PyMem_Free(fmt); + } + + return pydec_format(dec, context, fmtarg, state); } + if (replace_fillchar) { /* In order to avoid clobbering parts of UTF-8 thousands separators or decimal points when the substitution is reversed later, the actual @@ -3548,45 +3521,8 @@ dec_format(PyObject *dec, PyObject *args) } } - if (no_neg_0 && mpd_isnegative(mpd) && !mpd_isspecial(mpd)) { - /* Round into a temporary (carefully mirroring the rounding - of mpd_qformat_spec()), and check if the result is negative zero. - If so, clear the sign and format the resulting positive zero. */ - mpd_ssize_t prec; - mpd_qcopy(&tmp, mpd, &status); - if (spec.prec >= 0) { - switch (spec.type) { - case 'f': - mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status); - break; - case '%': - tmp.exp += 2; - mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status); - break; - case 'g': - prec = (spec.prec == 0) ? 1 : spec.prec; - if (tmp.digits > prec) { - _mpd_round(&tmp, &tmp, prec, CTX(context), &status); - } - break; - case 'e': - if (!mpd_iszero(&tmp)) { - _mpd_round(&tmp, &tmp, spec.prec+1, CTX(context), &status); - } - break; - } - } - if (status & MPD_Errors) { - PyErr_SetString(PyExc_ValueError, "unexpected error when rounding"); - goto finish; - } - if (mpd_iszero(&tmp)) { - mpd_set_positive(&tmp); - mpd = &tmp; - } - } - decstring = mpd_qformat_spec(mpd, &spec, CTX(context), &status); + decstring = mpd_qformat_spec(MPD(dec), &spec, CTX(context), &status); if (decstring == NULL) { if (status & MPD_Malloc_error) { PyErr_NoMemory(); @@ -3609,7 +3545,7 @@ finish: Py_XDECREF(grouping); Py_XDECREF(sep); Py_XDECREF(dot); - if (fmt_copy) PyMem_Free(fmt_copy); + if (replace_fillchar) PyMem_Free(fmt); if (decstring) mpd_free(decstring); return result; } @@ -5987,6 +5923,9 @@ _decimal_exec(PyObject *m) Py_CLEAR(collections_abc); Py_CLEAR(MutableMapping); + /* For format specifiers not yet supported by libmpdec */ + state->PyDecimal = NULL; + /* Add types to the module */ CHECK_INT(PyModule_AddType(m, state->PyDec_Type)); CHECK_INT(PyModule_AddType(m, state->PyDecContext_Type)); @@ -6192,6 +6131,7 @@ decimal_clear(PyObject *module) Py_CLEAR(state->extended_context_template); Py_CLEAR(state->Rational); Py_CLEAR(state->SignalTuple); + Py_CLEAR(state->PyDecimal); PyMem_Free(state->signal_map); PyMem_Free(state->cond_map); |