diff options
author | Raymond Hettinger <python@rcn.com> | 2009-01-08 05:20:19 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2009-01-08 05:20:19 (GMT) |
commit | 825758c50b3c8006db16c1b8627e417db32a1d23 (patch) | |
tree | 119e0903dc54f33806e5eae267ee52328b6d027c /Lib/test | |
parent | cd610ae7f21db5782b76d562002b27959d7e87b7 (diff) | |
download | cpython-825758c50b3c8006db16c1b8627e417db32a1d23.zip cpython-825758c50b3c8006db16c1b8627e417db32a1d23.tar.gz cpython-825758c50b3c8006db16c1b8627e417db32a1d23.tar.bz2 |
- Issue 4816: itertools.combinations() and itertools.product were raising
a ValueError for values of *r* larger than the input iterable. They now
correctly return an empty iterator.
Diffstat (limited to 'Lib/test')
-rw-r--r-- | Lib/test/test_itertools.py | 42 |
1 files changed, 33 insertions, 9 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 029498a..2182cb9 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -71,11 +71,11 @@ class TestBasicOps(unittest.TestCase): self.assertRaises(TypeError, list, chain.from_iterable([2, 3])) def test_combinations(self): - self.assertRaises(TypeError, combinations, 'abc') # missing r argument + self.assertRaises(TypeError, combinations, 'abc') # missing r argument self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, combinations, None) # pool is not iterable self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative - self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big + self.assertEqual(list(combinations('abc', 32)), []) # r > n self.assertEqual(list(combinations(range(4), 3)), [(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) @@ -83,6 +83,8 @@ class TestBasicOps(unittest.TestCase): 'Pure python version shown in the docs' pool = tuple(iterable) n = len(pool) + if r > n: + return indices = range(r) yield tuple(pool[i] for i in indices) while 1: @@ -106,9 +108,9 @@ class TestBasicOps(unittest.TestCase): for n in range(7): values = [5*x-12 for x in range(n)] - for r in range(n+1): + for r in range(n+2): result = list(combinations(values, r)) - self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs + self.assertEqual(len(result), 0 if r>n else fact(n) / fact(r) / fact(n-r)) # right number of combs self.assertEqual(len(result), len(set(result))) # no repeats self.assertEqual(result, sorted(result)) # lexicographic order for c in result: @@ -119,7 +121,7 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(list(c), [e for e in values if e in c]) # comb is a subsequence of the input iterable self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version - self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version + self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version # Test implementation detail: tuple re-use self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1) @@ -130,7 +132,7 @@ class TestBasicOps(unittest.TestCase): self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments self.assertRaises(TypeError, permutations, None) # pool is not iterable self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative - self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big + self.assertEqual(list(permutations('abc', 32)), []) # r > n self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None self.assertEqual(list(permutations(range(3), 2)), [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) @@ -140,6 +142,8 @@ class TestBasicOps(unittest.TestCase): pool = tuple(iterable) n = len(pool) r = n if r is None else r + if r > n: + return indices = range(n) cycles = range(n, n-r, -1) yield tuple(pool[i] for i in indices[:r]) @@ -168,9 +172,9 @@ class TestBasicOps(unittest.TestCase): for n in range(7): values = [5*x-12 for x in range(n)] - for r in range(n+1): + for r in range(n+2): result = list(permutations(values, r)) - self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms + self.assertEqual(len(result), 0 if r>n else fact(n) / fact(n-r)) # right number of perms self.assertEqual(len(result), len(set(result))) # no repeats self.assertEqual(result, sorted(result)) # lexicographic order for p in result: @@ -178,7 +182,7 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(len(set(p)), r) # no duplicate elements self.assert_(all(e in values for e in p)) # elements taken from input iterable self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version - self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version + self.assertEqual(result, list(permutations2(values, r))) # matches second pure python version if r == n: self.assertEqual(result, list(permutations(values, None))) # test r as None self.assertEqual(result, list(permutations(values))) # test default r @@ -1363,6 +1367,26 @@ perform as purported. >>> list(combinations_with_replacement('abc', 2)) [('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')] +>>> list(combinations_with_replacement('01', 3)) +[('0', '0', '0'), ('0', '0', '1'), ('0', '1', '1'), ('1', '1', '1')] + +>>> def combinations_with_replacement2(iterable, r): +... 'Alternate version that filters from product()' +... pool = tuple(iterable) +... n = len(pool) +... for indices in product(range(n), repeat=r): +... if sorted(indices) == list(indices): +... yield tuple(pool[i] for i in indices) + +>>> list(combinations_with_replacement('abc', 2)) == list(combinations_with_replacement2('abc', 2)) +True + +>>> list(combinations_with_replacement('01', 3)) == list(combinations_with_replacement2('01', 3)) +True + +>>> list(combinations_with_replacement('2310', 6)) == list(combinations_with_replacement2('2310', 6)) +True + >>> list(unique_everseen('AAAABBBCCDAABBB')) ['A', 'B', 'C', 'D'] |