summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_itertools.py31
-rw-r--r--Misc/NEWS2
-rw-r--r--Modules/itertoolsmodule.c274
3 files changed, 293 insertions, 14 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index 4197989..0692747 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -47,15 +47,6 @@ def fact(n):
'Factorial'
return prod(range(1, n+1))
-def permutations(iterable, r=None):
- # XXX use this until real permutations code is added
- pool = tuple(iterable)
- n = len(pool)
- r = n if r is None else r
- for indices in product(range(n), repeat=r):
- if len(set(indices)) == r:
- yield tuple(pool[i] for i in indices)
-
class TestBasicOps(unittest.TestCase):
def test_chain(self):
self.assertEqual(list(chain('abc', 'def')), list('abcdef'))
@@ -117,6 +108,8 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(set(c)), r) # no duplicate elements
self.assertEqual(list(c), sorted(c)) # keep original ordering
self.assert_(all(e in values for e in c)) # elements taken from input iterable
+ self.assertEqual(list(c),
+ [e for e in values if e in c]) # comb is a subsequence of the input iterable
self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version
@@ -127,9 +120,10 @@ class TestBasicOps(unittest.TestCase):
def test_permutations(self):
self.assertRaises(TypeError, permutations) # too few arguments
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
-## self.assertRaises(TypeError, permutations, None) # pool is not iterable
-## self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
-## self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big
+ self.assertRaises(TypeError, permutations, None) # pool is not iterable
+ self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
+ self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big
+ self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
self.assertEqual(list(permutations(range(3), 2)),
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
@@ -182,7 +176,7 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(result, list(permutations(values))) # test default r
# Test implementation detail: tuple re-use
-## self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
+ self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
def test_count(self):
@@ -407,12 +401,23 @@ class TestBasicOps(unittest.TestCase):
list(product(*args, **dict(repeat=r))))
self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
self.assertRaises(TypeError, product, range(6), None)
+
+ def product2(*args, **kwds):
+ 'Pure python version used in docs'
+ pools = map(tuple, args) * kwds.get('repeat', 1)
+ result = [[]]
+ for pool in pools:
+ result = [x+[y] for x in result for y in pool]
+ for prod in result:
+ yield tuple(prod)
+
argtypes = ['', 'abc', '', xrange(0), xrange(4), dict(a=1, b=2, c=3),
set('abcdefg'), range(11), tuple(range(13))]
for i in range(100):
args = [random.choice(argtypes) for j in range(random.randrange(5))]
expected_len = prod(map(len, args))
self.assertEqual(len(list(product(*args))), expected_len)
+ self.assertEqual(list(product(*args)), list(product2(*args)))
args = map(iter, args)
self.assertEqual(len(list(product(*args))), expected_len)
diff --git a/Misc/NEWS b/Misc/NEWS
index 75d3cc1..ff97473 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -699,7 +699,7 @@ Library
- Added itertools.product() which forms the Cartesian product of
the input iterables.
-- Added itertools.combinations().
+- Added itertools.combinations() and itertools.permutations().
- Patch #1541463: optimize performance of cgi.FieldStorage operations.
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index d8e14d0..fe06ef4 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -2238,6 +2238,279 @@ static PyTypeObject combinations_type = {
};
+/* permutations object ************************************************************
+
+def permutations(iterable, r=None):
+ 'permutations(range(3), 2) --> (0,1) (0,2) (1,0) (1,2) (2,0) (2,1)'
+ pool = tuple(iterable)
+ n = len(pool)
+ r = n if r is None else r
+ indices = range(n)
+ cycles = range(n-r+1, n+1)[::-1]
+ yield tuple(pool[i] for i in indices[:r])
+ while n:
+ for i in reversed(range(r)):
+ cycles[i] -= 1
+ if cycles[i] == 0:
+ indices[i:] = indices[i+1:] + indices[i:i+1]
+ cycles[i] = n - i
+ else:
+ j = cycles[i]
+ indices[i], indices[-j] = indices[-j], indices[i]
+ yield tuple(pool[i] for i in indices[:r])
+ break
+ else:
+ return
+*/
+
+typedef struct {
+ PyObject_HEAD
+ PyObject *pool; /* input converted to a tuple */
+ Py_ssize_t *indices; /* one index per element in the pool */
+ Py_ssize_t *cycles; /* one rollover counter per element in the result */
+ PyObject *result; /* most recently returned result tuple */
+ Py_ssize_t r; /* size of result tuple */
+ int stopped; /* set to 1 when the permutations iterator is exhausted */
+} permutationsobject;
+
+static PyTypeObject permutations_type;
+
+static PyObject *
+permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+ permutationsobject *po;
+ Py_ssize_t n;
+ Py_ssize_t r;
+ PyObject *robj = Py_None;
+ PyObject *pool = NULL;
+ PyObject *iterable = NULL;
+ Py_ssize_t *indices = NULL;
+ Py_ssize_t *cycles = NULL;
+ Py_ssize_t i;
+ static char *kwargs[] = {"iterable", "r", NULL};
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:permutations", kwargs,
+ &iterable, &robj))
+ return NULL;
+
+ pool = PySequence_Tuple(iterable);
+ if (pool == NULL)
+ goto error;
+ n = PyTuple_GET_SIZE(pool);
+
+ r = n;
+ if (robj != Py_None) {
+ r = PyInt_AsSsize_t(robj);
+ if (r == -1 && PyErr_Occurred())
+ goto error;
+ }
+ if (r < 0) {
+ PyErr_SetString(PyExc_ValueError, "r must be non-negative");
+ goto error;
+ }
+ if (r > n) {
+ PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
+ goto error;
+ }
+
+ indices = PyMem_Malloc(n * sizeof(Py_ssize_t));
+ cycles = PyMem_Malloc(r * sizeof(Py_ssize_t));
+ if (indices == NULL || cycles == NULL) {
+ PyErr_NoMemory();
+ goto error;
+ }
+
+ for (i=0 ; i<n ; i++)
+ indices[i] = i;
+ for (i=0 ; i<r ; i++)
+ cycles[i] = n - i;
+
+ /* create permutationsobject structure */
+ po = (permutationsobject *)type->tp_alloc(type, 0);
+ if (po == NULL)
+ goto error;
+
+ po->pool = pool;
+ po->indices = indices;
+ po->cycles = cycles;
+ po->result = NULL;
+ po->r = r;
+ po->stopped = 0;
+
+ return (PyObject *)po;
+
+error:
+ if (indices != NULL)
+ PyMem_Free(indices);
+ if (cycles != NULL)
+ PyMem_Free(cycles);
+ Py_XDECREF(pool);
+ return NULL;
+}
+
+static void
+permutations_dealloc(permutationsobject *po)
+{
+ PyObject_GC_UnTrack(po);
+ Py_XDECREF(po->pool);
+ Py_XDECREF(po->result);
+ PyMem_Free(po->indices);
+ PyMem_Free(po->cycles);
+ Py_TYPE(po)->tp_free(po);
+}
+
+static int
+permutations_traverse(permutationsobject *po, visitproc visit, void *arg)
+{
+ if (po->pool != NULL)
+ Py_VISIT(po->pool);
+ if (po->result != NULL)
+ Py_VISIT(po->result);
+ return 0;
+}
+
+static PyObject *
+permutations_next(permutationsobject *po)
+{
+ PyObject *elem;
+ PyObject *oldelem;
+ PyObject *pool = po->pool;
+ Py_ssize_t *indices = po->indices;
+ Py_ssize_t *cycles = po->cycles;
+ PyObject *result = po->result;
+ Py_ssize_t n = PyTuple_GET_SIZE(pool);
+ Py_ssize_t r = po->r;
+ Py_ssize_t i, j, k, index;
+
+ if (po->stopped)
+ return NULL;
+
+ if (result == NULL) {
+ /* On the first pass, initialize result tuple using the indices */
+ result = PyTuple_New(r);
+ if (result == NULL)
+ goto empty;
+ po->result = result;
+ for (i=0; i<r ; i++) {
+ index = indices[i];
+ elem = PyTuple_GET_ITEM(pool, index);
+ Py_INCREF(elem);
+ PyTuple_SET_ITEM(result, i, elem);
+ }
+ } else {
+ if (n == 0)
+ goto empty;
+
+ /* Copy the previous result tuple or re-use it if available */
+ if (Py_REFCNT(result) > 1) {
+ PyObject *old_result = result;
+ result = PyTuple_New(r);
+ if (result == NULL)
+ goto empty;
+ po->result = result;
+ for (i=0; i<r ; i++) {
+ elem = PyTuple_GET_ITEM(old_result, i);
+ Py_INCREF(elem);
+ PyTuple_SET_ITEM(result, i, elem);
+ }
+ Py_DECREF(old_result);
+ }
+ /* Now, we've got the only copy so we can update it in-place */
+ assert(r == 0 || Py_REFCNT(result) == 1);
+
+ /* Decrement rightmost cycle, moving leftward upon zero rollover */
+ for (i=r-1 ; i>=0 ; i--) {
+ cycles[i] -= 1;
+ if (cycles[i] == 0) {
+ /* rotatation: indices[i:] = indices[i+1:] + indices[i:i+1] */
+ index = indices[i];
+ for (j=i ; j<n-1 ; j++)
+ indices[j] = indices[j+1];
+ indices[n-1] = index;
+ cycles[i] = n - i;
+ } else {
+ j = cycles[i];
+ index = indices[i];
+ indices[i] = indices[n-j];
+ indices[n-j] = index;
+
+ for (k=i; k<r ; k++) {
+ /* start with i, the leftmost element that changed */
+ /* yield tuple(pool[k] for k in indices[:r]) */
+ index = indices[k];
+ elem = PyTuple_GET_ITEM(pool, index);
+ Py_INCREF(elem);
+ oldelem = PyTuple_GET_ITEM(result, k);
+ PyTuple_SET_ITEM(result, k, elem);
+ Py_DECREF(oldelem);
+ }
+ break;
+ }
+ }
+ /* If i is negative, then the cycles have all
+ rolled-over and we're done. */
+ if (i < 0)
+ goto empty;
+ }
+ Py_INCREF(result);
+ return result;
+
+empty:
+ po->stopped = 1;
+ return NULL;
+}
+
+PyDoc_STRVAR(permutations_doc,
+"permutations(iterables[, r]) --> permutations object\n\
+\n\
+Return successive r-length permutations of elements in the iterable.\n\n\
+permutations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)");
+
+static PyTypeObject permutations_type = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ "itertools.permutations", /* tp_name */
+ sizeof(permutationsobject), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ /* methods */
+ (destructor)permutations_dealloc, /* tp_dealloc */
+ 0, /* tp_print */
+ 0, /* tp_getattr */
+ 0, /* tp_setattr */
+ 0, /* tp_compare */
+ 0, /* tp_repr */
+ 0, /* tp_as_number */
+ 0, /* tp_as_sequence */
+ 0, /* tp_as_mapping */
+ 0, /* tp_hash */
+ 0, /* tp_call */
+ 0, /* tp_str */
+ PyObject_GenericGetAttr, /* tp_getattro */
+ 0, /* tp_setattro */
+ 0, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
+ Py_TPFLAGS_BASETYPE, /* tp_flags */
+ permutations_doc, /* tp_doc */
+ (traverseproc)permutations_traverse, /* tp_traverse */
+ 0, /* tp_clear */
+ 0, /* tp_richcompare */
+ 0, /* tp_weaklistoffset */
+ PyObject_SelfIter, /* tp_iter */
+ (iternextfunc)permutations_next, /* tp_iternext */
+ 0, /* tp_methods */
+ 0, /* tp_members */
+ 0, /* tp_getset */
+ 0, /* tp_base */
+ 0, /* tp_dict */
+ 0, /* tp_descr_get */
+ 0, /* tp_descr_set */
+ 0, /* tp_dictoffset */
+ 0, /* tp_init */
+ 0, /* tp_alloc */
+ permutations_new, /* tp_new */
+ PyObject_GC_Del, /* tp_free */
+};
+
+
/* ifilter object ************************************************************/
typedef struct {
@@ -3295,6 +3568,7 @@ inititertools(void)
&count_type,
&izip_type,
&iziplongest_type,
+ &permutations_type,
&product_type,
&repeat_type,
&groupby_type,