diff options
Diffstat (limited to 'Lib/test/test_itertools.py')
-rw-r--r-- | Lib/test/test_itertools.py | 224 |
1 files changed, 204 insertions, 20 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index d44235b..335e47d 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -51,22 +51,21 @@ def fact(n): 'Factorial' return prod(range(1, n+1)) -def permutations(iterable, r=None): - # XXX use this until real permutations code is added - pool = tuple(iterable) - n = len(pool) - r = n if r is None else r - for indices in product(range(n), repeat=r): - if len(set(indices)) == r: - yield tuple(pool[i] for i in indices) - class TestBasicOps(unittest.TestCase): def test_chain(self): - self.assertEqual(list(chain('abc', 'def')), list('abcdef')) - self.assertEqual(list(chain('abc')), list('abc')) - self.assertEqual(list(chain('')), []) - self.assertEqual(take(4, chain('abc', 'def')), list('abcd')) - self.assertRaises(TypeError, list,chain(2, 3)) + + def chain2(*iterables): + 'Pure python version in the docs' + for it in iterables: + for element in it: + yield element + + for c in (chain, chain2): + self.assertEqual(list(c('abc', 'def')), list('abcdef')) + self.assertEqual(list(c('abc')), list('abc')) + self.assertEqual(list(c('')), []) + self.assertEqual(take(4, c('abc', 'def')), list('abcd')) + self.assertRaises(TypeError, list,c(2, 3)) def test_chain_from_iterable(self): self.assertEqual(list(chain.from_iterable(['abc', 'def'])), list('abcdef')) @@ -121,6 +120,8 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(len(set(c)), r) # no duplicate elements self.assertEqual(list(c), sorted(c)) # keep original ordering self.assert_(all(e in values for e in c)) # elements taken from input iterable + self.assertEqual(list(c), + [e for e in values if e in c]) # comb is a subsequence of the input iterable self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version @@ -131,9 +132,10 @@ class TestBasicOps(unittest.TestCase): def test_permutations(self): self.assertRaises(TypeError, permutations) # too few arguments self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments -## self.assertRaises(TypeError, permutations, None) # pool is not iterable -## self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative -## self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big + self.assertRaises(TypeError, permutations, None) # pool is not iterable + self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative + self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big + self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None self.assertEqual(list(permutations(range(3), 2)), [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)]) @@ -186,7 +188,7 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(result, list(permutations(values))) # test default r # Test implementation detail: tuple re-use -## self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1) + self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1) self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1) def test_count(self): @@ -416,12 +418,46 @@ class TestBasicOps(unittest.TestCase): list(product(*args, **dict(repeat=r)))) self.assertEqual(len(list(product(*[range(7)]*6))), 7**6) self.assertRaises(TypeError, product, range(6), None) + + def product1(*args, **kwds): + pools = list(map(tuple, args)) * kwds.get('repeat', 1) + n = len(pools) + if n == 0: + yield () + return + if any(len(pool) == 0 for pool in pools): + return + indices = [0] * n + yield tuple(pool[i] for pool, i in zip(pools, indices)) + while 1: + for i in reversed(range(n)): # right to left + if indices[i] == len(pools[i]) - 1: + continue + indices[i] += 1 + for j in range(i+1, n): + indices[j] = 0 + yield tuple(pool[i] for pool, i in zip(pools, indices)) + break + else: + return + + def product2(*args, **kwds): + 'Pure python version used in docs' + pools = list(map(tuple, args)) * kwds.get('repeat', 1) + result = [[]] + for pool in pools: + result = [x+[y] for x in result for y in pool] + for prod in result: + yield tuple(prod) + argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3), set('abcdefg'), range(11), tuple(range(13))] for i in range(100): args = [random.choice(argtypes) for j in range(random.randrange(5))] expected_len = prod(map(len, args)) self.assertEqual(len(list(product(*args))), expected_len) + self.assertEqual(list(product(*args)), list(product1(*args))) + self.assertEqual(list(product(*args)), list(product2(*args))) args = map(iter, args) self.assertEqual(len(list(product(*args))), expected_len) @@ -661,6 +697,81 @@ class TestBasicOps(unittest.TestCase): self.assertRaises(StopIteration, next, f(lambda x:x, [])) self.assertRaises(StopIteration, next, f(lambda x:x, StopNow())) +class TestExamples(unittest.TestCase): + + def test_chain(self): + self.assertEqual(''.join(chain('ABC', 'DEF')), 'ABCDEF') + + def test_chain_from_iterable(self): + self.assertEqual(''.join(chain.from_iterable(['ABC', 'DEF'])), 'ABCDEF') + + def test_combinations(self): + self.assertEqual(list(combinations('ABCD', 2)), + [('A','B'), ('A','C'), ('A','D'), ('B','C'), ('B','D'), ('C','D')]) + self.assertEqual(list(combinations(range(4), 3)), + [(0,1,2), (0,1,3), (0,2,3), (1,2,3)]) + + def test_count(self): + self.assertEqual(list(islice(count(10), 5)), [10, 11, 12, 13, 14]) + + def test_cycle(self): + self.assertEqual(list(islice(cycle('ABCD'), 12)), list('ABCDABCDABCD')) + + def test_dropwhile(self): + self.assertEqual(list(dropwhile(lambda x: x<5, [1,4,6,4,1])), [6,4,1]) + + def test_groupby(self): + self.assertEqual([k for k, g in groupby('AAAABBBCCDAABBB')], + list('ABCDAB')) + self.assertEqual([(list(g)) for k, g in groupby('AAAABBBCCD')], + [list('AAAA'), list('BBB'), list('CC'), list('D')]) + + def test_filter(self): + self.assertEqual(list(filter(lambda x: x%2, range(10))), [1,3,5,7,9]) + + def test_filterfalse(self): + self.assertEqual(list(filterfalse(lambda x: x%2, range(10))), [0,2,4,6,8]) + + def test_map(self): + self.assertEqual(list(map(pow, (2,3,10), (5,2,3))), [32, 9, 1000]) + + def test_islice(self): + self.assertEqual(list(islice('ABCDEFG', 2)), list('AB')) + self.assertEqual(list(islice('ABCDEFG', 2, 4)), list('CD')) + self.assertEqual(list(islice('ABCDEFG', 2, None)), list('CDEFG')) + self.assertEqual(list(islice('ABCDEFG', 0, None, 2)), list('ACEG')) + + def test_zip(self): + self.assertEqual(list(zip('ABCD', 'xy')), [('A', 'x'), ('B', 'y')]) + + def test_zip_longest(self): + self.assertEqual(list(zip_longest('ABCD', 'xy', fillvalue='-')), + [('A', 'x'), ('B', 'y'), ('C', '-'), ('D', '-')]) + + def test_permutations(self): + self.assertEqual(list(permutations('ABCD', 2)), + list(map(tuple, 'AB AC AD BA BC BD CA CB CD DA DB DC'.split()))) + self.assertEqual(list(permutations(range(3))), + [(0,1,2), (0,2,1), (1,0,2), (1,2,0), (2,0,1), (2,1,0)]) + + def test_product(self): + self.assertEqual(list(product('ABCD', 'xy')), + list(map(tuple, 'Ax Ay Bx By Cx Cy Dx Dy'.split()))) + self.assertEqual(list(product(range(2), repeat=3)), + [(0,0,0), (0,0,1), (0,1,0), (0,1,1), + (1,0,0), (1,0,1), (1,1,0), (1,1,1)]) + + def test_repeat(self): + self.assertEqual(list(repeat(10, 3)), [10, 10, 10]) + + def test_stapmap(self): + self.assertEqual(list(starmap(pow, [(2,5), (3,2), (10,3)])), + [32, 9, 1000]) + + def test_takewhile(self): + self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4]) + + class TestGC(unittest.TestCase): def makecycle(self, iterator, container): @@ -672,6 +783,14 @@ class TestGC(unittest.TestCase): a = [] self.makecycle(chain(a), a) + def test_chain_from_iterable(self): + a = [] + self.makecycle(chain.from_iterable([a]), a) + + def test_combinations(self): + a = [] + self.makecycle(combinations([1,2,a,3], 3), a) + def test_cycle(self): a = [] self.makecycle(cycle([a]*2), a) @@ -684,6 +803,13 @@ class TestGC(unittest.TestCase): a = [] self.makecycle(groupby([a]*2, lambda x:x), a) + def test_issue2246(self): + # Issue 2246 -- the _grouper iterator was not included in GC + n = 10 + keyfunc = lambda x: x + for i, j in groupby(range(n), key=keyfunc): + keyfunc.__dict__.setdefault('x',[]).append(j) + def test_filter(self): a = [] self.makecycle(filter(lambda x:True, [a]*2), a) @@ -696,6 +822,12 @@ class TestGC(unittest.TestCase): a = [] self.makecycle(zip([a]*2, [a]*3), a) + def test_zip_longest(self): + a = [] + self.makecycle(zip_longest([a]*2, [a]*3), a) + b = [a, None] + self.makecycle(zip_longest([a]*2, [a]*3, fillvalue=b), a) + def test_map(self): a = [] self.makecycle(map(lambda x:x, [a]*2), a) @@ -704,6 +836,14 @@ class TestGC(unittest.TestCase): a = [] self.makecycle(islice([a]*2, None), a) + def test_permutations(self): + a = [] + self.makecycle(permutations([1,2,a,3], 3), a) + + def test_product(self): + a = [] + self.makecycle(product([1,2,a,3], repeat=3), a) + def test_repeat(self): a = [] self.makecycle(repeat(a), a) @@ -1115,7 +1255,7 @@ Samuele ... return sum(map(operator.mul, vec1, vec2)) >>> def flatten(listOfLists): -... return list(chain(*listOfLists)) +... return list(chain.from_iterable(listOfLists)) >>> def repeatfunc(func, times=None, *args): ... "Repeat calls to func with specified arguments." @@ -1134,6 +1274,38 @@ Samuele ... pass ... return zip(a, b) +>>> def grouper(n, iterable, fillvalue=None): +... "grouper(3, 'abcdefg', 'x') --> ('a','b','c'), ('d','e','f'), ('g','x','x')" +... args = [iter(iterable)] * n +... kwds = dict(fillvalue=fillvalue) +... return zip_longest(*args, **kwds) + +>>> def roundrobin(*iterables): +... "roundrobin('abc', 'd', 'ef') --> 'a', 'd', 'e', 'b', 'f', 'c'" +... # Recipe credited to George Sakkis +... pending = len(iterables) +... nexts = cycle(iter(it).__next__ for it in iterables) +... while pending: +... try: +... for next in nexts: +... yield next() +... except StopIteration: +... pending -= 1 +... nexts = cycle(islice(nexts, pending)) + +>>> def powerset(iterable): +... "powerset('ab') --> set([]), set(['a']), set(['b']), set(['a', 'b'])" +... # Recipe credited to Eric Raymond +... pairs = [(2**i, x) for i, x in enumerate(iterable)] +... for n in range(2**len(pairs)): +... yield set(x for m, x in pairs if m&n) + +>>> def compress(data, selectors): +... "compress('abcdef', [1,0,1,0,1,1]) --> a c e f" +... for d, s in zip(data, selectors): +... if s: +... yield d + This is not part of the examples but it tests to make sure the definitions perform as purported. @@ -1199,6 +1371,18 @@ False >>> dotproduct([1,2,3], [4,5,6]) 32 +>>> list(grouper(3, 'abcdefg', 'x')) +[('a', 'b', 'c'), ('d', 'e', 'f'), ('g', 'x', 'x')] + +>>> list(roundrobin('abc', 'd', 'ef')) +['a', 'd', 'e', 'b', 'f', 'c'] + +>>> list(map(sorted, powerset('ab'))) +[[], ['a'], ['b'], ['a', 'b']] + +>>> list(compress('abcdef', [1,0,1,0,1,1])) +['a', 'c', 'e', 'f'] + """ __test__ = {'libreftest' : libreftest} @@ -1206,7 +1390,7 @@ __test__ = {'libreftest' : libreftest} def test_main(verbose=None): test_classes = (TestBasicOps, TestVariousIteratorArgs, TestGC, RegressionTests, LengthTransparency, - SubclassWithKwargsTest) + SubclassWithKwargsTest, TestExamples) test_support.run_unittest(*test_classes) # verify reference counting |