From a5ff80c8bc96210bace3ffb683b01fbd7f4ab76d Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Fri, 21 Oct 2022 12:31:52 -0500 Subject: GH-98363: Fix exception handling in batched() (GH-98523) --- Lib/test/test_itertools.py | 15 +++++++++++++++ Modules/itertoolsmodule.c | 29 +++++++++++++++++++++-------- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index c0e3571..a0a740f 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -2012,6 +2012,20 @@ class E: def __next__(self): 3 // 0 +class E2: + 'Test propagation of exceptions after two iterations' + def __init__(self, seqn): + self.seqn = seqn + self.i = 0 + def __iter__(self): + return self + def __next__(self): + if self.i == 2: + raise ZeroDivisionError + v = self.seqn[self.i] + self.i += 1 + return v + class S: 'Test immediate stop' def __init__(self, seqn): @@ -2050,6 +2064,7 @@ class TestVariousIteratorArgs(unittest.TestCase): self.assertRaises(TypeError, batched, X(s), 2) self.assertRaises(TypeError, batched, N(s), 2) self.assertRaises(ZeroDivisionError, list, batched(E(s), 2)) + self.assertRaises(ZeroDivisionError, list, batched(E2(s), 4)) def test_chain(self): for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)): diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 868e8a8..627e698 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -154,23 +154,36 @@ batched_next(batchedobject *bo) if (result == NULL) { return NULL; } + iternextfunc iternext = *Py_TYPE(it)->tp_iternext; + PyObject **items = PySequence_Fast_ITEMS(result); for (i=0 ; i < n ; i++) { - item = PyIter_Next(it); + item = iternext(it); if (item == NULL) { - break; + goto null_item; + } + items[i] = item; + } + return result; + + null_item: + if (PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_StopIteration)) { + PyErr_Clear(); + } else { + /* input raised an exception other than StopIteration */ + Py_CLEAR(bo->it); + Py_DECREF(result); + return NULL; } - PyList_SET_ITEM(result, i, item); } if (i == 0) { Py_CLEAR(bo->it); Py_DECREF(result); return NULL; } - if (i < n) { - PyObject *short_list = PyList_GetSlice(result, 0, i); - Py_SETREF(result, short_list); - } - return result; + PyObject *short_list = PyList_GetSlice(result, 0, i); + Py_DECREF(result); + return short_list; } static PyTypeObject batched_type = { -- cgit v0.12