diff options
-rw-r--r-- | Doc/library/itertools.rst | 30 | ||||
-rw-r--r-- | Lib/test/test_itertools.py | 6 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2018-09-16-17-04-16.bpo-34659.CWemzH.rst | 1 | ||||
-rw-r--r-- | Modules/clinic/itertoolsmodule.c.h | 15 | ||||
-rw-r--r-- | Modules/itertoolsmodule.c | 32 |
5 files changed, 65 insertions, 19 deletions
diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index 959424f..b1513cd 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -86,29 +86,38 @@ The following module functions all construct and return iterators. Some provide streams of infinite length, so they should only be accessed by functions or loops that truncate the stream. -.. function:: accumulate(iterable[, func]) +.. function:: accumulate(iterable[, func, *, initial=None]) Make an iterator that returns accumulated sums, or accumulated results of other binary functions (specified via the optional - *func* argument). If *func* is supplied, it should be a function + *func* argument). + + If *func* is supplied, it should be a function of two arguments. Elements of the input *iterable* may be any type that can be accepted as arguments to *func*. (For example, with the default operation of addition, elements may be any addable type including :class:`~decimal.Decimal` or - :class:`~fractions.Fraction`.) If the input iterable is empty, the - output iterable will also be empty. + :class:`~fractions.Fraction`.) + + Usually, the number of elements output matches the input iterable. + However, if the keyword argument *initial* is provided, the + accumulation leads off with the *initial* value so that the output + has one more element than the input iterable. Roughly equivalent to:: - def accumulate(iterable, func=operator.add): + def accumulate(iterable, func=operator.add, *, initial=None): 'Return running totals' # accumulate([1,2,3,4,5]) --> 1 3 6 10 15 + # accumulate([1,2,3,4,5], initial=100) --> 100 101 103 106 110 115 # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120 it = iter(iterable) - try: - total = next(it) - except StopIteration: - return + total = initial + if initial is None: + try: + total = next(it) + except StopIteration: + return yield total for element in it: total = func(total, element) @@ -152,6 +161,9 @@ loops that truncate the stream. .. versionchanged:: 3.3 Added the optional *func* parameter. + .. versionchanged:: 3.8 + Added the optional *initial* parameter. + .. function:: chain(*iterables) Make an iterator that returns elements from the first iterable until it is diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index cbbb4c4..ea060a9 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -147,6 +147,12 @@ class TestBasicOps(unittest.TestCase): list(accumulate(s, chr)) # unary-operation for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.pickletest(proto, accumulate(range(10))) # test pickling + self.pickletest(proto, accumulate(range(10), initial=7)) + self.assertEqual(list(accumulate([10, 5, 1], initial=None)), [10, 15, 16]) + self.assertEqual(list(accumulate([10, 5, 1], initial=100)), [100, 110, 115, 116]) + self.assertEqual(list(accumulate([], initial=100)), [100]) + with self.assertRaises(TypeError): + list(accumulate([10, 20], 100)) def test_chain(self): diff --git a/Misc/NEWS.d/next/Library/2018-09-16-17-04-16.bpo-34659.CWemzH.rst b/Misc/NEWS.d/next/Library/2018-09-16-17-04-16.bpo-34659.CWemzH.rst new file mode 100644 index 0000000..3b7925a --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-09-16-17-04-16.bpo-34659.CWemzH.rst @@ -0,0 +1 @@ +Add an optional *initial* argument to itertools.accumulate(). diff --git a/Modules/clinic/itertoolsmodule.c.h b/Modules/clinic/itertoolsmodule.c.h index 94df96c..476adc1 100644 --- a/Modules/clinic/itertoolsmodule.c.h +++ b/Modules/clinic/itertoolsmodule.c.h @@ -382,29 +382,30 @@ exit: } PyDoc_STRVAR(itertools_accumulate__doc__, -"accumulate(iterable, func=None)\n" +"accumulate(iterable, func=None, *, initial=None)\n" "--\n" "\n" "Return series of accumulated sums (or other binary function results)."); static PyObject * itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable, - PyObject *binop); + PyObject *binop, PyObject *initial); static PyObject * itertools_accumulate(PyTypeObject *type, PyObject *args, PyObject *kwargs) { PyObject *return_value = NULL; - static const char * const _keywords[] = {"iterable", "func", NULL}; - static _PyArg_Parser _parser = {"O|O:accumulate", _keywords, 0}; + static const char * const _keywords[] = {"iterable", "func", "initial", NULL}; + static _PyArg_Parser _parser = {"O|O$O:accumulate", _keywords, 0}; PyObject *iterable; PyObject *binop = Py_None; + PyObject *initial = Py_None; if (!_PyArg_ParseTupleAndKeywordsFast(args, kwargs, &_parser, - &iterable, &binop)) { + &iterable, &binop, &initial)) { goto exit; } - return_value = itertools_accumulate_impl(type, iterable, binop); + return_value = itertools_accumulate_impl(type, iterable, binop, initial); exit: return return_value; @@ -509,4 +510,4 @@ itertools_count(PyTypeObject *type, PyObject *args, PyObject *kwargs) exit: return return_value; } -/*[clinic end generated code: output=d9eb9601bd3296ef input=a9049054013a1b77]*/ +/*[clinic end generated code: output=c8c47b766deeffc3 input=a9049054013a1b77]*/ diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index ec8f0ae..89c0280 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -3475,6 +3475,7 @@ typedef struct { PyObject *total; PyObject *it; PyObject *binop; + PyObject *initial; } accumulateobject; static PyTypeObject accumulate_type; @@ -3484,18 +3485,19 @@ static PyTypeObject accumulate_type; itertools.accumulate.__new__ iterable: object func as binop: object = None + * + initial: object = None Return series of accumulated sums (or other binary function results). [clinic start generated code]*/ static PyObject * itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable, - PyObject *binop) -/*[clinic end generated code: output=514d0fb30ba14d55 input=6d9d16aaa1d3cbfc]*/ + PyObject *binop, PyObject *initial) +/*[clinic end generated code: output=66da2650627128f8 input=c4ce20ac59bf7ffd]*/ { PyObject *it; accumulateobject *lz; - /* Get iterator. */ it = PyObject_GetIter(iterable); if (it == NULL) @@ -3514,6 +3516,8 @@ itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable, } lz->total = NULL; lz->it = it; + Py_XINCREF(initial); + lz->initial = initial; return (PyObject *)lz; } @@ -3524,6 +3528,7 @@ accumulate_dealloc(accumulateobject *lz) Py_XDECREF(lz->binop); Py_XDECREF(lz->total); Py_XDECREF(lz->it); + Py_XDECREF(lz->initial); Py_TYPE(lz)->tp_free(lz); } @@ -3533,6 +3538,7 @@ accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg) Py_VISIT(lz->binop); Py_VISIT(lz->it); Py_VISIT(lz->total); + Py_VISIT(lz->initial); return 0; } @@ -3541,6 +3547,13 @@ accumulate_next(accumulateobject *lz) { PyObject *val, *newtotal; + if (lz->initial != Py_None) { + lz->total = lz->initial; + Py_INCREF(Py_None); + lz->initial = Py_None; + Py_INCREF(lz->total); + return lz->total; + } val = (*Py_TYPE(lz->it)->tp_iternext)(lz->it); if (val == NULL) return NULL; @@ -3567,6 +3580,19 @@ accumulate_next(accumulateobject *lz) static PyObject * accumulate_reduce(accumulateobject *lz, PyObject *Py_UNUSED(ignored)) { + if (lz->initial != Py_None) { + PyObject *it; + + assert(lz->total == NULL); + if (PyType_Ready(&chain_type) < 0) + return NULL; + it = PyObject_CallFunction((PyObject *)&chain_type, "(O)O", + lz->initial, lz->it); + if (it == NULL) + return NULL; + return Py_BuildValue("O(NO)O", Py_TYPE(lz), + it, lz->binop?lz->binop:Py_None, Py_None); + } if (lz->total == Py_None) { PyObject *it; |