From fbb1c5ee068d209e33f6e15ecb4821d5d8b107fa Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Wed, 30 Mar 2016 20:40:02 +0300 Subject: Issue #26494: Fixed crash on iterating exhausting iterators. Affected classes are generic sequence iterators, iterators of str, bytes, bytearray, list, tuple, set, frozenset, dict, OrderedDict, corresponding views and os.scandir() iterator. --- Lib/test/seq_tests.py | 5 +++++ Lib/test/support/__init__.py | 19 +++++++++++++++++++ Lib/test/test_bytes.py | 4 ++++ Lib/test/test_deque.py | 4 ++++ Lib/test/test_dict.py | 6 ++++++ Lib/test/test_iter.py | 4 ++++ Lib/test/test_ordered_dict.py | 6 ++++++ Lib/test/test_set.py | 3 +++ Lib/test/test_unicode.py | 4 ++++ Misc/NEWS | 5 +++++ Modules/posixmodule.c | 16 ++++++++++------ Objects/bytearrayobject.c | 2 +- Objects/bytesobject.c | 2 +- Objects/dictobject.c | 6 +++--- Objects/iterobject.c | 2 +- Objects/listobject.c | 20 +++++++++++++------- Objects/setobject.c | 2 +- Objects/tupleobject.c | 2 +- Objects/unicodeobject.c | 2 +- 19 files changed, 92 insertions(+), 22 deletions(-) diff --git a/Lib/test/seq_tests.py b/Lib/test/seq_tests.py index 2416249..72f4845 100644 --- a/Lib/test/seq_tests.py +++ b/Lib/test/seq_tests.py @@ -5,6 +5,7 @@ Tests common to tuple, list and UserList.UserList import unittest import sys import pickle +from test import support # Various iterables # This is used for checking the constructor (here and in test_deque.py) @@ -408,3 +409,7 @@ class CommonTest(unittest.TestCase): lst2 = pickle.loads(pickle.dumps(lst, proto)) self.assertEqual(lst2, lst) self.assertNotEqual(id(lst2), id(lst)) + + def test_free_after_iterating(self): + support.check_free_after_iterating(self, iter, self.type2test) + support.check_free_after_iterating(self, reversed, self.type2test) diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index b82f9cb..e124fab 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -2366,3 +2366,22 @@ def run_in_subinterp(code): "memory allocations") import _testcapi return _testcapi.run_in_subinterp(code) + + +def check_free_after_iterating(test, iter, cls, args=()): + class A(cls): + def __del__(self): + nonlocal done + done = True + try: + next(it) + except StopIteration: + pass + + done = False + it = iter(A(*args)) + # Issue 26494: Shouldn't crash + test.assertRaises(StopIteration, next, it) + # The sequence should be deallocated just after the end of iterating + gc_collect() + test.assertTrue(done) diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index 80798f2..1bd3a1e 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -747,6 +747,10 @@ class BaseBytesTest: self.assertRaisesRegex(TypeError, r'\bendswith\b', b.endswith, x, None, None, None) + def test_free_after_iterating(self): + test.support.check_free_after_iterating(self, iter, self.type2test) + test.support.check_free_after_iterating(self, reversed, self.type2test) + class BytesTest(BaseBytesTest, unittest.TestCase): type2test = bytes diff --git a/Lib/test/test_deque.py b/Lib/test/test_deque.py index ec2be83..7041d17 100644 --- a/Lib/test/test_deque.py +++ b/Lib/test/test_deque.py @@ -905,6 +905,10 @@ class TestSequence(seq_tests.CommonTest): # For now, bypass tests that require slicing pass + def test_free_after_iterating(self): + # For now, bypass tests that require slicing + self.skipTest("Exhausted deque iterator doesn't free a deque") + #============================================================================== libreftest = """ diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index 3b42414..075cb5a 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -952,6 +952,12 @@ class DictTest(unittest.TestCase): d = {X(): 0, 1: 1} self.assertRaises(RuntimeError, d.update, other) + def test_free_after_iterating(self): + support.check_free_after_iterating(self, iter, dict) + support.check_free_after_iterating(self, lambda d: iter(d.keys()), dict) + support.check_free_after_iterating(self, lambda d: iter(d.values()), dict) + support.check_free_after_iterating(self, lambda d: iter(d.items()), dict) + from test import mapping_tests class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol): diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py index 56e21f8..54ddbaa 100644 --- a/Lib/test/test_iter.py +++ b/Lib/test/test_iter.py @@ -3,6 +3,7 @@ import sys import unittest from test.support import run_unittest, TESTFN, unlink, cpython_only +from test.support import check_free_after_iterating import pickle import collections.abc @@ -980,6 +981,9 @@ class TestCase(unittest.TestCase): self.assertEqual(next(it), 0) self.assertEqual(next(it), 1) + def test_free_after_iterating(self): + check_free_after_iterating(self, iter, SequenceClass, (0,)) + def test_main(): run_unittest(TestCase) diff --git a/Lib/test/test_ordered_dict.py b/Lib/test/test_ordered_dict.py index 8ab0a9f..901d4b2 100644 --- a/Lib/test/test_ordered_dict.py +++ b/Lib/test/test_ordered_dict.py @@ -598,6 +598,12 @@ class OrderedDictTests: gc.collect() self.assertIsNone(r()) + def test_free_after_iterating(self): + support.check_free_after_iterating(self, iter, self.OrderedDict) + support.check_free_after_iterating(self, lambda d: iter(d.keys()), self.OrderedDict) + support.check_free_after_iterating(self, lambda d: iter(d.values()), self.OrderedDict) + support.check_free_after_iterating(self, lambda d: iter(d.items()), self.OrderedDict) + class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase): diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index 54de508..0b99dfc 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -362,6 +362,9 @@ class TestJointOps: gc.collect() self.assertTrue(ref() is None, "Cycle was not collected") + def test_free_after_iterating(self): + support.check_free_after_iterating(self, iter, self.thetype) + class TestSet(TestJointOps, unittest.TestCase): thetype = set basetype = set diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index c30310e..c281146 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -2729,6 +2729,10 @@ class UnicodeTest(string_tests.CommonTest, # Check that the second call returns the same result self.assertEqual(getargs_s_hash(s), chr(k).encode() * (i + 1)) + def test_free_after_iterating(self): + support.check_free_after_iterating(self, iter, str) + support.check_free_after_iterating(self, reversed, str) + class StringModuleTest(unittest.TestCase): def test_formatter_parser(self): diff --git a/Misc/NEWS b/Misc/NEWS index 59c0828..391bbf2 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -10,6 +10,11 @@ Release date: tba Core and Builtins ----------------- +- Issue #26494: Fixed crash on iterating exhausting iterators. + Affected classes are generic sequence iterators, iterators of str, bytes, + bytearray, list, tuple, set, frozenset, dict, OrderedDict, corresponding + views and os.scandir() iterator. + - Issue #26581: If coding cookie is specified multiple times on a line in Python source code file, only the first one is taken to account. diff --git a/Modules/posixmodule.c b/Modules/posixmodule.c index 710bcde..c95668b 100644 --- a/Modules/posixmodule.c +++ b/Modules/posixmodule.c @@ -11928,13 +11928,15 @@ typedef struct { static void ScandirIterator_close(ScandirIterator *iterator) { - if (iterator->handle == INVALID_HANDLE_VALUE) + HANDLE handle = iterator->handle; + + if (handle == INVALID_HANDLE_VALUE) return; + iterator->handle = INVALID_HANDLE_VALUE; Py_BEGIN_ALLOW_THREADS - FindClose(iterator->handle); + FindClose(handle); Py_END_ALLOW_THREADS - iterator->handle = INVALID_HANDLE_VALUE; } static PyObject * @@ -11984,13 +11986,15 @@ ScandirIterator_iternext(ScandirIterator *iterator) static void ScandirIterator_close(ScandirIterator *iterator) { - if (!iterator->dirp) + DIR *dirp = iterator->dirp; + + if (!dirp) return; + iterator->dirp = NULL; Py_BEGIN_ALLOW_THREADS - closedir(iterator->dirp); + closedir(dirp); Py_END_ALLOW_THREADS - iterator->dirp = NULL; return; } diff --git a/Objects/bytearrayobject.c b/Objects/bytearrayobject.c index c59ad24..c723a9c 100644 --- a/Objects/bytearrayobject.c +++ b/Objects/bytearrayobject.c @@ -3186,8 +3186,8 @@ bytearrayiter_next(bytesiterobject *it) return item; } - Py_DECREF(seq); it->it_seq = NULL; + Py_DECREF(seq); return NULL; } diff --git a/Objects/bytesobject.c b/Objects/bytesobject.c index 51d0871..495c3eb 100644 --- a/Objects/bytesobject.c +++ b/Objects/bytesobject.c @@ -3628,8 +3628,8 @@ striter_next(striterobject *it) return item; } - Py_DECREF(seq); it->it_seq = NULL; + Py_DECREF(seq); return NULL; } diff --git a/Objects/dictobject.c b/Objects/dictobject.c index e4dff98..d774586 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -2985,8 +2985,8 @@ static PyObject *dictiter_iternextkey(dictiterobject *di) return key; fail: - Py_DECREF(d); di->di_dict = NULL; + Py_DECREF(d); return NULL; } @@ -3066,8 +3066,8 @@ static PyObject *dictiter_iternextvalue(dictiterobject *di) return value; fail: - Py_DECREF(d); di->di_dict = NULL; + Py_DECREF(d); return NULL; } @@ -3161,8 +3161,8 @@ static PyObject *dictiter_iternextitem(dictiterobject *di) return result; fail: - Py_DECREF(d); di->di_dict = NULL; + Py_DECREF(d); return NULL; } diff --git a/Objects/iterobject.c b/Objects/iterobject.c index 2fb0c88..ab29ff8 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -69,8 +69,8 @@ iter_iternext(PyObject *iterator) PyErr_ExceptionMatches(PyExc_StopIteration)) { PyErr_Clear(); - Py_DECREF(seq); it->it_seq = NULL; + Py_DECREF(seq); } return NULL; } diff --git a/Objects/listobject.c b/Objects/listobject.c index eee7c68..d688179 100644 --- a/Objects/listobject.c +++ b/Objects/listobject.c @@ -2782,8 +2782,8 @@ listiter_next(listiterobject *it) return item; } - Py_DECREF(seq); it->it_seq = NULL; + Py_DECREF(seq); return NULL; } @@ -2912,9 +2912,17 @@ static PyObject * listreviter_next(listreviterobject *it) { PyObject *item; - Py_ssize_t index = it->it_index; - PyListObject *seq = it->it_seq; + Py_ssize_t index; + PyListObject *seq; + + assert(it != NULL); + seq = it->it_seq; + if (seq == NULL) { + return NULL; + } + assert(PyList_Check(seq)); + index = it->it_index; if (index>=0 && index < PyList_GET_SIZE(seq)) { item = PyList_GET_ITEM(seq, index); it->it_index--; @@ -2922,10 +2930,8 @@ listreviter_next(listreviterobject *it) return item; } it->it_index = -1; - if (seq != NULL) { - it->it_seq = NULL; - Py_DECREF(seq); - } + it->it_seq = NULL; + Py_DECREF(seq); return NULL; } diff --git a/Objects/setobject.c b/Objects/setobject.c index 582f280..4ef692d 100644 --- a/Objects/setobject.c +++ b/Objects/setobject.c @@ -839,8 +839,8 @@ static PyObject *setiter_iternext(setiterobject *si) return key; fail: - Py_DECREF(so); si->si_set = NULL; + Py_DECREF(so); return NULL; } diff --git a/Objects/tupleobject.c b/Objects/tupleobject.c index 7efa1a6..7920fec 100644 --- a/Objects/tupleobject.c +++ b/Objects/tupleobject.c @@ -964,8 +964,8 @@ tupleiter_next(tupleiterobject *it) return item; } - Py_DECREF(seq); it->it_seq = NULL; + Py_DECREF(seq); return NULL; } diff --git a/Objects/unicodeobject.c b/Objects/unicodeobject.c index adc4615..230125b 100644 --- a/Objects/unicodeobject.c +++ b/Objects/unicodeobject.c @@ -15149,8 +15149,8 @@ unicodeiter_next(unicodeiterobject *it) return item; } - Py_DECREF(seq); it->it_seq = NULL; + Py_DECREF(seq); return NULL; } -- cgit v0.12