From 825758c50b3c8006db16c1b8627e417db32a1d23 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Thu, 8 Jan 2009 05:20:19 +0000 Subject: - Issue 4816: itertools.combinations() and itertools.product were raising a ValueError for values of *r* larger than the input iterable. They now correctly return an empty iterator. --- Doc/library/itertools.rst | 13 ++++++++++++- Lib/test/test_itertools.py | 42 +++++++++++++++++++++++++++++++++--------- Misc/NEWS | 4 ++++ Modules/itertoolsmodule.c | 12 ++---------- 4 files changed, 51 insertions(+), 20 deletions(-) diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index 67646c6..aef3f6a 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -108,6 +108,8 @@ loops that truncate the stream. # combinations(range(4), 3) --> 012 013 023 123 pool = tuple(iterable) n = len(pool) + if r > n: + return indices = range(r) yield tuple(pool[i] for i in indices) while 1: @@ -132,6 +134,9 @@ loops that truncate the stream. if sorted(indices) == list(indices): yield tuple(pool[i] for i in indices) + The number of items returned is ``n! / r! / (n-r)!`` when ``0 <= r <= n`` + or zero when ``r > n``. + .. versionadded:: 2.6 .. function:: count([n]) @@ -399,6 +404,8 @@ loops that truncate the stream. pool = tuple(iterable) n = len(pool) r = n if r is None else r + if r > n: + return indices = range(n) cycles = range(n, n-r, -1) yield tuple(pool[i] for i in indices[:r]) @@ -428,6 +435,9 @@ loops that truncate the stream. if len(set(indices)) == r: yield tuple(pool[i] for i in indices) + The number of items returned is ``n! / (n-r)!`` when ``0 <= r <= n`` + or zero when ``r > n``. + .. versionadded:: 2.6 .. function:: product(*iterables[, repeat]) @@ -674,7 +684,8 @@ which incur interpreter overhead. return (d for d, s in izip(data, selectors) if s) def combinations_with_replacement(iterable, r): - "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC" + "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC" + # number items returned: (n+r-1)! / r! / (n-1)! pool = tuple(iterable) n = len(pool) indices = [0] * r diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 029498a..2182cb9 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -71,11 +71,11 @@ class TestBasicOps(unittest.TestCase): self.assertRaises(TypeError, list, chain.from_iterable([2, 3])) def test_combinations(self): - self.assertRaises(TypeError, combinations, 'abc') # missing r argument + self.assertRaises(TypeError, combinations, 'abc') # missing r argument self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, combinations, None) # pool is not iterable self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative - self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big + self.assertEqual(list(combinations('abc', 32)), []) # r > n self.assertEqual(list(combinations(range(4), 3)), [(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) @@ -83,6 +83,8 @@ class TestBasicOps(unittest.TestCase): 'Pure python version shown in the docs' pool = tuple(iterable) n = len(pool) + if r > n: + return indices = range(r) yield tuple(pool[i] for i in indices) while 1: @@ -106,9 +108,9 @@ class TestBasicOps(unittest.TestCase): for n in range(7): values = [5*x-12 for x in range(n)] - for r in range(n+1): + for r in range(n+2): result = list(combinations(values, r)) - self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs + self.assertEqual(len(result), 0 if r>n else fact(n) / fact(r) / fact(n-r)) # right number of combs self.assertEqual(len(result), len(set(result))) # no repeats self.assertEqual(result, sorted(result)) # lexicographic order for c in result: @@ -119,7 +121,7 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(list(c), [e for e in values if e in c]) # comb is a subsequence of the input iterable self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version - self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version + self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version # Test implementation detail: tuple re-use self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1) @@ -130,7 +132,7 @@ class TestBasicOps(unittest.TestCase): self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, permutations, None) # pool is not iterable self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative - self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big + self.assertEqual(list(permutations('abc', 32)), []) # r > n self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None self.assertEqual(list(permutations(range(3), 2)), [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) @@ -140,6 +142,8 @@ class TestBasicOps(unittest.TestCase): pool = tuple(iterable) n = len(pool) r = n if r is None else r + if r > n: + return indices = range(n) cycles = range(n, n-r, -1) yield tuple(pool[i] for i in indices[:r]) @@ -168,9 +172,9 @@ class TestBasicOps(unittest.TestCase): for n in range(7): values = [5*x-12 for x in range(n)] - for r in range(n+1): + for r in range(n+2): result = list(permutations(values, r)) - self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms + self.assertEqual(len(result), 0 if r>n else fact(n) / fact(n-r)) # right number of perms self.assertEqual(len(result), len(set(result))) # no repeats self.assertEqual(result, sorted(result)) # lexicographic order for p in result: @@ -178,7 +182,7 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(len(set(p)), r) # no duplicate elements self.assert_(all(e in values for e in p)) # elements taken from input iterable self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version - self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version + self.assertEqual(result, list(permutations2(values, r))) # matches second pure python version if r == n: self.assertEqual(result, list(permutations(values, None))) # test r as None self.assertEqual(result, list(permutations(values))) # test default r @@ -1363,6 +1367,26 @@ perform as purported. >>> list(combinations_with_replacement('abc', 2)) [('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')] +>>> list(combinations_with_replacement('01', 3)) +[('0', '0', '0'), ('0', '0', '1'), ('0', '1', '1'), ('1', '1', '1')] + +>>> def combinations_with_replacement2(iterable, r): +... 'Alternate version that filters from product()' +... pool = tuple(iterable) +... n = len(pool) +... for indices in product(range(n), repeat=r): +... if sorted(indices) == list(indices): +... yield tuple(pool[i] for i in indices) + +>>> list(combinations_with_replacement('abc', 2)) == list(combinations_with_replacement2('abc', 2)) +True + +>>> list(combinations_with_replacement('01', 3)) == list(combinations_with_replacement2('01', 3)) +True + +>>> list(combinations_with_replacement('2310', 6)) == list(combinations_with_replacement2('2310', 6)) +True + >>> list(unique_everseen('AAAABBBCCDAABBB')) ['A', 'B', 'C', 'D'] diff --git a/Misc/NEWS b/Misc/NEWS index eaad519..f9aa06f 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -67,6 +67,10 @@ Core and Builtins Library ------- +- Issue 4816: itertools.combinations() and itertools.product were raising + a ValueError for values of *r* larger than the input iterable. They now + correctly return an empty iterator. + - Fractions.from_float() no longer loses precision for integers too big to cast as floats. diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 18a1229..5875d10 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -2059,10 +2059,6 @@ combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) PyErr_SetString(PyExc_ValueError, "r must be non-negative"); goto error; } - if (r > n) { - PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable"); - goto error; - } indices = PyMem_Malloc(r * sizeof(Py_ssize_t)); if (indices == NULL) { @@ -2082,7 +2078,7 @@ combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) co->indices = indices; co->result = NULL; co->r = r; - co->stopped = 0; + co->stopped = r > n ? 1 : 0; return (PyObject *)co; @@ -2318,10 +2314,6 @@ permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) PyErr_SetString(PyExc_ValueError, "r must be non-negative"); goto error; } - if (r > n) { - PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable"); - goto error; - } indices = PyMem_Malloc(n * sizeof(Py_ssize_t)); cycles = PyMem_Malloc(r * sizeof(Py_ssize_t)); @@ -2345,7 +2337,7 @@ permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds) po->cycles = cycles; po->result = NULL; po->r = r; - po->stopped = 0; + po->stopped = r > n ? 1 : 0; return (PyObject *)po; -- cgit v0.12