summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Include/abstract.h26
-rw-r--r--Lib/test/test_iter.py41
-rw-r--r--Misc/NEWS3
-rw-r--r--Objects/abstract.c134
-rw-r--r--Objects/classobject.c3
-rw-r--r--Objects/typeobject.c3
6 files changed, 128 insertions, 82 deletions
diff --git a/Include/abstract.h b/Include/abstract.h
index f4c1b3e..d736efc 100644
--- a/Include/abstract.h
+++ b/Include/abstract.h
@@ -988,14 +988,24 @@ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx*/
DL_IMPORT(int) PySequence_Contains(PyObject *seq, PyObject *ob);
/*
Return -1 if error; 1 if ob in seq; 0 if ob not in seq.
- Use __contains__ if possible, else _PySequence_IterContains().
- */
-
- DL_IMPORT(int) _PySequence_IterContains(PyObject *seq, PyObject *ob);
- /*
- Return -1 if error; 1 if ob in seq; 0 if ob not in seq.
- Always uses the iteration protocol, and only Py_EQ comparisons.
- */
+ Use __contains__ if possible, else _PySequence_IterSearch().
+ */
+
+#define PY_ITERSEARCH_COUNT 1
+#define PY_ITERSEARCH_INDEX 2
+#define PY_ITERSEARCH_CONTAINS 3
+ DL_IMPORT(int) _PySequence_IterSearch(PyObject *seq, PyObject *obj,
+ int operation);
+ /*
+ Iterate over seq. Result depends on the operation:
+ PY_ITERSEARCH_COUNT: return # of times obj appears in seq; -1 if
+ error.
+ PY_ITERSEARCH_INDEX: return 0-based index of first occurence of
+ obj in seq; set ValueError and return -1 if none found;
+ also return -1 on error.
+ PY_ITERSEARCH_CONTAINS: return 1 if obj in seq, else 0; -1 on
+ error.
+ */
/* For DLL-level backwards compatibility */
#undef PySequence_In
diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py
index 8b6891b..37fab7c 100644
--- a/Lib/test/test_iter.py
+++ b/Lib/test/test_iter.py
@@ -600,6 +600,47 @@ class TestCase(unittest.TestCase):
except OSError:
pass
+ # Test iterators with operator.indexOf (PySequence_Index).
+ def test_indexOf(self):
+ from operator import indexOf
+ self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0)
+ self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1)
+ self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3)
+ self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5)
+ self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0)
+ self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6)
+
+ self.assertEqual(indexOf("122325", "2"), 1)
+ self.assertEqual(indexOf("122325", "5"), 5)
+ self.assertRaises(ValueError, indexOf, "122325", "6")
+
+ self.assertRaises(TypeError, indexOf, 42, 1)
+ self.assertRaises(TypeError, indexOf, indexOf, indexOf)
+
+ f = open(TESTFN, "w")
+ try:
+ f.write("a\n" "b\n" "c\n" "d\n" "e\n")
+ finally:
+ f.close()
+ f = open(TESTFN, "r")
+ try:
+ fiter = iter(f)
+ self.assertEqual(indexOf(fiter, "b\n"), 1)
+ self.assertEqual(indexOf(fiter, "d\n"), 1)
+ self.assertEqual(indexOf(fiter, "e\n"), 0)
+ self.assertRaises(ValueError, indexOf, fiter, "a\n")
+ finally:
+ f.close()
+ try:
+ unlink(TESTFN)
+ except OSError:
+ pass
+
+ iclass = IteratingSequenceClass(3)
+ for i in range(3):
+ self.assertEqual(indexOf(iclass, i), i)
+ self.assertRaises(ValueError, indexOf, iclass, -1)
+
# Test iterators on RHS of unpacking assignments.
def test_unpack_iter(self):
a, b = 1, 2
diff --git a/Misc/NEWS b/Misc/NEWS
index 87bf717..ecc4588 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -5,6 +5,9 @@ Core
Library
+- operator.indexOf() (PySequence_Index() in the C API) now works with any
+ iterable object.
+
Tools
Build
diff --git a/Objects/abstract.c b/Objects/abstract.c
index c3a397c..5361b1d 100644
--- a/Objects/abstract.c
+++ b/Objects/abstract.c
@@ -1372,25 +1372,31 @@ PySequence_Fast(PyObject *v, const char *m)
return v;
}
-/* Return # of times o appears in s. */
+/* Iterate over seq. Result depends on the operation:
+ PY_ITERSEARCH_COUNT: -1 if error, else # of times obj appears in seq.
+ PY_ITERSEARCH_INDEX: 0-based index of first occurence of obj in seq;
+ set ValueError and return -1 if none found; also return -1 on error.
+ Py_ITERSEARCH_CONTAINS: return 1 if obj in seq, else 0; -1 on error.
+*/
int
-PySequence_Count(PyObject *s, PyObject *o)
+_PySequence_IterSearch(PyObject *seq, PyObject *obj, int operation)
{
- int n; /* running count of o hits */
- PyObject *it; /* iter(s) */
+ int n;
+ int wrapped; /* for PY_ITERSEARCH_INDEX, true iff n wrapped around */
+ PyObject *it; /* iter(seq) */
- if (s == NULL || o == NULL) {
+ if (seq == NULL || obj == NULL) {
null_error();
return -1;
}
- it = PyObject_GetIter(s);
+ it = PyObject_GetIter(seq);
if (it == NULL) {
- type_error(".count() requires iterable argument");
+ type_error("iterable argument required");
return -1;
}
- n = 0;
+ n = wrapped = 0;
for (;;) {
int cmp;
PyObject *item = PyIter_Next(it);
@@ -1399,61 +1405,70 @@ PySequence_Count(PyObject *s, PyObject *o)
goto Fail;
break;
}
- cmp = PyObject_RichCompareBool(o, item, Py_EQ);
+
+ cmp = PyObject_RichCompareBool(obj, item, Py_EQ);
Py_DECREF(item);
if (cmp < 0)
goto Fail;
if (cmp > 0) {
- if (n == INT_MAX) {
- PyErr_SetString(PyExc_OverflowError,
+ switch (operation) {
+ case PY_ITERSEARCH_COUNT:
+ ++n;
+ if (n <= 0) {
+ PyErr_SetString(PyExc_OverflowError,
"count exceeds C int size");
- goto Fail;
+ goto Fail;
+ }
+ break;
+
+ case PY_ITERSEARCH_INDEX:
+ if (wrapped) {
+ PyErr_SetString(PyExc_OverflowError,
+ "index exceeds C int size");
+ goto Fail;
+ }
+ goto Done;
+
+ case PY_ITERSEARCH_CONTAINS:
+ n = 1;
+ goto Done;
+
+ default:
+ assert(!"unknown operation");
}
- n++;
+ }
+
+ if (operation == PY_ITERSEARCH_INDEX) {
+ ++n;
+ if (n <= 0)
+ wrapped = 1;
}
}
- Py_DECREF(it);
- return n;
+ if (operation != PY_ITERSEARCH_INDEX)
+ goto Done;
+
+ PyErr_SetString(PyExc_ValueError,
+ "sequence.index(x): x not in sequence");
+ /* fall into failure code */
Fail:
+ n = -1;
+ /* fall through */
+Done:
Py_DECREF(it);
- return -1;
+ return n;
+
}
-/* Return -1 if error; 1 if ob in seq; 0 if ob not in seq.
- * Always uses the iteration protocol, and only Py_EQ comparison.
- */
+/* Return # of times o appears in s. */
int
-_PySequence_IterContains(PyObject *seq, PyObject *ob)
+PySequence_Count(PyObject *s, PyObject *o)
{
- int result;
- PyObject *it = PyObject_GetIter(seq);
- if (it == NULL) {
- PyErr_SetString(PyExc_TypeError,
- "'in' or 'not in' needs iterable right argument");
- return -1;
- }
-
- for (;;) {
- int cmp;
- PyObject *item = PyIter_Next(it);
- if (item == NULL) {
- result = PyErr_Occurred() ? -1 : 0;
- break;
- }
- cmp = PyObject_RichCompareBool(ob, item, Py_EQ);
- Py_DECREF(item);
- if (cmp == 0)
- continue;
- result = cmp > 0 ? 1 : -1;
- break;
- }
- Py_DECREF(it);
- return result;
+ return _PySequence_IterSearch(s, o, PY_ITERSEARCH_COUNT);
}
/* Return -1 if error; 1 if ob in seq; 0 if ob not in seq.
- * Use sq_contains if possible, else defer to _PySequence_IterContains().
+ * Use sq_contains if possible, else defer to _PySequence_IterSearch().
*/
int
PySequence_Contains(PyObject *seq, PyObject *ob)
@@ -1463,7 +1478,7 @@ PySequence_Contains(PyObject *seq, PyObject *ob)
if (sqm != NULL && sqm->sq_contains != NULL)
return (*sqm->sq_contains)(seq, ob);
}
- return _PySequence_IterContains(seq, ob);
+ return _PySequence_IterSearch(seq, ob, PY_ITERSEARCH_CONTAINS);
}
/* Backwards compatibility */
@@ -1477,32 +1492,7 @@ PySequence_In(PyObject *w, PyObject *v)
int
PySequence_Index(PyObject *s, PyObject *o)
{
- int l, i, cmp, err;
- PyObject *item;
-
- if (s == NULL || o == NULL) {
- null_error();
- return -1;
- }
-
- l = PySequence_Size(s);
- if (l < 0)
- return -1;
-
- for (i = 0; i < l; i++) {
- item = PySequence_GetItem(s, i);
- if (item == NULL)
- return -1;
- err = PyObject_Cmp(item, o, &cmp);
- Py_DECREF(item);
- if (err < 0)
- return err;
- if (cmp == 0)
- return i;
- }
-
- PyErr_SetString(PyExc_ValueError, "sequence.index(x): x not in list");
- return -1;
+ return _PySequence_IterSearch(s, o, PY_ITERSEARCH_INDEX);
}
/* Operations on mappings */
diff --git a/Objects/classobject.c b/Objects/classobject.c
index 4b69842..9d84173 100644
--- a/Objects/classobject.c
+++ b/Objects/classobject.c
@@ -1224,7 +1224,8 @@ instance_contains(PyInstanceObject *inst, PyObject *member)
* __contains__ attribute, and try iterating instead.
*/
PyErr_Clear();
- return _PySequence_IterContains((PyObject *)inst, member);
+ return _PySequence_IterSearch((PyObject *)inst, member,
+ PY_ITERSEARCH_CONTAINS);
}
else
return -1;
diff --git a/Objects/typeobject.c b/Objects/typeobject.c
index f15b096..430e68c 100644
--- a/Objects/typeobject.c
+++ b/Objects/typeobject.c
@@ -2559,7 +2559,8 @@ slot_sq_contains(PyObject *self, PyObject *value)
}
else {
PyErr_Clear();
- return _PySequence_IterContains(self, value);
+ return _PySequence_IterSearch(self, value,
+ PY_ITERSEARCH_CONTAINS);
}
}