summaryrefslogtreecommitdiffstats
path: root/Modules
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2023-01-07 18:46:35 (GMT)
committerGitHub <noreply@github.com>2023-01-07 18:46:35 (GMT)
commit47b9f83a83db288c652e43567c7b0f74d87a29be (patch)
treecb4fde0440b01f79852ef45e1a2a2fe22ba6daca /Modules
parentdeaf090699a7312cccb0637409f44de3f382389b (diff)
downloadcpython-47b9f83a83db288c652e43567c7b0f74d87a29be.zip
cpython-47b9f83a83db288c652e43567c7b0f74d87a29be.tar.gz
cpython-47b9f83a83db288c652e43567c7b0f74d87a29be.tar.bz2
GH-100485: Add math.sumprod() (GH-100677)
Diffstat (limited to 'Modules')
-rw-r--r--Modules/clinic/mathmodule.c.h39
-rw-r--r--Modules/mathmodule.c325
2 files changed, 363 insertions, 1 deletions
diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h
index 9fac103..1f97258 100644
--- a/Modules/clinic/mathmodule.c.h
+++ b/Modules/clinic/mathmodule.c.h
@@ -333,6 +333,43 @@ exit:
return return_value;
}
+PyDoc_STRVAR(math_sumprod__doc__,
+"sumprod($module, p, q, /)\n"
+"--\n"
+"\n"
+"Return the sum of products of values from two iterables p and q.\n"
+"\n"
+"Roughly equivalent to:\n"
+"\n"
+" sum(itertools.starmap(operator.mul, zip(p, q, strict=True)))\n"
+"\n"
+"For float and mixed int/float inputs, the intermediate products\n"
+"and sums are computed with extended precision.");
+
+#define MATH_SUMPROD_METHODDEF \
+ {"sumprod", _PyCFunction_CAST(math_sumprod), METH_FASTCALL, math_sumprod__doc__},
+
+static PyObject *
+math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q);
+
+static PyObject *
+math_sumprod(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
+{
+ PyObject *return_value = NULL;
+ PyObject *p;
+ PyObject *q;
+
+ if (!_PyArg_CheckPositional("sumprod", nargs, 2, 2)) {
+ goto exit;
+ }
+ p = args[0];
+ q = args[1];
+ return_value = math_sumprod_impl(module, p, q);
+
+exit:
+ return return_value;
+}
+
PyDoc_STRVAR(math_pow__doc__,
"pow($module, x, y, /)\n"
"--\n"
@@ -917,4 +954,4 @@ math_ulp(PyObject *module, PyObject *arg)
exit:
return return_value;
}
-/*[clinic end generated code: output=c2c2f42452d63734 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=899211ec70e4506c input=a9049054013a1b77]*/
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index 49c0293..0bcb336 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -68,6 +68,7 @@ raised for division by zero and mod by zero.
#include <float.h>
/* For _Py_log1p with workarounds for buggy handling of zeros. */
#include "_math.h"
+#include <stdbool.h>
#include "clinic/mathmodule.c.h"
@@ -2819,6 +2820,329 @@ For example, the hypotenuse of a 3/4/5 right triangle is:\n\
5.0\n\
");
+/** sumprod() ***************************************************************/
+
+/* Forward declaration */
+static inline int _check_long_mult_overflow(long a, long b);
+
+static inline bool
+long_add_would_overflow(long a, long b)
+{
+ return (a > 0) ? (b > LONG_MAX - a) : (b < LONG_MIN - a);
+}
+
+/*
+Double length extended precision floating point arithmetic
+based on ideas from three sources:
+
+ Improved Kahan–Babuška algorithm by Arnold Neumaier
+ https://www.mat.univie.ac.at/~neum/scan/01.pdf
+
+ A Floating-Point Technique for Extending the Available Precision
+ by T. J. Dekker
+ https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
+
+ Ultimately Fast Accurate Summation by Siegfried M. Rump
+ https://www.tuhh.de/ti3/paper/rump/Ru08b.pdf
+
+The double length routines allow for quite a bit of instruction
+level parallelism. On a 3.22 Ghz Apple M1 Max, the incremental
+cost of increasing the input vector size by one is 6.25 nsec.
+
+dl_zero() returns an extended precision zero
+dl_split() exactly splits a double into two half precision components.
+dl_add() performs compensated summation to keep a running total.
+dl_mul() implements lossless multiplication of doubles.
+dl_fma() implements an extended precision fused-multiply-add.
+dl_to_d() converts from extended precision to double precision.
+
+*/
+
+typedef struct{ double hi; double lo; } DoubleLength;
+
+static inline DoubleLength
+dl_zero()
+{
+ return (DoubleLength) {0.0, 0.0};
+}
+static inline DoubleLength
+dl_add(DoubleLength total, double x)
+{
+ double s = total.hi + x;
+ double c = total.lo;
+ if (fabs(total.hi) >= fabs(x)) {
+ c += (total.hi - s) + x;
+ } else {
+ c += (x - s) + total.hi;
+ }
+ return (DoubleLength) {s, c};
+}
+
+static inline DoubleLength
+dl_split(double x) {
+ double t = x * 134217729.0; /* Veltkamp constant = float(0x8000001) */
+ double hi = t - (t - x);
+ double lo = x - hi;
+ return (DoubleLength) {hi, lo};
+}
+
+static inline DoubleLength
+dl_mul(double x, double y)
+{
+ /* Dekker mul12(). Section (5.12) */
+ DoubleLength xx = dl_split(x);
+ DoubleLength yy = dl_split(y);
+ double p = xx.hi * yy.hi;
+ double q = xx.hi * yy.lo + xx.lo * yy.hi;
+ double z = p + q;
+ double zz = p - z + q + xx.lo * yy.lo;
+ return (DoubleLength) {z, zz};
+}
+
+static inline DoubleLength
+dl_fma(DoubleLength total, double p, double q)
+{
+ DoubleLength product = dl_mul(p, q);
+ total = dl_add(total, product.hi);
+ return dl_add(total, product.lo);
+}
+
+static inline double
+dl_to_d(DoubleLength total)
+{
+ return total.hi + total.lo;
+}
+
+/*[clinic input]
+math.sumprod
+
+ p: object
+ q: object
+ /
+
+Return the sum of products of values from two iterables p and q.
+
+Roughly equivalent to:
+
+ sum(itertools.starmap(operator.mul, zip(p, q, strict=True)))
+
+For float and mixed int/float inputs, the intermediate products
+and sums are computed with extended precision.
+[clinic start generated code]*/
+
+static PyObject *
+math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q)
+/*[clinic end generated code: output=6722dbfe60664554 input=82be54fe26f87e30]*/
+{
+ PyObject *p_i = NULL, *q_i = NULL, *term_i = NULL, *new_total = NULL;
+ PyObject *p_it, *q_it, *total;
+ iternextfunc p_next, q_next;
+ bool p_stopped = false, q_stopped = false;
+ bool int_path_enabled = true, int_total_in_use = false;
+ bool flt_path_enabled = true, flt_total_in_use = false;
+ long int_total = 0;
+ DoubleLength flt_total = dl_zero();
+
+ p_it = PyObject_GetIter(p);
+ if (p_it == NULL) {
+ return NULL;
+ }
+ q_it = PyObject_GetIter(q);
+ if (q_it == NULL) {
+ Py_DECREF(p_it);
+ return NULL;
+ }
+ total = PyLong_FromLong(0);
+ if (total == NULL) {
+ Py_DECREF(p_it);
+ Py_DECREF(q_it);
+ return NULL;
+ }
+ p_next = *Py_TYPE(p_it)->tp_iternext;
+ q_next = *Py_TYPE(q_it)->tp_iternext;
+ while (1) {
+ bool finished;
+
+ assert (p_i == NULL);
+ assert (q_i == NULL);
+ assert (term_i == NULL);
+ assert (new_total == NULL);
+
+ assert (p_it != NULL);
+ assert (q_it != NULL);
+ assert (total != NULL);
+
+ p_i = p_next(p_it);
+ if (p_i == NULL) {
+ if (PyErr_Occurred()) {
+ if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
+ goto err_exit;
+ }
+ PyErr_Clear();
+ }
+ p_stopped = true;
+ }
+ q_i = q_next(q_it);
+ if (q_i == NULL) {
+ if (PyErr_Occurred()) {
+ if (!PyErr_ExceptionMatches(PyExc_StopIteration)) {
+ goto err_exit;
+ }
+ PyErr_Clear();
+ }
+ q_stopped = true;
+ }
+ if (p_stopped != q_stopped) {
+ PyErr_Format(PyExc_ValueError, "Inputs are not the same length");
+ goto err_exit;
+ }
+ finished = p_stopped & q_stopped;
+
+ if (int_path_enabled) {
+
+ if (!finished && PyLong_CheckExact(p_i) & PyLong_CheckExact(q_i)) {
+ int overflow;
+ long int_p, int_q, int_prod;
+
+ int_p = PyLong_AsLongAndOverflow(p_i, &overflow);
+ if (overflow) {
+ goto finalize_int_path;
+ }
+ int_q = PyLong_AsLongAndOverflow(q_i, &overflow);
+ if (overflow) {
+ goto finalize_int_path;
+ }
+ if (_check_long_mult_overflow(int_p, int_q)) {
+ goto finalize_int_path;
+ }
+ int_prod = int_p * int_q;
+ if (long_add_would_overflow(int_total, int_prod)) {
+ goto finalize_int_path;
+ }
+ int_total += int_prod;
+ int_total_in_use = true;
+ Py_CLEAR(p_i);
+ Py_CLEAR(q_i);
+ continue;
+ }
+
+ finalize_int_path:
+ // # We're finished, overflowed, or have a non-int
+ int_path_enabled = false;
+ if (int_total_in_use) {
+ term_i = PyLong_FromLong(int_total);
+ if (term_i == NULL) {
+ goto err_exit;
+ }
+ new_total = PyNumber_Add(total, term_i);
+ if (new_total == NULL) {
+ goto err_exit;
+ }
+ Py_SETREF(total, new_total);
+ new_total = NULL;
+ Py_CLEAR(term_i);
+ int_total = 0; // An ounce of prevention, ...
+ int_total_in_use = false;
+ }
+ }
+
+ if (flt_path_enabled) {
+
+ if (!finished) {
+ double flt_p, flt_q;
+ bool p_type_float = PyFloat_CheckExact(p_i);
+ bool q_type_float = PyFloat_CheckExact(q_i);
+ if (p_type_float && q_type_float) {
+ flt_p = PyFloat_AS_DOUBLE(p_i);
+ flt_q = PyFloat_AS_DOUBLE(q_i);
+ } else if (p_type_float && (PyLong_CheckExact(q_i) || PyBool_Check(q_i))) {
+ /* We care about float/int pairs and int/float pairs because
+ they arise naturally in several use cases such as price
+ times quantity, measurements with integer weights, or
+ data selected by a vector of bools. */
+ flt_p = PyFloat_AS_DOUBLE(p_i);
+ flt_q = PyLong_AsDouble(q_i);
+ if (flt_q == -1.0 && PyErr_Occurred()) {
+ PyErr_Clear();
+ goto finalize_flt_path;
+ }
+ } else if (q_type_float && (PyLong_CheckExact(p_i) || PyBool_Check(q_i))) {
+ flt_q = PyFloat_AS_DOUBLE(q_i);
+ flt_p = PyLong_AsDouble(p_i);
+ if (flt_p == -1.0 && PyErr_Occurred()) {
+ PyErr_Clear();
+ goto finalize_flt_path;
+ }
+ } else {
+ goto finalize_flt_path;
+ }
+ DoubleLength new_flt_total = dl_fma(flt_total, flt_p, flt_q);
+ if (isfinite(new_flt_total.hi)) {
+ flt_total = new_flt_total;
+ flt_total_in_use = true;
+ Py_CLEAR(p_i);
+ Py_CLEAR(q_i);
+ continue;
+ }
+ }
+
+ finalize_flt_path:
+ // We're finished, overflowed, have a non-float, or got a non-finite value
+ flt_path_enabled = false;
+ if (flt_total_in_use) {
+ term_i = PyFloat_FromDouble(dl_to_d(flt_total));
+ if (term_i == NULL) {
+ goto err_exit;
+ }
+ new_total = PyNumber_Add(total, term_i);
+ if (new_total == NULL) {
+ goto err_exit;
+ }
+ Py_SETREF(total, new_total);
+ new_total = NULL;
+ Py_CLEAR(term_i);
+ flt_total = dl_zero();
+ flt_total_in_use = false;
+ }
+ }
+
+ assert(!int_total_in_use);
+ assert(!flt_total_in_use);
+ if (finished) {
+ goto normal_exit;
+ }
+ term_i = PyNumber_Multiply(p_i, q_i);
+ if (term_i == NULL) {
+ goto err_exit;
+ }
+ new_total = PyNumber_Add(total, term_i);
+ if (new_total == NULL) {
+ goto err_exit;
+ }
+ Py_SETREF(total, new_total);
+ new_total = NULL;
+ Py_CLEAR(p_i);
+ Py_CLEAR(q_i);
+ Py_CLEAR(term_i);
+ }
+
+ normal_exit:
+ Py_DECREF(p_it);
+ Py_DECREF(q_it);
+ return total;
+
+ err_exit:
+ Py_DECREF(p_it);
+ Py_DECREF(q_it);
+ Py_DECREF(total);
+ Py_XDECREF(p_i);
+ Py_XDECREF(q_i);
+ Py_XDECREF(term_i);
+ Py_XDECREF(new_total);
+ return NULL;
+}
+
+
/* pow can't use math_2, but needs its own wrapper: the problem is
that an infinite result can arise either as a result of overflow
(in which case OverflowError should be raised) or as a result of
@@ -3933,6 +4257,7 @@ static PyMethodDef math_methods[] = {
{"sqrt", math_sqrt, METH_O, math_sqrt_doc},
{"tan", math_tan, METH_O, math_tan_doc},
{"tanh", math_tanh, METH_O, math_tanh_doc},
+ MATH_SUMPROD_METHODDEF
MATH_TRUNC_METHODDEF
MATH_PROD_METHODDEF
MATH_PERM_METHODDEF