summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_itertools.py30
-rw-r--r--Modules/itertoolsmodule.c12
2 files changed, 42 insertions, 0 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index 73e8809..54e46e1 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -644,6 +644,36 @@ class RegressionTests(unittest.TestCase):
self.assertEqual(first, second)
+ def test_sf_950057(self):
+ # Make sure that chain() and cycle() catch exceptions immediately
+ # rather than when shifting between input sources
+
+ def gen1():
+ hist.append(0)
+ yield 1
+ hist.append(1)
+ assert False
+ hist.append(2)
+
+ def gen2(x):
+ hist.append(3)
+ yield 2
+ hist.append(4)
+ if x:
+ raise StopIteration
+
+ hist = []
+ self.assertRaises(AssertionError, list, chain(gen1(), gen2(False)))
+ self.assertEqual(hist, [0,1])
+
+ hist = []
+ self.assertRaises(AssertionError, list, chain(gen1(), gen2(True)))
+ self.assertEqual(hist, [0,1])
+
+ hist = []
+ self.assertRaises(AssertionError, list, cycle(gen1()))
+ self.assertEqual(hist, [0,1])
+
libreftest = """ Doctest for examples in the library reference: libitertools.tex
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index 4ce4643..3515bc6 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -699,6 +699,12 @@ cycle_next(cycleobject *lz)
PyList_Append(lz->saved, item);
return item;
}
+ if (PyErr_Occurred()) {
+ if (PyErr_ExceptionMatches(PyExc_StopIteration))
+ PyErr_Clear();
+ else
+ return NULL;
+ }
if (PyList_Size(lz->saved) == 0)
return NULL;
it = PyObject_GetIter(lz->saved);
@@ -1658,6 +1664,12 @@ chain_next(chainobject *lz)
item = PyIter_Next(it);
if (item != NULL)
return item;
+ if (PyErr_Occurred()) {
+ if (PyErr_ExceptionMatches(PyExc_StopIteration))
+ PyErr_Clear();
+ else
+ return NULL;
+ }
lz->iternum++;
}
return NULL;