summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Doc/library/itertools.rst30
-rw-r--r--Lib/test/test_itertools.py6
-rw-r--r--Misc/NEWS.d/next/Library/2018-09-16-17-04-16.bpo-34659.CWemzH.rst1
-rw-r--r--Modules/clinic/itertoolsmodule.c.h15
-rw-r--r--Modules/itertoolsmodule.c32
5 files changed, 65 insertions, 19 deletions
diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst
index 959424f..b1513cd 100644
--- a/Doc/library/itertools.rst
+++ b/Doc/library/itertools.rst
@@ -86,29 +86,38 @@ The following module functions all construct and return iterators. Some provide
streams of infinite length, so they should only be accessed by functions or
loops that truncate the stream.
-.. function:: accumulate(iterable[, func])
+.. function:: accumulate(iterable[, func, *, initial=None])
Make an iterator that returns accumulated sums, or accumulated
results of other binary functions (specified via the optional
- *func* argument). If *func* is supplied, it should be a function
+ *func* argument).
+
+ If *func* is supplied, it should be a function
of two arguments. Elements of the input *iterable* may be any type
that can be accepted as arguments to *func*. (For example, with
the default operation of addition, elements may be any addable
type including :class:`~decimal.Decimal` or
- :class:`~fractions.Fraction`.) If the input iterable is empty, the
- output iterable will also be empty.
+ :class:`~fractions.Fraction`.)
+
+ Usually, the number of elements output matches the input iterable.
+ However, if the keyword argument *initial* is provided, the
+ accumulation leads off with the *initial* value so that the output
+ has one more element than the input iterable.
Roughly equivalent to::
- def accumulate(iterable, func=operator.add):
+ def accumulate(iterable, func=operator.add, *, initial=None):
'Return running totals'
# accumulate([1,2,3,4,5]) --> 1 3 6 10 15
+ # accumulate([1,2,3,4,5], initial=100) --> 100 101 103 106 110 115
# accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
it = iter(iterable)
- try:
- total = next(it)
- except StopIteration:
- return
+ total = initial
+ if initial is None:
+ try:
+ total = next(it)
+ except StopIteration:
+ return
yield total
for element in it:
total = func(total, element)
@@ -152,6 +161,9 @@ loops that truncate the stream.
.. versionchanged:: 3.3
Added the optional *func* parameter.
+ .. versionchanged:: 3.8
+ Added the optional *initial* parameter.
+
.. function:: chain(*iterables)
Make an iterator that returns elements from the first iterable until it is
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index cbbb4c4..ea060a9 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -147,6 +147,12 @@ class TestBasicOps(unittest.TestCase):
list(accumulate(s, chr)) # unary-operation
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
self.pickletest(proto, accumulate(range(10))) # test pickling
+ self.pickletest(proto, accumulate(range(10), initial=7))
+ self.assertEqual(list(accumulate([10, 5, 1], initial=None)), [10, 15, 16])
+ self.assertEqual(list(accumulate([10, 5, 1], initial=100)), [100, 110, 115, 116])
+ self.assertEqual(list(accumulate([], initial=100)), [100])
+ with self.assertRaises(TypeError):
+ list(accumulate([10, 20], 100))
def test_chain(self):
diff --git a/Misc/NEWS.d/next/Library/2018-09-16-17-04-16.bpo-34659.CWemzH.rst b/Misc/NEWS.d/next/Library/2018-09-16-17-04-16.bpo-34659.CWemzH.rst
new file mode 100644
index 0000000..3b7925a
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-09-16-17-04-16.bpo-34659.CWemzH.rst
@@ -0,0 +1 @@
+Add an optional *initial* argument to itertools.accumulate().
diff --git a/Modules/clinic/itertoolsmodule.c.h b/Modules/clinic/itertoolsmodule.c.h
index 94df96c..476adc1 100644
--- a/Modules/clinic/itertoolsmodule.c.h
+++ b/Modules/clinic/itertoolsmodule.c.h
@@ -382,29 +382,30 @@ exit:
}
PyDoc_STRVAR(itertools_accumulate__doc__,
-"accumulate(iterable, func=None)\n"
+"accumulate(iterable, func=None, *, initial=None)\n"
"--\n"
"\n"
"Return series of accumulated sums (or other binary function results).");
static PyObject *
itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable,
- PyObject *binop);
+ PyObject *binop, PyObject *initial);
static PyObject *
itertools_accumulate(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
PyObject *return_value = NULL;
- static const char * const _keywords[] = {"iterable", "func", NULL};
- static _PyArg_Parser _parser = {"O|O:accumulate", _keywords, 0};
+ static const char * const _keywords[] = {"iterable", "func", "initial", NULL};
+ static _PyArg_Parser _parser = {"O|O$O:accumulate", _keywords, 0};
PyObject *iterable;
PyObject *binop = Py_None;
+ PyObject *initial = Py_None;
if (!_PyArg_ParseTupleAndKeywordsFast(args, kwargs, &_parser,
- &iterable, &binop)) {
+ &iterable, &binop, &initial)) {
goto exit;
}
- return_value = itertools_accumulate_impl(type, iterable, binop);
+ return_value = itertools_accumulate_impl(type, iterable, binop, initial);
exit:
return return_value;
@@ -509,4 +510,4 @@ itertools_count(PyTypeObject *type, PyObject *args, PyObject *kwargs)
exit:
return return_value;
}
-/*[clinic end generated code: output=d9eb9601bd3296ef input=a9049054013a1b77]*/
+/*[clinic end generated code: output=c8c47b766deeffc3 input=a9049054013a1b77]*/
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index ec8f0ae..89c0280 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -3475,6 +3475,7 @@ typedef struct {
PyObject *total;
PyObject *it;
PyObject *binop;
+ PyObject *initial;
} accumulateobject;
static PyTypeObject accumulate_type;
@@ -3484,18 +3485,19 @@ static PyTypeObject accumulate_type;
itertools.accumulate.__new__
iterable: object
func as binop: object = None
+ *
+ initial: object = None
Return series of accumulated sums (or other binary function results).
[clinic start generated code]*/
static PyObject *
itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable,
- PyObject *binop)
-/*[clinic end generated code: output=514d0fb30ba14d55 input=6d9d16aaa1d3cbfc]*/
+ PyObject *binop, PyObject *initial)
+/*[clinic end generated code: output=66da2650627128f8 input=c4ce20ac59bf7ffd]*/
{
PyObject *it;
accumulateobject *lz;
-
/* Get iterator. */
it = PyObject_GetIter(iterable);
if (it == NULL)
@@ -3514,6 +3516,8 @@ itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable,
}
lz->total = NULL;
lz->it = it;
+ Py_XINCREF(initial);
+ lz->initial = initial;
return (PyObject *)lz;
}
@@ -3524,6 +3528,7 @@ accumulate_dealloc(accumulateobject *lz)
Py_XDECREF(lz->binop);
Py_XDECREF(lz->total);
Py_XDECREF(lz->it);
+ Py_XDECREF(lz->initial);
Py_TYPE(lz)->tp_free(lz);
}
@@ -3533,6 +3538,7 @@ accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg)
Py_VISIT(lz->binop);
Py_VISIT(lz->it);
Py_VISIT(lz->total);
+ Py_VISIT(lz->initial);
return 0;
}
@@ -3541,6 +3547,13 @@ accumulate_next(accumulateobject *lz)
{
PyObject *val, *newtotal;
+ if (lz->initial != Py_None) {
+ lz->total = lz->initial;
+ Py_INCREF(Py_None);
+ lz->initial = Py_None;
+ Py_INCREF(lz->total);
+ return lz->total;
+ }
val = (*Py_TYPE(lz->it)->tp_iternext)(lz->it);
if (val == NULL)
return NULL;
@@ -3567,6 +3580,19 @@ accumulate_next(accumulateobject *lz)
static PyObject *
accumulate_reduce(accumulateobject *lz, PyObject *Py_UNUSED(ignored))
{
+ if (lz->initial != Py_None) {
+ PyObject *it;
+
+ assert(lz->total == NULL);
+ if (PyType_Ready(&chain_type) < 0)
+ return NULL;
+ it = PyObject_CallFunction((PyObject *)&chain_type, "(O)O",
+ lz->initial, lz->it);
+ if (it == NULL)
+ return NULL;
+ return Py_BuildValue("O(NO)O", Py_TYPE(lz),
+ it, lz->binop?lz->binop:Py_None, Py_None);
+ }
if (lz->total == Py_None) {
PyObject *it;