diff options
-rw-r--r-- | Lib/test/test_itertools.py | 31 | ||||
-rw-r--r-- | Misc/NEWS | 2 | ||||
-rw-r--r-- | Modules/itertoolsmodule.c | 274 |
3 files changed, 293 insertions, 14 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 4197989..0692747 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -47,15 +47,6 @@ def fact(n): 'Factorial' return prod(range(1, n+1)) -def permutations(iterable, r=None): - # XXX use this until real permutations code is added - pool = tuple(iterable) - n = len(pool) - r = n if r is None else r - for indices in product(range(n), repeat=r): - if len(set(indices)) == r: - yield tuple(pool[i] for i in indices) - class TestBasicOps(unittest.TestCase): def test_chain(self): self.assertEqual(list(chain('abc', 'def')), list('abcdef')) @@ -117,6 +108,8 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(len(set(c)), r) # no duplicate elements self.assertEqual(list(c), sorted(c)) # keep original ordering self.assert_(all(e in values for e in c)) # elements taken from input iterable + self.assertEqual(list(c), + [e for e in values if e in c]) # comb is a subsequence of the input iterable self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version @@ -127,9 +120,10 @@ class TestBasicOps(unittest.TestCase): def test_permutations(self): self.assertRaises(TypeError, permutations) # too few arguments self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments -## self.assertRaises(TypeError, permutations, None) # pool is not iterable -## self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative -## self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big + self.assertRaises(TypeError, permutations, None) # pool is not iterable + self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative + self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big + self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None self.assertEqual(list(permutations(range(3), 2)), [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) @@ -182,7 +176,7 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(result, list(permutations(values))) # test default r # Test implementation detail: tuple re-use -## self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1) + self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1) self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1) def test_count(self): @@ -407,12 +401,23 @@ class TestBasicOps(unittest.TestCase): list(product(*args, **dict(repeat=r)))) self.assertEqual(len(list(product(*[range(7)]*6))), 7**6) self.assertRaises(TypeError, product, range(6), None) + + def product2(*args, **kwds): + 'Pure python version used in docs' + pools = map(tuple, args) * kwds.get('repeat', 1) + result = [[]] + for pool in pools: + result = [x+[y] for x in result for y in pool] + for prod in result: + yield tuple(prod) + argtypes = ['', 'abc', '', xrange(0), xrange(4), dict(a=1, b=2, c=3), set('abcdefg'), range(11), tuple(range(13))] for i in range(100): args = [random.choice(argtypes) for j in range(random.randrange(5))] expected_len = prod(map(len, args)) self.assertEqual(len(list(product(*args))), expected_len) + self.assertEqual(list(product(*args)), list(product2(*args))) args = map(iter, args) self.assertEqual(len(list(product(*args))), expected_len) @@ -699,7 +699,7 @@ Library - Added itertools.product() which forms the Cartesian product of the input iterables. -- Added itertools.combinations(). +- Added itertools.combinations() and itertools.permutations(). - Patch #1541463: optimize performance of cgi.FieldStorage operations. diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index d8e14d0..fe06ef4 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -2238,6 +2238,279 @@ static PyTypeObject combinations_type = { }; +/* permutations object ************************************************************ + +def permutations(iterable, r=None): + 'permutations(range(3), 2) --> (0,1) (0,2) (1,0) (1,2) (2,0) (2,1)' + pool = tuple(iterable) + n = len(pool) + r = n if r is None else r + indices = range(n) + cycles = range(n-r+1, n+1)[::-1] + yield tuple(pool[i] for i in indices[:r]) + while n: + for i in reversed(range(r)): + cycles[i] -= 1 + if cycles[i] == 0: + indices[i:] = indices[i+1:] + indices[i:i+1] + cycles[i] = n - i + else: + j = cycles[i] + indices[i], indices[-j] = indices[-j], indices[i] + yield tuple(pool[i] for i in indices[:r]) + break + else: + return +*/ + +typedef struct { + PyObject_HEAD + PyObject *pool; /* input converted to a tuple */ + Py_ssize_t *indices; /* one index per element in the pool */ + Py_ssize_t *cycles; /* one rollover counter per element in the result */ + PyObject *result; /* most recently returned result tuple */ + Py_ssize_t r; /* size of result tuple */ + int stopped; /* set to 1 when the permutations iterator is exhausted */ +} permutationsobject; + +static PyTypeObject permutations_type; + +static PyObject * +permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + permutationsobject *po; + Py_ssize_t n; + Py_ssize_t r; + PyObject *robj = Py_None; + PyObject *pool = NULL; + PyObject *iterable = NULL; + Py_ssize_t *indices = NULL; + Py_ssize_t *cycles = NULL; + Py_ssize_t i; + static char *kwargs[] = {"iterable", "r", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:permutations", kwargs, + &iterable, &robj)) + return NULL; + + pool = PySequence_Tuple(iterable); + if (pool == NULL) + goto error; + n = PyTuple_GET_SIZE(pool); + + r = n; + if (robj != Py_None) { + r = PyInt_AsSsize_t(robj); + if (r == -1 && PyErr_Occurred()) + goto error; + } + if (r < 0) { + PyErr_SetString(PyExc_ValueError, "r must be non-negative"); + goto error; + } + if (r > n) { + PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable"); + goto error; + } + + indices = PyMem_Malloc(n * sizeof(Py_ssize_t)); + cycles = PyMem_Malloc(r * sizeof(Py_ssize_t)); + if (indices == NULL || cycles == NULL) { + PyErr_NoMemory(); + goto error; + } + + for (i=0 ; i<n ; i++) + indices[i] = i; + for (i=0 ; i<r ; i++) + cycles[i] = n - i; + + /* create permutationsobject structure */ + po = (permutationsobject *)type->tp_alloc(type, 0); + if (po == NULL) + goto error; + + po->pool = pool; + po->indices = indices; + po->cycles = cycles; + po->result = NULL; + po->r = r; + po->stopped = 0; + + return (PyObject *)po; + +error: + if (indices != NULL) + PyMem_Free(indices); + if (cycles != NULL) + PyMem_Free(cycles); + Py_XDECREF(pool); + return NULL; +} + +static void +permutations_dealloc(permutationsobject *po) +{ + PyObject_GC_UnTrack(po); + Py_XDECREF(po->pool); + Py_XDECREF(po->result); + PyMem_Free(po->indices); + PyMem_Free(po->cycles); + Py_TYPE(po)->tp_free(po); +} + +static int +permutations_traverse(permutationsobject *po, visitproc visit, void *arg) +{ + if (po->pool != NULL) + Py_VISIT(po->pool); + if (po->result != NULL) + Py_VISIT(po->result); + return 0; +} + +static PyObject * +permutations_next(permutationsobject *po) +{ + PyObject *elem; + PyObject *oldelem; + PyObject *pool = po->pool; + Py_ssize_t *indices = po->indices; + Py_ssize_t *cycles = po->cycles; + PyObject *result = po->result; + Py_ssize_t n = PyTuple_GET_SIZE(pool); + Py_ssize_t r = po->r; + Py_ssize_t i, j, k, index; + + if (po->stopped) + return NULL; + + if (result == NULL) { + /* On the first pass, initialize result tuple using the indices */ + result = PyTuple_New(r); + if (result == NULL) + goto empty; + po->result = result; + for (i=0; i<r ; i++) { + index = indices[i]; + elem = PyTuple_GET_ITEM(pool, index); + Py_INCREF(elem); + PyTuple_SET_ITEM(result, i, elem); + } + } else { + if (n == 0) + goto empty; + + /* Copy the previous result tuple or re-use it if available */ + if (Py_REFCNT(result) > 1) { + PyObject *old_result = result; + result = PyTuple_New(r); + if (result == NULL) + goto empty; + po->result = result; + for (i=0; i<r ; i++) { + elem = PyTuple_GET_ITEM(old_result, i); + Py_INCREF(elem); + PyTuple_SET_ITEM(result, i, elem); + } + Py_DECREF(old_result); + } + /* Now, we've got the only copy so we can update it in-place */ + assert(r == 0 || Py_REFCNT(result) == 1); + + /* Decrement rightmost cycle, moving leftward upon zero rollover */ + for (i=r-1 ; i>=0 ; i--) { + cycles[i] -= 1; + if (cycles[i] == 0) { + /* rotatation: indices[i:] = indices[i+1:] + indices[i:i+1] */ + index = indices[i]; + for (j=i ; j<n-1 ; j++) + indices[j] = indices[j+1]; + indices[n-1] = index; + cycles[i] = n - i; + } else { + j = cycles[i]; + index = indices[i]; + indices[i] = indices[n-j]; + indices[n-j] = index; + + for (k=i; k<r ; k++) { + /* start with i, the leftmost element that changed */ + /* yield tuple(pool[k] for k in indices[:r]) */ + index = indices[k]; + elem = PyTuple_GET_ITEM(pool, index); + Py_INCREF(elem); + oldelem = PyTuple_GET_ITEM(result, k); + PyTuple_SET_ITEM(result, k, elem); + Py_DECREF(oldelem); + } + break; + } + } + /* If i is negative, then the cycles have all + rolled-over and we're done. */ + if (i < 0) + goto empty; + } + Py_INCREF(result); + return result; + +empty: + po->stopped = 1; + return NULL; +} + +PyDoc_STRVAR(permutations_doc, +"permutations(iterables[, r]) --> permutations object\n\ +\n\ +Return successive r-length permutations of elements in the iterable.\n\n\ +permutations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)"); + +static PyTypeObject permutations_type = { + PyVarObject_HEAD_INIT(NULL, 0) + "itertools.permutations", /* tp_name */ + sizeof(permutationsobject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)permutations_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + permutations_doc, /* tp_doc */ + (traverseproc)permutations_traverse, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + (iternextfunc)permutations_next, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + permutations_new, /* tp_new */ + PyObject_GC_Del, /* tp_free */ +}; + + /* ifilter object ************************************************************/ typedef struct { @@ -3295,6 +3568,7 @@ inititertools(void) &count_type, &izip_type, &iziplongest_type, + &permutations_type, &product_type, &repeat_type, &groupby_type, |