summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRaymond Hettinger <python@rcn.com>2004-04-12 18:10:01 (GMT)
committerRaymond Hettinger <python@rcn.com>2004-04-12 18:10:01 (GMT)
commit7892b1c651d72a5bd08372f40309dec08a7065f0 (patch)
tree7ae71a1e81651c4fa7f786ebfbdbc8364a41730e
parent45d0b5cc44ffb6227a2379a39b00d480f253edd5 (diff)
downloadcpython-7892b1c651d72a5bd08372f40309dec08a7065f0.zip
cpython-7892b1c651d72a5bd08372f40309dec08a7065f0.tar.gz
cpython-7892b1c651d72a5bd08372f40309dec08a7065f0.tar.bz2
* Add unittests for iterators that report their length
* Document the differences between them * Fix corner cases covered by the unittests * Use Py_RETURN_NONE where possible for dictionaries
-rw-r--r--Lib/test/test_iterlen.py245
-rw-r--r--Modules/collectionsmodule.c2
-rw-r--r--Objects/dictobject.c14
-rw-r--r--Objects/enumobject.c13
-rw-r--r--Objects/iterobject.c12
5 files changed, 276 insertions, 10 deletions
diff --git a/Lib/test/test_iterlen.py b/Lib/test/test_iterlen.py
new file mode 100644
index 0000000..f77169f
--- /dev/null
+++ b/Lib/test/test_iterlen.py
@@ -0,0 +1,245 @@
+""" Test Iterator Length Transparency
+
+Some functions or methods which accept general iterable arguments have
+optional, more efficient code paths if they know how many items to expect.
+For instance, map(func, iterable), will pre-allocate the exact amount of
+space required whenever the iterable can report its length.
+
+The desired invariant is: len(it)==len(list(it)).
+
+A complication is that an iterable and iterator can be the same object. To
+maintain the invariant, an iterator needs to dynamically update its length.
+For instance, an iterable such as xrange(10) always reports its length as ten,
+but it=iter(xrange(10)) starts at ten, and then goes to nine after it.next().
+Having this capability means that map() can ignore the distinction between
+map(func, iterable) and map(func, iter(iterable)).
+
+When the iterable is immutable, the implementation can straight-forwardly
+report the original length minus the cumulative number of calls to next().
+This is the case for tuples, xrange objects, and itertools.repeat().
+
+Some containers become temporarily immutable during iteration. This includes
+dicts, sets, and collections.deque. Their implementation is equally simple
+though they need to permantently set their length to zero whenever there is
+an attempt to iterate after a length mutation.
+
+The situation slightly more involved whenever an object allows length mutation
+during iteration. Lists and sequence iterators are dynanamically updatable.
+So, if a list is extended during iteration, the iterator will continue through
+the new items. If it shrinks to a point before the most recent iteration,
+then no further items are available and the length is reported at zero.
+
+Reversed objects can also be wrapped around mutable objects; however, any
+appends after the current position are ignored. Any other approach leads
+to confusion and possibly returning the same item more than once.
+
+The iterators not listed above, such as enumerate and the other itertools,
+are not length transparent because they have no way to distinguish between
+iterables that report static length and iterators whose length changes with
+each call (i.e. the difference between enumerate('abc') and
+enumerate(iter('abc')).
+
+"""
+
+import unittest
+from test import test_support
+from itertools import repeat, count
+from collections import deque
+from UserList import UserList
+
+n = 10
+
+class TestInvariantWithoutMutations(unittest.TestCase):
+
+ def test_invariant(self):
+ it = self.it
+ for i in reversed(xrange(1, n+1)):
+ self.assertEqual(len(it), i)
+ it.next()
+ self.assertEqual(len(it), 0)
+ self.assertRaises(StopIteration, it.next)
+ self.assertEqual(len(it), 0)
+
+class TestTemporarilyImmutable(TestInvariantWithoutMutations):
+
+ def test_immutable_during_iteration(self):
+ # objects such as deques, sets, and dictionaries enforce
+ # length immutability during iteration
+
+ it = self.it
+ self.assertEqual(len(it), n)
+ it.next()
+ self.assertEqual(len(it), n-1)
+ self.mutate()
+ self.assertRaises(RuntimeError, it.next)
+ self.assertEqual(len(it), 0)
+
+## ------- Concrete Type Tests -------
+
+class TestRepeat(TestInvariantWithoutMutations):
+
+ def setUp(self):
+ self.it = repeat(None, n)
+
+ def test_no_len_for_infinite_repeat(self):
+ # The repeat() object can also be infinite
+ self.assertRaises(TypeError, len, repeat(None))
+
+class TestXrange(TestInvariantWithoutMutations):
+
+ def setUp(self):
+ self.it = iter(xrange(n))
+
+class TestXrangeCustomReversed(TestInvariantWithoutMutations):
+
+ def setUp(self):
+ self.it = reversed(xrange(n))
+
+class TestTuple(TestInvariantWithoutMutations):
+
+ def setUp(self):
+ self.it = iter(tuple(xrange(n)))
+
+## ------- Types that should not be mutated during iteration -------
+
+class TestDeque(TestTemporarilyImmutable):
+
+ def setUp(self):
+ d = deque(xrange(n))
+ self.it = iter(d)
+ self.mutate = d.pop
+
+class TestDequeReversed(TestTemporarilyImmutable):
+
+ def setUp(self):
+ d = deque(xrange(n))
+ self.it = reversed(d)
+ self.mutate = d.pop
+
+class TestDictKeys(TestTemporarilyImmutable):
+
+ def setUp(self):
+ d = dict.fromkeys(xrange(n))
+ self.it = iter(d)
+ self.mutate = d.popitem
+
+class TestDictItems(TestTemporarilyImmutable):
+
+ def setUp(self):
+ d = dict.fromkeys(xrange(n))
+ self.it = d.iteritems()
+ self.mutate = d.popitem
+
+class TestDictValues(TestTemporarilyImmutable):
+
+ def setUp(self):
+ d = dict.fromkeys(xrange(n))
+ self.it = d.itervalues()
+ self.mutate = d.popitem
+
+class TestSet(TestTemporarilyImmutable):
+
+ def setUp(self):
+ d = set(xrange(n))
+ self.it = iter(d)
+ self.mutate = d.pop
+
+## ------- Types that can mutate during iteration -------
+
+class TestList(TestInvariantWithoutMutations):
+
+ def setUp(self):
+ self.it = iter(range(n))
+
+ def test_mutation(self):
+ d = range(n)
+ it = iter(d)
+ it.next()
+ it.next()
+ self.assertEqual(len(it), n-2)
+ d.append(n)
+ self.assertEqual(len(it), n-1) # grow with append
+ d[1:] = []
+ self.assertEqual(len(it), 0)
+ self.assertEqual(list(it), [])
+ d.extend(xrange(20))
+ self.assertEqual(len(it), 0)
+
+class TestListReversed(TestInvariantWithoutMutations):
+
+ def setUp(self):
+ self.it = reversed(range(n))
+
+ def test_mutation(self):
+ d = range(n)
+ it = reversed(d)
+ it.next()
+ it.next()
+ self.assertEqual(len(it), n-2)
+ d.append(n)
+ self.assertEqual(len(it), n-2) # ignore append
+ d[1:] = []
+ self.assertEqual(len(it), 0)
+ self.assertEqual(list(it), []) # confirm invariant
+ d.extend(xrange(20))
+ self.assertEqual(len(it), 0)
+
+class TestSeqIter(TestInvariantWithoutMutations):
+
+ def setUp(self):
+ self.it = iter(UserList(range(n)))
+
+ def test_mutation(self):
+ d = UserList(range(n))
+ it = iter(d)
+ it.next()
+ it.next()
+ self.assertEqual(len(it), n-2)
+ d.append(n)
+ self.assertEqual(len(it), n-1) # grow with append
+ d[1:] = []
+ self.assertEqual(len(it), 0)
+ self.assertEqual(list(it), [])
+ d.extend(xrange(20))
+ self.assertEqual(len(it), 0)
+
+class TestSeqIterReversed(TestInvariantWithoutMutations):
+
+ def setUp(self):
+ self.it = reversed(UserList(range(n)))
+
+ def test_mutation(self):
+ d = UserList(range(n))
+ it = reversed(d)
+ it.next()
+ it.next()
+ self.assertEqual(len(it), n-2)
+ d.append(n)
+ self.assertEqual(len(it), n-2) # ignore append
+ d[1:] = []
+ self.assertEqual(len(it), 0)
+ self.assertEqual(list(it), []) # confirm invariant
+ d.extend(xrange(20))
+ self.assertEqual(len(it), 0)
+
+
+
+if __name__ == "__main__":
+
+ unittests = [
+ TestRepeat,
+ TestXrange,
+ TestXrangeCustomReversed,
+ TestTuple,
+ TestDeque,
+ TestDequeReversed,
+ TestDictKeys,
+ TestDictItems,
+ TestDictValues,
+ TestSet,
+ TestList,
+ TestListReversed,
+ TestSeqIter,
+ TestSeqIterReversed,
+ ]
+ test_support.run_unittest(*unittests)
diff --git a/Modules/collectionsmodule.c b/Modules/collectionsmodule.c
index cf474f7..fc30c99 100644
--- a/Modules/collectionsmodule.c
+++ b/Modules/collectionsmodule.c
@@ -770,6 +770,7 @@ dequeiter_next(dequeiterobject *it)
if (it->len != it->deque->len) {
it->len = -1; /* Make this state sticky */
+ it->counter = 0;
PyErr_SetString(PyExc_RuntimeError,
"deque changed size during iteration");
return NULL;
@@ -860,6 +861,7 @@ dequereviter_next(dequeiterobject *it)
if (it->len != it->deque->len) {
it->len = -1; /* Make this state sticky */
+ it->counter = 0;
PyErr_SetString(PyExc_RuntimeError,
"deque changed size during iteration");
return NULL;
diff --git a/Objects/dictobject.c b/Objects/dictobject.c
index 0f2a271..84cf482 100644
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -1088,10 +1088,9 @@ dict_update_common(PyObject *self, PyObject *args, PyObject *kwds, char *methnam
static PyObject *
dict_update(PyObject *self, PyObject *args, PyObject *kwds)
{
- if (dict_update_common(self, args, kwds, "update") == -1)
- return NULL;
- Py_INCREF(Py_None);
- return Py_None;
+ if (dict_update_common(self, args, kwds, "update") != -1)
+ Py_RETURN_NONE;
+ return NULL;
}
/* Update unconditionally replaces existing items.
@@ -1593,8 +1592,7 @@ static PyObject *
dict_clear(register dictobject *mp)
{
PyDict_Clear((PyObject *)mp);
- Py_INCREF(Py_None);
- return Py_None;
+ Py_RETURN_NONE;
}
static PyObject *
@@ -2050,7 +2048,9 @@ dictiter_dealloc(dictiterobject *di)
static int
dictiter_len(dictiterobject *di)
{
- return di->len;
+ if (di->di_dict != NULL && di->di_used == di->di_dict->ma_used)
+ return di->len;
+ return 0;
}
static PySequenceMethods dictiter_as_sequence = {
diff --git a/Objects/enumobject.c b/Objects/enumobject.c
index 28719a9..549fc9f 100644
--- a/Objects/enumobject.c
+++ b/Objects/enumobject.c
@@ -225,6 +225,9 @@ reversed_next(reversedobject *ro)
ro->index--;
return item;
}
+ if (PyErr_ExceptionMatches(PyExc_IndexError) ||
+ PyErr_ExceptionMatches(PyExc_StopIteration))
+ PyErr_Clear();
}
ro->index = -1;
if (ro->seq != NULL) {
@@ -242,7 +245,15 @@ PyDoc_STRVAR(reversed_doc,
static int
reversed_len(reversedobject *ro)
{
- return ro->index + 1;
+ int position, seqsize;
+
+ if (ro->seq == NULL)
+ return 0;
+ seqsize = PySequence_Size(ro->seq);
+ if (seqsize == -1)
+ return -1;
+ position = ro->index + 1;
+ return (seqsize < position) ? 0 : position;
}
static PySequenceMethods reversed_as_sequence = {
diff --git a/Objects/iterobject.c b/Objects/iterobject.c
index a407dd5..25e4e11 100644
--- a/Objects/iterobject.c
+++ b/Objects/iterobject.c
@@ -74,8 +74,16 @@ iter_iternext(PyObject *iterator)
static int
iter_len(seqiterobject *it)
{
- if (it->it_seq)
- return PyObject_Size(it->it_seq) - it->it_index;
+ int seqsize, len;
+
+ if (it->it_seq) {
+ seqsize = PySequence_Size(it->it_seq);
+ if (seqsize == -1)
+ return -1;
+ len = seqsize - it->it_index;
+ if (len >= 0)
+ return len;
+ }
return 0;
}