diff options
author | Raymond Hettinger <python@rcn.com> | 2008-02-23 02:20:41 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2008-02-23 02:20:41 (GMT) |
commit | 73d796324242dc2164a0b5943bd08d6252a28651 (patch) | |
tree | 3c1770c2ea596e727e3c69604033016d02ea1574 | |
parent | c5705a823bb200b48417677f8ef3ca6833ace4bb (diff) | |
download | cpython-73d796324242dc2164a0b5943bd08d6252a28651.zip cpython-73d796324242dc2164a0b5943bd08d6252a28651.tar.gz cpython-73d796324242dc2164a0b5943bd08d6252a28651.tar.bz2 |
Improve the implementation of itertools.product()
* Fix-up issues pointed-out by Neal Norwitz.
* Add extensive comments.
* The lz->result variable is now a tuple instead of a list.
* Use fast macro getitem/setitem calls so most code is in-line.
* Re-use the result tuple if available (modify in-place instead of copy).
-rw-r--r-- | Lib/test/test_itertools.py | 3 | ||||
-rw-r--r-- | Modules/itertoolsmodule.c | 55 |
2 files changed, 46 insertions, 12 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index e65bba7..f5dd069 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -274,6 +274,9 @@ class TestBasicOps(unittest.TestCase): args = map(iter, args) self.assertEqual(len(list(product(*args))), n) + # Test implementation detail: tuple re-use + self.assertEqual(len(set(map(id, product('abc', 'def')))), 1) + self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1) def test_repeat(self): self.assertEqual(zip(xrange(3),repeat('a')), diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 8929309..5a3b03f 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -1796,7 +1796,7 @@ product_new(PyTypeObject *type, PyObject *args, PyObject *kwds) lz = (productobject *)type->tp_alloc(type, 0); if (lz == NULL) { Py_DECREF(pools); - return NULL; + goto error; } lz->pools = pools; @@ -1840,7 +1840,7 @@ product_next(productobject *lz) { PyObject *pool; PyObject *elem; - PyObject *tuple_result; + PyObject *oldelem; PyObject *pools = lz->pools; PyObject *result = lz->result; Py_ssize_t npools = PyTuple_GET_SIZE(pools); @@ -1848,10 +1848,14 @@ product_next(productobject *lz) if (lz->stopped) return NULL; + if (result == NULL) { + /* On the first pass, return an initial tuple filled with the + first element from each pool. If any pool is empty, then + whole product is empty and we're already done */ if (npools == 0) goto empty; - result = PyList_New(npools); + result = PyTuple_New(npools); if (result == NULL) goto empty; lz->result = result; @@ -1861,34 +1865,61 @@ product_next(productobject *lz) goto empty; elem = PyTuple_GET_ITEM(pool, 0); Py_INCREF(elem); - PyList_SET_ITEM(result, i, elem); + PyTuple_SET_ITEM(result, i, elem); } } else { Py_ssize_t *indices = lz->indices; Py_ssize_t *maxvec = lz->maxvec; + + /* Copy the previous result tuple or re-use it if available */ + if (Py_REFCNT(result) > 1) { + PyObject *old_result = result; + result = PyTuple_New(npools); + if (result == NULL) + goto empty; + lz->result = result; + for (i=0; i < npools; i++) { + elem = PyTuple_GET_ITEM(old_result, i); + Py_INCREF(elem); + PyTuple_SET_ITEM(result, i, elem); + } + Py_DECREF(old_result); + } + /* Now, we've got the only copy so we can update it in-place */ + assert (Py_REFCNT(result) == 1); + + /* Update the pool indices right-to-left. Only advance to the + next pool when the previous one rolls-over */ for (i=npools-1 ; i >= 0 ; i--) { pool = PyTuple_GET_ITEM(pools, i); indices[i]++; if (indices[i] == maxvec[i]) { + /* Roll-over and advance to next pool */ indices[i] = 0; elem = PyTuple_GET_ITEM(pool, 0); Py_INCREF(elem); - PyList_SetItem(result, i, elem); + oldelem = PyTuple_GET_ITEM(result, i); + PyTuple_SET_ITEM(result, i, elem); + Py_DECREF(oldelem); } else { + /* No rollover. Just increment and stop here. */ elem = PyTuple_GET_ITEM(pool, indices[i]); Py_INCREF(elem); - PyList_SetItem(result, i, elem); + oldelem = PyTuple_GET_ITEM(result, i); + PyTuple_SET_ITEM(result, i, elem); + Py_DECREF(oldelem); break; } } + + /* If i is negative, then the indices have all rolled-over + and we're done. */ if (i < 0) - return NULL; + goto empty; } - tuple_result = PySequence_Tuple(result); - if (tuple_result == NULL) - lz->stopped = 1; - return tuple_result; + Py_INCREF(result); + return result; empty: lz->stopped = 1; @@ -1898,7 +1929,7 @@ empty: PyDoc_STRVAR(product_doc, "product(*iterables) --> product object\n\ \n\ -Cartesian product of input interables. Equivalent to nested for-loops.\n\n\ +Cartesian product of input iterables. Equivalent to nested for-loops.\n\n\ For example, product(A, B) returns the same as: ((x,y) for x in A for y in B).\n\ The leftmost iterators are in the outermost for-loop, so the output tuples\n\ cycle in a manner similar to an odometer (with the rightmost element changing\n\ |