summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2022-10-21 17:31:52 (GMT)
committerGitHub <noreply@github.com>2022-10-21 17:31:52 (GMT)
commita5ff80c8bc96210bace3ffb683b01fbd7f4ab76d (patch)
treeb1aee731a065f83542a5cdb32b769c966f9b86c4
parentec1f6f5f139868dc2c1116a7c7c878c38c668d53 (diff)
downloadcpython-a5ff80c8bc96210bace3ffb683b01fbd7f4ab76d.zip
cpython-a5ff80c8bc96210bace3ffb683b01fbd7f4ab76d.tar.gz
cpython-a5ff80c8bc96210bace3ffb683b01fbd7f4ab76d.tar.bz2
GH-98363: Fix exception handling in batched() (GH-98523)
-rw-r--r--Lib/test/test_itertools.py15
-rw-r--r--Modules/itertoolsmodule.c29
2 files changed, 36 insertions, 8 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index c0e3571..a0a740f 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -2012,6 +2012,20 @@ class E:
def __next__(self):
3 // 0
+class E2:
+ 'Test propagation of exceptions after two iterations'
+ def __init__(self, seqn):
+ self.seqn = seqn
+ self.i = 0
+ def __iter__(self):
+ return self
+ def __next__(self):
+ if self.i == 2:
+ raise ZeroDivisionError
+ v = self.seqn[self.i]
+ self.i += 1
+ return v
+
class S:
'Test immediate stop'
def __init__(self, seqn):
@@ -2050,6 +2064,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
self.assertRaises(TypeError, batched, X(s), 2)
self.assertRaises(TypeError, batched, N(s), 2)
self.assertRaises(ZeroDivisionError, list, batched(E(s), 2))
+ self.assertRaises(ZeroDivisionError, list, batched(E2(s), 4))
def test_chain(self):
for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index 868e8a8..627e698 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -154,23 +154,36 @@ batched_next(batchedobject *bo)
if (result == NULL) {
return NULL;
}
+ iternextfunc iternext = *Py_TYPE(it)->tp_iternext;
+ PyObject **items = PySequence_Fast_ITEMS(result);
for (i=0 ; i < n ; i++) {
- item = PyIter_Next(it);
+ item = iternext(it);
if (item == NULL) {
- break;
+ goto null_item;
+ }
+ items[i] = item;
+ }
+ return result;
+
+ null_item:
+ if (PyErr_Occurred()) {
+ if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
+ PyErr_Clear();
+ } else {
+ /* input raised an exception other than StopIteration */
+ Py_CLEAR(bo->it);
+ Py_DECREF(result);
+ return NULL;
}
- PyList_SET_ITEM(result, i, item);
}
if (i == 0) {
Py_CLEAR(bo->it);
Py_DECREF(result);
return NULL;
}
- if (i < n) {
- PyObject *short_list = PyList_GetSlice(result, 0, i);
- Py_SETREF(result, short_list);
- }
- return result;
+ PyObject *short_list = PyList_GetSlice(result, 0, i);
+ Py_DECREF(result);
+ return short_list;
}
static PyTypeObject batched_type = {