summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Doc/library/itertools.rst30
-rw-r--r--Lib/test/test_itertools.py24
-rw-r--r--Misc/NEWS2
-rw-r--r--Modules/itertoolsmodule.c226
4 files changed, 266 insertions, 16 deletions
diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst
index c8f6e33..6dc19a1 100644
--- a/Doc/library/itertools.rst
+++ b/Doc/library/itertools.rst
@@ -97,21 +97,21 @@ loops that truncate the stream.
def combinations(iterable, r):
pool = tuple(iterable)
- if pool:
- n = len(pool)
- vec = range(r)
- yield tuple(pool[i] for i in vec)
- while 1:
- for i in reversed(range(r)):
- if vec[i] == i + n-r:
- continue
- vec[i] += 1
- for j in range(i+1, r):
- vec[j] = vec[j-1] + 1
- yield tuple(pool[i] for i in vec)
- break
- else:
- return
+ n = len(pool)
+ assert 0 <= r <= n
+ vec = range(r)
+ yield tuple(pool[i] for i in vec)
+ while 1:
+ for i in reversed(range(r)):
+ if vec[i] == i + n-r:
+ continue
+ vec[i] += 1
+ for j in range(i+1, r):
+ vec[j] = vec[j-1] + 1
+ yield tuple(pool[i] for i in vec)
+ break
+ else:
+ return
.. versionadded:: 2.6
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index dc9081e..9868325 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -40,6 +40,10 @@ def take(n, seq):
'Convenience function for partially consuming a long of infinite iterable'
return list(islice(seq, n))
+def fact(n):
+ 'Factorial'
+ return reduce(operator.mul, range(1, n+1), 1)
+
class TestBasicOps(unittest.TestCase):
def test_chain(self):
self.assertEqual(list(chain('abc', 'def')), list('abcdef'))
@@ -48,6 +52,26 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(take(4, chain('abc', 'def')), list('abcd'))
self.assertRaises(TypeError, chain, 2, 3)
+ def test_combinations(self):
+ self.assertRaises(TypeError, combinations, 'abc') # missing r argument
+ self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
+ self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative
+ self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big
+ self.assertEqual(list(combinations(range(4), 3)),
+ [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
+ for n in range(6):
+ values = [5*x-12 for x in range(n)]
+ for r in range(n+1):
+ result = list(combinations(values, r))
+ self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs
+ self.assertEqual(len(result), len(set(result))) # no repeats
+ self.assertEqual(result, sorted(result)) # lexicographic order
+ for c in result:
+ self.assertEqual(len(c), r) # r-length combinations
+ self.assertEqual(len(set(c)), r) # no duplicate elements
+ self.assertEqual(list(c), sorted(c)) # keep original ordering
+ self.assert_(all(e in values for e in c)) # elements taken from input iterable
+
def test_count(self):
self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
self.assertEqual(zip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)])
diff --git a/Misc/NEWS b/Misc/NEWS
index d72fe74..c28fd85 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -670,6 +670,8 @@ Library
- Added itertools.product() which forms the Cartesian product of
the input iterables.
+- Added itertools.combinations().
+
- Patch #1541463: optimize performance of cgi.FieldStorage operations.
- Decimal is fully updated to the latest Decimal Specification (v1.66).
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index d95376a..10c5e0b 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -1982,6 +1982,229 @@ static PyTypeObject product_type = {
};
+/* combinations object ************************************************************/
+
+typedef struct {
+ PyObject_HEAD
+ PyObject *pool; /* input converted to a tuple */
+ Py_ssize_t *indices; /* one index per result element */
+ PyObject *result; /* most recently returned result tuple */
+ Py_ssize_t r; /* size of result tuple */
+ int stopped; /* set to 1 when the combinations iterator is exhausted */
+} combinationsobject;
+
+static PyTypeObject combinations_type;
+
+static PyObject *
+combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+ combinationsobject *co;
+ Py_ssize_t n;
+ Py_ssize_t r;
+ PyObject *pool = NULL;
+ PyObject *iterable = NULL;
+ Py_ssize_t *indices = NULL;
+ Py_ssize_t i;
+ static char *kwargs[] = {"iterable", "r", NULL};
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "On:combinations", kwargs,
+ &iterable, &r))
+ return NULL;
+
+ pool = PySequence_Tuple(iterable);
+ if (pool == NULL)
+ goto error;
+ n = PyTuple_GET_SIZE(pool);
+ if (r < 0) {
+ PyErr_SetString(PyExc_ValueError, "r must be non-negative");
+ goto error;
+ }
+ if (r > n) {
+ PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
+ goto error;
+ }
+
+ indices = PyMem_Malloc(r * sizeof(Py_ssize_t));
+ if (indices == NULL) {
+ PyErr_NoMemory();
+ goto error;
+ }
+
+ for (i=0 ; i<r ; i++)
+ indices[i] = i;
+
+ /* create combinationsobject structure */
+ co = (combinationsobject *)type->tp_alloc(type, 0);
+ if (co == NULL)
+ goto error;
+
+ co->pool = pool;
+ co->indices = indices;
+ co->result = NULL;
+ co->r = r;
+ co->stopped = 0;
+
+ return (PyObject *)co;
+
+error:
+ if (indices != NULL)
+ PyMem_Free(indices);
+ Py_XDECREF(pool);
+ return NULL;
+}
+
+static void
+combinations_dealloc(combinationsobject *co)
+{
+ PyObject_GC_UnTrack(co);
+ Py_XDECREF(co->pool);
+ Py_XDECREF(co->result);
+ PyMem_Free(co->indices);
+ Py_TYPE(co)->tp_free(co);
+}
+
+static int
+combinations_traverse(combinationsobject *co, visitproc visit, void *arg)
+{
+ Py_VISIT(co->pool);
+ Py_VISIT(co->result);
+ return 0;
+}
+
+static PyObject *
+combinations_next(combinationsobject *co)
+{
+ PyObject *elem;
+ PyObject *oldelem;
+ PyObject *pool = co->pool;
+ Py_ssize_t *indices = co->indices;
+ PyObject *result = co->result;
+ Py_ssize_t n = PyTuple_GET_SIZE(pool);
+ Py_ssize_t r = co->r;
+ Py_ssize_t i, j, index;
+
+ if (co->stopped)
+ return NULL;
+
+ if (result == NULL) {
+ /* On the first pass, initialize result tuple using the indices */
+ result = PyTuple_New(r);
+ if (result == NULL)
+ goto empty;
+ co->result = result;
+ for (i=0; i<r ; i++) {
+ index = indices[i];
+ elem = PyTuple_GET_ITEM(pool, index);
+ Py_INCREF(elem);
+ PyTuple_SET_ITEM(result, i, elem);
+ }
+ } else {
+ /* Copy the previous result tuple or re-use it if available */
+ if (Py_REFCNT(result) > 1) {
+ PyObject *old_result = result;
+ result = PyTuple_New(r);
+ if (result == NULL)
+ goto empty;
+ co->result = result;
+ for (i=0; i<r ; i++) {
+ elem = PyTuple_GET_ITEM(old_result, i);
+ Py_INCREF(elem);
+ PyTuple_SET_ITEM(result, i, elem);
+ }
+ Py_DECREF(old_result);
+ }
+ /* Now, we've got the only copy so we can update it in-place */
+ assert (Py_REFCNT(result) == 1);
+
+ /* Scan indices right-to-left until finding one that is not
+ at its maximum (i + n - r). */
+ for (i=r-1 ; i >= 0 && indices[i] == i+n-r ; i--)
+ ;
+
+ /* If i is negative, then the indices are all at
+ their maximum value and we're done. */
+ if (i < 0)
+ goto empty;
+
+ /* Increment the current index which we know is not at its
+ maximum. Then move back to the right setting each index
+ to its lowest possible value (one higher than the index
+ to its left -- this maintains the sort order invariant). */
+ indices[i]++;
+ for (j=i+1 ; j<r ; j++)
+ indices[j] = indices[j-1] + 1;
+
+ /* Update the result tuple for the new indices
+ starting with i, the leftmost index that changed */
+ for ( ; i<r ; i++) {
+ index = indices[i];
+ elem = PyTuple_GET_ITEM(pool, index);
+ Py_INCREF(elem);
+ oldelem = PyTuple_GET_ITEM(result, i);
+ PyTuple_SET_ITEM(result, i, elem);
+ Py_DECREF(oldelem);
+ }
+ }
+
+ Py_INCREF(result);
+ return result;
+
+empty:
+ co->stopped = 1;
+ return NULL;
+}
+
+PyDoc_STRVAR(combinations_doc,
+"combinations(iterables) --> combinations object\n\
+\n\
+Return successive r-length combinations of elements in the iterable.\n\n\
+combinations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)");
+
+static PyTypeObject combinations_type = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ "itertools.combinations", /* tp_name */
+ sizeof(combinationsobject), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ /* methods */
+ (destructor)combinations_dealloc, /* tp_dealloc */
+ 0, /* tp_print */
+ 0, /* tp_getattr */
+ 0, /* tp_setattr */
+ 0, /* tp_compare */
+ 0, /* tp_repr */
+ 0, /* tp_as_number */
+ 0, /* tp_as_sequence */
+ 0, /* tp_as_mapping */
+ 0, /* tp_hash */
+ 0, /* tp_call */
+ 0, /* tp_str */
+ PyObject_GenericGetAttr, /* tp_getattro */
+ 0, /* tp_setattro */
+ 0, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
+ Py_TPFLAGS_BASETYPE, /* tp_flags */
+ combinations_doc, /* tp_doc */
+ (traverseproc)combinations_traverse, /* tp_traverse */
+ 0, /* tp_clear */
+ 0, /* tp_richcompare */
+ 0, /* tp_weaklistoffset */
+ PyObject_SelfIter, /* tp_iter */
+ (iternextfunc)combinations_next, /* tp_iternext */
+ 0, /* tp_methods */
+ 0, /* tp_members */
+ 0, /* tp_getset */
+ 0, /* tp_base */
+ 0, /* tp_dict */
+ 0, /* tp_descr_get */
+ 0, /* tp_descr_set */
+ 0, /* tp_dictoffset */
+ 0, /* tp_init */
+ 0, /* tp_alloc */
+ combinations_new, /* tp_new */
+ PyObject_GC_Del, /* tp_free */
+};
+
+
/* ifilter object ************************************************************/
typedef struct {
@@ -3026,6 +3249,7 @@ inititertools(void)
PyObject *m;
char *name;
PyTypeObject *typelist[] = {
+ &combinations_type,
&cycle_type,
&dropwhile_type,
&takewhile_type,
@@ -3038,7 +3262,7 @@ inititertools(void)
&count_type,
&izip_type,
&iziplongest_type,
- &product_type,
+ &product_type,
&repeat_type,
&groupby_type,
NULL