From 50986cc45bfdfd23fd49cd46148b42ea763cfefd Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Fri, 22 Feb 2008 03:16:42 +0000 Subject: First draft for itertools.product(). Docs and other updates forthcoming. --- Lib/test/test_itertools.py | 28 ++++++ Modules/itertoolsmodule.c | 213 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 240 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 9d19228..e65bba7 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -253,6 +253,28 @@ class TestBasicOps(unittest.TestCase): ids = map(id, list(izip_longest('abc', 'def'))) self.assertEqual(len(dict.fromkeys(ids)), len(ids)) + def test_product(self): + for args, result in [ + ([], []), # zero iterables ??? is this correct + (['ab'], [('a',), ('b',)]), # one iterable + ([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]), # two iterables + ([range(0), range(2), range(3)], []), # first iterable with zero length + ([range(2), range(0), range(3)], []), # middle iterable with zero length + ([range(2), range(3), range(0)], []), # last iterable with zero length + ]: + self.assertEqual(list(product(*args)), result) + self.assertEqual(len(list(product(*[range(7)]*6))), 7**6) + self.assertRaises(TypeError, product, range(6), None) + argtypes = ['', 'abc', '', xrange(0), xrange(4), dict(a=1, b=2, c=3), + set('abcdefg'), range(11), tuple(range(13))] + for i in range(100): + args = [random.choice(argtypes) for j in range(random.randrange(5))] + n = reduce(operator.mul, map(len, args), 1) if args else 0 + self.assertEqual(len(list(product(*args))), n) + args = map(iter, args) + self.assertEqual(len(list(product(*args))), n) + + def test_repeat(self): self.assertEqual(zip(xrange(3),repeat('a')), [(0, 'a'), (1, 'a'), (2, 'a')]) @@ -623,6 +645,12 @@ class TestVariousIteratorArgs(unittest.TestCase): self.assertRaises(TypeError, list, chain(N(s))) self.assertRaises(ZeroDivisionError, list, chain(E(s))) + def test_product(self): + for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)): + self.assertRaises(TypeError, product, X(s)) + self.assertRaises(TypeError, product, N(s)) + self.assertRaises(ZeroDivisionError, product, E(s)) + def test_cycle(self): for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)): for g in (G, I, Ig, S, L, R): diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 430313e..8929309 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -1741,6 +1741,216 @@ static PyTypeObject chain_type = { }; +/* product object ************************************************************/ + +typedef struct { + PyObject_HEAD + PyObject *pools; /* tuple of pool tuples */ + Py_ssize_t *maxvec; + Py_ssize_t *indices; + PyObject *result; + int stopped; +} productobject; + +static PyTypeObject product_type; + +static PyObject * +product_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + productobject *lz; + Py_ssize_t npools; + PyObject *pools = NULL; + Py_ssize_t *maxvec = NULL; + Py_ssize_t *indices = NULL; + Py_ssize_t i; + + if (type == &product_type && !_PyArg_NoKeywords("product()", kwds)) + return NULL; + + assert(PyTuple_Check(args)); + npools = PyTuple_GET_SIZE(args); + + maxvec = PyMem_Malloc(npools * sizeof(Py_ssize_t)); + indices = PyMem_Malloc(npools * sizeof(Py_ssize_t)); + if (maxvec == NULL || indices == NULL) { + PyErr_NoMemory(); + goto error; + } + + pools = PyTuple_New(npools); + if (pools == NULL) + goto error; + + for (i=0; i < npools; ++i) { + PyObject *item = PyTuple_GET_ITEM(args, i); + PyObject *pool = PySequence_Tuple(item); + if (pool == NULL) + goto error; + + PyTuple_SET_ITEM(pools, i, pool); + maxvec[i] = PyTuple_GET_SIZE(pool); + indices[i] = 0; + } + + /* create productobject structure */ + lz = (productobject *)type->tp_alloc(type, 0); + if (lz == NULL) { + Py_DECREF(pools); + return NULL; + } + + lz->pools = pools; + lz->maxvec = maxvec; + lz->indices = indices; + lz->result = NULL; + lz->stopped = 0; + + return (PyObject *)lz; + +error: + if (maxvec != NULL) + PyMem_Free(maxvec); + if (indices != NULL) + PyMem_Free(indices); + Py_XDECREF(pools); + return NULL; +} + +static void +product_dealloc(productobject *lz) +{ + PyObject_GC_UnTrack(lz); + Py_XDECREF(lz->pools); + Py_XDECREF(lz->result); + PyMem_Free(lz->maxvec); + PyMem_Free(lz->indices); + Py_TYPE(lz)->tp_free(lz); +} + +static int +product_traverse(productobject *lz, visitproc visit, void *arg) +{ + Py_VISIT(lz->pools); + Py_VISIT(lz->result); + return 0; +} + +static PyObject * +product_next(productobject *lz) +{ + PyObject *pool; + PyObject *elem; + PyObject *tuple_result; + PyObject *pools = lz->pools; + PyObject *result = lz->result; + Py_ssize_t npools = PyTuple_GET_SIZE(pools); + Py_ssize_t i; + + if (lz->stopped) + return NULL; + if (result == NULL) { + if (npools == 0) + goto empty; + result = PyList_New(npools); + if (result == NULL) + goto empty; + lz->result = result; + for (i=0; i < npools; i++) { + pool = PyTuple_GET_ITEM(pools, i); + if (PyTuple_GET_SIZE(pool) == 0) + goto empty; + elem = PyTuple_GET_ITEM(pool, 0); + Py_INCREF(elem); + PyList_SET_ITEM(result, i, elem); + } + } else { + Py_ssize_t *indices = lz->indices; + Py_ssize_t *maxvec = lz->maxvec; + for (i=npools-1 ; i >= 0 ; i--) { + pool = PyTuple_GET_ITEM(pools, i); + indices[i]++; + if (indices[i] == maxvec[i]) { + indices[i] = 0; + elem = PyTuple_GET_ITEM(pool, 0); + Py_INCREF(elem); + PyList_SetItem(result, i, elem); + } else { + elem = PyTuple_GET_ITEM(pool, indices[i]); + Py_INCREF(elem); + PyList_SetItem(result, i, elem); + break; + } + } + if (i < 0) + return NULL; + } + + tuple_result = PySequence_Tuple(result); + if (tuple_result == NULL) + lz->stopped = 1; + return tuple_result; + +empty: + lz->stopped = 1; + return NULL; +} + +PyDoc_STRVAR(product_doc, +"product(*iterables) --> product object\n\ +\n\ +Cartesian product of input interables. Equivalent to nested for-loops.\n\n\ +For example, product(A, B) returns the same as: ((x,y) for x in A for y in B).\n\ +The leftmost iterators are in the outermost for-loop, so the output tuples\n\ +cycle in a manner similar to an odometer (with the rightmost element changing\n\ +on every iteration).\n\n\ +product('ab', range(3)) --> ('a',0) ('a',1) ('a',2) ('b',0) ('b',1) ('b',2)\n\ +product((0,1), (0,1), (0,1)) --> (0,0,0) (0,0,1) (0,1,0) (0,1,1) (1,0,0) ..."); + +static PyTypeObject product_type = { + PyVarObject_HEAD_INIT(NULL, 0) + "itertools.product", /* tp_name */ + sizeof(productobject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)product_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + product_doc, /* tp_doc */ + (traverseproc)product_traverse, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + (iternextfunc)product_next, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + product_new, /* tp_new */ + PyObject_GC_Del, /* tp_free */ +}; + + /* ifilter object ************************************************************/ typedef struct { @@ -2796,7 +3006,8 @@ inititertools(void) &ifilterfalse_type, &count_type, &izip_type, - &iziplongest_type, + &iziplongest_type, + &product_type, &repeat_type, &groupby_type, NULL -- cgit v0.12