diff options
-rw-r--r-- | Lib/test/test_itertools.py | 30 | ||||
-rw-r--r-- | Modules/itertoolsmodule.c | 12 |
2 files changed, 42 insertions, 0 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 73e8809..54e46e1 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -644,6 +644,36 @@ class RegressionTests(unittest.TestCase): self.assertEqual(first, second) + def test_sf_950057(self): + # Make sure that chain() and cycle() catch exceptions immediately + # rather than when shifting between input sources + + def gen1(): + hist.append(0) + yield 1 + hist.append(1) + assert False + hist.append(2) + + def gen2(x): + hist.append(3) + yield 2 + hist.append(4) + if x: + raise StopIteration + + hist = [] + self.assertRaises(AssertionError, list, chain(gen1(), gen2(False))) + self.assertEqual(hist, [0,1]) + + hist = [] + self.assertRaises(AssertionError, list, chain(gen1(), gen2(True))) + self.assertEqual(hist, [0,1]) + + hist = [] + self.assertRaises(AssertionError, list, cycle(gen1())) + self.assertEqual(hist, [0,1]) + libreftest = """ Doctest for examples in the library reference: libitertools.tex diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 4ce4643..3515bc6 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -699,6 +699,12 @@ cycle_next(cycleobject *lz) PyList_Append(lz->saved, item); return item; } + if (PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_StopIteration)) + PyErr_Clear(); + else + return NULL; + } if (PyList_Size(lz->saved) == 0) return NULL; it = PyObject_GetIter(lz->saved); @@ -1658,6 +1664,12 @@ chain_next(chainobject *lz) item = PyIter_Next(it); if (item != NULL) return item; + if (PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_StopIteration)) + PyErr_Clear(); + else + return NULL; + } lz->iternum++; } return NULL; |