summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Doc/library/collections.rst3
-rw-r--r--Doc/library/itertools.rst63
-rw-r--r--Lib/test/test_itertools.py116
-rw-r--r--Misc/NEWS3
-rw-r--r--Modules/itertoolsmodule.c253
5 files changed, 379 insertions, 59 deletions
diff --git a/Doc/library/collections.rst b/Doc/library/collections.rst
index b358e38..5b25b47 100644
--- a/Doc/library/collections.rst
+++ b/Doc/library/collections.rst
@@ -291,8 +291,7 @@ counts less than one::
Section 4.6.3, Exercise 19*\.
* To enumerate all distinct multisets of a given size over a given set of
- elements, see :func:`combinations_with_replacement` in the
- :ref:`itertools-recipes` for itertools::
+ elements, see :func:`itertools.combinations_with_replacement`.
map(Counter, combinations_with_replacement('ABC', 2)) --> AA AB AC BB BC CC
diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst
index b7cd431..9aff478 100644
--- a/Doc/library/itertools.rst
+++ b/Doc/library/itertools.rst
@@ -139,6 +139,53 @@ loops that truncate the stream.
.. versionadded:: 2.6
+.. function:: combinations_with_replacement(iterable, r)
+
+ Return *r* length subsequences of elements from the input *iterable*
+ allowing individual elements to be repeated more than once.
+
+ Combinations are emitted in lexicographic sort order. So, if the
+ input *iterable* is sorted, the combination tuples will be produced
+ in sorted order.
+
+ Elements are treated as unique based on their position, not on their
+ value. So if the input elements are unique, the generated combinations
+ will also be unique.
+
+ Equivalent to::
+
+ def combinations_with_replacement(iterable, r):
+ # combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC
+ pool = tuple(iterable)
+ n = len(pool)
+ if not n and r:
+ return
+ indices = [0] * r
+ yield tuple(pool[i] for i in indices)
+ while 1:
+ for i in reversed(range(r)):
+ if indices[i] != n - 1:
+ break
+ else:
+ return
+ indices[i:] = [indices[i] + 1] * (r - i)
+ yield tuple(pool[i] for i in indices)
+
+ The code for :func:`combinations_with_replacement` can be also expressed as
+ a subsequence of :func:`product` after filtering entries where the elements
+ are not in sorted order (according to their position in the input pool)::
+
+ def combinations_with_replacement(iterable, r):
+ pool = tuple(iterable)
+ n = len(pool)
+ for indices in product(range(n), repeat=r):
+ if sorted(indices) == list(indices):
+ yield tuple(pool[i] for i in indices)
+
+ The number of items returned is ``(n+r-1)! / r! / (n-1)!`` when ``n > 0``.
+
+ .. versionadded:: 2.7
+
.. function:: compress(data, selectors)
Make an iterator that filters elements from *data* returning only those that
@@ -691,22 +738,6 @@ which incur interpreter overhead.
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
- def combinations_with_replacement(iterable, r):
- "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
- # number items returned: (n+r-1)! / r! / (n-1)!
- pool = tuple(iterable)
- n = len(pool)
- indices = [0] * r
- yield tuple(pool[i] for i in indices)
- while 1:
- for i in reversed(range(r)):
- if indices[i] != n - 1:
- break
- else:
- return
- indices[i:] = [indices[i] + 1] * (r - i)
- yield tuple(pool[i] for i in indices)
-
def unique_everseen(iterable, key=None):
"List unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') --> A B C D
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index 9b399c0..23a8765 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -127,6 +127,76 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1)
+ def test_combinations_with_replacement(self):
+ cwr = combinations_with_replacement
+ self.assertRaises(TypeError, cwr, 'abc') # missing r argument
+ self.assertRaises(TypeError, cwr, 'abc', 2, 1) # too many arguments
+ self.assertRaises(TypeError, cwr, None) # pool is not iterable
+ self.assertRaises(ValueError, cwr, 'abc', -2) # r is negative
+ self.assertEqual(list(cwr('ABC', 2)),
+ [('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
+
+ def cwr1(iterable, r):
+ 'Pure python version shown in the docs'
+ # number items returned: (n+r-1)! / r! / (n-1)! when n>0
+ pool = tuple(iterable)
+ n = len(pool)
+ if not n and r:
+ return
+ indices = [0] * r
+ yield tuple(pool[i] for i in indices)
+ while 1:
+ for i in reversed(range(r)):
+ if indices[i] != n - 1:
+ break
+ else:
+ return
+ indices[i:] = [indices[i] + 1] * (r - i)
+ yield tuple(pool[i] for i in indices)
+
+ def cwr2(iterable, r):
+ 'Pure python version shown in the docs'
+ pool = tuple(iterable)
+ n = len(pool)
+ for indices in product(range(n), repeat=r):
+ if sorted(indices) == list(indices):
+ yield tuple(pool[i] for i in indices)
+
+ def numcombs(n, r):
+ if not n:
+ return 0 if r else 1
+ return fact(n+r-1) / fact(r)/ fact(n-1)
+
+ for n in range(7):
+ values = [5*x-12 for x in range(n)]
+ for r in range(n+2):
+ result = list(cwr(values, r))
+
+ self.assertEqual(len(result), numcombs(n, r)) # right number of combs
+ self.assertEqual(len(result), len(set(result))) # no repeats
+ self.assertEqual(result, sorted(result)) # lexicographic order
+
+ regular_combs = list(combinations(values, r)) # compare to combs without replacement
+ if n == 0 or r <= 1:
+ self.assertEquals(result, regular_combs) # cases that should be identical
+ else:
+ self.assert_(set(result) >= set(regular_combs)) # rest should be supersets of regular combs
+
+ for c in result:
+ self.assertEqual(len(c), r) # r-length combinations
+ noruns = [k for k,v in groupby(c)] # combo without consecutive repeats
+ self.assertEqual(len(noruns), len(set(noruns))) # no repeats other than consecutive
+ self.assertEqual(list(c), sorted(c)) # keep original ordering
+ self.assert_(all(e in values for e in c)) # elements taken from input iterable
+ self.assertEqual(noruns,
+ [e for e in values if e in c]) # comb is a subsequence of the input iterable
+ self.assertEqual(result, list(cwr1(values, r))) # matches first pure python version
+ self.assertEqual(result, list(cwr2(values, r))) # matches second pure python version
+
+ # Test implementation detail: tuple re-use
+ self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
+ self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
+
def test_permutations(self):
self.assertRaises(TypeError, permutations) # too few arguments
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
@@ -716,6 +786,10 @@ class TestExamples(unittest.TestCase):
self.assertEqual(list(combinations(range(4), 3)),
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
+ def test_combinations_with_replacement(self):
+ self.assertEqual(list(combinations_with_replacement('ABC', 2)),
+ [('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
+
def test_compress(self):
self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
@@ -799,6 +873,10 @@ class TestGC(unittest.TestCase):
a = []
self.makecycle(combinations([1,2,a,3], 3), a)
+ def test_combinations_with_replacement(self):
+ a = []
+ self.makecycle(combinations_with_replacement([1,2,a,3], 3), a)
+
def test_compress(self):
a = []
self.makecycle(compress('ABCDEF', [1,0,1,0,1,0]), a)
@@ -1291,21 +1369,6 @@ Samuele
... s = list(iterable)
... return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
->>> def combinations_with_replacement(iterable, r):
-... "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC"
-... pool = tuple(iterable)
-... n = len(pool)
-... indices = [0] * r
-... yield tuple(pool[i] for i in indices)
-... while 1:
-... for i in reversed(range(r)):
-... if indices[i] != n - 1:
-... break
-... else:
-... return
-... indices[i:] = [indices[i] + 1] * (r - i)
-... yield tuple(pool[i] for i in indices)
-
>>> def unique_everseen(iterable, key=None):
... "List unique elements, preserving order. Remember all elements ever seen."
... # unique_everseen('AAAABBBCCDAABBB') --> A B C D
@@ -1386,29 +1449,6 @@ perform as purported.
>>> list(powerset([1,2,3]))
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
->>> list(combinations_with_replacement('abc', 2))
-[('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')]
-
->>> list(combinations_with_replacement('01', 3))
-[('0', '0', '0'), ('0', '0', '1'), ('0', '1', '1'), ('1', '1', '1')]
-
->>> def combinations_with_replacement2(iterable, r):
-... 'Alternate version that filters from product()'
-... pool = tuple(iterable)
-... n = len(pool)
-... for indices in product(range(n), repeat=r):
-... if sorted(indices) == list(indices):
-... yield tuple(pool[i] for i in indices)
-
->>> list(combinations_with_replacement('abc', 2)) == list(combinations_with_replacement2('abc', 2))
-True
-
->>> list(combinations_with_replacement('01', 3)) == list(combinations_with_replacement2('01', 3))
-True
-
->>> list(combinations_with_replacement('2310', 6)) == list(combinations_with_replacement2('2310', 6))
-True
-
>>> list(unique_everseen('AAAABBBCCDAABBB'))
['A', 'B', 'C', 'D']
diff --git a/Misc/NEWS b/Misc/NEWS
index da016c1..bb1aeac 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -153,7 +153,8 @@ Library
- Issue #4863: distutils.mwerkscompiler has been removed.
-- Added a new function: itertools.compress().
+- Added a new itertools functions: combinations_with_replacement()
+ and compress().
- Fix and properly document the multiprocessing module's logging
support, expose the internal levels and provide proper usage
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index f66d052..221dbe5 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -1862,7 +1862,8 @@ product_dealloc(productobject *lz)
PyObject_GC_UnTrack(lz);
Py_XDECREF(lz->pools);
Py_XDECREF(lz->result);
- PyMem_Free(lz->indices);
+ if (lz->indices != NULL)
+ PyMem_Free(lz->indices);
Py_TYPE(lz)->tp_free(lz);
}
@@ -2090,7 +2091,8 @@ combinations_dealloc(combinationsobject *co)
PyObject_GC_UnTrack(co);
Py_XDECREF(co->pool);
Py_XDECREF(co->result);
- PyMem_Free(co->indices);
+ if (co->indices != NULL)
+ PyMem_Free(co->indices);
Py_TYPE(co)->tp_free(co);
}
@@ -2239,6 +2241,252 @@ static PyTypeObject combinations_type = {
};
+/* combinations with replacement object *******************************************/
+
+/* Equivalent to:
+
+ def combinations_with_replacement(iterable, r):
+ "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
+ # number items returned: (n+r-1)! / r! / (n-1)!
+ pool = tuple(iterable)
+ n = len(pool)
+ indices = [0] * r
+ yield tuple(pool[i] for i in indices)
+ while 1:
+ for i in reversed(range(r)):
+ if indices[i] != n - 1:
+ break
+ else:
+ return
+ indices[i:] = [indices[i] + 1] * (r - i)
+ yield tuple(pool[i] for i in indices)
+
+ def combinations_with_replacement2(iterable, r):
+ 'Alternate version that filters from product()'
+ pool = tuple(iterable)
+ n = len(pool)
+ for indices in product(range(n), repeat=r):
+ if sorted(indices) == list(indices):
+ yield tuple(pool[i] for i in indices)
+*/
+typedef struct {
+ PyObject_HEAD
+ PyObject *pool; /* input converted to a tuple */
+ Py_ssize_t *indices; /* one index per result element */
+ PyObject *result; /* most recently returned result tuple */
+ Py_ssize_t r; /* size of result tuple */
+ int stopped; /* set to 1 when the cwr iterator is exhausted */
+} cwrobject;
+
+static PyTypeObject cwr_type;
+
+static PyObject *
+cwr_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+ cwrobject *co;
+ Py_ssize_t n;
+ Py_ssize_t r;
+ PyObject *pool = NULL;
+ PyObject *iterable = NULL;
+ Py_ssize_t *indices = NULL;
+ Py_ssize_t i;
+ static char *kwargs[] = {"iterable", "r", NULL};
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "On:combinations_with_replacement", kwargs,
+ &iterable, &r))
+ return NULL;
+
+ pool = PySequence_Tuple(iterable);
+ if (pool == NULL)
+ goto error;
+ n = PyTuple_GET_SIZE(pool);
+ if (r < 0) {
+ PyErr_SetString(PyExc_ValueError, "r must be non-negative");
+ goto error;
+ }
+
+ indices = PyMem_Malloc(r * sizeof(Py_ssize_t));
+ if (indices == NULL) {
+ PyErr_NoMemory();
+ goto error;
+ }
+
+ for (i=0 ; i<r ; i++)
+ indices[i] = 0;
+
+ /* create cwrobject structure */
+ co = (cwrobject *)type->tp_alloc(type, 0);
+ if (co == NULL)
+ goto error;
+
+ co->pool = pool;
+ co->indices = indices;
+ co->result = NULL;
+ co->r = r;
+ co->stopped = !n && r;
+
+ return (PyObject *)co;
+
+error:
+ if (indices != NULL)
+ PyMem_Free(indices);
+ Py_XDECREF(pool);
+ return NULL;
+}
+
+static void
+cwr_dealloc(cwrobject *co)
+{
+ PyObject_GC_UnTrack(co);
+ Py_XDECREF(co->pool);
+ Py_XDECREF(co->result);
+ if (co->indices != NULL)
+ PyMem_Free(co->indices);
+ Py_TYPE(co)->tp_free(co);
+}
+
+static int
+cwr_traverse(cwrobject *co, visitproc visit, void *arg)
+{
+ Py_VISIT(co->pool);
+ Py_VISIT(co->result);
+ return 0;
+}
+
+static PyObject *
+cwr_next(cwrobject *co)
+{
+ PyObject *elem;
+ PyObject *oldelem;
+ PyObject *pool = co->pool;
+ Py_ssize_t *indices = co->indices;
+ PyObject *result = co->result;
+ Py_ssize_t n = PyTuple_GET_SIZE(pool);
+ Py_ssize_t r = co->r;
+ Py_ssize_t i, j, index;
+
+ if (co->stopped)
+ return NULL;
+
+ if (result == NULL) {
+ /* On the first pass, initialize result tuple using the indices */
+ result = PyTuple_New(r);
+ if (result == NULL)
+ goto empty;
+ co->result = result;
+ for (i=0; i<r ; i++) {
+ index = indices[i];
+ elem = PyTuple_GET_ITEM(pool, index);
+ Py_INCREF(elem);
+ PyTuple_SET_ITEM(result, i, elem);
+ }
+ } else {
+ /* Copy the previous result tuple or re-use it if available */
+ if (Py_REFCNT(result) > 1) {
+ PyObject *old_result = result;
+ result = PyTuple_New(r);
+ if (result == NULL)
+ goto empty;
+ co->result = result;
+ for (i=0; i<r ; i++) {
+ elem = PyTuple_GET_ITEM(old_result, i);
+ Py_INCREF(elem);
+ PyTuple_SET_ITEM(result, i, elem);
+ }
+ Py_DECREF(old_result);
+ }
+ /* Now, we've got the only copy so we can update it in-place CPython's
+ empty tuple is a singleton and cached in PyTuple's freelist. */
+ assert(r == 0 || Py_REFCNT(result) == 1);
+
+ /* Scan indices right-to-left until finding one that is not
+ * at its maximum (n-1). */
+ for (i=r-1 ; i >= 0 && indices[i] == n-1; i--)
+ ;
+
+ /* If i is negative, then the indices are all at
+ their maximum value and we're done. */
+ if (i < 0)
+ goto empty;
+
+ /* Increment the current index which we know is not at its
+ maximum. Then set all to the right to the same value. */
+ indices[i]++;
+ for (j=i+1 ; j<r ; j++)
+ indices[j] = indices[j-1];
+
+ /* Update the result tuple for the new indices
+ starting with i, the leftmost index that changed */
+ for ( ; i<r ; i++) {
+ index = indices[i];
+ elem = PyTuple_GET_ITEM(pool, index);
+ Py_INCREF(elem);
+ oldelem = PyTuple_GET_ITEM(result, i);
+ PyTuple_SET_ITEM(result, i, elem);
+ Py_DECREF(oldelem);
+ }
+ }
+
+ Py_INCREF(result);
+ return result;
+
+empty:
+ co->stopped = 1;
+ return NULL;
+}
+
+PyDoc_STRVAR(cwr_doc,
+"combinations_with_replacement(iterable[, r]) --> combinations_with_replacement object\n\
+\n\
+Return successive r-length combinations of elements in the iterable\n\
+allowing individual elements to have successive repeats.\n\
+combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC");
+
+static PyTypeObject cwr_type = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ "itertools.combinations_with_replacement", /* tp_name */
+ sizeof(cwrobject), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ /* methods */
+ (destructor)cwr_dealloc, /* tp_dealloc */
+ 0, /* tp_print */
+ 0, /* tp_getattr */
+ 0, /* tp_setattr */
+ 0, /* tp_compare */
+ 0, /* tp_repr */
+ 0, /* tp_as_number */
+ 0, /* tp_as_sequence */
+ 0, /* tp_as_mapping */
+ 0, /* tp_hash */
+ 0, /* tp_call */
+ 0, /* tp_str */
+ PyObject_GenericGetAttr, /* tp_getattro */
+ 0, /* tp_setattro */
+ 0, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
+ Py_TPFLAGS_BASETYPE, /* tp_flags */
+ cwr_doc, /* tp_doc */
+ (traverseproc)cwr_traverse, /* tp_traverse */
+ 0, /* tp_clear */
+ 0, /* tp_richcompare */
+ 0, /* tp_weaklistoffset */
+ PyObject_SelfIter, /* tp_iter */
+ (iternextfunc)cwr_next, /* tp_iternext */
+ 0, /* tp_methods */
+ 0, /* tp_members */
+ 0, /* tp_getset */
+ 0, /* tp_base */
+ 0, /* tp_dict */
+ 0, /* tp_descr_get */
+ 0, /* tp_descr_set */
+ 0, /* tp_dictoffset */
+ 0, /* tp_init */
+ 0, /* tp_alloc */
+ cwr_new, /* tp_new */
+ PyObject_GC_Del, /* tp_free */
+};
+
+
/* permutations object ************************************************************
def permutations(iterable, r=None):
@@ -3701,6 +3949,7 @@ inititertools(void)
char *name;
PyTypeObject *typelist[] = {
&combinations_type,
+ &cwr_type,
&cycle_type,
&dropwhile_type,
&takewhile_type,