diff options
author | Raymond Hettinger <rhettinger@users.noreply.github.com> | 2022-10-17 23:53:45 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-17 23:53:45 (GMT) |
commit | de3ece769a8bc10c207a648c8a446f520504fa7e (patch) | |
tree | d44b4a58b8267d87e16289d781b0f65cefaaccf2 /Lib/test | |
parent | 70732d8a4c98cdf3cc9efa5241ce33fb9bc323ca (diff) | |
download | cpython-de3ece769a8bc10c207a648c8a446f520504fa7e.zip cpython-de3ece769a8bc10c207a648c8a446f520504fa7e.tar.gz cpython-de3ece769a8bc10c207a648c8a446f520504fa7e.tar.bz2 |
GH-98363: Add itertools.batched() (GH-98364)
Diffstat (limited to 'Lib/test')
-rw-r--r-- | Lib/test/test_itertools.py | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index f469bfe..c0e3571 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -159,6 +159,44 @@ class TestBasicOps(unittest.TestCase): with self.assertRaises(TypeError): list(accumulate([10, 20], 100)) + def test_batched(self): + self.assertEqual(list(batched('ABCDEFG', 3)), + [['A', 'B', 'C'], ['D', 'E', 'F'], ['G']]) + self.assertEqual(list(batched('ABCDEFG', 2)), + [['A', 'B'], ['C', 'D'], ['E', 'F'], ['G']]) + self.assertEqual(list(batched('ABCDEFG', 1)), + [['A'], ['B'], ['C'], ['D'], ['E'], ['F'], ['G']]) + + with self.assertRaises(TypeError): # Too few arguments + list(batched('ABCDEFG')) + with self.assertRaises(TypeError): + list(batched('ABCDEFG', 3, None)) # Too many arguments + with self.assertRaises(TypeError): + list(batched(None, 3)) # Non-iterable input + with self.assertRaises(TypeError): + list(batched('ABCDEFG', 'hello')) # n is a string + with self.assertRaises(ValueError): + list(batched('ABCDEFG', 0)) # n is zero + with self.assertRaises(ValueError): + list(batched('ABCDEFG', -1)) # n is negative + + data = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + for n in range(1, 6): + for i in range(len(data)): + s = data[:i] + batches = list(batched(s, n)) + with self.subTest(s=s, n=n, batches=batches): + # Order is preserved and no data is lost + self.assertEqual(''.join(chain(*batches)), s) + # Each batch is an exact list + self.assertTrue(all(type(batch) is list for batch in batches)) + # All but the last batch is of size n + if batches: + last_batch = batches.pop() + self.assertTrue(all(len(batch) == n for batch in batches)) + self.assertTrue(len(last_batch) <= n) + batches.append(last_batch) + def test_chain(self): def chain2(*iterables): @@ -1737,6 +1775,31 @@ class TestExamples(unittest.TestCase): class TestPurePythonRoughEquivalents(unittest.TestCase): + def test_batched_recipe(self): + def batched_recipe(iterable, n): + "Batch data into lists of length n. The last batch may be shorter." + # batched('ABCDEFG', 3) --> ABC DEF G + if n < 1: + raise ValueError('n must be at least one') + it = iter(iterable) + while (batch := list(islice(it, n))): + yield batch + + for iterable, n in product( + ['', 'a', 'ab', 'abc', 'abcd', 'abcde', 'abcdef', 'abcdefg', None], + [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, None]): + with self.subTest(iterable=iterable, n=n): + try: + e1, r1 = None, list(batched(iterable, n)) + except Exception as e: + e1, r1 = type(e), None + try: + e2, r2 = None, list(batched_recipe(iterable, n)) + except Exception as e: + e2, r2 = type(e), None + self.assertEqual(r1, r2) + self.assertEqual(e1, e2) + @staticmethod def islice(iterable, *args): s = slice(*args) @@ -1788,6 +1851,10 @@ class TestGC(unittest.TestCase): a = [] self.makecycle(accumulate([1,2,a,3]), a) + def test_batched(self): + a = [] + self.makecycle(batched([1,2,a,3], 2), a) + def test_chain(self): a = [] self.makecycle(chain(a), a) @@ -1972,6 +2039,18 @@ class TestVariousIteratorArgs(unittest.TestCase): self.assertRaises(TypeError, accumulate, N(s)) self.assertRaises(ZeroDivisionError, list, accumulate(E(s))) + def test_batched(self): + s = 'abcde' + r = [['a', 'b'], ['c', 'd'], ['e']] + n = 2 + for g in (G, I, Ig, L, R): + with self.subTest(g=g): + self.assertEqual(list(batched(g(s), n)), r) + self.assertEqual(list(batched(S(s), 2)), []) + self.assertRaises(TypeError, batched, X(s), 2) + self.assertRaises(TypeError, batched, N(s), 2) + self.assertRaises(ZeroDivisionError, list, batched(E(s), 2)) + def test_chain(self): for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)): for g in (G, I, Ig, S, L, R): |