diff options
author | Raymond Hettinger <python@rcn.com> | 2008-03-04 04:17:08 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2008-03-04 04:17:08 (GMT) |
commit | d553d856e7471b4e6c9ede6a445b813a7e468d4f (patch) | |
tree | 9f47f542aae79d4eda81e4bc7443c996140e75c9 /Lib/test/test_itertools.py | |
parent | 378586a844cd1cc8346482b515a24740eedbb59e (diff) | |
download | cpython-d553d856e7471b4e6c9ede6a445b813a7e468d4f.zip cpython-d553d856e7471b4e6c9ede6a445b813a7e468d4f.tar.gz cpython-d553d856e7471b4e6c9ede6a445b813a7e468d4f.tar.bz2 |
Beef-up docs and tests for itertools. Fix-up end-case for product().
Diffstat (limited to 'Lib/test/test_itertools.py')
-rw-r--r-- | Lib/test/test_itertools.py | 118 |
1 files changed, 112 insertions, 6 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 087570c..4197989 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -40,9 +40,21 @@ def take(n, seq): 'Convenience function for partially consuming a long of infinite iterable' return list(islice(seq, n)) +def prod(iterable): + return reduce(operator.mul, iterable, 1) + def fact(n): 'Factorial' - return reduce(operator.mul, range(1, n+1), 1) + return prod(range(1, n+1)) + +def permutations(iterable, r=None): + # XXX use this until real permutations code is added + pool = tuple(iterable) + n = len(pool) + r = n if r is None else r + for indices in product(range(n), repeat=r): + if len(set(indices)) == r: + yield tuple(pool[i] for i in indices) class TestBasicOps(unittest.TestCase): def test_chain(self): @@ -62,11 +74,38 @@ class TestBasicOps(unittest.TestCase): def test_combinations(self): self.assertRaises(TypeError, combinations, 'abc') # missing r argument self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments + self.assertRaises(TypeError, combinations, None) # pool is not iterable self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative self.assertRaises(ValueError, combinations, 'abc', 32) # r is too big self.assertEqual(list(combinations(range(4), 3)), [(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) - for n in range(8): + + def combinations1(iterable, r): + 'Pure python version shown in the docs' + pool = tuple(iterable) + n = len(pool) + indices = range(r) + yield tuple(pool[i] for i in indices) + while 1: + for i in reversed(range(r)): + if indices[i] != i + n - r: + break + else: + return + indices[i] += 1 + for j in range(i+1, r): + indices[j] = indices[j-1] + 1 + yield tuple(pool[i] for i in indices) + + def combinations2(iterable, r): + 'Pure python version shown in the docs' + pool = tuple(iterable) + n = len(pool) + for indices in permutations(range(n), r): + if sorted(indices) == list(indices): + yield tuple(pool[i] for i in indices) + + for n in range(7): values = [5*x-12 for x in range(n)] for r in range(n+1): result = list(combinations(values, r)) @@ -78,6 +117,73 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(len(set(c)), r) # no duplicate elements self.assertEqual(list(c), sorted(c)) # keep original ordering self.assert_(all(e in values for e in c)) # elements taken from input iterable + self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version + self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version + + # Test implementation detail: tuple re-use + self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1) + self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1) + + def test_permutations(self): + self.assertRaises(TypeError, permutations) # too few arguments + self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments +## self.assertRaises(TypeError, permutations, None) # pool is not iterable +## self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative +## self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big + self.assertEqual(list(permutations(range(3), 2)), + [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) + + def permutations1(iterable, r=None): + 'Pure python version shown in the docs' + pool = tuple(iterable) + n = len(pool) + r = n if r is None else r + indices = range(n) + cycles = range(n-r+1, n+1)[::-1] + yield tuple(pool[i] for i in indices[:r]) + while n: + for i in reversed(range(r)): + cycles[i] -= 1 + if cycles[i] == 0: + indices[i:] = indices[i+1:] + indices[i:i+1] + cycles[i] = n - i + else: + j = cycles[i] + indices[i], indices[-j] = indices[-j], indices[i] + yield tuple(pool[i] for i in indices[:r]) + break + else: + return + + def permutations2(iterable, r=None): + 'Pure python version shown in the docs' + pool = tuple(iterable) + n = len(pool) + r = n if r is None else r + for indices in product(range(n), repeat=r): + if len(set(indices)) == r: + yield tuple(pool[i] for i in indices) + + for n in range(7): + values = [5*x-12 for x in range(n)] + for r in range(n+1): + result = list(permutations(values, r)) + self.assertEqual(len(result), fact(n) / fact(n-r)) # right number of perms + self.assertEqual(len(result), len(set(result))) # no repeats + self.assertEqual(result, sorted(result)) # lexicographic order + for p in result: + self.assertEqual(len(p), r) # r-length permutations + self.assertEqual(len(set(p)), r) # no duplicate elements + self.assert_(all(e in values for e in p)) # elements taken from input iterable + self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version + self.assertEqual(result, list(permutations2(values, r))) # matches first pure python version + if r == n: + self.assertEqual(result, list(permutations(values, None))) # test r as None + self.assertEqual(result, list(permutations(values))) # test default r + + # Test implementation detail: tuple re-use +## self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1) + self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1) def test_count(self): self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)]) @@ -288,7 +394,7 @@ class TestBasicOps(unittest.TestCase): def test_product(self): for args, result in [ - ([], []), # zero iterables ??? is this correct + ([], [()]), # zero iterables (['ab'], [('a',), ('b',)]), # one iterable ([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]), # two iterables ([range(0), range(2), range(3)], []), # first iterable with zero length @@ -305,10 +411,10 @@ class TestBasicOps(unittest.TestCase): set('abcdefg'), range(11), tuple(range(13))] for i in range(100): args = [random.choice(argtypes) for j in range(random.randrange(5))] - n = reduce(operator.mul, map(len, args), 1) if args else 0 - self.assertEqual(len(list(product(*args))), n) + expected_len = prod(map(len, args)) + self.assertEqual(len(list(product(*args))), expected_len) args = map(iter, args) - self.assertEqual(len(list(product(*args))), n) + self.assertEqual(len(list(product(*args))), expected_len) # Test implementation detail: tuple re-use self.assertEqual(len(set(map(id, product('abc', 'def')))), 1) |