summaryrefslogtreecommitdiffstats
path: root/Lib/test
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2022-10-17 23:53:45 (GMT)
committerGitHub <noreply@github.com>2022-10-17 23:53:45 (GMT)
commitde3ece769a8bc10c207a648c8a446f520504fa7e (patch)
treed44b4a58b8267d87e16289d781b0f65cefaaccf2 /Lib/test
parent70732d8a4c98cdf3cc9efa5241ce33fb9bc323ca (diff)
downloadcpython-de3ece769a8bc10c207a648c8a446f520504fa7e.zip
cpython-de3ece769a8bc10c207a648c8a446f520504fa7e.tar.gz
cpython-de3ece769a8bc10c207a648c8a446f520504fa7e.tar.bz2
GH-98363: Add itertools.batched() (GH-98364)
Diffstat (limited to 'Lib/test')
-rw-r--r--Lib/test/test_itertools.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index f469bfe..c0e3571 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -159,6 +159,44 @@ class TestBasicOps(unittest.TestCase):
with self.assertRaises(TypeError):
list(accumulate([10, 20], 100))
+ def test_batched(self):
+ self.assertEqual(list(batched('ABCDEFG', 3)),
+ [['A', 'B', 'C'], ['D', 'E', 'F'], ['G']])
+ self.assertEqual(list(batched('ABCDEFG', 2)),
+ [['A', 'B'], ['C', 'D'], ['E', 'F'], ['G']])
+ self.assertEqual(list(batched('ABCDEFG', 1)),
+ [['A'], ['B'], ['C'], ['D'], ['E'], ['F'], ['G']])
+
+ with self.assertRaises(TypeError): # Too few arguments
+ list(batched('ABCDEFG'))
+ with self.assertRaises(TypeError):
+ list(batched('ABCDEFG', 3, None)) # Too many arguments
+ with self.assertRaises(TypeError):
+ list(batched(None, 3)) # Non-iterable input
+ with self.assertRaises(TypeError):
+ list(batched('ABCDEFG', 'hello')) # n is a string
+ with self.assertRaises(ValueError):
+ list(batched('ABCDEFG', 0)) # n is zero
+ with self.assertRaises(ValueError):
+ list(batched('ABCDEFG', -1)) # n is negative
+
+ data = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
+ for n in range(1, 6):
+ for i in range(len(data)):
+ s = data[:i]
+ batches = list(batched(s, n))
+ with self.subTest(s=s, n=n, batches=batches):
+ # Order is preserved and no data is lost
+ self.assertEqual(''.join(chain(*batches)), s)
+ # Each batch is an exact list
+ self.assertTrue(all(type(batch) is list for batch in batches))
+ # All but the last batch is of size n
+ if batches:
+ last_batch = batches.pop()
+ self.assertTrue(all(len(batch) == n for batch in batches))
+ self.assertTrue(len(last_batch) <= n)
+ batches.append(last_batch)
+
def test_chain(self):
def chain2(*iterables):
@@ -1737,6 +1775,31 @@ class TestExamples(unittest.TestCase):
class TestPurePythonRoughEquivalents(unittest.TestCase):
+ def test_batched_recipe(self):
+ def batched_recipe(iterable, n):
+ "Batch data into lists of length n. The last batch may be shorter."
+ # batched('ABCDEFG', 3) --> ABC DEF G
+ if n < 1:
+ raise ValueError('n must be at least one')
+ it = iter(iterable)
+ while (batch := list(islice(it, n))):
+ yield batch
+
+ for iterable, n in product(
+ ['', 'a', 'ab', 'abc', 'abcd', 'abcde', 'abcdef', 'abcdefg', None],
+ [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, None]):
+ with self.subTest(iterable=iterable, n=n):
+ try:
+ e1, r1 = None, list(batched(iterable, n))
+ except Exception as e:
+ e1, r1 = type(e), None
+ try:
+ e2, r2 = None, list(batched_recipe(iterable, n))
+ except Exception as e:
+ e2, r2 = type(e), None
+ self.assertEqual(r1, r2)
+ self.assertEqual(e1, e2)
+
@staticmethod
def islice(iterable, *args):
s = slice(*args)
@@ -1788,6 +1851,10 @@ class TestGC(unittest.TestCase):
a = []
self.makecycle(accumulate([1,2,a,3]), a)
+ def test_batched(self):
+ a = []
+ self.makecycle(batched([1,2,a,3], 2), a)
+
def test_chain(self):
a = []
self.makecycle(chain(a), a)
@@ -1972,6 +2039,18 @@ class TestVariousIteratorArgs(unittest.TestCase):
self.assertRaises(TypeError, accumulate, N(s))
self.assertRaises(ZeroDivisionError, list, accumulate(E(s)))
+ def test_batched(self):
+ s = 'abcde'
+ r = [['a', 'b'], ['c', 'd'], ['e']]
+ n = 2
+ for g in (G, I, Ig, L, R):
+ with self.subTest(g=g):
+ self.assertEqual(list(batched(g(s), n)), r)
+ self.assertEqual(list(batched(S(s), 2)), [])
+ self.assertRaises(TypeError, batched, X(s), 2)
+ self.assertRaises(TypeError, batched, N(s), 2)
+ self.assertRaises(ZeroDivisionError, list, batched(E(s), 2))
+
def test_chain(self):
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
for g in (G, I, Ig, S, L, R):