diff options
author | Raymond Hettinger <python@rcn.com> | 2003-12-06 16:23:06 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2003-12-06 16:23:06 (GMT) |
commit | d25c1c635164daa5c300342ac99c0810fd9b575c (patch) | |
tree | df412ba3ffaa8fee35e2e12f96aab0beecdaaec0 | |
parent | b8d5f245b7077d869121835ed72656ac14962ef0 (diff) | |
download | cpython-d25c1c635164daa5c300342ac99c0810fd9b575c.zip cpython-d25c1c635164daa5c300342ac99c0810fd9b575c.tar.gz cpython-d25c1c635164daa5c300342ac99c0810fd9b575c.tar.bz2 |
Implement itertools.groupby()
Original idea by Guido van Rossum.
Idea for skipable inner iterators by Raymond Hettinger.
Idea for argument order and identity function default by Alex Martelli.
Implementation by Hye-Shik Chang (with tweaks by Raymond Hettinger).
-rw-r--r-- | Doc/lib/libitertools.tex | 60 | ||||
-rw-r--r-- | Lib/test/test_itertools.py | 108 | ||||
-rw-r--r-- | Misc/NEWS | 5 | ||||
-rw-r--r-- | Modules/itertoolsmodule.c | 322 |
4 files changed, 493 insertions, 2 deletions
diff --git a/Doc/lib/libitertools.tex b/Doc/lib/libitertools.tex index 6f9f5c6..82912b0 100644 --- a/Doc/lib/libitertools.tex +++ b/Doc/lib/libitertools.tex @@ -130,6 +130,54 @@ by functions or loops that truncate the stream. \end{verbatim} \end{funcdesc} +\begin{funcdesc}{groupby}{iterable\optional{, key}} + Make an iterator that returns consecutive keys and groups from the + \var{iterable}. \var{key} is function computing a key value for each + element. If not specified or is \code{None}, \var{key} defaults to an + identity function (returning the element unchanged). Generally, the + iterable needs to already be sorted on the same key function. + + The returned group is itself an iterator that shares the underlying + iterable with \function{groupby()}. Because the source is shared, when + the \function{groupby} object is advanced, the previous group is no + longer visible. So, if that data is needed later, it should be stored + as a list: + + \begin{verbatim} + groups = [] + uniquekeys = [] + for k, g in groupby(data, keyfunc): + groups.append(list(g)) # Store group iterator as a list + uniquekeys.append(k) + \end{verbatim} + + \function{groupby()} is equivalent to: + + \begin{verbatim} + class groupby(object): + def __init__(self, iterable, key=None): + if key is None: + key = lambda x: x + self.keyfunc = key + self.it = iter(iterable) + self.tgtkey = self.currkey = self.currvalue = xrange(0) + def __iter__(self): + return self + def next(self): + while self.currkey == self.tgtkey: + self.currvalue = self.it.next() # Exit on StopIteration + self.currkey = self.keyfunc(self.currvalue) + self.tgtkey = self.currkey + return (self.currkey, self._grouper(self.tgtkey)) + def _grouper(self, tgtkey): + while self.currkey == tgtkey: + yield self.currvalue + self.currvalue = self.it.next() # Exit on StopIteration + self.currkey = self.keyfunc(self.currvalue) + \end{verbatim} + \versionadded{2.4} +\end{funcdesc} + \begin{funcdesc}{ifilter}{predicate, iterable} Make an iterator that filters elements from iterable returning only those for which the predicate is \code{True}. @@ -346,6 +394,18 @@ Martin Walter Samuele +# Show a dictionary sorted and grouped by value +>>> from operator import itemgetter +>>> d = dict(a=1, b=2, c=1, d=2, e=1, f=2, g=3) +>>> di = list.sorted(d.iteritems(), key=itemgetter(1)) +>>> for k, g in groupby(di, key=itemgetter(1)): +... print k, map(itemgetter(0), g) +... +1 ['a', 'c', 'e'] +2 ['b', 'd', 'f'] +3 ['g'] + + \end{verbatim} This section shows how itertools can be combined to create other more diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 543acc1..b4c0a8b 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -61,6 +61,94 @@ class TestBasicOps(unittest.TestCase): self.assertRaises(TypeError, cycle, 5) self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0]) + def test_groupby(self): + # Check whether it accepts arguments correctly + self.assertEqual([], list(groupby([]))) + self.assertEqual([], list(groupby([], key=id))) + self.assertRaises(TypeError, list, groupby('abc', [])) + self.assertRaises(TypeError, groupby, None) + + # Check normal input + s = [(0, 10, 20), (0, 11,21), (0,12,21), (1,13,21), (1,14,22), + (2,15,22), (3,16,23), (3,17,23)] + dup = [] + for k, g in groupby(s, lambda r:r[0]): + for elem in g: + self.assertEqual(k, elem[0]) + dup.append(elem) + self.assertEqual(s, dup) + + # Check nested case + dup = [] + for k, g in groupby(s, lambda r:r[0]): + for ik, ig in groupby(g, lambda r:r[2]): + for elem in ig: + self.assertEqual(k, elem[0]) + self.assertEqual(ik, elem[2]) + dup.append(elem) + self.assertEqual(s, dup) + + # Check case where inner iterator is not used + keys = [k for k, g in groupby(s, lambda r:r[0])] + expectedkeys = set([r[0] for r in s]) + self.assertEqual(set(keys), expectedkeys) + self.assertEqual(len(keys), len(expectedkeys)) + + # Exercise pipes and filters style + s = 'abracadabra' + # sort s | uniq + r = [k for k, g in groupby(list.sorted(s))] + self.assertEqual(r, ['a', 'b', 'c', 'd', 'r']) + # sort s | uniq -d + r = [k for k, g in groupby(list.sorted(s)) if list(islice(g,1,2))] + self.assertEqual(r, ['a', 'b', 'r']) + # sort s | uniq -c + r = [(len(list(g)), k) for k, g in groupby(list.sorted(s))] + self.assertEqual(r, [(5, 'a'), (2, 'b'), (1, 'c'), (1, 'd'), (2, 'r')]) + # sort s | uniq -c | sort -rn | head -3 + r = list.sorted([(len(list(g)) , k) for k, g in groupby(list.sorted(s))], reverse=True)[:3] + self.assertEqual(r, [(5, 'a'), (2, 'r'), (2, 'b')]) + + # iter.next failure + class ExpectedError(Exception): + pass + def delayed_raise(n=0): + for i in range(n): + yield 'yo' + raise ExpectedError + def gulp(iterable, keyp=None, func=list): + return [func(g) for k, g in groupby(iterable, keyp)] + + # iter.next failure on outer object + self.assertRaises(ExpectedError, gulp, delayed_raise(0)) + # iter.next failure on inner object + self.assertRaises(ExpectedError, gulp, delayed_raise(1)) + + # __cmp__ failure + class DummyCmp: + def __cmp__(self, dst): + raise ExpectedError + s = [DummyCmp(), DummyCmp(), None] + + # __cmp__ failure on outer object + self.assertRaises(ExpectedError, gulp, s, func=id) + # __cmp__ failure on inner object + self.assertRaises(ExpectedError, gulp, s) + + # keyfunc failure + def keyfunc(obj): + if keyfunc.skip > 0: + keyfunc.skip -= 1 + return obj + else: + raise ExpectedError + + # keyfunc failure on outer object + keyfunc.skip = 0 + self.assertRaises(ExpectedError, gulp, [None], keyfunc) + keyfunc.skip = 1 + self.assertRaises(ExpectedError, gulp, [None, None], keyfunc) + def test_ifilter(self): self.assertEqual(list(ifilter(isEven, range(6))), [0,2,4]) self.assertEqual(list(ifilter(None, [0,1,0,2,0])), [1,2]) @@ -268,7 +356,7 @@ class TestBasicOps(unittest.TestCase): def test_StopIteration(self): self.assertRaises(StopIteration, izip().next) - for f in (chain, cycle, izip): + for f in (chain, cycle, izip, groupby): self.assertRaises(StopIteration, f([]).next) self.assertRaises(StopIteration, f(StopNow()).next) @@ -426,6 +514,14 @@ class TestVariousIteratorArgs(unittest.TestCase): self.assertRaises(TypeError, list, cycle(N(s))) self.assertRaises(ZeroDivisionError, list, cycle(E(s))) + def test_groupby(self): + for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)): + for g in (G, I, Ig, S, L, R): + self.assertEqual([k for k, sb in groupby(g(s))], list(g(s))) + self.assertRaises(TypeError, groupby, X(s)) + self.assertRaises(TypeError, list, groupby(N(s))) + self.assertRaises(ZeroDivisionError, list, groupby(E(s))) + def test_ifilter(self): for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)): for g in (G, I, Ig, S, L, R): @@ -571,6 +667,16 @@ Martin Walter Samuele +>>> from operator import itemgetter +>>> d = dict(a=1, b=2, c=1, d=2, e=1, f=2, g=3) +>>> di = list.sorted(d.iteritems(), key=itemgetter(1)) +>>> for k, g in groupby(di, itemgetter(1)): +... print k, map(itemgetter(0), g) +... +1 ['a', 'c', 'e'] +2 ['b', 'd', 'f'] +3 ['g'] + >>> def take(n, seq): ... return list(islice(seq, n)) @@ -164,6 +164,11 @@ Extension modules SF bug #812202). Generators that do not define genrandbits() now issue a warning when randrange() is called with a range that large. +- itertools has a new function, groupby() for aggregating iterables + into groups sharing the same key (as determined by a key function). + It offers some of functionality of SQL's groupby keyword and of + the Unix uniq filter. + - itertools now has a new function, tee() which produces two independent iterators from a single iterable. diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index a341a66..387133c 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -7,6 +7,323 @@ All rights reserved. */ + +/* groupby object ***********************************************************/ + +typedef struct { + PyObject_HEAD + PyObject *it; + PyObject *keyfunc; + PyObject *tgtkey; + PyObject *currkey; + PyObject *currvalue; +} groupbyobject; + +static PyTypeObject groupby_type; +static PyObject *_grouper_create(groupbyobject *, PyObject *); + +static PyObject * +groupby_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + static char *kwargs[] = {"iterable", "key", NULL}; + groupbyobject *gbo; + PyObject *it, *keyfunc = Py_None; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:groupby", kwargs, + &it, &keyfunc)) + return NULL; + + gbo = (groupbyobject *)type->tp_alloc(type, 0); + if (gbo == NULL) + return NULL; + gbo->tgtkey = NULL; + gbo->currkey = NULL; + gbo->currvalue = NULL; + gbo->keyfunc = keyfunc; + Py_INCREF(keyfunc); + gbo->it = PyObject_GetIter(it); + if (gbo->it == NULL) { + Py_DECREF(gbo); + return NULL; + } + return (PyObject *)gbo; +} + +static void +groupby_dealloc(groupbyobject *gbo) +{ + PyObject_GC_UnTrack(gbo); + Py_XDECREF(gbo->it); + Py_XDECREF(gbo->keyfunc); + Py_XDECREF(gbo->tgtkey); + Py_XDECREF(gbo->currkey); + Py_XDECREF(gbo->currvalue); + gbo->ob_type->tp_free(gbo); +} + +static int +groupby_traverse(groupbyobject *gbo, visitproc visit, void *arg) +{ + int err; + + if (gbo->it) { + err = visit(gbo->it, arg); + if (err) + return err; + } + if (gbo->keyfunc) { + err = visit(gbo->keyfunc, arg); + if (err) + return err; + } + if (gbo->tgtkey) { + err = visit(gbo->tgtkey, arg); + if (err) + return err; + } + if (gbo->currkey) { + err = visit(gbo->currkey, arg); + if (err) + return err; + } + if (gbo->currvalue) { + err = visit(gbo->currvalue, arg); + if (err) + return err; + } + return 0; +} + +static PyObject * +groupby_next(groupbyobject *gbo) +{ + PyObject *newvalue, *newkey, *r, *grouper; + + /* skip to next iteration group */ + for (;;) { + if (gbo->currkey == NULL) + /* pass */; + else if (gbo->tgtkey == NULL) + break; + else { + int rcmp; + + rcmp = PyObject_RichCompareBool(gbo->tgtkey, + gbo->currkey, Py_EQ); + if (rcmp == -1) + return NULL; + else if (rcmp == 0) + break; + } + + newvalue = PyIter_Next(gbo->it); + if (newvalue == NULL) + return NULL; + + if (gbo->keyfunc == Py_None) { + newkey = newvalue; + Py_INCREF(newvalue); + } else { + newkey = PyObject_CallFunctionObjArgs(gbo->keyfunc, + newvalue, NULL); + if (newkey == NULL) { + Py_DECREF(newvalue); + return NULL; + } + } + + Py_XDECREF(gbo->currkey); + gbo->currkey = newkey; + Py_XDECREF(gbo->currvalue); + gbo->currvalue = newvalue; + } + + Py_XDECREF(gbo->tgtkey); + gbo->tgtkey = gbo->currkey; + Py_INCREF(gbo->currkey); + + grouper = _grouper_create(gbo, gbo->tgtkey); + if (grouper == NULL) + return NULL; + + r = PyTuple_Pack(2, gbo->currkey, grouper); + Py_DECREF(grouper); + return r; +} + +PyDoc_STRVAR(groupby_doc, +"groupby(iterable[, keyfunc]) -> create an iterator which returns\n\ +(key, sub-iterator) grouped by each value of key(value).\n"); + +static PyTypeObject groupby_type = { + PyObject_HEAD_INIT(NULL) + 0, /* ob_size */ + "itertools.groupby", /* tp_name */ + sizeof(groupbyobject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)groupby_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + groupby_doc, /* tp_doc */ + (traverseproc)groupby_traverse, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + (iternextfunc)groupby_next, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + groupby_new, /* tp_new */ + PyObject_GC_Del, /* tp_free */ +}; + + +/* _grouper object (internal) ************************************************/ + +typedef struct { + PyObject_HEAD + PyObject *parent; + PyObject *tgtkey; +} _grouperobject; + +static PyTypeObject _grouper_type; + +static PyObject * +_grouper_create(groupbyobject *parent, PyObject *tgtkey) +{ + _grouperobject *igo; + + igo = PyObject_New(_grouperobject, &_grouper_type); + if (igo == NULL) + return NULL; + igo->parent = (PyObject *)parent; + Py_INCREF(parent); + igo->tgtkey = tgtkey; + Py_INCREF(tgtkey); + + return (PyObject *)igo; +} + +static void +_grouper_dealloc(_grouperobject *igo) +{ + Py_DECREF(igo->parent); + Py_DECREF(igo->tgtkey); + PyObject_Del(igo); +} + +static PyObject * +_grouper_next(_grouperobject *igo) +{ + groupbyobject *gbo = (groupbyobject *)igo->parent; + PyObject *newvalue, *newkey, *r; + int rcmp; + + if (gbo->currvalue == NULL) { + newvalue = PyIter_Next(gbo->it); + if (newvalue == NULL) + return NULL; + + if (gbo->keyfunc == Py_None) { + newkey = newvalue; + Py_INCREF(newvalue); + } else { + newkey = PyObject_CallFunctionObjArgs(gbo->keyfunc, + newvalue, NULL); + if (newkey == NULL) { + Py_DECREF(newvalue); + return NULL; + } + } + + assert(gbo->currkey == NULL); + gbo->currkey = newkey; + gbo->currvalue = newvalue; + } + + assert(gbo->currkey != NULL); + rcmp = PyObject_RichCompareBool(igo->tgtkey, gbo->currkey, Py_EQ); + if (rcmp <= 0) + /* got any error or current group is end */ + return NULL; + + r = gbo->currvalue; + gbo->currvalue = NULL; + Py_DECREF(gbo->currkey); + gbo->currkey = NULL; + + return r; +} + +static PyTypeObject _grouper_type = { + PyObject_HEAD_INIT(NULL) + 0, /* ob_size */ + "itertools._grouper", /* tp_name */ + sizeof(_grouperobject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)_grouper_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + 0, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + (iternextfunc)_grouper_next, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + 0, /* tp_new */ + PyObject_Del, /* tp_free */ +}; + + + /* tee object and with supporting function and objects ***************/ /* The teedataobject pre-allocates space for LINKCELLS number of objects. @@ -2103,6 +2420,7 @@ tee(it, n=2) --> (it1, it2 , ... itn) splits one iterator into n\n\ chain(p, q, ...) --> p0, p1, ... plast, q0, q1, ... \n\ takewhile(pred, seq) --> seq[0], seq[1], until pred fails\n\ dropwhile(pred, seq) --> seq[n], seq[n+1], starting when pred fails\n\ +groupby(iterable[, keyfunc]) --> sub-iterators grouped by value of keyfunc(v)\n\ "); @@ -2130,6 +2448,7 @@ inititertools(void) &count_type, &izip_type, &repeat_type, + &groupby_type, NULL }; @@ -2148,5 +2467,6 @@ inititertools(void) return; if (PyType_Ready(&tee_type) < 0) return; - + if (PyType_Ready(&_grouper_type) < 0) + return; } |