diff options
author | Raymond Hettinger <python@rcn.com> | 2011-03-28 01:52:10 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2011-03-28 01:52:10 (GMT) |
commit | 5d44613e3bfd3b5e657b8dea808ef958543be7e0 (patch) | |
tree | 45feb9409f64893ad357722ee36f27480a7d9652 /Modules/itertoolsmodule.c | |
parent | af88d866994a20d13b75ecad6ab5144ef5bcc9fc (diff) | |
download | cpython-5d44613e3bfd3b5e657b8dea808ef958543be7e0.zip cpython-5d44613e3bfd3b5e657b8dea808ef958543be7e0.tar.gz cpython-5d44613e3bfd3b5e657b8dea808ef958543be7e0.tar.bz2 |
Add optional *func* argument to itertools.accumulate().
Diffstat (limited to 'Modules/itertoolsmodule.c')
-rw-r--r-- | Modules/itertoolsmodule.c | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index b202e52..4f58d57 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -2590,6 +2590,7 @@ typedef struct { PyObject_HEAD PyObject *total; PyObject *it; + PyObject *binop; } accumulateobject; static PyTypeObject accumulate_type; @@ -2597,12 +2598,14 @@ static PyTypeObject accumulate_type; static PyObject * accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { - static char *kwargs[] = {"iterable", NULL}; + static char *kwargs[] = {"iterable", "func", NULL}; PyObject *iterable; PyObject *it; + PyObject *binop = NULL; accumulateobject *lz; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:accumulate", kwargs, &iterable)) + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate", + kwargs, &iterable, &binop)) return NULL; /* Get iterator. */ @@ -2617,6 +2620,8 @@ accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds) return NULL; } + Py_XINCREF(binop); + lz->binop = binop; lz->total = NULL; lz->it = it; return (PyObject *)lz; @@ -2626,6 +2631,7 @@ static void accumulate_dealloc(accumulateobject *lz) { PyObject_GC_UnTrack(lz); + Py_XDECREF(lz->binop); Py_XDECREF(lz->total); Py_XDECREF(lz->it); Py_TYPE(lz)->tp_free(lz); @@ -2634,6 +2640,7 @@ accumulate_dealloc(accumulateobject *lz) static int accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg) { + Py_VISIT(lz->binop); Py_VISIT(lz->it); Py_VISIT(lz->total); return 0; @@ -2653,8 +2660,11 @@ accumulate_next(accumulateobject *lz) lz->total = val; return lz->total; } - - newtotal = PyNumber_Add(lz->total, val); + + if (lz->binop == NULL) + newtotal = PyNumber_Add(lz->total, val); + else + newtotal = PyObject_CallFunctionObjArgs(lz->binop, lz->total, val, NULL); Py_DECREF(val); if (newtotal == NULL) return NULL; |