summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_itertools.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_itertools.py')
-rw-r--r--Lib/test/test_itertools.py224
1 files changed, 204 insertions, 20 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index d44235b..335e47d 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -51,22 +51,21 @@ def fact(n):
'Factorial'
return prod(range(1, n+1))
-def permutations(iterable, r=None):
- # XXX use this until real permutations code is added
- pool = tuple(iterable)
- n = len(pool)
- r = n if r is None else r
- for indices in product(range(n), repeat=r):
- if len(set(indices)) == r:
- yield tuple(pool[i] for i in indices)
-
class TestBasicOps(unittest.TestCase):
def test_chain(self):
- self.assertEqual(list(chain('abc', 'def')), list('abcdef'))
- self.assertEqual(list(chain('abc')), list('abc'))
- self.assertEqual(list(chain('')), [])
- self.assertEqual(take(4, chain('abc', 'def')), list('abcd'))
- self.assertRaises(TypeError, list,chain(2, 3))
+
+ def chain2(*iterables):
+ 'Pure python version in the docs'
+ for it in iterables:
+ for element in it:
+ yield element
+
+ for c in (chain, chain2):
+ self.assertEqual(list(c('abc', 'def')), list('abcdef'))
+ self.assertEqual(list(c('abc')), list('abc'))
+ self.assertEqual(list(c('')), [])
+ self.assertEqual(take(4, c('abc', 'def')), list('abcd'))
+ self.assertRaises(TypeError, list,c(2, 3))
def test_chain_from_iterable(self):
self.assertEqual(list(chain.from_iterable(['abc', 'def'])), list('abcdef'))
@@ -121,6 +120,8 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(len(set(c)), r) # no duplicate elements
self.assertEqual(list(c), sorted(c)) # keep original ordering
self.assert_(all(e in values for e in c)) # elements taken from input iterable
+ self.assertEqual(list(c),
+ [e for e in values if e in c]) # comb is a subsequence of the input iterable
self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version
@@ -131,9 +132,10 @@ class TestBasicOps(unittest.TestCase):
def test_permutations(self):
self.assertRaises(TypeError, permutations) # too few arguments
self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
-## self.assertRaises(TypeError, permutations, None) # pool is not iterable
-## self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
-## self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big
+ self.assertRaises(TypeError, permutations, None) # pool is not iterable
+ self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
+ self.assertRaises(ValueError, permutations, 'abc', 32) # r is too big
+ self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
self.assertEqual(list(permutations(range(3), 2)),
[(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
@@ -186,7 +188,7 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(result, list(permutations(values))) # test default r
# Test implementation detail: tuple re-use
-## self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
+ self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
def test_count(self):
@@ -416,12 +418,46 @@ class TestBasicOps(unittest.TestCase):
list(product(*args, **dict(repeat=r))))
self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
self.assertRaises(TypeError, product, range(6), None)
+
+ def product1(*args, **kwds):
+ pools = list(map(tuple, args)) * kwds.get('repeat', 1)
+ n = len(pools)
+ if n == 0:
+ yield ()
+ return
+ if any(len(pool) == 0 for pool in pools):
+ return
+ indices = [0] * n
+ yield tuple(pool[i] for pool, i in zip(pools, indices))
+ while 1:
+ for i in reversed(range(n)): # right to left
+ if indices[i] == len(pools[i]) - 1:
+ continue
+ indices[i] += 1
+ for j in range(i+1, n):
+ indices[j] = 0
+ yield tuple(pool[i] for pool, i in zip(pools, indices))
+ break
+ else:
+ return
+
+ def product2(*args, **kwds):
+ 'Pure python version used in docs'
+ pools = list(map(tuple, args)) * kwds.get('repeat', 1)
+ result = [[]]
+ for pool in pools:
+ result = [x+[y] for x in result for y in pool]
+ for prod in result:
+ yield tuple(prod)
+
argtypes = ['', 'abc', '', range(0), range(4), dict(a=1, b=2, c=3),
set('abcdefg'), range(11), tuple(range(13))]
for i in range(100):
args = [random.choice(argtypes) for j in range(random.randrange(5))]
expected_len = prod(map(len, args))
self.assertEqual(len(list(product(*args))), expected_len)
+ self.assertEqual(list(product(*args)), list(product1(*args)))
+ self.assertEqual(list(product(*args)), list(product2(*args)))
args = map(iter, args)
self.assertEqual(len(list(product(*args))), expected_len)
@@ -661,6 +697,81 @@ class TestBasicOps(unittest.TestCase):
self.assertRaises(StopIteration, next, f(lambda x:x, []))
self.assertRaises(StopIteration, next, f(lambda x:x, StopNow()))
+class TestExamples(unittest.TestCase):
+
+ def test_chain(self):
+ self.assertEqual(''.join(chain('ABC', 'DEF')), 'ABCDEF')
+
+ def test_chain_from_iterable(self):
+ self.assertEqual(''.join(chain.from_iterable(['ABC', 'DEF'])), 'ABCDEF')
+
+ def test_combinations(self):
+ self.assertEqual(list(combinations('ABCD', 2)),
+ [('A','B'), ('A','C'), ('A','D'), ('B','C'), ('B','D'), ('C','D')])
+ self.assertEqual(list(combinations(range(4), 3)),
+ [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
+
+ def test_count(self):
+ self.assertEqual(list(islice(count(10), 5)), [10, 11, 12, 13, 14])
+
+ def test_cycle(self):
+ self.assertEqual(list(islice(cycle('ABCD'), 12)), list('ABCDABCDABCD'))
+
+ def test_dropwhile(self):
+ self.assertEqual(list(dropwhile(lambda x: x<5, [1,4,6,4,1])), [6,4,1])
+
+ def test_groupby(self):
+ self.assertEqual([k for k, g in groupby('AAAABBBCCDAABBB')],
+ list('ABCDAB'))
+ self.assertEqual([(list(g)) for k, g in groupby('AAAABBBCCD')],
+ [list('AAAA'), list('BBB'), list('CC'), list('D')])
+
+ def test_filter(self):
+ self.assertEqual(list(filter(lambda x: x%2, range(10))), [1,3,5,7,9])
+
+ def test_filterfalse(self):
+ self.assertEqual(list(filterfalse(lambda x: x%2, range(10))), [0,2,4,6,8])
+
+ def test_map(self):
+ self.assertEqual(list(map(pow, (2,3,10), (5,2,3))), [32, 9, 1000])
+
+ def test_islice(self):
+ self.assertEqual(list(islice('ABCDEFG', 2)), list('AB'))
+ self.assertEqual(list(islice('ABCDEFG', 2, 4)), list('CD'))
+ self.assertEqual(list(islice('ABCDEFG', 2, None)), list('CDEFG'))
+ self.assertEqual(list(islice('ABCDEFG', 0, None, 2)), list('ACEG'))
+
+ def test_zip(self):
+ self.assertEqual(list(zip('ABCD', 'xy')), [('A', 'x'), ('B', 'y')])
+
+ def test_zip_longest(self):
+ self.assertEqual(list(zip_longest('ABCD', 'xy', fillvalue='-')),
+ [('A', 'x'), ('B', 'y'), ('C', '-'), ('D', '-')])
+
+ def test_permutations(self):
+ self.assertEqual(list(permutations('ABCD', 2)),
+ list(map(tuple, 'AB AC AD BA BC BD CA CB CD DA DB DC'.split())))
+ self.assertEqual(list(permutations(range(3))),
+ [(0,1,2), (0,2,1), (1,0,2), (1,2,0), (2,0,1), (2,1,0)])
+
+ def test_product(self):
+ self.assertEqual(list(product('ABCD', 'xy')),
+ list(map(tuple, 'Ax Ay Bx By Cx Cy Dx Dy'.split())))
+ self.assertEqual(list(product(range(2), repeat=3)),
+ [(0,0,0), (0,0,1), (0,1,0), (0,1,1),
+ (1,0,0), (1,0,1), (1,1,0), (1,1,1)])
+
+ def test_repeat(self):
+ self.assertEqual(list(repeat(10, 3)), [10, 10, 10])
+
+ def test_stapmap(self):
+ self.assertEqual(list(starmap(pow, [(2,5), (3,2), (10,3)])),
+ [32, 9, 1000])
+
+ def test_takewhile(self):
+ self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4])
+
+
class TestGC(unittest.TestCase):
def makecycle(self, iterator, container):
@@ -672,6 +783,14 @@ class TestGC(unittest.TestCase):
a = []
self.makecycle(chain(a), a)
+ def test_chain_from_iterable(self):
+ a = []
+ self.makecycle(chain.from_iterable([a]), a)
+
+ def test_combinations(self):
+ a = []
+ self.makecycle(combinations([1,2,a,3], 3), a)
+
def test_cycle(self):
a = []
self.makecycle(cycle([a]*2), a)
@@ -684,6 +803,13 @@ class TestGC(unittest.TestCase):
a = []
self.makecycle(groupby([a]*2, lambda x:x), a)
+ def test_issue2246(self):
+ # Issue 2246 -- the _grouper iterator was not included in GC
+ n = 10
+ keyfunc = lambda x: x
+ for i, j in groupby(range(n), key=keyfunc):
+ keyfunc.__dict__.setdefault('x',[]).append(j)
+
def test_filter(self):
a = []
self.makecycle(filter(lambda x:True, [a]*2), a)
@@ -696,6 +822,12 @@ class TestGC(unittest.TestCase):
a = []
self.makecycle(zip([a]*2, [a]*3), a)
+ def test_zip_longest(self):
+ a = []
+ self.makecycle(zip_longest([a]*2, [a]*3), a)
+ b = [a, None]
+ self.makecycle(zip_longest([a]*2, [a]*3, fillvalue=b), a)
+
def test_map(self):
a = []
self.makecycle(map(lambda x:x, [a]*2), a)
@@ -704,6 +836,14 @@ class TestGC(unittest.TestCase):
a = []
self.makecycle(islice([a]*2, None), a)
+ def test_permutations(self):
+ a = []
+ self.makecycle(permutations([1,2,a,3], 3), a)
+
+ def test_product(self):
+ a = []
+ self.makecycle(product([1,2,a,3], repeat=3), a)
+
def test_repeat(self):
a = []
self.makecycle(repeat(a), a)
@@ -1115,7 +1255,7 @@ Samuele
... return sum(map(operator.mul, vec1, vec2))
>>> def flatten(listOfLists):
-... return list(chain(*listOfLists))
+... return list(chain.from_iterable(listOfLists))
>>> def repeatfunc(func, times=None, *args):
... "Repeat calls to func with specified arguments."
@@ -1134,6 +1274,38 @@ Samuele
... pass
... return zip(a, b)
+>>> def grouper(n, iterable, fillvalue=None):
+... "grouper(3, 'abcdefg', 'x') --> ('a','b','c'), ('d','e','f'), ('g','x','x')"
+... args = [iter(iterable)] * n
+... kwds = dict(fillvalue=fillvalue)
+... return zip_longest(*args, **kwds)
+
+>>> def roundrobin(*iterables):
+... "roundrobin('abc', 'd', 'ef') --> 'a', 'd', 'e', 'b', 'f', 'c'"
+... # Recipe credited to George Sakkis
+... pending = len(iterables)
+... nexts = cycle(iter(it).__next__ for it in iterables)
+... while pending:
+... try:
+... for next in nexts:
+... yield next()
+... except StopIteration:
+... pending -= 1
+... nexts = cycle(islice(nexts, pending))
+
+>>> def powerset(iterable):
+... "powerset('ab') --> set([]), set(['a']), set(['b']), set(['a', 'b'])"
+... # Recipe credited to Eric Raymond
+... pairs = [(2**i, x) for i, x in enumerate(iterable)]
+... for n in range(2**len(pairs)):
+... yield set(x for m, x in pairs if m&n)
+
+>>> def compress(data, selectors):
+... "compress('abcdef', [1,0,1,0,1,1]) --> a c e f"
+... for d, s in zip(data, selectors):
+... if s:
+... yield d
+
This is not part of the examples but it tests to make sure the definitions
perform as purported.
@@ -1199,6 +1371,18 @@ False
>>> dotproduct([1,2,3], [4,5,6])
32
+>>> list(grouper(3, 'abcdefg', 'x'))
+[('a', 'b', 'c'), ('d', 'e', 'f'), ('g', 'x', 'x')]
+
+>>> list(roundrobin('abc', 'd', 'ef'))
+['a', 'd', 'e', 'b', 'f', 'c']
+
+>>> list(map(sorted, powerset('ab')))
+[[], ['a'], ['b'], ['a', 'b']]
+
+>>> list(compress('abcdef', [1,0,1,0,1,1]))
+['a', 'c', 'e', 'f']
+
"""
__test__ = {'libreftest' : libreftest}
@@ -1206,7 +1390,7 @@ __test__ = {'libreftest' : libreftest}
def test_main(verbose=None):
test_classes = (TestBasicOps, TestVariousIteratorArgs, TestGC,
RegressionTests, LengthTransparency,
- SubclassWithKwargsTest)
+ SubclassWithKwargsTest, TestExamples)
test_support.run_unittest(*test_classes)
# verify reference counting