From 7892b1c651d72a5bd08372f40309dec08a7065f0 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Mon, 12 Apr 2004 18:10:01 +0000 Subject: * Add unittests for iterators that report their length * Document the differences between them * Fix corner cases covered by the unittests * Use Py_RETURN_NONE where possible for dictionaries --- Lib/test/test_iterlen.py | 245 ++++++++++++++++++++++++++++++++++++++++++++ Modules/collectionsmodule.c | 2 + Objects/dictobject.c | 14 +-- Objects/enumobject.c | 13 ++- Objects/iterobject.c | 12 ++- 5 files changed, 276 insertions(+), 10 deletions(-) create mode 100644 Lib/test/test_iterlen.py diff --git a/Lib/test/test_iterlen.py b/Lib/test/test_iterlen.py new file mode 100644 index 0000000..f77169f --- /dev/null +++ b/Lib/test/test_iterlen.py @@ -0,0 +1,245 @@ +""" Test Iterator Length Transparency + +Some functions or methods which accept general iterable arguments have +optional, more efficient code paths if they know how many items to expect. +For instance, map(func, iterable), will pre-allocate the exact amount of +space required whenever the iterable can report its length. + +The desired invariant is: len(it)==len(list(it)). + +A complication is that an iterable and iterator can be the same object. To +maintain the invariant, an iterator needs to dynamically update its length. +For instance, an iterable such as xrange(10) always reports its length as ten, +but it=iter(xrange(10)) starts at ten, and then goes to nine after it.next(). +Having this capability means that map() can ignore the distinction between +map(func, iterable) and map(func, iter(iterable)). + +When the iterable is immutable, the implementation can straight-forwardly +report the original length minus the cumulative number of calls to next(). +This is the case for tuples, xrange objects, and itertools.repeat(). + +Some containers become temporarily immutable during iteration. This includes +dicts, sets, and collections.deque. Their implementation is equally simple +though they need to permantently set their length to zero whenever there is +an attempt to iterate after a length mutation. + +The situation slightly more involved whenever an object allows length mutation +during iteration. Lists and sequence iterators are dynanamically updatable. +So, if a list is extended during iteration, the iterator will continue through +the new items. If it shrinks to a point before the most recent iteration, +then no further items are available and the length is reported at zero. + +Reversed objects can also be wrapped around mutable objects; however, any +appends after the current position are ignored. Any other approach leads +to confusion and possibly returning the same item more than once. + +The iterators not listed above, such as enumerate and the other itertools, +are not length transparent because they have no way to distinguish between +iterables that report static length and iterators whose length changes with +each call (i.e. the difference between enumerate('abc') and +enumerate(iter('abc')). + +""" + +import unittest +from test import test_support +from itertools import repeat, count +from collections import deque +from UserList import UserList + +n = 10 + +class TestInvariantWithoutMutations(unittest.TestCase): + + def test_invariant(self): + it = self.it + for i in reversed(xrange(1, n+1)): + self.assertEqual(len(it), i) + it.next() + self.assertEqual(len(it), 0) + self.assertRaises(StopIteration, it.next) + self.assertEqual(len(it), 0) + +class TestTemporarilyImmutable(TestInvariantWithoutMutations): + + def test_immutable_during_iteration(self): + # objects such as deques, sets, and dictionaries enforce + # length immutability during iteration + + it = self.it + self.assertEqual(len(it), n) + it.next() + self.assertEqual(len(it), n-1) + self.mutate() + self.assertRaises(RuntimeError, it.next) + self.assertEqual(len(it), 0) + +## ------- Concrete Type Tests ------- + +class TestRepeat(TestInvariantWithoutMutations): + + def setUp(self): + self.it = repeat(None, n) + + def test_no_len_for_infinite_repeat(self): + # The repeat() object can also be infinite + self.assertRaises(TypeError, len, repeat(None)) + +class TestXrange(TestInvariantWithoutMutations): + + def setUp(self): + self.it = iter(xrange(n)) + +class TestXrangeCustomReversed(TestInvariantWithoutMutations): + + def setUp(self): + self.it = reversed(xrange(n)) + +class TestTuple(TestInvariantWithoutMutations): + + def setUp(self): + self.it = iter(tuple(xrange(n))) + +## ------- Types that should not be mutated during iteration ------- + +class TestDeque(TestTemporarilyImmutable): + + def setUp(self): + d = deque(xrange(n)) + self.it = iter(d) + self.mutate = d.pop + +class TestDequeReversed(TestTemporarilyImmutable): + + def setUp(self): + d = deque(xrange(n)) + self.it = reversed(d) + self.mutate = d.pop + +class TestDictKeys(TestTemporarilyImmutable): + + def setUp(self): + d = dict.fromkeys(xrange(n)) + self.it = iter(d) + self.mutate = d.popitem + +class TestDictItems(TestTemporarilyImmutable): + + def setUp(self): + d = dict.fromkeys(xrange(n)) + self.it = d.iteritems() + self.mutate = d.popitem + +class TestDictValues(TestTemporarilyImmutable): + + def setUp(self): + d = dict.fromkeys(xrange(n)) + self.it = d.itervalues() + self.mutate = d.popitem + +class TestSet(TestTemporarilyImmutable): + + def setUp(self): + d = set(xrange(n)) + self.it = iter(d) + self.mutate = d.pop + +## ------- Types that can mutate during iteration ------- + +class TestList(TestInvariantWithoutMutations): + + def setUp(self): + self.it = iter(range(n)) + + def test_mutation(self): + d = range(n) + it = iter(d) + it.next() + it.next() + self.assertEqual(len(it), n-2) + d.append(n) + self.assertEqual(len(it), n-1) # grow with append + d[1:] = [] + self.assertEqual(len(it), 0) + self.assertEqual(list(it), []) + d.extend(xrange(20)) + self.assertEqual(len(it), 0) + +class TestListReversed(TestInvariantWithoutMutations): + + def setUp(self): + self.it = reversed(range(n)) + + def test_mutation(self): + d = range(n) + it = reversed(d) + it.next() + it.next() + self.assertEqual(len(it), n-2) + d.append(n) + self.assertEqual(len(it), n-2) # ignore append + d[1:] = [] + self.assertEqual(len(it), 0) + self.assertEqual(list(it), []) # confirm invariant + d.extend(xrange(20)) + self.assertEqual(len(it), 0) + +class TestSeqIter(TestInvariantWithoutMutations): + + def setUp(self): + self.it = iter(UserList(range(n))) + + def test_mutation(self): + d = UserList(range(n)) + it = iter(d) + it.next() + it.next() + self.assertEqual(len(it), n-2) + d.append(n) + self.assertEqual(len(it), n-1) # grow with append + d[1:] = [] + self.assertEqual(len(it), 0) + self.assertEqual(list(it), []) + d.extend(xrange(20)) + self.assertEqual(len(it), 0) + +class TestSeqIterReversed(TestInvariantWithoutMutations): + + def setUp(self): + self.it = reversed(UserList(range(n))) + + def test_mutation(self): + d = UserList(range(n)) + it = reversed(d) + it.next() + it.next() + self.assertEqual(len(it), n-2) + d.append(n) + self.assertEqual(len(it), n-2) # ignore append + d[1:] = [] + self.assertEqual(len(it), 0) + self.assertEqual(list(it), []) # confirm invariant + d.extend(xrange(20)) + self.assertEqual(len(it), 0) + + + +if __name__ == "__main__": + + unittests = [ + TestRepeat, + TestXrange, + TestXrangeCustomReversed, + TestTuple, + TestDeque, + TestDequeReversed, + TestDictKeys, + TestDictItems, + TestDictValues, + TestSet, + TestList, + TestListReversed, + TestSeqIter, + TestSeqIterReversed, + ] + test_support.run_unittest(*unittests) diff --git a/Modules/collectionsmodule.c b/Modules/collectionsmodule.c index cf474f7..fc30c99 100644 --- a/Modules/collectionsmodule.c +++ b/Modules/collectionsmodule.c @@ -770,6 +770,7 @@ dequeiter_next(dequeiterobject *it) if (it->len != it->deque->len) { it->len = -1; /* Make this state sticky */ + it->counter = 0; PyErr_SetString(PyExc_RuntimeError, "deque changed size during iteration"); return NULL; @@ -860,6 +861,7 @@ dequereviter_next(dequeiterobject *it) if (it->len != it->deque->len) { it->len = -1; /* Make this state sticky */ + it->counter = 0; PyErr_SetString(PyExc_RuntimeError, "deque changed size during iteration"); return NULL; diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 0f2a271..84cf482 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -1088,10 +1088,9 @@ dict_update_common(PyObject *self, PyObject *args, PyObject *kwds, char *methnam static PyObject * dict_update(PyObject *self, PyObject *args, PyObject *kwds) { - if (dict_update_common(self, args, kwds, "update") == -1) - return NULL; - Py_INCREF(Py_None); - return Py_None; + if (dict_update_common(self, args, kwds, "update") != -1) + Py_RETURN_NONE; + return NULL; } /* Update unconditionally replaces existing items. @@ -1593,8 +1592,7 @@ static PyObject * dict_clear(register dictobject *mp) { PyDict_Clear((PyObject *)mp); - Py_INCREF(Py_None); - return Py_None; + Py_RETURN_NONE; } static PyObject * @@ -2050,7 +2048,9 @@ dictiter_dealloc(dictiterobject *di) static int dictiter_len(dictiterobject *di) { - return di->len; + if (di->di_dict != NULL && di->di_used == di->di_dict->ma_used) + return di->len; + return 0; } static PySequenceMethods dictiter_as_sequence = { diff --git a/Objects/enumobject.c b/Objects/enumobject.c index 28719a9..549fc9f 100644 --- a/Objects/enumobject.c +++ b/Objects/enumobject.c @@ -225,6 +225,9 @@ reversed_next(reversedobject *ro) ro->index--; return item; } + if (PyErr_ExceptionMatches(PyExc_IndexError) || + PyErr_ExceptionMatches(PyExc_StopIteration)) + PyErr_Clear(); } ro->index = -1; if (ro->seq != NULL) { @@ -242,7 +245,15 @@ PyDoc_STRVAR(reversed_doc, static int reversed_len(reversedobject *ro) { - return ro->index + 1; + int position, seqsize; + + if (ro->seq == NULL) + return 0; + seqsize = PySequence_Size(ro->seq); + if (seqsize == -1) + return -1; + position = ro->index + 1; + return (seqsize < position) ? 0 : position; } static PySequenceMethods reversed_as_sequence = { diff --git a/Objects/iterobject.c b/Objects/iterobject.c index a407dd5..25e4e11 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -74,8 +74,16 @@ iter_iternext(PyObject *iterator) static int iter_len(seqiterobject *it) { - if (it->it_seq) - return PyObject_Size(it->it_seq) - it->it_index; + int seqsize, len; + + if (it->it_seq) { + seqsize = PySequence_Size(it->it_seq); + if (seqsize == -1) + return -1; + len = seqsize - it->it_index; + if (len >= 0) + return len; + } return 0; } -- cgit v0.12