diff options
author | Raymond Hettinger <python@rcn.com> | 2008-02-26 23:40:50 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2008-02-26 23:40:50 (GMT) |
commit | 93e804da9cb2e36801e153d2b9c9f8332c70784f (patch) | |
tree | 660987165ed0e07ae6658b20effbc3a67098e3f6 | |
parent | 3ef2063ec80da1f8ea5ee214caa3f5265d791448 (diff) | |
download | cpython-93e804da9cb2e36801e153d2b9c9f8332c70784f.zip cpython-93e804da9cb2e36801e153d2b9c9f8332c70784f.tar.gz cpython-93e804da9cb2e36801e153d2b9c9f8332c70784f.tar.bz2 |
Add itertools.combinations().
-rw-r--r-- | Doc/library/itertools.rst | 30 | ||||
-rw-r--r-- | Lib/test/test_itertools.py | 24 | ||||
-rw-r--r-- | Misc/NEWS | 2 | ||||
-rw-r--r-- | Modules/itertoolsmodule.c | 226 |
4 files changed, 266 insertions, 16 deletions
diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index c8f6e33..6dc19a1 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -97,21 +97,21 @@ loops that truncate the stream. def combinations(iterable, r): pool = tuple(iterable) - if pool: - n = len(pool) - vec = range(r) - yield tuple(pool[i] for i in vec) - while 1: - for i in reversed(range(r)): - if vec[i] == i + n-r: - continue - vec[i] += 1 - for j in range(i+1, r): - vec[j] = vec[j-1] + 1 - yield tuple(pool[i] for i in vec) - break - else: - return + n = len(pool) + assert 0 <= r <= n + vec = range(r) + yield tuple(pool[i] for i in vec) + while 1: + for i in reversed(range(r)): + if vec[i] == i + n-r: + continue + vec[i] += 1 + for j in range(i+1, r): + vec[j] = vec[j-1] + 1 + yield tuple(pool[i] for i in vec) + break + else: + return .. versionadded:: 2.6 diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index dc9081e..9868325 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -40,6 +40,10 @@ def take(n, seq): 'Convenience function for partially consuming a long of infinite iterable' return list(islice(seq, n)) +def fact(n): + 'Factorial' + return reduce(operator.mul, range(1, n+1), 1) + class TestBasicOps(unittest.TestCase): def test_chain(self): self.assertEqual(list(chain('abc', 'def')), list('abcdef')) @@ -48,6 +52,26 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(take(4, chain('abc', 'def')), list('abcd')) self.assertRaises(TypeError, chain, 2, 3) + def test_combinations(self): + self.assertRaises(TypeError, combinations, 'abc') # missing r argument + self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments + self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative + self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big + self.assertEqual(list(combinations(range(4), 3)), + [(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) + for n in range(6): + values = [5*x-12 for x in range(n)] + for r in range(n+1): + result = list(combinations(values, r)) + self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs + self.assertEqual(len(result), len(set(result))) # no repeats + self.assertEqual(result, sorted(result)) # lexicographic order + for c in result: + self.assertEqual(len(c), r) # r-length combinations + self.assertEqual(len(set(c)), r) # no duplicate elements + self.assertEqual(list(c), sorted(c)) # keep original ordering + self.assert_(all(e in values for e in c)) # elements taken from input iterable + def test_count(self): self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)]) self.assertEqual(zip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)]) @@ -670,6 +670,8 @@ Library - Added itertools.product() which forms the Cartesian product of the input iterables. +- Added itertools.combinations(). + - Patch #1541463: optimize performance of cgi.FieldStorage operations. - Decimal is fully updated to the latest Decimal Specification (v1.66). diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index d95376a..10c5e0b 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -1982,6 +1982,229 @@ static PyTypeObject product_type = { }; +/* combinations object ************************************************************/ + +typedef struct { + PyObject_HEAD + PyObject *pool; /* input converted to a tuple */ + Py_ssize_t *indices; /* one index per result element */ + PyObject *result; /* most recently returned result tuple */ + Py_ssize_t r; /* size of result tuple */ + int stopped; /* set to 1 when the combinations iterator is exhausted */ +} combinationsobject; + +static PyTypeObject combinations_type; + +static PyObject * +combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + combinationsobject *co; + Py_ssize_t n; + Py_ssize_t r; + PyObject *pool = NULL; + PyObject *iterable = NULL; + Py_ssize_t *indices = NULL; + Py_ssize_t i; + static char *kwargs[] = {"iterable", "r", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "On:combinations", kwargs, + &iterable, &r)) + return NULL; + + pool = PySequence_Tuple(iterable); + if (pool == NULL) + goto error; + n = PyTuple_GET_SIZE(pool); + if (r < 0) { + PyErr_SetString(PyExc_ValueError, "r must be non-negative"); + goto error; + } + if (r > n) { + PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable"); + goto error; + } + + indices = PyMem_Malloc(r * sizeof(Py_ssize_t)); + if (indices == NULL) { + PyErr_NoMemory(); + goto error; + } + + for (i=0 ; i<r ; i++) + indices[i] = i; + + /* create combinationsobject structure */ + co = (combinationsobject *)type->tp_alloc(type, 0); + if (co == NULL) + goto error; + + co->pool = pool; + co->indices = indices; + co->result = NULL; + co->r = r; + co->stopped = 0; + + return (PyObject *)co; + +error: + if (indices != NULL) + PyMem_Free(indices); + Py_XDECREF(pool); + return NULL; +} + +static void +combinations_dealloc(combinationsobject *co) +{ + PyObject_GC_UnTrack(co); + Py_XDECREF(co->pool); + Py_XDECREF(co->result); + PyMem_Free(co->indices); + Py_TYPE(co)->tp_free(co); +} + +static int +combinations_traverse(combinationsobject *co, visitproc visit, void *arg) +{ + Py_VISIT(co->pool); + Py_VISIT(co->result); + return 0; +} + +static PyObject * +combinations_next(combinationsobject *co) +{ + PyObject *elem; + PyObject *oldelem; + PyObject *pool = co->pool; + Py_ssize_t *indices = co->indices; + PyObject *result = co->result; + Py_ssize_t n = PyTuple_GET_SIZE(pool); + Py_ssize_t r = co->r; + Py_ssize_t i, j, index; + + if (co->stopped) + return NULL; + + if (result == NULL) { + /* On the first pass, initialize result tuple using the indices */ + result = PyTuple_New(r); + if (result == NULL) + goto empty; + co->result = result; + for (i=0; i<r ; i++) { + index = indices[i]; + elem = PyTuple_GET_ITEM(pool, index); + Py_INCREF(elem); + PyTuple_SET_ITEM(result, i, elem); + } + } else { + /* Copy the previous result tuple or re-use it if available */ + if (Py_REFCNT(result) > 1) { + PyObject *old_result = result; + result = PyTuple_New(r); + if (result == NULL) + goto empty; + co->result = result; + for (i=0; i<r ; i++) { + elem = PyTuple_GET_ITEM(old_result, i); + Py_INCREF(elem); + PyTuple_SET_ITEM(result, i, elem); + } + Py_DECREF(old_result); + } + /* Now, we've got the only copy so we can update it in-place */ + assert (Py_REFCNT(result) == 1); + + /* Scan indices right-to-left until finding one that is not + at its maximum (i + n - r). */ + for (i=r-1 ; i >= 0 && indices[i] == i+n-r ; i--) + ; + + /* If i is negative, then the indices are all at + their maximum value and we're done. */ + if (i < 0) + goto empty; + + /* Increment the current index which we know is not at its + maximum. Then move back to the right setting each index + to its lowest possible value (one higher than the index + to its left -- this maintains the sort order invariant). */ + indices[i]++; + for (j=i+1 ; j<r ; j++) + indices[j] = indices[j-1] + 1; + + /* Update the result tuple for the new indices + starting with i, the leftmost index that changed */ + for ( ; i<r ; i++) { + index = indices[i]; + elem = PyTuple_GET_ITEM(pool, index); + Py_INCREF(elem); + oldelem = PyTuple_GET_ITEM(result, i); + PyTuple_SET_ITEM(result, i, elem); + Py_DECREF(oldelem); + } + } + + Py_INCREF(result); + return result; + +empty: + co->stopped = 1; + return NULL; +} + +PyDoc_STRVAR(combinations_doc, +"combinations(iterables) --> combinations object\n\ +\n\ +Return successive r-length combinations of elements in the iterable.\n\n\ +combinations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)"); + +static PyTypeObject combinations_type = { + PyVarObject_HEAD_INIT(NULL, 0) + "itertools.combinations", /* tp_name */ + sizeof(combinationsobject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)combinations_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + combinations_doc, /* tp_doc */ + (traverseproc)combinations_traverse, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + (iternextfunc)combinations_next, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + combinations_new, /* tp_new */ + PyObject_GC_Del, /* tp_free */ +}; + + /* ifilter object ************************************************************/ typedef struct { @@ -3026,6 +3249,7 @@ inititertools(void) PyObject *m; char *name; PyTypeObject *typelist[] = { + &combinations_type, &cycle_type, &dropwhile_type, &takewhile_type, @@ -3038,7 +3262,7 @@ inititertools(void) &count_type, &izip_type, &iziplongest_type, - &product_type, + &product_type, &repeat_type, &groupby_type, NULL |