summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRaymond Hettinger <python@rcn.com>2010-12-01 22:48:00 (GMT)
committerRaymond Hettinger <python@rcn.com>2010-12-01 22:48:00 (GMT)
commit482ba772456259633fa8f0433b8b4220c55c9cab (patch)
tree7c365089bd3189f8a211f88673622ebfd6cc2def
parent2f9a77a389c4182d4960b8b143c4c456a16ea5f3 (diff)
downloadcpython-482ba772456259633fa8f0433b8b4220c55c9cab.zip
cpython-482ba772456259633fa8f0433b8b4220c55c9cab.tar.gz
cpython-482ba772456259633fa8f0433b8b4220c55c9cab.tar.bz2
Add itertools.accumulate().
-rw-r--r--Lib/test/test_itertools.py35
-rw-r--r--Misc/NEWS2
-rw-r--r--Modules/itertoolsmodule.c142
3 files changed, 179 insertions, 0 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index fe6131a..8a67cff 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -56,6 +56,23 @@ def fact(n):
return prod(range(1, n+1))
class TestBasicOps(unittest.TestCase):
+
+ def test_accumulate(self):
+ self.assertEqual(list(accumulate(range(10))), # one positional arg
+ [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
+ self.assertEqual(list(accumulate(range(10), 100)), # two positional args
+ [100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
+ self.assertEqual(list(accumulate(iterable=range(10), start=100)), # kw args
+ [100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
+ for typ in int, complex, Decimal, Fraction: # multiple types
+ self.assertEqual(list(accumulate(range(10), typ(0))),
+ list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])))
+ self.assertEqual(list(accumulate([])), []) # empty iterable
+ self.assertRaises(TypeError, accumulate, range(10), 0, 5) # too many args
+ self.assertRaises(TypeError, accumulate) # too few args
+ self.assertRaises(TypeError, accumulate, range(10), x=7) # unexpected kwd args
+ self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add
+
def test_chain(self):
def chain2(*iterables):
@@ -932,6 +949,9 @@ class TestBasicOps(unittest.TestCase):
class TestExamples(unittest.TestCase):
+ def test_accumlate(self):
+ self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15])
+
def test_chain(self):
self.assertEqual(''.join(chain('ABC', 'DEF')), 'ABCDEF')
@@ -1019,6 +1039,10 @@ class TestGC(unittest.TestCase):
next(iterator)
del container, iterator
+ def test_accumulate(self):
+ a = []
+ self.makecycle(accumulate([1,2,a,3]), a)
+
def test_chain(self):
a = []
self.makecycle(chain(a), a)
@@ -1188,6 +1212,17 @@ def L(seqn):
class TestVariousIteratorArgs(unittest.TestCase):
+ def test_accumulate(self):
+ s = [1,2,3,4,5]
+ r = [1,3,6,10,15]
+ n = len(s)
+ for g in (G, I, Ig, L, R):
+ self.assertEqual(list(accumulate(g(s))), r)
+ self.assertEqual(list(accumulate(S(s))), [])
+ self.assertRaises(TypeError, accumulate, X(s))
+ self.assertRaises(TypeError, accumulate, N(s))
+ self.assertRaises(ZeroDivisionError, list, accumulate(E(s)))
+
def test_chain(self):
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
for g in (G, I, Ig, S, L, R):
diff --git a/Misc/NEWS b/Misc/NEWS
index 7d77b20..494087c 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -46,6 +46,8 @@ Core and Builtins
Library
-------
+- Added itertools.accumulate().
+
- Issue #4113: Added custom ``__repr__`` method to ``functools.partial``.
Original patch by Daniel Urban.
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index d5336f2..04bfffc 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -2584,6 +2584,146 @@ static PyTypeObject permutations_type = {
PyObject_GC_Del, /* tp_free */
};
+/* accumulate object ************************************************************/
+
+typedef struct {
+ PyObject_HEAD
+ PyObject *total;
+ PyObject *it;
+} accumulateobject;
+
+static PyTypeObject accumulate_type;
+
+static PyObject *
+accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+ static char *kwargs[] = {"iterable", "start", NULL};
+ PyObject *iterable;
+ PyObject *it;
+ PyObject *start = NULL;
+ accumulateobject *lz;
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate",
+ kwargs, &iterable, &start))
+ return NULL;
+
+ /* Get iterator. */
+ it = PyObject_GetIter(iterable);
+ if (it == NULL)
+ return NULL;
+
+ /* Default start value */
+ if (start == NULL) {
+ start = PyLong_FromLong(0);
+ if (start == NULL) {
+ Py_DECREF(it);
+ return NULL;
+ }
+ } else {
+ Py_INCREF(start);
+ }
+
+ /* create accumulateobject structure */
+ lz = (accumulateobject *)type->tp_alloc(type, 0);
+ if (lz == NULL) {
+ Py_DECREF(it);
+ Py_DECREF(start);
+ return NULL;
+ }
+
+ lz->total = start;
+ lz->it = it;
+ return (PyObject *)lz;
+}
+
+static void
+accumulate_dealloc(accumulateobject *lz)
+{
+ PyObject_GC_UnTrack(lz);
+ Py_XDECREF(lz->total);
+ Py_XDECREF(lz->it);
+ Py_TYPE(lz)->tp_free(lz);
+}
+
+static int
+accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg)
+{
+ Py_VISIT(lz->it);
+ Py_VISIT(lz->total);
+ return 0;
+}
+
+static PyObject *
+accumulate_next(accumulateobject *lz)
+{
+ PyObject *val, *oldtotal, *newtotal;
+
+ val = PyIter_Next(lz->it);
+ if (val == NULL)
+ return NULL;
+
+ newtotal = PyNumber_Add(lz->total, val);
+ Py_DECREF(val);
+ if (newtotal == NULL)
+ return NULL;
+
+ oldtotal = lz->total;
+ lz->total = newtotal;
+ Py_DECREF(oldtotal);
+
+ Py_INCREF(newtotal);
+ return newtotal;
+}
+
+PyDoc_STRVAR(accumulate_doc,
+"accumulate(iterable, start=0) --> accumulate object\n\
+\n\
+Return series of accumulated sums.");
+
+static PyTypeObject accumulate_type = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ "itertools.accumulate", /* tp_name */
+ sizeof(accumulateobject), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ /* methods */
+ (destructor)accumulate_dealloc, /* tp_dealloc */
+ 0, /* tp_print */
+ 0, /* tp_getattr */
+ 0, /* tp_setattr */
+ 0, /* tp_reserved */
+ 0, /* tp_repr */
+ 0, /* tp_as_number */
+ 0, /* tp_as_sequence */
+ 0, /* tp_as_mapping */
+ 0, /* tp_hash */
+ 0, /* tp_call */
+ 0, /* tp_str */
+ PyObject_GenericGetAttr, /* tp_getattro */
+ 0, /* tp_setattro */
+ 0, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
+ Py_TPFLAGS_BASETYPE, /* tp_flags */
+ accumulate_doc, /* tp_doc */
+ (traverseproc)accumulate_traverse, /* tp_traverse */
+ 0, /* tp_clear */
+ 0, /* tp_richcompare */
+ 0, /* tp_weaklistoffset */
+ PyObject_SelfIter, /* tp_iter */
+ (iternextfunc)accumulate_next, /* tp_iternext */
+ 0, /* tp_methods */
+ 0, /* tp_members */
+ 0, /* tp_getset */
+ 0, /* tp_base */
+ 0, /* tp_dict */
+ 0, /* tp_descr_get */
+ 0, /* tp_descr_set */
+ 0, /* tp_dictoffset */
+ 0, /* tp_init */
+ 0, /* tp_alloc */
+ accumulate_new, /* tp_new */
+ PyObject_GC_Del, /* tp_free */
+};
+
/* compress object ************************************************************/
@@ -3496,6 +3636,7 @@ cycle(p) --> p0, p1, ... plast, p0, p1, ...\n\
repeat(elem [,n]) --> elem, elem, elem, ... endlessly or up to n times\n\
\n\
Iterators terminating on the shortest input sequence:\n\
+accumulate(p, start=0) --> p0, p0+p1, p0+p1+p2\n\
chain(p, q, ...) --> p0, p1, ... plast, q0, q1, ... \n\
compress(data, selectors) --> (d[0] if s[0]), (d[1] if s[1]), ...\n\
dropwhile(pred, seq) --> seq[n], seq[n+1], starting when pred fails\n\
@@ -3541,6 +3682,7 @@ PyInit_itertools(void)
PyObject *m;
char *name;
PyTypeObject *typelist[] = {
+ &accumulate_type,
&combinations_type,
&cwr_type,
&cycle_type,