diff options
-rw-r--r-- | Doc/library/itertools.rst | 7 | ||||
-rw-r--r-- | Lib/test/test_itertools.py | 20 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2017-09-24-13-08-46.bpo-30346.Csse77.rst | 2 | ||||
-rw-r--r-- | Modules/itertoolsmodule.c | 8 |
4 files changed, 34 insertions, 3 deletions
diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index c989e46..530c29d 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -401,13 +401,14 @@ loops that truncate the stream. def __iter__(self): return self def __next__(self): + self.id = object() while self.currkey == self.tgtkey: self.currvalue = next(self.it) # Exit on StopIteration self.currkey = self.keyfunc(self.currvalue) self.tgtkey = self.currkey - return (self.currkey, self._grouper(self.tgtkey)) - def _grouper(self, tgtkey): - while self.currkey == tgtkey: + return (self.currkey, self._grouper(self.tgtkey, self.id)) + def _grouper(self, tgtkey, id): + while self.id is id and self.currkey == tgtkey: yield self.currvalue try: self.currvalue = next(self.it) diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 50cf148..8353e68 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -751,6 +751,26 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(set(keys), expectedkeys) self.assertEqual(len(keys), len(expectedkeys)) + # Check case where inner iterator is used after advancing the groupby + # iterator + s = list(zip('AABBBAAAA', range(9))) + it = groupby(s, testR) + _, g1 = next(it) + _, g2 = next(it) + _, g3 = next(it) + self.assertEqual(list(g1), []) + self.assertEqual(list(g2), []) + self.assertEqual(next(g3), ('A', 5)) + list(it) # exhaust the groupby iterator + self.assertEqual(list(g3), []) + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + it = groupby(s, testR) + _, g = next(it) + next(it) + next(it) + self.assertEqual(list(pickle.loads(pickle.dumps(g, proto))), []) + # Exercise pipes and filters style s = 'abracadabra' # sort s | uniq diff --git a/Misc/NEWS.d/next/Library/2017-09-24-13-08-46.bpo-30346.Csse77.rst b/Misc/NEWS.d/next/Library/2017-09-24-13-08-46.bpo-30346.Csse77.rst new file mode 100644 index 0000000..81ad053 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2017-09-24-13-08-46.bpo-30346.Csse77.rst @@ -0,0 +1,2 @@ +An iterator produced by itertools.groupby() iterator now becames exhausted +after advancing the groupby iterator. diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 48e6c35..2ac5ab2 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -17,6 +17,7 @@ typedef struct { PyObject *tgtkey; PyObject *currkey; PyObject *currvalue; + const void *currgrouper; /* borrowed reference */ } groupbyobject; static PyTypeObject groupby_type; @@ -77,6 +78,7 @@ groupby_next(groupbyobject *gbo) { PyObject *newvalue, *newkey, *r, *grouper; + gbo->currgrouper = NULL; /* skip to next iteration group */ for (;;) { if (gbo->currkey == NULL) @@ -255,6 +257,7 @@ _grouper_create(groupbyobject *parent, PyObject *tgtkey) Py_INCREF(parent); igo->tgtkey = tgtkey; Py_INCREF(tgtkey); + parent->currgrouper = igo; /* borrowed reference */ PyObject_GC_Track(igo); return (PyObject *)igo; @@ -284,6 +287,8 @@ _grouper_next(_grouperobject *igo) PyObject *newvalue, *newkey, *r; int rcmp; + if (gbo->currgrouper != igo) + return NULL; if (gbo->currvalue == NULL) { newvalue = PyIter_Next(gbo->it); if (newvalue == NULL) @@ -321,6 +326,9 @@ _grouper_next(_grouperobject *igo) static PyObject * _grouper_reduce(_grouperobject *lz) { + if (((groupbyobject *)lz->parent)->currgrouper != lz) { + return Py_BuildValue("N(())", _PyObject_GetBuiltin("iter")); + } return Py_BuildValue("O(OO)", Py_TYPE(lz), lz->parent, lz->tgtkey); } |