summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2023-12-16 15:13:50 (GMT)
committerGitHub <noreply@github.com>2023-12-16 15:13:50 (GMT)
commit1583c40be938d2caf363c126976bc8757df90b13 (patch)
tree13c6360133c2dcb3d58c255fba500e9090e383a5
parentfe479fb8a979894224a4d279d1e46a5cdb108fa4 (diff)
downloadcpython-1583c40be938d2caf363c126976bc8757df90b13.zip
cpython-1583c40be938d2caf363c126976bc8757df90b13.tar.gz
cpython-1583c40be938d2caf363c126976bc8757df90b13.tar.bz2
gh-113202: Add a strict option to itertools.batched() (gh-113203)
-rw-r--r--Doc/library/itertools.rst18
-rw-r--r--Lib/test/test_itertools.py4
-rw-r--r--Misc/NEWS.d/next/Library/2023-12-15-18-10-26.gh-issue-113202.xv_Ww8.rst1
-rw-r--r--Modules/clinic/itertoolsmodule.c.h32
-rw-r--r--Modules/itertoolsmodule.c29
5 files changed, 60 insertions, 24 deletions
diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst
index 6bcda30..c016fb7 100644
--- a/Doc/library/itertools.rst
+++ b/Doc/library/itertools.rst
@@ -164,11 +164,14 @@ loops that truncate the stream.
Added the optional *initial* parameter.
-.. function:: batched(iterable, n)
+.. function:: batched(iterable, n, *, strict=False)
Batch data from the *iterable* into tuples of length *n*. The last
batch may be shorter than *n*.
+ If *strict* is true, will raise a :exc:`ValueError` if the final
+ batch is shorter than *n*.
+
Loops over the input iterable and accumulates data into tuples up to
size *n*. The input is consumed lazily, just enough to fill a batch.
The result is yielded as soon as the batch is full or when the input
@@ -190,16 +193,21 @@ loops that truncate the stream.
Roughly equivalent to::
- def batched(iterable, n):
+ def batched(iterable, n, *, strict=False):
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError('n must be at least one')
it = iter(iterable)
while batch := tuple(islice(it, n)):
+ if strict and len(batch) != n:
+ raise ValueError('batched(): incomplete batch')
yield batch
.. versionadded:: 3.12
+ .. versionchanged:: 3.13
+ Added the *strict* option.
+
.. function:: chain(*iterables)
@@ -1039,7 +1047,7 @@ The following recipes have a more mathematical flavor:
def reshape(matrix, cols):
"Reshape a 2-D matrix to have a given number of columns."
# reshape([(0, 1), (2, 3), (4, 5)], 3) --> (0, 1, 2), (3, 4, 5)
- return batched(chain.from_iterable(matrix), cols)
+ return batched(chain.from_iterable(matrix), cols, strict=True)
def transpose(matrix):
"Swap the rows and columns of a 2-D matrix."
@@ -1270,6 +1278,10 @@ The following recipes have a more mathematical flavor:
[(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10, 11)]
>>> list(reshape(M, 4))
[(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)]
+ >>> list(reshape(M, 5))
+ Traceback (most recent call last):
+ ...
+ ValueError: batched(): incomplete batch
>>> list(reshape(M, 6))
[(0, 1, 2, 3, 4, 5), (6, 7, 8, 9, 10, 11)]
>>> list(reshape(M, 12))
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index 705e880..9af0730 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -187,7 +187,11 @@ class TestBasicOps(unittest.TestCase):
[('A', 'B'), ('C', 'D'), ('E', 'F'), ('G',)])
self.assertEqual(list(batched('ABCDEFG', 1)),
[('A',), ('B',), ('C',), ('D',), ('E',), ('F',), ('G',)])
+ self.assertEqual(list(batched('ABCDEF', 2, strict=True)),
+ [('A', 'B'), ('C', 'D'), ('E', 'F')])
+ with self.assertRaises(ValueError): # Incomplete batch when strict
+ list(batched('ABCDEFG', 3, strict=True))
with self.assertRaises(TypeError): # Too few arguments
list(batched('ABCDEFG'))
with self.assertRaises(TypeError):
diff --git a/Misc/NEWS.d/next/Library/2023-12-15-18-10-26.gh-issue-113202.xv_Ww8.rst b/Misc/NEWS.d/next/Library/2023-12-15-18-10-26.gh-issue-113202.xv_Ww8.rst
new file mode 100644
index 0000000..44f26ae
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2023-12-15-18-10-26.gh-issue-113202.xv_Ww8.rst
@@ -0,0 +1 @@
+Add a ``strict`` option to ``batched()`` in the ``itertools`` module.
diff --git a/Modules/clinic/itertoolsmodule.c.h b/Modules/clinic/itertoolsmodule.c.h
index fa2c5e0..3ec4799 100644
--- a/Modules/clinic/itertoolsmodule.c.h
+++ b/Modules/clinic/itertoolsmodule.c.h
@@ -10,7 +10,7 @@ preserve
#include "pycore_modsupport.h" // _PyArg_UnpackKeywords()
PyDoc_STRVAR(batched_new__doc__,
-"batched(iterable, n)\n"
+"batched(iterable, n, *, strict=False)\n"
"--\n"
"\n"
"Batch data into tuples of length n. The last batch may be shorter than n.\n"
@@ -25,10 +25,14 @@ PyDoc_STRVAR(batched_new__doc__,
" ...\n"
" (\'A\', \'B\', \'C\')\n"
" (\'D\', \'E\', \'F\')\n"
-" (\'G\',)");
+" (\'G\',)\n"
+"\n"
+"If \"strict\" is True, raises a ValueError if the final batch is shorter\n"
+"than n.");
static PyObject *
-batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n);
+batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n,
+ int strict);
static PyObject *
batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
@@ -36,14 +40,14 @@ batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
PyObject *return_value = NULL;
#if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE)
- #define NUM_KEYWORDS 2
+ #define NUM_KEYWORDS 3
static struct {
PyGC_Head _this_is_not_used;
PyObject_VAR_HEAD
PyObject *ob_item[NUM_KEYWORDS];
} _kwtuple = {
.ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS)
- .ob_item = { &_Py_ID(iterable), &_Py_ID(n), },
+ .ob_item = { &_Py_ID(iterable), &_Py_ID(n), &_Py_ID(strict), },
};
#undef NUM_KEYWORDS
#define KWTUPLE (&_kwtuple.ob_base.ob_base)
@@ -52,18 +56,20 @@ batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
# define KWTUPLE NULL
#endif // !Py_BUILD_CORE
- static const char * const _keywords[] = {"iterable", "n", NULL};
+ static const char * const _keywords[] = {"iterable", "n", "strict", NULL};
static _PyArg_Parser _parser = {
.keywords = _keywords,
.fname = "batched",
.kwtuple = KWTUPLE,
};
#undef KWTUPLE
- PyObject *argsbuf[2];
+ PyObject *argsbuf[3];
PyObject * const *fastargs;
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
+ Py_ssize_t noptargs = nargs + (kwargs ? PyDict_GET_SIZE(kwargs) : 0) - 2;
PyObject *iterable;
Py_ssize_t n;
+ int strict = 0;
fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 2, 2, 0, argsbuf);
if (!fastargs) {
@@ -82,7 +88,15 @@ batched_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
}
n = ival;
}
- return_value = batched_new_impl(type, iterable, n);
+ if (!noptargs) {
+ goto skip_optional_kwonly;
+ }
+ strict = PyObject_IsTrue(fastargs[2]);
+ if (strict < 0) {
+ goto exit;
+ }
+skip_optional_kwonly:
+ return_value = batched_new_impl(type, iterable, n, strict);
exit:
return return_value;
@@ -914,4 +928,4 @@ skip_optional_pos:
exit:
return return_value;
}
-/*[clinic end generated code: output=782fe7e30733779b input=a9049054013a1b77]*/
+/*[clinic end generated code: output=c6a515f765da86b5 input=a9049054013a1b77]*/
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index ab99fa4..1647414 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -105,20 +105,11 @@ class itertools.pairwise "pairwiseobject *" "clinic_state()->pairwise_type"
/* batched object ************************************************************/
-/* Note: The built-in zip() function includes a "strict" argument
- that was needed because that function would silently truncate data,
- and there was no easy way for a user to detect the data loss.
- The same reasoning does not apply to batched() which never drops data.
- Instead, batched() produces a shorter tuple which can be handled
- as the user sees fit. If requested, it would be reasonable to add
- "fillvalue" support which had demonstrated value in zip_longest().
- For now, the API is kept simple and clean.
- */
-
typedef struct {
PyObject_HEAD
PyObject *it;
Py_ssize_t batch_size;
+ bool strict;
} batchedobject;
/*[clinic input]
@@ -126,6 +117,9 @@ typedef struct {
itertools.batched.__new__ as batched_new
iterable: object
n: Py_ssize_t
+ *
+ strict: bool = False
+
Batch data into tuples of length n. The last batch may be shorter than n.
Loops over the input iterable and accumulates data into tuples
@@ -140,11 +134,15 @@ or when the input iterable is exhausted.
('D', 'E', 'F')
('G',)
+If "strict" is True, raises a ValueError if the final batch is shorter
+than n.
+
[clinic start generated code]*/
static PyObject *
-batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
-/*[clinic end generated code: output=7ebc954d655371b6 input=ffd70726927c5129]*/
+batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n,
+ int strict)
+/*[clinic end generated code: output=c6de11b061529d3e input=7814b47e222f5467]*/
{
PyObject *it;
batchedobject *bo;
@@ -170,6 +168,7 @@ batched_new_impl(PyTypeObject *type, PyObject *iterable, Py_ssize_t n)
}
bo->batch_size = n;
bo->it = it;
+ bo->strict = (bool) strict;
return (PyObject *)bo;
}
@@ -233,6 +232,12 @@ batched_next(batchedobject *bo)
Py_DECREF(result);
return NULL;
}
+ if (bo->strict) {
+ Py_CLEAR(bo->it);
+ Py_DECREF(result);
+ PyErr_SetString(PyExc_ValueError, "batched(): incomplete batch");
+ return NULL;
+ }
_PyTuple_Resize(&result, i);
return result;
}