diff options
Diffstat (limited to 'Modules')
-rw-r--r-- | Modules/_decimal/_decimal.c | 126 |
1 files changed, 119 insertions, 7 deletions
diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index 35a1156..4637b8b 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -3183,6 +3183,56 @@ dotsep_as_utf8(const char *s) return utf8; } +/* copy of libmpdec _mpd_round() */ +static void +_mpd_round(mpd_t *result, const mpd_t *a, mpd_ssize_t prec, + const mpd_context_t *ctx, uint32_t *status) +{ + mpd_ssize_t exp = a->exp + a->digits - prec; + + if (prec <= 0) { + mpd_seterror(result, MPD_Invalid_operation, status); + return; + } + if (mpd_isspecial(a) || mpd_iszero(a)) { + mpd_qcopy(result, a, status); + return; + } + + mpd_qrescale_fmt(result, a, exp, ctx, status); + if (result->digits > prec) { + mpd_qrescale_fmt(result, result, exp+1, ctx, status); + } +} + +/* Locate negative zero "z" option within a UTF-8 format spec string. + * Returns pointer to "z", else NULL. + * The portion of the spec we're working with is [[fill]align][sign][z] */ +static const char * +format_spec_z_search(char const *fmt, Py_ssize_t size) { + char const *pos = fmt; + char const *fmt_end = fmt + size; + /* skip over [[fill]align] (fill may be multi-byte character) */ + pos += 1; + while (pos < fmt_end && *pos & 0x80) { + pos += 1; + } + if (pos < fmt_end && strchr("<>=^", *pos) != NULL) { + pos += 1; + } else { + /* fill not present-- skip over [align] */ + pos = fmt; + if (pos < fmt_end && strchr("<>=^", *pos) != NULL) { + pos += 1; + } + } + /* skip over [sign] */ + if (pos < fmt_end && strchr("+- ", *pos) != NULL) { + pos += 1; + } + return pos < fmt_end && *pos == 'z' ? pos : NULL; +} + static int dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const char **valuestr) { @@ -3220,11 +3270,16 @@ dec_format(PyObject *dec, PyObject *args) PyObject *fmtarg; PyObject *context; mpd_spec_t spec; - char *fmt; + char const *fmt; + char *fmt_copy = NULL; char *decstring = NULL; uint32_t status = 0; int replace_fillchar = 0; + int no_neg_0 = 0; Py_ssize_t size; + mpd_t *mpd = MPD(dec); + mpd_uint_t dt[MPD_MINALLOC_MAX]; + mpd_t tmp = {MPD_STATIC|MPD_STATIC_DATA,0,0,0,MPD_MINALLOC_MAX,dt}; CURRENT_CONTEXT(context); @@ -3233,19 +3288,39 @@ dec_format(PyObject *dec, PyObject *args) } if (PyUnicode_Check(fmtarg)) { - fmt = (char *)PyUnicode_AsUTF8AndSize(fmtarg, &size); + fmt = PyUnicode_AsUTF8AndSize(fmtarg, &size); if (fmt == NULL) { return NULL; } + /* NOTE: If https://github.com/python/cpython/pull/29438 lands, the + * format string manipulation below can be eliminated by enhancing + * the forked mpd_parse_fmt_str(). */ if (size > 0 && fmt[0] == '\0') { /* NUL fill character: must be replaced with a valid UTF-8 char before calling mpd_parse_fmt_str(). */ replace_fillchar = 1; - fmt = dec_strdup(fmt, size); - if (fmt == NULL) { + fmt = fmt_copy = dec_strdup(fmt, size); + if (fmt_copy == NULL) { return NULL; } - fmt[0] = '_'; + fmt_copy[0] = '_'; + } + /* Strip 'z' option, which isn't understood by mpd_parse_fmt_str(). + * NOTE: fmt is always null terminated by PyUnicode_AsUTF8AndSize() */ + char const *z_position = format_spec_z_search(fmt, size); + if (z_position != NULL) { + no_neg_0 = 1; + size_t z_index = z_position - fmt; + if (fmt_copy == NULL) { + fmt = fmt_copy = dec_strdup(fmt, size); + if (fmt_copy == NULL) { + return NULL; + } + } + /* Shift characters (including null terminator) left, + overwriting the 'z' option. */ + memmove(fmt_copy + z_index, fmt_copy + z_index + 1, size - z_index); + size -= 1; } } else { @@ -3311,8 +3386,45 @@ dec_format(PyObject *dec, PyObject *args) } } + if (no_neg_0 && mpd_isnegative(mpd) && !mpd_isspecial(mpd)) { + /* Round into a temporary (carefully mirroring the rounding + of mpd_qformat_spec()), and check if the result is negative zero. + If so, clear the sign and format the resulting positive zero. */ + mpd_ssize_t prec; + mpd_qcopy(&tmp, mpd, &status); + if (spec.prec >= 0) { + switch (spec.type) { + case 'f': + mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status); + break; + case '%': + tmp.exp += 2; + mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status); + break; + case 'g': + prec = (spec.prec == 0) ? 1 : spec.prec; + if (tmp.digits > prec) { + _mpd_round(&tmp, &tmp, prec, CTX(context), &status); + } + break; + case 'e': + if (!mpd_iszero(&tmp)) { + _mpd_round(&tmp, &tmp, spec.prec+1, CTX(context), &status); + } + break; + } + } + if (status & MPD_Errors) { + PyErr_SetString(PyExc_ValueError, "unexpected error when rounding"); + goto finish; + } + if (mpd_iszero(&tmp)) { + mpd_set_positive(&tmp); + mpd = &tmp; + } + } - decstring = mpd_qformat_spec(MPD(dec), &spec, CTX(context), &status); + decstring = mpd_qformat_spec(mpd, &spec, CTX(context), &status); if (decstring == NULL) { if (status & MPD_Malloc_error) { PyErr_NoMemory(); @@ -3335,7 +3447,7 @@ finish: Py_XDECREF(grouping); Py_XDECREF(sep); Py_XDECREF(dot); - if (replace_fillchar) PyMem_Free(fmt); + if (fmt_copy) PyMem_Free(fmt_copy); if (decstring) mpd_free(decstring); return result; } |