summaryrefslogtreecommitdiffstats
path: root/Modules/itertoolsmodule.c
diff options
context:
space:
mode:
authorLisa Roach <lisaroach14@gmail.com>2018-09-24 00:34:59 (GMT)
committerGitHub <noreply@github.com>2018-09-24 00:34:59 (GMT)
commit9718b59ee5f2416cdb8116ea5837b062faf0d9f8 (patch)
tree4cb5ba002be73a33680252e355c13800f4e9c42c /Modules/itertoolsmodule.c
parentc87d9f406bb23657c1b4cd63017bb7bd7693a1fb (diff)
downloadcpython-9718b59ee5f2416cdb8116ea5837b062faf0d9f8.zip
cpython-9718b59ee5f2416cdb8116ea5837b062faf0d9f8.tar.gz
cpython-9718b59ee5f2416cdb8116ea5837b062faf0d9f8.tar.bz2
bpo-34659: Adds initial kwarg to itertools.accumulate() (GH-9345)
Diffstat (limited to 'Modules/itertoolsmodule.c')
-rw-r--r--Modules/itertoolsmodule.c32
1 files changed, 29 insertions, 3 deletions
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index ec8f0ae..89c0280 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -3475,6 +3475,7 @@ typedef struct {
PyObject *total;
PyObject *it;
PyObject *binop;
+ PyObject *initial;
} accumulateobject;
static PyTypeObject accumulate_type;
@@ -3484,18 +3485,19 @@ static PyTypeObject accumulate_type;
itertools.accumulate.__new__
iterable: object
func as binop: object = None
+ *
+ initial: object = None
Return series of accumulated sums (or other binary function results).
[clinic start generated code]*/
static PyObject *
itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable,
- PyObject *binop)
-/*[clinic end generated code: output=514d0fb30ba14d55 input=6d9d16aaa1d3cbfc]*/
+ PyObject *binop, PyObject *initial)
+/*[clinic end generated code: output=66da2650627128f8 input=c4ce20ac59bf7ffd]*/
{
PyObject *it;
accumulateobject *lz;
-
/* Get iterator. */
it = PyObject_GetIter(iterable);
if (it == NULL)
@@ -3514,6 +3516,8 @@ itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable,
}
lz->total = NULL;
lz->it = it;
+ Py_XINCREF(initial);
+ lz->initial = initial;
return (PyObject *)lz;
}
@@ -3524,6 +3528,7 @@ accumulate_dealloc(accumulateobject *lz)
Py_XDECREF(lz->binop);
Py_XDECREF(lz->total);
Py_XDECREF(lz->it);
+ Py_XDECREF(lz->initial);
Py_TYPE(lz)->tp_free(lz);
}
@@ -3533,6 +3538,7 @@ accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg)
Py_VISIT(lz->binop);
Py_VISIT(lz->it);
Py_VISIT(lz->total);
+ Py_VISIT(lz->initial);
return 0;
}
@@ -3541,6 +3547,13 @@ accumulate_next(accumulateobject *lz)
{
PyObject *val, *newtotal;
+ if (lz->initial != Py_None) {
+ lz->total = lz->initial;
+ Py_INCREF(Py_None);
+ lz->initial = Py_None;
+ Py_INCREF(lz->total);
+ return lz->total;
+ }
val = (*Py_TYPE(lz->it)->tp_iternext)(lz->it);
if (val == NULL)
return NULL;
@@ -3567,6 +3580,19 @@ accumulate_next(accumulateobject *lz)
static PyObject *
accumulate_reduce(accumulateobject *lz, PyObject *Py_UNUSED(ignored))
{
+ if (lz->initial != Py_None) {
+ PyObject *it;
+
+ assert(lz->total == NULL);
+ if (PyType_Ready(&chain_type) < 0)
+ return NULL;
+ it = PyObject_CallFunction((PyObject *)&chain_type, "(O)O",
+ lz->initial, lz->it);
+ if (it == NULL)
+ return NULL;
+ return Py_BuildValue("O(NO)O", Py_TYPE(lz),
+ it, lz->binop?lz->binop:Py_None, Py_None);
+ }
if (lz->total == Py_None) {
PyObject *it;