diff options
author | Mark Dickinson <dickinsm@gmail.com> | 2008-05-23 01:35:30 (GMT) |
---|---|---|
committer | Mark Dickinson <dickinsm@gmail.com> | 2008-05-23 01:35:30 (GMT) |
commit | 99dfe927590ef95b9ce42f8334e3126c5960ad6f (patch) | |
tree | 72a56f54ed7b3b8e9153fb585ee268b64a3f1cda /Modules | |
parent | cc858ccc500cf1606f64fc6cc47ab8af230c89e6 (diff) | |
download | cpython-99dfe927590ef95b9ce42f8334e3126c5960ad6f.zip cpython-99dfe927590ef95b9ce42f8334e3126c5960ad6f.tar.gz cpython-99dfe927590ef95b9ce42f8334e3126c5960ad6f.tar.bz2 |
Issue #2819: Add math.sum, a function that sums a sequence of floats
efficiently but with no intermediate loss of precision. Based on
Raymond Hettinger's ASPN recipe. Thanks Jean Brouwers for the patch.
Diffstat (limited to 'Modules')
-rw-r--r-- | Modules/mathmodule.c | 223 |
1 files changed, 223 insertions, 0 deletions
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index c4ac69a..19d6f43 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -307,6 +307,228 @@ FUNC1(tan, tan, 0, FUNC1(tanh, tanh, 0, "tanh(x)\n\nReturn the hyperbolic tangent of x.") +/* Precision summation function as msum() by Raymond Hettinger in + <http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/393090>, + enhanced with the exact partials sum and roundoff from Mark + Dickinson's post at <http://bugs.python.org/file10357/msum4.py>. + + See both of those for more details, proofs and other references. + + Note 1: IEEE 754 floating point format and semantics are assumed, but not + explicitly maintained. The following rules may not apply: + + 1. if the summands include a NaN, return a NaN, + + 2. if the summands include infinities of both signs, raise ValueError, + + 3. if the summands include infinities of only one sign, return infinity + with that sign, + + 4. otherwise (all summands are finite) if the result is infinite, raise + OverflowError. The result can never be a NaN if all summands are + finite. + + Note 2: the implementation below not include the intermediate overflow + handling from Mark Dickinson's msum(). Therefore, sum([1e+308, 1e-308, + 1e+308]) returns result 1e+308, however sum([1e+308, 1e+308, 1e-308]) + raises an OverflowError due to intermediate overflow of the first + partial sum. + + Note 3: aggressively optimizing compilers may eliminate the roundoff + expressions critical for accurate summation. For example, the compiler + may optimize the following expressions + + hi = x + y; + lo = y - (hi - x); + to + hi = x + y; + lo = 0.0; + + defeating the whole purpose. Using volatile variables and/or explicit + assignment of critical subexpressions to a volatile variable should + remedy the problem + + volatile double v; // Deter compiler from algebraically optimizing + // this critical, intermediate value away + hi = x + y; + v = hi - x; + lo = y - v; + + by forcing the compiler to compute the value for v. This may also help + when subexpression are not computed with the full double precision. + + Note 4. the same summation functions may be in ./cmathmodule.c. Make + sure to update both when making changes. +*/ + +#define NUM_PARTIALS 32 /* initial partials array size, on stack */ + +/* Extend the partials array p[] by doubling its size. + */ +static int /* non-zero on error */ +_sum_realloc(double **p_ptr, Py_ssize_t n, + double *ps, Py_ssize_t *m_ptr) +{ + void *v = NULL; + Py_ssize_t m = *m_ptr; + + m += m; /* double */ + if (n < m && m < (PY_SSIZE_T_MAX / sizeof(double))) { + double *p = *p_ptr; + if (p == ps) { + v = PyMem_Malloc(sizeof(double) * m); + if (v != NULL) + memcpy(v, ps, sizeof(double) * n); + } + else + v = PyMem_Realloc(p, sizeof(double) * m); + } + if (v == NULL) { /* size overflow or no memory */ + PyErr_SetString(PyExc_MemoryError, "math sum partials"); + return 1; + } + *p_ptr = (double*) v; + *m_ptr = m; + return 0; +} + +/* Full precision summation of a sequence of floats. + + def msum(iterable): + partials = [] # sorted, non-overlapping partial sums + for x in iterable: + i = 0 + for y in partials: + if abs(x) < abs(y): + x, y = y, x + hi = x + y + lo = y - (hi - x) + if lo: + partials[i] = lo + i += 1 + x = hi + partials[i:] = [x] + return sum_exact(partials) + + Rounded x+y stored in hi with the roundoff stored in lo. Together hi+lo + are exactly equal to x+y. The inner loop applies hi/lo summation to each + partial so that the list of partial sums remains exact. + + Sum_exact() adds the partial sums exactly and correctly rounds the final + result (using the round-half-to-even rule). The items in partials remain + non-zero, non-special, non-overlapping and strictly increasing in + magnitude, but possibly not all having the same sign. + + Depends on IEEE 754 arithmetic guarantees. + */ +static PyObject* +math_sum(PyObject *self, PyObject *seq) +{ + PyObject *item, *iter, *sum = NULL; + Py_ssize_t i, j, n = 0, m = NUM_PARTIALS; + double x, y, hi, lo=0.0, ps[NUM_PARTIALS], *p = ps; + + iter = PyObject_GetIter(seq); + if (iter == NULL) + return NULL; + + PyFPE_START_PROTECT("sum", Py_DECREF(iter); return NULL) + + for(;;) { /* for x in iterable */ + /* some invariants */ + assert(0 <= n && n <= m); + assert((m == NUM_PARTIALS && p == ps) || + (m > NUM_PARTIALS && p != NULL)); + + item = PyIter_Next(iter); + if (item == NULL) { + if (PyErr_Occurred()) + goto _sum_error; + else + break; + } + x = PyFloat_AsDouble(item); + Py_DECREF(item); + if (PyErr_Occurred()) + goto _sum_error; + + for (i = j = 0; j < n; j++) { /* for y in partials */ + y = p[j]; + hi = x + y; + lo = fabs(x) < fabs(y) + ? x - (hi - y) /* volatile */ + : y - (hi - x); /* volatile */ + if (lo != 0.0) + p[i++] = lo; + x = hi; + } + /* ps[i:] = [x] */ + n = i; + if (x != 0.0) { + /* if non-finite, reset partials, effectively + adding subsequent items without roundoff + and yielding correct non-finite results, + provided IEEE 754 rules are observed */ + if (! Py_IS_FINITE(x)) + n = 0; + else if (n >= m && _sum_realloc(&p, n, ps, &m)) + goto _sum_error; + p[n++] = x; + } + } + assert(n <= m); + + if (n > 0) { + hi = p[--n]; + if (Py_IS_FINITE(hi)) { + /* sum_exact(ps, hi) from the top, stop + as soon as the sum becomes inexact */ + while (n > 0) { + x = p[--n]; + y = hi; + hi = x + y; + assert(fabs(x) < fabs(y)); + lo = x - (hi - y); /* volatile */ + if (lo != 0.0) + break; + } + /* round correctly if necessary */ + if (n > 0 && ((lo < 0.0 && p[n-1] < 0.0) || + (lo > 0.0 && p[n-1] > 0.0))) { + y = lo * 2.0; + x = hi + y; /* volatile */ + if (y == (x - hi)) + hi = x; + } + } + else { /* raise corresponding error */ + errno = Py_IS_NAN(hi) ? EDOM : ERANGE; + if (is_error(hi)) + goto _sum_error; + } + } + else /* default */ + hi = 0.0; + sum = PyFloat_FromDouble(hi); + +_sum_error: + PyFPE_END_PROTECT(hi) + + Py_DECREF(iter); + if (p != ps) + PyMem_Free(p); + return sum; +} + +#undef NUM_PARTIALS + +PyDoc_STRVAR(math_sum_doc, +"sum(sequence)\n\n\ +Return the full precision sum of a sequence of numbers.\n\ +When the sequence is empty, return zero.\n\n\ +For accurate results, IEEE 754 floating point format\n\ +and semantics and floating point radix 2 are required."); + static PyObject * math_trunc(PyObject *self, PyObject *number) { @@ -760,6 +982,7 @@ static PyMethodDef math_methods[] = { {"sin", math_sin, METH_O, math_sin_doc}, {"sinh", math_sinh, METH_O, math_sinh_doc}, {"sqrt", math_sqrt, METH_O, math_sqrt_doc}, + {"sum", math_sum, METH_O, math_sum_doc}, {"tan", math_tan, METH_O, math_tan_doc}, {"tanh", math_tanh, METH_O, math_tanh_doc}, {"trunc", math_trunc, METH_O, math_trunc_doc}, |