summaryrefslogtreecommitdiffstats
path: root/Modules
diff options
context:
space:
mode:
Diffstat (limited to 'Modules')
-rw-r--r--Modules/_decimal/_decimal.c126
1 files changed, 119 insertions, 7 deletions
diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c
index 35a1156..4637b8b 100644
--- a/Modules/_decimal/_decimal.c
+++ b/Modules/_decimal/_decimal.c
@@ -3183,6 +3183,56 @@ dotsep_as_utf8(const char *s)
return utf8;
}
+/* copy of libmpdec _mpd_round() */
+static void
+_mpd_round(mpd_t *result, const mpd_t *a, mpd_ssize_t prec,
+ const mpd_context_t *ctx, uint32_t *status)
+{
+ mpd_ssize_t exp = a->exp + a->digits - prec;
+
+ if (prec <= 0) {
+ mpd_seterror(result, MPD_Invalid_operation, status);
+ return;
+ }
+ if (mpd_isspecial(a) || mpd_iszero(a)) {
+ mpd_qcopy(result, a, status);
+ return;
+ }
+
+ mpd_qrescale_fmt(result, a, exp, ctx, status);
+ if (result->digits > prec) {
+ mpd_qrescale_fmt(result, result, exp+1, ctx, status);
+ }
+}
+
+/* Locate negative zero "z" option within a UTF-8 format spec string.
+ * Returns pointer to "z", else NULL.
+ * The portion of the spec we're working with is [[fill]align][sign][z] */
+static const char *
+format_spec_z_search(char const *fmt, Py_ssize_t size) {
+ char const *pos = fmt;
+ char const *fmt_end = fmt + size;
+ /* skip over [[fill]align] (fill may be multi-byte character) */
+ pos += 1;
+ while (pos < fmt_end && *pos & 0x80) {
+ pos += 1;
+ }
+ if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
+ pos += 1;
+ } else {
+ /* fill not present-- skip over [align] */
+ pos = fmt;
+ if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
+ pos += 1;
+ }
+ }
+ /* skip over [sign] */
+ if (pos < fmt_end && strchr("+- ", *pos) != NULL) {
+ pos += 1;
+ }
+ return pos < fmt_end && *pos == 'z' ? pos : NULL;
+}
+
static int
dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const char **valuestr)
{
@@ -3220,11 +3270,16 @@ dec_format(PyObject *dec, PyObject *args)
PyObject *fmtarg;
PyObject *context;
mpd_spec_t spec;
- char *fmt;
+ char const *fmt;
+ char *fmt_copy = NULL;
char *decstring = NULL;
uint32_t status = 0;
int replace_fillchar = 0;
+ int no_neg_0 = 0;
Py_ssize_t size;
+ mpd_t *mpd = MPD(dec);
+ mpd_uint_t dt[MPD_MINALLOC_MAX];
+ mpd_t tmp = {MPD_STATIC|MPD_STATIC_DATA,0,0,0,MPD_MINALLOC_MAX,dt};
CURRENT_CONTEXT(context);
@@ -3233,19 +3288,39 @@ dec_format(PyObject *dec, PyObject *args)
}
if (PyUnicode_Check(fmtarg)) {
- fmt = (char *)PyUnicode_AsUTF8AndSize(fmtarg, &size);
+ fmt = PyUnicode_AsUTF8AndSize(fmtarg, &size);
if (fmt == NULL) {
return NULL;
}
+ /* NOTE: If https://github.com/python/cpython/pull/29438 lands, the
+ * format string manipulation below can be eliminated by enhancing
+ * the forked mpd_parse_fmt_str(). */
if (size > 0 && fmt[0] == '\0') {
/* NUL fill character: must be replaced with a valid UTF-8 char
before calling mpd_parse_fmt_str(). */
replace_fillchar = 1;
- fmt = dec_strdup(fmt, size);
- if (fmt == NULL) {
+ fmt = fmt_copy = dec_strdup(fmt, size);
+ if (fmt_copy == NULL) {
return NULL;
}
- fmt[0] = '_';
+ fmt_copy[0] = '_';
+ }
+ /* Strip 'z' option, which isn't understood by mpd_parse_fmt_str().
+ * NOTE: fmt is always null terminated by PyUnicode_AsUTF8AndSize() */
+ char const *z_position = format_spec_z_search(fmt, size);
+ if (z_position != NULL) {
+ no_neg_0 = 1;
+ size_t z_index = z_position - fmt;
+ if (fmt_copy == NULL) {
+ fmt = fmt_copy = dec_strdup(fmt, size);
+ if (fmt_copy == NULL) {
+ return NULL;
+ }
+ }
+ /* Shift characters (including null terminator) left,
+ overwriting the 'z' option. */
+ memmove(fmt_copy + z_index, fmt_copy + z_index + 1, size - z_index);
+ size -= 1;
}
}
else {
@@ -3311,8 +3386,45 @@ dec_format(PyObject *dec, PyObject *args)
}
}
+ if (no_neg_0 && mpd_isnegative(mpd) && !mpd_isspecial(mpd)) {
+ /* Round into a temporary (carefully mirroring the rounding
+ of mpd_qformat_spec()), and check if the result is negative zero.
+ If so, clear the sign and format the resulting positive zero. */
+ mpd_ssize_t prec;
+ mpd_qcopy(&tmp, mpd, &status);
+ if (spec.prec >= 0) {
+ switch (spec.type) {
+ case 'f':
+ mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
+ break;
+ case '%':
+ tmp.exp += 2;
+ mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
+ break;
+ case 'g':
+ prec = (spec.prec == 0) ? 1 : spec.prec;
+ if (tmp.digits > prec) {
+ _mpd_round(&tmp, &tmp, prec, CTX(context), &status);
+ }
+ break;
+ case 'e':
+ if (!mpd_iszero(&tmp)) {
+ _mpd_round(&tmp, &tmp, spec.prec+1, CTX(context), &status);
+ }
+ break;
+ }
+ }
+ if (status & MPD_Errors) {
+ PyErr_SetString(PyExc_ValueError, "unexpected error when rounding");
+ goto finish;
+ }
+ if (mpd_iszero(&tmp)) {
+ mpd_set_positive(&tmp);
+ mpd = &tmp;
+ }
+ }
- decstring = mpd_qformat_spec(MPD(dec), &spec, CTX(context), &status);
+ decstring = mpd_qformat_spec(mpd, &spec, CTX(context), &status);
if (decstring == NULL) {
if (status & MPD_Malloc_error) {
PyErr_NoMemory();
@@ -3335,7 +3447,7 @@ finish:
Py_XDECREF(grouping);
Py_XDECREF(sep);
Py_XDECREF(dot);
- if (replace_fillchar) PyMem_Free(fmt);
+ if (fmt_copy) PyMem_Free(fmt_copy);
if (decstring) mpd_free(decstring);
return result;
}