From 05bf6338b810d01e7bba5503fbd01f2b6216ca59 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Thu, 28 Feb 2008 22:30:42 +0000 Subject: Have itertools.chain() consume its inputs lazily instead of building a tuple of iterators at the outset. --- Lib/test/test_itertools.py | 4 +- Modules/itertoolsmodule.c | 102 ++++++++++++++++++++++----------------------- 2 files changed, 53 insertions(+), 53 deletions(-) diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 79c4b3a..41e9362 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -50,7 +50,7 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(list(chain('abc')), list('abc')) self.assertEqual(list(chain('')), []) self.assertEqual(take(4, chain('abc', 'def')), list('abcd')) - self.assertRaises(TypeError, chain, 2, 3) + self.assertRaises(TypeError, list,chain(2, 3)) def test_combinations(self): self.assertRaises(TypeError, combinations, 'abc') # missing r argument @@ -670,7 +670,7 @@ class TestVariousIteratorArgs(unittest.TestCase): for g in (G, I, Ig, S, L, R): self.assertEqual(list(chain(g(s))), list(g(s))) self.assertEqual(list(chain(g(s), g(s))), list(g(s))+list(g(s))) - self.assertRaises(TypeError, chain, X(s)) + self.assertRaises(TypeError, list, chain(X(s))) self.assertRaises(TypeError, list, chain(N(s))) self.assertRaises(ZeroDivisionError, list, chain(E(s))) diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 2ee947d..3b8339c 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -1601,92 +1601,92 @@ static PyTypeObject imap_type = { typedef struct { PyObject_HEAD - Py_ssize_t tuplesize; - Py_ssize_t iternum; /* which iterator is active */ - PyObject *ittuple; /* tuple of iterators */ + PyObject *source; /* Iterator over input iterables */ + PyObject *active; /* Currently running input iterator */ } chainobject; static PyTypeObject chain_type; -static PyObject * -chain_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +static PyObject * +chain_new_internal(PyTypeObject *type, PyObject *source) { chainobject *lz; - Py_ssize_t tuplesize = PySequence_Length(args); - Py_ssize_t i; - PyObject *ittuple; - - if (type == &chain_type && !_PyArg_NoKeywords("chain()", kwds)) - return NULL; - - /* obtain iterators */ - assert(PyTuple_Check(args)); - ittuple = PyTuple_New(tuplesize); - if (ittuple == NULL) - return NULL; - for (i=0; i < tuplesize; ++i) { - PyObject *item = PyTuple_GET_ITEM(args, i); - PyObject *it = PyObject_GetIter(item); - if (it == NULL) { - if (PyErr_ExceptionMatches(PyExc_TypeError)) - PyErr_Format(PyExc_TypeError, - "chain argument #%zd must support iteration", - i+1); - Py_DECREF(ittuple); - return NULL; - } - PyTuple_SET_ITEM(ittuple, i, it); - } - /* create chainobject structure */ lz = (chainobject *)type->tp_alloc(type, 0); if (lz == NULL) { - Py_DECREF(ittuple); + Py_DECREF(source); return NULL; } + + lz->source = source; + lz->active = NULL; + return (PyObject *)lz; +} - lz->ittuple = ittuple; - lz->iternum = 0; - lz->tuplesize = tuplesize; +static PyObject * +chain_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + PyObject *source; - return (PyObject *)lz; + if (type == &chain_type && !_PyArg_NoKeywords("chain()", kwds)) + return NULL; + + source = PyObject_GetIter(args); + if (source == NULL) + return NULL; + + return chain_new_internal(type, source); } static void chain_dealloc(chainobject *lz) { PyObject_GC_UnTrack(lz); - Py_XDECREF(lz->ittuple); + Py_XDECREF(lz->active); + Py_XDECREF(lz->source); Py_TYPE(lz)->tp_free(lz); } static int chain_traverse(chainobject *lz, visitproc visit, void *arg) { - Py_VISIT(lz->ittuple); + Py_VISIT(lz->source); + Py_VISIT(lz->active); return 0; } static PyObject * chain_next(chainobject *lz) { - PyObject *it; PyObject *item; - while (lz->iternum < lz->tuplesize) { - it = PyTuple_GET_ITEM(lz->ittuple, lz->iternum); - item = PyIter_Next(it); - if (item != NULL) - return item; - if (PyErr_Occurred()) { - if (PyErr_ExceptionMatches(PyExc_StopIteration)) - PyErr_Clear(); - else - return NULL; + if (lz->source == NULL) + return NULL; /* already stopped */ + + if (lz->active == NULL) { + PyObject *iterable = PyIter_Next(lz->source); + if (iterable == NULL) { + Py_CLEAR(lz->source); + return NULL; /* no more input sources */ + } + lz->active = PyObject_GetIter(iterable); + if (lz->active == NULL) { + Py_DECREF(iterable); + Py_CLEAR(lz->source); + return NULL; /* input not iterable */ } - lz->iternum++; } - return NULL; + item = PyIter_Next(lz->active); + if (item != NULL) + return item; + if (PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_StopIteration)) + PyErr_Clear(); + else + return NULL; /* input raised an exception */ + } + Py_CLEAR(lz->active); + return chain_next(lz); /* recurse and use next active */ } PyDoc_STRVAR(chain_doc, -- cgit v0.12